Skip to content

Commit a24c031

Browse files
ueshincloud-fan
authored andcommitted
[SPARK-23290][SQL][PYTHON] Use datetime.date for date type when converting Spark DataFrame to Pandas DataFrame.
## What changes were proposed in this pull request? In #18664, there was a change in how `DateType` is being returned to users ([line 1968 in dataframe.py](https://github.com/apache/spark/pull/18664/files#diff-6fc344560230bf0ef711bb9b5573f1faR1968)). This can cause client code which works in Spark 2.2 to fail. See [SPARK-23290](https://issues.apache.org/jira/browse/SPARK-23290?focusedCommentId=16350917&page=com.atlassian.jira.plugin.system.issuetabpanels%3Acomment-tabpanel#comment-16350917) for an example. This pr modifies to use `datetime.date` for date type as Spark 2.2 does. ## How was this patch tested? Tests modified to fit the new behavior and existing tests. Author: Takuya UESHIN <[email protected]> Closes #20506 from ueshin/issues/SPARK-23290.
1 parent f3f1e14 commit a24c031

File tree

4 files changed

+66
-22
lines changed

4 files changed

+66
-22
lines changed

python/pyspark/serializers.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -267,12 +267,15 @@ def load_stream(self, stream):
267267
"""
268268
Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series.
269269
"""
270-
from pyspark.sql.types import _check_dataframe_localize_timestamps
270+
from pyspark.sql.types import from_arrow_schema, _check_dataframe_convert_date, \
271+
_check_dataframe_localize_timestamps
271272
import pyarrow as pa
272273
reader = pa.open_stream(stream)
274+
schema = from_arrow_schema(reader.schema)
273275
for batch in reader:
274-
# NOTE: changed from pa.Columns.to_pandas, timezone issue in conversion fixed in 0.7.1
275-
pdf = _check_dataframe_localize_timestamps(batch.to_pandas(), self._timezone)
276+
pdf = batch.to_pandas()
277+
pdf = _check_dataframe_convert_date(pdf, schema)
278+
pdf = _check_dataframe_localize_timestamps(pdf, self._timezone)
276279
yield [c for _, c in pdf.iteritems()]
277280

278281
def __repr__(self):

python/pyspark/sql/dataframe.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1923,14 +1923,16 @@ def toPandas(self):
19231923

19241924
if self.sql_ctx.getConf("spark.sql.execution.arrow.enabled", "false").lower() == "true":
19251925
try:
1926-
from pyspark.sql.types import _check_dataframe_localize_timestamps
1926+
from pyspark.sql.types import _check_dataframe_convert_date, \
1927+
_check_dataframe_localize_timestamps
19271928
from pyspark.sql.utils import require_minimum_pyarrow_version
19281929
import pyarrow
19291930
require_minimum_pyarrow_version()
19301931
tables = self._collectAsArrow()
19311932
if tables:
19321933
table = pyarrow.concat_tables(tables)
19331934
pdf = table.to_pandas()
1935+
pdf = _check_dataframe_convert_date(pdf, self.schema)
19341936
return _check_dataframe_localize_timestamps(pdf, timezone)
19351937
else:
19361938
return pd.DataFrame.from_records([], columns=self.columns)
@@ -2009,7 +2011,6 @@ def _to_corrected_pandas_type(dt):
20092011
"""
20102012
When converting Spark SQL records to Pandas DataFrame, the inferred data type may be wrong.
20112013
This method gets the corrected data type for Pandas if that type may be inferred uncorrectly.
2012-
NOTE: DateType is inferred incorrectly as 'object', TimestampType is correct with datetime64[ns]
20132014
"""
20142015
import numpy as np
20152016
if type(dt) == ByteType:
@@ -2020,8 +2021,6 @@ def _to_corrected_pandas_type(dt):
20202021
return np.int32
20212022
elif type(dt) == FloatType:
20222023
return np.float32
2023-
elif type(dt) == DateType:
2024-
return 'datetime64[ns]'
20252024
else:
20262025
return None
20272026

python/pyspark/sql/tests.py

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2816,7 +2816,7 @@ def test_to_pandas(self):
28162816
self.assertEquals(types[1], np.object)
28172817
self.assertEquals(types[2], np.bool)
28182818
self.assertEquals(types[3], np.float32)
2819-
self.assertEquals(types[4], 'datetime64[ns]')
2819+
self.assertEquals(types[4], np.object) # datetime.date
28202820
self.assertEquals(types[5], 'datetime64[ns]')
28212821

28222822
@unittest.skipIf(not _have_old_pandas, "Old Pandas not installed")
@@ -3388,7 +3388,7 @@ class ArrowTests(ReusedSQLTestCase):
33883388

33893389
@classmethod
33903390
def setUpClass(cls):
3391-
from datetime import datetime
3391+
from datetime import date, datetime
33923392
from decimal import Decimal
33933393
ReusedSQLTestCase.setUpClass()
33943394

