From 223d0a06a755d3ceb59664b37a87af82f61f2ae4 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Mon, 5 Feb 2018 15:52:43 +0900 Subject: [PATCH 1/5] Use datetime.date for date type when converting Spark DataFrame to Pandas DataFrame. --- python/pyspark/serializers.py | 9 ++++++--- python/pyspark/sql/dataframe.py | 7 +++---- python/pyspark/sql/tests.py | 10 +++++----- python/pyspark/sql/types.py | 15 +++++++++++++++ 4 files changed, 29 insertions(+), 12 deletions(-) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 88d6a191babc..4fadb574bd7d 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -267,12 +267,15 @@ def load_stream(self, stream): """ Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series. """ - from pyspark.sql.types import _check_dataframe_localize_timestamps + from pyspark.sql.types import from_arrow_schema, _correct_date_of_dataframe_from_arrow, \ + _check_dataframe_localize_timestamps import pyarrow as pa reader = pa.open_stream(stream) + schema = from_arrow_schema(reader.schema) for batch in reader: - # NOTE: changed from pa.Columns.to_pandas, timezone issue in conversion fixed in 0.7.1 - pdf = _check_dataframe_localize_timestamps(batch.to_pandas(), self._timezone) + pdf = batch.to_pandas() + pdf = _correct_date_of_dataframe_from_arrow(pdf, schema) + pdf = _check_dataframe_localize_timestamps(pdf, self._timezone) yield [c for _, c in pdf.iteritems()] def __repr__(self): diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 2e55407b5397..2f0d4b6deb44 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1923,7 +1923,8 @@ def toPandas(self): if self.sql_ctx.getConf("spark.sql.execution.arrow.enabled", "false").lower() == "true": try: - from pyspark.sql.types import _check_dataframe_localize_timestamps + from pyspark.sql.types import _correct_date_of_dataframe_from_arrow, \ + _check_dataframe_localize_timestamps from pyspark.sql.utils import require_minimum_pyarrow_version import pyarrow require_minimum_pyarrow_version() @@ -1931,6 +1932,7 @@ def toPandas(self): if tables: table = pyarrow.concat_tables(tables) pdf = table.to_pandas() + pdf = _correct_date_of_dataframe_from_arrow(pdf, self.schema) return _check_dataframe_localize_timestamps(pdf, timezone) else: return pd.DataFrame.from_records([], columns=self.columns) @@ -2009,7 +2011,6 @@ def _to_corrected_pandas_type(dt): """ When converting Spark SQL records to Pandas DataFrame, the inferred data type may be wrong. This method gets the corrected data type for Pandas if that type may be inferred uncorrectly. - NOTE: DateType is inferred incorrectly as 'object', TimestampType is correct with datetime64[ns] """ import numpy as np if type(dt) == ByteType: @@ -2020,8 +2021,6 @@ def _to_corrected_pandas_type(dt): return np.int32 elif type(dt) == FloatType: return np.float32 - elif type(dt) == DateType: - return 'datetime64[ns]' else: return None diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index b27363023ae7..cbf57897970d 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2816,7 +2816,7 @@ def test_to_pandas(self): self.assertEquals(types[1], np.object) self.assertEquals(types[2], np.bool) self.assertEquals(types[3], np.float32) - self.assertEquals(types[4], 'datetime64[ns]') + self.assertEquals(types[4], np.object) # datetime.date self.assertEquals(types[5], 'datetime64[ns]') @unittest.skipIf(not _have_old_pandas, "Old Pandas not installed") @@ -3388,7 +3388,7 @@ class ArrowTests(ReusedSQLTestCase): @classmethod def setUpClass(cls): - from datetime import datetime + from datetime import date, datetime from decimal import Decimal ReusedSQLTestCase.setUpClass() @@ -3410,11 +3410,11 @@ def setUpClass(cls): StructField("7_date_t", DateType(), True), StructField("8_timestamp_t", TimestampType(), True)]) cls.data = [(u"a", 1, 10, 0.2, 2.0, Decimal("2.0"), - datetime(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)), + date(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)), (u"b", 2, 20, 0.4, 4.0, Decimal("4.0"), - datetime(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)), + date(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)), (u"c", 3, 30, 0.8, 6.0, Decimal("6.0"), - datetime(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))] + date(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))] @classmethod def tearDownClass(cls): diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 0dc5823f72a3..46e9352247c5 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1694,6 +1694,21 @@ def from_arrow_schema(arrow_schema): for field in arrow_schema]) +def _correct_date_of_dataframe_from_arrow(pdf, schema): + """ Correct date type value to use datetime.date. + + Pandas DataFrame created from PyArrow uses datetime64[ns] for date type values, but we should + use datetime.date to keep backward compatibility. + + :param pdf: pandas.DataFrame + :param schema: a Spark schema of the pandas.DataFrame + """ + for field in schema: + if type(field.dataType) == DateType: + pdf[field.name] = pdf[field.name].dt.date + return pdf + + def _check_dataframe_localize_timestamps(pdf, timezone): """ Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone From 57ab41b90dbdace4dc5ce71421c42cfff27d061c Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Mon, 5 Feb 2018 16:49:36 +0900 Subject: [PATCH 2/5] Modify a test for date type. --- python/pyspark/sql/tests.py | 43 +++++++++++++++++++++++++++++-------- 1 file changed, 34 insertions(+), 9 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index cbf57897970d..84824bbbb3d9 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -4062,18 +4062,42 @@ def test_vectorized_udf_unsupported_types(self): with self.assertRaisesRegexp(Exception, 'Unsupported data type'): df.select(f(col('map'))).collect() - def test_vectorized_udf_null_date(self): + def test_vectorized_udf_dates(self): from pyspark.sql.functions import pandas_udf, col from datetime import date - schema = StructType().add("date", DateType()) - data = [(date(1969, 1, 1),), - (date(2012, 2, 2),), - (None,), - (date(2100, 4, 4),)] + schema = StructType().add("idx", LongType()).add("date", DateType()) + data = [(0, date(1969, 1, 1),), + (1, date(2012, 2, 2),), + (2, None,), + (3, date(2100, 4, 4),)] df = self.spark.createDataFrame(data, schema=schema) - date_f = pandas_udf(lambda t: t, returnType=DateType()) - res = df.select(date_f(col("date"))) - self.assertEquals(df.collect(), res.collect()) + + date_copy = pandas_udf(lambda t: t, returnType=DateType()) + df = df.withColumn("date_copy", date_copy(col("date"))) + + @pandas_udf(returnType=StringType()) + def check_data(idx, date, date_copy): + import pandas as pd + msgs = [] + is_equal = date.isnull() + for i in range(len(idx)): + if (is_equal[i] and data[idx[i]][1] is None) or \ + date[i] == data[idx[i]][1]: + msgs.append(None) + else: + msgs.append( + "date values are not equal (date='%s': data[%d][1]='%s')" + % (date[i], idx[i], data[idx[i]][1])) + return pd.Series(msgs) + + result = df.withColumn("check_data", + check_data(col("idx"), col("date"), col("date_copy"))).collect() + + self.assertEquals(len(data), len(result)) + for i in range(len(result)): + self.assertEquals(data[i][1], result[i][1]) # "date" col + self.assertEquals(data[i][1], result[i][2]) # "date_copy" col + self.assertIsNone(result[i][3]) # "check_data" col def test_vectorized_udf_timestamps(self): from pyspark.sql.functions import pandas_udf, col @@ -4114,6 +4138,7 @@ def check_data(idx, timestamp, timestamp_copy): self.assertEquals(len(data), len(result)) for i in range(len(result)): self.assertEquals(data[i][1], result[i][1]) # "timestamp" col + self.assertEquals(data[i][1], result[i][2]) # "timestamp_copy" col self.assertIsNone(result[i][3]) # "check_data" col def test_vectorized_udf_return_timestamp_tz(self): From ebdbd8c4a06a4da52fc61b1dc98d6e2f2facdf9c Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Mon, 5 Feb 2018 23:34:21 +0900 Subject: [PATCH 3/5] Address a comment. --- python/pyspark/sql/types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 46e9352247c5..7674f997134a 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1698,7 +1698,7 @@ def _correct_date_of_dataframe_from_arrow(pdf, schema): """ Correct date type value to use datetime.date. Pandas DataFrame created from PyArrow uses datetime64[ns] for date type values, but we should - use datetime.date to keep backward compatibility. + use datetime.date to match the behavior with when Arrow optimization is disabled. :param pdf: pandas.DataFrame :param schema: a Spark schema of the pandas.DataFrame From 8823043ee4fd72d1cf126dd5cfa5eacbd0e775d0 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Tue, 6 Feb 2018 14:42:59 +0900 Subject: [PATCH 4/5] Modify a method name. --- python/pyspark/serializers.py | 4 ++-- python/pyspark/sql/dataframe.py | 4 ++-- python/pyspark/sql/types.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 4fadb574bd7d..e870325d202c 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -267,14 +267,14 @@ def load_stream(self, stream): """ Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series. """ - from pyspark.sql.types import from_arrow_schema, _correct_date_of_dataframe_from_arrow, \ + from pyspark.sql.types import from_arrow_schema, _check_dataframe_convert_date, \ _check_dataframe_localize_timestamps import pyarrow as pa reader = pa.open_stream(stream) schema = from_arrow_schema(reader.schema) for batch in reader: pdf = batch.to_pandas() - pdf = _correct_date_of_dataframe_from_arrow(pdf, schema) + pdf = _check_dataframe_convert_date(pdf, schema) pdf = _check_dataframe_localize_timestamps(pdf, self._timezone) yield [c for _, c in pdf.iteritems()] diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 2f0d4b6deb44..59a417015b94 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1923,7 +1923,7 @@ def toPandas(self): if self.sql_ctx.getConf("spark.sql.execution.arrow.enabled", "false").lower() == "true": try: - from pyspark.sql.types import _correct_date_of_dataframe_from_arrow, \ + from pyspark.sql.types import _check_dataframe_convert_date, \ _check_dataframe_localize_timestamps from pyspark.sql.utils import require_minimum_pyarrow_version import pyarrow @@ -1932,7 +1932,7 @@ def toPandas(self): if tables: table = pyarrow.concat_tables(tables) pdf = table.to_pandas() - pdf = _correct_date_of_dataframe_from_arrow(pdf, self.schema) + pdf = _check_dataframe_convert_date(pdf, self.schema) return _check_dataframe_localize_timestamps(pdf, timezone) else: return pd.DataFrame.from_records([], columns=self.columns) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 7674f997134a..093dae5a22e1 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1694,7 +1694,7 @@ def from_arrow_schema(arrow_schema): for field in arrow_schema]) -def _correct_date_of_dataframe_from_arrow(pdf, schema): +def _check_dataframe_convert_date(pdf, schema): """ Correct date type value to use datetime.date. Pandas DataFrame created from PyArrow uses datetime64[ns] for date type values, but we should From f151cdf492959d928025a51cabe9c4ba7a395460 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Tue, 6 Feb 2018 14:43:15 +0900 Subject: [PATCH 5/5] Modify a test to check between its expected Pandas DataFrame. --- python/pyspark/sql/tests.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 84824bbbb3d9..545ec5aee08f 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3461,7 +3461,9 @@ def _toPandas_arrow_toggle(self, df): def test_toPandas_arrow_toggle(self): df = self.spark.createDataFrame(self.data, schema=self.schema) pdf, pdf_arrow = self._toPandas_arrow_toggle(df) - self.assertPandasEqual(pdf_arrow, pdf) + expected = self.create_pandas_data_frame() + self.assertPandasEqual(expected, pdf) + self.assertPandasEqual(expected, pdf_arrow) def test_toPandas_respect_session_timezone(self): df = self.spark.createDataFrame(self.data, schema=self.schema)