Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
314 changes: 184 additions & 130 deletions sql/core/src/main/scala/org/apache/spark/sql/Arrow.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ import scala.collection.JavaConverters._
import scala.language.implicitConversions

import io.netty.buffer.ArrowBuf
import org.apache.arrow.memory.RootAllocator
import org.apache.arrow.vector.BitVector
import org.apache.arrow.memory.{BaseAllocator, RootAllocator}
import org.apache.arrow.vector._
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually explicit imports are preferred. Is it basically an import for every vector type?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is similar to import org.apache.spark.sql.types._ , otherwise we end up with a very long import (10+)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, it will probably be fine like this then

import org.apache.arrow.vector.BaseValueVector.BaseMutator
import org.apache.arrow.vector.schema.{ArrowFieldNode, ArrowRecordBatch}
import org.apache.arrow.vector.types.FloatingPointPrecision
import org.apache.arrow.vector.types.pojo.{ArrowType, Field, Schema}
Expand All @@ -32,70 +33,17 @@ import org.apache.spark.sql.types._

object Arrow {

private case class TypeFuncs(getType: () => ArrowType,
fill: ArrowBuf => Unit,
write: (InternalRow, Int, ArrowBuf) => Unit)

private def getTypeFuncs(dataType: DataType): TypeFuncs = {
val err = s"Unsupported data type ${dataType.simpleString}"

private def sparkTypeToArrowType(dataType: DataType): ArrowType = {
dataType match {
case NullType =>
TypeFuncs(
() => ArrowType.Null.INSTANCE,
(buf: ArrowBuf) => (),
(row: InternalRow, ordinal: Int, buf: ArrowBuf) => ())
case BooleanType =>
TypeFuncs(
() => ArrowType.Bool.INSTANCE,
(buf: ArrowBuf) => buf.writeBoolean(false),
(row: InternalRow, ordinal: Int, buf: ArrowBuf) =>
buf.writeBoolean(row.getBoolean(ordinal)))
case ShortType =>
TypeFuncs(
() => new ArrowType.Int(8 * ShortType.defaultSize, true),
(buf: ArrowBuf) => buf.writeShort(0),
(row: InternalRow, ordinal: Int, buf: ArrowBuf) => buf.writeShort(row.getShort(ordinal)))
case IntegerType =>
TypeFuncs(
() => new ArrowType.Int(8 * IntegerType.defaultSize, true),
(buf: ArrowBuf) => buf.writeInt(0),
(row: InternalRow, ordinal: Int, buf: ArrowBuf) => buf.writeInt(row.getInt(ordinal)))
case LongType =>
TypeFuncs(
() => new ArrowType.Int(8 * LongType.defaultSize, true),
(buf: ArrowBuf) => buf.writeLong(0L),
(row: InternalRow, ordinal: Int, buf: ArrowBuf) => buf.writeLong(row.getLong(ordinal)))
case FloatType =>
TypeFuncs(
() => new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE),
(buf: ArrowBuf) => buf.writeFloat(0f),
(row: InternalRow, ordinal: Int, buf: ArrowBuf) => buf.writeFloat(row.getFloat(ordinal)))
case DoubleType =>
TypeFuncs(
() => new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE),
(buf: ArrowBuf) => buf.writeDouble(0d),
(row: InternalRow, ordinal: Int, buf: ArrowBuf) =>
buf.writeDouble(row.getDouble(ordinal)))
case ByteType =>
TypeFuncs(
() => new ArrowType.Int(8, false),
(buf: ArrowBuf) => buf.writeByte(0),
(row: InternalRow, ordinal: Int, buf: ArrowBuf) => buf.writeByte(row.getByte(ordinal)))
case StringType =>
TypeFuncs(
() => ArrowType.Utf8.INSTANCE,
(buf: ArrowBuf) => throw new UnsupportedOperationException(err), // TODO
(row: InternalRow, ordinal: Int, buf: ArrowBuf) =>
throw new UnsupportedOperationException(err))
case StructType(_) =>
TypeFuncs(
() => ArrowType.Struct.INSTANCE,
(buf: ArrowBuf) => throw new UnsupportedOperationException(err), // TODO
(row: InternalRow, ordinal: Int, buf: ArrowBuf) =>
throw new UnsupportedOperationException(err))
case _ =>
throw new IllegalArgumentException(err)
case BooleanType => ArrowType.Bool.INSTANCE
case ShortType => new ArrowType.Int(8 * ShortType.defaultSize, true)
case IntegerType => new ArrowType.Int(8 * IntegerType.defaultSize, true)
case LongType => new ArrowType.Int(8 * LongType.defaultSize, true)
case FloatType => new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)
case DoubleType => new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)
case ByteType => new ArrowType.Int(8, false)
case StringType => ArrowType.Utf8.INSTANCE
case _ => throw new UnsupportedOperationException(s"Unsupported data type: ${dataType}")
}
}

Expand All @@ -110,8 +58,8 @@ object Arrow {
internalRowToArrowBuf(rows, ordinal, field, allocator)
}

