Skip to content
Draft
18 changes: 18 additions & 0 deletions common/src/main/scala/org/apache/comet/CometConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,24 @@ object CometConf extends ShimCometConf {
.booleanConf
.createWithEnvVarOrDefault("ENABLE_COMET_STRICT_TESTING", false)

val COMET_COST_BASED_OPTIMIZATION_ENABLED: ConfigEntry[Boolean] =
conf("spark.comet.cost.enabled")
.category(CATEGORY_TUNING)
.doc(
"Whether to enable cost-based optimization for Comet. When enabled, Comet will " +
"use a cost model to estimate acceleration factors for operators and make decisions " +
"about whether to use Comet or Spark operators based on estimated performance.")
.booleanConf
.createWithDefault(false)

val COMET_COST_MODEL_CLASS: ConfigEntry[String] =
conf("spark.comet.cost.model.class")
.category(CATEGORY_TUNING)
.doc("The fully qualified class name of the cost model implementation to use for " +
"cost-based optimization. The class must implement the CometCostModel trait.")
.stringConf
.createWithDefault("org.apache.comet.cost.DefaultCometCostModel")

/** Create a config to enable a specific operator */
private def createExecEnabledConfig(
exec: String,
Expand Down
2 changes: 2 additions & 0 deletions docs/source/user-guide/latest/configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ These settings can be used to determine which parts of the plan are accelerated
| Config | Description | Default Value |
|--------|-------------|---------------|
| `spark.comet.batchSize` | The columnar batch size, i.e., the maximum number of rows that a batch can contain. | 8192 |
| `spark.comet.cost.enabled` | Whether to enable cost-based optimization for Comet. When enabled, Comet will use a cost model to estimate acceleration factors for operators and make decisions about whether to use Comet or Spark operators based on estimated performance. | false |
| `spark.comet.cost.model.class` | The fully qualified class name of the cost model implementation to use for cost-based optimization. The class must implement the CometCostModel trait. | org.apache.comet.cost.DefaultCometCostModel |
| `spark.comet.exec.memoryPool` | The type of memory pool to be used for Comet native execution when running Spark in off-heap mode. Available pool types are `greedy_unified` and `fair_unified`. For more information, refer to the [Comet Tuning Guide](https://datafusion.apache.org/comet/user-guide/tuning.html). | fair_unified |
| `spark.comet.exec.memoryPool.fraction` | Fraction of off-heap memory pool that is available to Comet. Only applies to off-heap mode. For more information, refer to the [Comet Tuning Guide](https://datafusion.apache.org/comet/user-guide/tuning.html). | 1.0 |
| `spark.comet.tracing.enabled` | Enable fine-grained tracing of events and memory usage. For more information, refer to the [Comet Tracing Guide](https://datafusion.apache.org/comet/contributor-guide/tracing.html). | false |
Expand Down
3 changes: 3 additions & 0 deletions spark/src/main/scala/org/apache/comet/DataTypeSupport.scala
Original file line number Diff line number Diff line change
Expand Up @@ -79,4 +79,7 @@ object DataTypeSupport {
case _: StructType | _: ArrayType | _: MapType => true
case _ => false
}

def hasComplexTypes(schema: StructType): Boolean =
schema.fields.exists(f => isComplexType(f.dataType))
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.comet.cost

import org.apache.spark.internal.Logging
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.adaptive.{Cost, CostEvaluator}
import org.apache.spark.sql.internal.SQLConf

import org.apache.comet.CometConf

/**
* Simple Cost implementation for Comet cost evaluator.
*/
case class CometCost(value: Double) extends Cost {
override def compare(that: Cost): Int = that match {
case CometCost(thatValue) => java.lang.Double.compare(value, thatValue)
case _ => 0 // If we can't compare, assume equal
}

override def toString: String = s"CometCost($value)"
}

/**
* Comet implementation of Spark's CostEvaluator for adaptive query execution.
*
* This evaluator uses the configured CometCostModel to estimate costs for query plans, allowing
* Spark's adaptive query execution to make informed decisions about whether to use Comet or Spark
* operators based on estimated performance.
*/
class CometCostEvaluator extends CostEvaluator with Logging {

@transient private lazy val costModel: CometCostModel = {
val conf = SQLConf.get
val costModelClass = CometConf.COMET_COST_MODEL_CLASS.get(conf)

try {
// scalastyle:off classforname
val clazz = Class.forName(costModelClass)
// scalastyle:on classforname
val constructor = clazz.getConstructor()
constructor.newInstance().asInstanceOf[CometCostModel]
} catch {
case e: Exception =>
logWarning(
s"Failed to instantiate cost model class '$costModelClass', " +
s"falling back to DefaultCometCostModel. Error: ${e.getMessage}")
new DefaultCometCostModel()
}
}

/**
* Evaluates the cost of executing the given SparkPlan.
*
* This method uses the configured CometCostModel to estimate the acceleration factor for the
* plan, then converts it to a Cost object that Spark's adaptive query execution can use for
* decision making.
*
* @param plan
* The SparkPlan to evaluate
* @return
* A Cost representing the estimated execution cost
*/
override def evaluateCost(plan: SparkPlan): Cost = {
val estimate = costModel.estimateCost(plan)

// Convert acceleration factor to cost
// Lower cost means better performance, so we use the inverse of acceleration factor
// For example:
// - 2.0x acceleration -> cost = 0.5 (half the cost)
// - 0.8x acceleration -> cost = 1.25 (25% more cost)
val costValue = 1.0 / estimate.acceleration

logDebug(
s"Cost evaluation for ${plan.getClass.getSimpleName}: " +
s"acceleration=${estimate.acceleration}, cost=$costValue")

// Create Cost object with the calculated value
CometCost(costValue)
}
}
103 changes: 103 additions & 0 deletions spark/src/main/scala/org/apache/comet/cost/CometCostModel.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.comet.cost

import org.apache.spark.sql.catalyst.expressions.{BinaryArithmetic, Expression}
import org.apache.spark.sql.comet.{CometColumnarToRowExec, CometPlan, CometProjectExec}
import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle, CometShuffleExchangeExec}
import org.apache.spark.sql.execution.SparkPlan

import org.apache.comet.DataTypeSupport

case class CometCostEstimate(acceleration: Double)

trait CometCostModel {

/** Estimate the relative cost of one operator */
def estimateCost(plan: SparkPlan): CometCostEstimate
}

class DefaultCometCostModel extends CometCostModel {

// optimistic default of 2x acceleration
private val defaultAcceleration = 2.0

override def estimateCost(plan: SparkPlan): CometCostEstimate = {
// Walk the entire plan tree and accumulate costs
var totalAcceleration = 0.0
var operatorCount = 0

def collectOperatorCosts(node: SparkPlan): Unit = {
val operatorCost = estimateOperatorCost(node)
totalAcceleration += operatorCost.acceleration
operatorCount += 1

// Recursively process children
node.children.foreach(collectOperatorCosts)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we could remove usage of vars and let the function return the totalAcceleration and operator count itself ?

Something like :

def countItems(list: List[Any], accumulator: Int = 0): Int = {
  list match {
    case head :: tail => countItems(tail, accumulator + 1)
    case Nil => accumulator
  }
}

val myList = List(1, 2, 3, 4)
val count = countItems(myList)  // result: 4

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know how much of this code will still exist by the time the proof-of-concept is working and ready for detailed code review, so I'll hold off from making these changes now.

I am really looking for high-level feedback on the general approach at the moment.

}

collectOperatorCosts(plan)

// Calculate average acceleration across all operators
// This is crude but gives us a starting point
val averageAcceleration = if (operatorCount > 0) {
totalAcceleration / operatorCount.toDouble
} else {
1.0 // No acceleration if no operators
}

CometCostEstimate(averageAcceleration)
}

/** Estimate the cost of a single operator */
private def estimateOperatorCost(plan: SparkPlan): CometCostEstimate = {
plan match {
case op: CometShuffleExchangeExec =>
op.shuffleType match {
case CometNativeShuffle => CometCostEstimate(1.5)
case CometColumnarShuffle =>
if (DataTypeSupport.hasComplexTypes(op.schema)) {
CometCostEstimate(0.8)
} else {
CometCostEstimate(1.1)
}
}
case _: CometColumnarToRowExec =>
CometCostEstimate(1.0)
case op: CometProjectExec =>
val total: Double = op.expressions.map(estimateExpressionCost).sum
CometCostEstimate(total / op.expressions.length.toDouble)
case _: CometPlan =>
CometCostEstimate(defaultAcceleration)
case _ =>
// Spark operator
CometCostEstimate(1.0)
}
}

/** Estimate the cost of an expression */
private def estimateExpressionCost(expr: Expression): Double = {
expr match {
case _: BinaryArithmetic =>
2.0
case _ => defaultAcceleration
}
}
}
11 changes: 10 additions & 1 deletion spark/src/main/scala/org/apache/spark/Plugins.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.{EXECUTOR_MEMORY, EXECUTOR_MEMORY_OVERHEAD, EXECUTOR_MEMORY_OVERHEAD_FACTOR}
import org.apache.spark.sql.internal.StaticSQLConf

import org.apache.comet.CometConf.COMET_ONHEAP_ENABLED
import org.apache.comet.CometConf.{COMET_COST_BASED_OPTIMIZATION_ENABLED, COMET_ONHEAP_ENABLED}
import org.apache.comet.CometSparkSessionExtensions

/**
Expand Down Expand Up @@ -57,6 +57,15 @@ class CometDriverPlugin extends DriverPlugin with Logging with ShimCometDriverPl
// register CometSparkSessionExtensions if it isn't already registered
CometDriverPlugin.registerCometSessionExtension(sc.conf)

// Enable cost-based optimization if configured
if (sc.getConf.getBoolean(COMET_COST_BASED_OPTIMIZATION_ENABLED.key, false)) {
// Set the custom cost evaluator for Spark's adaptive query execution
sc.conf.set(
"spark.sql.adaptive.customCostEvaluatorClass",
"org.apache.comet.cost.CometCostEvaluator")
logInfo("Enabled Comet cost-based optimization with CometCostEvaluator")
}

if (CometSparkSessionExtensions.shouldOverrideMemoryConf(sc.getConf)) {
val execMemOverhead = if (sc.getConf.contains(EXECUTOR_MEMORY_OVERHEAD.key)) {
sc.getConf.getSizeAsMb(EXECUTOR_MEMORY_OVERHEAD.key)
Expand Down
Loading