@@ -3142,6 +3142,7 @@ class ArrowTests(ReusedSQLTestCase):
31423142 @classmethod
31433143 def setUpClass (cls ):
31443144 from datetime import datetime
3145+ from decimal import Decimal
31453146 ReusedSQLTestCase .setUpClass ()
31463147
31473148 # Synchronize default timezone between Python and Java
@@ -3158,11 +3159,15 @@ def setUpClass(cls):
31583159 StructField ("3_long_t" , LongType (), True ),
31593160 StructField ("4_float_t" , FloatType (), True ),
31603161 StructField ("5_double_t" , DoubleType (), True ),
3161- StructField ("6_date_t" , DateType (), True ),
3162- StructField ("7_timestamp_t" , TimestampType (), True )])
3163- cls .data = [(u"a" , 1 , 10 , 0.2 , 2.0 , datetime (1969 , 1 , 1 ), datetime (1969 , 1 , 1 , 1 , 1 , 1 )),
3164- (u"b" , 2 , 20 , 0.4 , 4.0 , datetime (2012 , 2 , 2 ), datetime (2012 , 2 , 2 , 2 , 2 , 2 )),
3165- (u"c" , 3 , 30 , 0.8 , 6.0 , datetime (2100 , 3 , 3 ), datetime (2100 , 3 , 3 , 3 , 3 , 3 ))]
3162+ StructField ("6_decimal_t" , DecimalType (38 , 18 ), True ),
3163+ StructField ("7_date_t" , DateType (), True ),
3164+ StructField ("8_timestamp_t" , TimestampType (), True )])
3165+ cls .data = [(u"a" , 1 , 10 , 0.2 , 2.0 , Decimal ("2.0" ),
3166+ datetime (1969 , 1 , 1 ), datetime (1969 , 1 , 1 , 1 , 1 , 1 )),
3167+ (u"b" , 2 , 20 , 0.4 , 4.0 , Decimal ("4.0" ),
3168+ datetime (2012 , 2 , 2 ), datetime (2012 , 2 , 2 , 2 , 2 , 2 )),
3169+ (u"c" , 3 , 30 , 0.8 , 6.0 , Decimal ("6.0" ),
3170+ datetime (2100 , 3 , 3 ), datetime (2100 , 3 , 3 , 3 , 3 , 3 ))]
31663171
31673172 @classmethod
31683173 def tearDownClass (cls ):
@@ -3190,10 +3195,11 @@ def create_pandas_data_frame(self):
31903195 return pd .DataFrame (data = data_dict )
31913196
31923197 def test_unsupported_datatype (self ):
3193- schema = StructType ([StructField ("decimal " , DecimalType ( ), True )])
3198+ schema = StructType ([StructField ("map " , MapType ( StringType (), IntegerType () ), True )])
31943199 df = self .spark .createDataFrame ([(None ,)], schema = schema )
31953200 with QuietTest (self .sc ):
3196- self .assertRaises (Exception , lambda : df .toPandas ())
3201+ with self .assertRaisesRegexp (Exception , 'Unsupported data type' ):
3202+ df .toPandas ()
31973203
31983204 def test_null_conversion (self ):
31993205 df_null = self .spark .createDataFrame ([tuple ([None for _ in range (len (self .data [0 ]))])] +
@@ -3293,7 +3299,7 @@ def test_createDataFrame_respect_session_timezone(self):
32933299 self .assertNotEqual (result_ny , result_la )
32943300
32953301 # Correct result_la by adjusting 3 hours difference between Los Angeles and New York
3296- result_la_corrected = [Row (** {k : v - timedelta (hours = 3 ) if k == '7_timestamp_t ' else v
3302+ result_la_corrected = [Row (** {k : v - timedelta (hours = 3 ) if k == '8_timestamp_t ' else v
32973303 for k , v in row .asDict ().items ()})
32983304 for row in result_la ]
32993305 self .assertEqual (result_ny , result_la_corrected )
@@ -3317,11 +3323,11 @@ def test_createDataFrame_with_incorrect_schema(self):
33173323 def test_createDataFrame_with_names (self ):
33183324 pdf = self .create_pandas_data_frame ()
33193325 # Test that schema as a list of column names gets applied
3320- df = self .spark .createDataFrame (pdf , schema = list ('abcdefg ' ))
3321- self .assertEquals (df .schema .fieldNames (), list ('abcdefg ' ))
3326+ df = self .spark .createDataFrame (pdf , schema = list ('abcdefgh ' ))
3327+ self .assertEquals (df .schema .fieldNames (), list ('abcdefgh ' ))
33223328 # Test that schema as tuple of column names gets applied
3323- df = self .spark .createDataFrame (pdf , schema = tuple ('abcdefg ' ))
3324- self .assertEquals (df .schema .fieldNames (), list ('abcdefg ' ))
3329+ df = self .spark .createDataFrame (pdf , schema = tuple ('abcdefgh ' ))
3330+ self .assertEquals (df .schema .fieldNames (), list ('abcdefgh ' ))
33253331
33263332 def test_createDataFrame_column_name_encoding (self ):
33273333 import pandas as pd
@@ -3344,7 +3350,7 @@ def test_createDataFrame_does_not_modify_input(self):
33443350 # Some series get converted for Spark to consume, this makes sure input is unchanged
33453351 pdf = self .create_pandas_data_frame ()
33463352 # Use a nanosecond value to make sure it is not truncated
3347- pdf .ix [0 , '7_timestamp_t ' ] = pd .Timestamp (1 )
3353+ pdf .ix [0 , '8_timestamp_t ' ] = pd .Timestamp (1 )
33483354 # Integers with nulls will get NaNs filled with 0 and will be casted
33493355 pdf .ix [1 , '2_int_t' ] = None
33503356 pdf_copy = pdf .copy (deep = True )
@@ -3514,17 +3520,20 @@ def test_vectorized_udf_basic(self):
35143520 col ('id' ).alias ('long' ),
35153521 col ('id' ).cast ('float' ).alias ('float' ),
35163522 col ('id' ).cast ('double' ).alias ('double' ),
3523+ col ('id' ).cast ('decimal' ).alias ('decimal' ),
35173524 col ('id' ).cast ('boolean' ).alias ('bool' ))
35183525 f = lambda x : x
35193526 str_f = pandas_udf (f , StringType ())
35203527 int_f = pandas_udf (f , IntegerType ())
35213528 long_f = pandas_udf (f , LongType ())
35223529 float_f = pandas_udf (f , FloatType ())
35233530 double_f = pandas_udf (f , DoubleType ())
3531+ decimal_f = pandas_udf (f , DecimalType ())
35243532 bool_f = pandas_udf (f , BooleanType ())
35253533 res = df .select (str_f (col ('str' )), int_f (col ('int' )),
35263534 long_f (col ('long' )), float_f (col ('float' )),
3527- double_f (col ('double' )), bool_f (col ('bool' )))
3535+ double_f (col ('double' )), decimal_f ('decimal' ),
3536+ bool_f (col ('bool' )))
35283537 self .assertEquals (df .collect (), res .collect ())
35293538
35303539 def test_vectorized_udf_null_boolean (self ):
@@ -3590,6 +3599,16 @@ def test_vectorized_udf_null_double(self):
35903599 res = df .select (double_f (col ('double' )))
35913600 self .assertEquals (df .collect (), res .collect ())
35923601
3602+ def test_vectorized_udf_null_decimal (self ):
3603+ from decimal import Decimal
3604+ from pyspark .sql .functions import pandas_udf , col
3605+ data = [(Decimal (3.0 ),), (Decimal (5.0 ),), (Decimal (- 1.0 ),), (None ,)]
3606+ schema = StructType ().add ("decimal" , DecimalType (38 , 18 ))
3607+ df = self .spark .createDataFrame (data , schema )
3608+ decimal_f = pandas_udf (lambda x : x , DecimalType (38 , 18 ))
3609+ res = df .select (decimal_f (col ('decimal' )))
3610+ self .assertEquals (df .collect (), res .collect ())
3611+
35933612 def test_vectorized_udf_null_string (self ):
35943613 from pyspark .sql .functions import pandas_udf , col
35953614 data = [("foo" ,), (None ,), ("bar" ,), ("bar" ,)]
@@ -3607,17 +3626,20 @@ def test_vectorized_udf_datatype_string(self):
36073626 col ('id' ).alias ('long' ),
36083627 col ('id' ).cast ('float' ).alias ('float' ),
36093628 col ('id' ).cast ('double' ).alias ('double' ),
3629+ col ('id' ).cast ('decimal' ).alias ('decimal' ),
36103630 col ('id' ).cast ('boolean' ).alias ('bool' ))
36113631 f = lambda x : x
36123632 str_f = pandas_udf (f , 'string' )
36133633 int_f = pandas_udf (f , 'integer' )
36143634 long_f = pandas_udf (f , 'long' )
36153635 float_f = pandas_udf (f , 'float' )
36163636 double_f = pandas_udf (f , 'double' )
3637+ decimal_f = pandas_udf (f , 'decimal(38, 18)' )
36173638 bool_f = pandas_udf (f , 'boolean' )
36183639 res = df .select (str_f (col ('str' )), int_f (col ('int' )),
36193640 long_f (col ('long' )), float_f (col ('float' )),
3620- double_f (col ('double' )), bool_f (col ('bool' )))
3641+ double_f (col ('double' )), decimal_f ('decimal' ),
3642+ bool_f (col ('bool' )))
36213643 self .assertEquals (df .collect (), res .collect ())
36223644
36233645 def test_vectorized_udf_complex (self ):
@@ -3713,12 +3735,12 @@ def test_vectorized_udf_varargs(self):
37133735
37143736 def test_vectorized_udf_unsupported_types (self ):
37153737 from pyspark .sql .functions import pandas_udf , col
3716- schema = StructType ([StructField ("dt " , DecimalType ( ), True )])
3738+ schema = StructType ([StructField ("map " , MapType ( StringType (), IntegerType () ), True )])
37173739 df = self .spark .createDataFrame ([(None ,)], schema = schema )
3718- f = pandas_udf (lambda x : x , DecimalType ( ))
3740+ f = pandas_udf (lambda x : x , MapType ( StringType (), IntegerType () ))
37193741 with QuietTest (self .sc ):
37203742 with self .assertRaisesRegexp (Exception , 'Unsupported data type' ):
3721- df .select (f (col ('dt ' ))).collect ()
3743+ df .select (f (col ('map ' ))).collect ()
37223744
37233745 def test_vectorized_udf_null_date (self ):
37243746 from pyspark .sql .functions import pandas_udf , col
@@ -4012,7 +4034,8 @@ def test_wrong_args(self):
40124034 def test_unsupported_types (self ):
40134035 from pyspark .sql .functions import pandas_udf , col , PandasUDFType
40144036 schema = StructType (
4015- [StructField ("id" , LongType (), True ), StructField ("dt" , DecimalType (), True )])
4037+ [StructField ("id" , LongType (), True ),
4038+ StructField ("map" , MapType (StringType (), IntegerType ()), True )])
40164039 df = self .spark .createDataFrame ([(1 , None ,)], schema = schema )
40174040 f = pandas_udf (lambda x : x , df .schema , PandasUDFType .GROUP_MAP )
40184041 with QuietTest (self .sc ):
0 commit comments