Skip to content

Commit 12dc279

Browse files
icexellossBryanCutler
authored andcommitted
Add support for date/timestamp/binary; Add more numbers to benchmark.py; Fix memory leaking bug
closes apache#19
1 parent a6c2970 commit 12dc279

File tree

5 files changed

+150
-20
lines changed

5 files changed

+150
-20
lines changed

benchmark.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,45 @@
11
import pyspark
22
import timeit
33
import random
4+
import sys
45
from pyspark.sql import SparkSession
6+
import numpy as np
7+
import pandas as pd
58

69
numPartition = 8
710

8-
def time(df, repeat, number):
11+
def scala_object(jpkg, obj):
12+
return jpkg.__getattr__(obj + "$").__getattr__("MODULE$")
13+
14+
def time(spark, df, repeat, number):
15+
print("collect as internal rows")
16+
time = timeit.repeat(lambda: df._jdf.queryExecution().executedPlan().executeCollect(), repeat=repeat, number=number)
17+
time_df = pd.Series(time)
18+
print(time_df.describe())
19+
20+
print("internal rows to arrow record batch")
21+
arrow = scala_object(spark._jvm.org.apache.spark.sql, "Arrow")
22+
root_allocator = spark._jvm.org.apache.arrow.memory.RootAllocator(sys.maxsize)
23+
internal_rows = df._jdf.queryExecution().executedPlan().executeCollect()
24+
jschema = df._jdf.schema()
25+
def internalRowsToArrowRecordBatch():
26+
rb = arrow.internalRowsToArrowRecordBatch(internal_rows, jschema, root_allocator)
27+
rb.close()
28+
29+
time = timeit.repeat(internalRowsToArrowRecordBatch, repeat=repeat, number=number)
30+
root_allocator.close()
31+
time_df = pd.Series(time)
32+
print(time_df.describe())
33+
934
print("toPandas with arrow")
10-
print(timeit.repeat(lambda: df.toPandas(True), repeat=repeat, number=number))
35+
time = timeit.repeat(lambda: df.toPandas(True), repeat=repeat, number=number)
36+
time_df = pd.Series(time)
37+
print(time_df.describe())
1138

1239
print("toPandas without arrow")
13-
print(timeit.repeat(lambda: df.toPandas(False), repeat=repeat, number=number))
40+
time = timeit.repeat(lambda: df.toPandas(False), repeat=repeat, number=number)
41+
time_df = pd.Series(time)
42+
print(time_df.describe())
1443

1544
def long():
1645
return random.randint(0, 10000)
@@ -32,10 +61,10 @@ def genData(spark, size, columns):
3261

3362
if __name__ == "__main__":
3463
spark = SparkSession.builder.appName("ArrowBenchmark").getOrCreate()
35-
df = genData(spark, 1000 * 1000, [long, double])
64+
df = genData(spark, 1000 * 1000, [double])
3665
df.cache()
3766
df.count()
67+
df.collect()
3868

39-
time(df, 10, 1)
40-
69+
time(spark, df, 50, 1)
4170
df.unpersist()

sql/core/src/main/scala/org/apache/spark/sql/Arrow.scala

Lines changed: 60 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import org.apache.arrow.memory.{BaseAllocator, RootAllocator}
2525
import org.apache.arrow.vector._
2626
import org.apache.arrow.vector.BaseValueVector.BaseMutator
2727
import org.apache.arrow.vector.schema.{ArrowFieldNode, ArrowRecordBatch}
28-
import org.apache.arrow.vector.types.FloatingPointPrecision
28+
import org.apache.arrow.vector.types.{FloatingPointPrecision, TimeUnit}
2929
import org.apache.arrow.vector.types.pojo.{ArrowType, Field, Schema}
3030

