Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions spark/src/main/scala/org/apache/comet/CometExecIterator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,19 @@ import org.apache.comet.vector.NativeUtil
* The input iterators producing sequence of batches of Arrow Arrays.
* @param protobufQueryPlan
* The serialized bytes of Spark execution plan.
* @param numParts
* The number of partitions.
* @param partitionIndex
* The index of the partition.
*/
class CometExecIterator(
val id: Long,
inputs: Seq[Iterator[ColumnarBatch]],
numOutputCols: Int,
protobufQueryPlan: Array[Byte],
nativeMetrics: CometMetricNode)
nativeMetrics: CometMetricNode,
numParts: Int,
partitionIndex: Int)
extends Iterator[ColumnarBatch] {

private val nativeLib = new Native()
Expand All @@ -55,13 +61,12 @@ class CometExecIterator(
}.toArray
private val plan = {
val configs = createNativeConf
TaskContext.get().numPartitions()
nativeLib.createPlan(
id,
configs,
cometBatchIterators,
protobufQueryPlan,
TaskContext.get().numPartitions(),
numParts,
nativeMetrics,
new CometTaskMemoryManager(id))
}
Expand Down Expand Up @@ -103,10 +108,12 @@ class CometExecIterator(
}

def getNextBatch(): Option[ColumnarBatch] = {
assert(partitionIndex >= 0 && partitionIndex < numParts)

nativeUtil.getNextBatch(
numOutputCols,
(arrayAddrs, schemaAddrs) => {
nativeLib.executePlan(plan, TaskContext.get().partitionId(), arrayAddrs, schemaAddrs)
nativeLib.executePlan(plan, partitionIndex, arrayAddrs, schemaAddrs)
})
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1029,12 +1029,20 @@ class CometSparkSessionExtensions
var firstNativeOp = true
newPlan.transformDown {
case op: CometNativeExec =>
if (firstNativeOp) {
val newPlan = if (firstNativeOp) {
firstNativeOp = false
op.convertBlock()
} else {
op
}

// If reaching leaf node, reset `firstNativeOp` to true
// because it will start a new block in next iteration.
if (op.children.isEmpty) {
firstNativeOp = true
}

newPlan
Comment on lines +1039 to +1045
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here this is the main diff that we don't have in main branch. Because native scan now is a real "native" operator, once we encounter it when transforming down, we need to reset firstNativeOp flag.

case op =>
firstNativeOp = true
op
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ import org.apache.spark.sql.vectorized.ColumnarBatch
private[spark] class CometExecRDD(
sc: SparkContext,
partitionNum: Int,
var f: (Seq[Iterator[ColumnarBatch]]) => Iterator[ColumnarBatch])
var f: (Seq[Iterator[ColumnarBatch]], Int, Int) => Iterator[ColumnarBatch])
extends RDD[ColumnarBatch](sc, Nil) {

override def compute(s: Partition, context: TaskContext): Iterator[ColumnarBatch] = {
f(Seq.empty)
f(Seq.empty, partitionNum, s.index)
}

override protected def getPartitions: Array[Partition] = {
Expand All @@ -46,7 +46,8 @@ private[spark] class CometExecRDD(

object CometExecRDD {
def apply(sc: SparkContext, partitionNum: Int)(
f: (Seq[Iterator[ColumnarBatch]]) => Iterator[ColumnarBatch]): RDD[ColumnarBatch] =
f: (Seq[Iterator[ColumnarBatch]], Int, Int) => Iterator[ColumnarBatch])
: RDD[ColumnarBatch] =
withScope(sc) {
new CometExecRDD(sc, partitionNum, f)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,10 @@ object CometExecUtils {
childPlan: RDD[ColumnarBatch],
outputAttribute: Seq[Attribute],
limit: Int): RDD[ColumnarBatch] = {
childPlan.mapPartitionsInternal { iter =>
val numParts = childPlan.getNumPartitions
childPlan.mapPartitionsWithIndexInternal { case (idx, iter) =>
val limitOp = CometExecUtils.getLimitNativePlan(outputAttribute, limit).get
CometExec.getCometIterator(Seq(iter), outputAttribute.length, limitOp)
CometExec.getCometIterator(Seq(iter), outputAttribute.length, limitOp, numParts, idx)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,13 @@ case class CometTakeOrderedAndProjectExec(
val localTopK = if (orderingSatisfies) {
CometExecUtils.getNativeLimitRDD(childRDD, child.output, limit)
} else {
childRDD.mapPartitionsInternal { iter =>
val numParts = childRDD.getNumPartitions
childRDD.mapPartitionsWithIndexInternal { case (idx, iter) =>
val topK =
CometExecUtils
.getTopKNativePlan(child.output, sortOrder, child, limit)
.get
CometExec.getCometIterator(Seq(iter), child.output.length, topK)
CometExec.getCometIterator(Seq(iter), child.output.length, topK, numParts, idx)
}
}

Expand All @@ -102,7 +103,7 @@ case class CometTakeOrderedAndProjectExec(
val topKAndProjection = CometExecUtils
.getProjectionNativePlan(projectList, child.output, sortOrder, child, limit)
.get
val it = CometExec.getCometIterator(Seq(iter), output.length, topKAndProjection)
val it = CometExec.getCometIterator(Seq(iter), output.length, topKAndProjection, 1, 0)
setSubqueries(it.id, this)

Option(TaskContext.get()).foreach { context =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,20 @@ import org.apache.spark.sql.vectorized.ColumnarBatch
*/
private[spark] class ZippedPartitionsRDD(
sc: SparkContext,
var f: (Seq[Iterator[ColumnarBatch]]) => Iterator[ColumnarBatch],
var f: (Seq[Iterator[ColumnarBatch]], Int, Int) => Iterator[ColumnarBatch],
var zipRdds: Seq[RDD[ColumnarBatch]],
preservesPartitioning: Boolean = false)
extends ZippedPartitionsBaseRDD[ColumnarBatch](sc, zipRdds, preservesPartitioning) {

// We need to get the number of partitions in `compute` but `getNumPartitions` is not available
// on the executors. So we need to capture it here.
private val numParts: Int = this.getNumPartitions

override def compute(s: Partition, context: TaskContext): Iterator[ColumnarBatch] = {
val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions
val iterators =
zipRdds.zipWithIndex.map(pair => pair._1.iterator(partitions(pair._2), context))
f(iterators)
f(iterators, numParts, s.index)
}

override def clearDependencies(): Unit = {
Expand All @@ -52,7 +56,8 @@ private[spark] class ZippedPartitionsRDD(

object ZippedPartitionsRDD {
def apply(sc: SparkContext, rdds: Seq[RDD[ColumnarBatch]])(
f: (Seq[Iterator[ColumnarBatch]]) => Iterator[ColumnarBatch]): RDD[ColumnarBatch] =
f: (Seq[Iterator[ColumnarBatch]], Int, Int) => Iterator[ColumnarBatch])
: RDD[ColumnarBatch] =
withScope(sc) {
new ZippedPartitionsRDD(sc, f, rdds)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,13 +224,14 @@ object CometShuffleExchangeExec extends ShimCometShuffleExchangeExec {
outputPartitioning: Partitioning,
serializer: Serializer,
metrics: Map[String, SQLMetric]): ShuffleDependency[Int, ColumnarBatch, ColumnarBatch] = {
val numParts = rdd.getNumPartitions
val dependency = new CometShuffleDependency[Int, ColumnarBatch, ColumnarBatch](
rdd.map(
(0, _)
), // adding fake partitionId that is always 0 because ShuffleDependency requires it
serializer = serializer,
shuffleWriterProcessor =
new CometShuffleWriteProcessor(outputPartitioning, outputAttributes, metrics),
new CometShuffleWriteProcessor(outputPartitioning, outputAttributes, metrics, numParts),
shuffleType = CometNativeShuffle,
partitioner = new Partitioner {
override def numPartitions: Int = outputPartitioning.numPartitions
Expand Down Expand Up @@ -446,7 +447,8 @@ object CometShuffleExchangeExec extends ShimCometShuffleExchangeExec {
class CometShuffleWriteProcessor(
outputPartitioning: Partitioning,
outputAttributes: Seq[Attribute],
metrics: Map[String, SQLMetric])
metrics: Map[String, SQLMetric],
numParts: Int)
extends ShimCometShuffleWriteProcessor {

private val OFFSET_LENGTH = 8
Expand Down Expand Up @@ -489,7 +491,9 @@ class CometShuffleWriteProcessor(
Seq(newInputs.asInstanceOf[Iterator[ColumnarBatch]]),
outputAttributes.length,
nativePlan,
nativeMetrics)
nativeMetrics,
numParts,
context.partitionId())

while (cometIter.hasNext) {
cometIter.next()
Expand Down
40 changes: 31 additions & 9 deletions spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -120,20 +120,37 @@ object CometExec {
def getCometIterator(
inputs: Seq[Iterator[ColumnarBatch]],
numOutputCols: Int,
nativePlan: Operator): CometExecIterator = {
getCometIterator(inputs, numOutputCols, nativePlan, CometMetricNode(Map.empty))
nativePlan: Operator,
numParts: Int,
partitionIdx: Int): CometExecIterator = {
getCometIterator(
inputs,
numOutputCols,
nativePlan,
CometMetricNode(Map.empty),
numParts,
partitionIdx)
}

def getCometIterator(
inputs: Seq[Iterator[ColumnarBatch]],
numOutputCols: Int,
nativePlan: Operator,
nativeMetrics: CometMetricNode): CometExecIterator = {
nativeMetrics: CometMetricNode,
numParts: Int,
partitionIdx: Int): CometExecIterator = {
val outputStream = new ByteArrayOutputStream()
nativePlan.writeTo(outputStream)
outputStream.close()
val bytes = outputStream.toByteArray
new CometExecIterator(newIterId, inputs, numOutputCols, bytes, nativeMetrics)
new CometExecIterator(
newIterId,
inputs,
numOutputCols,
bytes,
nativeMetrics,
numParts,
partitionIdx)
}

/**
Expand Down Expand Up @@ -214,13 +231,18 @@ abstract class CometNativeExec extends CometExec {
// TODO: support native metrics for all operators.
val nativeMetrics = CometMetricNode.fromCometPlan(this)

def createCometExecIter(inputs: Seq[Iterator[ColumnarBatch]]): CometExecIterator = {
def createCometExecIter(
inputs: Seq[Iterator[ColumnarBatch]],
numParts: Int,
partitionIndex: Int): CometExecIterator = {
val it = new CometExecIterator(
CometExec.newIterId,
inputs,
output.length,
serializedPlanCopy,
nativeMetrics)
nativeMetrics,
numParts,
partitionIndex)

setSubqueries(it.id, this)

Expand Down Expand Up @@ -271,7 +293,7 @@ abstract class CometNativeExec extends CometExec {
// Spark doesn't need to zip Broadcast RDDs, so it doesn't schedule Broadcast RDDs with
// same partition number. But for Comet, we need to zip them so we need to adjust the
// partition number of Broadcast RDDs to make sure they have the same partition number.
sparkPlans.zipWithIndex.foreach { case (plan, idx) =>
sparkPlans.zipWithIndex.foreach { case (plan, _) =>
plan match {
case c: CometBroadcastExchangeExec if firstNonBroadcastPlanNumPartitions.nonEmpty =>
inputs += c
Expand Down Expand Up @@ -315,10 +337,10 @@ abstract class CometNativeExec extends CometExec {
}

if (inputs.nonEmpty) {
ZippedPartitionsRDD(sparkContext, inputs.toSeq)(createCometExecIter(_))
ZippedPartitionsRDD(sparkContext, inputs.toSeq)(createCometExecIter)
} else {
val partitionNum = firstNonBroadcastPlanNumPartitions.get
CometExecRDD(sparkContext, partitionNum)(createCometExecIter(_))
CometExecRDD(sparkContext, partitionNum)(createCometExecIter)
}
}
}
Expand Down
4 changes: 3 additions & 1 deletion spark/src/test/scala/org/apache/comet/CometNativeSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ class CometNativeSuite extends CometTestBase {
override def next(): ColumnarBatch = throw new NullPointerException()
}),
1,
limitOp)
limitOp,
1,
0)
cometIter.next()
cometIter.close()
value
Expand Down
Loading