Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 35 additions & 6 deletions benchmark.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,45 @@
import pyspark
import timeit
import random
import sys
from pyspark.sql import SparkSession
import numpy as np
import pandas as pd

numPartition = 8

def time(df, repeat, number):
def scala_object(jpkg, obj):
return jpkg.__getattr__(obj + "$").__getattr__("MODULE$")

def time(spark, df, repeat, number):
print("collect as internal rows")
time = timeit.repeat(lambda: df._jdf.queryExecution().executedPlan().executeCollect(), repeat=repeat, number=number)
time_df = pd.Series(time)
print(time_df.describe())

print("internal rows to arrow record batch")
arrow = scala_object(spark._jvm.org.apache.spark.sql, "Arrow")
root_allocator = spark._jvm.org.apache.arrow.memory.RootAllocator(sys.maxsize)
internal_rows = df._jdf.queryExecution().executedPlan().executeCollect()
jschema = df._jdf.schema()
def internalRowsToArrowRecordBatch():
rb = arrow.internalRowsToArrowRecordBatch(internal_rows, jschema, root_allocator)
rb.close()

time = timeit.repeat(internalRowsToArrowRecordBatch, repeat=repeat, number=number)
root_allocator.close()
time_df = pd.Series(time)
print(time_df.describe())

print("toPandas with arrow")
print(timeit.repeat(lambda: df.toPandas(True), repeat=repeat, number=number))
time = timeit.repeat(lambda: df.toPandas(True), repeat=repeat, number=number)
time_df = pd.Series(time)
print(time_df.describe())

print("toPandas without arrow")
print(timeit.repeat(lambda: df.toPandas(False), repeat=repeat, number=number))
time = timeit.repeat(lambda: df.toPandas(False), repeat=repeat, number=number)
time_df = pd.Series(time)
print(time_df.describe())

def long():
return random.randint(0, 10000)
Expand All @@ -32,10 +61,10 @@ def genData(spark, size, columns):

if __name__ == "__main__":
spark = SparkSession.builder.appName("ArrowBenchmark").getOrCreate()
df = genData(spark, 1000 * 1000, [long, double])
df = genData(spark, 1000 * 1000, [double])
df.cache()
df.count()
df.collect()

time(df, 10, 1)

time(spark, df, 50, 1)
df.unpersist()
57 changes: 54 additions & 3 deletions sql/core/src/main/scala/org/apache/spark/sql/Arrow.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.arrow.memory.{BaseAllocator, RootAllocator}
import org.apache.arrow.vector._
import org.apache.arrow.vector.BaseValueVector.BaseMutator
import org.apache.arrow.vector.schema.{ArrowFieldNode, ArrowRecordBatch}
import org.apache.arrow.vector.types.FloatingPointPrecision
import org.apache.arrow.vector.types.{FloatingPointPrecision, TimeUnit}
import org.apache.arrow.vector.types.pojo.{ArrowType, Field, Schema}

import org.apache.spark.sql.catalyst.InternalRow
Expand All @@ -43,6 +43,9 @@ object Arrow {
case DoubleType => new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)
case ByteType => new ArrowType.Int(8, true)
case StringType => ArrowType.Utf8.INSTANCE
case BinaryType => ArrowType.Binary.INSTANCE
case DateType => ArrowType.Date.INSTANCE
case TimestampType => new ArrowType.Timestamp(TimeUnit.MILLISECOND)
case _ => throw new UnsupportedOperationException(s"Unsupported data type: ${dataType}")
}
}
Expand All @@ -61,7 +64,10 @@ object Arrow {
val fieldNodes = bufAndField.flatMap(_._1).toList.asJava
val buffers = bufAndField.flatMap(_._2).toList.asJava

new ArrowRecordBatch(rows.length, fieldNodes, buffers)
val recordBatch = new ArrowRecordBatch(rows.length, fieldNodes, buffers)
buffers.asScala.foreach(_.release())

recordBatch
}

/**
Expand Down Expand Up @@ -115,6 +121,9 @@ object ColumnWriter {
case DoubleType => new DoubleColumnWriter(allocator)
case ByteType => new ByteColumnWriter(allocator)
case StringType => new UTF8StringColumnWriter(allocator)
case BinaryType => new BinaryColumnWriter(allocator)
case DateType => new DateColumnWriter(allocator)
case TimestampType => new TimeStampColumnWriter(allocator)
case _ => throw new UnsupportedOperationException(s"Unsupported data type: ${dataType}")
}
}
Expand All @@ -124,6 +133,11 @@ private[sql] trait ColumnWriter {
def init(initialSize: Int): Unit
def writeNull(): Unit
def write(row: InternalRow, ordinal: Int): Unit

/**
* Clear the column writer and return the ArrowFieldNode and ArrowBuf.
* This should be called only once after all the data is written.
*/
def finish(): (Seq[ArrowFieldNode], Seq[ArrowBuf])
}

Expand All @@ -140,7 +154,7 @@ private[sql] abstract class PrimitiveColumnWriter(protected val allocator: BaseA

protected def setNull(): Unit
protected def setValue(row: InternalRow, ordinal: Int): Unit
protected def valueBuffers(): Seq[ArrowBuf] = valueVector.getBuffers(true) // TODO: check the flag
protected def valueBuffers(): Seq[ArrowBuf] = valueVector.getBuffers(true)

override def init(initialSize: Int): Unit = {
valueVector.allocateNew()
Expand Down Expand Up @@ -255,3 +269,40 @@ private[sql] class UTF8StringColumnWriter(allocator: BaseAllocator)
valueMutator.setSafe(count, bytes, 0, bytes.length)
}
}