3131
import org.apache.spark.sql.catalyst.InternalRow
@@ -46,6 +46,9 @@ object Arrow {
4646
case DoubleType => new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)
4747
case ByteType => new ArrowType.Int(8, true)
4848
case StringType => ArrowType.Utf8.INSTANCE
49+
case BinaryType => ArrowType.Binary.INSTANCE
50+
case DateType => ArrowType.Date.INSTANCE
51+
case TimestampType => new ArrowType.Timestamp(TimeUnit.MILLISECOND)
4952
case _ => throw new UnsupportedOperationException(s"Unsupported data type: ${dataType}")
5053
}
5154
}
@@ -57,12 +60,17 @@ object Arrow {
5760
rows: Array[InternalRow],
5861
schema: StructType,
5962
allocator: RootAllocator): ArrowRecordBatch = {
60-
val (fieldNodes, buffers) = schema.fields.zipWithIndex.map { case (field, ordinal) =>
63+
val fieldAndBuf = schema.fields.zipWithIndex.map { case (field, ordinal) =>
6164
internalRowToArrowBuf(rows, ordinal, field, allocator)
6265
}.unzip
66+
val fieldNodes = fieldAndBuf._1.flatten
67+
val buffers = fieldAndBuf._2.flatten
6368

64-
new ArrowRecordBatch(rows.length,
65-
fieldNodes.flatten.toList.asJava, buffers.flatten.toList.asJava)
69+
val recordBatch = new ArrowRecordBatch(rows.length,
70+
fieldNodes.toList.asJava, buffers.toList.asJava)
71+
72+
buffers.foreach(_.release())
73+
recordBatch
6674
}
6775

6876
/**
@@ -107,6 +115,11 @@ private[sql] trait ColumnWriter {
107115
def init(initialSize: Int): Unit
108116
def writeNull(): Unit
109117
def write(row: InternalRow, ordinal: Int): Unit
118+
119+
/**
120+
* Clear the column writer and return the ArrowFieldNode and ArrowBuf.
121+
* This should be called only once after all the data is written.
122+
*/
110123
def finish(): (Seq[ArrowFieldNode], Seq[ArrowBuf])
111124
}
112125

@@ -142,7 +155,7 @@ private[sql] abstract class PrimitiveColumnWriter(protected val allocator: BaseA
142155
override def finish(): (Seq[ArrowFieldNode], Seq[ArrowBuf]) = {
143156
valueMutator.setValueCount(count)
144157
val fieldNode = new ArrowFieldNode(count, nullCount)
145-
val valueBuffers: Seq[ArrowBuf] = valueVector.getBuffers(true) // TODO: check the flag
158+
val valueBuffers: Seq[ArrowBuf] = valueVector.getBuffers(true)
146159
(List(fieldNode), valueBuffers)
147160
}
148161
}
@@ -239,6 +252,44 @@ private[sql] class UTF8StringColumnWriter(allocator: BaseAllocator)
239252
}
240253
}
241254

