Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.trees.TreePattern.{MAX, TreePattern}
import org.apache.spark.sql.catalyst.trees.UnaryLike
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
Expand All @@ -43,6 +44,8 @@ case class Max(child: Expression) extends DeclarativeAggregate with UnaryLike[Ex
override def checkInputDataTypes(): TypeCheckResult =
TypeUtils.checkForOrderingExpr(child.dataType, prettyName)

final override val nodePatterns: Seq[TreePattern] = Seq(MAX)

private lazy val max = AttributeReference("max", child.dataType)()

override lazy val aggBufferAttributes: Seq[AttributeReference] = max :: Nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.trees.TreePattern.{MIN, TreePattern}
import org.apache.spark.sql.catalyst.trees.UnaryLike
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
Expand All @@ -43,6 +44,8 @@ case class Min(child: Expression) extends DeclarativeAggregate with UnaryLike[Ex
override def checkInputDataTypes(): TypeCheckResult =
TypeUtils.checkForOrderingExpr(child.dataType, prettyName)

final override val nodePatterns: Seq[TreePattern] = Seq(MIN)

private lazy val min = AttributeReference("min", child.dataType)()

override lazy val aggBufferAttributes: Seq[AttributeReference] = min :: Nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2576,9 +2576,9 @@ object DecimalAggregates extends Rule[LogicalPlan] {
}

def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
_.containsAnyPattern(SUM, AVERAGE), ruleId) {
_.containsAnyPattern(SUM, AVERAGE, MIN, MAX), ruleId) {
case q: LogicalPlan => q.transformExpressionsDownWithPruning(
_.containsAnyPattern(SUM, AVERAGE), ruleId) {
_.containsAnyPattern(SUM, AVERAGE, MIN, MAX), ruleId) {
case we @ WindowExpression(ae @ AggregateExpression(af, _, _, _, _), _) => af match {
// Window arm: `ExtractWindowExpressions` hoists composite children
// (here the widening Cast) into a child Project, so widened-Cast
Expand Down Expand Up @@ -2636,6 +2636,23 @@ object DecimalAggregates extends Rule[LogicalPlan] {
Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)),
DecimalType(prec + 4, scale + 4), Option(conf.sessionLocalTimeZone))

// Hoist a scale-preserving widening Cast out of Min so the Min runs on
// the narrower inner Decimal. Min picks an existing row's value, so a
// widening Cast (same scale, larger precision) is bit-identical to
// applying the Cast after the aggregate. The outer Cast preserves the
// pre-rewrite result dataType (Min.dataType == child.dataType).
case m @ Min(WidenedDecimalChild(inner, _, pPrime, sPrime)) =>
Cast(
ae.copy(aggregateFunction = m.copy(child = inner)),
DecimalType(pPrime, sPrime), Option(conf.sessionLocalTimeZone))

// Hoist a scale-preserving widening Cast out of Max (same reasoning
// as the Min arm above).
case m @ Max(WidenedDecimalChild(inner, _, pPrime, sPrime)) =>
Cast(
ae.copy(aggregateFunction = m.copy(child = inner)),
DecimalType(pPrime, sPrime), Option(conf.sessionLocalTimeZone))

case _ => ae
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@ object TreePattern extends Enumeration {
val MAP_FROM_ARRAYS: Value = Value
val MAP_FROM_ENTRIES: Value = Value
val MAP_OBJECTS: Value = Value
val MAX: Value = Value
val MEASURE: Value = Value
val MIN: Value = Value
val MULTI_ALIAS: Value = Value
val NEW_INSTANCE: Value = Value
val NOT: Value = Value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.scalatestplus.scalacheck.ScalaCheckDrivenPropertyChecks
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{Average, Sum}
import org.apache.spark.sql.catalyst.expressions.aggregate.{Average, MaxBy, MaxMinByK, MinBy, Sum}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
Expand Down Expand Up @@ -613,4 +613,92 @@ class DecimalAggregatesSuite extends PlanTest with ScalaCheckDrivenPropertyCheck
s"evalMode should be preserved as TRY after rewrite, got " +
avgs.map(_.evalMode).mkString(","))
}

// ---------------------------------------------------------------------------
// SPARK-57023: DecimalAggregates widened-Cast peel for MIN/MAX. Same scale
// (s == sPrime) + same-or-larger precision (pPrime >= p) widening Cast is
// bit-identical to applying the Cast after the aggregate, because Min/Max
// pick an existing row's value (no arithmetic). The peel hoists the Cast
// out and runs Min/Max on the narrower inner Decimal.
//
// Vanilla 5.0.0-SNAPSHOT ground-truth (rule OFF vs ON) and design rationale:
// features/spark-decimal-minmax-cast-peel/docs/0002-decision-design.md (rev 3)
// ---------------------------------------------------------------------------

test("SPARK-57023: MIN(CAST(dec(7,2) AS dec(12,2))) peels via widened-Cast fast path") {
val widened = $"d7_2".cast(DecimalType(12, 2))
val originalQuery = widenRel.select(min(widened).as("min_widened"))
val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = widenRel
.select(
Cast(
min($"d7_2"),
DecimalType(12, 2),
Option(conf.sessionLocalTimeZone))
.as("min_widened"))
.analyze

comparePlans(optimized, correctAnswer)
}

test("SPARK-57023: MAX(CAST(dec(7,2) AS dec(12,2))) peels via widened-Cast fast path") {
val widened = $"d7_2".cast(DecimalType(12, 2))
val originalQuery = widenRel.select(max(widened).as("max_widened"))
val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = widenRel
.select(
Cast(
max($"d7_2"),
DecimalType(12, 2),
Option(conf.sessionLocalTimeZone))
.as("max_widened"))
.analyze

comparePlans(optimized, correctAnswer)
}

test("SPARK-57023: MIN(CAST(dec(7,2) AS dec(12,4))) does NOT peel (scale change)") {
val rescaled = $"d7_2".cast(DecimalType(12, 4))
val originalQuery = widenRel.select(min(rescaled).as("min_rescaled"))
val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = originalQuery.analyze

comparePlans(optimized, correctAnswer)
}

test("SPARK-57023: MIN(CAST(dec(17,2) AS dec(10,2))) does NOT peel (narrowing)") {
val narrowed = $"d17_2".cast(DecimalType(10, 2))
val originalQuery = widenRel.select(min(narrowed).as("min_narrowed"))
val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = originalQuery.analyze

comparePlans(optimized, correctAnswer)
}

test("SPARK-57023: MIN/MAX(CheckOverflow) does NOT peel (CheckOverflow guard)") {
val co = CheckOverflow($"d7_2", DecimalType(7, 2), nullOnOverflow = true)
val widened = Cast(co, DecimalType(12, 2))
val originalQuery = widenRel.select(min(widened).as("min_co"), max(widened).as("max_co"))
val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = originalQuery.analyze

comparePlans(optimized, correctAnswer)
}

test("SPARK-57023: MinBy/MaxBy/MaxMinByK with widened-Cast value do NOT peel " +
"(rule pattern matches only Min/Max)") {
val widened = $"d7_2".cast(DecimalType(12, 2))
val ordering = $"i"
val minByExpr = MinBy(widened, ordering).toAggregateExpression()
val maxByExpr = MaxBy(widened, ordering).toAggregateExpression()
val maxMinByKExpr = MaxMinByK(widened, ordering, Literal(3)).toAggregateExpression()
val originalQuery = widenRel.select(
minByExpr.as("min_by_w"),
maxByExpr.as("max_by_w"),
maxMinByKExpr.as("mmbk_w"))
val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = originalQuery.analyze

comparePlans(optimized, correctAnswer)
}
}
Loading