private[sql] class BinaryColumnWriter(allocator: BaseAllocator)
extends PrimitiveColumnWriter(allocator) {
override protected val valueVector: NullableVarBinaryVector
= new NullableVarBinaryVector("UTF8StringValue", allocator)
override protected val valueMutator: NullableVarBinaryVector#Mutator = valueVector.getMutator

override def setNull(): Unit = valueMutator.setNull(count)
override def setValue(row: InternalRow, ordinal: Int): Unit = {
val bytes = row.getBinary(ordinal)
valueMutator.setSafe(count, bytes, 0, bytes.length)
}
}

private[sql] class DateColumnWriter(allocator: BaseAllocator)
extends PrimitiveColumnWriter(allocator) {
override protected val valueVector: NullableDateVector
= new NullableDateVector("DateValue", allocator)
override protected val valueMutator: NullableDateVector#Mutator = valueVector.getMutator

override protected def setNull(): Unit = valueMutator.setNull(count)
override protected def setValue(row: InternalRow, ordinal: Int): Unit = {
valueMutator.setSafe(count, row.getInt(ordinal).toLong * 24 * 3600 * 1000)
}
}

private[sql] class TimeStampColumnWriter(allocator: BaseAllocator)
extends PrimitiveColumnWriter(allocator) {
override protected val valueVector: NullableTimeStampVector
= new NullableTimeStampVector("TimeStampValue", allocator)
override protected val valueMutator: NullableTimeStampVector#Mutator = valueVector.getMutator

override protected def setNull(): Unit = valueMutator.setNull(count)
override protected def setValue(row: InternalRow, ordinal: Int): Unit = {
valueMutator.setSafe(count, row.getLong(ordinal) / 1000)
}
}
2 changes: 2 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2773,6 +2773,8 @@ class Dataset[T] private[sql](
} catch {
case e: Exception =>
throw e
} finally {
recordBatch.close()
}

withNewExecutionId {
Expand Down
32 changes: 32 additions & 0 deletions sql/core/src/test/resources/test-data/arrow/timestampData.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
{
"schema": {
"fields": [
{
"name": "a_timestamp",
"type": {"name": "timestamp", "unit": "MILLISECOND"},
"nullable": true,
"children": [],
"typeLayout": {
"vectors": [
{"type": "VALIDITY", "typeBitWidth": 1},
{"type": "DATA", "typeBitWidth": 64}
]
}
}
]
},

"batches": [
{
"count": 2,
"columns": [
{
"name": "a_timestamp",
"count": 2,
"VALIDITY": [1, 1],
"DATA": [1365383415567, 1365426610789]
}
]
}
]
}
29 changes: 21 additions & 8 deletions sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql
import java.io.File
import java.sql.{Date, Timestamp}
import java.text.SimpleDateFormat
import java.util.Locale
import java.util.{Locale, TimeZone}

import org.apache.arrow.memory.RootAllocator
import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot}
Expand Down Expand Up @@ -90,17 +90,30 @@ class ArrowSuite extends SharedSQLContext {
collectAndValidate(lowerCaseData, "test-data/arrow/lowercase-strings.json")
}

ignore("time and date conversion") {
val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US)
val d1 = new Date(sdf.parse("2015-04-08 13:10:15").getTime)
val d2 = new Date(sdf.parse("2015-04-08 13:10:15").getTime)
val ts1 = new Timestamp(sdf.parse("2013-04-08 01:10:15").getTime)
val ts2 = new Timestamp(sdf.parse("2013-04-08 13:10:10").getTime)
ignore("date conversion") {
val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS", Locale.US)
val d1 = new Date(sdf.parse("2015-04-08 13:10:15.000").getTime)
val d2 = new Date(sdf.parse("2015-04-08 13:10:15.000").getTime)
val ts1 = new Timestamp(sdf.parse("2013-04-08 01:10:15.567").getTime)
val ts2 = new Timestamp(sdf.parse("2013-04-08 13:10:10.789").getTime)
val dateTimeData = Seq((d1, sdf.format(d1), ts1), (d2, sdf.format(d2), ts2))
.toDF("a_date", "b_string", "c_timestamp")
.toDF("a_date", "b_string", "c_timestamp")
collectAndValidate(dateTimeData, "test-data/arrow/datetimeData-strings.json")
}

test("timestamp conversion") {
val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS z", Locale.US)
val ts1 = new Timestamp(sdf.parse("2013-04-08 01:10:15.567 UTC").getTime)
val ts2 = new Timestamp(sdf.parse("2013-04-08 13:10:10.789 UTC").getTime)
val dateTimeData = Seq((ts1), (ts2)).toDF("a_timestamp")
collectAndValidate(dateTimeData, "test-data/arrow/timestampData.json")
}

// Arrow json reader doesn't support binary data
ignore("binary type conversion") {
collectAndValidate(binaryData, "test-data/arrow/binaryData.json")
}

test("nested type conversion") { }

test("array type conversion") { }
Expand Down