diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala index 9652fce5425fa..7adb514715859 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala @@ -127,19 +127,23 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte def processAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = { val spark = dataframe.sparkSession val schema = dataframe.schema + val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) { val rows = dataframe.queryExecution.executedPlan.execute() val numPartitions = rows.getNumPartitions + // Conservatively sets it 70% because the size is not accurate but estimated. + val maxBatchSize = (MAX_BATCH_SIZE * 0.7).toLong var numSent = 0 if (numPartitions > 0) { type Batch = (Array[Byte], Long) val batches = rows.mapPartitionsInternal { iter => - ArrowConverters - .toBatchWithSchemaIterator(iter, schema, MAX_BATCH_SIZE, timeZoneId) + val newIter = ArrowConverters + .toBatchWithSchemaIterator(iter, schema, maxRecordsPerBatch, maxBatchSize, timeZoneId) + newIter.map { batch: Array[Byte] => (batch, newIter.rowCountInLastBatch) } } val signal = new Object diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index c233ac32c125b..a60f7b5970d1b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -71,63 +71,125 @@ private[sql] class ArrowBatchStreamWriter( } private[sql] object ArrowConverters extends Logging { - - /** - * Maps Iterator from InternalRow to serialized ArrowRecordBatches. Limit ArrowRecordBatch size - * in a batch by setting maxRecordsPerBatch or use 0 to fully consume rowIter. - */ - private[sql] def toBatchIterator( + private[sql] class ArrowBatchIterator( rowIter: Iterator[InternalRow], schema: StructType, - maxRecordsPerBatch: Int, + maxRecordsPerBatch: Long, timeZoneId: String, - context: TaskContext): Iterator[Array[Byte]] = { + context: TaskContext) extends Iterator[Array[Byte]] { - val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) - val allocator = - ArrowUtils.rootAllocator.newChildAllocator("toBatchIterator", 0, Long.MaxValue) + protected val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) + private val allocator = + ArrowUtils.rootAllocator.newChildAllocator( + s"to${this.getClass.getSimpleName}", 0, Long.MaxValue) - val root = VectorSchemaRoot.create(arrowSchema, allocator) - val unloader = new VectorUnloader(root) - val arrowWriter = ArrowWriter.create(root) + private val root = VectorSchemaRoot.create(arrowSchema, allocator) + protected val unloader = new VectorUnloader(root) + protected val arrowWriter = ArrowWriter.create(root) - context.addTaskCompletionListener[Unit] { _ => + Option(context).foreach {_.addTaskCompletionListener[Unit] { _ => root.close() allocator.close() + }} + + override def hasNext: Boolean = rowIter.hasNext || { + root.close() + allocator.close() + false } - new Iterator[Array[Byte]] { + override def next(): Array[Byte] = { + val out = new ByteArrayOutputStream() + val writeChannel = new WriteChannel(Channels.newChannel(out)) - override def hasNext: Boolean = rowIter.hasNext || { - root.close() - allocator.close() - false + Utils.tryWithSafeFinally { + var rowCount = 0L + while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || rowCount < maxRecordsPerBatch)) { + val row = rowIter.next() + arrowWriter.write(row) + rowCount += 1 + } + arrowWriter.finish() + val batch = unloader.getRecordBatch() + MessageSerializer.serialize(writeChannel, batch) + batch.close() + } { + arrowWriter.reset() } - override def next(): Array[Byte] = { - val out = new ByteArrayOutputStream() - val writeChannel = new WriteChannel(Channels.newChannel(out)) - - Utils.tryWithSafeFinally { - var rowCount = 0 - while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || rowCount < maxRecordsPerBatch)) { - val row = rowIter.next() - arrowWriter.write(row) - rowCount += 1 - } - arrowWriter.finish() - val batch = unloader.getRecordBatch() - MessageSerializer.serialize(writeChannel, batch) - batch.close() - } { - arrowWriter.reset() + out.toByteArray + } + } + + private[sql] class ArrowBatchWithSchemaIterator( + rowIter: Iterator[InternalRow], + schema: StructType, + maxRecordsPerBatch: Long, + maxEstimatedBatchSize: Long, + timeZoneId: String, + context: TaskContext) + extends ArrowBatchIterator( + rowIter, schema, maxRecordsPerBatch, timeZoneId, context) { + + private val arrowSchemaSize = SizeEstimator.estimate(arrowSchema) + var rowCountInLastBatch: Long = 0 + + override def next(): Array[Byte] = { + val out = new ByteArrayOutputStream() + val writeChannel = new WriteChannel(Channels.newChannel(out)) + + rowCountInLastBatch = 0 + var estimatedBatchSize = arrowSchemaSize + Utils.tryWithSafeFinally { + // Always write the schema. + MessageSerializer.serialize(writeChannel, arrowSchema) + + // Always write the first row. + while (rowIter.hasNext && ( + // For maxBatchSize and maxRecordsPerBatch, respect whatever smaller. + // If the size in bytes is positive (set properly), always write the first row. + rowCountInLastBatch == 0 && maxEstimatedBatchSize > 0 || + // If the size in bytes of rows are 0 or negative, unlimit it. + estimatedBatchSize <= 0 || + estimatedBatchSize < maxEstimatedBatchSize || + // If the size of rows are 0 or negative, unlimit it. + maxRecordsPerBatch <= 0 || + rowCountInLastBatch < maxRecordsPerBatch)) { + val row = rowIter.next() + arrowWriter.write(row) + estimatedBatchSize += row.asInstanceOf[UnsafeRow].getSizeInBytes + rowCountInLastBatch += 1 } + arrowWriter.finish() + val batch = unloader.getRecordBatch() + MessageSerializer.serialize(writeChannel, batch) + + // Always write the Ipc options at the end. + ArrowStreamWriter.writeEndOfStream(writeChannel, IpcOption.DEFAULT) - out.toByteArray + batch.close() + } { + arrowWriter.reset() } + + out.toByteArray } } + /** + * Maps Iterator from InternalRow to serialized ArrowRecordBatches. Limit ArrowRecordBatch size + * in a batch by setting maxRecordsPerBatch or use 0 to fully consume rowIter. + */ + private[sql] def toBatchIterator( + rowIter: Iterator[InternalRow], + schema: StructType, + maxRecordsPerBatch: Long, + timeZoneId: String, + context: TaskContext): ArrowBatchIterator = { + new ArrowBatchIterator( + rowIter, schema, maxRecordsPerBatch, timeZoneId, context) + } + /** * Convert the input rows into fully contained arrow batches. * Different from [[toBatchIterator]], each output arrow batch starts with the schema. @@ -135,94 +197,20 @@ private[sql] object ArrowConverters extends Logging { private[sql] def toBatchWithSchemaIterator( rowIter: Iterator[InternalRow], schema: StructType, - maxBatchSize: Long, - timeZoneId: String): Iterator[(Array[Byte], Long)] = { - val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) - val allocator = ArrowUtils.rootAllocator.newChildAllocator( - "toArrowBatchIterator", 0, Long.MaxValue) - - val root = VectorSchemaRoot.create(arrowSchema, allocator) - val unloader = new VectorUnloader(root) - val arrowWriter = ArrowWriter.create(root) - val arrowSchemaSize = SizeEstimator.estimate(arrowSchema) - - Option(TaskContext.get).foreach { - _.addTaskCompletionListener[Unit] { _ => - root.close() - allocator.close() - } - } - - new Iterator[(Array[Byte], Long)] { - - override def hasNext: Boolean = rowIter.hasNext || { - root.close() - allocator.close() - false - } - - override def next(): (Array[Byte], Long) = { - val out = new ByteArrayOutputStream() - val writeChannel = new WriteChannel(Channels.newChannel(out)) - - var rowCount = 0L - var estimatedBatchSize = arrowSchemaSize - Utils.tryWithSafeFinally { - // Always write the schema. - MessageSerializer.serialize(writeChannel, arrowSchema) - - // Always write the first row. - while (rowIter.hasNext && (rowCount == 0 || estimatedBatchSize < maxBatchSize)) { - val row = rowIter.next() - arrowWriter.write(row) - estimatedBatchSize += row.asInstanceOf[UnsafeRow].getSizeInBytes - rowCount += 1 - } - arrowWriter.finish() - val batch = unloader.getRecordBatch() - MessageSerializer.serialize(writeChannel, batch) - - // Always write the Ipc options at the end. - ArrowStreamWriter.writeEndOfStream(writeChannel, IpcOption.DEFAULT) - - batch.close() - } { - arrowWriter.reset() - } - - (out.toByteArray, rowCount) - } - } + maxRecordsPerBatch: Long, + maxEstimatedBatchSize: Long, + timeZoneId: String): ArrowBatchWithSchemaIterator = { + new ArrowBatchWithSchemaIterator( + rowIter, schema, maxRecordsPerBatch, maxEstimatedBatchSize, timeZoneId, TaskContext.get) } private[sql] def createEmptyArrowBatch( schema: StructType, timeZoneId: String): Array[Byte] = { - val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) - val allocator = ArrowUtils.rootAllocator.newChildAllocator( - "createEmptyArrowBatch", 0, Long.MaxValue) - - val root = VectorSchemaRoot.create(arrowSchema, allocator) - val unloader = new VectorUnloader(root) - val arrowWriter = ArrowWriter.create(root) - - val out = new ByteArrayOutputStream() - val writeChannel = new WriteChannel(Channels.newChannel(out)) - - Utils.tryWithSafeFinally { - arrowWriter.finish() - val batch = unloader.getRecordBatch() // empty batch - - MessageSerializer.serialize(writeChannel, arrowSchema) - MessageSerializer.serialize(writeChannel, batch) - ArrowStreamWriter.writeEndOfStream(writeChannel, IpcOption.DEFAULT) - - batch.close() - } { - arrowWriter.reset() - } - - out.toByteArray + new ArrowBatchWithSchemaIterator( + Iterator.empty, schema, 0L, 0L, timeZoneId, TaskContext.get) { + override def hasNext: Boolean = true + }.next() } /**