Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions native/core/src/execution/operators/scan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,19 @@ impl ScanExec {

Ok(selection_indices_arrays)
}

pub(crate) fn with_ordering(mut self, input_sorted: Vec<PhysicalSortExpr>) -> Self {
assert_ne!(input_sorted.len(), 0, "input_sorted cannot be empty");
let mut eq_properties = self.cache.eq_properties.clone();

eq_properties.add_ordering(
LexOrdering::new(input_sorted).expect("Must be able to create LexOrdering"),
);

self.cache = self.cache.with_eq_properties(eq_properties);

self
}
}

fn scan_schema(input_batch: &InputBatch, data_types: &[DataType]) -> SchemaRef {
Expand Down
57 changes: 42 additions & 15 deletions native/core/src/execution/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -889,29 +889,38 @@ impl PhysicalPlanner {
/// Create a DataFusion physical sort expression from Spark physical expression
fn create_sort_expr<'a>(
&'a self,
spark_expr: &'a Expr,
spark_expr: &'a spark_expression::Expr,
input_schema: SchemaRef,
) -> Result<PhysicalSortExpr, ExecutionError> {
match spark_expr.expr_struct.as_ref().unwrap() {
ExprStruct::SortOrder(expr) => {
let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?;
let descending = expr.direction == 1;
let nulls_first = expr.null_ordering == 0;

let options = SortOptions {
descending,
nulls_first,
};

Ok(PhysicalSortExpr {
expr: child,
options,
})
self.sort_order_to_physical_sort_expr(expr, input_schema)
}
expr => Err(GeneralError(format!("{expr:?} isn't a SortOrder"))),
}
}

/// Create a DataFusion physical sort expression from Spark physical Sort Order
fn sort_order_to_physical_sort_expr<'a>(
&'a self,
spark_sort_order: &'a spark_expression::SortOrder,
input_schema: SchemaRef,
) -> Result<PhysicalSortExpr, ExecutionError> {
let child = self.create_expr(spark_sort_order.child.as_ref().unwrap(), input_schema)?;
let descending = spark_sort_order.direction == 1;
let nulls_first = spark_sort_order.null_ordering == 0;

let options = SortOptions {
descending,
nulls_first,
};

Ok(PhysicalSortExpr {
expr: child,
options,
})
}

fn create_binary_expr(
&self,
left: &Expr,
Expand Down Expand Up @@ -1384,15 +1393,28 @@ impl PhysicalPlanner {
Some(inputs.remove(0))
};

let input_ordering = scan.input_ordering.clone();

// The `ScanExec` operator will take actual arrays from Spark during execution
let scan = ScanExec::new(
let mut scan = ScanExec::new(
self.exec_context_id,
input_source,
&scan.source,
data_types,
scan.arrow_ffi_safe,
)?;

if !input_ordering.is_empty() {
let sort_exprs: Vec<PhysicalSortExpr> = input_ordering
.iter()
.map(|expr| {
self.sort_order_to_physical_sort_expr(expr, Arc::clone(&scan.schema()))
})
.collect::<Result<_, ExecutionError>>()?;

scan = scan.with_ordering(sort_exprs)
}

Ok((
vec![scan.clone()],
Arc::new(SparkPlan::new(spark_plan.plan_id, Arc::new(scan), vec![])),
Expand Down Expand Up @@ -2932,6 +2954,7 @@ mod tests {
}],
source: "".to_string(),
arrow_ffi_safe: false,
input_ordering: vec![],
})),
};

Expand Down Expand Up @@ -3006,6 +3029,7 @@ mod tests {
}],
source: "".to_string(),
arrow_ffi_safe: false,
input_ordering: vec![],
})),
};

Expand Down Expand Up @@ -3217,6 +3241,7 @@ mod tests {
fields: vec![create_proto_datatype()],
source: "".to_string(),
arrow_ffi_safe: false,
input_ordering: vec![],
})),
}
}
Expand Down Expand Up @@ -3260,6 +3285,7 @@ mod tests {
],
source: "".to_string(),
arrow_ffi_safe: false,
input_ordering: vec![],
})),
};

Expand Down Expand Up @@ -3375,6 +3401,7 @@ mod tests {
],
source: "".to_string(),
arrow_ffi_safe: false,
input_ordering: vec![],
})),
};

Expand Down
2 changes: 2 additions & 0 deletions native/proto/src/proto/operator.proto
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ message Scan {
string source = 2;
// Whether native code can assume ownership of batches that it receives
bool arrow_ffi_safe = 3;

repeated spark.spark_expression.SortOrder input_ordering = 4;
}

