Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions python/pyspark/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,10 +311,9 @@ def __init__(self, timezone, safecheck):

def arrow_to_pandas(self, arrow_column):
from pyspark.sql.types import from_arrow_type, \
_check_series_convert_date, _check_series_localize_timestamps
_arrow_column_to_pandas, _check_series_localize_timestamps

s = arrow_column.to_pandas()
s = _check_series_convert_date(s, from_arrow_type(arrow_column.type))
s = _arrow_column_to_pandas(arrow_column, from_arrow_type(arrow_column.type))
s = _check_series_localize_timestamps(s, self._timezone)
return s

Expand Down
5 changes: 2 additions & 3 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2107,14 +2107,13 @@ def toPandas(self):
# of PyArrow is found, if 'spark.sql.execution.arrow.enabled' is enabled.
if use_arrow:
try:
from pyspark.sql.types import _check_dataframe_convert_date, \
from pyspark.sql.types import _arrow_table_to_pandas, \
_check_dataframe_localize_timestamps
import pyarrow
batches = self._collectAsArrow()
if len(batches) > 0:
table = pyarrow.Table.from_batches(batches)
pdf = table.to_pandas()
pdf = _check_dataframe_convert_date(pdf, self.schema)
pdf = _arrow_table_to_pandas(table, self.schema)
return _check_dataframe_localize_timestamps(pdf, timezone)
else:
return pd.DataFrame.from_records([], columns=self.columns)
Expand Down
5 changes: 4 additions & 1 deletion python/pyspark/sql/tests/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,17 @@ def setUpClass(cls):
(u"b", 2, 20, 0.4, 4.0, Decimal("4.0"),
date(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)),
(u"c", 3, 30, 0.8, 6.0, Decimal("6.0"),
date(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))]
date(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3)),
(u"d", 4, 40, 1.0, 8.0, Decimal("8.0"),
date(2262, 4, 12), datetime(2262, 3, 3, 3, 3, 3))]

# TODO: remove version check once minimum pyarrow version is 0.10.0
if LooseVersion("0.10.0") <= LooseVersion(pa.__version__):
cls.schema.add(StructField("9_binary_t", BinaryType(), True))
cls.data[0] = cls.data[0] + (bytearray(b"a"),)
cls.data[1] = cls.data[1] + (bytearray(b"bb"),)
cls.data[2] = cls.data[2] + (bytearray(b"ccc"),)
cls.data[3] = cls.data[3] + (bytearray(b"dddd"),)

@classmethod
def tearDownClass(cls):
Expand Down
3 changes: 2 additions & 1 deletion python/pyspark/sql/tests/test_pandas_udf_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,8 @@ def test_vectorized_udf_dates(self):
data = [(0, date(1969, 1, 1),),
(1, date(2012, 2, 2),),
(2, None,),
(3, date(2100, 4, 4),)]
(3, date(2100, 4, 4),),
(4, date(2262, 4, 12),)]
df = self.spark.createDataFrame(data, schema=schema)

date_copy = pandas_udf(lambda t: t, returnType=DateType())
Expand Down
55 changes: 35 additions & 20 deletions python/pyspark/sql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1681,38 +1681,53 @@ def from_arrow_schema(arrow_schema):
for field in arrow_schema])


def _check_series_convert_date(series, data_type):
"""
Cast the series to datetime.date if it's a date type, otherwise returns the original series.
def _arrow_column_to_pandas(column, data_type):
""" Convert Arrow Column to pandas Series.

If the given column is a date type column, creates a series of datetime.date directly instead
of creating datetime64[ns] as intermediate data.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor: I think these details belong as a comment internally rather than in the doc string.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to say that for dates this will return datetime.date, but yeah maybe move the part about datetime[64] as intermediate to an internal comment. _arrow_table_to_pandas has a comment that the reason for this is to match pyspark w/o arrow, but maybe it would be good to add here as well.


:param series: pandas.Series
:param data_type: a Spark data type for the series
:param series: pyarrow.lib.Column
:param data_type: a Spark data type for the column
"""
import pyarrow
import pandas as pd
import pyarrow as pa
from distutils.version import LooseVersion
# As of Arrow 0.12.0, date_as_objects is True by default, see ARROW-3910
if LooseVersion(pyarrow.__version__) < LooseVersion("0.12.0") and type(data_type) == DateType:
return series.dt.date
# Since Arrow 0.11.0, support date_as_object to return datetime.date instead of np.datetime64.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Include a comment about the overflow here so we know why we are avoiding np.datetime64.

if LooseVersion(pa.__version__) < LooseVersion("0.11.0"):
if type(data_type) == DateType:
return pd.Series(column.to_pylist(), name=column.name)
else:
return column.to_pandas()
else:
return series
return column.to_pandas(date_as_object=True)


def _arrow_table_to_pandas(table, schema):
""" Convert Arrow Table to pandas DataFrame.

def _check_dataframe_convert_date(pdf, schema):
""" Correct date type value to use datetime.date.
If the given table contains a date type column, use `_arrow_column_to_pandas` for pyarrow<0.11
or use `date_as_object` option for pyarrow>=0.11 to avoid creating datetime64[ns] as
intermediate data.

Pandas DataFrame created from PyArrow uses datetime64[ns] for date type values, but we should
use datetime.date to match the behavior with when Arrow optimization is disabled.

:param pdf: pandas.DataFrame
:param schema: a Spark schema of the pandas.DataFrame
:param table: pyarrow.lib.Table
:param schema: a Spark schema of the pyarrow.lib.Table
"""
import pyarrow
import pandas as pd
import pyarrow as pa
from distutils.version import LooseVersion
# As of Arrow 0.12.0, date_as_objects is True by default, see ARROW-3910
if LooseVersion(pyarrow.__version__) < LooseVersion("0.12.0"):
for field in schema:
pdf[field.name] = _check_series_convert_date(pdf[field.name], field.dataType)
return pdf
# Since Arrow 0.11.0, support date_as_object to return datetime.date instead of np.datetime64.
if LooseVersion(pa.__version__) < LooseVersion("0.11.0"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good @ueshin.

@ueshin, @BryanCutler , BTW, which version of PyArrow do you think we should bump up to in Spark 3.0.0? I was thinking about matching it to 0.12.0, or 0.11.0. I think it's overhead that we should test all the pyarrow versions.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to bump to 0.12.0 because I think that would allow us to clean up the code the most, but since it's a raised error if the user doesn't have that version, it might too restrictive. Let's definitely make a JIRA to discuss more.

if any(type(field.dataType) == DateType for field in schema):
return pd.concat([_arrow_column_to_pandas(column, field.dataType)
for column, field in zip(table.itercolumns(), schema)], axis=1)
else:
return table.to_pandas()
else:
return table.to_pandas(date_as_object=True)


def _get_local_timezone():
Expand Down