diff --git a/benchmark.py b/benchmark.py index f6e7c0ae8b2b..b23f34833666 100644 --- a/benchmark.py +++ b/benchmark.py @@ -1,16 +1,45 @@ import pyspark import timeit import random +import sys from pyspark.sql import SparkSession +import numpy as np +import pandas as pd numPartition = 8 -def time(df, repeat, number): +def scala_object(jpkg, obj): + return jpkg.__getattr__(obj + "$").__getattr__("MODULE$") + +def time(spark, df, repeat, number): + print("collect as internal rows") + time = timeit.repeat(lambda: df._jdf.queryExecution().executedPlan().executeCollect(), repeat=repeat, number=number) + time_df = pd.Series(time) + print(time_df.describe()) + + print("internal rows to arrow record batch") + arrow = scala_object(spark._jvm.org.apache.spark.sql, "Arrow") + root_allocator = spark._jvm.org.apache.arrow.memory.RootAllocator(sys.maxsize) + internal_rows = df._jdf.queryExecution().executedPlan().executeCollect() + jschema = df._jdf.schema() + def internalRowsToArrowRecordBatch(): + rb = arrow.internalRowsToArrowRecordBatch(internal_rows, jschema, root_allocator) + rb.close() + + time = timeit.repeat(internalRowsToArrowRecordBatch, repeat=repeat, number=number) + root_allocator.close() + time_df = pd.Series(time) + print(time_df.describe()) + print("toPandas with arrow") - print(timeit.repeat(lambda: df.toPandas(True), repeat=repeat, number=number)) + time = timeit.repeat(lambda: df.toPandas(True), repeat=repeat, number=number) + time_df = pd.Series(time) + print(time_df.describe()) print("toPandas without arrow") - print(timeit.repeat(lambda: df.toPandas(False), repeat=repeat, number=number)) + time = timeit.repeat(lambda: df.toPandas(False), repeat=repeat, number=number) + time_df = pd.Series(time) + print(time_df.describe()) def long(): return random.randint(0, 10000) @@ -32,10 +61,10 @@ def genData(spark, size, columns): if __name__ == "__main__": spark = SparkSession.builder.appName("ArrowBenchmark").getOrCreate() - df = genData(spark, 1000 * 1000, [long, double]) + df = genData(spark, 1000 * 1000, [double]) df.cache() df.count() + df.collect() - time(df, 10, 1) - + time(spark, df, 50, 1) df.unpersist() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Arrow.scala b/sql/core/src/main/scala/org/apache/spark/sql/Arrow.scala index d58a25fd05e2..beca6313f10c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Arrow.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Arrow.scala @@ -25,7 +25,7 @@ import org.apache.arrow.memory.{BaseAllocator, RootAllocator} import org.apache.arrow.vector._ import org.apache.arrow.vector.BaseValueVector.BaseMutator import org.apache.arrow.vector.schema.{ArrowFieldNode, ArrowRecordBatch} -import org.apache.arrow.vector.types.FloatingPointPrecision +import org.apache.arrow.vector.types.{FloatingPointPrecision, TimeUnit} import org.apache.arrow.vector.types.pojo.{ArrowType, Field, Schema} import org.apache.spark.sql.catalyst.InternalRow @@ -43,6 +43,9 @@ object Arrow { case DoubleType => new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE) case ByteType => new ArrowType.Int(8, true) case StringType => ArrowType.Utf8.INSTANCE + case BinaryType => ArrowType.Binary.INSTANCE + case DateType => ArrowType.Date.INSTANCE + case TimestampType => new ArrowType.Timestamp(TimeUnit.MILLISECOND) case _ => throw new UnsupportedOperationException(s"Unsupported data type: ${dataType}") } } @@ -61,7 +64,10 @@ object Arrow { val fieldNodes = bufAndField.flatMap(_._1).toList.asJava val buffers = bufAndField.flatMap(_._2).toList.asJava - new ArrowRecordBatch(rows.length, fieldNodes, buffers) + val recordBatch = new ArrowRecordBatch(rows.length, fieldNodes, buffers) + buffers.asScala.foreach(_.release()) + + recordBatch } /** @@ -115,6 +121,9 @@ object ColumnWriter { case DoubleType => new DoubleColumnWriter(allocator) case ByteType => new ByteColumnWriter(allocator) case StringType => new UTF8StringColumnWriter(allocator) + case BinaryType => new BinaryColumnWriter(allocator) + case DateType => new DateColumnWriter(allocator) + case TimestampType => new TimeStampColumnWriter(allocator) case _ => throw new UnsupportedOperationException(s"Unsupported data type: ${dataType}") } } @@ -124,6 +133,11 @@ private[sql] trait ColumnWriter { def init(initialSize: Int): Unit def writeNull(): Unit def write(row: InternalRow, ordinal: Int): Unit + + /** + * Clear the column writer and return the ArrowFieldNode and ArrowBuf. + * This should be called only once after all the data is written. + */ def finish(): (Seq[ArrowFieldNode], Seq[ArrowBuf]) } @@ -140,7 +154,7 @@ private[sql] abstract class PrimitiveColumnWriter(protected val allocator: BaseA protected def setNull(): Unit protected def setValue(row: InternalRow, ordinal: Int): Unit - protected def valueBuffers(): Seq[ArrowBuf] = valueVector.getBuffers(true) // TODO: check the flag + protected def valueBuffers(): Seq[ArrowBuf] = valueVector.getBuffers(true) override def init(initialSize: Int): Unit = { valueVector.allocateNew() @@ -255,3 +269,40 @@ private[sql] class UTF8StringColumnWriter(allocator: BaseAllocator) valueMutator.setSafe(count, bytes, 0, bytes.length) } } + +private[sql] class BinaryColumnWriter(allocator: BaseAllocator) + extends PrimitiveColumnWriter(allocator) { + override protected val valueVector: NullableVarBinaryVector + = new NullableVarBinaryVector("UTF8StringValue", allocator) + override protected val valueMutator: NullableVarBinaryVector#Mutator = valueVector.getMutator + + override def setNull(): Unit = valueMutator.setNull(count) + override def setValue(row: InternalRow, ordinal: Int): Unit = { + val bytes = row.getBinary(ordinal) + valueMutator.setSafe(count, bytes, 0, bytes.length) + } +} + +private[sql] class DateColumnWriter(allocator: BaseAllocator) + extends PrimitiveColumnWriter(allocator) { + override protected val valueVector: NullableDateVector + = new NullableDateVector("DateValue", allocator) + override protected val valueMutator: NullableDateVector#Mutator = valueVector.getMutator + + override protected def setNull(): Unit = valueMutator.setNull(count) + override protected def setValue(row: InternalRow, ordinal: Int): Unit = { + valueMutator.setSafe(count, row.getInt(ordinal).toLong * 24 * 3600 * 1000) + } +} + +private[sql] class TimeStampColumnWriter(allocator: BaseAllocator) + extends PrimitiveColumnWriter(allocator) { + override protected val valueVector: NullableTimeStampVector + = new NullableTimeStampVector("TimeStampValue", allocator) + override protected val valueMutator: NullableTimeStampVector#Mutator = valueVector.getMutator + + override protected def setNull(): Unit = valueMutator.setNull(count) + override protected def setValue(row: InternalRow, ordinal: Int): Unit = { + valueMutator.setSafe(count, row.getLong(ordinal) / 1000) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 8b39b5448500..dbfca1882ee8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2773,6 +2773,8 @@ class Dataset[T] private[sql]( } catch { case e: Exception => throw e + } finally { + recordBatch.close() } withNewExecutionId { diff --git a/sql/core/src/test/resources/test-data/arrow/timestampData.json b/sql/core/src/test/resources/test-data/arrow/timestampData.json new file mode 100644 index 000000000000..174c62e4a12d --- /dev/null +++ b/sql/core/src/test/resources/test-data/arrow/timestampData.json @@ -0,0 +1,32 @@ +{ + "schema": { + "fields": [ + { + "name": "a_timestamp", + "type": {"name": "timestamp", "unit": "MILLISECOND"}, + "nullable": true, + "children": [], + "typeLayout": { + "vectors": [ + {"type": "VALIDITY", "typeBitWidth": 1}, + {"type": "DATA", "typeBitWidth": 64} + ] + } + } + ] + }, + + "batches": [ + { + "count": 2, + "columns": [ + { + "name": "a_timestamp", + "count": 2, + "VALIDITY": [1, 1], + "DATA": [1365383415567, 1365426610789] + } + ] + } + ] +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala index f51a74084a10..13b38c8c8568 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import java.io.File import java.sql.{Date, Timestamp} import java.text.SimpleDateFormat -import java.util.Locale +import java.util.{Locale, TimeZone} import org.apache.arrow.memory.RootAllocator import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot} @@ -90,17 +90,30 @@ class ArrowSuite extends SharedSQLContext { collectAndValidate(lowerCaseData, "test-data/arrow/lowercase-strings.json") } - ignore("time and date conversion") { - val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US) - val d1 = new Date(sdf.parse("2015-04-08 13:10:15").getTime) - val d2 = new Date(sdf.parse("2015-04-08 13:10:15").getTime) - val ts1 = new Timestamp(sdf.parse("2013-04-08 01:10:15").getTime) - val ts2 = new Timestamp(sdf.parse("2013-04-08 13:10:10").getTime) + ignore("date conversion") { + val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS", Locale.US) + val d1 = new Date(sdf.parse("2015-04-08 13:10:15.000").getTime) + val d2 = new Date(sdf.parse("2015-04-08 13:10:15.000").getTime) + val ts1 = new Timestamp(sdf.parse("2013-04-08 01:10:15.567").getTime) + val ts2 = new Timestamp(sdf.parse("2013-04-08 13:10:10.789").getTime) val dateTimeData = Seq((d1, sdf.format(d1), ts1), (d2, sdf.format(d2), ts2)) - .toDF("a_date", "b_string", "c_timestamp") + .toDF("a_date", "b_string", "c_timestamp") collectAndValidate(dateTimeData, "test-data/arrow/datetimeData-strings.json") } + test("timestamp conversion") { + val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS z", Locale.US) + val ts1 = new Timestamp(sdf.parse("2013-04-08 01:10:15.567 UTC").getTime) + val ts2 = new Timestamp(sdf.parse("2013-04-08 13:10:10.789 UTC").getTime) + val dateTimeData = Seq((ts1), (ts2)).toDF("a_timestamp") + collectAndValidate(dateTimeData, "test-data/arrow/timestampData.json") + } + + // Arrow json reader doesn't support binary data + ignore("binary type conversion") { + collectAndValidate(binaryData, "test-data/arrow/binaryData.json") + } + test("nested type conversion") { } test("array type conversion") { }