val buffers = bufAndField.flatMap(_._1).toList.asJava
val fieldNodes = bufAndField.flatMap(_._2).toList.asJava
val fieldNodes = bufAndField.flatMap(_._1).toList.asJava
val buffers = bufAndField.flatMap(_._2).toList.asJava

new ArrowRecordBatch(rows.length, fieldNodes, buffers)
}
Expand All @@ -123,67 +71,24 @@ object Arrow {
rows: Array[InternalRow],
ordinal: Int,
field: StructField,
allocator: RootAllocator): (Array[ArrowBuf], Array[ArrowFieldNode]) = {
allocator: RootAllocator): (Array[ArrowFieldNode], Array[ArrowBuf]) = {
val numOfRows = rows.length
val columnWriter = ColumnWriter(allocator, field.dataType)
columnWriter.init(numOfRows)
var index = 0

field.dataType match {
case ShortType | IntegerType | LongType | DoubleType | FloatType | BooleanType | ByteType =>
val validityVector = new BitVector("validity", allocator)
val validityMutator = validityVector.getMutator
validityVector.allocateNew(numOfRows)
validityMutator.setValueCount(numOfRows)

val buf = allocator.buffer(numOfRows * field.dataType.defaultSize)
val typeFunc = getTypeFuncs(field.dataType)
var nullCount = 0
var index = 0
while (index < rows.length) {
val row = rows(index)
if (row.isNullAt(ordinal)) {
nullCount += 1
validityMutator.set(index, 0)
typeFunc.fill(buf)
} else {
validityMutator.set(index, 1)
typeFunc.write(row, ordinal, buf)
}
index += 1
}

val fieldNode = new ArrowFieldNode(numOfRows, nullCount)

(Array(validityVector.getBuffer, buf), Array(fieldNode))

case StringType =>
val validityVector = new BitVector("validity", allocator)
val validityMutator = validityVector.getMutator()
validityVector.allocateNew(numOfRows)
validityMutator.setValueCount(numOfRows)

val bufOffset = allocator.buffer((numOfRows + 1) * IntegerType.defaultSize)
var bytesCount = 0
bufOffset.writeInt(bytesCount)
val bufValues = allocator.buffer(1024)
var nullCount = 0
rows.zipWithIndex.foreach { case (row, index) =>
if (row.isNullAt(ordinal)) {
nullCount += 1
validityMutator.set(index, 0)
bufOffset.writeInt(bytesCount)
} else {
validityMutator.set(index, 1)
val bytes = row.getUTF8String(ordinal).getBytes
bytesCount += bytes.length
bufOffset.writeInt(bytesCount)
bufValues.writeBytes(bytes)
}
}

val fieldNode = new ArrowFieldNode(numOfRows, nullCount)

(Array(validityVector.getBuffer, bufOffset, bufValues),
Array(fieldNode))
while(index < numOfRows) {
val row = rows(index)
if (row.isNullAt(ordinal)) {
columnWriter.writeNull()
} else {
columnWriter.write(row, ordinal)
}
index += 1
}

val (arrowFieldNodes, arrowBufs) = columnWriter.finish()
(arrowFieldNodes.toArray, arrowBufs.toArray)
}

private[sql] def schemaToArrowSchema(schema: StructType): Schema = {
Expand All @@ -195,13 +100,162 @@ object Arrow {
val name = sparkField.name
val dataType = sparkField.dataType
val nullable = sparkField.nullable
new Field(name, nullable, sparkTypeToArrowType(dataType), List.empty[Field].asJava)
}
}

object ColumnWriter {
def apply(allocator: BaseAllocator, dataType: DataType): ColumnWriter = {
dataType match {
case StructType(fields) =>
val childrenFields = fields.map(sparkFieldToArrowField).toList.asJava
new Field(name, nullable, ArrowType.Struct.INSTANCE, childrenFields)
case _ =>
new Field(name, nullable, getTypeFuncs(dataType).getType(), List.empty[Field].asJava)
case BooleanType => new BooleanColumnWriter(allocator)
case ShortType => new ShortColumnWriter(allocator)
case IntegerType => new IntegerColumnWriter(allocator)
case LongType => new LongColumnWriter(allocator)
case FloatType => new FloatColumnWriter(allocator)
case DoubleType => new DoubleColumnWriter(allocator)
case ByteType => new ByteColumnWriter(allocator)
case StringType => new UTF8StringColumnWriter(allocator)
case _ => throw new UnsupportedOperationException(s"Unsupported data type: ${dataType}")
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets move this to a ColumnWriter object apply() method

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do

}
}
}

private[sql] trait ColumnWriter {
def init(initialSize: Int): Unit
def writeNull(): Unit
def write(row: InternalRow, ordinal: Int): Unit
def finish(): (Seq[ArrowFieldNode], Seq[ArrowBuf])
}

/**
* Base class for flat arrow column writer, i.e., column without children.
*/
private[sql] abstract class PrimitiveColumnWriter(protected val allocator: BaseAllocator)
extends ColumnWriter {
protected val valueVector: BaseDataValueVector
protected val valueMutator: BaseMutator

var count = 0
var nullCount = 0

protected def setNull(): Unit
protected def setValue(row: InternalRow, ordinal: Int): Unit
protected def valueBuffers(): Seq[ArrowBuf]
= valueVector.getBuffers(true) // TODO: check the flag

override def init(initialSize: Int): Unit = {
valueVector.allocateNew()
}

override def writeNull(): Unit = {
setNull()
nullCount += 1
count += 1
}

override def write(row: InternalRow, ordinal: Int): Unit = {
setValue(row, ordinal)
count += 1
}

override def finish(): (Seq[ArrowFieldNode], Seq[ArrowBuf]) = {
valueMutator.setValueCount(count)
val fieldNode = new ArrowFieldNode(count, nullCount)
(List(fieldNode), valueBuffers)
}
}

private[sql] class BooleanColumnWriter(allocator: BaseAllocator)
extends PrimitiveColumnWriter(allocator) {
private def bool2int(b: Boolean): Int = if (b) 1 else 0

override protected val valueVector: NullableBitVector
= new NullableBitVector("BooleanValue", allocator)
override protected val valueMutator: NullableBitVector#Mutator = valueVector.getMutator

override def setNull(): Unit = valueMutator.setNull(count)
override def setValue(row: InternalRow, ordinal: Int): Unit
= valueMutator.setSafe(count, bool2int(row.getBoolean(ordinal)))
}

private[sql] class ShortColumnWriter(allocator: BaseAllocator)
extends PrimitiveColumnWriter(allocator) {
override protected val valueVector: NullableSmallIntVector
= new NullableSmallIntVector("ShortValue", allocator)
override protected val valueMutator: NullableSmallIntVector#Mutator = valueVector.getMutator

override def setNull(): Unit = valueMutator.setNull(count)
override def setValue(row: InternalRow, ordinal: Int): Unit
= valueMutator.setSafe(count, row.getShort(ordinal))
}

private[sql] class IntegerColumnWriter(allocator: BaseAllocator)
extends PrimitiveColumnWriter(allocator) {
override protected val valueVector: NullableIntVector
= new NullableIntVector("IntValue", allocator)
override protected val valueMutator: NullableIntVector#Mutator = valueVector.getMutator

override def setNull(): Unit = valueMutator.setNull(count)
override def setValue(row: InternalRow, ordinal: Int): Unit
= valueMutator.setSafe(count, row.getInt(ordinal))
}

private[sql] class LongColumnWriter(allocator: BaseAllocator)
extends PrimitiveColumnWriter(allocator) {
override protected val valueVector: NullableBigIntVector
= new NullableBigIntVector("LongValue", allocator)
override protected val valueMutator: NullableBigIntVector#Mutator = valueVector.getMutator

override def setNull(): Unit = valueMutator.setNull(count)
override def setValue(row: InternalRow, ordinal: Int): Unit
= valueMutator.setSafe(count, row.getLong(ordinal))
}

private[sql] class FloatColumnWriter(allocator: BaseAllocator)
extends PrimitiveColumnWriter(allocator) {
override protected val valueVector: NullableFloat4Vector
= new NullableFloat4Vector("FloatValue", allocator)
override protected val valueMutator: NullableFloat4Vector#Mutator
= valueVector.getMutator

override def setNull(): Unit = valueMutator.setNull(count)
override def setValue(row: InternalRow, ordinal: Int): Unit
= valueMutator.setSafe(count, row.getFloat(ordinal))
}

private[sql] class DoubleColumnWriter(allocator: BaseAllocator)
extends PrimitiveColumnWriter(allocator) {
override protected val valueVector: NullableFloat8Vector
= new NullableFloat8Vector("DoubleValue", allocator)
override protected val valueMutator: NullableFloat8Vector#Mutator
= valueVector.getMutator

override def setNull(): Unit = valueMutator.setNull(count)
override def setValue(row: InternalRow, ordinal: Int): Unit
= valueMutator.setSafe(count, row.getDouble(ordinal))
}

private[sql] class ByteColumnWriter(allocator: BaseAllocator)
extends PrimitiveColumnWriter(allocator) {
override protected val valueVector: NullableUInt1Vector
= new NullableUInt1Vector("ByteValue", allocator)
override protected val valueMutator: NullableUInt1Vector#Mutator = valueVector.getMutator

override def setNull(): Unit = valueMutator.setNull(count)
override def setValue(row: InternalRow, ordinal: Int): Unit
= valueMutator.setSafe(count, row.getByte(ordinal))
}

private[sql] class UTF8StringColumnWriter(allocator: BaseAllocator)
extends PrimitiveColumnWriter(allocator) {
override protected val valueVector: NullableVarBinaryVector
= new NullableVarBinaryVector("UTF8StringValue", allocator)
override protected val valueMutator: NullableVarBinaryVector#Mutator
= valueVector.getMutator

override def setNull(): Unit = valueMutator.setNull(count)
override def setValue(row: InternalRow, ordinal: Int): Unit = {
val bytes = row.getUTF8String(ordinal).getBytes
valueMutator.setSafe(count, bytes, 0, bytes.length)
}
}