Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 96 additions & 2 deletions spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,15 @@ package org.apache.comet.rules

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.{Divide, DoubleLiteral, EqualNullSafe, EqualTo, Expression, FloatLiteral, GreaterThan, GreaterThanOrEqual, KnownFloatingPointNormalized, LessThan, LessThanOrEqual, NamedExpression, Remainder}
import org.apache.spark.sql.catalyst.expressions.aggregate.{Final, Partial, PartialMerge}
import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.util.sideBySide
import org.apache.spark.sql.comet._
import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle, CometShuffleExchangeExec}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AQEShuffleReadExec, BroadcastQueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec}
import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec, ObjectHashAggregateExec}
import org.apache.spark.sql.execution.command.{DataWritingCommandExec, ExecutedCommandExec}
import org.apache.spark.sql.execution.datasources.v2.V2CommandExec
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec}
Expand Down Expand Up @@ -103,6 +104,91 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {

private def isCometNative(op: SparkPlan): Boolean = op.isInstanceOf[CometNativeExec]

/**
* Pre-processes the plan to ensure coordination between partial and final hash aggregates.
*
* This method walks the plan top-down to identify final hash aggregates that cannot be
* converted to Comet. For such cases, it finds and tags any corresponding partial aggregates
* with fallback reasons to prevent mixed Comet partial + Spark final aggregation.
*
* @param plan
* The input plan to pre-process
* @return
* The plan with appropriate fallback tags added
*/
private def tagUnsupportedPartialAggregates(plan: SparkPlan): SparkPlan = {
plan.transformDown {
case finalAgg: BaseAggregateExec if hasFinalMode(finalAgg) =>
// Check if this final aggregate can be converted to Comet
val handler = allExecs
.get(finalAgg.getClass)
.map(_.asInstanceOf[CometOperatorSerde[SparkPlan]])

handler match {
case Some(serde) =>
// Get the actual support level and reason for the final aggregate
serde.getSupportLevel(finalAgg) match {
case Unsupported(reasonOpt) =>
// Final aggregate cannot be converted, extract the actual reason
val actualReason = reasonOpt.getOrElse("Final aggregate not supported by Comet")
val reason = s"Cannot convert final aggregate to Comet ($actualReason), " +
"so partial aggregates must also use Spark to avoid mixed execution"
tagRelatedPartialAggregates(finalAgg, reason)
case Incompatible(reasonOpt) =>
// Final aggregate cannot be converted, extract the actual reason
val actualReason = reasonOpt.getOrElse("Final aggregate incompatible with Comet")
val reason = s"Cannot convert final aggregate to Comet ($actualReason), " +
"so partial aggregates must also use Spark to avoid mixed execution"
tagRelatedPartialAggregates(finalAgg, reason)
case Compatible(_) =>
finalAgg
}
case _ =>
finalAgg
}
case other => other
}
}

/**
* Helper method to check if an aggregate has Final mode expressions.
*/
private def hasFinalMode(agg: BaseAggregateExec): Boolean = {
agg.aggregateExpressions.exists(_.mode == Final)
}

/**
* Tags the first related partial aggregate in the subtree with fallback reasons. Stops
* transforming after finding and tagging the first partial aggregate to avoid affecting
* unrelated aggregates elsewhere in the tree.
*/
private def tagRelatedPartialAggregates(plan: SparkPlan, reason: String): SparkPlan = {
var found = false

def transformOnce(node: SparkPlan): SparkPlan = {
if (found) {
node
} else {
node match {
case partialAgg: BaseAggregateExec if hasPartialMode(partialAgg) =>
found = true
withInfo(partialAgg, reason)
case other =>
other.withNewChildren(other.children.map(transformOnce))
}
}
}

transformOnce(plan)
}

/**
* Helper method to check if an aggregate has Partial or PartialMerge mode expressions.
*/
private def hasPartialMode(agg: BaseAggregateExec): Boolean = {
agg.aggregateExpressions.exists(expr => expr.mode == Partial || expr.mode == PartialMerge)
}

// spotless:off

/**
Expand Down Expand Up @@ -239,6 +325,11 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
convertToComet(s, CometShuffleExchangeExec).getOrElse(s)

case op =>
// Check if this operator has already been tagged with fallback reasons
if (hasExplainInfo(op)) {
return op
}
Comment on lines +328 to +331
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that this is also a performance optimization in the case where the rule is applied to a plan multiple times (which does happen with AQE)


// if all children are native (or if this is a leaf node) then see if there is a
// registered handler for creating a fully native plan
if (op.children.forall(_.isInstanceOf[CometNativeExec])) {
Expand Down Expand Up @@ -365,7 +456,10 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
normalizedPlan
}

var newPlan = transform(planWithJoinRewritten)
// Pre-process the plan to ensure coordination between partial and final hash aggregates
val planWithAggregateCoordination = tagUnsupportedPartialAggregates(planWithJoinRewritten)

var newPlan = transform(planWithAggregateCoordination)

// if the plan cannot be run fully natively then explain why (when appropriate
// config is enabled)
Expand Down
14 changes: 14 additions & 0 deletions spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1276,6 +1276,20 @@ object CometObjectHashAggregateExec
override def enabledConfig: Option[ConfigEntry[Boolean]] = Some(
CometConf.COMET_EXEC_AGGREGATE_ENABLED)

override def getSupportLevel(op: ObjectHashAggregateExec): SupportLevel = {
// some unit tests need to disable partial or final hash aggregate support to test that
// CometExecRule does not allow mixed Spark/Comet aggregates
if (!CometConf.COMET_ENABLE_PARTIAL_HASH_AGGREGATE.get(op.conf) &&
op.aggregateExpressions.exists(expr => expr.mode == Partial || expr.mode == PartialMerge)) {
return Unsupported(Some("Partial aggregates disabled via test config"))
}
if (!CometConf.COMET_ENABLE_FINAL_HASH_AGGREGATE.get(op.conf) &&
op.aggregateExpressions.exists(_.mode == Final)) {
return Unsupported(Some("Final aggregates disabled via test config"))
}
Compatible()
}

override def convert(
aggregate: ObjectHashAggregateExec,
builder: Operator.Builder,
Expand Down
Loading
Loading