diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 9ef943216f88..072cef6d7b35 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -394,11 +394,11 @@ def collect(self): @ignore_unicode_prefix @since(2.0) def collectAsArrow(self): - """Returns all the records as an ArrowRecordBatch + """Returns all records as list of deserialized ArrowPayloads """ with SCCallSiteSync(self._sc) as css: port = self._jdf.collectAsArrowToPython() - return list(_load_from_socket(port, ArrowSerializer()))[0] + return list(_load_from_socket(port, ArrowSerializer())) @ignore_unicode_prefix @since(2.0) @@ -1573,6 +1573,9 @@ def toPandas(self, useArrow=False): This is only available if Pandas is installed and available. + :param useArrow: Make use of Apache Arrow for conversion, pyarrow must be installed + on the calling Python process. + .. note:: This method should only be used if the resulting Pandas's DataFrame is expected to be small, as all the data is loaded into the driver's memory. @@ -1581,11 +1584,13 @@ def toPandas(self, useArrow=False): 0 2 Alice 1 5 Bob """ - import pandas as pd - if useArrow: - return self.collectAsArrow().to_pandas() + from pyarrow.table import concat_tables + tables = self.collectAsArrow() + table = concat_tables(tables) + return table.to_pandas() else: + import pandas as pd return pd.DataFrame.from_records(self.collect(), columns=self.columns) ########################################################################################## diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/ArrowConverters.scala index 8a7379536ab5..47a2d966b0c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ArrowConverters.scala @@ -18,7 +18,8 @@ package org.apache.spark.sql import java.io.ByteArrayOutputStream -import java.nio.channels.Channels +import java.nio.ByteBuffer +import java.nio.channels.{SeekableByteChannel, Channels} import scala.collection.JavaConverters._ @@ -26,7 +27,7 @@ import io.netty.buffer.ArrowBuf import org.apache.arrow.memory.{BaseAllocator, RootAllocator} import org.apache.arrow.vector._ import org.apache.arrow.vector.BaseValueVector.BaseMutator -import org.apache.arrow.vector.file.ArrowWriter +import org.apache.arrow.vector.file.{ArrowReader, ArrowWriter} import org.apache.arrow.vector.schema.{ArrowFieldNode, ArrowRecordBatch} import org.apache.arrow.vector.types.{FloatingPointPrecision, TimeUnit} import org.apache.arrow.vector.types.pojo.{ArrowType, Field, Schema} @@ -34,11 +35,65 @@ import org.apache.arrow.vector.types.pojo.{ArrowType, Field, Schema} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ + +/** + * ArrowReader requires a seekable byte channel. + * NOTE - this is taken from test org.apache.vector.file, see about moving to public util pkg + */ +private[sql] class ByteArrayReadableSeekableByteChannel(var byteArray: Array[Byte]) + extends SeekableByteChannel { + var _position: Long = 0L + + override def isOpen: Boolean = { + byteArray != null + } + + override def close(): Unit = { + byteArray = null + } + + override def read(dst: ByteBuffer): Int = { + val remainingBuf = byteArray.length - _position + val length = Math.min(dst.remaining(), remainingBuf).toInt + dst.put(byteArray, _position.toInt, length) + _position += length + length.toInt + } + + override def position(): Long = _position + + override def position(newPosition: Long): SeekableByteChannel = { + _position = newPosition.toLong + this + } + + override def size: Long = { + byteArray.length.toLong + } + + override def write(src: ByteBuffer): Int = { + throw new UnsupportedOperationException("Read Only") + } + + override def truncate(size: Long): SeekableByteChannel = { + throw new UnsupportedOperationException("Read Only") + } +} + /** * Intermediate data structure returned from Arrow conversions */ private[sql] abstract class ArrowPayload extends Iterator[ArrowRecordBatch] +/** + * Build a payload from existing ArrowRecordBatches + */ +private[sql] class ArrowStaticPayload(batches: ArrowRecordBatch*) extends ArrowPayload { + private val iter = batches.iterator + override def next(): ArrowRecordBatch = iter.next() + override def hasNext: Boolean = iter.hasNext +} + /** * Class that wraps an Arrow RootAllocator used in conversion */ @@ -47,16 +102,24 @@ private[sql] class ArrowConverters { private[sql] def allocator: RootAllocator = _allocator - private class ArrowStaticPayload(batches: ArrowRecordBatch*) extends ArrowPayload { - private val iter = batches.iterator - - override def next(): ArrowRecordBatch = iter.next() - override def hasNext: Boolean = iter.hasNext + def interalRowIterToPayload(rowIter: Iterator[InternalRow], schema: StructType): ArrowPayload = { + val batch = ArrowConverters.internalRowIterToArrowBatch(rowIter, schema, allocator) + new ArrowStaticPayload(batch) } - def internalRowsToPayload(rows: Array[InternalRow], schema: StructType): ArrowPayload = { - val batch = ArrowConverters.internalRowsToArrowRecordBatch(rows, schema, allocator) - new ArrowStaticPayload(batch) + def readPayloadByteArrays(payloadByteArrays: Array[Array[Byte]]): ArrowPayload = { + val batches = scala.collection.mutable.ArrayBuffer.empty[ArrowRecordBatch] + var i = 0 + while (i < payloadByteArrays.length) { + val payloadBytes = payloadByteArrays(i) + val in = new ByteArrayReadableSeekableByteChannel(payloadBytes) + val reader = new ArrowReader(in, _allocator) + val footer = reader.readFooter() + val batchBlocks = footer.getRecordBatches.asScala.toArray + batchBlocks.foreach(block => batches += reader.readRecordBatch(block)) + i += 1 + } + new ArrowStaticPayload(batches: _*) } } @@ -83,52 +146,43 @@ private[sql] object ArrowConverters { } /** - * Transfer an array of InternalRow to an ArrowRecordBatch. + * Iterate over InternalRows and write to an ArrowRecordBatch. */ - private[sql] def internalRowsToArrowRecordBatch( - rows: Array[InternalRow], + private def internalRowIterToArrowBatch( + rowIter: Iterator[InternalRow], schema: StructType, allocator: RootAllocator): ArrowRecordBatch = { - val fieldAndBuf = schema.fields.zipWithIndex.map { case (field, ordinal) => - internalRowToArrowBuf(rows, ordinal, field, allocator) + + val columnWriters = schema.fields.zipWithIndex.map { case (field, ordinal) => + ColumnWriter(ordinal, allocator, field.dataType) + .init() + } + + val writerLength = columnWriters.length + while (rowIter.hasNext) { + val row = rowIter.next() + var i = 0 + while (i < writerLength) { + columnWriters(i).write(row) + i += 1 + } + } + + val fieldAndBuf = columnWriters.map { writer => + writer.finish() }.unzip - val fieldNodes = fieldAndBuf._1.flatten + val fieldNodes = fieldAndBuf._1 val buffers = fieldAndBuf._2.flatten - val recordBatch = new ArrowRecordBatch(rows.length, + val rowLength = if(fieldNodes.nonEmpty) fieldNodes.head.getLength else 0 + + val recordBatch = new ArrowRecordBatch(rowLength, fieldNodes.toList.asJava, buffers.toList.asJava) buffers.foreach(_.release()) recordBatch } - /** - * Write a Field from array of InternalRow to an ArrowBuf. - */ - private def internalRowToArrowBuf( - rows: Array[InternalRow], - ordinal: Int, - field: StructField, - allocator: RootAllocator): (Array[ArrowFieldNode], Array[ArrowBuf]) = { - val numOfRows = rows.length - val columnWriter = ColumnWriter(allocator, field.dataType) - columnWriter.init(numOfRows) - var index = 0 - - while(index < numOfRows) { - val row = rows(index) - if (row.isNullAt(ordinal)) { - columnWriter.writeNull() - } else { - columnWriter.write(row, ordinal) - } - index += 1 - } - - val (arrowFieldNodes, arrowBufs) = columnWriter.finish() - (arrowFieldNodes.toArray, arrowBufs.toArray) - } - /** * Convert a Spark Dataset schema to Arrow schema. */ @@ -160,138 +214,139 @@ private[sql] object ArrowConverters { } private[sql] trait ColumnWriter { - def init(initialSize: Int): Unit - def writeNull(): Unit - def write(row: InternalRow, ordinal: Int): Unit + def init(): this.type + def write(row: InternalRow): Unit /** * Clear the column writer and return the ArrowFieldNode and ArrowBuf. * This should be called only once after all the data is written. */ - def finish(): (Seq[ArrowFieldNode], Seq[ArrowBuf]) + def finish(): (ArrowFieldNode, Array[ArrowBuf]) } /** * Base class for flat arrow column writer, i.e., column without children. */ -private[sql] abstract class PrimitiveColumnWriter(protected val allocator: BaseAllocator) +private[sql] abstract class PrimitiveColumnWriter( + val ordinal: Int, + val allocator: BaseAllocator) extends ColumnWriter { - protected def valueVector: BaseDataValueVector - protected def valueMutator: BaseMutator + def valueVector: BaseDataValueVector + def valueMutator: BaseMutator - protected def setNull(): Unit - protected def setValue(row: InternalRow, ordinal: Int): Unit + def setNull(): Unit + def setValue(row: InternalRow, ordinal: Int): Unit protected var count = 0 protected var nullCount = 0 - override def init(initialSize: Int): Unit = { + override def init(): this.type = { valueVector.allocateNew() + this } - override def writeNull(): Unit = { - setNull() - nullCount += 1 - count += 1 - } - - override def write(row: InternalRow, ordinal: Int): Unit = { - setValue(row, ordinal) + override def write(row: InternalRow): Unit = { + if (row.isNullAt(ordinal)) { + setNull() + nullCount += 1 + } else { + setValue(row, ordinal) + } count += 1 } - override def finish(): (Seq[ArrowFieldNode], Seq[ArrowBuf]) = { + override def finish(): (ArrowFieldNode, Array[ArrowBuf]) = { valueMutator.setValueCount(count) val fieldNode = new ArrowFieldNode(count, nullCount) - val valueBuffers: Seq[ArrowBuf] = valueVector.getBuffers(true) - (List(fieldNode), valueBuffers) + val valueBuffers = valueVector.getBuffers(true) + (fieldNode, valueBuffers) } } -private[sql] class BooleanColumnWriter(allocator: BaseAllocator) - extends PrimitiveColumnWriter(allocator) { +private[sql] class BooleanColumnWriter(ordinal: Int, allocator: BaseAllocator) + extends PrimitiveColumnWriter(ordinal, allocator) { private def bool2int(b: Boolean): Int = if (b) 1 else 0 - override protected val valueVector: NullableBitVector + override val valueVector: NullableBitVector = new NullableBitVector("BooleanValue", allocator) - override protected val valueMutator: NullableBitVector#Mutator = valueVector.getMutator + override val valueMutator: NullableBitVector#Mutator = valueVector.getMutator override def setNull(): Unit = valueMutator.setNull(count) override def setValue(row: InternalRow, ordinal: Int): Unit = valueMutator.setSafe(count, bool2int(row.getBoolean(ordinal))) } -private[sql] class ShortColumnWriter(allocator: BaseAllocator) - extends PrimitiveColumnWriter(allocator) { - override protected val valueVector: NullableSmallIntVector +private[sql] class ShortColumnWriter(ordinal: Int, allocator: BaseAllocator) + extends PrimitiveColumnWriter(ordinal, allocator) { + override val valueVector: NullableSmallIntVector = new NullableSmallIntVector("ShortValue", allocator) - override protected val valueMutator: NullableSmallIntVector#Mutator = valueVector.getMutator + override val valueMutator: NullableSmallIntVector#Mutator = valueVector.getMutator override def setNull(): Unit = valueMutator.setNull(count) override def setValue(row: InternalRow, ordinal: Int): Unit = valueMutator.setSafe(count, row.getShort(ordinal)) } -private[sql] class IntegerColumnWriter(allocator: BaseAllocator) - extends PrimitiveColumnWriter(allocator) { - override protected val valueVector: NullableIntVector +private[sql] class IntegerColumnWriter(ordinal: Int, allocator: BaseAllocator) + extends PrimitiveColumnWriter(ordinal, allocator) { + override val valueVector: NullableIntVector = new NullableIntVector("IntValue", allocator) - override protected val valueMutator: NullableIntVector#Mutator = valueVector.getMutator + override val valueMutator: NullableIntVector#Mutator = valueVector.getMutator override def setNull(): Unit = valueMutator.setNull(count) override def setValue(row: InternalRow, ordinal: Int): Unit = valueMutator.setSafe(count, row.getInt(ordinal)) } -private[sql] class LongColumnWriter(allocator: BaseAllocator) - extends PrimitiveColumnWriter(allocator) { - override protected val valueVector: NullableBigIntVector +private[sql] class LongColumnWriter(ordinal: Int, allocator: BaseAllocator) + extends PrimitiveColumnWriter(ordinal, allocator) { + override val valueVector: NullableBigIntVector = new NullableBigIntVector("LongValue", allocator) - override protected val valueMutator: NullableBigIntVector#Mutator = valueVector.getMutator + override val valueMutator: NullableBigIntVector#Mutator = valueVector.getMutator override def setNull(): Unit = valueMutator.setNull(count) override def setValue(row: InternalRow, ordinal: Int): Unit = valueMutator.setSafe(count, row.getLong(ordinal)) } -private[sql] class FloatColumnWriter(allocator: BaseAllocator) - extends PrimitiveColumnWriter(allocator) { - override protected val valueVector: NullableFloat4Vector +private[sql] class FloatColumnWriter(ordinal: Int, allocator: BaseAllocator) + extends PrimitiveColumnWriter(ordinal, allocator) { + override val valueVector: NullableFloat4Vector = new NullableFloat4Vector("FloatValue", allocator) - override protected val valueMutator: NullableFloat4Vector#Mutator = valueVector.getMutator + override val valueMutator: NullableFloat4Vector#Mutator = valueVector.getMutator override def setNull(): Unit = valueMutator.setNull(count) override def setValue(row: InternalRow, ordinal: Int): Unit = valueMutator.setSafe(count, row.getFloat(ordinal)) } -private[sql] class DoubleColumnWriter(allocator: BaseAllocator) - extends PrimitiveColumnWriter(allocator) { - override protected val valueVector: NullableFloat8Vector +private[sql] class DoubleColumnWriter(ordinal: Int, allocator: BaseAllocator) + extends PrimitiveColumnWriter(ordinal, allocator) { + override val valueVector: NullableFloat8Vector = new NullableFloat8Vector("DoubleValue", allocator) - override protected val valueMutator: NullableFloat8Vector#Mutator = valueVector.getMutator + override val valueMutator: NullableFloat8Vector#Mutator = valueVector.getMutator override def setNull(): Unit = valueMutator.setNull(count) override def setValue(row: InternalRow, ordinal: Int): Unit = valueMutator.setSafe(count, row.getDouble(ordinal)) } -private[sql] class ByteColumnWriter(allocator: BaseAllocator) - extends PrimitiveColumnWriter(allocator) { - override protected val valueVector: NullableUInt1Vector +private[sql] class ByteColumnWriter(ordinal: Int, allocator: BaseAllocator) + extends PrimitiveColumnWriter(ordinal, allocator) { + override val valueVector: NullableUInt1Vector = new NullableUInt1Vector("ByteValue", allocator) - override protected val valueMutator: NullableUInt1Vector#Mutator = valueVector.getMutator + override val valueMutator: NullableUInt1Vector#Mutator = valueVector.getMutator override def setNull(): Unit = valueMutator.setNull(count) override def setValue(row: InternalRow, ordinal: Int): Unit = valueMutator.setSafe(count, row.getByte(ordinal)) } -private[sql] class UTF8StringColumnWriter(allocator: BaseAllocator) - extends PrimitiveColumnWriter(allocator) { - override protected val valueVector: NullableVarBinaryVector +private[sql] class UTF8StringColumnWriter(ordinal: Int, allocator: BaseAllocator) + extends PrimitiveColumnWriter(ordinal, allocator) { + override val valueVector: NullableVarBinaryVector = new NullableVarBinaryVector("UTF8StringValue", allocator) - override protected val valueMutator: NullableVarBinaryVector#Mutator = valueVector.getMutator + override val valueMutator: NullableVarBinaryVector#Mutator = valueVector.getMutator override def setNull(): Unit = valueMutator.setNull(count) override def setValue(row: InternalRow, ordinal: Int): Unit = { @@ -300,11 +355,11 @@ private[sql] class UTF8StringColumnWriter(allocator: BaseAllocator) } } -private[sql] class BinaryColumnWriter(allocator: BaseAllocator) - extends PrimitiveColumnWriter(allocator) { - override protected val valueVector: NullableVarBinaryVector +private[sql] class BinaryColumnWriter(ordinal: Int, allocator: BaseAllocator) + extends PrimitiveColumnWriter(ordinal, allocator) { + override val valueVector: NullableVarBinaryVector = new NullableVarBinaryVector("BinaryValue", allocator) - override protected val valueMutator: NullableVarBinaryVector#Mutator = valueVector.getMutator + override val valueMutator: NullableVarBinaryVector#Mutator = valueVector.getMutator override def setNull(): Unit = valueMutator.setNull(count) override def setValue(row: InternalRow, ordinal: Int): Unit = { @@ -313,47 +368,45 @@ private[sql] class BinaryColumnWriter(allocator: BaseAllocator) } } -private[sql] class DateColumnWriter(allocator: BaseAllocator) - extends PrimitiveColumnWriter(allocator) { - override protected val valueVector: NullableDateVector +private[sql] class DateColumnWriter(ordinal: Int, allocator: BaseAllocator) + extends PrimitiveColumnWriter(ordinal, allocator) { + override val valueVector: NullableDateVector = new NullableDateVector("DateValue", allocator) - override protected val valueMutator: NullableDateVector#Mutator = valueVector.getMutator + override val valueMutator: NullableDateVector#Mutator = valueVector.getMutator - override protected def setNull(): Unit = valueMutator.setNull(count) - override protected def setValue(row: InternalRow, ordinal: Int): Unit = { + override def setNull(): Unit = valueMutator.setNull(count) + override def setValue(row: InternalRow, ordinal: Int): Unit = { // TODO: comment on diff btw value representations of date/timestamp valueMutator.setSafe(count, row.getInt(ordinal).toLong * 24 * 3600 * 1000) } } -private[sql] class TimeStampColumnWriter(allocator: BaseAllocator) - extends PrimitiveColumnWriter(allocator) { - override protected val valueVector: NullableTimeStampVector - = new NullableTimeStampVector("TimeStampValue", allocator) - override protected val valueMutator: NullableTimeStampVector#Mutator = valueVector.getMutator +private[sql] class TimeStampColumnWriter(ordinal: Int, allocator: BaseAllocator) + extends PrimitiveColumnWriter(ordinal, allocator) { + override val valueVector: NullableTimeStampMicroVector + = new NullableTimeStampMicroVector("TimeStampValue", allocator) + override val valueMutator: NullableTimeStampMicroVector#Mutator = valueVector.getMutator - override protected def setNull(): Unit = valueMutator.setNull(count) - - override protected def setValue(row: InternalRow, ordinal: Int): Unit = { - // TODO: use microsecond timestamp when ARROW-477 is resolved + override def setNull(): Unit = valueMutator.setNull(count) + override def setValue(row: InternalRow, ordinal: Int): Unit = { valueMutator.setSafe(count, row.getLong(ordinal) / 1000) } } private[sql] object ColumnWriter { - def apply(allocator: BaseAllocator, dataType: DataType): ColumnWriter = { + def apply(ordinal: Int, allocator: BaseAllocator, dataType: DataType): ColumnWriter = { dataType match { - case BooleanType => new BooleanColumnWriter(allocator) - case ShortType => new ShortColumnWriter(allocator) - case IntegerType => new IntegerColumnWriter(allocator) - case LongType => new LongColumnWriter(allocator) - case FloatType => new FloatColumnWriter(allocator) - case DoubleType => new DoubleColumnWriter(allocator) - case ByteType => new ByteColumnWriter(allocator) - case StringType => new UTF8StringColumnWriter(allocator) - case BinaryType => new BinaryColumnWriter(allocator) - case DateType => new DateColumnWriter(allocator) - case TimestampType => new TimeStampColumnWriter(allocator) + case BooleanType => new BooleanColumnWriter(ordinal, allocator) + case ShortType => new ShortColumnWriter(ordinal, allocator) + case IntegerType => new IntegerColumnWriter(ordinal, allocator) + case LongType => new LongColumnWriter(ordinal, allocator) + case FloatType => new FloatColumnWriter(ordinal, allocator) + case DoubleType => new DoubleColumnWriter(ordinal, allocator) + case ByteType => new ByteColumnWriter(ordinal, allocator) + case StringType => new UTF8StringColumnWriter(ordinal, allocator) + case BinaryType => new BinaryColumnWriter(ordinal, allocator) + case DateType => new DateColumnWriter(ordinal, allocator) + case TimestampType => new TimeStampColumnWriter(ordinal, allocator) case _ => throw new UnsupportedOperationException(s"Unsupported data type: $dataType") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 8caecb7fd8ac..31178706a65d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2364,26 +2364,6 @@ class Dataset[T] private[sql]( } } - /** - * Collect a Dataset to an ArrowRecordBatch. - * - * @group action - * @since 2.2.0 - */ - @DeveloperApi - def collectAsArrow(converter: Option[ArrowConverters] = None): ArrowPayload = { - val cnvtr = converter.getOrElse(new ArrowConverters) - withNewExecutionId { - try { - val collectedRows = queryExecution.executedPlan.executeCollect() - cnvtr.internalRowsToPayload(collectedRows, this.schema) - } catch { - case e: Exception => - throw e - } - } - } - /** * Return an iterator that contains all rows in this Dataset. * @@ -2754,14 +2734,13 @@ class Dataset[T] private[sql]( } /** - * Collect a Dataset as an ArrowRecordBatch, and serve the ArrowRecordBatch to PySpark. + * Collect a Dataset as ArrowPayload byte arrays and serve to PySpark. */ private[sql] def collectAsArrowToPython(): Int = { - val payload = collectAsArrow() - val payloadBytes = ArrowConverters.payloadToByteArray(payload, this.schema) - + val payloadRdd = toArrowPayloadBytes() + val payloadByteArrays = payloadRdd.collect() withNewExecutionId { - PythonRDD.serveIterator(Iterator(payloadBytes), "serve-Arrow") + PythonRDD.serveIterator(payloadByteArrays.iterator, "serve-Arrow") } } @@ -2854,4 +2833,16 @@ class Dataset[T] private[sql]( Dataset(sparkSession, logicalPlan) } } + + /** Convert to an RDD of ArrowPayload byte arrays */ + private[sql] def toArrowPayloadBytes(): RDD[Array[Byte]] = { + val schema_captured = this.schema + queryExecution.toRdd.mapPartitionsInternal { iter => + val converter = new ArrowConverters + val payload = converter.interalRowIterToPayload(iter, schema_captured) + val payloadBytes = ArrowConverters.payloadToByteArray(payload, schema_captured) + payload.foreach(_.close()) + Iterator(payloadBytes) + } + } } diff --git a/sql/core/src/test/resources/test-data/arrow/testData2-ints.json b/sql/core/src/test/resources/test-data/arrow/testData2-ints-part1.json similarity index 79% rename from sql/core/src/test/resources/test-data/arrow/testData2-ints.json rename to sql/core/src/test/resources/test-data/arrow/testData2-ints-part1.json index 6edc2a030287..bf6f0a38a332 100644 --- a/sql/core/src/test/resources/test-data/arrow/testData2-ints.json +++ b/sql/core/src/test/resources/test-data/arrow/testData2-ints-part1.json @@ -30,19 +30,19 @@ "batches": [ { - "count": 6, + "count": 3, "columns": [ { "name": "a", - "count": 6, - "VALIDITY": [1, 1, 1, 1, 1, 1], - "DATA": [1, 1, 2, 2, 3, 3] + "count": 3, + "VALIDITY": [1, 1, 1], + "DATA": [1, 1, 2] }, { "name": "b", - "count": 6, - "VALIDITY": [1, 1, 1, 1, 1, 1], - "DATA": [1, 2, 1, 2, 1, 2] + "count": 3, + "VALIDITY": [1, 1, 1], + "DATA": [1, 2, 1] } ] } diff --git a/sql/core/src/test/resources/test-data/arrow/testData2-ints-part2.json b/sql/core/src/test/resources/test-data/arrow/testData2-ints-part2.json new file mode 100644 index 000000000000..5261d51ff218 --- /dev/null +++ b/sql/core/src/test/resources/test-data/arrow/testData2-ints-part2.json @@ -0,0 +1,50 @@ +{ + "schema": { + "fields": [ + { + "name": "a", + "type": {"name": "int", "isSigned": true, "bitWidth": 32}, + "nullable": false, + "children": [], + "typeLayout": { + "vectors": [ + {"type": "VALIDITY", "typeBitWidth": 1}, + {"type": "DATA", "typeBitWidth": 32} + ] + } + }, + { + "name": "b", + "type": {"name": "int", "isSigned": true, "bitWidth": 32}, + "nullable": false, + "children": [], + "typeLayout": { + "vectors": [ + {"type": "VALIDITY", "typeBitWidth": 1}, + {"type": "DATA", "typeBitWidth": 32} + ] + } + } + ] + }, + + "batches": [ + { + "count": 3, + "columns": [ + { + "name": "a", + "count": 3, + "VALIDITY": [1, 1, 1], + "DATA": [2, 3, 3] + }, + { + "name": "b", + "count": 3, + "VALIDITY": [1, 1, 1], + "DATA": [2, 1, 2] + } + ] + } + ] +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ArrowConvertersSuite.scala index d4a6b6672e07..e0497b855e03 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ArrowConvertersSuite.scala @@ -24,8 +24,10 @@ import java.util.Locale import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot} import org.apache.arrow.vector.file.json.JsonFileReader import org.apache.arrow.vector.util.Validator +import org.apache.spark.SparkException import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.StructType // NOTE - nullable type can be declared as Option[*] or java.lang.* @@ -39,16 +41,26 @@ private[sql] case class DoubleData(i: Int, a_d: Double, b_d: Option[Double]) class ArrowConvertersSuite extends SharedSQLContext { import testImplicits._ + private def collectAsArrow(df: DataFrame, + converter: Option[ArrowConverters] = None): ArrowPayload = { + val cnvtr = converter.getOrElse(new ArrowConverters) + val payloadByteArrays = df.toArrowPayloadBytes().collect() + cnvtr.readPayloadByteArrays(payloadByteArrays) + } + private def testFile(fileName: String): String = { Thread.currentThread().getContextClassLoader.getResource(fileName).getFile } test("collect to arrow record batch") { - val arrowPayload = indexData.collectAsArrow() + val arrowPayload = collectAsArrow(indexData) assert(arrowPayload.nonEmpty) - arrowPayload.foreach(arrowRecordBatch => assert(arrowRecordBatch.getLength > 0)) - arrowPayload.foreach(arrowRecordBatch => assert(arrowRecordBatch.getNodes.size() > 0)) - arrowPayload.foreach(arrowRecordBatch => arrowRecordBatch.close()) + val arrowBatches = arrowPayload.toArray + assert(arrowBatches.length == indexData.rdd.getNumPartitions) + val rowCount = arrowBatches.map(batch => batch.getLength).sum + assert(rowCount === indexData.count()) + arrowBatches.foreach(batch => assert(batch.getNodes.size() > 0)) + arrowBatches.foreach(batch => batch.close()) } test("standard type conversion") { @@ -82,7 +94,16 @@ class ArrowConvertersSuite extends SharedSQLContext { } test("partitioned DataFrame") { - collectAndValidate(testData2, "test-data/arrow/testData2-ints.json") + val converter = new ArrowConverters + val schema = testData2.schema + val arrowPayload = collectAsArrow(testData2, Some(converter)) + val arrowBatches = arrowPayload.toArray + // NOTE: testData2 should have 2 partitions -> 2 arrow batches in payload + assert(arrowBatches.length === 2) + val pl1 = new ArrowStaticPayload(arrowBatches(0)) + val pl2 = new ArrowStaticPayload(arrowBatches(1)) + validateConversion(schema, pl1,"test-data/arrow/testData2-ints-part1.json", Some(converter)) + validateConversion(schema, pl2,"test-data/arrow/testData2-ints-part2.json", Some(converter)) } test("string type conversion") { @@ -105,11 +126,14 @@ class ArrowConvertersSuite extends SharedSQLContext { collectAndValidate(binaryData, "test-data/arrow/binaryData.json") } - test("nested type conversion") { } + // Type not yet supported + ignore("nested type conversion") { } - test("array type conversion") { } + // Type not yet supported + ignore("array type conversion") { } - test("mapped type conversion") { } + // Type not yet supported + ignore("mapped type conversion") { } test("floating-point NaN") { val nanData = Seq((1, 1.2F, Double.NaN), (2, Float.NaN, 1.23)).toDF("i", "NaN_f", "NaN_d") @@ -123,22 +147,32 @@ class ArrowConvertersSuite extends SharedSQLContext { } test("empty frame collect") { - val arrowPayload = spark.emptyDataFrame.collectAsArrow() - assert(arrowPayload.nonEmpty) - arrowPayload.foreach(emptyBatch => assert(emptyBatch.getLength == 0)) + val arrowPayload = collectAsArrow(spark.emptyDataFrame) + assert(arrowPayload.isEmpty) + } + + test("empty partition collect") { + val emptyPart = spark.sparkContext.parallelize(Seq(1), 2).toDF("i") + val arrowPayload = collectAsArrow(emptyPart) + val arrowBatches = arrowPayload.toArray + assert(arrowBatches.length === 2) + assert(arrowBatches.count(_.getLength == 0) === 1) + assert(arrowBatches.count(_.getLength == 1) === 1) } test("unsupported types") { def runUnsupported(block: => Unit): Unit = { - val msg = intercept[UnsupportedOperationException] { + val msg = intercept[SparkException] { block } assert(msg.getMessage.contains("Unsupported data type")) + assert(msg.getCause.getClass === classOf[UnsupportedOperationException]) } - runUnsupported { - collectAndValidate(decimalData, "test-data/arrow/decimalData-BigDecimal.json") - } + runUnsupported { collectAsArrow(decimalData) } + runUnsupported { collectAsArrow(arrayData.toDF()) } + runUnsupported { collectAsArrow(mapData.toDF()) } + runUnsupported { collectAsArrow(complexData) } } test("test Arrow Validator") { @@ -160,22 +194,29 @@ class ArrowConvertersSuite extends SharedSQLContext { } /** Test that a converted DataFrame to Arrow record batch equals batch read from JSON file */ - private def collectAndValidate(df: DataFrame, arrowFile: String) { - val jsonFilePath = testFile(arrowFile) - + private def collectAndValidate(df: DataFrame, arrowFile: String): Unit = { val converter = new ArrowConverters + // NOTE: coalesce to single partition because can only load 1 batch in validator + val arrowPayload = collectAsArrow(df.coalesce(1), Some(converter)) + validateConversion(df.schema, arrowPayload, arrowFile, Some(converter)) + } + + private def validateConversion(sparkSchema: StructType, + arrowPayload: ArrowPayload, + arrowFile: String, + converterOpt: Option[ArrowConverters] = None): Unit = { + val converter = converterOpt.getOrElse(new ArrowConverters) + val jsonFilePath = testFile(arrowFile) val jsonReader = new JsonFileReader(new File(jsonFilePath), converter.allocator) - val arrowSchema = ArrowConverters.schemaToArrowSchema(df.schema) + val arrowSchema = ArrowConverters.schemaToArrowSchema(sparkSchema) val jsonSchema = jsonReader.start() Validator.compareSchemas(arrowSchema, jsonSchema) - val arrowPayload = df.collectAsArrow(Some(converter)) val arrowRoot = new VectorSchemaRoot(arrowSchema, converter.allocator) val vectorLoader = new VectorLoader(arrowRoot) arrowPayload.foreach(vectorLoader.load) val jsonRoot = jsonReader.read() - Validator.compareVectorSchemaRoot(arrowRoot, jsonRoot) }