forked from apache/spark
-
Notifications
You must be signed in to change notification settings - Fork 2
Implement Arrow column writers #16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
icexelloss
wants to merge
5
commits into
BryanCutler:arrow-integration
from
icexelloss:arrow-integration
Closed
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
8c7de51
Implement Arrow column writers
icexelloss 0fd25ea
Move column writers to Arrow.scala
icexelloss 48155e0
Add support for more types; Switch to arrow NullableVector
icexelloss 3ede886
Address comments
icexelloss f432bac
Fix tests
icexelloss File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,8 +21,9 @@ import scala.collection.JavaConverters._ | |
| import scala.language.implicitConversions | ||
|
|
||
| import io.netty.buffer.ArrowBuf | ||
| import org.apache.arrow.memory.RootAllocator | ||
| import org.apache.arrow.vector.BitVector | ||
| import org.apache.arrow.memory.{BaseAllocator, RootAllocator} | ||
| import org.apache.arrow.vector._ | ||
| import org.apache.arrow.vector.BaseValueVector.BaseMutator | ||
| import org.apache.arrow.vector.schema.{ArrowFieldNode, ArrowRecordBatch} | ||
| import org.apache.arrow.vector.types.FloatingPointPrecision | ||
| import org.apache.arrow.vector.types.pojo.{ArrowType, Field, Schema} | ||
|
|
@@ -32,70 +33,17 @@ import org.apache.spark.sql.types._ | |
|
|
||
| object Arrow { | ||
|
|
||
| private case class TypeFuncs(getType: () => ArrowType, | ||
| fill: ArrowBuf => Unit, | ||
| write: (InternalRow, Int, ArrowBuf) => Unit) | ||
|
|
||
| private def getTypeFuncs(dataType: DataType): TypeFuncs = { | ||
| val err = s"Unsupported data type ${dataType.simpleString}" | ||
|
|
||
| private def sparkTypeToArrowType(dataType: DataType): ArrowType = { | ||
| dataType match { | ||
| case NullType => | ||
| TypeFuncs( | ||
| () => ArrowType.Null.INSTANCE, | ||
| (buf: ArrowBuf) => (), | ||
| (row: InternalRow, ordinal: Int, buf: ArrowBuf) => ()) | ||
| case BooleanType => | ||
| TypeFuncs( | ||
| () => ArrowType.Bool.INSTANCE, | ||
| (buf: ArrowBuf) => buf.writeBoolean(false), | ||
| (row: InternalRow, ordinal: Int, buf: ArrowBuf) => | ||
| buf.writeBoolean(row.getBoolean(ordinal))) | ||
| case ShortType => | ||
| TypeFuncs( | ||
| () => new ArrowType.Int(8 * ShortType.defaultSize, true), | ||
| (buf: ArrowBuf) => buf.writeShort(0), | ||
| (row: InternalRow, ordinal: Int, buf: ArrowBuf) => buf.writeShort(row.getShort(ordinal))) | ||
| case IntegerType => | ||
| TypeFuncs( | ||
| () => new ArrowType.Int(8 * IntegerType.defaultSize, true), | ||
| (buf: ArrowBuf) => buf.writeInt(0), | ||
| (row: InternalRow, ordinal: Int, buf: ArrowBuf) => buf.writeInt(row.getInt(ordinal))) | ||
| case LongType => | ||
| TypeFuncs( | ||
| () => new ArrowType.Int(8 * LongType.defaultSize, true), | ||
| (buf: ArrowBuf) => buf.writeLong(0L), | ||
| (row: InternalRow, ordinal: Int, buf: ArrowBuf) => buf.writeLong(row.getLong(ordinal))) | ||
| case FloatType => | ||
| TypeFuncs( | ||
| () => new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE), | ||
| (buf: ArrowBuf) => buf.writeFloat(0f), | ||
| (row: InternalRow, ordinal: Int, buf: ArrowBuf) => buf.writeFloat(row.getFloat(ordinal))) | ||
| case DoubleType => | ||
| TypeFuncs( | ||
| () => new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE), | ||
| (buf: ArrowBuf) => buf.writeDouble(0d), | ||
| (row: InternalRow, ordinal: Int, buf: ArrowBuf) => | ||
| buf.writeDouble(row.getDouble(ordinal))) | ||
| case ByteType => | ||
| TypeFuncs( | ||
| () => new ArrowType.Int(8, false), | ||
| (buf: ArrowBuf) => buf.writeByte(0), | ||
| (row: InternalRow, ordinal: Int, buf: ArrowBuf) => buf.writeByte(row.getByte(ordinal))) | ||
| case StringType => | ||
| TypeFuncs( | ||
| () => ArrowType.Utf8.INSTANCE, | ||
| (buf: ArrowBuf) => throw new UnsupportedOperationException(err), // TODO | ||
| (row: InternalRow, ordinal: Int, buf: ArrowBuf) => | ||
| throw new UnsupportedOperationException(err)) | ||
| case StructType(_) => | ||
| TypeFuncs( | ||
| () => ArrowType.Struct.INSTANCE, | ||
| (buf: ArrowBuf) => throw new UnsupportedOperationException(err), // TODO | ||
| (row: InternalRow, ordinal: Int, buf: ArrowBuf) => | ||
| throw new UnsupportedOperationException(err)) | ||
| case _ => | ||
| throw new IllegalArgumentException(err) | ||
| case BooleanType => ArrowType.Bool.INSTANCE | ||
| case ShortType => new ArrowType.Int(8 * ShortType.defaultSize, true) | ||
| case IntegerType => new ArrowType.Int(8 * IntegerType.defaultSize, true) | ||
| case LongType => new ArrowType.Int(8 * LongType.defaultSize, true) | ||
| case FloatType => new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE) | ||
| case DoubleType => new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE) | ||
| case ByteType => new ArrowType.Int(8, false) | ||
| case StringType => ArrowType.Utf8.INSTANCE | ||
| case _ => throw new UnsupportedOperationException(s"Unsupported data type: ${dataType}") | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -110,8 +58,8 @@ object Arrow { | |
| internalRowToArrowBuf(rows, ordinal, field, allocator) | ||
| } | ||
|
|
||
| val buffers = bufAndField.flatMap(_._1).toList.asJava | ||
| val fieldNodes = bufAndField.flatMap(_._2).toList.asJava | ||
| val fieldNodes = bufAndField.flatMap(_._1).toList.asJava | ||
| val buffers = bufAndField.flatMap(_._2).toList.asJava | ||
|
|
||
| new ArrowRecordBatch(rows.length, fieldNodes, buffers) | ||
| } | ||
|
|
@@ -123,67 +71,24 @@ object Arrow { | |
| rows: Array[InternalRow], | ||
| ordinal: Int, | ||
| field: StructField, | ||
| allocator: RootAllocator): (Array[ArrowBuf], Array[ArrowFieldNode]) = { | ||
| allocator: RootAllocator): (Array[ArrowFieldNode], Array[ArrowBuf]) = { | ||
| val numOfRows = rows.length | ||
| val columnWriter = ColumnWriter(allocator, field.dataType) | ||
| columnWriter.init(numOfRows) | ||
| var index = 0 | ||
|
|
||
| field.dataType match { | ||
| case ShortType | IntegerType | LongType | DoubleType | FloatType | BooleanType | ByteType => | ||
| val validityVector = new BitVector("validity", allocator) | ||
| val validityMutator = validityVector.getMutator | ||
| validityVector.allocateNew(numOfRows) | ||
| validityMutator.setValueCount(numOfRows) | ||
|
|
||
| val buf = allocator.buffer(numOfRows * field.dataType.defaultSize) | ||
| val typeFunc = getTypeFuncs(field.dataType) | ||
| var nullCount = 0 | ||
| var index = 0 | ||
| while (index < rows.length) { | ||
| val row = rows(index) | ||
| if (row.isNullAt(ordinal)) { | ||
| nullCount += 1 | ||
| validityMutator.set(index, 0) | ||
| typeFunc.fill(buf) | ||
| } else { | ||
| validityMutator.set(index, 1) | ||
| typeFunc.write(row, ordinal, buf) | ||
| } | ||
| index += 1 | ||
| } | ||
|
|
||
| val fieldNode = new ArrowFieldNode(numOfRows, nullCount) | ||
|
|
||
| (Array(validityVector.getBuffer, buf), Array(fieldNode)) | ||
|
|
||
| case StringType => | ||
| val validityVector = new BitVector("validity", allocator) | ||
| val validityMutator = validityVector.getMutator() | ||
| validityVector.allocateNew(numOfRows) | ||
| validityMutator.setValueCount(numOfRows) | ||
|
|
||
| val bufOffset = allocator.buffer((numOfRows + 1) * IntegerType.defaultSize) | ||
| var bytesCount = 0 | ||
| bufOffset.writeInt(bytesCount) | ||
| val bufValues = allocator.buffer(1024) | ||
| var nullCount = 0 | ||
| rows.zipWithIndex.foreach { case (row, index) => | ||
| if (row.isNullAt(ordinal)) { | ||
| nullCount += 1 | ||
| validityMutator.set(index, 0) | ||
| bufOffset.writeInt(bytesCount) | ||
| } else { | ||
| validityMutator.set(index, 1) | ||
| val bytes = row.getUTF8String(ordinal).getBytes | ||
| bytesCount += bytes.length | ||
| bufOffset.writeInt(bytesCount) | ||
| bufValues.writeBytes(bytes) | ||
| } | ||
| } | ||
|
|
||
| val fieldNode = new ArrowFieldNode(numOfRows, nullCount) | ||
|
|
||
| (Array(validityVector.getBuffer, bufOffset, bufValues), | ||
| Array(fieldNode)) | ||
| while(index < numOfRows) { | ||
| val row = rows(index) | ||
| if (row.isNullAt(ordinal)) { | ||
| columnWriter.writeNull() | ||
| } else { | ||
| columnWriter.write(row, ordinal) | ||
| } | ||
| index += 1 | ||
| } | ||
|
|
||
| val (arrowFieldNodes, arrowBufs) = columnWriter.finish() | ||
| (arrowFieldNodes.toArray, arrowBufs.toArray) | ||
| } | ||
|
|
||
| private[sql] def schemaToArrowSchema(schema: StructType): Schema = { | ||
|
|
@@ -195,13 +100,162 @@ object Arrow { | |
| val name = sparkField.name | ||
| val dataType = sparkField.dataType | ||
| val nullable = sparkField.nullable | ||
| new Field(name, nullable, sparkTypeToArrowType(dataType), List.empty[Field].asJava) | ||
| } | ||
| } | ||
|
|
||
| object ColumnWriter { | ||
| def apply(allocator: BaseAllocator, dataType: DataType): ColumnWriter = { | ||
| dataType match { | ||
| case StructType(fields) => | ||
| val childrenFields = fields.map(sparkFieldToArrowField).toList.asJava | ||
| new Field(name, nullable, ArrowType.Struct.INSTANCE, childrenFields) | ||
| case _ => | ||
| new Field(name, nullable, getTypeFuncs(dataType).getType(), List.empty[Field].asJava) | ||
| case BooleanType => new BooleanColumnWriter(allocator) | ||
| case ShortType => new ShortColumnWriter(allocator) | ||
| case IntegerType => new IntegerColumnWriter(allocator) | ||
| case LongType => new LongColumnWriter(allocator) | ||
| case FloatType => new FloatColumnWriter(allocator) | ||
| case DoubleType => new DoubleColumnWriter(allocator) | ||
| case ByteType => new ByteColumnWriter(allocator) | ||
| case StringType => new UTF8StringColumnWriter(allocator) | ||
| case _ => throw new UnsupportedOperationException(s"Unsupported data type: ${dataType}") | ||
|
||
| } | ||
| } | ||
| } | ||
|
|
||
| private[sql] trait ColumnWriter { | ||
| def init(initialSize: Int): Unit | ||
| def writeNull(): Unit | ||
| def write(row: InternalRow, ordinal: Int): Unit | ||
| def finish(): (Seq[ArrowFieldNode], Seq[ArrowBuf]) | ||
| } | ||
|
|
||
| /** | ||
| * Base class for flat arrow column writer, i.e., column without children. | ||
| */ | ||
| private[sql] abstract class PrimitiveColumnWriter(protected val allocator: BaseAllocator) | ||
| extends ColumnWriter { | ||
| protected val valueVector: BaseDataValueVector | ||
| protected val valueMutator: BaseMutator | ||
|
|
||
| var count = 0 | ||
| var nullCount = 0 | ||
|
|
||
| protected def setNull(): Unit | ||
| protected def setValue(row: InternalRow, ordinal: Int): Unit | ||
| protected def valueBuffers(): Seq[ArrowBuf] | ||
| = valueVector.getBuffers(true) // TODO: check the flag | ||
|
|
||
| override def init(initialSize: Int): Unit = { | ||
| valueVector.allocateNew() | ||
| } | ||
|
|
||
| override def writeNull(): Unit = { | ||
| setNull() | ||
| nullCount += 1 | ||
| count += 1 | ||
| } | ||
|
|
||
| override def write(row: InternalRow, ordinal: Int): Unit = { | ||
| setValue(row, ordinal) | ||
| count += 1 | ||
| } | ||
|
|
||
| override def finish(): (Seq[ArrowFieldNode], Seq[ArrowBuf]) = { | ||
| valueMutator.setValueCount(count) | ||
| val fieldNode = new ArrowFieldNode(count, nullCount) | ||
| (List(fieldNode), valueBuffers) | ||
| } | ||
| } | ||
|
|
||
| private[sql] class BooleanColumnWriter(allocator: BaseAllocator) | ||
| extends PrimitiveColumnWriter(allocator) { | ||
| private def bool2int(b: Boolean): Int = if (b) 1 else 0 | ||
|
|
||
| override protected val valueVector: NullableBitVector | ||
| = new NullableBitVector("BooleanValue", allocator) | ||
| override protected val valueMutator: NullableBitVector#Mutator = valueVector.getMutator | ||
|
|
||
| override def setNull(): Unit = valueMutator.setNull(count) | ||
| override def setValue(row: InternalRow, ordinal: Int): Unit | ||
| = valueMutator.setSafe(count, bool2int(row.getBoolean(ordinal))) | ||
| } | ||
|
|
||
| private[sql] class ShortColumnWriter(allocator: BaseAllocator) | ||
| extends PrimitiveColumnWriter(allocator) { | ||
| override protected val valueVector: NullableSmallIntVector | ||
| = new NullableSmallIntVector("ShortValue", allocator) | ||
| override protected val valueMutator: NullableSmallIntVector#Mutator = valueVector.getMutator | ||
|
|
||
| override def setNull(): Unit = valueMutator.setNull(count) | ||
| override def setValue(row: InternalRow, ordinal: Int): Unit | ||
| = valueMutator.setSafe(count, row.getShort(ordinal)) | ||
| } | ||
|
|
||
| private[sql] class IntegerColumnWriter(allocator: BaseAllocator) | ||
| extends PrimitiveColumnWriter(allocator) { | ||
| override protected val valueVector: NullableIntVector | ||
| = new NullableIntVector("IntValue", allocator) | ||
| override protected val valueMutator: NullableIntVector#Mutator = valueVector.getMutator | ||
|
|
||
| override def setNull(): Unit = valueMutator.setNull(count) | ||
| override def setValue(row: InternalRow, ordinal: Int): Unit | ||
| = valueMutator.setSafe(count, row.getInt(ordinal)) | ||
| } | ||
|
|
||
| private[sql] class LongColumnWriter(allocator: BaseAllocator) | ||
| extends PrimitiveColumnWriter(allocator) { | ||
| override protected val valueVector: NullableBigIntVector | ||
| = new NullableBigIntVector("LongValue", allocator) | ||
| override protected val valueMutator: NullableBigIntVector#Mutator = valueVector.getMutator | ||
|
|
||
| override def setNull(): Unit = valueMutator.setNull(count) | ||
| override def setValue(row: InternalRow, ordinal: Int): Unit | ||
| = valueMutator.setSafe(count, row.getLong(ordinal)) | ||
| } | ||
|
|
||
| private[sql] class FloatColumnWriter(allocator: BaseAllocator) | ||
| extends PrimitiveColumnWriter(allocator) { | ||
| override protected val valueVector: NullableFloat4Vector | ||
| = new NullableFloat4Vector("FloatValue", allocator) | ||
| override protected val valueMutator: NullableFloat4Vector#Mutator | ||
| = valueVector.getMutator | ||
|
|
||
| override def setNull(): Unit = valueMutator.setNull(count) | ||
| override def setValue(row: InternalRow, ordinal: Int): Unit | ||
| = valueMutator.setSafe(count, row.getFloat(ordinal)) | ||
| } | ||
|
|
||
| private[sql] class DoubleColumnWriter(allocator: BaseAllocator) | ||
| extends PrimitiveColumnWriter(allocator) { | ||
| override protected val valueVector: NullableFloat8Vector | ||
| = new NullableFloat8Vector("DoubleValue", allocator) | ||
| override protected val valueMutator: NullableFloat8Vector#Mutator | ||
| = valueVector.getMutator | ||
|
|
||
| override def setNull(): Unit = valueMutator.setNull(count) | ||
| override def setValue(row: InternalRow, ordinal: Int): Unit | ||
| = valueMutator.setSafe(count, row.getDouble(ordinal)) | ||
| } | ||
|
|
||
| private[sql] class ByteColumnWriter(allocator: BaseAllocator) | ||
| extends PrimitiveColumnWriter(allocator) { | ||
| override protected val valueVector: NullableUInt1Vector | ||
| = new NullableUInt1Vector("ByteValue", allocator) | ||
| override protected val valueMutator: NullableUInt1Vector#Mutator = valueVector.getMutator | ||
|
|
||
| override def setNull(): Unit = valueMutator.setNull(count) | ||
| override def setValue(row: InternalRow, ordinal: Int): Unit | ||
| = valueMutator.setSafe(count, row.getByte(ordinal)) | ||
| } | ||
|
|
||
| private[sql] class UTF8StringColumnWriter(allocator: BaseAllocator) | ||
| extends PrimitiveColumnWriter(allocator) { | ||
| override protected val valueVector: NullableVarBinaryVector | ||
| = new NullableVarBinaryVector("UTF8StringValue", allocator) | ||
| override protected val valueMutator: NullableVarBinaryVector#Mutator | ||
| = valueVector.getMutator | ||
|
|
||
| override def setNull(): Unit = valueMutator.setNull(count) | ||
| override def setValue(row: InternalRow, ordinal: Int): Unit = { | ||
| val bytes = row.getUTF8String(ordinal).getBytes | ||
| valueMutator.setSafe(count, bytes, 0, bytes.length) | ||
| } | ||
| } | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Usually explicit imports are preferred. Is it basically an import for every vector type?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is similar to import org.apache.spark.sql.types._ , otherwise we end up with a very long import (10+)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, it will probably be fine like this then