Skip to content

Commit eb386be

Browse files
ueshinHyukjinKwon
authored andcommitted
[SPARK-21552][SQL] Add DecimalType support to ArrowWriter.
## What changes were proposed in this pull request? Decimal type is not yet supported in `ArrowWriter`. This is adding the decimal type support. ## How was this patch tested? Added a test to `ArrowConvertersSuite`. Author: Takuya UESHIN <[email protected]> Closes #18754 from ueshin/issues/SPARK-21552.
1 parent 0e68330 commit eb386be

File tree

5 files changed

+131
-22
lines changed

5 files changed

+131
-22
lines changed

python/pyspark/sql/tests.py

Lines changed: 42 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3142,6 +3142,7 @@ class ArrowTests(ReusedSQLTestCase):
31423142
@classmethod
31433143
def setUpClass(cls):
31443144
from datetime import datetime
3145+
from decimal import Decimal
31453146
ReusedSQLTestCase.setUpClass()
31463147

31473148
# Synchronize default timezone between Python and Java
@@ -3158,11 +3159,15 @@ def setUpClass(cls):
31583159
StructField("3_long_t", LongType(), True),
31593160
StructField("4_float_t", FloatType(), True),
31603161
StructField("5_double_t", DoubleType(), True),
3161-
StructField("6_date_t", DateType(), True),
3162-
StructField("7_timestamp_t", TimestampType(), True)])
3163-
cls.data = [(u"a", 1, 10, 0.2, 2.0, datetime(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)),
3164-
(u"b", 2, 20, 0.4, 4.0, datetime(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)),
3165-
(u"c", 3, 30, 0.8, 6.0, datetime(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))]
3162+
StructField("6_decimal_t", DecimalType(38, 18), True),
3163+
StructField("7_date_t", DateType(), True),
3164+
StructField("8_timestamp_t", TimestampType(), True)])
3165+
cls.data = [(u"a", 1, 10, 0.2, 2.0, Decimal("2.0"),
3166+
datetime(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)),
3167+
(u"b", 2, 20, 0.4, 4.0, Decimal("4.0"),
3168+
datetime(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)),
3169+
(u"c", 3, 30, 0.8, 6.0, Decimal("6.0"),
3170+
datetime(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))]
31663171

31673172
@classmethod
31683173
def tearDownClass(cls):
@@ -3190,10 +3195,11 @@ def create_pandas_data_frame(self):
31903195
return pd.DataFrame(data=data_dict)
31913196

31923197
def test_unsupported_datatype(self):
3193-
schema = StructType([StructField("decimal", DecimalType(), True)])
3198+
schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)])
31943199
df = self.spark.createDataFrame([(None,)], schema=schema)
31953200
with QuietTest(self.sc):
3196-
self.assertRaises(Exception, lambda: df.toPandas())
3201+
with self.assertRaisesRegexp(Exception, 'Unsupported data type'):
3202+
df.toPandas()
31973203

