diff --git a/native/core/src/execution/operators/scan.rs b/native/core/src/execution/operators/scan.rs index 502fa0bd11..a17fb1933e 100644 --- a/native/core/src/execution/operators/scan.rs +++ b/native/core/src/execution/operators/scan.rs @@ -434,6 +434,19 @@ impl ScanExec { Ok(selection_indices_arrays) } + + pub(crate) fn with_ordering(mut self, input_sorted: Vec) -> Self { + assert_ne!(input_sorted.len(), 0, "input_sorted cannot be empty"); + let mut eq_properties = self.cache.eq_properties.clone(); + + eq_properties.add_ordering( + LexOrdering::new(input_sorted).expect("Must be able to create LexOrdering"), + ); + + self.cache = self.cache.with_eq_properties(eq_properties); + + self + } } fn scan_schema(input_batch: &InputBatch, data_types: &[DataType]) -> SchemaRef { diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 5aa6ece3bc..b3cc4897aa 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -889,29 +889,38 @@ impl PhysicalPlanner { /// Create a DataFusion physical sort expression from Spark physical expression fn create_sort_expr<'a>( &'a self, - spark_expr: &'a Expr, + spark_expr: &'a spark_expression::Expr, input_schema: SchemaRef, ) -> Result { match spark_expr.expr_struct.as_ref().unwrap() { ExprStruct::SortOrder(expr) => { - let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?; - let descending = expr.direction == 1; - let nulls_first = expr.null_ordering == 0; - - let options = SortOptions { - descending, - nulls_first, - }; - - Ok(PhysicalSortExpr { - expr: child, - options, - }) + self.sort_order_to_physical_sort_expr(expr, input_schema) } expr => Err(GeneralError(format!("{expr:?} isn't a SortOrder"))), } } + /// Create a DataFusion physical sort expression from Spark physical Sort Order + fn sort_order_to_physical_sort_expr<'a>( + &'a self, + spark_sort_order: &'a spark_expression::SortOrder, + input_schema: SchemaRef, + ) -> Result { + let child = self.create_expr(spark_sort_order.child.as_ref().unwrap(), input_schema)?; + let descending = spark_sort_order.direction == 1; + let nulls_first = spark_sort_order.null_ordering == 0; + + let options = SortOptions { + descending, + nulls_first, + }; + + Ok(PhysicalSortExpr { + expr: child, + options, + }) + } + fn create_binary_expr( &self, left: &Expr, @@ -1384,8 +1393,10 @@ impl PhysicalPlanner { Some(inputs.remove(0)) }; + let input_ordering = scan.input_ordering.clone(); + // The `ScanExec` operator will take actual arrays from Spark during execution - let scan = ScanExec::new( + let mut scan = ScanExec::new( self.exec_context_id, input_source, &scan.source, @@ -1393,6 +1404,17 @@ impl PhysicalPlanner { scan.arrow_ffi_safe, )?; + if !input_ordering.is_empty() { + let sort_exprs: Vec = input_ordering + .iter() + .map(|expr| { + self.sort_order_to_physical_sort_expr(expr, Arc::clone(&scan.schema())) + }) + .collect::>()?; + + scan = scan.with_ordering(sort_exprs) + } + Ok(( vec![scan.clone()], Arc::new(SparkPlan::new(spark_plan.plan_id, Arc::new(scan), vec![])), @@ -2932,6 +2954,7 @@ mod tests { }], source: "".to_string(), arrow_ffi_safe: false, + input_ordering: vec![], })), }; @@ -3006,6 +3029,7 @@ mod tests { }], source: "".to_string(), arrow_ffi_safe: false, + input_ordering: vec![], })), }; @@ -3217,6 +3241,7 @@ mod tests { fields: vec![create_proto_datatype()], source: "".to_string(), arrow_ffi_safe: false, + input_ordering: vec![], })), } } @@ -3260,6 +3285,7 @@ mod tests { ], source: "".to_string(), arrow_ffi_safe: false, + input_ordering: vec![], })), }; @@ -3375,6 +3401,7 @@ mod tests { ], source: "".to_string(), arrow_ffi_safe: false, + input_ordering: vec![], })), }; diff --git a/native/proto/src/proto/operator.proto b/native/proto/src/proto/operator.proto index 57e012b369..c89a94a9fd 100644 --- a/native/proto/src/proto/operator.proto +++ b/native/proto/src/proto/operator.proto @@ -79,6 +79,8 @@ message Scan { string source = 2; // Whether native code can assume ownership of batches that it receives bool arrow_ffi_safe = 3; + + repeated spark.spark_expression.SortOrder input_ordering = 4; } message NativeScan { diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 258d275e5b..ae22c61de4 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -19,6 +19,7 @@ package org.apache.comet.serde +import scala.annotation.tailrec import scala.collection.mutable.ListBuffer import scala.jdk.CollectionConverters._ @@ -731,31 +732,18 @@ object QueryPlanSerde extends Logging with CometExprShim { None } - case SortOrder(child, direction, nullOrdering, _) => - val childExpr = exprToProtoInternal(child, inputs, binding) - - if (childExpr.isDefined) { - val sortOrderBuilder = ExprOuterClass.SortOrder.newBuilder() - sortOrderBuilder.setChild(childExpr.get) - - direction match { - case Ascending => sortOrderBuilder.setDirectionValue(0) - case Descending => sortOrderBuilder.setDirectionValue(1) - } - - nullOrdering match { - case NullsFirst => sortOrderBuilder.setNullOrderingValue(0) - case NullsLast => sortOrderBuilder.setNullOrderingValue(1) - } + case sortOrder @ SortOrder(child, direction, nullOrdering, _) => + val sortOrderProto = sortOrderingToProto(sortOrder, inputs, binding) + if (sortOrderProto.isEmpty) { + withInfo(expr, child) + None + } else { Some( ExprOuterClass.Expr .newBuilder() - .setSortOrder(sortOrderBuilder) + .setSortOrder(sortOrderProto.get) .build()) - } else { - withInfo(expr, child) - None } case UnaryExpression(child) if expr.prettyName == "promote_precision" => @@ -1243,18 +1231,16 @@ object QueryPlanSerde extends Logging with CometExprShim { if CometConf.COMET_EXEC_WINDOW_ENABLED.get(conf) => val output = child.output - val winExprs: Array[WindowExpression] = windowExpression.flatMap { expr => - expr match { - case alias: Alias => - alias.child match { - case winExpr: WindowExpression => - Some(winExpr) - case _ => - None - } - case _ => - None - } + val winExprs: Array[WindowExpression] = windowExpression.flatMap { + case alias: Alias => + alias.child match { + case winExpr: WindowExpression => + Some(winExpr) + case _ => + None + } + case _ => + None }.toArray if (winExprs.length != windowExpression.length) { @@ -1582,6 +1568,11 @@ object QueryPlanSerde extends Logging with CometExprShim { scanBuilder.setSource(source) } + if (op.children.length == 1) { + scanBuilder.addAllInputOrdering( + QueryPlanSerde.parsePlanSortOrderAsMuchAsCan(op.children.head).asJava) + } + val ffiSafe = op match { case _ if isExchangeSink(op) => // Source of broadcast exchange batches is ArrowStreamReader @@ -1815,6 +1806,79 @@ object QueryPlanSerde extends Logging with CometExprShim { }) nativeScanBuilder.addFilePartitions(partitionBuilder.build()) } + + def sortOrderingToProto( + sortOrder: SortOrder, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.SortOrder] = { + val childExpr = exprToProtoInternal(sortOrder.child, inputs, binding) + + if (childExpr.isDefined) { + val sortOrderBuilder = ExprOuterClass.SortOrder.newBuilder() + sortOrderBuilder.setChild(childExpr.get) + + sortOrder.direction match { + case Ascending => sortOrderBuilder.setDirectionValue(0) + case Descending => sortOrderBuilder.setDirectionValue(1) + } + + sortOrder.nullOrdering match { + case NullsFirst => sortOrderBuilder.setNullOrderingValue(0) + case NullsLast => sortOrderBuilder.setNullOrderingValue(1) + } + + Some(sortOrderBuilder.build()) + } else { + withInfo(sortOrder, sortOrder.child) + None + } + } + + /** + * Return the plan input sort order. + * + * This will not return the full sort order if it can't be fully mapped to the child (if the + * sort order is on an expression that is not a direct child of the input) + * + * in case this is the sort: Sort by a, b, coalesce(c, d), e + * + * We will return this sort order: a, b + * + * as it is still correct, the data IS ordered by a, b. + * + * And not: a, b, e + * + * as the data IS NOT ordered by a, b, e. + * + * This is meant to use for scan where we don't want to lose the input ordering information as + * it can allow certain optimization. + */ + def parsePlanSortOrderAsMuchAsCan(plan: SparkPlan): Seq[ExprOuterClass.SortOrder] = { + if (plan.outputOrdering.isEmpty) { + Seq.empty + } else { + val outputAttributes = plan.output + val sortOrders = plan.outputOrdering.map(so => { + if (!isExprOneOfAttributes(so.child, outputAttributes)) { + None + } else { + QueryPlanSerde.sortOrderingToProto(so, outputAttributes, binding = true) + } + }) + + // Take the sort orders until the first None + sortOrders.takeWhile(_.isDefined).map(_.get) + } + } + + @tailrec + private def isExprOneOfAttributes(expr: Expression, attrs: Seq[Attribute]): Boolean = { + expr match { + case attr: Attribute => attrs.exists(_.exprId == attr.exprId) + case alias: Alias => isExprOneOfAttributes(alias.child, attrs) + case _ => false + } + } } sealed trait SupportLevel diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometCollectLimitExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometCollectLimitExec.scala index 09794e8e26..edfe4ca425 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometCollectLimitExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometCollectLimitExec.scala @@ -77,7 +77,7 @@ case class CometCollectLimitExec( childRDD } else { val localLimitedRDD = if (limit >= 0) { - CometExecUtils.getNativeLimitRDD(childRDD, output, limit) + CometExecUtils.getNativeLimitRDD(child, childRDD, output, limit) } else { childRDD } @@ -92,7 +92,7 @@ case class CometCollectLimitExec( new CometShuffledBatchRDD(dep, readMetrics) } - CometExecUtils.getNativeLimitRDD(singlePartitionRDD, output, limit, offset) + CometExecUtils.getNativeLimitRDD(child, singlePartitionRDD, output, limit, offset) } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala index fd97fe3fa2..bd455aeeea 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression, So import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.vectorized.ColumnarBatch -import org.apache.comet.serde.OperatorOuterClass +import org.apache.comet.serde.{OperatorOuterClass, QueryPlanSerde} import org.apache.comet.serde.OperatorOuterClass.Operator import org.apache.comet.serde.QueryPlanSerde.{exprToProto, serializeDataType} @@ -48,13 +48,15 @@ object CometExecUtils { * partition. The limit operation is performed on the native side. */ def getNativeLimitRDD( - childPlan: RDD[ColumnarBatch], + childPlan: SparkPlan, + child: RDD[ColumnarBatch], outputAttribute: Seq[Attribute], limit: Int, offset: Int = 0): RDD[ColumnarBatch] = { - val numParts = childPlan.getNumPartitions - childPlan.mapPartitionsWithIndexInternal { case (idx, iter) => - val limitOp = CometExecUtils.getLimitNativePlan(outputAttribute, limit, offset).get + val numParts = child.getNumPartitions + child.mapPartitionsWithIndexInternal { case (idx, iter) => + val limitOp = + CometExecUtils.getLimitNativePlan(childPlan, outputAttribute, limit, offset).get CometExec.getCometIterator(Seq(iter), outputAttribute.length, limitOp, numParts, idx) } } @@ -90,10 +92,15 @@ object CometExecUtils { * child partition */ def getLimitNativePlan( + child: SparkPlan, outputAttributes: Seq[Attribute], limit: Int, offset: Int = 0): Option[Operator] = { - val scanBuilder = OperatorOuterClass.Scan.newBuilder().setSource("LimitInput") + val scanBuilder = OperatorOuterClass.Scan + .newBuilder() + .setSource("LimitInput") + .addAllInputOrdering(QueryPlanSerde.parsePlanSortOrderAsMuchAsCan(child).asJava) + val scanOpBuilder = OperatorOuterClass.Operator.newBuilder() val scanTypes = outputAttributes.flatten { attr => @@ -125,7 +132,11 @@ object CometExecUtils { child: SparkPlan, limit: Int, offset: Int = 0): Option[Operator] = { - val scanBuilder = OperatorOuterClass.Scan.newBuilder().setSource("TopKInput") + val scanBuilder = OperatorOuterClass.Scan + .newBuilder() + .setSource("TopKInput") + .addAllInputOrdering(QueryPlanSerde.parsePlanSortOrderAsMuchAsCan(child).asJava) + val scanOpBuilder = OperatorOuterClass.Operator.newBuilder() val scanTypes = outputAttributes.flatten { attr => diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala index aa89dec137..12a8282cda 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala @@ -77,7 +77,7 @@ case class CometTakeOrderedAndProjectExec( childRDD } else { val localTopK = if (orderingSatisfies) { - CometExecUtils.getNativeLimitRDD(childRDD, child.output, limit) + CometExecUtils.getNativeLimitRDD(child, childRDD, child.output, limit) } else { val numParts = childRDD.getNumPartitions childRDD.mapPartitionsWithIndexInternal { case (idx, iter) => diff --git a/spark/src/test/scala/org/apache/comet/CometNativeSuite.scala b/spark/src/test/scala/org/apache/comet/CometNativeSuite.scala index fd6d3ef535..10a34d36d6 100644 --- a/spark/src/test/scala/org/apache/comet/CometNativeSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometNativeSuite.scala @@ -28,9 +28,15 @@ import org.apache.spark.sql.vectorized.ColumnarBatch class CometNativeSuite extends CometTestBase { test("test handling NPE thrown by JVM") { - val rdd = spark.range(0, 1).rdd.map { value => + val dataset = spark.range(0, 1) + val rdd = dataset.rdd.map { value => val limitOp = - CometExecUtils.getLimitNativePlan(Seq(PrettyAttribute("test", LongType)), 100).get + CometExecUtils + .getLimitNativePlan( + dataset.queryExecution.executedPlan.children.head, + Seq(PrettyAttribute("test", LongType)), + 100) + .get val cometIter = CometExec.getCometIterator( Seq(new Iterator[ColumnarBatch] { override def hasNext: Boolean = true