-
Notifications
You must be signed in to change notification settings - Fork 264
feat: Add support for complex types in native shuffle #1655
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
eec034a
307beef
ae416b3
7811f61
788801b
4ff9f2e
1707161
da79d4e
e9d0029
a9aa537
b5b4d27
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -558,7 +558,7 @@ object QueryPlanSerde extends Logging with CometExprShim { | |
| case c @ Cast(child, dt, timeZoneId, _) => | ||
| handleCast(expr, child, inputs, binding, dt, timeZoneId, evalMode(c)) | ||
|
|
||
| case add @ Add(left, right, _) if supportedDataType(left.dataType) => | ||
| case add @ Add(left, right, _) if supportedShuffleDataType(left.dataType) => | ||
| createMathExpression( | ||
| expr, | ||
| left, | ||
|
|
@@ -569,11 +569,11 @@ object QueryPlanSerde extends Logging with CometExprShim { | |
| add.evalMode == EvalMode.ANSI, | ||
| (builder, mathExpr) => builder.setAdd(mathExpr)) | ||
|
|
||
| case add @ Add(left, _, _) if !supportedDataType(left.dataType) => | ||
| case add @ Add(left, _, _) if !supportedShuffleDataType(left.dataType) => | ||
| withInfo(add, s"Unsupported datatype ${left.dataType}") | ||
| None | ||
|
|
||
| case sub @ Subtract(left, right, _) if supportedDataType(left.dataType) => | ||
| case sub @ Subtract(left, right, _) if supportedShuffleDataType(left.dataType) => | ||
| createMathExpression( | ||
| expr, | ||
| left, | ||
|
|
@@ -584,11 +584,11 @@ object QueryPlanSerde extends Logging with CometExprShim { | |
| sub.evalMode == EvalMode.ANSI, | ||
| (builder, mathExpr) => builder.setSubtract(mathExpr)) | ||
|
|
||
| case sub @ Subtract(left, _, _) if !supportedDataType(left.dataType) => | ||
| case sub @ Subtract(left, _, _) if !supportedShuffleDataType(left.dataType) => | ||
| withInfo(sub, s"Unsupported datatype ${left.dataType}") | ||
| None | ||
|
|
||
| case mul @ Multiply(left, right, _) if supportedDataType(left.dataType) => | ||
| case mul @ Multiply(left, right, _) if supportedShuffleDataType(left.dataType) => | ||
| createMathExpression( | ||
| expr, | ||
| left, | ||
|
|
@@ -600,12 +600,12 @@ object QueryPlanSerde extends Logging with CometExprShim { | |
| (builder, mathExpr) => builder.setMultiply(mathExpr)) | ||
|
|
||
| case mul @ Multiply(left, _, _) => | ||
| if (!supportedDataType(left.dataType)) { | ||
| if (!supportedShuffleDataType(left.dataType)) { | ||
| withInfo(mul, s"Unsupported datatype ${left.dataType}") | ||
| } | ||
| None | ||
|
|
||
| case div @ Divide(left, right, _) if supportedDataType(left.dataType) => | ||
| case div @ Divide(left, right, _) if supportedShuffleDataType(left.dataType) => | ||
| // Datafusion now throws an exception for dividing by zero | ||
| // See https://github.com/apache/arrow-datafusion/pull/6792 | ||
| // For now, use NullIf to swap zeros with nulls. | ||
|
|
@@ -622,12 +622,12 @@ object QueryPlanSerde extends Logging with CometExprShim { | |
| (builder, mathExpr) => builder.setDivide(mathExpr)) | ||
|
|
||
| case div @ Divide(left, _, _) => | ||
| if (!supportedDataType(left.dataType)) { | ||
| if (!supportedShuffleDataType(left.dataType)) { | ||
| withInfo(div, s"Unsupported datatype ${left.dataType}") | ||
| } | ||
| None | ||
|
|
||
| case div @ IntegralDivide(left, right, _) if supportedDataType(left.dataType) => | ||
| case div @ IntegralDivide(left, right, _) if supportedShuffleDataType(left.dataType) => | ||
| val rightExpr = nullIfWhenPrimitive(right) | ||
|
|
||
| val dataType = (left.dataType, right.dataType) match { | ||
|
|
@@ -671,12 +671,12 @@ object QueryPlanSerde extends Logging with CometExprShim { | |
| } | ||
|
|
||
| case div @ IntegralDivide(left, _, _) => | ||
| if (!supportedDataType(left.dataType)) { | ||
| if (!supportedShuffleDataType(left.dataType)) { | ||
| withInfo(div, s"Unsupported datatype ${left.dataType}") | ||
| } | ||
| None | ||
|
|
||
| case rem @ Remainder(left, right, _) if supportedDataType(left.dataType) => | ||
| case rem @ Remainder(left, right, _) if supportedShuffleDataType(left.dataType) => | ||
| val rightExpr = nullIfWhenPrimitive(right) | ||
|
|
||
| createMathExpression( | ||
|
|
@@ -690,7 +690,7 @@ object QueryPlanSerde extends Logging with CometExprShim { | |
| (builder, mathExpr) => builder.setRemainder(mathExpr)) | ||
|
|
||
| case rem @ Remainder(left, _, _) => | ||
| if (!supportedDataType(left.dataType)) { | ||
| if (!supportedShuffleDataType(left.dataType)) { | ||
| withInfo(rem, s"Unsupported datatype ${left.dataType}") | ||
| } | ||
| None | ||
|
|
@@ -816,7 +816,7 @@ object QueryPlanSerde extends Logging with CometExprShim { | |
| withInfo(expr, s"Unsupported datatype $dataType") | ||
| None | ||
| } | ||
| case Literal(_, dataType) if !supportedDataType(dataType) => | ||
| case Literal(_, dataType) if !supportedShuffleDataType(dataType) => | ||
| withInfo(expr, s"Unsupported datatype $dataType") | ||
| None | ||
|
|
||
|
|
@@ -1786,7 +1786,7 @@ object QueryPlanSerde extends Logging with CometExprShim { | |
| ExprOuterClass.Expr.newBuilder().setNormalizeNanAndZero(builder).build() | ||
| } | ||
|
|
||
| case s @ execution.ScalarSubquery(_, _) if supportedDataType(s.dataType) => | ||
| case s @ execution.ScalarSubquery(_, _) if supportedShuffleDataType(s.dataType) => | ||
| val dataType = serializeDataType(s.dataType) | ||
| if (dataType.isEmpty) { | ||
| withInfo(s, s"Scalar subquery returns unsupported datatype ${s.dataType}") | ||
|
|
@@ -2785,52 +2785,28 @@ object QueryPlanSerde extends Logging with CometExprShim { | |
| * Check if the datatypes of shuffle input are supported. This is used for Columnar shuffle | ||
| * which supports struct/array. | ||
| */ | ||
| def supportPartitioningTypes( | ||
| inputs: Seq[Attribute], | ||
| partitioning: Partitioning): (Boolean, String) = { | ||
| def supportedDataType(dt: DataType): Boolean = dt match { | ||
| case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType | | ||
| _: DoubleType | _: StringType | _: BinaryType | _: TimestampType | _: DecimalType | | ||
| _: DateType | _: BooleanType => | ||
| true | ||
| case StructType(fields) => | ||
| fields.forall(f => supportedDataType(f.dataType)) && | ||
| // Java Arrow stream reader cannot work on duplicate field name | ||
| fields.map(f => f.name).distinct.length == fields.length | ||
| case ArrayType(ArrayType(_, _), _) => false // TODO: nested array is not supported | ||
| case ArrayType(MapType(_, _, _), _) => false // TODO: map array element is not supported | ||
| case ArrayType(elementType, _) => | ||
| supportedDataType(elementType) | ||
| case MapType(MapType(_, _, _), _, _) => false // TODO: nested map is not supported | ||
| case MapType(_, MapType(_, _, _), _) => false | ||
| case MapType(StructType(_), _, _) => false // TODO: struct map key/value is not supported | ||
| case MapType(_, StructType(_), _) => false | ||
| case MapType(ArrayType(_, _), _, _) => false // TODO: array map key/value is not supported | ||
| case MapType(_, ArrayType(_, _), _) => false | ||
| case MapType(keyType, valueType, _) => | ||
| supportedDataType(keyType) && supportedDataType(valueType) | ||
| case _ => | ||
| false | ||
| } | ||
|
|
||
| def columnarShuffleSupported(s: ShuffleExchangeExec): (Boolean, String) = { | ||
| val inputs = s.child.output | ||
| val partitioning = s.outputPartitioning | ||
| var msg = "" | ||
| val supported = partitioning match { | ||
| case HashPartitioning(expressions, _) => | ||
| val supported = | ||
| expressions.map(QueryPlanSerde.exprToProto(_, inputs)).forall(_.isDefined) && | ||
| expressions.forall(e => supportedDataType(e.dataType)) && | ||
| inputs.forall(attr => supportedDataType(attr.dataType)) | ||
| expressions.forall(e => supportedShuffleDataType(e.dataType)) && | ||
| inputs.forall(attr => supportedShuffleDataType(attr.dataType)) | ||
| if (!supported) { | ||
| msg = s"unsupported Spark partitioning expressions: $expressions" | ||
| } | ||
| supported | ||
| case SinglePartition => inputs.forall(attr => supportedDataType(attr.dataType)) | ||
| case RoundRobinPartitioning(_) => inputs.forall(attr => supportedDataType(attr.dataType)) | ||
| case SinglePartition => inputs.forall(attr => supportedShuffleDataType(attr.dataType)) | ||
| case RoundRobinPartitioning(_) => | ||
| inputs.forall(attr => supportedShuffleDataType(attr.dataType)) | ||
| case RangePartitioning(orderings, _) => | ||
| val supported = | ||
| orderings.map(QueryPlanSerde.exprToProto(_, inputs)).forall(_.isDefined) && | ||
| orderings.forall(e => supportedDataType(e.dataType)) && | ||
| inputs.forall(attr => supportedDataType(attr.dataType)) | ||
| orderings.forall(e => supportedShuffleDataType(e.dataType)) && | ||
| inputs.forall(attr => supportedShuffleDataType(attr.dataType)) | ||
| if (!supported) { | ||
| msg = s"unsupported Spark partitioning expressions: $orderings" | ||
| } | ||
|
|
@@ -2849,33 +2825,23 @@ object QueryPlanSerde extends Logging with CometExprShim { | |
| } | ||
|
|
||
| /** | ||
| * Whether the given Spark partitioning is supported by Comet. | ||
| * Whether the given Spark partitioning is supported by Comet native shuffle. | ||
| */ | ||
| def supportPartitioning( | ||
| inputs: Seq[Attribute], | ||
| partitioning: Partitioning): (Boolean, String) = { | ||
| def supportedDataType(dt: DataType): Boolean = dt match { | ||
| case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType | | ||
| _: DoubleType | _: StringType | _: BinaryType | _: TimestampType | _: DecimalType | | ||
| _: DateType | _: BooleanType => | ||
| true | ||
| case _ => | ||
| // Native shuffle doesn't support struct/array yet | ||
| false | ||
| } | ||
|
|
||
| def nativeShuffleSupported(s: ShuffleExchangeExec): (Boolean, String) = { | ||
| val inputs = s.child.output | ||
| val partitioning = s.outputPartitioning | ||
| var msg = "" | ||
| val supported = partitioning match { | ||
| case HashPartitioning(expressions, _) => | ||
| val supported = | ||
| expressions.map(QueryPlanSerde.exprToProto(_, inputs)).forall(_.isDefined) && | ||
| expressions.forall(e => supportedDataType(e.dataType)) && | ||
| inputs.forall(attr => supportedDataType(attr.dataType)) | ||
| expressions.forall(e => supportedShuffleDataType(e.dataType)) && | ||
| inputs.forall(attr => supportedShuffleDataType(attr.dataType)) | ||
| if (!supported) { | ||
| msg = s"unsupported Spark partitioning expressions: $expressions" | ||
| } | ||
| supported | ||
| case SinglePartition => inputs.forall(attr => supportedDataType(attr.dataType)) | ||
| case SinglePartition => inputs.forall(attr => supportedShuffleDataType(attr.dataType)) | ||
| case _ => | ||
| msg = s"unsupported Spark partitioning: ${partitioning.getClass.getName}" | ||
| false | ||
|
|
@@ -2889,6 +2855,31 @@ object QueryPlanSerde extends Logging with CometExprShim { | |
| } | ||
| } | ||
|
|
||
| def supportedShuffleDataType(dt: DataType): Boolean = dt match { | ||
| case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType | | ||
| _: DoubleType | _: StringType | _: BinaryType | _: TimestampType | _: TimestampNTZType | | ||
|
||
| _: DecimalType | _: DateType | _: BooleanType => | ||
| true | ||
| case StructType(fields) => | ||
| fields.forall(f => supportedShuffleDataType(f.dataType)) && | ||
| // Java Arrow stream reader cannot work on duplicate field name | ||
| fields.map(f => f.name).distinct.length == fields.length | ||
| case ArrayType(ArrayType(_, _), _) => false // TODO: nested array is not supported | ||
| case ArrayType(MapType(_, _, _), _) => false // TODO: map array element is not supported | ||
| case ArrayType(elementType, _) => | ||
| supportedShuffleDataType(elementType) | ||
| case MapType(MapType(_, _, _), _, _) => false // TODO: nested map is not supported | ||
| case MapType(_, MapType(_, _, _), _) => false | ||
| case MapType(StructType(_), _, _) => false // TODO: struct map key/value is not supported | ||
| case MapType(_, StructType(_), _) => false | ||
| case MapType(ArrayType(_, _), _, _) => false // TODO: array map key/value is not supported | ||
| case MapType(_, ArrayType(_, _), _) => false | ||
| case MapType(keyType, valueType, _) => | ||
| supportedShuffleDataType(keyType) && supportedShuffleDataType(valueType) | ||
| case _ => | ||
| false | ||
| } | ||
|
|
||
| // Utility method. Adds explain info if the result of calling exprToProto is None | ||
| def optExprWithInfo( | ||
| optExpr: Option[Expr], | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -30,6 +30,7 @@ import org.scalatest.Tag | |
| import org.apache.commons.io.FileUtils | ||
| import org.apache.spark.sql.CometTestBase | ||
| import org.apache.spark.sql.comet.{CometNativeScanExec, CometScanExec} | ||
| import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec | ||
| import org.apache.spark.sql.execution.SparkPlan | ||
| import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper | ||
| import org.apache.spark.sql.internal.SQLConf | ||
|
|
@@ -161,6 +162,18 @@ class CometFuzzTestSuite extends CometTestBase with AdaptiveSparkPlanHelper { | |
| } | ||
| } | ||
|
|
||
| test("shuffle") { | ||
| val df = spark.read.parquet(filename) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does the data have complex type?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, the data has arrays and structs but not maps yet |
||
| val df2 = df.repartition(8, df.col("c0")).sort("c1") | ||
| df2.collect() | ||
| if (CometConf.isExperimentalNativeScan) { | ||
| val cometShuffles = collect(df2.queryExecution.executedPlan) { | ||
| case exec: CometShuffleExchangeExec => exec | ||
| } | ||
| assert(1 == cometShuffles.length) | ||
| } | ||
| } | ||
|
|
||
| test("join") { | ||
| val df = spark.read.parquet(filename) | ||
| df.createOrReplaceTempView("t1") | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, it does! This method is removed and we now have a single
supportedShuffleDataTypemethod that is used for both native and columnar shuffle type checks.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for that, I was so confused about having this supported check in at least 3 places