-
Notifications
You must be signed in to change notification settings - Fork 2
[WIP] Arrow conversion at partitions #23
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
5eaa36d
0e0f183
f7ecb4b
a68b496
fcd242d
4b39e5a
1fcd333
274aaae
ef2ecb8
3b8a8c6
87a9b96
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -54,8 +54,8 @@ private[sql] class ArrowConverters { | |
| override def hasNext: Boolean = iter.hasNext | ||
| } | ||
|
|
||
| def internalRowsToPayload(rows: Array[InternalRow], schema: StructType): ArrowPayload = { | ||
| val batch = ArrowConverters.internalRowsToArrowRecordBatch(rows, schema, allocator) | ||
| def interalRowIterToPayload(rowIter: Iterator[InternalRow], schema: StructType): ArrowPayload = { | ||
| val batch = ArrowConverters.internalRowIterToArrowBatch(rowIter, schema, allocator, None) | ||
| new ArrowStaticPayload(batch) | ||
| } | ||
| } | ||
|
|
@@ -83,52 +83,40 @@ private[sql] object ArrowConverters { | |
| } | ||
|
|
||
| /** | ||
| * Transfer an array of InternalRow to an ArrowRecordBatch. | ||
| * Iterate over InternalRows and write to an ArrowRecordBatch. | ||
| */ | ||
| private[sql] def internalRowsToArrowRecordBatch( | ||
| rows: Array[InternalRow], | ||
| private def internalRowIterToArrowBatch( | ||
| rowIter: Iterator[InternalRow], | ||
| schema: StructType, | ||
| allocator: RootAllocator): ArrowRecordBatch = { | ||
| val fieldAndBuf = schema.fields.zipWithIndex.map { case (field, ordinal) => | ||
| internalRowToArrowBuf(rows, ordinal, field, allocator) | ||
| allocator: RootAllocator, | ||
| initialSize: Option[Int]): ArrowRecordBatch = { | ||
|
|
||
| val columnWriters = schema.fields.zipWithIndex.map { case (field, ordinal) => | ||
| ColumnWriter(ordinal, allocator, field.dataType) | ||
| .init(initialSize) | ||
| } | ||
|
|
||
| rowIter.foreach { row => | ||
|
||
| columnWriters.foreach { writer => | ||
| writer.write(row) | ||
| } | ||
| } | ||
|
|
||
| val fieldAndBuf = columnWriters.map { writer => | ||
| writer.finish() | ||
| }.unzip | ||
| val fieldNodes = fieldAndBuf._1.flatten | ||
| val buffers = fieldAndBuf._2.flatten | ||
|
|
||
| val recordBatch = new ArrowRecordBatch(rows.length, | ||
| val rowLength = if(fieldNodes.nonEmpty) fieldNodes.head.getLength else 0 | ||
|
Owner
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this seem acceptable to get the row length for creating an
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we need another counter here. When is fieldNodes empty?
Owner
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It will probably never be empty, but might be good to keep the check just in case |
||
|
|
||
| val recordBatch = new ArrowRecordBatch(rowLength, | ||
| fieldNodes.toList.asJava, buffers.toList.asJava) | ||
|
|
||
| buffers.foreach(_.release()) | ||
| recordBatch | ||
| } | ||
|
|
||
| /** | ||
| * Write a Field from array of InternalRow to an ArrowBuf. | ||
| */ | ||
| private def internalRowToArrowBuf( | ||
| rows: Array[InternalRow], | ||
| ordinal: Int, | ||
| field: StructField, | ||
| allocator: RootAllocator): (Array[ArrowFieldNode], Array[ArrowBuf]) = { | ||
| val numOfRows = rows.length | ||
| val columnWriter = ColumnWriter(allocator, field.dataType) | ||
| columnWriter.init(numOfRows) | ||
| var index = 0 | ||
|
|
||
| while(index < numOfRows) { | ||
| val row = rows(index) | ||
| if (row.isNullAt(ordinal)) { | ||
| columnWriter.writeNull() | ||
| } else { | ||
| columnWriter.write(row, ordinal) | ||
| } | ||
| index += 1 | ||
| } | ||
|
|
||
| val (arrowFieldNodes, arrowBufs) = columnWriter.finish() | ||
| (arrowFieldNodes.toArray, arrowBufs.toArray) | ||
| } | ||
|
|
||
| /** | ||
| * Convert a Spark Dataset schema to Arrow schema. | ||
| */ | ||
|
|
@@ -160,9 +148,8 @@ private[sql] object ArrowConverters { | |
| } | ||
|
|
||
| private[sql] trait ColumnWriter { | ||
| def init(initialSize: Int): Unit | ||
| def writeNull(): Unit | ||
| def write(row: InternalRow, ordinal: Int): Unit | ||
| def init(initialSize: Option[Int]): this.type | ||
| def write(row: InternalRow): Unit | ||
|
|
||
| /** | ||
| * Clear the column writer and return the ArrowFieldNode and ArrowBuf. | ||
|
|
@@ -174,7 +161,9 @@ private[sql] trait ColumnWriter { | |
| /** | ||
| * Base class for flat arrow column writer, i.e., column without children. | ||
| */ | ||
| private[sql] abstract class PrimitiveColumnWriter(protected val allocator: BaseAllocator) | ||
| private[sql] abstract class PrimitiveColumnWriter( | ||
| val ordinal: Int, | ||
| protected val allocator: BaseAllocator) | ||
| extends ColumnWriter { | ||
| protected def valueVector: BaseDataValueVector | ||
| protected def valueMutator: BaseMutator | ||
|
|
@@ -185,18 +174,19 @@ private[sql] abstract class PrimitiveColumnWriter(protected val allocator: BaseA | |
| protected var count = 0 | ||
| protected var nullCount = 0 | ||
|
|
||
| override def init(initialSize: Int): Unit = { | ||
| override def init(initialSize: Option[Int]): this.type = { | ||
| initialSize.foreach(valueVector.setInitialCapacity) | ||
|
Owner
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I noticed this was never used before, would it increase performance much to set the exact capacity? Although I'm not sure its possible when using
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For iterator, we cannot know the size. I would just set a constant initial capacity.
Owner
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, I'll remove the option to set an initial capacity now. The default capacity seems to be fine. |
||
| valueVector.allocateNew() | ||
| this | ||
| } | ||
|
|
||
| override def writeNull(): Unit = { | ||
| setNull() | ||
| nullCount += 1 | ||
| count += 1 | ||
| } | ||
|
|
||
| override def write(row: InternalRow, ordinal: Int): Unit = { | ||
| setValue(row, ordinal) | ||
| override def write(row: InternalRow): Unit = { | ||
| if (row.isNullAt(ordinal)) { | ||
| setNull() | ||
| nullCount += 1 | ||
| } else { | ||
| setValue(row, ordinal) | ||
| } | ||
| count += 1 | ||
| } | ||
|
|
||
|
|
@@ -208,8 +198,8 @@ private[sql] abstract class PrimitiveColumnWriter(protected val allocator: BaseA | |
| } | ||
| } | ||
|
|
||
| private[sql] class BooleanColumnWriter(allocator: BaseAllocator) | ||
| extends PrimitiveColumnWriter(allocator) { | ||
| private[sql] class BooleanColumnWriter(ordinal: Int, allocator: BaseAllocator) | ||
| extends PrimitiveColumnWriter(ordinal, allocator) { | ||
| private def bool2int(b: Boolean): Int = if (b) 1 else 0 | ||
|
|
||
| override protected val valueVector: NullableBitVector | ||
|
|
@@ -221,8 +211,8 @@ private[sql] class BooleanColumnWriter(allocator: BaseAllocator) | |
| = valueMutator.setSafe(count, bool2int(row.getBoolean(ordinal))) | ||
| } | ||
|
|
||
| private[sql] class ShortColumnWriter(allocator: BaseAllocator) | ||
| extends PrimitiveColumnWriter(allocator) { | ||
| private[sql] class ShortColumnWriter(ordinal: Int, allocator: BaseAllocator) | ||
| extends PrimitiveColumnWriter(ordinal, allocator) { | ||
| override protected val valueVector: NullableSmallIntVector | ||
| = new NullableSmallIntVector("ShortValue", allocator) | ||
| override protected val valueMutator: NullableSmallIntVector#Mutator = valueVector.getMutator | ||
|
|
@@ -232,8 +222,8 @@ private[sql] class ShortColumnWriter(allocator: BaseAllocator) | |
| = valueMutator.setSafe(count, row.getShort(ordinal)) | ||
| } | ||
|
|
||
| private[sql] class IntegerColumnWriter(allocator: BaseAllocator) | ||
| extends PrimitiveColumnWriter(allocator) { | ||
| private[sql] class IntegerColumnWriter(ordinal: Int, allocator: BaseAllocator) | ||
| extends PrimitiveColumnWriter(ordinal, allocator) { | ||
| override protected val valueVector: NullableIntVector | ||
| = new NullableIntVector("IntValue", allocator) | ||
| override protected val valueMutator: NullableIntVector#Mutator = valueVector.getMutator | ||
|
|
@@ -243,8 +233,8 @@ private[sql] class IntegerColumnWriter(allocator: BaseAllocator) | |
| = valueMutator.setSafe(count, row.getInt(ordinal)) | ||
| } | ||
|
|
||
| private[sql] class LongColumnWriter(allocator: BaseAllocator) | ||
| extends PrimitiveColumnWriter(allocator) { | ||
| private[sql] class LongColumnWriter(ordinal: Int, allocator: BaseAllocator) | ||
| extends PrimitiveColumnWriter(ordinal, allocator) { | ||
| override protected val valueVector: NullableBigIntVector | ||
| = new NullableBigIntVector("LongValue", allocator) | ||
| override protected val valueMutator: NullableBigIntVector#Mutator = valueVector.getMutator | ||
|
|
@@ -254,8 +244,8 @@ private[sql] class LongColumnWriter(allocator: BaseAllocator) | |
| = valueMutator.setSafe(count, row.getLong(ordinal)) | ||
| } | ||
|
|
||
| private[sql] class FloatColumnWriter(allocator: BaseAllocator) | ||
| extends PrimitiveColumnWriter(allocator) { | ||
| private[sql] class FloatColumnWriter(ordinal: Int, allocator: BaseAllocator) | ||
| extends PrimitiveColumnWriter(ordinal, allocator) { | ||
| override protected val valueVector: NullableFloat4Vector | ||
| = new NullableFloat4Vector("FloatValue", allocator) | ||
| override protected val valueMutator: NullableFloat4Vector#Mutator = valueVector.getMutator | ||
|
|
@@ -265,8 +255,8 @@ private[sql] class FloatColumnWriter(allocator: BaseAllocator) | |
| = valueMutator.setSafe(count, row.getFloat(ordinal)) | ||
| } | ||
|
|
||
| private[sql] class DoubleColumnWriter(allocator: BaseAllocator) | ||
| extends PrimitiveColumnWriter(allocator) { | ||
| private[sql] class DoubleColumnWriter(ordinal: Int, allocator: BaseAllocator) | ||
| extends PrimitiveColumnWriter(ordinal, allocator) { | ||
| override protected val valueVector: NullableFloat8Vector | ||
| = new NullableFloat8Vector("DoubleValue", allocator) | ||
| override protected val valueMutator: NullableFloat8Vector#Mutator = valueVector.getMutator | ||
|
|
@@ -276,8 +266,8 @@ private[sql] class DoubleColumnWriter(allocator: BaseAllocator) | |
| = valueMutator.setSafe(count, row.getDouble(ordinal)) | ||
| } | ||
|
|
||
| private[sql] class ByteColumnWriter(allocator: BaseAllocator) | ||
| extends PrimitiveColumnWriter(allocator) { | ||
| private[sql] class ByteColumnWriter(ordinal: Int, allocator: BaseAllocator) | ||
| extends PrimitiveColumnWriter(ordinal, allocator) { | ||
| override protected val valueVector: NullableUInt1Vector | ||
| = new NullableUInt1Vector("ByteValue", allocator) | ||
| override protected val valueMutator: NullableUInt1Vector#Mutator = valueVector.getMutator | ||
|
|
@@ -287,8 +277,8 @@ private[sql] class ByteColumnWriter(allocator: BaseAllocator) | |
| = valueMutator.setSafe(count, row.getByte(ordinal)) | ||
| } | ||
|
|
||
| private[sql] class UTF8StringColumnWriter(allocator: BaseAllocator) | ||
| extends PrimitiveColumnWriter(allocator) { | ||
| private[sql] class UTF8StringColumnWriter(ordinal: Int, allocator: BaseAllocator) | ||
| extends PrimitiveColumnWriter(ordinal, allocator) { | ||
| override protected val valueVector: NullableVarBinaryVector | ||
| = new NullableVarBinaryVector("UTF8StringValue", allocator) | ||
| override protected val valueMutator: NullableVarBinaryVector#Mutator = valueVector.getMutator | ||
|
|
@@ -300,8 +290,8 @@ private[sql] class UTF8StringColumnWriter(allocator: BaseAllocator) | |
| } | ||
| } | ||
|
|
||
| private[sql] class BinaryColumnWriter(allocator: BaseAllocator) | ||
| extends PrimitiveColumnWriter(allocator) { | ||
| private[sql] class BinaryColumnWriter(ordinal: Int, allocator: BaseAllocator) | ||
| extends PrimitiveColumnWriter(ordinal, allocator) { | ||
| override protected val valueVector: NullableVarBinaryVector | ||
| = new NullableVarBinaryVector("BinaryValue", allocator) | ||
| override protected val valueMutator: NullableVarBinaryVector#Mutator = valueVector.getMutator | ||
|
|
@@ -313,8 +303,8 @@ private[sql] class BinaryColumnWriter(allocator: BaseAllocator) | |
| } | ||
| } | ||
|
|
||
| private[sql] class DateColumnWriter(allocator: BaseAllocator) | ||
| extends PrimitiveColumnWriter(allocator) { | ||
| private[sql] class DateColumnWriter(ordinal: Int, allocator: BaseAllocator) | ||
| extends PrimitiveColumnWriter(ordinal, allocator) { | ||
| override protected val valueVector: NullableDateVector | ||
| = new NullableDateVector("DateValue", allocator) | ||
| override protected val valueMutator: NullableDateVector#Mutator = valueVector.getMutator | ||
|
|
@@ -326,8 +316,8 @@ private[sql] class DateColumnWriter(allocator: BaseAllocator) | |
| } | ||
| } | ||
|
|
||
| private[sql] class TimeStampColumnWriter(allocator: BaseAllocator) | ||
| extends PrimitiveColumnWriter(allocator) { | ||
| private[sql] class TimeStampColumnWriter(ordinal: Int, allocator: BaseAllocator) | ||
| extends PrimitiveColumnWriter(ordinal, allocator) { | ||
| override protected val valueVector: NullableTimeStampVector | ||
| = new NullableTimeStampVector("TimeStampValue", allocator) | ||
| override protected val valueMutator: NullableTimeStampVector#Mutator = valueVector.getMutator | ||
|
|
@@ -341,19 +331,19 @@ private[sql] class TimeStampColumnWriter(allocator: BaseAllocator) | |
| } | ||
|
|
||
| private[sql] object ColumnWriter { | ||
| def apply(allocator: BaseAllocator, dataType: DataType): ColumnWriter = { | ||
| def apply(ordinal: Int, allocator: BaseAllocator, dataType: DataType): ColumnWriter = { | ||
| dataType match { | ||
| case BooleanType => new BooleanColumnWriter(allocator) | ||
| case ShortType => new ShortColumnWriter(allocator) | ||
| case IntegerType => new IntegerColumnWriter(allocator) | ||
| case LongType => new LongColumnWriter(allocator) | ||
| case FloatType => new FloatColumnWriter(allocator) | ||
| case DoubleType => new DoubleColumnWriter(allocator) | ||
| case ByteType => new ByteColumnWriter(allocator) | ||
| case StringType => new UTF8StringColumnWriter(allocator) | ||
| case BinaryType => new BinaryColumnWriter(allocator) | ||
| case DateType => new DateColumnWriter(allocator) | ||
| case TimestampType => new TimeStampColumnWriter(allocator) | ||
| case BooleanType => new BooleanColumnWriter(ordinal, allocator) | ||
| case ShortType => new ShortColumnWriter(ordinal, allocator) | ||
| case IntegerType => new IntegerColumnWriter(ordinal, allocator) | ||
| case LongType => new LongColumnWriter(ordinal, allocator) | ||
| case FloatType => new FloatColumnWriter(ordinal, allocator) | ||
| case DoubleType => new DoubleColumnWriter(ordinal, allocator) | ||
| case ByteType => new ByteColumnWriter(ordinal, allocator) | ||
| case StringType => new UTF8StringColumnWriter(ordinal, allocator) | ||
| case BinaryType => new BinaryColumnWriter(ordinal, allocator) | ||
| case DateType => new DateColumnWriter(ordinal, allocator) | ||
| case TimestampType => new TimeStampColumnWriter(ordinal, allocator) | ||
| case _ => throw new UnsupportedOperationException(s"Unsupported data type: $dataType") | ||
| } | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2375,8 +2375,8 @@ class Dataset[T] private[sql]( | |
| val cnvtr = converter.getOrElse(new ArrowConverters) | ||
| withNewExecutionId { | ||
| try { | ||
| val collectedRows = queryExecution.executedPlan.executeCollect() | ||
| cnvtr.internalRowsToPayload(collectedRows, this.schema) | ||
| val rowIter = queryExecution.executedPlan.executeToIterator() | ||
|
||
| cnvtr.interalRowIterToPayload(rowIter, this.schema) | ||
| } catch { | ||
| case e: Exception => | ||
| throw e | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can get rid of initialSize for an iterator orientated implementation