31983204
def test_null_conversion(self):
31993205
df_null = self.spark.createDataFrame([tuple([None for _ in range(len(self.data[0]))])] +
@@ -3293,7 +3299,7 @@ def test_createDataFrame_respect_session_timezone(self):
32933299
self.assertNotEqual(result_ny, result_la)
32943300

32953301
# Correct result_la by adjusting 3 hours difference between Los Angeles and New York
3296-
result_la_corrected = [Row(**{k: v - timedelta(hours=3) if k == '7_timestamp_t' else v
3302+
result_la_corrected = [Row(**{k: v - timedelta(hours=3) if k == '8_timestamp_t' else v
32973303
for k, v in row.asDict().items()})
32983304
for row in result_la]
32993305
self.assertEqual(result_ny, result_la_corrected)
@@ -3317,11 +3323,11 @@ def test_createDataFrame_with_incorrect_schema(self):
33173323
def test_createDataFrame_with_names(self):
33183324
pdf = self.create_pandas_data_frame()
33193325
# Test that schema as a list of column names gets applied
3320-
df = self.spark.createDataFrame(pdf, schema=list('abcdefg'))
3321-
self.assertEquals(df.schema.fieldNames(), list('abcdefg'))
3326+
df = self.spark.createDataFrame(pdf, schema=list('abcdefgh'))
3327+
self.assertEquals(df.schema.fieldNames(), list('abcdefgh'))
33223328
# Test that schema as tuple of column names gets applied
3323-
df = self.spark.createDataFrame(pdf, schema=tuple('abcdefg'))
3324-
self.assertEquals(df.schema.fieldNames(), list('abcdefg'))
3329+
df = self.spark.createDataFrame(pdf, schema=tuple('abcdefgh'))
3330+
self.assertEquals(df.schema.fieldNames(), list('abcdefgh'))
33253331

33263332
def test_createDataFrame_column_name_encoding(self):
33273333
import pandas as pd
@@ -3344,7 +3350,7 @@ def test_createDataFrame_does_not_modify_input(self):
33443350
# Some series get converted for Spark to consume, this makes sure input is unchanged
33453351
pdf = self.create_pandas_data_frame()
33463352
# Use a nanosecond value to make sure it is not truncated
3347-
pdf.ix[0, '7_timestamp_t'] = pd.Timestamp(1)
3353+
pdf.ix[0, '8_timestamp_t'] = pd.Timestamp(1)
33483354
# Integers with nulls will get NaNs filled with 0 and will be casted
33493355
pdf.ix[1, '2_int_t'] = None
33503356
pdf_copy = pdf.copy(deep=True)
@@ -3514,17 +3520,20 @@ def test_vectorized_udf_basic(self):
35143520
col('id').alias('long'),
35153521
col('id').cast('float').alias('float'),
35163522
col('id').cast('double').alias('double'),
3523+
col('id').cast('decimal').alias('decimal'),
35173524
col('id').cast('boolean').alias('bool'))
35183525
f = lambda x: x
35193526
str_f = pandas_udf(f, StringType())
35203527
int_f = pandas_udf(f, IntegerType())
35213528
long_f = pandas_udf(f, LongType())
35223529
float_f = pandas_udf(f, FloatType())
35233530
double_f = pandas_udf(f, DoubleType())
3531+
decimal_f = pandas_udf(f, DecimalType())
35243532
bool_f = pandas_udf(f, BooleanType())
35253533
res = df.select(str_f(col('str')), int_f(col('int')),
35263534
long_f(col('long')), float_f(col('float')),
3527-
double_f(col('double')), bool_f(col('bool')))
3535+
double_f(col('double')), decimal_f('decimal'),
3536+
bool_f(col('bool')))
35283537
self.assertEquals(df.collect(), res.collect())
35293538

35303539
def test_vectorized_udf_null_boolean(self):
@@ -3590,6 +3599,16 @@ def test_vectorized_udf_null_double(self):
35903599
res = df.select(double_f(col('double')))
35913600
self.assertEquals(df.collect(), res.collect())
35923601

3602+
def test_vectorized_udf_null_decimal(self):
3603+
from decimal import Decimal
3604+
from pyspark.sql.functions import pandas_udf, col
3605+
data = [(Decimal(3.0),), (Decimal(5.0),), (Decimal(-1.0),), (None,)]
3606+
schema = StructType().add("decimal", DecimalType(38, 18))
3607+
df = self.spark.createDataFrame(data, schema)
3608+
decimal_f = pandas_udf(lambda x: x, DecimalType(38, 18))
3609+
res = df.select(decimal_f(col('decimal')))
3610+
self.assertEquals(df.collect(), res.collect())
3611+
35933612
def test_vectorized_udf_null_string(self):
35943613
from pyspark.sql.functions import pandas_udf, col
35953614
data = [("foo",), (None,), ("bar",), ("bar",)]
@@ -3607,17 +3626,20 @@ def test_vectorized_udf_datatype_string(self):
36073626
col('id').alias('long'),
36083627
col('id').cast('float').alias('float'),
36093628
col('id').cast('double').alias('double'),
3629+
col('id').cast('decimal').alias('decimal'),
36103630
col('id').cast('boolean').alias('bool'))
36113631
f = lambda x: x
36123632
str_f = pandas_udf(f, 'string')
36133633
int_f = pandas_udf(f, 'integer')
36143634
long_f = pandas_udf(f, 'long')
36153635
float_f = pandas_udf(f, 'float')
36163636
double_f = pandas_udf(f, 'double')
3637+
decimal_f = pandas_udf(f, 'decimal(38, 18)')
36173638
bool_f = pandas_udf(f, 'boolean')
36183639
res = df.select(str_f(col('str')), int_f(col('int')),
36193640
long_f(col('long')), float_f(col('float')),
3620-
double_f(col('double')), bool_f(col('bool')))
3641+
double_f(col('double')), decimal_f('decimal'),
3642+
bool_f(col('bool')))
36213643
self.assertEquals(df.collect(), res.collect())
36223644

