Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ class ArrowPythonRunner(
protected override val timeZoneId: String,
protected override val workerConf: Map[String, String])
extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch](funcs, evalType, argOffsets)
with PythonArrowInput
with PythonArrowOutput {
with BasicPythonArrowInput
with BasicPythonArrowOutput {

override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class CoGroupedArrowPythonRunner(
conf: Map[String, String])
extends BasePythonRunner[
(Iterator[InternalRow], Iterator[InternalRow]), ColumnarBatch](funcs, evalType, argOffsets)
with PythonArrowOutput {
with BasicPythonArrowOutput {

override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,21 @@ import org.apache.spark.util.Utils

/**
* A trait that can be mixed-in with [[BasePythonRunner]]. It implements the logic from
* JVM (an iterator of internal rows) to Python (Arrow).
* JVM (an iterator of internal rows + additional data if required) to Python (Arrow).
*/
private[python] trait PythonArrowInput { self: BasePythonRunner[Iterator[InternalRow], _] =>
private[python] trait PythonArrowInput[IN] { self: BasePythonRunner[IN, _] =>
protected val workerConf: Map[String, String]

protected val schema: StructType

protected val timeZoneId: String

protected def writeIteratorToArrowStream(
root: VectorSchemaRoot,
writer: ArrowStreamWriter,
dataOut: DataOutputStream,
inputIterator: Iterator[IN]): Unit

protected def handleMetadataBeforeExec(stream: DataOutputStream): Unit = {
// Write config for the worker as a number of key -> value pairs of strings
stream.writeInt(workerConf.size)
Expand All @@ -53,7 +59,7 @@ private[python] trait PythonArrowInput { self: BasePythonRunner[Iterator[Interna
protected override def newWriterThread(
env: SparkEnv,
worker: Socket,
inputIterator: Iterator[Iterator[InternalRow]],
inputIterator: Iterator[IN],
partitionIndex: Int,
context: TaskContext): WriterThread = {
new WriterThread(env, worker, inputIterator, partitionIndex, context) {
Expand All @@ -74,17 +80,8 @@ private[python] trait PythonArrowInput { self: BasePythonRunner[Iterator[Interna
val writer = new ArrowStreamWriter(root, null, dataOut)
writer.start()

while (inputIterator.hasNext) {
val nextBatch = inputIterator.next()

while (nextBatch.hasNext) {
arrowWriter.write(nextBatch.next())
}
writeIteratorToArrowStream(root, writer, dataOut, inputIterator)

arrowWriter.finish()
writer.writeBatch()
arrowWriter.reset()
}
// end writes footer to the output stream and doesn't clean any resources.
// It could throw exception if the output stream is closed, so it should be
// in the try block.
Expand All @@ -107,3 +104,27 @@ private[python] trait PythonArrowInput { self: BasePythonRunner[Iterator[Interna
}
}
}

private[python] trait BasicPythonArrowInput extends PythonArrowInput[Iterator[InternalRow]] {
self: BasePythonRunner[Iterator[InternalRow], _] =>

protected def writeIteratorToArrowStream(
root: VectorSchemaRoot,
writer: ArrowStreamWriter,
dataOut: DataOutputStream,
inputIterator: Iterator[Iterator[InternalRow]]): Unit = {
val arrowWriter = ArrowWriter.create(root)

while (inputIterator.hasNext) {
val nextBatch = inputIterator.next()

while (nextBatch.hasNext) {
arrowWriter.write(nextBatch.next())
}

arrowWriter.finish()
writer.writeBatch()
arrowWriter.reset()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,14 @@ import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, Column

/**
* A trait that can be mixed-in with [[BasePythonRunner]]. It implements the logic from
* Python (Arrow) to JVM (ColumnarBatch).
* Python (Arrow) to JVM (output type being deserialized from ColumnarBatch).
*/
private[python] trait PythonArrowOutput { self: BasePythonRunner[_, ColumnarBatch] =>
private[python] trait PythonArrowOutput[OUT <: AnyRef] { self: BasePythonRunner[_, OUT] =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qq: should it be <: AnyRef?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We assign null to the OUT type (although that's a trick) hence need to be AnyRef at least if I understand correctly.


protected def handleMetadataAfterExec(stream: DataInputStream): Unit = { }

protected def deserializeColumnarBatch(batch: ColumnarBatch, schema: StructType): OUT

protected def newReaderIterator(
stream: DataInputStream,
writerThread: WriterThread,
Expand All @@ -47,7 +49,7 @@ private[python] trait PythonArrowOutput { self: BasePythonRunner[_, ColumnarBatc
worker: Socket,
pid: Option[Int],
releasedOrClosed: AtomicBoolean,
context: TaskContext): Iterator[ColumnarBatch] = {
context: TaskContext): Iterator[OUT] = {

new ReaderIterator(
stream, writerThread, startTime, env, worker, pid, releasedOrClosed, context) {
Expand All @@ -74,7 +76,7 @@ private[python] trait PythonArrowOutput { self: BasePythonRunner[_, ColumnarBatc
super.handleEndOfDataSection()
}

protected override def read(): ColumnarBatch = {
protected override def read(): OUT = {
if (writerThread.exception.isDefined) {
throw writerThread.exception.get
}
Expand All @@ -84,7 +86,7 @@ private[python] trait PythonArrowOutput { self: BasePythonRunner[_, ColumnarBatc
if (batchLoaded) {
val batch = new ColumnarBatch(vectors)
batch.setNumRows(root.getRowCount)
batch
deserializeColumnarBatch(batch, schema)
} else {
reader.close(false)
allocator.close()
Expand All @@ -108,11 +110,19 @@ private[python] trait PythonArrowOutput { self: BasePythonRunner[_, ColumnarBatc
throw handlePythonException()
case SpecialLengths.END_OF_DATA_SECTION =>
handleEndOfDataSection()
null
null.asInstanceOf[OUT]
}
}
} catch handleException
}
}
}
}

private[python] trait BasicPythonArrowOutput extends PythonArrowOutput[ColumnarBatch] {
self: BasePythonRunner[_, ColumnarBatch] =>

protected def deserializeColumnarBatch(
batch: ColumnarBatch,
schema: StructType): ColumnarBatch = batch
}