diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala b/common/src/main/scala/org/apache/comet/CometConf.scala index 7eaae60552..24fbfbab62 100644 --- a/common/src/main/scala/org/apache/comet/CometConf.scala +++ b/common/src/main/scala/org/apache/comet/CometConf.scala @@ -754,6 +754,24 @@ object CometConf extends ShimCometConf { .booleanConf .createWithEnvVarOrDefault("ENABLE_COMET_STRICT_TESTING", false) + val `COMET_COST_BASED_OPTIMIZATION_ENABLED`: ConfigEntry[Boolean] = + conf("spark.comet.cost.enabled") + .category(CATEGORY_TUNING) + .doc( + "Whether to enable cost-based optimization for Comet. When enabled, Comet will " + + "use a cost model to estimate acceleration factors for operators and make decisions " + + "about whether to use Comet or Spark operators based on estimated performance.") + .booleanConf + .createWithDefault(false) + + val COMET_COST_MODEL_CLASS: ConfigEntry[String] = + conf("spark.comet.cost.model.class") + .category(CATEGORY_TUNING) + .doc("The fully qualified class name of the cost model implementation to use for " + + "cost-based optimization. The class must implement the CometCostModel trait.") + .stringConf + .createWithDefault("org.apache.comet.cost.DefaultCometCostModel") + /** Create a config to enable a specific operator */ private def createExecEnabledConfig( exec: String, diff --git a/docs/source/user-guide/latest/configs.md b/docs/source/user-guide/latest/configs.md index 5b416f927d..74b10e83ee 100644 --- a/docs/source/user-guide/latest/configs.md +++ b/docs/source/user-guide/latest/configs.md @@ -122,6 +122,8 @@ These settings can be used to determine which parts of the plan are accelerated | Config | Description | Default Value | |--------|-------------|---------------| | `spark.comet.batchSize` | The columnar batch size, i.e., the maximum number of rows that a batch can contain. | 8192 | +| `spark.comet.cost.enabled` | Whether to enable cost-based optimization for Comet. When enabled, Comet will use a cost model to estimate acceleration factors for operators and make decisions about whether to use Comet or Spark operators based on estimated performance. | false | +| `spark.comet.cost.model.class` | The fully qualified class name of the cost model implementation to use for cost-based optimization. The class must implement the CometCostModel trait. | org.apache.comet.cost.DefaultCometCostModel | | `spark.comet.exec.memoryPool` | The type of memory pool to be used for Comet native execution when running Spark in off-heap mode. Available pool types are `greedy_unified` and `fair_unified`. For more information, refer to the [Comet Tuning Guide](https://datafusion.apache.org/comet/user-guide/tuning.html). | fair_unified | | `spark.comet.exec.memoryPool.fraction` | Fraction of off-heap memory pool that is available to Comet. Only applies to off-heap mode. For more information, refer to the [Comet Tuning Guide](https://datafusion.apache.org/comet/user-guide/tuning.html). | 1.0 | | `spark.comet.tracing.enabled` | Enable fine-grained tracing of events and memory usage. For more information, refer to the [Comet Tracing Guide](https://datafusion.apache.org/comet/contributor-guide/tracing.html). | false | diff --git a/spark/src/main/scala/org/apache/comet/DataTypeSupport.scala b/spark/src/main/scala/org/apache/comet/DataTypeSupport.scala index 9adf829580..694ae8b0e1 100644 --- a/spark/src/main/scala/org/apache/comet/DataTypeSupport.scala +++ b/spark/src/main/scala/org/apache/comet/DataTypeSupport.scala @@ -79,4 +79,7 @@ object DataTypeSupport { case _: StructType | _: ArrayType | _: MapType => true case _ => false } + + def hasComplexTypes(schema: StructType): Boolean = + schema.fields.exists(f => isComplexType(f.dataType)) } diff --git a/spark/src/main/scala/org/apache/comet/cost/CometCostModel.scala b/spark/src/main/scala/org/apache/comet/cost/CometCostModel.scala new file mode 100644 index 0000000000..7903e7feb5 --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/cost/CometCostModel.scala @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.cost + +import scala.jdk.CollectionConverters._ + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.comet.{CometColumnarToRowExec, CometPlan, CometProjectExec} +import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle, CometShuffleExchangeExec} +import org.apache.spark.sql.execution.SparkPlan + +import org.apache.comet.DataTypeSupport +import org.apache.comet.serde.{ExprOuterClass, OperatorOuterClass} + +case class CometCostEstimate(acceleration: Double) + +trait CometCostModel { + + /** Estimate the relative cost of one operator */ + def estimateCost(plan: SparkPlan): CometCostEstimate +} + +class DefaultCometCostModel extends CometCostModel with Logging { + + // optimistic default of 2x acceleration + private val defaultAcceleration = 2.0 + + override def estimateCost(plan: SparkPlan): CometCostEstimate = { + + logTrace(s"estimateCost for $plan") + + // Walk the entire plan tree and accumulate costs + var totalAcceleration = 0.0 + var operatorCount = 0 + + def collectOperatorCosts(node: SparkPlan): Unit = { + val operatorCost = estimateOperatorCost(node) + logTrace( + s"Operator: ${node.getClass.getSimpleName}, " + + s"Cost: ${operatorCost.acceleration}") + totalAcceleration += operatorCost.acceleration + operatorCount += 1 + + // Recursively process children + node.children.foreach(collectOperatorCosts) + } + + collectOperatorCosts(plan) + + // Calculate average acceleration across all operators + // This is crude but gives us a starting point + val averageAcceleration = if (operatorCount > 0) { + totalAcceleration / operatorCount.toDouble + } else { + 1.0 // No acceleration if no operators + } + + logTrace( + s"Plan: ${plan.getClass.getSimpleName}, Total operators: $operatorCount, " + + s"Average acceleration: $averageAcceleration") + + CometCostEstimate(averageAcceleration) + } + + /** Estimate the cost of a single operator */ + private def estimateOperatorCost(plan: SparkPlan): CometCostEstimate = { + val result = plan match { + case op: CometProjectExec => + logTrace("CometProjectExec found - evaluating expressions") + // Cast nativeOp to Operator and extract projection expressions + val operator = op.nativeOp.asInstanceOf[OperatorOuterClass.Operator] + val projection = operator.getProjection + val expressions = projection.getProjectListList.asScala + logTrace(s"Found ${expressions.length} expressions in projection") + + val costs = expressions.map { expr => + val cost = estimateCometExpressionCost(expr) + logTrace(s"Expression cost: $cost") + cost + } + val total = costs.sum + val average = total / expressions.length.toDouble + logTrace(s"CometProjectExec total cost: $total, average: $average") + CometCostEstimate(average) + + case op: CometShuffleExchangeExec => + op.shuffleType match { + case CometNativeShuffle => CometCostEstimate(1.5) + case CometColumnarShuffle => + if (DataTypeSupport.hasComplexTypes(op.schema)) { + CometCostEstimate(0.8) + } else { + CometCostEstimate(1.1) + } + } + case _: CometColumnarToRowExec => + CometCostEstimate(1.0) + case _: CometPlan => + logTrace(s"Generic CometPlan: ${plan.getClass.getSimpleName}") + CometCostEstimate(defaultAcceleration) + case _ => + logTrace(s"Non-Comet operator: ${plan.getClass.getSimpleName}") + // Spark operator + CometCostEstimate(1.0) + } + + logTrace(s"${plan.getClass.getSimpleName} -> acceleration: ${result.acceleration}") + result + } + + /** Estimate the cost of a Comet protobuf expression */ + private def estimateCometExpressionCost(expr: ExprOuterClass.Expr): Double = { + val result = expr.getExprStructCase match { + // Handle specialized expression types + case ExprOuterClass.Expr.ExprStructCase.SUBSTRING => 6.3 + + // Handle generic scalar functions + case ExprOuterClass.Expr.ExprStructCase.SCALARFUNC => + val funcName = expr.getScalarFunc.getFunc + funcName match { + // String expression numbers from CometStringExpressionBenchmark + case "ascii" => 0.6 + case "octet_length" => 0.6 + case "lower" => 3.0 + case "upper" => 3.0 + case "char" => 0.6 + case "initcap" => 0.9 + case "trim" => 0.4 + case "concat_ws" => 0.5 + case "length" => 9.1 + case "repeat" => 0.4 + case "reverse" => 6.9 + case "instr" => 0.6 + case "replace" => 1.3 + case "string_space" => 0.8 + case "translate" => 0.8 + case _ => defaultAcceleration + } + + case _ => + logTrace( + s"Expression: Unknown type ${expr.getExprStructCase} -> " + + s"$defaultAcceleration") + defaultAcceleration + } + result + } + +} diff --git a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala index ed48e36f07..7025bb357d 100644 --- a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala +++ b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala @@ -47,8 +47,9 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.comet.{CometConf, CometExplainInfo, ExtendedExplainInfo} -import org.apache.comet.CometConf.{COMET_SPARK_TO_ARROW_ENABLED, COMET_SPARK_TO_ARROW_SUPPORTED_OPERATOR_LIST} +import org.apache.comet.CometConf.{COMET_COST_BASED_OPTIMIZATION_ENABLED, COMET_COST_MODEL_CLASS, COMET_SPARK_TO_ARROW_ENABLED, COMET_SPARK_TO_ARROW_SUPPORTED_OPERATOR_LIST} import org.apache.comet.CometSparkSessionExtensions._ +import org.apache.comet.cost.CometCostModel import org.apache.comet.rules.CometExecRule.allExecs import org.apache.comet.serde.{CometOperatorSerde, Compatible, Incompatible, OperatorOuterClass, Unsupported} import org.apache.comet.serde.operator._ @@ -97,6 +98,28 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] { private lazy val showTransformations = CometConf.COMET_EXPLAIN_TRANSFORMATIONS.get() + // Cache the cost model to avoid loading the class on every call + @transient private lazy val costModel: Option[CometCostModel] = { + if (COMET_COST_BASED_OPTIMIZATION_ENABLED.get(conf)) { + try { + val costModelClassName = COMET_COST_MODEL_CLASS.get(conf) + // scalastyle:off classforname + val costModelClass = Class.forName(costModelClassName) + // scalastyle:on classforname + val constructor = costModelClass.getConstructor() + Some(constructor.newInstance().asInstanceOf[CometCostModel]) + } catch { + case e: Exception => + logWarning( + s"Failed to load cost model class: ${e.getMessage}. " + + "Falling back to Spark query plan without cost-based optimization.") + None + } + } else { + None + } + } + private def applyCometShuffle(plan: SparkPlan): SparkPlan = { plan.transformUp { case s: ShuffleExchangeExec if CometShuffleExchangeExec.nativeShuffleSupported(s) => @@ -344,7 +367,24 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] { } override def apply(plan: SparkPlan): SparkPlan = { - val newPlan = _apply(plan) + val candidatePlan = _apply(plan) + + // Only apply cost-based optimization if enabled and cost model is available + val newPlan = costModel match { + case Some(model) => + val costBefore = model.estimateCost(plan) + val costAfter = model.estimateCost(candidatePlan) + + if (costAfter.acceleration > costBefore.acceleration) { + candidatePlan + } else { + plan + } + case None => + // Cost-based optimization is disabled or failed to load, return candidate plan + candidatePlan + } + if (showTransformations && !newPlan.fastEquals(plan)) { logInfo(s""" |=== Applying Rule $ruleName === diff --git a/spark/src/test/scala/org/apache/comet/CometCostModelSuite.scala b/spark/src/test/scala/org/apache/comet/CometCostModelSuite.scala new file mode 100644 index 0000000000..5ee530ae69 --- /dev/null +++ b/spark/src/test/scala/org/apache/comet/CometCostModelSuite.scala @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet + +import org.apache.spark.sql.CometTestBase +import org.apache.spark.sql.comet.CometProjectExec +import org.apache.spark.sql.execution.ProjectExec +import org.apache.spark.sql.execution.SparkPlan + +class CometCostModelSuite extends CometTestBase { + + test("CBO should prefer Comet for fast expressions (length)") { + withSQLConf(CometConf.COMET_COST_BASED_OPTIMIZATION_ENABLED.key -> "true") { + withTempView("test_data") { + createSimpleTestData() + val query = "SELECT length(text1), length(text2) FROM test_data" + executeAndCheckOperator( + query, + classOf[CometProjectExec], + "Expected CometProjectExec for fast expression") + } + } + } + + test("CBO should prefer Spark for slow expressions (trim)") { + withSQLConf(CometConf.COMET_COST_BASED_OPTIMIZATION_ENABLED.key -> "true") { + withTempView("test_data") { + createPaddedTestData() + val query = "SELECT trim(text1), trim(text2) FROM test_data" + executeAndCheckOperator( + query, + classOf[ProjectExec], + "Expected Spark ProjectExec for slow expression") + } + } + } + + /** Create simple test data for string operations using parquet to prevent pushdown */ + private def createSimpleTestData(): Unit = { + import testImplicits._ + val df = Seq( + ("hello world", "test string"), + ("comet rocks", "another test"), + ("fast execution", "performance")).toDF("text1", "text2") + + // Write to parquet and read back to prevent projection pushdown + val tempPath = s"${System.getProperty("java.io.tmpdir")}/comet_cost_test_${System.nanoTime()}" + df.write.mode("overwrite").parquet(tempPath) + + val parquetDf = spark.read.parquet(tempPath).repartition(5) + parquetDf.createOrReplaceTempView("test_data") + } + + /** Create padded test data for trim operations using parquet to prevent pushdown */ + private def createPaddedTestData(): Unit = { + import testImplicits._ + val df = Seq( + (" hello world ", " test string "), + (" comet rocks ", " another test "), + (" slow execution ", " performance ")).toDF("text1", "text2") + + // Write to parquet and read back to prevent projection pushdown + val tempPath = + s"${System.getProperty("java.io.tmpdir")}/comet_cost_test_padded_${System.nanoTime()}" + df.write.mode("overwrite").parquet(tempPath) + + val parquetDf = spark.read.parquet(tempPath).repartition(5) + parquetDf.createOrReplaceTempView("test_data") + } + + /** Execute query and check that the expected operator type is used */ + private def executeAndCheckOperator( + query: String, + expectedClass: Class[_], + message: String): Unit = { + + val result = sql(query) + result.collect() // Materialize the plan + + val executedPlan = stripAQEPlan(result.queryExecution.executedPlan) + + val hasProjectExec = findProjectExec(executedPlan) + + assert(hasProjectExec.isDefined, "Should have a project operator") + assert( + expectedClass.isInstance(hasProjectExec.get), + s"$message, got ${hasProjectExec.get.getClass.getSimpleName}") + } + + /** Helper method to find ProjectExec or CometProjectExec in the plan tree */ + private def findProjectExec(plan: SparkPlan): Option[SparkPlan] = { + // More robust recursive search that handles deep nesting + def searchPlan(node: SparkPlan): Option[SparkPlan] = { + if (node.isInstanceOf[ProjectExec] || node.isInstanceOf[CometProjectExec]) { + Some(node) + } else { + // Search all children recursively + for (child <- node.children) { + searchPlan(child) match { + case Some(found) => return Some(found) + case None => // continue searching + } + } + None + } + } + + searchPlan(plan) + } +}