Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -106,50 +106,56 @@ case class AggregateInPandasExec(
})

inputRDD.mapPartitionsInternal { iter =>
val prunedProj = UnsafeProjection.create(allInputs, child.output)

val grouped = if (groupingExpressions.isEmpty) {
// Use an empty unsafe row as a place holder for the grouping key
Iterator((new UnsafeRow(), iter))
} else {
GroupedIterator(iter, groupingExpressions, child.output)
}.map { case (key, rows) =>
(key, rows.map(prunedProj))
}
// Only execute on non-empty partitions
if (iter.nonEmpty) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The diff is a little off, this is really the only change

val prunedProj = UnsafeProjection.create(allInputs, child.output)

val context = TaskContext.get()
val grouped = if (groupingExpressions.isEmpty) {
// Use an empty unsafe row as a place holder for the grouping key
Iterator((new UnsafeRow(), iter))
} else {
GroupedIterator(iter, groupingExpressions, child.output)
}.map { case (key, rows) =>
(key, rows.map(prunedProj))
}

// The queue used to buffer input rows so we can drain it to
// combine input with output from Python.
val queue = HybridRowQueue(context.taskMemoryManager(),
new File(Utils.getLocalDir(SparkEnv.get.conf)), groupingExpressions.length)
context.addTaskCompletionListener[Unit] { _ =>
queue.close()
}
val context = TaskContext.get()

// Add rows to queue to join later with the result.
val projectedRowIter = grouped.map { case (groupingKey, rows) =>
queue.add(groupingKey.asInstanceOf[UnsafeRow])
rows
}
// The queue used to buffer input rows so we can drain it to
// combine input with output from Python.
val queue = HybridRowQueue(context.taskMemoryManager(),
new File(Utils.getLocalDir(SparkEnv.get.conf)), groupingExpressions.length)
context.addTaskCompletionListener[Unit] { _ =>
queue.close()
}

// Add rows to queue to join later with the result.
val projectedRowIter = grouped.map { case (groupingKey, rows) =>
queue.add(groupingKey.asInstanceOf[UnsafeRow])
rows
}

val columnarBatchIter = new ArrowPythonRunner(
pyFuncs,
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
argOffsets,
aggInputSchema,
sessionLocalTimeZone,
pythonRunnerConf).compute(projectedRowIter, context.partitionId(), context)

val joinedAttributes =
groupingExpressions.map(_.toAttribute) ++ udfExpressions.map(_.resultAttribute)
val joined = new JoinedRow
val resultProj = UnsafeProjection.create(resultExpressions, joinedAttributes)

columnarBatchIter.map(_.rowIterator.next()).map { aggOutputRow =>
val leftRow = queue.remove()
val joinedRow = joined(leftRow, aggOutputRow)
resultProj(joinedRow)
val columnarBatchIter = new ArrowPythonRunner(
pyFuncs,
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
argOffsets,
aggInputSchema,
sessionLocalTimeZone,
pythonRunnerConf).compute(projectedRowIter, context.partitionId(), context)

val joinedAttributes =
groupingExpressions.map(_.toAttribute) ++ udfExpressions.map(_.resultAttribute)
val joined = new JoinedRow
val resultProj = UnsafeProjection.create(resultExpressions, joinedAttributes)

columnarBatchIter.map(_.rowIterator.next()).map { aggOutputRow =>
val leftRow = queue.remove()
val joinedRow = joined(leftRow, aggOutputRow)
resultProj(joinedRow)
}
} else {
Iterator.empty
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,36 +126,44 @@ case class FlatMapGroupsInPandasExec(
val dedupSchema = StructType.fromAttributes(dedupAttributes)

inputRDD.mapPartitionsInternal { iter =>
val grouped = if (groupingAttributes.isEmpty) {
Iterator(iter)
} else {
val groupedIter = GroupedIterator(iter, groupingAttributes, child.output)
val dedupProj = UnsafeProjection.create(dedupAttributes, child.output)
groupedIter.map {
case (_, groupedRowIter) => groupedRowIter.map(dedupProj)

// Only execute on non-empty partitions
if (iter.nonEmpty) {

val grouped = if (groupingAttributes.isEmpty) {
Iterator(iter)
} else {
val groupedIter = GroupedIterator(iter, groupingAttributes, child.output)
val dedupProj = UnsafeProjection.create(dedupAttributes, child.output)
groupedIter.map {
case (_, groupedRowIter) => groupedRowIter.map(dedupProj)
}
}
}

val context = TaskContext.get()

val columnarBatchIter = new ArrowPythonRunner(
chainedFunc,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
argOffsets,
dedupSchema,
sessionLocalTimeZone,
pythonRunnerConf).compute(grouped, context.partitionId(), context)

val unsafeProj = UnsafeProjection.create(output, output)

columnarBatchIter.flatMap { batch =>
// Grouped Map UDF returns a StructType column in ColumnarBatch, select the children here
val structVector = batch.column(0).asInstanceOf[ArrowColumnVector]
val outputVectors = output.indices.map(structVector.getChild)
val flattenedBatch = new ColumnarBatch(outputVectors.toArray)
flattenedBatch.setNumRows(batch.numRows())
flattenedBatch.rowIterator.asScala
}.map(unsafeProj)
val context = TaskContext.get()

val columnarBatchIter = new ArrowPythonRunner(
chainedFunc,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
argOffsets,
dedupSchema,
sessionLocalTimeZone,
pythonRunnerConf).compute(grouped, context.partitionId(), context)

val unsafeProj = UnsafeProjection.create(output, output)

columnarBatchIter.flatMap { batch =>
// Grouped Map UDF returns a StructType column in ColumnarBatch, select the children here
val structVector = batch.column(0).asInstanceOf[ArrowColumnVector]
val outputVectors = output.indices.map(structVector.getChild)
val flattenedBatch = new ColumnarBatch(outputVectors.toArray)
flattenedBatch.setNumRows(batch.numRows())
flattenedBatch.rowIterator.asScala
}.map(unsafeProj)

} else {
Iterator.empty
}
}
}
}