@@ -3410,11 +3410,11 @@ def setUpClass(cls):
34103410
StructField("7_date_t", DateType(), True),
34113411
StructField("8_timestamp_t", TimestampType(), True)])
34123412
cls.data = [(u"a", 1, 10, 0.2, 2.0, Decimal("2.0"),
3413-
datetime(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)),
3413+
date(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)),
34143414
(u"b", 2, 20, 0.4, 4.0, Decimal("4.0"),
3415-
datetime(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)),
3415+
date(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)),
34163416
(u"c", 3, 30, 0.8, 6.0, Decimal("6.0"),
3417-
datetime(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))]
3417+
date(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))]
34183418

34193419
@classmethod
34203420
def tearDownClass(cls):
@@ -3461,7 +3461,9 @@ def _toPandas_arrow_toggle(self, df):
34613461
def test_toPandas_arrow_toggle(self):
34623462
df = self.spark.createDataFrame(self.data, schema=self.schema)
34633463
pdf, pdf_arrow = self._toPandas_arrow_toggle(df)
3464-
self.assertPandasEqual(pdf_arrow, pdf)
3464+
expected = self.create_pandas_data_frame()
3465+
self.assertPandasEqual(expected, pdf)
3466+
self.assertPandasEqual(expected, pdf_arrow)
34653467

34663468
def test_toPandas_respect_session_timezone(self):
34673469
df = self.spark.createDataFrame(self.data, schema=self.schema)
@@ -4062,18 +4064,42 @@ def test_vectorized_udf_unsupported_types(self):
40624064
with self.assertRaisesRegexp(Exception, 'Unsupported data type'):
40634065
df.select(f(col('map'))).collect()
40644066

4065-
def test_vectorized_udf_null_date(self):
4067+
def test_vectorized_udf_dates(self):
40664068
from pyspark.sql.functions import pandas_udf, col
40674069
from datetime import date
4068-
schema = StructType().add("date", DateType())
4069-
data = [(date(1969, 1, 1),),
4070-
(date(2012, 2, 2),),
4071-
(None,),
4072-
(date(2100, 4, 4),)]
4070+
schema = StructType().add("idx", LongType()).add("date", DateType())
4071+
data = [(0, date(1969, 1, 1),),
4072+
(1, date(2012, 2, 2),),
4073+
(2, None,),
4074+
(3, date(2100, 4, 4),)]
40734075
df = self.spark.createDataFrame(data, schema=schema)
4074-
date_f = pandas_udf(lambda t: t, returnType=DateType())
4075-
res = df.select(date_f(col("date")))
4076-
self.assertEquals(df.collect(), res.collect())
4076+
4077+
date_copy = pandas_udf(lambda t: t, returnType=DateType())
4078+
df = df.withColumn("date_copy", date_copy(col("date")))
4079+
4080+
@pandas_udf(returnType=StringType())
4081+
def check_data(idx, date, date_copy):
4082+
import pandas as pd
4083+
msgs = []
4084+
is_equal = date.isnull()
4085+
for i in range(len(idx)):
4086+
if (is_equal[i] and data[idx[i]][1] is None) or \
4087+
date[i] == data[idx[i]][1]:
4088+
msgs.append(None)
4089+
else:
4090+
msgs.append(
4091+
"date values are not equal (date='%s': data[%d][1]='%s')"
4092+
% (date[i], idx[i], data[idx[i]][1]))
4093+
return pd.Series(msgs)
4094+
4095+
result = df.withColumn("check_data",
4096+
check_data(col("idx"), col("date"), col("date_copy"))).collect()
4097+
4098+
self.assertEquals(len(data), len(result))
4099+
for i in range(len(result)):
4100+
self.assertEquals(data[i][1], result[i][1]) # "date" col
4101+
self.assertEquals(data[i][1], result[i][2]) # "date_copy" col
4102+
self.assertIsNone(result[i][3]) # "check_data" col
40774103

40784104
def test_vectorized_udf_timestamps(self):
40794105
from pyspark.sql.functions import pandas_udf, col
@@ -4114,6 +4140,7 @@ def check_data(idx, timestamp, timestamp_copy):
41144140
self.assertEquals(len(data), len(result))
41154141
for i in range(len(result)):
41164142
self.assertEquals(data[i][1], result[i][1]) # "timestamp" col
4143+
self.assertEquals(data[i][1], result[i][2]) # "timestamp_copy" col
41174144
self.assertIsNone(result[i][3]) # "check_data" col
41184145

41194146
def test_vectorized_udf_return_timestamp_tz(self):

python/pyspark/sql/types.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1694,6 +1694,21 @@ def from_arrow_schema(arrow_schema):
16941694
for field in arrow_schema])
16951695

16961696

1697+
def _check_dataframe_convert_date(pdf, schema):
1698+
""" Correct date type value to use datetime.date.
1699+
1700+
Pandas DataFrame created from PyArrow uses datetime64[ns] for date type values, but we should
1701+
use datetime.date to match the behavior with when Arrow optimization is disabled.
1702+
1703+
:param pdf: pandas.DataFrame
1704+
:param schema: a Spark schema of the pandas.DataFrame
1705+
"""
1706+
for field in schema:
1707+
if type(field.dataType) == DateType:
1708+
pdf[field.name] = pdf[field.name].dt.date
1709+
return pdf
1710+
1711+
16971712
def _check_dataframe_localize_timestamps(pdf, timezone):
16981713
"""
16991714
Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone

0 commit comments

Comments
 (0)