255+
private[sql] class BinaryColumnWriter(allocator: BaseAllocator)
256+
extends PrimitiveColumnWriter(allocator) {
257+
override protected val valueVector: NullableVarBinaryVector
258+
= new NullableVarBinaryVector("UTF8StringValue", allocator)
259+
override protected val valueMutator: NullableVarBinaryVector#Mutator = valueVector.getMutator
260+
261+
override def setNull(): Unit = valueMutator.setNull(count)
262+
override def setValue(row: InternalRow, ordinal: Int): Unit = {
263+
val bytes = row.getBinary(ordinal)
264+
valueMutator.setSafe(count, bytes, 0, bytes.length)
265+
}
266+
}
267+
268+
private[sql] class DateColumnWriter(allocator: BaseAllocator)
269+
extends PrimitiveColumnWriter(allocator) {
270+
override protected val valueVector: NullableDateVector
271+
= new NullableDateVector("DateValue", allocator)
272+
override protected val valueMutator: NullableDateVector#Mutator = valueVector.getMutator
273+
274+
override protected def setNull(): Unit = valueMutator.setNull(count)
275+
override protected def setValue(row: InternalRow, ordinal: Int): Unit = {
276+
valueMutator.setSafe(count, row.getInt(ordinal).toLong * 24 * 3600 * 1000)
277+
}
278+
}
279+
280+
private[sql] class TimeStampColumnWriter(allocator: BaseAllocator)
281+
extends PrimitiveColumnWriter(allocator) {
282+
override protected val valueVector: NullableTimeStampVector
283+
= new NullableTimeStampVector("TimeStampValue", allocator)
284+
override protected val valueMutator: NullableTimeStampVector#Mutator = valueVector.getMutator
285+
286+
override protected def setNull(): Unit = valueMutator.setNull(count)
287+
288+
override protected def setValue(row: InternalRow, ordinal: Int): Unit = {
289+
valueMutator.setSafe(count, row.getLong(ordinal) / 1000)
290+
}
291+
}
292+
242293
private[sql] object ColumnWriter {
243294
def apply(allocator: BaseAllocator, dataType: DataType): ColumnWriter = {
244295
dataType match {
@@ -250,7 +301,10 @@ private[sql] object ColumnWriter {
250301
case DoubleType => new DoubleColumnWriter(allocator)
251302
case ByteType => new ByteColumnWriter(allocator)
252303
case StringType => new UTF8StringColumnWriter(allocator)
253-
case _ => throw new UnsupportedOperationException(s"Unsupported data type: ${dataType}")
304+
case BinaryType => new BinaryColumnWriter(allocator)
305+
case DateType => new DateColumnWriter(allocator)
306+
case TimestampType => new TimeStampColumnWriter(allocator)
307+
case _ => throw new UnsupportedOperationException(s"Unsupported data type: $dataType")
254308
}
255309
}
256310
}

sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2773,6 +2773,8 @@ class Dataset[T] private[sql](
27732773
} catch {
27742774
case e: Exception =>
27752775
throw e
2776+
} finally {
2777+
recordBatch.close()
27762778
}
27772779

27782780
withNewExecutionId {
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
{
2+
"schema": {
3+
"fields": [
4+
{
5+
"name": "a_timestamp",
6+
"type": {"name": "timestamp", "unit": "MILLISECOND"},
7+
"nullable": true,
8+
"children": [],
9+
"typeLayout": {
10+
"vectors": [
11+
{"type": "VALIDITY", "typeBitWidth": 1},
12+
{"type": "DATA", "typeBitWidth": 64}
13+
]
14+
}
15+
}
16+
]
17+
},
18+
19+
"batches": [
20+
{
21+
"count": 2,
22+
"columns": [
23+
{
24+
"name": "a_timestamp",
25+
"count": 2,
26+
"VALIDITY": [1, 1],
27+
"DATA": [1365383415567, 1365426610789]
28+
}
29+
]
30+
}
31+
]
32+
}

sql/core/src/test/scala/org/apache/spark/sql/ArrowSuite.scala

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.sql
1919
import java.io.File
2020
import java.sql.{Date, Timestamp}
2121
import java.text.SimpleDateFormat
22-
import java.util.Locale
22+
import java.util.{Locale, TimeZone}
2323

2424
import org.apache.arrow.memory.RootAllocator
2525
import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot}
@@ -90,17 +90,30 @@ class ArrowSuite extends SharedSQLContext {
9090
collectAndValidate(lowerCaseData, "test-data/arrow/lowercase-strings.json")
9191
}
9292

93-
ignore("time and date conversion") {
94-
val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US)
95-
val d1 = new Date(sdf.parse("2015-04-08 13:10:15").getTime)
96-
val d2 = new Date(sdf.parse("2015-04-08 13:10:15").getTime)
97-
val ts1 = new Timestamp(sdf.parse("2013-04-08 01:10:15").getTime)
98-
val ts2 = new Timestamp(sdf.parse("2013-04-08 13:10:10").getTime)
93+
ignore("date conversion") {
94+
val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS", Locale.US)
95+
val d1 = new Date(sdf.parse("2015-04-08 13:10:15.000").getTime)
96+
val d2 = new Date(sdf.parse("2015-04-08 13:10:15.000").getTime)
97+
val ts1 = new Timestamp(sdf.parse("2013-04-08 01:10:15.567").getTime)
98+
val ts2 = new Timestamp(sdf.parse("2013-04-08 13:10:10.789").getTime)
9999
val dateTimeData = Seq((d1, sdf.format(d1), ts1), (d2, sdf.format(d2), ts2))
100-
.toDF("a_date", "b_string", "c_timestamp")
100+
.toDF("a_date", "b_string", "c_timestamp")
101101
collectAndValidate(dateTimeData, "test-data/arrow/datetimeData-strings.json")
102102
}
103103

104+
test("timestamp conversion") {
105+
val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS z", Locale.US)
106+
val ts1 = new Timestamp(sdf.parse("2013-04-08 01:10:15.567 UTC").getTime)
107+
val ts2 = new Timestamp(sdf.parse("2013-04-08 13:10:10.789 UTC").getTime)
108+
val dateTimeData = Seq((ts1), (ts2)).toDF("a_timestamp")
109+
collectAndValidate(dateTimeData, "test-data/arrow/timestampData.json")
110+
}
111+
112+
// Arrow json reader doesn't support binary data
113+
ignore("binary type conversion") {
114+
collectAndValidate(binaryData, "test-data/arrow/binaryData.json")
115+
}
116+
104117
test("nested type conversion") { }
105118

106119
test("array type conversion") { }

0 commit comments

Comments
 (0)