diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index 2d587076cd..1f969b3f4b 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -240,7 +240,7 @@ class CometSparkSessionExtensions plan.transformUp { case s: ShuffleExchangeExec if isCometPlan(s.child) && isCometNativeShuffleMode(conf) && - QueryPlanSerde.supportPartitioning(s.child.output, s.outputPartitioning)._1 => + QueryPlanSerde.nativeShuffleSupported(s)._1 => logInfo("Comet extension enabled for Native Shuffle") // Switch to use Decimal128 regardless of precision, since Arrow native execution @@ -253,7 +253,7 @@ class CometSparkSessionExtensions case s: ShuffleExchangeExec if (!s.child.supportsColumnar || isCometPlan(s.child)) && isCometJVMShuffleMode( conf) && - QueryPlanSerde.supportPartitioningTypes(s.child.output, s.outputPartitioning)._1 && + QueryPlanSerde.columnarShuffleSupported(s)._1 && !isShuffleOperator(s.child) => logInfo("Comet extension enabled for JVM Columnar Shuffle") CometShuffleExchangeExec(s, shuffleType = CometColumnarShuffle) @@ -719,7 +719,7 @@ class CometSparkSessionExtensions case s: ShuffleExchangeExec => val nativePrecondition = isCometShuffleEnabled(conf) && isCometNativeShuffleMode(conf) && - QueryPlanSerde.supportPartitioning(s.child.output, s.outputPartitioning)._1 + QueryPlanSerde.nativeShuffleSupported(s)._1 val nativeShuffle: Option[SparkPlan] = if (nativePrecondition) { @@ -753,7 +753,7 @@ class CometSparkSessionExtensions // If the child of ShuffleExchangeExec is also a ShuffleExchangeExec, we should not // convert it to CometColumnarShuffle, if (isCometShuffleEnabled(conf) && isCometJVMShuffleMode(conf) && - QueryPlanSerde.supportPartitioningTypes(s.child.output, s.outputPartitioning)._1 && + QueryPlanSerde.columnarShuffleSupported(s)._1 && !isShuffleOperator(s.child)) { val newOp = QueryPlanSerde.operator2Proto(s) @@ -781,22 +781,22 @@ class CometSparkSessionExtensions nativeOrColumnarShuffle.get } else { val isShuffleEnabled = isCometShuffleEnabled(conf) - val outputPartitioning = s.outputPartitioning + s.outputPartitioning val reason = getCometShuffleNotEnabledReason(conf).getOrElse("no reason available") val msg1 = createMessage(!isShuffleEnabled, s"Comet shuffle is not enabled: $reason") val columnarShuffleEnabled = isCometJVMShuffleMode(conf) val msg2 = createMessage( isShuffleEnabled && !columnarShuffleEnabled && !QueryPlanSerde - .supportPartitioning(s.child.output, outputPartitioning) + .nativeShuffleSupported(s) ._1, "Native shuffle: " + - s"${QueryPlanSerde.supportPartitioning(s.child.output, outputPartitioning)._2}") + s"${QueryPlanSerde.nativeShuffleSupported(s)._2}") val typeInfo = QueryPlanSerde - .supportPartitioningTypes(s.child.output, outputPartitioning) + .columnarShuffleSupported(s) ._2 val msg3 = createMessage( isShuffleEnabled && columnarShuffleEnabled && !QueryPlanSerde - .supportPartitioningTypes(s.child.output, outputPartitioning) + .columnarShuffleSupported(s) ._1, "JVM shuffle: " + s"$typeInfo") diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 4bcccf948e..325cf15a1d 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -2785,52 +2785,31 @@ object QueryPlanSerde extends Logging with CometExprShim { * Check if the datatypes of shuffle input are supported. This is used for Columnar shuffle * which supports struct/array. */ - def supportPartitioningTypes( - inputs: Seq[Attribute], - partitioning: Partitioning): (Boolean, String) = { - def supportedDataType(dt: DataType): Boolean = dt match { - case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType | - _: DoubleType | _: StringType | _: BinaryType | _: TimestampType | _: DecimalType | - _: DateType | _: BooleanType => - true - case StructType(fields) => - fields.forall(f => supportedDataType(f.dataType)) && - // Java Arrow stream reader cannot work on duplicate field name - fields.map(f => f.name).distinct.length == fields.length - case ArrayType(ArrayType(_, _), _) => false // TODO: nested array is not supported - case ArrayType(MapType(_, _, _), _) => false // TODO: map array element is not supported - case ArrayType(elementType, _) => - supportedDataType(elementType) - case MapType(MapType(_, _, _), _, _) => false // TODO: nested map is not supported - case MapType(_, MapType(_, _, _), _) => false - case MapType(StructType(_), _, _) => false // TODO: struct map key/value is not supported - case MapType(_, StructType(_), _) => false - case MapType(ArrayType(_, _), _, _) => false // TODO: array map key/value is not supported - case MapType(_, ArrayType(_, _), _) => false - case MapType(keyType, valueType, _) => - supportedDataType(keyType) && supportedDataType(valueType) - case _ => - false - } - + def columnarShuffleSupported(s: ShuffleExchangeExec): (Boolean, String) = { + val inputs = s.child.output + val partitioning = s.outputPartitioning var msg = "" val supported = partitioning match { case HashPartitioning(expressions, _) => + // columnar shuffle supports the same data types (including complex types) both for + // partition keys and for other columns val supported = expressions.map(QueryPlanSerde.exprToProto(_, inputs)).forall(_.isDefined) && - expressions.forall(e => supportedDataType(e.dataType)) && - inputs.forall(attr => supportedDataType(attr.dataType)) + expressions.forall(e => supportedShuffleDataType(e.dataType)) && + inputs.forall(attr => supportedShuffleDataType(attr.dataType)) if (!supported) { msg = s"unsupported Spark partitioning expressions: $expressions" } supported - case SinglePartition => inputs.forall(attr => supportedDataType(attr.dataType)) - case RoundRobinPartitioning(_) => inputs.forall(attr => supportedDataType(attr.dataType)) + case SinglePartition => + inputs.forall(attr => supportedShuffleDataType(attr.dataType)) + case RoundRobinPartitioning(_) => + inputs.forall(attr => supportedShuffleDataType(attr.dataType)) case RangePartitioning(orderings, _) => val supported = orderings.map(QueryPlanSerde.exprToProto(_, inputs)).forall(_.isDefined) && - orderings.forall(e => supportedDataType(e.dataType)) && - inputs.forall(attr => supportedDataType(attr.dataType)) + orderings.forall(e => supportedShuffleDataType(e.dataType)) && + inputs.forall(attr => supportedShuffleDataType(attr.dataType)) if (!supported) { msg = s"unsupported Spark partitioning expressions: $orderings" } @@ -2849,33 +2828,42 @@ object QueryPlanSerde extends Logging with CometExprShim { } /** - * Whether the given Spark partitioning is supported by Comet. + * Whether the given Spark partitioning is supported by Comet native shuffle. */ - def supportPartitioning( - inputs: Seq[Attribute], - partitioning: Partitioning): (Boolean, String) = { - def supportedDataType(dt: DataType): Boolean = dt match { - case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType | - _: DoubleType | _: StringType | _: BinaryType | _: TimestampType | _: DecimalType | - _: DateType | _: BooleanType => + def nativeShuffleSupported(s: ShuffleExchangeExec): (Boolean, String) = { + + /** + * Determine which data types are supported as hash-partition keys in native shuffle. + * + * Hash Partition Key determines how data should be collocated for operations like + * `groupByKey`, `reduceByKey` or `join`. + */ + def supportedPartitionKeyDataType(dt: DataType): Boolean = dt match { + case _: BooleanType | _: ByteType | _: ShortType | _: IntegerType | _: LongType | + _: FloatType | _: DoubleType | _: StringType | _: BinaryType | _: TimestampType | + _: TimestampNTZType | _: DecimalType | _: DateType => true case _ => - // Native shuffle doesn't support struct/array yet false } + val inputs = s.child.output + val partitioning = s.outputPartitioning var msg = "" val supported = partitioning match { case HashPartitioning(expressions, _) => + // native shuffle currently does not support complex types as partition keys + // due to lack of hashing support for those types val supported = expressions.map(QueryPlanSerde.exprToProto(_, inputs)).forall(_.isDefined) && - expressions.forall(e => supportedDataType(e.dataType)) && - inputs.forall(attr => supportedDataType(attr.dataType)) + expressions.forall(e => supportedPartitionKeyDataType(e.dataType)) && + inputs.forall(attr => supportedShuffleDataType(attr.dataType)) if (!supported) { msg = s"unsupported Spark partitioning expressions: $expressions" } supported - case SinglePartition => inputs.forall(attr => supportedDataType(attr.dataType)) + case SinglePartition => + inputs.forall(attr => supportedShuffleDataType(attr.dataType)) case _ => msg = s"unsupported Spark partitioning: ${partitioning.getClass.getName}" false @@ -2889,6 +2877,34 @@ object QueryPlanSerde extends Logging with CometExprShim { } } + /** + * Determine which data types are supported in a shuffle. + */ + def supportedShuffleDataType(dt: DataType): Boolean = dt match { + case _: BooleanType | _: ByteType | _: ShortType | _: IntegerType | _: LongType | + _: FloatType | _: DoubleType | _: StringType | _: BinaryType | _: TimestampType | + _: TimestampNTZType | _: DecimalType | _: DateType => + true + case StructType(fields) => + fields.forall(f => supportedShuffleDataType(f.dataType)) && + // Java Arrow stream reader cannot work on duplicate field name + fields.map(f => f.name).distinct.length == fields.length + case ArrayType(ArrayType(_, _), _) => false // TODO: nested array is not supported + case ArrayType(MapType(_, _, _), _) => false // TODO: map array element is not supported + case ArrayType(elementType, _) => + supportedShuffleDataType(elementType) + case MapType(MapType(_, _, _), _, _) => false // TODO: nested map is not supported + case MapType(_, MapType(_, _, _), _) => false + case MapType(StructType(_), _, _) => false // TODO: struct map key/value is not supported + case MapType(_, StructType(_), _) => false + case MapType(ArrayType(_, _), _, _) => false // TODO: array map key/value is not supported + case MapType(_, ArrayType(_, _), _) => false + case MapType(keyType, valueType, _) => + supportedShuffleDataType(keyType) && supportedShuffleDataType(valueType) + case _ => + false + } + // Utility method. Adds explain info if the result of calling exprToProto is None def optExprWithInfo( optExpr: Option[Expr], @@ -2920,7 +2936,8 @@ object QueryPlanSerde extends Logging with CometExprShim { val canSort = sortOrder.head.dataType match { case _: BooleanType => true case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType | - _: DoubleType | _: TimestampType | _: TimestampType | _: DecimalType | _: DateType => + _: DoubleType | _: TimestampType | _: TimestampNTZType | _: DecimalType | + _: DateType => true case _: BinaryType | _: StringType => true case ArrayType(elementType, _) => canRank(elementType) diff --git a/spark/src/test/scala/org/apache/comet/CometFuzzTestSuite.scala b/spark/src/test/scala/org/apache/comet/CometFuzzTestSuite.scala index 1953085269..3ced5fb8d8 100644 --- a/spark/src/test/scala/org/apache/comet/CometFuzzTestSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometFuzzTestSuite.scala @@ -30,6 +30,7 @@ import org.scalatest.Tag import org.apache.commons.io.FileUtils import org.apache.spark.sql.CometTestBase import org.apache.spark.sql.comet.{CometNativeScanExec, CometScanExec} +import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.internal.SQLConf @@ -161,6 +162,18 @@ class CometFuzzTestSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("shuffle") { + val df = spark.read.parquet(filename) + val df2 = df.repartition(8, df.col("c0")).sort("c1") + df2.collect() + if (CometConf.isExperimentalNativeScan) { + val cometShuffles = collect(df2.queryExecution.executedPlan) { + case exec: CometShuffleExchangeExec => exec + } + assert(1 == cometShuffles.length) + } + } + test("join") { val df = spark.read.parquet(filename) df.createOrReplaceTempView("t1")