diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Arrow.scala b/sql/core/src/main/scala/org/apache/spark/sql/ArrowConverters.scala similarity index 87% rename from sql/core/src/main/scala/org/apache/spark/sql/Arrow.scala rename to sql/core/src/main/scala/org/apache/spark/sql/ArrowConverters.scala index 41ee20ab34a8..8a7379536ab5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Arrow.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ArrowConverters.scala @@ -17,13 +17,16 @@ package org.apache.spark.sql +import java.io.ByteArrayOutputStream +import java.nio.channels.Channels + import scala.collection.JavaConverters._ -import scala.language.implicitConversions import io.netty.buffer.ArrowBuf import org.apache.arrow.memory.{BaseAllocator, RootAllocator} import org.apache.arrow.vector._ import org.apache.arrow.vector.BaseValueVector.BaseMutator +import org.apache.arrow.vector.file.ArrowWriter import org.apache.arrow.vector.schema.{ArrowFieldNode, ArrowRecordBatch} import org.apache.arrow.vector.types.{FloatingPointPrecision, TimeUnit} import org.apache.arrow.vector.types.pojo.{ArrowType, Field, Schema} @@ -31,7 +34,33 @@ import org.apache.arrow.vector.types.pojo.{ArrowType, Field, Schema} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ -object Arrow { +/** + * Intermediate data structure returned from Arrow conversions + */ +private[sql] abstract class ArrowPayload extends Iterator[ArrowRecordBatch] + +/** + * Class that wraps an Arrow RootAllocator used in conversion + */ +private[sql] class ArrowConverters { + private val _allocator = new RootAllocator(Long.MaxValue) + + private[sql] def allocator: RootAllocator = _allocator + + private class ArrowStaticPayload(batches: ArrowRecordBatch*) extends ArrowPayload { + private val iter = batches.iterator + + override def next(): ArrowRecordBatch = iter.next() + override def hasNext: Boolean = iter.hasNext + } + + def internalRowsToPayload(rows: Array[InternalRow], schema: StructType): ArrowPayload = { + val batch = ArrowConverters.internalRowsToArrowRecordBatch(rows, schema, allocator) + new ArrowStaticPayload(batch) + } +} + +private[sql] object ArrowConverters { /** * Map a Spark Dataset type to ArrowType. @@ -49,7 +78,7 @@ object Arrow { case BinaryType => ArrowType.Binary.INSTANCE case DateType => ArrowType.Date.INSTANCE case TimestampType => new ArrowType.Timestamp(TimeUnit.MILLISECOND) - case _ => throw new UnsupportedOperationException(s"Unsupported data type: ${dataType}") + case _ => throw new UnsupportedOperationException(s"Unsupported data type: $dataType") } } @@ -109,6 +138,25 @@ object Arrow { } new Schema(arrowFields.toList.asJava) } + + /** + * Write an ArrowPayload to a byte array + */ + private[sql] def payloadToByteArray(payload: ArrowPayload, schema: StructType): Array[Byte] = { + val arrowSchema = ArrowConverters.schemaToArrowSchema(schema) + val out = new ByteArrayOutputStream() + val writer = new ArrowWriter(Channels.newChannel(out), arrowSchema) + try { + payload.foreach(writer.writeRecordBatch) + } catch { + case e: Exception => + throw e + } finally { + writer.close() + payload.foreach(_.close()) + } + out.toByteArray + } } private[sql] trait ColumnWriter { @@ -255,7 +303,7 @@ private[sql] class UTF8StringColumnWriter(allocator: BaseAllocator) private[sql] class BinaryColumnWriter(allocator: BaseAllocator) extends PrimitiveColumnWriter(allocator) { override protected val valueVector: NullableVarBinaryVector - = new NullableVarBinaryVector("UTF8StringValue", allocator) + = new NullableVarBinaryVector("BinaryValue", allocator) override protected val valueMutator: NullableVarBinaryVector#Mutator = valueVector.getMutator override def setNull(): Unit = valueMutator.setNull(count) @@ -273,6 +321,7 @@ private[sql] class DateColumnWriter(allocator: BaseAllocator) override protected def setNull(): Unit = valueMutator.setNull(count) override protected def setValue(row: InternalRow, ordinal: Int): Unit = { + // TODO: comment on diff btw value representations of date/timestamp valueMutator.setSafe(count, row.getInt(ordinal).toLong * 24 * 3600 * 1000) } } @@ -286,6 +335,7 @@ private[sql] class TimeStampColumnWriter(allocator: BaseAllocator) override protected def setNull(): Unit = valueMutator.setNull(count) override protected def setValue(row: InternalRow, ordinal: Int): Unit = { + // TODO: use microsecond timestamp when ARROW-477 is resolved valueMutator.setSafe(count, row.getLong(ordinal) / 1000) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index bc0888e98df5..8caecb7fd8ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -17,17 +17,13 @@ package org.apache.spark.sql -import java.io.{ByteArrayOutputStream, CharArrayWriter} -import java.nio.channels.Channels +import java.io.CharArrayWriter import scala.collection.JavaConverters._ import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal -import org.apache.arrow.memory.RootAllocator -import org.apache.arrow.vector.file.ArrowWriter -import org.apache.arrow.vector.schema.ArrowRecordBatch import org.apache.commons.lang3.StringUtils import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} @@ -2375,14 +2371,12 @@ class Dataset[T] private[sql]( * @since 2.2.0 */ @DeveloperApi - def collectAsArrow(rootAllocator: Option[RootAllocator] = None): ArrowRecordBatch = { - val allocator = rootAllocator.getOrElse(new RootAllocator(Long.MaxValue)) + def collectAsArrow(converter: Option[ArrowConverters] = None): ArrowPayload = { + val cnvtr = converter.getOrElse(new ArrowConverters) withNewExecutionId { try { val collectedRows = queryExecution.executedPlan.executeCollect() - val recordBatch = Arrow.internalRowsToArrowRecordBatch( - collectedRows, this.schema, allocator) - recordBatch + cnvtr.internalRowsToPayload(collectedRows, this.schema) } catch { case e: Exception => throw e @@ -2763,22 +2757,11 @@ class Dataset[T] private[sql]( * Collect a Dataset as an ArrowRecordBatch, and serve the ArrowRecordBatch to PySpark. */ private[sql] def collectAsArrowToPython(): Int = { - val recordBatch = collectAsArrow() - val arrowSchema = Arrow.schemaToArrowSchema(this.schema) - val out = new ByteArrayOutputStream() - try { - val writer = new ArrowWriter(Channels.newChannel(out), arrowSchema) - writer.writeRecordBatch(recordBatch) - writer.close() - } catch { - case e: Exception => - throw e - } finally { - recordBatch.close() - } + val payload = collectAsArrow() + val payloadBytes = ArrowConverters.payloadToByteArray(payload, this.schema) withNewExecutionId { - PythonRDD.serveIterator(Iterator(out.toByteArray), "serve-Arrow") + PythonRDD.serveIterator(Iterator(payloadBytes), "serve-Arrow") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ArrowConvertersSuite.scala similarity index 90% rename from sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/ArrowConvertersSuite.scala index c784b3eefb74..d4a6b6672e07 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ArrowConvertersSuite.scala @@ -19,15 +19,13 @@ package org.apache.spark.sql import java.io.File import java.sql.{Date, Timestamp} import java.text.SimpleDateFormat -import java.util.{Locale, TimeZone} +import java.util.Locale -import org.apache.arrow.memory.RootAllocator import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot} import org.apache.arrow.vector.file.json.JsonFileReader import org.apache.arrow.vector.util.Validator import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.unsafe.types.CalendarInterval // NOTE - nullable type can be declared as Option[*] or java.lang.* @@ -38,7 +36,7 @@ private[sql] case class FloatData(i: Int, a_f: Float, b_f: Option[Float]) private[sql] case class DoubleData(i: Int, a_d: Double, b_d: Option[Double]) -class ArrowSuite extends SharedSQLContext { +class ArrowConvertersSuite extends SharedSQLContext { import testImplicits._ private def testFile(fileName: String): String = { @@ -46,10 +44,11 @@ class ArrowSuite extends SharedSQLContext { } test("collect to arrow record batch") { - val arrowRecordBatch = indexData.collectAsArrow() - assert(arrowRecordBatch.getLength > 0) - assert(arrowRecordBatch.getNodes.size() > 0) - arrowRecordBatch.close() + val arrowPayload = indexData.collectAsArrow() + assert(arrowPayload.nonEmpty) + arrowPayload.foreach(arrowRecordBatch => assert(arrowRecordBatch.getLength > 0)) + arrowPayload.foreach(arrowRecordBatch => assert(arrowRecordBatch.getNodes.size() > 0)) + arrowPayload.foreach(arrowRecordBatch => arrowRecordBatch.close()) } test("standard type conversion") { @@ -124,8 +123,9 @@ class ArrowSuite extends SharedSQLContext { } test("empty frame collect") { - val emptyBatch = spark.emptyDataFrame.collectAsArrow() - assert(emptyBatch.getLength == 0) + val arrowPayload = spark.emptyDataFrame.collectAsArrow() + assert(arrowPayload.nonEmpty) + arrowPayload.foreach(emptyBatch => assert(emptyBatch.getLength == 0)) } test("unsupported types") { @@ -163,17 +163,17 @@ class ArrowSuite extends SharedSQLContext { private def collectAndValidate(df: DataFrame, arrowFile: String) { val jsonFilePath = testFile(arrowFile) - val allocator = new RootAllocator(Integer.MAX_VALUE) - val jsonReader = new JsonFileReader(new File(jsonFilePath), allocator) + val converter = new ArrowConverters + val jsonReader = new JsonFileReader(new File(jsonFilePath), converter.allocator) - val arrowSchema = Arrow.schemaToArrowSchema(df.schema) + val arrowSchema = ArrowConverters.schemaToArrowSchema(df.schema) val jsonSchema = jsonReader.start() Validator.compareSchemas(arrowSchema, jsonSchema) - val arrowRecordBatch = df.collectAsArrow(Some(allocator)) - val arrowRoot = new VectorSchemaRoot(arrowSchema, allocator) + val arrowPayload = df.collectAsArrow(Some(converter)) + val arrowRoot = new VectorSchemaRoot(arrowSchema, converter.allocator) val vectorLoader = new VectorLoader(arrowRoot) - vectorLoader.load(arrowRecordBatch) + arrowPayload.foreach(vectorLoader.load) val jsonRoot = jsonReader.read() Validator.compareVectorSchemaRoot(arrowRoot, jsonRoot)