36233645
def test_vectorized_udf_complex(self):
@@ -3713,12 +3735,12 @@ def test_vectorized_udf_varargs(self):
37133735

37143736
def test_vectorized_udf_unsupported_types(self):
37153737
from pyspark.sql.functions import pandas_udf, col
3716-
schema = StructType([StructField("dt", DecimalType(), True)])
3738+
schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)])
37173739
df = self.spark.createDataFrame([(None,)], schema=schema)
3718-
f = pandas_udf(lambda x: x, DecimalType())
3740+
f = pandas_udf(lambda x: x, MapType(StringType(), IntegerType()))
37193741
with QuietTest(self.sc):
37203742
with self.assertRaisesRegexp(Exception, 'Unsupported data type'):
3721-
df.select(f(col('dt'))).collect()
3743+
df.select(f(col('map'))).collect()
37223744

37233745
def test_vectorized_udf_null_date(self):
37243746
from pyspark.sql.functions import pandas_udf, col
@@ -4012,7 +4034,8 @@ def test_wrong_args(self):
40124034
def test_unsupported_types(self):
40134035
from pyspark.sql.functions import pandas_udf, col, PandasUDFType
40144036
schema = StructType(
4015-
[StructField("id", LongType(), True), StructField("dt", DecimalType(), True)])
4037+
[StructField("id", LongType(), True),
4038+
StructField("map", MapType(StringType(), IntegerType()), True)])
40164039
df = self.spark.createDataFrame([(1, None,)], schema=schema)
40174040
f = pandas_udf(lambda x: x, df.schema, PandasUDFType.GROUP_MAP)
40184041
with QuietTest(self.sc):

python/pyspark/sql/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1617,7 +1617,7 @@ def to_arrow_type(dt):
16171617
elif type(dt) == DoubleType:
16181618
arrow_type = pa.float64()
16191619
elif type(dt) == DecimalType:
1620-
arrow_type = pa.decimal(dt.precision, dt.scale)
1620+
arrow_type = pa.decimal128(dt.precision, dt.scale)
16211621
elif type(dt) == StringType:
16221622
arrow_type = pa.string()
16231623
elif type(dt) == DateType:

sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ object ArrowWriter {
5353
case (LongType, vector: BigIntVector) => new LongWriter(vector)
5454
case (FloatType, vector: Float4Vector) => new FloatWriter(vector)
5555
case (DoubleType, vector: Float8Vector) => new DoubleWriter(vector)
56+
case (DecimalType.Fixed(precision, scale), vector: DecimalVector) =>
57+
new DecimalWriter(vector, precision, scale)
5658
case (StringType, vector: VarCharVector) => new StringWriter(vector)
5759
case (BinaryType, vector: VarBinaryVector) => new BinaryWriter(vector)
5860
case (DateType, vector: DateDayVector) => new DateWriter(vector)
@@ -214,6 +216,25 @@ private[arrow] class DoubleWriter(val valueVector: Float8Vector) extends ArrowFi
214216
}
215217
}
216218

