diff --git a/pom.xml b/pom.xml
index 0abd544f99bac..62a8802132a01 100644
--- a/pom.xml
+++ b/pom.xml
@@ -182,7 +182,7 @@
2.8
1.8
1.0.0
- 0.1.0
+ 0.1.1-SNAPSHOT
${java.home}
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 29469e020143e..48582e05017e0 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -1971,12 +1971,27 @@ def setUpClass(cls):
ReusedPySparkTestCase.setUpClass()
cls.spark = SparkSession(cls.sc)
+ def assertFramesEqual(self, df_with_arrow, df_without):
+ msg = ("DataFrame from Arrow is not equal" +
+ ("\n\nWith Arrow:\n%s\n%s" % (df_with_arrow, df_with_arrow.dtypes)) +
+ ("\n\nWithout:\n%s\n%s" % (df_without, df_without.dtypes)))
+ self.assertTrue(df_without.equals(df_with_arrow), msg=msg)
+
def test_arrow_toPandas(self):
- schema = StructType().add("key", IntegerType()).add("value", IntegerType())
- df = self.spark.createDataFrame([(1, 2), (2, 4), (3, 6), (4, 8)], schema=schema)
+ schema = StructType([
+ StructField("str_t", StringType(), True), # Fails in conversion
+ StructField("int_t", IntegerType(), True), # Fails, without is converted to int64
+ StructField("long_t", LongType(), True), # Fails if nullable=False
+ StructField("double_t", DoubleType(), True)])
+ data = [("a", 1, 10, 2.0),
+ ("b", 2, 20, 4.0),
+ ("c", 3, 30, 6.0)]
+
+ df = self.spark.createDataFrame(data, schema=schema)
+ df = df.select("long_t", "double_t")
pdf = df.toPandas(useArrow=False)
pdf_arrow = df.toPandas(useArrow=True)
- self.assertTrue(pdf.equals(pdf_arrow))
+ self.assertFramesEqual(pdf_arrow, pdf)
if __name__ == "__main__":
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 86febae8aa079..303473a1334d4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -29,6 +29,7 @@ import io.netty.buffer.ArrowBuf
import org.apache.arrow.memory.RootAllocator
import org.apache.arrow.vector.file.ArrowWriter
import org.apache.arrow.vector.schema.{ArrowFieldNode, ArrowRecordBatch}
+import org.apache.arrow.vector.types.FloatingPointPrecision
import org.apache.arrow.vector.types.pojo.{ArrowType, Field, Schema}
import org.apache.commons.lang3.StringUtils
@@ -2291,6 +2292,18 @@ class Dataset[T] private[sql](
dt match {
case IntegerType =>
new ArrowType.Int(8 * IntegerType.defaultSize, true)
+ case LongType =>
+ new ArrowType.Int(8 * LongType.defaultSize, true)
+ case StringType =>
+ ArrowType.List.INSTANCE
+ case DoubleType =>
+ new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)
+ case FloatType =>
+ new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)
+ case BooleanType =>
+ ArrowType.Bool.INSTANCE
+ case ByteType =>
+ new ArrowType.Int(8, false)
case _ =>
throw new IllegalArgumentException(s"Unsupported data type")
}
@@ -2302,8 +2315,16 @@ class Dataset[T] private[sql](
private[sql] def schemaToArrowSchema(schema: StructType): Schema = {
val arrowFields = schema.fields.map {
case StructField(name, dataType, nullable, metadata) =>
- // TODO: Consider nested types
- new Field(name, nullable, dataTypeToArrowType(dataType), List.empty[Field].asJava)
+ dataType match {
+ // TODO: Consider other nested types
+ case StringType =>
+ // TODO: Make sure String => List
+ val itemField =
+ new Field("item", false, ArrowType.Utf8.INSTANCE, List.empty[Field].asJava)
+ new Field(name, nullable, dataTypeToArrowType(dataType), List(itemField).asJava)
+ case _ =>
+ new Field(name, nullable, dataTypeToArrowType(dataType), List.empty[Field].asJava)
+ }
}
val arrowSchema = new Schema(arrowFields.toIterable.asJava)
arrowSchema
@@ -2319,16 +2340,91 @@ class Dataset[T] private[sql](
}
/**
- * Infer the validity map from the internal rows.
- * @param rows An array of InternalRows
- * @param idx Index of current column in the array of InternalRows
- * @param field StructField related to the current column
- * @param allocator ArrowBuf allocator
+ * Get an entry from the InternalRow, and then set to ArrowBuf.
+ * Note: No Null check for the entry.
+ */
+ private def getAndSetToArrow(
+ row: InternalRow, buf: ArrowBuf, dataType: DataType, ordinal: Int): Unit = {
+ dataType match {
+ case NullType =>
+ case BooleanType =>
+ buf.writeBoolean(row.getBoolean(ordinal))
+ case ShortType =>
+ buf.writeShort(row.getShort(ordinal))
+ case IntegerType =>
+ buf.writeInt(row.getInt(ordinal))
+ case LongType =>
+ buf.writeLong(row.getLong(ordinal))
+ case FloatType =>
+ buf.writeFloat(row.getFloat(ordinal))
+ case DoubleType =>
+ buf.writeDouble(row.getDouble(ordinal))
+ case ByteType =>
+ buf.writeByte(row.getByte(ordinal))
+ case _ =>
+ throw new UnsupportedOperationException(
+ s"Unsupported data type ${dataType.simpleString}")
+ }
+ }
+
+ /**
+ * Convert an array of InternalRow to an ArrowBuf.
*/
- private def internalRowToValidityMap(
- rows: Array[InternalRow], idx: Int, field: StructField, allocator: RootAllocator): ArrowBuf = {
- val buf = allocator.buffer(numBytesOfBitmap(rows.length))
- buf
+ private def internalRowToArrowBuf(
+ rows: Array[InternalRow],
+ ordinal: Int,
+ field: StructField,
+ allocator: RootAllocator): (Array[ArrowBuf], Array[ArrowFieldNode]) = {
+ val numOfRows = rows.length
+
+ field.dataType match {
+ case IntegerType | LongType | DoubleType | FloatType | BooleanType | ByteType =>
+ val validity = allocator.buffer(numBytesOfBitmap(numOfRows))
+ val buf = allocator.buffer(numOfRows * field.dataType.defaultSize)
+ var nullCount = 0
+ rows.foreach { row =>
+ if (row.isNullAt(ordinal)) {
+ nullCount += 1
+ } else {
+ getAndSetToArrow(row, buf, field.dataType, ordinal)
+ }
+ }
+
+ val fieldNode = new ArrowFieldNode(numOfRows, nullCount)
+
+ (Array(validity, buf), Array(fieldNode))
+
+ case StringType =>
+ val validityOffset = allocator.buffer(numBytesOfBitmap(numOfRows))
+ val bufOffset = allocator.buffer((numOfRows + 1) * IntegerType.defaultSize)
+ var bytesCount = 0
+ bufOffset.writeInt(bytesCount) // Start position
+ val validityValues = allocator.buffer(numBytesOfBitmap(numOfRows))
+ val bufValues = allocator.buffer(Int.MaxValue) // TODO: Reduce the size?
+ var nullCount = 0
+ rows.foreach { row =>
+ if (row.isNullAt(ordinal)) {
+ nullCount += 1
+ bufOffset.writeInt(bytesCount)
+ } else {
+ val bytes = row.getUTF8String(ordinal).getBytes
+ bytesCount += bytes.length
+ bufOffset.writeInt(bytesCount)
+ bufValues.writeBytes(bytes)
+ }
+ }
+
+ val fieldNodeOffset = if (field.nullable) {
+ new ArrowFieldNode(numOfRows, nullCount)
+ } else {
+ new ArrowFieldNode(numOfRows, 0)
+ }
+
+ val fieldNodeValues = new ArrowFieldNode(bytesCount, 0)
+
+ (Array(validityOffset, bufOffset, validityValues, bufValues),
+ Array(fieldNodeOffset, fieldNodeValues))
+ }
}
/**
@@ -2336,24 +2432,14 @@ class Dataset[T] private[sql](
*/
private[sql] def internalRowsToArrowRecordBatch(
rows: Array[InternalRow], allocator: RootAllocator): ArrowRecordBatch = {
- val numOfRows = rows.length
-
- val buffers = this.schema.fields.zipWithIndex.flatMap { case (field, idx) =>
- val validity = internalRowToValidityMap(rows, idx, field, allocator)
- val buf = allocator.buffer(numOfRows * field.dataType.defaultSize)
- rows.foreach { row => buf.writeInt(row.getInt(idx)) }
- Array(validity, buf)
- }.toList.asJava
+ val bufAndField = this.schema.fields.zipWithIndex.map { case (field, ordinal) =>
+ internalRowToArrowBuf(rows, ordinal, field, allocator)
+ }
- val fieldNodes = this.schema.fields.zipWithIndex.map { case (field, idx) =>
- if (field.nullable) {
- new ArrowFieldNode(numOfRows, 0)
- } else {
- new ArrowFieldNode(numOfRows, 0)
- }
- }.toList.asJava
+ val buffers = bufAndField.flatMap(_._1).toList.asJava
+ val fieldNodes = bufAndField.flatMap(_._2).toList.asJava
- new ArrowRecordBatch(numOfRows, fieldNodes, buffers)
+ new ArrowRecordBatch(rows.length, fieldNodes, buffers)
}
/**
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetToArrowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetToArrowSuite.scala
index e954cdc751a6c..8aec3699c9dd1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetToArrowSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetToArrowSuite.scala
@@ -17,100 +17,156 @@
package org.apache.spark.sql
-import java.io.{DataInputStream, EOFException, RandomAccessFile}
+import java.io._
import java.net.{InetAddress, Socket}
+import java.nio.{ByteBuffer, ByteOrder}
import java.nio.channels.FileChannel
+import scala.util.Random
+
import io.netty.buffer.ArrowBuf
+import org.apache.arrow.flatbuf.Precision
import org.apache.arrow.memory.RootAllocator
import org.apache.arrow.vector.file.ArrowReader
-import org.apache.arrow.vector.schema.ArrowRecordBatch
+import org.apache.arrow.vector.types.pojo.{ArrowType, Field}
import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.util.Utils
+
-case class ArrowIntTest(a: Int, b: Int)
+case class ArrowTestClass(col1: Int, col2: Double, col3: String)
class DatasetToArrowSuite extends QueryTest with SharedSQLContext {
import testImplicits._
+ final val numElements = 4
+ @transient var data: Seq[ArrowTestClass] = _
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ data = Seq.fill(numElements)(ArrowTestClass(
+ Random.nextInt, Random.nextDouble, Random.nextString(Random.nextInt(100))))
+ }
+
test("Collect as arrow to python") {
+ val dataset = data.toDS()
+
+ val port = dataset.collectAsArrowToPython()
+
+ val receiver: RecordBatchReceiver = new RecordBatchReceiver
+ val (buffer, numBytesRead) = receiver.connectAndRead(port)
+ val channel = receiver.makeFile(buffer)
+ val reader = new ArrowReader(channel, receiver.allocator)
+
+ val footer = reader.readFooter()
+ val schema = footer.getSchema
+
+ val numCols = schema.getFields.size()
+ assert(numCols === dataset.schema.fields.length)
+ for (i <- 0 until schema.getFields.size()) {
+ val arrowField = schema.getFields.get(i)
+ val sparkField = dataset.schema.fields(i)
+ assert(arrowField.getName === sparkField.name)
+ assert(arrowField.isNullable === sparkField.nullable)
+ assert(DatasetToArrowSuite.compareSchemaTypes(arrowField, sparkField))
+ }
+
+ val blockMetadata = footer.getRecordBatches
+ assert(blockMetadata.size() === 1)
+
+ val recordBatch = reader.readRecordBatch(blockMetadata.get(0))
+ val nodes = recordBatch.getNodes
+ assert(nodes.size() === numCols + 1) // +1 for Type String, which has two nodes.
+
+ val firstNode = nodes.get(0)
+ assert(firstNode.getLength === numElements)
+ assert(firstNode.getNullCount === 0)
+
+ val buffers = recordBatch.getBuffers
+ assert(buffers.size() === (numCols + 1) * 2) // +1 for Type String
+
+ assert(receiver.getIntArray(buffers.get(1)) === data.map(_.col1))
+ assert(receiver.getDoubleArray(buffers.get(3)) === data.map(_.col2))
+ assert(receiver.getStringArray(buffers.get(5), buffers.get(7)) ===
+ data.map(d => UTF8String.fromString(d.col3)).toArray)
+ }
+}
- val ds = Seq(ArrowIntTest(1, 2), ArrowIntTest(2, 3), ArrowIntTest(3, 4)).toDS()
-
- val port = ds.collectAsArrowToPython()
-
- val clientThread: Thread = new Thread(new Runnable() {
- def run() {
- try {
- val receiver: RecordBatchReceiver = new RecordBatchReceiver
- val record: ArrowRecordBatch = receiver.read(port)
- }
- catch {
- case e: Exception =>
- throw e
- }
- }
- })
-
- clientThread.start()
-
- try {
- clientThread.join()
- } catch {
- case e: InterruptedException =>
- throw e
- case _ =>
+object DatasetToArrowSuite {
+ def compareSchemaTypes(arrowField: Field, sparkField: StructField): Boolean = {
+ val arrowType = arrowField.getType
+ val sparkType = sparkField.dataType
+ (arrowType, sparkType) match {
+ case (_: ArrowType.Int, _: IntegerType) => true
+ case (_: ArrowType.FloatingPoint, _: DoubleType) =>
+ arrowType.asInstanceOf[ArrowType.FloatingPoint].getPrecision == Precision.DOUBLE
+ case (_: ArrowType.FloatingPoint, _: FloatType) =>
+ arrowType.asInstanceOf[ArrowType.FloatingPoint].getPrecision == Precision.SINGLE
+ case (_: ArrowType.List, _: StringType) =>
+ val subField = arrowField.getChildren
+ (subField.size() == 1) && subField.get(0).getType.isInstanceOf[ArrowType.Utf8]
+ case (_: ArrowType.Bool, _: BooleanType) => true
+ case _ => false
}
}
}
class RecordBatchReceiver {
- def array(buf: ArrowBuf): Array[Byte] = {
- val bytes = Array.ofDim[Byte](buf.readableBytes())
- buf.readBytes(bytes)
- bytes
+ val allocator = new RootAllocator(Long.MaxValue)
+
+ def getIntArray(buf: ArrowBuf): Array[Int] = {
+ val buffer = ByteBuffer.wrap(array(buf)).order(ByteOrder.LITTLE_ENDIAN).asIntBuffer()
+ val resultArray = Array.ofDim[Int](buffer.remaining())
+ buffer.get(resultArray)
+ resultArray
}
- def connectAndRead(port: Int): (Array[Byte], Int) = {
- val s = new Socket(InetAddress.getByName("localhost"), port)
- val is = s.getInputStream
+ def getDoubleArray(buf: ArrowBuf): Array[Double] = {
+ val buffer = ByteBuffer.wrap(array(buf)).order(ByteOrder.LITTLE_ENDIAN).asDoubleBuffer()
+ val resultArray = Array.ofDim[Double](buffer.remaining())
+ buffer.get(resultArray)
+ resultArray
+ }
- val dis = new DataInputStream(is)
- val len = dis.readInt()
+ def getStringArray(bufOffsets: ArrowBuf, bufValues: ArrowBuf): Array[UTF8String] = {
+ val offsets = getIntArray(bufOffsets)
+ val lens = offsets.zip(offsets.drop(1))
+ .map { case (prevOffset, offset) => offset - prevOffset }
- val buffer = Array.ofDim[Byte](len)
- val bytesRead = dis.read(buffer)
- if (bytesRead != len) {
- throw new EOFException("Wrong EOF")
+ val values = array(bufValues)
+ val strings = offsets.zip(lens).map { case (offset, len) =>
+ UTF8String.fromBytes(values, offset, len)
}
- (buffer, len)
+ strings
}
- def makeFile(buffer: Array[Byte]): FileChannel = {
- var aFile = new RandomAccessFile("/tmp/nio-data.txt", "rw")
- aFile.write(buffer)
- aFile.close()
-
- aFile = new RandomAccessFile("/tmp/nio-data.txt", "r")
- val fChannel = aFile.getChannel
- fChannel
+ private def array(buf: ArrowBuf): Array[Byte] = {
+ val bytes = Array.ofDim[Byte](buf.readableBytes())
+ buf.readBytes(bytes)
+ bytes
}
- def readRecordBatch(fc: FileChannel, len: Int): ArrowRecordBatch = {
- val allocator = new RootAllocator(len)
- val reader = new ArrowReader(fc, allocator)
- val footer = reader.readFooter()
- val schema = footer.getSchema
- val blocks = footer.getRecordBatches
- val recordBatch = reader.readRecordBatch(blocks.get(0))
- recordBatch
+ def connectAndRead(port: Int): (Array[Byte], Int) = {
+ val clientSocket = new Socket(InetAddress.getByName("localhost"), port)
+ val clientDataIns = new DataInputStream(clientSocket.getInputStream)
+ val messageLength = clientDataIns.readInt()
+ val buffer = Array.ofDim[Byte](messageLength)
+ clientDataIns.readFully(buffer, 0, messageLength)
+ (buffer, messageLength)
}
- def read(port: Int): ArrowRecordBatch = {
- val (buffer, len) = connectAndRead(port)
- val fc = makeFile(buffer)
- readRecordBatch(fc, len)
+ def makeFile(buffer: Array[Byte]): FileChannel = {
+ val tempDir = Utils.createTempDir(namePrefix = this.getClass.getName).getPath
+ val arrowFile = new File(tempDir, "arrow-bytes")
+ val arrowOus = new FileOutputStream(arrowFile.getPath)
+ arrowOus.write(buffer)
+ arrowOus.close()
+
+ val arrowIns = new FileInputStream(arrowFile.getPath)
+ arrowIns.getChannel
}
}