Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -127,19 +127,23 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
def processAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
val spark = dataframe.sparkSession
val schema = dataframe.schema
val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone

SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
val rows = dataframe.queryExecution.executedPlan.execute()
val numPartitions = rows.getNumPartitions
// Conservatively sets it 70% because the size is not accurate but estimated.
val maxBatchSize = (MAX_BATCH_SIZE * 0.7).toLong
var numSent = 0

if (numPartitions > 0) {
type Batch = (Array[Byte], Long)

val batches = rows.mapPartitionsInternal { iter =>
ArrowConverters
.toBatchWithSchemaIterator(iter, schema, MAX_BATCH_SIZE, timeZoneId)
val newIter = ArrowConverters
.toBatchWithSchemaIterator(iter, schema, maxRecordsPerBatch, maxBatchSize, timeZoneId)
newIter.map { batch: Array[Byte] => (batch, newIter.rowCountInLastBatch) }
}

val signal = new Object
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,158 +71,146 @@ private[sql] class ArrowBatchStreamWriter(
}

private[sql] object ArrowConverters extends Logging {

/**
* Maps Iterator from InternalRow to serialized ArrowRecordBatches. Limit ArrowRecordBatch size
* in a batch by setting maxRecordsPerBatch or use 0 to fully consume rowIter.
*/
private[sql] def toBatchIterator(
private[sql] class ArrowBatchIterator(
rowIter: Iterator[InternalRow],
schema: StructType,
maxRecordsPerBatch: Int,
maxRecordsPerBatch: Long,
timeZoneId: String,
context: TaskContext): Iterator[Array[Byte]] = {
context: TaskContext) extends Iterator[Array[Byte]] {

val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
val allocator =
ArrowUtils.rootAllocator.newChildAllocator("toBatchIterator", 0, Long.MaxValue)
protected val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
private val allocator =
ArrowUtils.rootAllocator.newChildAllocator(
s"to${this.getClass.getSimpleName}", 0, Long.MaxValue)

val root = VectorSchemaRoot.create(arrowSchema, allocator)
val unloader = new VectorUnloader(root)
val arrowWriter = ArrowWriter.create(root)
private val root = VectorSchemaRoot.create(arrowSchema, allocator)
protected val unloader = new VectorUnloader(root)
protected val arrowWriter = ArrowWriter.create(root)

context.addTaskCompletionListener[Unit] { _ =>
Option(context).foreach {_.addTaskCompletionListener[Unit] { _ =>
root.close()
allocator.close()
}}

override def hasNext: Boolean = rowIter.hasNext || {
root.close()
allocator.close()
false
}

new Iterator[Array[Byte]] {
override def next(): Array[Byte] = {
val out = new ByteArrayOutputStream()
val writeChannel = new WriteChannel(Channels.newChannel(out))

override def hasNext: Boolean = rowIter.hasNext || {
root.close()
allocator.close()
false
Utils.tryWithSafeFinally {
var rowCount = 0L
while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || rowCount < maxRecordsPerBatch)) {
val row = rowIter.next()
arrowWriter.write(row)
rowCount += 1
}
arrowWriter.finish()
val batch = unloader.getRecordBatch()
MessageSerializer.serialize(writeChannel, batch)
batch.close()
} {
arrowWriter.reset()
}

override def next(): Array[Byte] = {
val out = new ByteArrayOutputStream()
val writeChannel = new WriteChannel(Channels.newChannel(out))

Utils.tryWithSafeFinally {
var rowCount = 0
while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || rowCount < maxRecordsPerBatch)) {
val row = rowIter.next()
arrowWriter.write(row)
rowCount += 1
}
arrowWriter.finish()
val batch = unloader.getRecordBatch()
MessageSerializer.serialize(writeChannel, batch)
batch.close()
} {
arrowWriter.reset()
out.toByteArray
}
}

private[sql] class ArrowBatchWithSchemaIterator(
rowIter: Iterator[InternalRow],
schema: StructType,
maxRecordsPerBatch: Long,
maxEstimatedBatchSize: Long,
timeZoneId: String,
context: TaskContext)
extends ArrowBatchIterator(
rowIter, schema, maxRecordsPerBatch, timeZoneId, context) {

private val arrowSchemaSize = SizeEstimator.estimate(arrowSchema)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this supposed to achieve? This uses a lot of reflective code to figure out the size of the schema object. How is this related to the size of the batch?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hvanhovell, this PR is virtually pure refactoring except the couple of points I mentioned in the PR description. For the question, it came from #38612 to estimate the size of the batch before creating an Arrow batch.

var rowCountInLastBatch: Long = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A couple of questions. Why do we need the rowcount? It is already encoded in the batch itself. If we do need the rowcount, please make the iterator return it in the next call instead relying on a side effect.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic is also from #38468, and this PR is a followup. The return type here is Array[Byte] that is raw binary record batch. So we cannot get the count from that unless we define other case classes to keep the row count. This class is private that is only used in the specific case.


override def next(): Array[Byte] = {
val out = new ByteArrayOutputStream()
val writeChannel = new WriteChannel(Channels.newChannel(out))

rowCountInLastBatch = 0
var estimatedBatchSize = arrowSchemaSize
Utils.tryWithSafeFinally {
// Always write the schema.
MessageSerializer.serialize(writeChannel, arrowSchema)

// Always write the first row.
while (rowIter.hasNext && (
// For maxBatchSize and maxRecordsPerBatch, respect whatever smaller.
// If the size in bytes is positive (set properly), always write the first row.
rowCountInLastBatch == 0 && maxEstimatedBatchSize > 0 ||
// If the size in bytes of rows are 0 or negative, unlimit it.
estimatedBatchSize <= 0 ||
estimatedBatchSize < maxEstimatedBatchSize ||
// If the size of rows are 0 or negative, unlimit it.
maxRecordsPerBatch <= 0 ||
rowCountInLastBatch < maxRecordsPerBatch)) {
val row = rowIter.next()
arrowWriter.write(row)
estimatedBatchSize += row.asInstanceOf[UnsafeRow].getSizeInBytes
rowCountInLastBatch += 1
}
arrowWriter.finish()
val batch = unloader.getRecordBatch()
MessageSerializer.serialize(writeChannel, batch)

// Always write the Ipc options at the end.
ArrowStreamWriter.writeEndOfStream(writeChannel, IpcOption.DEFAULT)

out.toByteArray
batch.close()
} {
arrowWriter.reset()
}

out.toByteArray
}
}

/**
* Maps Iterator from InternalRow to serialized ArrowRecordBatches. Limit ArrowRecordBatch size
* in a batch by setting maxRecordsPerBatch or use 0 to fully consume rowIter.
*/
private[sql] def toBatchIterator(
rowIter: Iterator[InternalRow],
schema: StructType,
maxRecordsPerBatch: Long,
timeZoneId: String,
context: TaskContext): ArrowBatchIterator = {
new ArrowBatchIterator(
rowIter, schema, maxRecordsPerBatch, timeZoneId, context)
}

/**
* Convert the input rows into fully contained arrow batches.
* Different from [[toBatchIterator]], each output arrow batch starts with the schema.
*/
private[sql] def toBatchWithSchemaIterator(
rowIter: Iterator[InternalRow],
schema: StructType,
maxBatchSize: Long,
timeZoneId: String): Iterator[(Array[Byte], Long)] = {
val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
val allocator = ArrowUtils.rootAllocator.newChildAllocator(
"toArrowBatchIterator", 0, Long.MaxValue)

val root = VectorSchemaRoot.create(arrowSchema, allocator)
val unloader = new VectorUnloader(root)
val arrowWriter = ArrowWriter.create(root)
val arrowSchemaSize = SizeEstimator.estimate(arrowSchema)

Option(TaskContext.get).foreach {
_.addTaskCompletionListener[Unit] { _ =>
root.close()
allocator.close()
}
}

new Iterator[(Array[Byte], Long)] {

override def hasNext: Boolean = rowIter.hasNext || {
root.close()
allocator.close()
false
}

override def next(): (Array[Byte], Long) = {
val out = new ByteArrayOutputStream()
val writeChannel = new WriteChannel(Channels.newChannel(out))

var rowCount = 0L
var estimatedBatchSize = arrowSchemaSize
Utils.tryWithSafeFinally {
// Always write the schema.
MessageSerializer.serialize(writeChannel, arrowSchema)

// Always write the first row.
while (rowIter.hasNext && (rowCount == 0 || estimatedBatchSize < maxBatchSize)) {
val row = rowIter.next()
arrowWriter.write(row)
estimatedBatchSize += row.asInstanceOf[UnsafeRow].getSizeInBytes
rowCount += 1
}
arrowWriter.finish()
val batch = unloader.getRecordBatch()
MessageSerializer.serialize(writeChannel, batch)

// Always write the Ipc options at the end.
ArrowStreamWriter.writeEndOfStream(writeChannel, IpcOption.DEFAULT)

batch.close()
} {
arrowWriter.reset()
}

(out.toByteArray, rowCount)
}
}
maxRecordsPerBatch: Long,
maxEstimatedBatchSize: Long,
timeZoneId: String): ArrowBatchWithSchemaIterator = {
new ArrowBatchWithSchemaIterator(
rowIter, schema, maxRecordsPerBatch, maxEstimatedBatchSize, timeZoneId, TaskContext.get)
}

private[sql] def createEmptyArrowBatch(
schema: StructType,
timeZoneId: String): Array[Byte] = {
val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
val allocator = ArrowUtils.rootAllocator.newChildAllocator(
"createEmptyArrowBatch", 0, Long.MaxValue)

val root = VectorSchemaRoot.create(arrowSchema, allocator)
val unloader = new VectorUnloader(root)
val arrowWriter = ArrowWriter.create(root)

val out = new ByteArrayOutputStream()
val writeChannel = new WriteChannel(Channels.newChannel(out))

Utils.tryWithSafeFinally {
arrowWriter.finish()
val batch = unloader.getRecordBatch() // empty batch

MessageSerializer.serialize(writeChannel, arrowSchema)
MessageSerializer.serialize(writeChannel, batch)
ArrowStreamWriter.writeEndOfStream(writeChannel, IpcOption.DEFAULT)

batch.close()
} {
arrowWriter.reset()
}

out.toByteArray
new ArrowBatchWithSchemaIterator(
Iterator.empty, schema, 0L, 0L, timeZoneId, TaskContext.get) {
override def hasNext: Boolean = true
}.next()
}

/**
Expand Down