diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Arrow.scala b/sql/core/src/main/scala/org/apache/spark/sql/Arrow.scala new file mode 100644 index 000000000000..31b90259de0e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/Arrow.scala @@ -0,0 +1,228 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +import scala.collection.JavaConverters._ +import scala.language.implicitConversions + +import io.netty.buffer.ArrowBuf +import org.apache.arrow.memory.RootAllocator +import org.apache.arrow.vector.BitVector +import org.apache.arrow.vector.schema.{ArrowFieldNode, ArrowRecordBatch} +import org.apache.arrow.vector.types.FloatingPointPrecision +import org.apache.arrow.vector.types.pojo.{ArrowType, Field, Schema} + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types._ + +object Arrow { + + /** + * Compute the number of bytes needed to build validity map. According to + * [Arrow Layout](https://github.com/apache/arrow/blob/master/format/Layout.md#null-bitmaps), + * the length of the validity bitmap should be multiples of 64 bytes. + */ + private def numBytesOfBitmap(numOfRows: Int): Int = { + Math.ceil(numOfRows / 64.0).toInt * 8 + } + + private def fillArrow(buf: ArrowBuf, dataType: DataType): Unit = { + dataType match { + case NullType => + case BooleanType => + buf.writeBoolean(false) + case ShortType => + buf.writeShort(0) + case IntegerType => + buf.writeInt(0) + case LongType => + buf.writeLong(0L) + case FloatType => + buf.writeFloat(0f) + case DoubleType => + buf.writeDouble(0d) + case ByteType => + buf.writeByte(0) + case _ => + throw new UnsupportedOperationException( + s"Unsupported data type ${dataType.simpleString}") + } + } + + /** + * Get an entry from the InternalRow, and then set to ArrowBuf. + * Note: No Null check for the entry. + */ + private def getAndSetToArrow( + row: InternalRow, + buf: ArrowBuf, + dataType: DataType, + ordinal: Int): Unit = { + dataType match { + case NullType => + case BooleanType => + buf.writeBoolean(row.getBoolean(ordinal)) + case ShortType => + buf.writeShort(row.getShort(ordinal)) + case IntegerType => + buf.writeInt(row.getInt(ordinal)) + case LongType => + buf.writeLong(row.getLong(ordinal)) + case FloatType => + buf.writeFloat(row.getFloat(ordinal)) + case DoubleType => + buf.writeDouble(row.getDouble(ordinal)) + case ByteType => + buf.writeByte(row.getByte(ordinal)) + case _ => + throw new UnsupportedOperationException( + s"Unsupported data type ${dataType.simpleString}") + } + } + + /** + * Transfer an array of InternalRow to an ArrowRecordBatch. + */ + def internalRowsToArrowRecordBatch( + rows: Array[InternalRow], + schema: StructType, + allocator: RootAllocator): ArrowRecordBatch = { + val bufAndField = schema.fields.zipWithIndex.map { case (field, ordinal) => + internalRowToArrowBuf(rows, ordinal, field, allocator) + } + + val buffers = bufAndField.flatMap(_._1).toList.asJava + val fieldNodes = bufAndField.flatMap(_._2).toList.asJava + + new ArrowRecordBatch(rows.length, fieldNodes, buffers) + } + + /** + * Convert an array of InternalRow to an ArrowBuf. + */ + def internalRowToArrowBuf( + rows: Array[InternalRow], + ordinal: Int, + field: StructField, + allocator: RootAllocator): (Array[ArrowBuf], Array[ArrowFieldNode]) = { + val numOfRows = rows.length + + field.dataType match { + case IntegerType | LongType | DoubleType | FloatType | BooleanType | ByteType => + val validityVector = new BitVector("validity", allocator) + val validityMutator = validityVector.getMutator + validityVector.allocateNew(numOfRows) + validityMutator.setValueCount(numOfRows) + + val buf = allocator.buffer(numOfRows * field.dataType.defaultSize) + var nullCount = 0 + var index = 0 + while (index < rows.length) { + val row = rows(index) + if (row.isNullAt(ordinal)) { + nullCount += 1 + validityMutator.set(index, 0) + fillArrow(buf, field.dataType) + } else { + validityMutator.set(index, 1) + getAndSetToArrow(row, buf, field.dataType, ordinal) + } + index += 1 + } + + val fieldNode = new ArrowFieldNode(numOfRows, nullCount) + + (Array(validityVector.getBuffer, buf), Array(fieldNode)) + + case StringType => + val validityVector = new BitVector("validity", allocator) + val validityMutator = validityVector.getMutator() + validityVector.allocateNew(numOfRows) + validityMutator.setValueCount(numOfRows) + + val bufOffset = allocator.buffer((numOfRows + 1) * IntegerType.defaultSize) + var bytesCount = 0 + bufOffset.writeInt(bytesCount) + val bufValues = allocator.buffer(1024) + var nullCount = 0 + rows.zipWithIndex.foreach { case (row, index) => + if (row.isNullAt(ordinal)) { + nullCount += 1 + validityMutator.set(index, 0) + bufOffset.writeInt(bytesCount) + } else { + validityMutator.set(index, 1) + val bytes = row.getUTF8String(ordinal).getBytes + bytesCount += bytes.length + bufOffset.writeInt(bytesCount) + bufValues.writeBytes(bytes) + } + } + + val fieldNode = new ArrowFieldNode(numOfRows, nullCount) + + (Array(validityVector.getBuffer, bufOffset, bufValues), + Array(fieldNode)) + } + } + + private[sql] def schemaToArrowSchema(schema: StructType): Schema = { + val arrowFields = schema.fields.map(sparkFieldToArrowField(_)) + new Schema(arrowFields.toList.asJava) + } + + private[sql] def sparkFieldToArrowField(sparkField: StructField): Field = { + val name = sparkField.name + val dataType = sparkField.dataType + val nullable = sparkField.nullable + + dataType match { + case StructType(fields) => + val childrenFields = fields.map(sparkFieldToArrowField(_)).toList.asJava + new Field(name, nullable, ArrowType.Struct.INSTANCE, childrenFields) + case _ => + new Field(name, nullable, dataTypeToArrowType(dataType), List.empty[Field].asJava) + } + } + + /** + * Transform Spark DataType to Arrow ArrowType. + */ + private[sql] def dataTypeToArrowType(dt: DataType): ArrowType = { + dt match { + case IntegerType => + new ArrowType.Int(8 * IntegerType.defaultSize, true) + case LongType => + new ArrowType.Int(8 * LongType.defaultSize, true) + case StringType => + ArrowType.Utf8.INSTANCE + case DoubleType => + new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE) + case FloatType => + new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE) + case BooleanType => + ArrowType.Bool.INSTANCE + case ByteType => + new ArrowType.Int(8, false) + case StructType(_) => + ArrowType.Struct.INSTANCE + case _ => + throw new IllegalArgumentException(s"Unsupported data type") + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 1f35aecb49c2..8b39b5448500 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -25,13 +25,9 @@ import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal -import io.netty.buffer.ArrowBuf import org.apache.arrow.memory.RootAllocator -import org.apache.arrow.vector.BitVector import org.apache.arrow.vector.file.ArrowWriter -import org.apache.arrow.vector.schema.{ArrowFieldNode, ArrowRecordBatch} -import org.apache.arrow.vector.types.FloatingPointPrecision -import org.apache.arrow.vector.types.pojo.{ArrowType, Field, Schema} +import org.apache.arrow.vector.schema.ArrowRecordBatch import org.apache.commons.lang3.StringUtils import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} @@ -2372,195 +2368,6 @@ class Dataset[T] private[sql]( } } - /** - * Transform Spark DataType to Arrow ArrowType. - */ - private[sql] def dataTypeToArrowType(dt: DataType): ArrowType = { - dt match { - case IntegerType => - new ArrowType.Int(8 * IntegerType.defaultSize, true) - case LongType => - new ArrowType.Int(8 * LongType.defaultSize, true) - case StringType => - ArrowType.List.INSTANCE - case DoubleType => - new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE) - case FloatType => - new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE) - case BooleanType => - ArrowType.Bool.INSTANCE - case ByteType => - new ArrowType.Int(8, false) - case _ => - throw new IllegalArgumentException(s"Unsupported data type") - } - } - - /** - * Transform Spark StructType to Arrow Schema. - */ - private[sql] def schemaToArrowSchema(schema: StructType): Schema = { - val arrowFields = schema.fields.map { - case StructField(name, dataType, nullable, metadata) => - dataType match { - // TODO: Consider other nested types - case StringType => - // TODO: Make sure String => List - val itemField = - new Field("item", false, ArrowType.Utf8.INSTANCE, List.empty[Field].asJava) - new Field(name, nullable, dataTypeToArrowType(dataType), List(itemField).asJava) - case _ => - new Field(name, nullable, dataTypeToArrowType(dataType), List.empty[Field].asJava) - } - } - val arrowSchema = new Schema(arrowFields.toIterable.asJava) - arrowSchema - } - - /** - * Compute the number of bytes needed to build validity map. According to - * [Arrow Layout](https://github.com/apache/arrow/blob/master/format/Layout.md#null-bitmaps), - * the length of the validity bitmap should be multiples of 64 bytes. - */ - private def numBytesOfBitmap(numOfRows: Int): Int = { - Math.ceil(numOfRows / 64.0).toInt * 8 - } - - private def fillArrow(buf: ArrowBuf, dataType: DataType): Unit = { - dataType match { - case NullType => - case BooleanType => - buf.writeBoolean(false) - case ShortType => - buf.writeShort(0) - case IntegerType => - buf.writeInt(0) - case LongType => - buf.writeLong(0L) - case FloatType => - buf.writeFloat(0f) - case DoubleType => - buf.writeDouble(0d) - case ByteType => - buf.writeByte(0) - case _ => - throw new UnsupportedOperationException( - s"Unsupported data type ${dataType.simpleString}") - } - } - - /** - * Get an entry from the InternalRow, and then set to ArrowBuf. - * Note: No Null check for the entry. - */ - private def getAndSetToArrow( - row: InternalRow, buf: ArrowBuf, dataType: DataType, ordinal: Int): Unit = { - dataType match { - case NullType => - case BooleanType => - buf.writeBoolean(row.getBoolean(ordinal)) - case ShortType => - buf.writeShort(row.getShort(ordinal)) - case IntegerType => - buf.writeInt(row.getInt(ordinal)) - case LongType => - buf.writeLong(row.getLong(ordinal)) - case FloatType => - buf.writeFloat(row.getFloat(ordinal)) - case DoubleType => - buf.writeDouble(row.getDouble(ordinal)) - case ByteType => - buf.writeByte(row.getByte(ordinal)) - case _ => - throw new UnsupportedOperationException( - s"Unsupported data type ${dataType.simpleString}") - } - } - - /** - * Convert an array of InternalRow to an ArrowBuf. - */ - private def internalRowToArrowBuf( - rows: Array[InternalRow], - ordinal: Int, - field: StructField, - allocator: RootAllocator): (Array[ArrowBuf], Array[ArrowFieldNode]) = { - val numOfRows = rows.length - - field.dataType match { - case IntegerType | LongType | DoubleType | FloatType | BooleanType | ByteType => - val validityVector = new BitVector("validity", allocator) - val validityMutator = validityVector.getMutator - validityVector.allocateNew(numOfRows) - validityMutator.setValueCount(numOfRows) - val buf = allocator.buffer(numOfRows * field.dataType.defaultSize) - var nullCount = 0 - var index = 0 - while (index < rows.length) { - val row = rows(index) - if (row.isNullAt(ordinal)) { - nullCount += 1 - validityMutator.set(index, 0) - fillArrow(buf, field.dataType) - } else { - validityMutator.set(index, 1) - getAndSetToArrow(row, buf, field.dataType, ordinal) - } - index += 1 - } - - val fieldNode = new ArrowFieldNode(numOfRows, nullCount) - - (Array(validityVector.getBuffer, buf), Array(fieldNode)) - - case StringType => - val validityOffset = allocator.buffer(numBytesOfBitmap(numOfRows)) - val bufOffset = allocator.buffer((numOfRows + 1) * IntegerType.defaultSize) - var bytesCount = 0 - bufOffset.writeInt(bytesCount) // Start position - val validityValues = allocator.buffer(numBytesOfBitmap(numOfRows)) - val bufValues = allocator.buffer(Int.MaxValue) // TODO: Reduce the size? - var nullCount = 0 - rows.foreach { row => - if (row.isNullAt(ordinal)) { - nullCount += 1 - bufOffset.writeInt(bytesCount) - } else { - val bytes = row.getUTF8String(ordinal).getBytes - bytesCount += bytes.length - bufOffset.writeInt(bytesCount) - bufValues.writeBytes(bytes) - } - } - - val fieldNodeOffset = if (field.nullable) { - new ArrowFieldNode(numOfRows, nullCount) - } else { - new ArrowFieldNode(numOfRows, 0) - } - - val fieldNodeValues = new ArrowFieldNode(bytesCount, 0) - - (Array(validityOffset, bufOffset, validityValues, bufValues), - Array(fieldNodeOffset, fieldNodeValues)) - } - } - - /** - * Transfer an array of InternalRow to an ArrowRecordBatch. - */ - private[sql] def internalRowsToArrowRecordBatch( - rows: Array[InternalRow], allocator: RootAllocator): ArrowRecordBatch = { - val bufAndField = this.schema.fields.zipWithIndex.map { case (field, ordinal) => - internalRowToArrowBuf(rows, ordinal, field, allocator) - } - - val buffers = bufAndField.flatMap(_._1).toList.asJava - val fieldNodes = bufAndField.flatMap(_._2).toList.asJava - - new ArrowRecordBatch(rows.length, fieldNodes, buffers) - } - /** * Collect a Dataset to an ArrowRecordBatch. * @@ -2573,7 +2380,8 @@ class Dataset[T] private[sql]( withNewExecutionId { try { val collectedRows = queryExecution.executedPlan.executeCollect() - val recordBatch = internalRowsToArrowRecordBatch(collectedRows, allocator) + val recordBatch = Arrow.internalRowsToArrowRecordBatch( + collectedRows, this.schema, allocator) recordBatch } catch { case e: Exception => @@ -2956,7 +2764,7 @@ class Dataset[T] private[sql]( */ private[sql] def collectAsArrowToPython(): Int = { val recordBatch = collectAsArrow() - val arrowSchema = schemaToArrowSchema(this.schema) + val arrowSchema = Arrow.schemaToArrowSchema(this.schema) val out = new ByteArrayOutputStream() try { val writer = new ArrowWriter(Channels.newChannel(out), arrowSchema) diff --git a/sql/core/src/test/resources/test-data/arrowNullInts.json b/sql/core/src/test/resources/test-data/arrowNullInts.json index 31b272af7d12..1a2447abdc0b 100644 --- a/sql/core/src/test/resources/test-data/arrowNullInts.json +++ b/sql/core/src/test/resources/test-data/arrowNullInts.json @@ -15,6 +15,7 @@ } ] }, + "batches": [ { "count": 4, diff --git a/sql/core/src/test/resources/test-data/arrowNullStrings.json b/sql/core/src/test/resources/test-data/arrowNullStrings.json new file mode 100644 index 000000000000..a1c8ae3d932b --- /dev/null +++ b/sql/core/src/test/resources/test-data/arrowNullStrings.json @@ -0,0 +1,34 @@ +{ + "schema": { + "fields": [ + { + "name": "value", + "type": {"name": "utf8"}, + "nullable": true, + "children": [], + "typeLayout": { + "vectors": [ + {"type": "VALIDITY", "typeBitWidth": 1}, + {"type": "OFFSET", "typeBitWidth": 32}, + {"type": "DATA", "typeBitWidth": 8} + ] + } + } + ] + }, + + "batches": [ + { + "count": 6, + "columns": [ + { + "name": "value", + "count": 6, + "VALIDITY": [1, 0, 1, 1, 1, 0], + "OFFSET": [0, 1, 1, 2, 4, 7, 7], + "DATA": ["a", "", "b", "ab", "abc", ""] + } + ] + } + ] +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala index 9b1786c83f16..b9614d1e5ae1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala @@ -19,14 +19,13 @@ package org.apache.spark.sql import java.io.File import org.apache.arrow.memory.RootAllocator -import org.apache.arrow.vector.{BitVector, VectorLoader, VectorSchemaRoot} +import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot} import org.apache.arrow.vector.file.json.JsonFileReader import org.apache.arrow.vector.util.Validator -import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} +import org.apache.spark.sql.test.{ArrowTestData, SharedSQLContext} -class ArrowSuite extends QueryTest with SharedSQLContext with SQLTestUtils { - private val nullIntsFile = "test-data/arrowNullInts.json" +class ArrowSuite extends QueryTest with SharedSQLContext with ArrowTestData { private def testFile(fileName: String): String = { // TODO: Copied from CSVSuite, find a better way to read test files @@ -34,13 +33,20 @@ class ArrowSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("convert int column with null to arrow") { - val df = nullInts - val jsonFilePath = testFile(nullIntsFile) + test(arrowNullInts, arrowNullIntsFile) + } + + test("convert string column with null to arrow") { + test(arrowNullStrings, arrowNullStringsFile) + } + + private def test(df: DataFrame, arrowFile: String) { + val jsonFilePath = testFile(arrowFile) val allocator = new RootAllocator(Integer.MAX_VALUE) val jsonReader = new JsonFileReader(new File(jsonFilePath), allocator) - val arrowSchema = df.schemaToArrowSchema(df.schema) + val arrowSchema = Arrow.schemaToArrowSchema(df.schema) val jsonSchema = jsonReader.start() Validator.compareSchemas(arrowSchema, jsonSchema) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/arrow/ColumnWriter.scala b/sql/core/src/test/scala/org/apache/spark/sql/arrow/ColumnWriter.scala new file mode 100644 index 000000000000..d8592b79b19a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/arrow/ColumnWriter.scala @@ -0,0 +1,21 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.arrow + +trait ColumnWriter { +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/ArrowTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/ArrowTestData.scala new file mode 100644 index 000000000000..6e9a91bbc560 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/ArrowTestData.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import org.apache.spark.sql.{DataFrame, SparkSession, SQLContext, SQLImplicits} + +private[sql] trait ArrowTestData {self => + protected def spark: SparkSession + + // Helper object to import SQL implicits without a concrete SQLContext + private object internalImplicits extends SQLImplicits { + protected override def _sqlContext: SQLContext = self.spark.sqlContext + } + + import internalImplicits._ + import ArrowTestData._ + + protected val arrowNullIntsFile = "test-data/arrowNullInts.json" + protected val arrowNullStringsFile = "test-data/arrowNullStrings.json" + + protected lazy val arrowNullInts: DataFrame = spark.sparkContext.parallelize( + NullInts(1) :: + NullInts(2) :: + NullInts(3) :: + NullInts(null) :: Nil + ).toDF() + + protected lazy val arrowNullStrings: DataFrame = spark.sparkContext.parallelize( + NullStrings("a") :: + NullStrings(null) :: + NullStrings("b") :: + NullStrings("ab") :: + NullStrings("abc") :: + NullStrings(null) :: Nil + ).toDF() +} + +/** + * Case classes used in test data. + */ +private[sql] object ArrowTestData { + case class NullInts(a: Integer) + case class NullStrings(value: String) +} +