219+
private[arrow] class DecimalWriter(
220+
val valueVector: DecimalVector,
221+
precision: Int,
222+
scale: Int) extends ArrowFieldWriter {
223+
224+
override def setNull(): Unit = {
225+
valueVector.setNull(count)
226+
}
227+
228+
override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
229+
val decimal = input.getDecimal(ordinal, precision, scale)
230+
if (decimal.changePrecision(precision, scale)) {
231+
valueVector.setSafe(count, decimal.toJavaBigDecimal)
232+
} else {
233+
setNull()
234+
}
235+
}
236+
}
237+
217238
private[arrow] class StringWriter(val valueVector: VarCharVector) extends ArrowFieldWriter {
218239

219240
override def setNull(): Unit = {

sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.InternalRow
3535
import org.apache.spark.sql.catalyst.util.DateTimeUtils
3636
import org.apache.spark.sql.internal.SQLConf
3737
import org.apache.spark.sql.test.SharedSQLContext
38-
import org.apache.spark.sql.types.{BinaryType, IntegerType, StructField, StructType}
38+
import org.apache.spark.sql.types.{BinaryType, Decimal, IntegerType, StructField, StructType}
3939
import org.apache.spark.util.Utils
4040

4141

@@ -304,6 +304,70 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll {
304304
collectAndValidate(df, json, "floating_point-double_precision.json")
305305
}
306306

307+
test("decimal conversion") {
308+
val json =
309+
s"""
310+
|{
311+
| "schema" : {
312+
| "fields" : [ {
313+
| "name" : "a_d",
314+
| "type" : {
315+
| "name" : "decimal",
316+
| "precision" : 38,
317+
| "scale" : 18
318+
| },
319+
| "nullable" : true,
320+
| "children" : [ ]
321+
| }, {
322+
| "name" : "b_d",
323+
| "type" : {
324+
| "name" : "decimal",
325+
| "precision" : 38,
326+
| "scale" : 18
327+
| },
328+
| "nullable" : true,
329+
| "children" : [ ]
330+
| } ]
331+
| },
332+
| "batches" : [ {
333+
| "count" : 7,
334+
| "columns" : [ {
335+
| "name" : "a_d",
336+
| "count" : 7,
337+
| "VALIDITY" : [ 1, 1, 1, 1, 1, 1, 1 ],
338+
| "DATA" : [
339+
| "1000000000000000000",
340+
| "2000000000000000000",
341+
| "10000000000000000",
342+
| "200000000000000000000",
343+
| "100000000000000",
344+
| "20000000000000000000000",
345+
| "30000000000000000000" ]
346+
| }, {
347+
| "name" : "b_d",
348+
| "count" : 7,
349+
| "VALIDITY" : [ 1, 0, 0, 1, 0, 1, 0 ],
350+
| "DATA" : [
351+
| "1100000000000000000",
352+
| "0",
353+
| "0",
354+
| "2200000000000000000",
355+
| "0",
356+
| "3300000000000000000",
357+
| "0" ]
358+
| } ]
359+
| } ]
360+
|}
361+
""".stripMargin
362+
363+
val a_d = List(1.0, 2.0, 0.01, 200.0, 0.0001, 20000.0, 30.0).map(Decimal(_))
364+
val b_d = List(Some(Decimal(1.1)), None, None, Some(Decimal(2.2)), None, Some(Decimal(3.3)),
365+
Some(Decimal("123456789012345678901234567890")))
366+
val df = a_d.zip(b_d).toDF("a_d", "b_d")
367+
368+
collectAndValidate(df, json, "decimalData.json")
369+
}
370+
307371
test("index conversion") {
308372
val data = List[Int](1, 2, 3, 4, 5, 6)
309373
val json =
@@ -1153,7 +1217,6 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll {
11531217
assert(msg.getCause.getClass === classOf[UnsupportedOperationException])
11541218
}
11551219

1156-
runUnsupported { decimalData.toArrowPayload.collect() }
11571220
runUnsupported { mapData.toDF().toArrowPayload.collect() }
11581221
runUnsupported { complexData.toArrowPayload.collect() }
11591222
}

sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ class ArrowWriterSuite extends SparkFunSuite {
4949
case LongType => reader.getLong(rowId)
5050
case FloatType => reader.getFloat(rowId)
5151
case DoubleType => reader.getDouble(rowId)
52+
case DecimalType.Fixed(precision, scale) => reader.getDecimal(rowId, precision, scale)
5253
case StringType => reader.getUTF8String(rowId)
5354
case BinaryType => reader.getBinary(rowId)
5455
case DateType => reader.getInt(rowId)
@@ -66,6 +67,7 @@ class ArrowWriterSuite extends SparkFunSuite {
6667
check(LongType, Seq(1L, 2L, null, 4L))
6768
check(FloatType, Seq(1.0f, 2.0f, null, 4.0f))
6869
check(DoubleType, Seq(1.0d, 2.0d, null, 4.0d))
70+
check(DecimalType.SYSTEM_DEFAULT, Seq(Decimal(1), Decimal(2), null, Decimal(4)))
6971
check(StringType, Seq("a", "b", null, "d").map(UTF8String.fromString))
7072
check(BinaryType, Seq("a".getBytes(), "b".getBytes(), null, "d".getBytes()))
7173
check(DateType, Seq(0, 1, 2, null, 4))

0 commit comments

Comments
 (0)