Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ import org.apache.spark.sql.execution.arrow.ArrowConverters
class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) extends Logging {

// The maximum batch size in bytes for a single batch of data to be returned via proto.
val MAX_BATCH_SIZE: Long = 10 * 1024 * 1024
private val MAX_BATCH_SIZE: Long = 4 * 1024 * 1024

def handle(v: Request): Unit = {
val session =
Expand Down Expand Up @@ -127,8 +127,6 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
def processAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
val spark = dataframe.sparkSession
val schema = dataframe.schema
// TODO: control the batch size instead of max records
val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone

SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
Expand All @@ -141,7 +139,7 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte

val batches = rows.mapPartitionsInternal { iter =>
ArrowConverters
.toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
.toBatchWithSchemaIterator(iter, schema, MAX_BATCH_SIZE, timeZoneId)
}

val signal = new Object
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@ import org.apache.spark.internal.Logging
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.ArrowUtils
import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector}
import org.apache.spark.util.{ByteBufferOutputStream, Utils}
import org.apache.spark.util.{ByteBufferOutputStream, SizeEstimator, Utils}


/**
Expand Down Expand Up @@ -128,10 +128,14 @@ private[sql] object ArrowConverters extends Logging {
}
}

private[sql] def toArrowBatchIterator(
/**
* Convert the input rows into fully contained arrow batches.
* Different from [[toBatchIterator]], each output arrow batch starts with the schema.
*/
private[sql] def toBatchWithSchemaIterator(
rowIter: Iterator[InternalRow],
schema: StructType,
maxRecordsPerBatch: Int,
maxBatchSize: Long,
timeZoneId: String): Iterator[(Array[Byte], Long)] = {
val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
val allocator = ArrowUtils.rootAllocator.newChildAllocator(
Expand All @@ -140,6 +144,7 @@ private[sql] object ArrowConverters extends Logging {
val root = VectorSchemaRoot.create(arrowSchema, allocator)
val unloader = new VectorUnloader(root)
val arrowWriter = ArrowWriter.create(root)
val arrowSchemaSize = SizeEstimator.estimate(arrowSchema)

Option(TaskContext.get).foreach {
_.addTaskCompletionListener[Unit] { _ =>
Expand All @@ -161,17 +166,23 @@ private[sql] object ArrowConverters extends Logging {
val writeChannel = new WriteChannel(Channels.newChannel(out))

var rowCount = 0L
var estimatedBatchSize = arrowSchemaSize
Utils.tryWithSafeFinally {
while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || rowCount < maxRecordsPerBatch)) {
// Always write the schema.
MessageSerializer.serialize(writeChannel, arrowSchema)

// Always write the first row.
while (rowIter.hasNext && (rowCount == 0 || estimatedBatchSize < maxBatchSize)) {
val row = rowIter.next()
arrowWriter.write(row)
estimatedBatchSize += row.asInstanceOf[UnsafeRow].getSizeInBytes
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

refer to how the size is computed in BroadcastExchange

but not 100% sure, should I use this instead?

row match {
   case unsafe: UnsafeRow => estimatedBatchSize += unsafe.getSizeInBytes
   case _ => estimatedBatchSize += SizeEstimator.estimate(row)
}

cc @HyukjinKwon

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The size of message should be based on Arrow but we are only able to know the size of the batch when Arrow batch is created.

So I am fine with the current approach. I do believe that UnsafeRow has bigger size than ArrowBatch in general.

One nit would be we should probably set the lower size in maxBatchSize to be conservative. For example, maxBatchSize * 0.7

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will update maxBatchSize * 0.7

rowCount += 1
}
arrowWriter.finish()
val batch = unloader.getRecordBatch()

MessageSerializer.serialize(writeChannel, arrowSchema)
MessageSerializer.serialize(writeChannel, batch)

// Always write the Ipc options at the end.
ArrowStreamWriter.writeEndOfStream(writeChannel, IpcOption.DEFAULT)

batch.close()
Expand Down