diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Arrow.scala b/sql/core/src/main/scala/org/apache/spark/sql/Arrow.scala index 31b90259de0e..1b68a9d0429c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Arrow.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Arrow.scala @@ -32,66 +32,70 @@ import org.apache.spark.sql.types._ object Arrow { - /** - * Compute the number of bytes needed to build validity map. According to - * [Arrow Layout](https://github.com/apache/arrow/blob/master/format/Layout.md#null-bitmaps), - * the length of the validity bitmap should be multiples of 64 bytes. - */ - private def numBytesOfBitmap(numOfRows: Int): Int = { - Math.ceil(numOfRows / 64.0).toInt * 8 - } + private case class TypeFuncs(getType: () => ArrowType, + fill: ArrowBuf => Unit, + write: (InternalRow, Int, ArrowBuf) => Unit) - private def fillArrow(buf: ArrowBuf, dataType: DataType): Unit = { - dataType match { - case NullType => - case BooleanType => - buf.writeBoolean(false) - case ShortType => - buf.writeShort(0) - case IntegerType => - buf.writeInt(0) - case LongType => - buf.writeLong(0L) - case FloatType => - buf.writeFloat(0f) - case DoubleType => - buf.writeDouble(0d) - case ByteType => - buf.writeByte(0) - case _ => - throw new UnsupportedOperationException( - s"Unsupported data type ${dataType.simpleString}") - } - } + private def getTypeFuncs(dataType: DataType): TypeFuncs = { + val err = s"Unsupported data type ${dataType.simpleString}" - /** - * Get an entry from the InternalRow, and then set to ArrowBuf. - * Note: No Null check for the entry. - */ - private def getAndSetToArrow( - row: InternalRow, - buf: ArrowBuf, - dataType: DataType, - ordinal: Int): Unit = { dataType match { case NullType => + TypeFuncs( + () => ArrowType.Null.INSTANCE, + (buf: ArrowBuf) => (), + (row: InternalRow, ordinal: Int, buf: ArrowBuf) => ()) case BooleanType => - buf.writeBoolean(row.getBoolean(ordinal)) + TypeFuncs( + () => ArrowType.Bool.INSTANCE, + (buf: ArrowBuf) => buf.writeBoolean(false), + (row: InternalRow, ordinal: Int, buf: ArrowBuf) => + buf.writeBoolean(row.getBoolean(ordinal))) case ShortType => - buf.writeShort(row.getShort(ordinal)) + TypeFuncs( + () => new ArrowType.Int(4 * ShortType.defaultSize, true), // TODO - check on this + (buf: ArrowBuf) => buf.writeShort(0), + (row: InternalRow, ordinal: Int, buf: ArrowBuf) => buf.writeShort(row.getShort(ordinal))) case IntegerType => - buf.writeInt(row.getInt(ordinal)) + TypeFuncs( + () => new ArrowType.Int(8 * IntegerType.defaultSize, true), + (buf: ArrowBuf) => buf.writeInt(0), + (row: InternalRow, ordinal: Int, buf: ArrowBuf) => buf.writeInt(row.getInt(ordinal))) case LongType => - buf.writeLong(row.getLong(ordinal)) + TypeFuncs( + () => new ArrowType.Int(8 * LongType.defaultSize, true), + (buf: ArrowBuf) => buf.writeLong(0L), + (row: InternalRow, ordinal: Int, buf: ArrowBuf) => buf.writeLong(row.getLong(ordinal))) case FloatType => - buf.writeFloat(row.getFloat(ordinal)) + TypeFuncs( + () => new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE), + (buf: ArrowBuf) => buf.writeFloat(0f), + (row: InternalRow, ordinal: Int, buf: ArrowBuf) => buf.writeFloat(row.getFloat(ordinal))) case DoubleType => - buf.writeDouble(row.getDouble(ordinal)) + TypeFuncs( + () => new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE), + (buf: ArrowBuf) => buf.writeDouble(0d), + (row: InternalRow, ordinal: Int, buf: ArrowBuf) => + buf.writeDouble(row.getDouble(ordinal))) case ByteType => - buf.writeByte(row.getByte(ordinal)) + TypeFuncs( + () => new ArrowType.Int(8, false), + (buf: ArrowBuf) => buf.writeByte(0), + (row: InternalRow, ordinal: Int, buf: ArrowBuf) => buf.writeByte(row.getByte(ordinal))) + case StringType => + TypeFuncs( + () => ArrowType.Utf8.INSTANCE, + (buf: ArrowBuf) => throw new UnsupportedOperationException(err), // TODO + (row: InternalRow, ordinal: Int, buf: ArrowBuf) => + throw new UnsupportedOperationException(err)) + case StructType(_) => + TypeFuncs( + () => ArrowType.Struct.INSTANCE, + (buf: ArrowBuf) => throw new UnsupportedOperationException(err), // TODO + (row: InternalRow, ordinal: Int, buf: ArrowBuf) => + throw new UnsupportedOperationException(err)) case _ => - throw new UnsupportedOperationException( - s"Unsupported data type ${dataType.simpleString}") + throw new IllegalArgumentException(err) } } @@ -130,6 +134,7 @@ object Arrow { validityMutator.setValueCount(numOfRows) val buf = allocator.buffer(numOfRows * field.dataType.defaultSize) + val typeFunc = getTypeFuncs(field.dataType) var nullCount = 0 var index = 0 while (index < rows.length) { @@ -137,10 +142,10 @@ object Arrow { if (row.isNullAt(ordinal)) { nullCount += 1 validityMutator.set(index, 0) - fillArrow(buf, field.dataType) + typeFunc.fill(buf) } else { validityMutator.set(index, 1) - getAndSetToArrow(row, buf, field.dataType, ordinal) + typeFunc.write(row, ordinal, buf) } index += 1 } @@ -182,7 +187,7 @@ object Arrow { } private[sql] def schemaToArrowSchema(schema: StructType): Schema = { - val arrowFields = schema.fields.map(sparkFieldToArrowField(_)) + val arrowFields = schema.fields.map(sparkFieldToArrowField) new Schema(arrowFields.toList.asJava) } @@ -193,36 +198,10 @@ object Arrow { dataType match { case StructType(fields) => - val childrenFields = fields.map(sparkFieldToArrowField(_)).toList.asJava + val childrenFields = fields.map(sparkFieldToArrowField).toList.asJava new Field(name, nullable, ArrowType.Struct.INSTANCE, childrenFields) case _ => - new Field(name, nullable, dataTypeToArrowType(dataType), List.empty[Field].asJava) - } - } - - /** - * Transform Spark DataType to Arrow ArrowType. - */ - private[sql] def dataTypeToArrowType(dt: DataType): ArrowType = { - dt match { - case IntegerType => - new ArrowType.Int(8 * IntegerType.defaultSize, true) - case LongType => - new ArrowType.Int(8 * LongType.defaultSize, true) - case StringType => - ArrowType.Utf8.INSTANCE - case DoubleType => - new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE) - case FloatType => - new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE) - case BooleanType => - ArrowType.Bool.INSTANCE - case ByteType => - new ArrowType.Int(8, false) - case StructType(_) => - ArrowType.Struct.INSTANCE - case _ => - throw new IllegalArgumentException(s"Unsupported data type") + new Field(name, nullable, getTypeFuncs(dataType).getType(), List.empty[Field].asJava) } } } diff --git a/sql/core/src/test/resources/test-data/arrowNullInts.json b/sql/core/src/test/resources/test-data/arrow/null-ints.json similarity index 100% rename from sql/core/src/test/resources/test-data/arrowNullInts.json rename to sql/core/src/test/resources/test-data/arrow/null-ints.json diff --git a/sql/core/src/test/resources/test-data/arrowNullStrings.json b/sql/core/src/test/resources/test-data/arrow/null-strings.json similarity index 100% rename from sql/core/src/test/resources/test-data/arrowNullStrings.json rename to sql/core/src/test/resources/test-data/arrow/null-strings.json diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala index 036a367cb0db..ff65f13151ad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala @@ -32,14 +32,15 @@ class ArrowSuite extends SharedSQLContext { } test("convert int column with null to arrow") { - testCollect(nullInts, "test-data/arrowNullInts.json") + testCollect(nullInts, "test-data/arrow/null-ints.json") } test("convert string column with null to arrow") { val nullStringsColOnly = nullStrings.select(nullStrings.columns(1)) - testCollect(nullStringsColOnly, "test-data/arrowNullStrings.json") + testCollect(nullStringsColOnly, "test-data/arrow/null-strings.json") } + /** Test that a converted DataFrame to Arrow record batch equals batch read from JSON file */ private def testCollect(df: DataFrame, arrowFile: String) { val jsonFilePath = testFile(arrowFile) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetToArrowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetToArrowSuite.scala deleted file mode 100644 index 8aec3699c9dd..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetToArrowSuite.scala +++ /dev/null @@ -1,172 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import java.io._ -import java.net.{InetAddress, Socket} -import java.nio.{ByteBuffer, ByteOrder} -import java.nio.channels.FileChannel - -import scala.util.Random - -import io.netty.buffer.ArrowBuf -import org.apache.arrow.flatbuf.Precision -import org.apache.arrow.memory.RootAllocator -import org.apache.arrow.vector.file.ArrowReader -import org.apache.arrow.vector.types.pojo.{ArrowType, Field} - -import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String -import org.apache.spark.util.Utils - - -case class ArrowTestClass(col1: Int, col2: Double, col3: String) - -class DatasetToArrowSuite extends QueryTest with SharedSQLContext { - - import testImplicits._ - - final val numElements = 4 - @transient var data: Seq[ArrowTestClass] = _ - - override def beforeAll(): Unit = { - super.beforeAll() - data = Seq.fill(numElements)(ArrowTestClass( - Random.nextInt, Random.nextDouble, Random.nextString(Random.nextInt(100)))) - } - - test("Collect as arrow to python") { - val dataset = data.toDS() - - val port = dataset.collectAsArrowToPython() - - val receiver: RecordBatchReceiver = new RecordBatchReceiver - val (buffer, numBytesRead) = receiver.connectAndRead(port) - val channel = receiver.makeFile(buffer) - val reader = new ArrowReader(channel, receiver.allocator) - - val footer = reader.readFooter() - val schema = footer.getSchema - - val numCols = schema.getFields.size() - assert(numCols === dataset.schema.fields.length) - for (i <- 0 until schema.getFields.size()) { - val arrowField = schema.getFields.get(i) - val sparkField = dataset.schema.fields(i) - assert(arrowField.getName === sparkField.name) - assert(arrowField.isNullable === sparkField.nullable) - assert(DatasetToArrowSuite.compareSchemaTypes(arrowField, sparkField)) - } - - val blockMetadata = footer.getRecordBatches - assert(blockMetadata.size() === 1) - - val recordBatch = reader.readRecordBatch(blockMetadata.get(0)) - val nodes = recordBatch.getNodes - assert(nodes.size() === numCols + 1) // +1 for Type String, which has two nodes. - - val firstNode = nodes.get(0) - assert(firstNode.getLength === numElements) - assert(firstNode.getNullCount === 0) - - val buffers = recordBatch.getBuffers - assert(buffers.size() === (numCols + 1) * 2) // +1 for Type String - - assert(receiver.getIntArray(buffers.get(1)) === data.map(_.col1)) - assert(receiver.getDoubleArray(buffers.get(3)) === data.map(_.col2)) - assert(receiver.getStringArray(buffers.get(5), buffers.get(7)) === - data.map(d => UTF8String.fromString(d.col3)).toArray) - } -} - -object DatasetToArrowSuite { - def compareSchemaTypes(arrowField: Field, sparkField: StructField): Boolean = { - val arrowType = arrowField.getType - val sparkType = sparkField.dataType - (arrowType, sparkType) match { - case (_: ArrowType.Int, _: IntegerType) => true - case (_: ArrowType.FloatingPoint, _: DoubleType) => - arrowType.asInstanceOf[ArrowType.FloatingPoint].getPrecision == Precision.DOUBLE - case (_: ArrowType.FloatingPoint, _: FloatType) => - arrowType.asInstanceOf[ArrowType.FloatingPoint].getPrecision == Precision.SINGLE - case (_: ArrowType.List, _: StringType) => - val subField = arrowField.getChildren - (subField.size() == 1) && subField.get(0).getType.isInstanceOf[ArrowType.Utf8] - case (_: ArrowType.Bool, _: BooleanType) => true - case _ => false - } - } -} - -class RecordBatchReceiver { - - val allocator = new RootAllocator(Long.MaxValue) - - def getIntArray(buf: ArrowBuf): Array[Int] = { - val buffer = ByteBuffer.wrap(array(buf)).order(ByteOrder.LITTLE_ENDIAN).asIntBuffer() - val resultArray = Array.ofDim[Int](buffer.remaining()) - buffer.get(resultArray) - resultArray - } - - def getDoubleArray(buf: ArrowBuf): Array[Double] = { - val buffer = ByteBuffer.wrap(array(buf)).order(ByteOrder.LITTLE_ENDIAN).asDoubleBuffer() - val resultArray = Array.ofDim[Double](buffer.remaining()) - buffer.get(resultArray) - resultArray - } - - def getStringArray(bufOffsets: ArrowBuf, bufValues: ArrowBuf): Array[UTF8String] = { - val offsets = getIntArray(bufOffsets) - val lens = offsets.zip(offsets.drop(1)) - .map { case (prevOffset, offset) => offset - prevOffset } - - val values = array(bufValues) - val strings = offsets.zip(lens).map { case (offset, len) => - UTF8String.fromBytes(values, offset, len) - } - strings - } - - private def array(buf: ArrowBuf): Array[Byte] = { - val bytes = Array.ofDim[Byte](buf.readableBytes()) - buf.readBytes(bytes) - bytes - } - - def connectAndRead(port: Int): (Array[Byte], Int) = { - val clientSocket = new Socket(InetAddress.getByName("localhost"), port) - val clientDataIns = new DataInputStream(clientSocket.getInputStream) - val messageLength = clientDataIns.readInt() - val buffer = Array.ofDim[Byte](messageLength) - clientDataIns.readFully(buffer, 0, messageLength) - (buffer, messageLength) - } - - def makeFile(buffer: Array[Byte]): FileChannel = { - val tempDir = Utils.createTempDir(namePrefix = this.getClass.getName).getPath - val arrowFile = new File(tempDir, "arrow-bytes") - val arrowOus = new FileOutputStream(arrowFile.getPath) - arrowOus.write(buffer) - arrowOus.close() - - val arrowIns = new FileInputStream(arrowFile.getPath) - arrowIns.getChannel - } -}