message NativeScan {
Expand Down
128 changes: 96 additions & 32 deletions spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

package org.apache.comet.serde

import scala.annotation.tailrec
import scala.collection.mutable.ListBuffer
import scala.jdk.CollectionConverters._

Expand Down Expand Up @@ -731,31 +732,18 @@ object QueryPlanSerde extends Logging with CometExprShim {
None
}

case SortOrder(child, direction, nullOrdering, _) =>
val childExpr = exprToProtoInternal(child, inputs, binding)

if (childExpr.isDefined) {
val sortOrderBuilder = ExprOuterClass.SortOrder.newBuilder()
sortOrderBuilder.setChild(childExpr.get)

direction match {
case Ascending => sortOrderBuilder.setDirectionValue(0)
case Descending => sortOrderBuilder.setDirectionValue(1)
}

nullOrdering match {
case NullsFirst => sortOrderBuilder.setNullOrderingValue(0)
case NullsLast => sortOrderBuilder.setNullOrderingValue(1)
}
case sortOrder @ SortOrder(child, direction, nullOrdering, _) =>
val sortOrderProto = sortOrderingToProto(sortOrder, inputs, binding)

if (sortOrderProto.isEmpty) {
withInfo(expr, child)
None
} else {
Some(
ExprOuterClass.Expr
.newBuilder()
.setSortOrder(sortOrderBuilder)
.setSortOrder(sortOrderProto.get)
.build())
} else {
withInfo(expr, child)
None
}

case UnaryExpression(child) if expr.prettyName == "promote_precision" =>
Expand Down Expand Up @@ -1243,18 +1231,16 @@ object QueryPlanSerde extends Logging with CometExprShim {
if CometConf.COMET_EXEC_WINDOW_ENABLED.get(conf) =>
val output = child.output

val winExprs: Array[WindowExpression] = windowExpression.flatMap { expr =>
expr match {
case alias: Alias =>
alias.child match {
case winExpr: WindowExpression =>
Some(winExpr)
case _ =>
None
}
case _ =>
None
}
val winExprs: Array[WindowExpression] = windowExpression.flatMap {
case alias: Alias =>
alias.child match {
case winExpr: WindowExpression =>
Some(winExpr)
case _ =>
None
}
case _ =>
None
}.toArray

if (winExprs.length != windowExpression.length) {
Expand Down Expand Up @@ -1582,6 +1568,11 @@ object QueryPlanSerde extends Logging with CometExprShim {
scanBuilder.setSource(source)
}

if (op.children.length == 1) {
scanBuilder.addAllInputOrdering(
QueryPlanSerde.parsePlanSortOrderAsMuchAsCan(op.children.head).asJava)
}

val ffiSafe = op match {
case _ if isExchangeSink(op) =>
// Source of broadcast exchange batches is ArrowStreamReader
Expand Down Expand Up @@ -1815,6 +1806,79 @@ object QueryPlanSerde extends Logging with CometExprShim {
})
nativeScanBuilder.addFilePartitions(partitionBuilder.build())
}

def sortOrderingToProto(
sortOrder: SortOrder,
inputs: Seq[Attribute],
binding: Boolean): Option[ExprOuterClass.SortOrder] = {
val childExpr = exprToProtoInternal(sortOrder.child, inputs, binding)

if (childExpr.isDefined) {
val sortOrderBuilder = ExprOuterClass.SortOrder.newBuilder()
sortOrderBuilder.setChild(childExpr.get)

sortOrder.direction match {
case Ascending => sortOrderBuilder.setDirectionValue(0)
case Descending => sortOrderBuilder.setDirectionValue(1)
}

sortOrder.nullOrdering match {
case NullsFirst => sortOrderBuilder.setNullOrderingValue(0)
case NullsLast => sortOrderBuilder.setNullOrderingValue(1)
}

Some(sortOrderBuilder.build())
} else {
withInfo(sortOrder, sortOrder.child)
None
}
}

/**
* Return the plan input sort order.
*
* This will not return the full sort order if it can't be fully mapped to the child (if the
* sort order is on an expression that is not a direct child of the input)
*
* in case this is the sort: Sort by a, b, coalesce(c, d), e
*
* We will return this sort order: a, b
*
* as it is still correct, the data IS ordered by a, b.
*
* And not: a, b, e
*
* as the data IS NOT ordered by a, b, e.
*
* This is meant to use for scan where we don't want to lose the input ordering information as
* it can allow certain optimization.
*/
def parsePlanSortOrderAsMuchAsCan(plan: SparkPlan): Seq[ExprOuterClass.SortOrder] = {
if (plan.outputOrdering.isEmpty) {
Seq.empty
} else {
val outputAttributes = plan.output
val sortOrders = plan.outputOrdering.map(so => {
if (!isExprOneOfAttributes(so.child, outputAttributes)) {
None
} else {
QueryPlanSerde.sortOrderingToProto(so, outputAttributes, binding = true)
}
})

// Take the sort orders until the first None
sortOrders.takeWhile(_.isDefined).map(_.get)
}
}

@tailrec
private def isExprOneOfAttributes(expr: Expression, attrs: Seq[Attribute]): Boolean = {
expr match {
case attr: Attribute => attrs.exists(_.exprId == attr.exprId)
case alias: Alias => isExprOneOfAttributes(alias.child, attrs)
case _ => false
}
}
}

sealed trait SupportLevel
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ case class CometCollectLimitExec(
childRDD
} else {
val localLimitedRDD = if (limit >= 0) {
CometExecUtils.getNativeLimitRDD(childRDD, output, limit)
CometExecUtils.getNativeLimitRDD(child, childRDD, output, limit)
} else {
childRDD
}
Expand All @@ -92,7 +92,7 @@ case class CometCollectLimitExec(

new CometShuffledBatchRDD(dep, readMetrics)
}
CometExecUtils.getNativeLimitRDD(singlePartitionRDD, output, limit, offset)
CometExecUtils.getNativeLimitRDD(child, singlePartitionRDD, output, limit, offset)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression, So
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.vectorized.ColumnarBatch

import org.apache.comet.serde.OperatorOuterClass
import org.apache.comet.serde.{OperatorOuterClass, QueryPlanSerde}
import org.apache.comet.serde.OperatorOuterClass.Operator
import org.apache.comet.serde.QueryPlanSerde.{exprToProto, serializeDataType}

Expand All @@ -48,13 +48,15 @@ object CometExecUtils {
* partition. The limit operation is performed on the native side.
*/
def getNativeLimitRDD(
childPlan: RDD[ColumnarBatch],
childPlan: SparkPlan,
child: RDD[ColumnarBatch],
outputAttribute: Seq[Attribute],
limit: Int,
offset: Int = 0): RDD[ColumnarBatch] = {
val numParts = childPlan.getNumPartitions
childPlan.mapPartitionsWithIndexInternal { case (idx, iter) =>
val limitOp = CometExecUtils.getLimitNativePlan(outputAttribute, limit, offset).get
val numParts = child.getNumPartitions
child.mapPartitionsWithIndexInternal { case (idx, iter) =>
val limitOp =
CometExecUtils.getLimitNativePlan(childPlan, outputAttribute, limit, offset).get
CometExec.getCometIterator(Seq(iter), outputAttribute.length, limitOp, numParts, idx)
}
}
Expand Down Expand Up @@ -90,10 +92,15 @@ object CometExecUtils {
* child partition
*/
def getLimitNativePlan(
child: SparkPlan,
outputAttributes: Seq[Attribute],
limit: Int,
offset: Int = 0): Option[Operator] = {
val scanBuilder = OperatorOuterClass.Scan.newBuilder().setSource("LimitInput")
val scanBuilder = OperatorOuterClass.Scan
.newBuilder()
.setSource("LimitInput")
.addAllInputOrdering(QueryPlanSerde.parsePlanSortOrderAsMuchAsCan(child).asJava)

val scanOpBuilder = OperatorOuterClass.Operator.newBuilder()

val scanTypes = outputAttributes.flatten { attr =>
Expand Down Expand Up @@ -125,7 +132,11 @@ object CometExecUtils {
child: SparkPlan,
limit: Int,
offset: Int = 0): Option[Operator] = {
val scanBuilder = OperatorOuterClass.Scan.newBuilder().setSource("TopKInput")
val scanBuilder = OperatorOuterClass.Scan
.newBuilder()
.setSource("TopKInput")
.addAllInputOrdering(QueryPlanSerde.parsePlanSortOrderAsMuchAsCan(child).asJava)

val scanOpBuilder = OperatorOuterClass.Operator.newBuilder()

val scanTypes = outputAttributes.flatten { attr =>
Expand Down
Loading