diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala index 3b734616b213..9652fce5425f 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.execution.arrow.ArrowConverters class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) extends Logging { // The maximum batch size in bytes for a single batch of data to be returned via proto. - val MAX_BATCH_SIZE: Long = 10 * 1024 * 1024 + private val MAX_BATCH_SIZE: Long = 4 * 1024 * 1024 def handle(v: Request): Unit = { val session = @@ -127,8 +127,6 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte def processAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = { val spark = dataframe.sparkSession val schema = dataframe.schema - // TODO: control the batch size instead of max records - val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) { @@ -141,7 +139,7 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte val batches = rows.mapPartitionsInternal { iter => ArrowConverters - .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId) + .toBatchWithSchemaIterator(iter, schema, MAX_BATCH_SIZE, timeZoneId) } val signal = new Object diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index a2dce31bc6d3..c233ac32c125 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -33,12 +33,12 @@ import org.apache.spark.internal.Logging import org.apache.spark.network.util.JavaUtils import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.types._ import org.apache.spark.sql.util.ArrowUtils import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} -import org.apache.spark.util.{ByteBufferOutputStream, Utils} +import org.apache.spark.util.{ByteBufferOutputStream, SizeEstimator, Utils} /** @@ -128,10 +128,14 @@ private[sql] object ArrowConverters extends Logging { } } - private[sql] def toArrowBatchIterator( + /** + * Convert the input rows into fully contained arrow batches. + * Different from [[toBatchIterator]], each output arrow batch starts with the schema. + */ + private[sql] def toBatchWithSchemaIterator( rowIter: Iterator[InternalRow], schema: StructType, - maxRecordsPerBatch: Int, + maxBatchSize: Long, timeZoneId: String): Iterator[(Array[Byte], Long)] = { val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) val allocator = ArrowUtils.rootAllocator.newChildAllocator( @@ -140,6 +144,7 @@ private[sql] object ArrowConverters extends Logging { val root = VectorSchemaRoot.create(arrowSchema, allocator) val unloader = new VectorUnloader(root) val arrowWriter = ArrowWriter.create(root) + val arrowSchemaSize = SizeEstimator.estimate(arrowSchema) Option(TaskContext.get).foreach { _.addTaskCompletionListener[Unit] { _ => @@ -161,17 +166,23 @@ private[sql] object ArrowConverters extends Logging { val writeChannel = new WriteChannel(Channels.newChannel(out)) var rowCount = 0L + var estimatedBatchSize = arrowSchemaSize Utils.tryWithSafeFinally { - while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || rowCount < maxRecordsPerBatch)) { + // Always write the schema. + MessageSerializer.serialize(writeChannel, arrowSchema) + + // Always write the first row. + while (rowIter.hasNext && (rowCount == 0 || estimatedBatchSize < maxBatchSize)) { val row = rowIter.next() arrowWriter.write(row) + estimatedBatchSize += row.asInstanceOf[UnsafeRow].getSizeInBytes rowCount += 1 } arrowWriter.finish() val batch = unloader.getRecordBatch() - - MessageSerializer.serialize(writeChannel, arrowSchema) MessageSerializer.serialize(writeChannel, batch) + + // Always write the Ipc options at the end. ArrowStreamWriter.writeEndOfStream(writeChannel, IpcOption.DEFAULT) batch.close()