diff --git a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala index f490202537..bd0d3be068 100644 --- a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala +++ b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala @@ -21,6 +21,7 @@ package org.apache.comet.rules import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.{Divide, DoubleLiteral, EqualNullSafe, EqualTo, Expression, FloatLiteral, GreaterThan, GreaterThanOrEqual, KnownFloatingPointNormalized, LessThan, LessThanOrEqual, NamedExpression, Remainder} +import org.apache.spark.sql.catalyst.expressions.aggregate.{Final, Partial, PartialMerge} import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.sideBySide @@ -28,7 +29,7 @@ import org.apache.spark.sql.comet._ import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle, CometShuffleExchangeExec} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AQEShuffleReadExec, BroadcastQueryStageExec, ShuffleQueryStageExec} -import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec} +import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec, ObjectHashAggregateExec} import org.apache.spark.sql.execution.command.{DataWritingCommandExec, ExecutedCommandExec} import org.apache.spark.sql.execution.datasources.v2.V2CommandExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} @@ -103,6 +104,91 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] { private def isCometNative(op: SparkPlan): Boolean = op.isInstanceOf[CometNativeExec] + /** + * Pre-processes the plan to ensure coordination between partial and final hash aggregates. + * + * This method walks the plan top-down to identify final hash aggregates that cannot be + * converted to Comet. For such cases, it finds and tags any corresponding partial aggregates + * with fallback reasons to prevent mixed Comet partial + Spark final aggregation. + * + * @param plan + * The input plan to pre-process + * @return + * The plan with appropriate fallback tags added + */ + private def tagUnsupportedPartialAggregates(plan: SparkPlan): SparkPlan = { + plan.transformDown { + case finalAgg: BaseAggregateExec if hasFinalMode(finalAgg) => + // Check if this final aggregate can be converted to Comet + val handler = allExecs + .get(finalAgg.getClass) + .map(_.asInstanceOf[CometOperatorSerde[SparkPlan]]) + + handler match { + case Some(serde) => + // Get the actual support level and reason for the final aggregate + serde.getSupportLevel(finalAgg) match { + case Unsupported(reasonOpt) => + // Final aggregate cannot be converted, extract the actual reason + val actualReason = reasonOpt.getOrElse("Final aggregate not supported by Comet") + val reason = s"Cannot convert final aggregate to Comet ($actualReason), " + + "so partial aggregates must also use Spark to avoid mixed execution" + tagRelatedPartialAggregates(finalAgg, reason) + case Incompatible(reasonOpt) => + // Final aggregate cannot be converted, extract the actual reason + val actualReason = reasonOpt.getOrElse("Final aggregate incompatible with Comet") + val reason = s"Cannot convert final aggregate to Comet ($actualReason), " + + "so partial aggregates must also use Spark to avoid mixed execution" + tagRelatedPartialAggregates(finalAgg, reason) + case Compatible(_) => + finalAgg + } + case _ => + finalAgg + } + case other => other + } + } + + /** + * Helper method to check if an aggregate has Final mode expressions. + */ + private def hasFinalMode(agg: BaseAggregateExec): Boolean = { + agg.aggregateExpressions.exists(_.mode == Final) + } + + /** + * Tags the first related partial aggregate in the subtree with fallback reasons. Stops + * transforming after finding and tagging the first partial aggregate to avoid affecting + * unrelated aggregates elsewhere in the tree. + */ + private def tagRelatedPartialAggregates(plan: SparkPlan, reason: String): SparkPlan = { + var found = false + + def transformOnce(node: SparkPlan): SparkPlan = { + if (found) { + node + } else { + node match { + case partialAgg: BaseAggregateExec if hasPartialMode(partialAgg) => + found = true + withInfo(partialAgg, reason) + case other => + other.withNewChildren(other.children.map(transformOnce)) + } + } + } + + transformOnce(plan) + } + + /** + * Helper method to check if an aggregate has Partial or PartialMerge mode expressions. + */ + private def hasPartialMode(agg: BaseAggregateExec): Boolean = { + agg.aggregateExpressions.exists(expr => expr.mode == Partial || expr.mode == PartialMerge) + } + // spotless:off /** @@ -239,6 +325,11 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] { convertToComet(s, CometShuffleExchangeExec).getOrElse(s) case op => + // Check if this operator has already been tagged with fallback reasons + if (hasExplainInfo(op)) { + return op + } + // if all children are native (or if this is a leaf node) then see if there is a // registered handler for creating a fully native plan if (op.children.forall(_.isInstanceOf[CometNativeExec])) { @@ -365,7 +456,10 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] { normalizedPlan } - var newPlan = transform(planWithJoinRewritten) + // Pre-process the plan to ensure coordination between partial and final hash aggregates + val planWithAggregateCoordination = tagUnsupportedPartialAggregates(planWithJoinRewritten) + + var newPlan = transform(planWithAggregateCoordination) // if the plan cannot be run fully natively then explain why (when appropriate // config is enabled) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 0a435e5b7a..a10cf1a7b9 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -1276,6 +1276,20 @@ object CometObjectHashAggregateExec override def enabledConfig: Option[ConfigEntry[Boolean]] = Some( CometConf.COMET_EXEC_AGGREGATE_ENABLED) + override def getSupportLevel(op: ObjectHashAggregateExec): SupportLevel = { + // some unit tests need to disable partial or final hash aggregate support to test that + // CometExecRule does not allow mixed Spark/Comet aggregates + if (!CometConf.COMET_ENABLE_PARTIAL_HASH_AGGREGATE.get(op.conf) && + op.aggregateExpressions.exists(expr => expr.mode == Partial || expr.mode == PartialMerge)) { + return Unsupported(Some("Partial aggregates disabled via test config")) + } + if (!CometConf.COMET_ENABLE_FINAL_HASH_AGGREGATE.get(op.conf) && + op.aggregateExpressions.exists(_.mode == Final)) { + return Unsupported(Some("Final aggregates disabled via test config")) + } + Compatible() + } + override def convert( aggregate: ObjectHashAggregateExec, builder: Operator.Builder, diff --git a/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala b/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala index cf6f8918f4..67cbe982bd 100644 --- a/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala +++ b/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala @@ -22,15 +22,19 @@ package org.apache.comet.rules import scala.util.Random import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.FunctionIdentifier +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.ExpressionInfo +import org.apache.spark.sql.catalyst.expressions.aggregate.{BloomFilterAggregate, Final, Partial, PartialMerge} import org.apache.spark.sql.comet._ import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.adaptive.QueryStageExec -import org.apache.spark.sql.execution.aggregate.HashAggregateExec +import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, QueryStageExec, ShuffleQueryStageExec} +import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec, ObjectHashAggregateExec} import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.types.{DataTypes, StructField, StructType} -import org.apache.comet.CometConf +import org.apache.comet.{CometConf, CometExplainInfo} import org.apache.comet.testing.{DataGenOptions, FuzzDataGenerator} /** @@ -73,6 +77,27 @@ class CometExecRuleSuite extends CometTestBase { }.sum } + /** Helper method to find all partial aggregates in a plan */ + private def findPartialAggregates(plan: SparkPlan): Seq[BaseAggregateExec] = { + plan.collect { + case agg: BaseAggregateExec + if agg.aggregateExpressions.exists(expr => + expr.mode == Partial || expr.mode == PartialMerge) => + agg + } + } + + /** Helper method to check if an operator has a specific fallback message */ + private def hasFallbackMessage(op: SparkPlan, expectedMessage: String): Boolean = { + op.getTagValue(CometExplainInfo.EXTENSION_INFO) + .exists(_.contains(expectedMessage)) + } + + /** Helper method to check if an aggregate has Partial or PartialMerge mode expressions */ + private def hasPartialMode(agg: BaseAggregateExec): Boolean = { + agg.aggregateExpressions.exists(expr => expr.mode == Partial || expr.mode == PartialMerge) + } + test( "CometExecRule should apply basic operator transformations, but only when Comet is enabled") { withTempView("test_data") { @@ -131,9 +156,7 @@ class CometExecRuleSuite extends CometTestBase { } } - // TODO this test exposes the bug described in - // https://github.com/apache/datafusion-comet/issues/1389 - ignore("CometExecRule should not allow Comet partial and Spark final hash aggregate") { + test("CometExecRule should not allow Comet partial and Spark final hash aggregate") { withTempView("test_data") { createTestDataFrame.createOrReplaceTempView("test_data") @@ -181,6 +204,323 @@ class CometExecRuleSuite extends CometTestBase { } } + test("CometExecRule should not allow Comet partial and Spark final object hash aggregate") { + val funcId_bloom_filter_agg = new FunctionIdentifier("bloom_filter_agg") + spark.sessionState.functionRegistry.registerFunction( + funcId_bloom_filter_agg, + new ExpressionInfo(classOf[BloomFilterAggregate].getName, "bloom_filter_agg"), + (children: Seq[Expression]) => + children.size match { + case 1 => new BloomFilterAggregate(children.head) + case 2 => new BloomFilterAggregate(children.head, children(1)) + case 3 => new BloomFilterAggregate(children.head, children(1), children(2)) + }) + + try { + withTempView("test_data") { + createTestDataFrame.createOrReplaceTempView("test_data") + + val sparkPlan = + createSparkPlan( + spark, + "SELECT bloom_filter_agg(cast(id as long)) FROM test_data GROUP BY (id % 3)") + + // Count original Spark operators - bloom filter should generate ObjectHashAggregateExec + val originalObjectHashAggCount = + countOperators(sparkPlan, classOf[ObjectHashAggregateExec]) + assert(originalObjectHashAggCount == 2) + + withSQLConf( + CometConf.COMET_ENABLE_FINAL_HASH_AGGREGATE.key -> "false", + CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true") { + val transformedPlan = applyCometExecRule(sparkPlan) + + // if the final object aggregate cannot be converted to Comet, then neither should be + assert( + countOperators( + transformedPlan, + classOf[ObjectHashAggregateExec]) == originalObjectHashAggCount) + assert(countOperators(transformedPlan, classOf[CometHashAggregateExec]) == 0) + + // Verify that the partial aggregate has the expected fallback message + val partialAggregates = findPartialAggregates(transformedPlan) + assert(partialAggregates.nonEmpty, "Should have found at least one partial aggregate") + val expectedMessage = + "Cannot convert final aggregate to Comet (Final aggregates disabled via test config), " + + "so partial aggregates must also use Spark to avoid mixed execution" + assert( + partialAggregates.exists(hasFallbackMessage(_, expectedMessage)), + s"Partial aggregate should have fallback message: $expectedMessage") + } + } + } finally { + spark.sessionState.functionRegistry.dropFunction(funcId_bloom_filter_agg) + } + } + + test("CometExecRule should not allow Spark partial and Comet final object hash aggregate") { + val funcId_bloom_filter_agg = new FunctionIdentifier("bloom_filter_agg") + spark.sessionState.functionRegistry.registerFunction( + funcId_bloom_filter_agg, + new ExpressionInfo(classOf[BloomFilterAggregate].getName, "bloom_filter_agg"), + (children: Seq[Expression]) => + children.size match { + case 1 => new BloomFilterAggregate(children.head) + case 2 => new BloomFilterAggregate(children.head, children(1)) + case 3 => new BloomFilterAggregate(children.head, children(1), children(2)) + }) + + try { + withTempView("test_data") { + createTestDataFrame.createOrReplaceTempView("test_data") + + val sparkPlan = + createSparkPlan( + spark, + "SELECT bloom_filter_agg(cast(id as long)) FROM test_data GROUP BY (id % 3)") + + // Count original Spark operators - bloom filter should generate ObjectHashAggregateExec + val originalObjectHashAggCount = + countOperators(sparkPlan, classOf[ObjectHashAggregateExec]) + assert(originalObjectHashAggCount == 2) + + withSQLConf( + CometConf.COMET_ENABLE_PARTIAL_HASH_AGGREGATE.key -> "false", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", // ObjectHashAggregateExec requires shuffle + CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true") { + val transformedPlan = applyCometExecRule(sparkPlan) + + // if the partial object aggregate cannot be converted to Comet, then neither should be + assert( + countOperators( + transformedPlan, + classOf[ObjectHashAggregateExec]) == originalObjectHashAggCount) + assert(countOperators(transformedPlan, classOf[CometHashAggregateExec]) == 0) + } + } + } finally { + spark.sessionState.functionRegistry.dropFunction(funcId_bloom_filter_agg) + } + } + + test("CometExecRule should coordinate across AQE stages for ObjectHashAggregateExec") { + val funcId_bloom_filter_agg = new FunctionIdentifier("bloom_filter_agg") + spark.sessionState.functionRegistry.registerFunction( + funcId_bloom_filter_agg, + new ExpressionInfo(classOf[BloomFilterAggregate].getName, "bloom_filter_agg"), + (children: Seq[Expression]) => + children.size match { + case 1 => new BloomFilterAggregate(children.head) + case 2 => new BloomFilterAggregate(children.head, children(1)) + case 3 => new BloomFilterAggregate(children.head, children(1), children(2)) + }) + + try { + withTempView("test_data1", "test_data2", "test_data3") { + // Create datasets large enough to force AQE stage creation + val testSchema = new StructType( + Array( + StructField("id", DataTypes.IntegerType, nullable = true), + StructField("group_key", DataTypes.IntegerType, nullable = true), + StructField("value", DataTypes.StringType, nullable = true))) + + // Create multiple tables with larger datasets to force shuffle stages + val data1 = FuzzDataGenerator.generateDataFrame( + new Random(42), + spark, + testSchema, + 1000, + DataGenOptions()) + val data2 = FuzzDataGenerator.generateDataFrame( + new Random(43), + spark, + testSchema, + 1000, + DataGenOptions()) + val data3 = FuzzDataGenerator.generateDataFrame( + new Random(44), + spark, + testSchema, + 1000, + DataGenOptions()) + + data1.createOrReplaceTempView("test_data1") + data2.createOrReplaceTempView("test_data2") + data3.createOrReplaceTempView("test_data3") + + // More aggressive AQE settings to force stage creation + withSQLConf( + "spark.sql.adaptive.enabled" -> "true", + "spark.sql.adaptive.coalescePartitions.enabled" -> "true", + "spark.sql.adaptive.advisoryPartitionSizeInBytes" -> "4KB", // Very small to force stages + "spark.sql.adaptive.skewJoin.enabled" -> "true", + "spark.default.parallelism" -> "8", + "spark.sql.shuffle.partitions" -> "8", + CometConf.COMET_ENABLE_FINAL_HASH_AGGREGATE.key -> "false", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true") { + + // Use a complex query with multiple joins and subqueries to force stage boundaries + // This pattern is more likely to create distinct partial/final aggregate stages + val df = spark.sql(""" + |WITH combined_data AS ( + | SELECT t1.id, t1.group_key, t1.value + | FROM test_data1 t1 + | JOIN test_data2 t2 ON t1.group_key = t2.group_key + | WHERE t1.id % 3 = 0 + | UNION ALL + | SELECT t3.id, t3.group_key, t3.value + | FROM test_data3 t3 + | WHERE t3.id % 5 = 0 + |) + |SELECT + | group_key, + | bloom_filter_agg(cast(id as long)) as bloom_result, + | count(*) as cnt + |FROM combined_data + |GROUP BY group_key + |HAVING count(*) > 1 + |ORDER BY group_key + |""".stripMargin) + + // Execute the plan to trigger AQE stage creation + df.collect() + + // Get the executed plan which should have AQE stages + val executedPlan = df.queryExecution.executedPlan + + // Check if we have QueryStageExec nodes (indicating AQE created stages) + val queryStages = stripAQEPlan(executedPlan).collect { case qs: QueryStageExec => qs } + + assert(queryStages.nonEmpty) + + // Verify that we have ObjectHashAggregateExec operators in the plan + // Need to recursively search through AQE stages + def findObjectHashAggs(plan: SparkPlan): Seq[ObjectHashAggregateExec] = { + val buffer = scala.collection.mutable.ListBuffer[ObjectHashAggregateExec]() + def collect(p: SparkPlan): Unit = { + p match { + case agg: ObjectHashAggregateExec => buffer += agg + case stage: ShuffleQueryStageExec => collect(stage.plan) + case stage: BroadcastQueryStageExec => collect(stage.plan) + case _ => p.children.foreach(collect) + } + } + collect(plan) + buffer.toSeq + } + + val objectHashAggs = findObjectHashAggs(stripAQEPlan(executedPlan)) + assert(objectHashAggs.nonEmpty, "Should have ObjectHashAggregateExec operators") + + // Verify coordination worked - no mixed Comet/Spark aggregation + def findCometHashAggs(plan: SparkPlan): Seq[CometHashAggregateExec] = { + val buffer = scala.collection.mutable.ListBuffer[CometHashAggregateExec]() + def collect(p: SparkPlan): Unit = { + p match { + case agg: CometHashAggregateExec => buffer += agg + case stage: ShuffleQueryStageExec => collect(stage.plan) + case stage: BroadcastQueryStageExec => collect(stage.plan) + case _ => p.children.foreach(collect) + } + } + collect(plan) + buffer.toSeq + } + + val cometHashAggs = findCometHashAggs(executedPlan) + assert( + cometHashAggs.isEmpty, + "Should have no CometHashAggregateExec - coordination should prevent mixed execution") + + // Verify that partial aggregates have the expected fallback message + def findPartialAggsInAQE(plan: SparkPlan): Seq[BaseAggregateExec] = { + val buffer = scala.collection.mutable.ListBuffer[BaseAggregateExec]() + def collect(p: SparkPlan): Unit = { + p match { + case agg: BaseAggregateExec if hasPartialMode(agg) => buffer += agg + case stage: ShuffleQueryStageExec => collect(stage.plan) + case stage: BroadcastQueryStageExec => collect(stage.plan) + case _ => p.children.foreach(collect) + } + } + collect(plan) + buffer.toSeq + } + + // Create a mapping from aggregate operators to their containing QueryStageExec + def buildStageMapping(plan: SparkPlan): Map[BaseAggregateExec, QueryStageExec] = { + val mapping = scala.collection.mutable.Map[BaseAggregateExec, QueryStageExec]() + def collect(p: SparkPlan, currentStage: Option[QueryStageExec]): Unit = { + p match { + case stage: QueryStageExec => + collect(stage.plan, Some(stage)) + case agg: BaseAggregateExec if currentStage.isDefined => + mapping += (agg -> currentStage.get) + p.children.foreach(collect(_, currentStage)) + case _ => + p.children.foreach(collect(_, currentStage)) + } + } + collect(plan, None) + mapping.toMap + } + + val partialAggregates = findPartialAggsInAQE(executedPlan) + if (partialAggregates.nonEmpty) { + val expectedMessage = + "Cannot convert final aggregate to Comet (Final aggregates disabled via test config), " + + "so partial aggregates must also use Spark to avoid mixed execution" + + val partialAggsWithFallback = + partialAggregates.filter(hasFallbackMessage(_, expectedMessage)) + assert( + partialAggsWithFallback.nonEmpty, + s"Should have partial aggregates with fallback message: $expectedMessage") + + // Find final aggregates to verify cross-stage coordination + def findFinalAggsInAQE(plan: SparkPlan): Seq[BaseAggregateExec] = { + val buffer = scala.collection.mutable.ListBuffer[BaseAggregateExec]() + def collect(p: SparkPlan): Unit = { + p match { + case agg: BaseAggregateExec + if agg.aggregateExpressions.exists(_.mode == Final) => + buffer += agg + case stage: ShuffleQueryStageExec => collect(stage.plan) + case stage: BroadcastQueryStageExec => collect(stage.plan) + case _ => p.children.foreach(collect) + } + } + collect(plan) + buffer.toSeq + } + + val finalAggregates = findFinalAggsInAQE(executedPlan) + val stageMapping = buildStageMapping(stripAQEPlan(executedPlan)) + + if (finalAggregates.nonEmpty && partialAggsWithFallback.nonEmpty) { + // Verify that partial and final aggregates are in different stages + val partialStages = partialAggsWithFallback.flatMap(stageMapping.get).distinct + val finalStages = finalAggregates.flatMap(stageMapping.get).distinct + + assert( + partialStages.nonEmpty && finalStages.nonEmpty, + "Should find both partial and final aggregates within QueryStageExec nodes") + + assert( + partialStages.intersect(finalStages).isEmpty, + s"Partial aggregates (stages: ${partialStages.map(_.id)}) and " + + s"final aggregates (stages: ${finalStages.map(_.id)}) should be in different AQE stages " + + "to prove cross-stage coordination is working") + } + } + } + } + } finally { + spark.sessionState.functionRegistry.dropFunction(funcId_bloom_filter_agg) + } + } + test("CometExecRule should apply broadcast exchange transformations") { withTempView("test_data") { createTestDataFrame.createOrReplaceTempView("test_data")