-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-41108][SPARK-41005][CONNECT][FOLLOW-UP] Deduplicate ArrowConverters codes #38618
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -71,158 +71,146 @@ private[sql] class ArrowBatchStreamWriter( | |
| } | ||
|
|
||
| private[sql] object ArrowConverters extends Logging { | ||
|
|
||
| /** | ||
| * Maps Iterator from InternalRow to serialized ArrowRecordBatches. Limit ArrowRecordBatch size | ||
| * in a batch by setting maxRecordsPerBatch or use 0 to fully consume rowIter. | ||
| */ | ||
| private[sql] def toBatchIterator( | ||
| private[sql] class ArrowBatchIterator( | ||
| rowIter: Iterator[InternalRow], | ||
| schema: StructType, | ||
| maxRecordsPerBatch: Int, | ||
| maxRecordsPerBatch: Long, | ||
| timeZoneId: String, | ||
| context: TaskContext): Iterator[Array[Byte]] = { | ||
| context: TaskContext) extends Iterator[Array[Byte]] { | ||
|
|
||
| val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) | ||
| val allocator = | ||
| ArrowUtils.rootAllocator.newChildAllocator("toBatchIterator", 0, Long.MaxValue) | ||
| protected val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) | ||
| private val allocator = | ||
| ArrowUtils.rootAllocator.newChildAllocator( | ||
| s"to${this.getClass.getSimpleName}", 0, Long.MaxValue) | ||
|
|
||
| val root = VectorSchemaRoot.create(arrowSchema, allocator) | ||
| val unloader = new VectorUnloader(root) | ||
| val arrowWriter = ArrowWriter.create(root) | ||
| private val root = VectorSchemaRoot.create(arrowSchema, allocator) | ||
| protected val unloader = new VectorUnloader(root) | ||
| protected val arrowWriter = ArrowWriter.create(root) | ||
|
|
||
| context.addTaskCompletionListener[Unit] { _ => | ||
| Option(context).foreach {_.addTaskCompletionListener[Unit] { _ => | ||
| root.close() | ||
| allocator.close() | ||
| }} | ||
|
|
||
| override def hasNext: Boolean = rowIter.hasNext || { | ||
| root.close() | ||
| allocator.close() | ||
| false | ||
| } | ||
|
|
||
| new Iterator[Array[Byte]] { | ||
| override def next(): Array[Byte] = { | ||
| val out = new ByteArrayOutputStream() | ||
| val writeChannel = new WriteChannel(Channels.newChannel(out)) | ||
|
|
||
| override def hasNext: Boolean = rowIter.hasNext || { | ||
| root.close() | ||
| allocator.close() | ||
| false | ||
| Utils.tryWithSafeFinally { | ||
| var rowCount = 0L | ||
| while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || rowCount < maxRecordsPerBatch)) { | ||
| val row = rowIter.next() | ||
| arrowWriter.write(row) | ||
| rowCount += 1 | ||
| } | ||
| arrowWriter.finish() | ||
| val batch = unloader.getRecordBatch() | ||
| MessageSerializer.serialize(writeChannel, batch) | ||
| batch.close() | ||
| } { | ||
| arrowWriter.reset() | ||
| } | ||
|
|
||
| override def next(): Array[Byte] = { | ||
| val out = new ByteArrayOutputStream() | ||
| val writeChannel = new WriteChannel(Channels.newChannel(out)) | ||
|
|
||
| Utils.tryWithSafeFinally { | ||
| var rowCount = 0 | ||
| while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || rowCount < maxRecordsPerBatch)) { | ||
| val row = rowIter.next() | ||
| arrowWriter.write(row) | ||
| rowCount += 1 | ||
| } | ||
| arrowWriter.finish() | ||
| val batch = unloader.getRecordBatch() | ||
| MessageSerializer.serialize(writeChannel, batch) | ||
| batch.close() | ||
| } { | ||
| arrowWriter.reset() | ||
| out.toByteArray | ||
| } | ||
| } | ||
|
|
||
| private[sql] class ArrowBatchWithSchemaIterator( | ||
| rowIter: Iterator[InternalRow], | ||
| schema: StructType, | ||
| maxRecordsPerBatch: Long, | ||
| maxEstimatedBatchSize: Long, | ||
| timeZoneId: String, | ||
| context: TaskContext) | ||
| extends ArrowBatchIterator( | ||
| rowIter, schema, maxRecordsPerBatch, timeZoneId, context) { | ||
|
|
||
| private val arrowSchemaSize = SizeEstimator.estimate(arrowSchema) | ||
| var rowCountInLastBatch: Long = 0 | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A couple of questions. Why do we need the rowcount? It is already encoded in the batch itself. If we do need the rowcount, please make the iterator return it in the next call instead relying on a side effect.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This logic is also from #38468, and this PR is a followup. The return type here is |
||
|
|
||
| override def next(): Array[Byte] = { | ||
| val out = new ByteArrayOutputStream() | ||
| val writeChannel = new WriteChannel(Channels.newChannel(out)) | ||
|
|
||
| rowCountInLastBatch = 0 | ||
| var estimatedBatchSize = arrowSchemaSize | ||
| Utils.tryWithSafeFinally { | ||
| // Always write the schema. | ||
| MessageSerializer.serialize(writeChannel, arrowSchema) | ||
|
|
||
| // Always write the first row. | ||
| while (rowIter.hasNext && ( | ||
| // For maxBatchSize and maxRecordsPerBatch, respect whatever smaller. | ||
| // If the size in bytes is positive (set properly), always write the first row. | ||
| rowCountInLastBatch == 0 && maxEstimatedBatchSize > 0 || | ||
| // If the size in bytes of rows are 0 or negative, unlimit it. | ||
| estimatedBatchSize <= 0 || | ||
| estimatedBatchSize < maxEstimatedBatchSize || | ||
| // If the size of rows are 0 or negative, unlimit it. | ||
| maxRecordsPerBatch <= 0 || | ||
| rowCountInLastBatch < maxRecordsPerBatch)) { | ||
| val row = rowIter.next() | ||
| arrowWriter.write(row) | ||
| estimatedBatchSize += row.asInstanceOf[UnsafeRow].getSizeInBytes | ||
| rowCountInLastBatch += 1 | ||
| } | ||
| arrowWriter.finish() | ||
| val batch = unloader.getRecordBatch() | ||
| MessageSerializer.serialize(writeChannel, batch) | ||
|
|
||
| // Always write the Ipc options at the end. | ||
| ArrowStreamWriter.writeEndOfStream(writeChannel, IpcOption.DEFAULT) | ||
|
|
||
| out.toByteArray | ||
| batch.close() | ||
| } { | ||
| arrowWriter.reset() | ||
| } | ||
|
|
||
| out.toByteArray | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Maps Iterator from InternalRow to serialized ArrowRecordBatches. Limit ArrowRecordBatch size | ||
| * in a batch by setting maxRecordsPerBatch or use 0 to fully consume rowIter. | ||
| */ | ||
| private[sql] def toBatchIterator( | ||
| rowIter: Iterator[InternalRow], | ||
| schema: StructType, | ||
| maxRecordsPerBatch: Long, | ||
| timeZoneId: String, | ||
| context: TaskContext): ArrowBatchIterator = { | ||
| new ArrowBatchIterator( | ||
| rowIter, schema, maxRecordsPerBatch, timeZoneId, context) | ||
| } | ||
|
|
||
| /** | ||
| * Convert the input rows into fully contained arrow batches. | ||
| * Different from [[toBatchIterator]], each output arrow batch starts with the schema. | ||
| */ | ||
| private[sql] def toBatchWithSchemaIterator( | ||
| rowIter: Iterator[InternalRow], | ||
| schema: StructType, | ||
| maxBatchSize: Long, | ||
| timeZoneId: String): Iterator[(Array[Byte], Long)] = { | ||
| val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) | ||
| val allocator = ArrowUtils.rootAllocator.newChildAllocator( | ||
| "toArrowBatchIterator", 0, Long.MaxValue) | ||
|
|
||
| val root = VectorSchemaRoot.create(arrowSchema, allocator) | ||
| val unloader = new VectorUnloader(root) | ||
| val arrowWriter = ArrowWriter.create(root) | ||
| val arrowSchemaSize = SizeEstimator.estimate(arrowSchema) | ||
|
|
||
| Option(TaskContext.get).foreach { | ||
| _.addTaskCompletionListener[Unit] { _ => | ||
| root.close() | ||
| allocator.close() | ||
| } | ||
| } | ||
|
|
||
| new Iterator[(Array[Byte], Long)] { | ||
|
|
||
| override def hasNext: Boolean = rowIter.hasNext || { | ||
| root.close() | ||
| allocator.close() | ||
| false | ||
| } | ||
|
|
||
| override def next(): (Array[Byte], Long) = { | ||
| val out = new ByteArrayOutputStream() | ||
| val writeChannel = new WriteChannel(Channels.newChannel(out)) | ||
|
|
||
| var rowCount = 0L | ||
| var estimatedBatchSize = arrowSchemaSize | ||
| Utils.tryWithSafeFinally { | ||
| // Always write the schema. | ||
| MessageSerializer.serialize(writeChannel, arrowSchema) | ||
|
|
||
| // Always write the first row. | ||
| while (rowIter.hasNext && (rowCount == 0 || estimatedBatchSize < maxBatchSize)) { | ||
| val row = rowIter.next() | ||
| arrowWriter.write(row) | ||
| estimatedBatchSize += row.asInstanceOf[UnsafeRow].getSizeInBytes | ||
| rowCount += 1 | ||
| } | ||
| arrowWriter.finish() | ||
| val batch = unloader.getRecordBatch() | ||
| MessageSerializer.serialize(writeChannel, batch) | ||
|
|
||
| // Always write the Ipc options at the end. | ||
| ArrowStreamWriter.writeEndOfStream(writeChannel, IpcOption.DEFAULT) | ||
|
|
||
| batch.close() | ||
| } { | ||
| arrowWriter.reset() | ||
| } | ||
|
|
||
| (out.toByteArray, rowCount) | ||
| } | ||
| } | ||
| maxRecordsPerBatch: Long, | ||
| maxEstimatedBatchSize: Long, | ||
| timeZoneId: String): ArrowBatchWithSchemaIterator = { | ||
| new ArrowBatchWithSchemaIterator( | ||
| rowIter, schema, maxRecordsPerBatch, maxEstimatedBatchSize, timeZoneId, TaskContext.get) | ||
| } | ||
|
|
||
| private[sql] def createEmptyArrowBatch( | ||
| schema: StructType, | ||
| timeZoneId: String): Array[Byte] = { | ||
| val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) | ||
| val allocator = ArrowUtils.rootAllocator.newChildAllocator( | ||
| "createEmptyArrowBatch", 0, Long.MaxValue) | ||
|
|
||
| val root = VectorSchemaRoot.create(arrowSchema, allocator) | ||
| val unloader = new VectorUnloader(root) | ||
| val arrowWriter = ArrowWriter.create(root) | ||
|
|
||
| val out = new ByteArrayOutputStream() | ||
| val writeChannel = new WriteChannel(Channels.newChannel(out)) | ||
|
|
||
| Utils.tryWithSafeFinally { | ||
| arrowWriter.finish() | ||
| val batch = unloader.getRecordBatch() // empty batch | ||
|
|
||
| MessageSerializer.serialize(writeChannel, arrowSchema) | ||
| MessageSerializer.serialize(writeChannel, batch) | ||
| ArrowStreamWriter.writeEndOfStream(writeChannel, IpcOption.DEFAULT) | ||
|
|
||
| batch.close() | ||
| } { | ||
| arrowWriter.reset() | ||
| } | ||
|
|
||
| out.toByteArray | ||
| new ArrowBatchWithSchemaIterator( | ||
| Iterator.empty, schema, 0L, 0L, timeZoneId, TaskContext.get) { | ||
| override def hasNext: Boolean = true | ||
| }.next() | ||
| } | ||
|
|
||
| /** | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is this supposed to achieve? This uses a lot of reflective code to figure out the size of the schema object. How is this related to the size of the batch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@hvanhovell, this PR is virtually pure refactoring except the couple of points I mentioned in the PR description. For the question, it came from #38612 to estimate the size of the batch before creating an Arrow batch.