Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion python/docs/source/user_guide/arrow_pandas.rst
Original file line number Diff line number Diff line change
Expand Up @@ -341,8 +341,9 @@ Supported SQL Types

.. currentmodule:: pyspark.sql.types

Currently, all Spark SQL data types are supported by Arrow-based conversion except :class:`MapType`,
Currently, all Spark SQL data types are supported by Arrow-based conversion except

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I should probably mention MapType only for pyarrow 2.0.0..

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

:class:`ArrayType` of :class:`TimestampType`, and nested :class:`StructType`.
:class: `MapType` is only supported when using PyArrow 2.0.0 and above.

Setting Arrow Batch Size
~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
8 changes: 6 additions & 2 deletions python/pyspark/sql/pandas/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from pyspark.sql.pandas.serializers import ArrowCollectSerializer
from pyspark.sql.types import IntegralType
from pyspark.sql.types import ByteType, ShortType, IntegerType, LongType, FloatType, \
DoubleType, BooleanType, TimestampType, StructType, DataType
DoubleType, BooleanType, MapType, TimestampType, StructType, DataType
from pyspark.traceback_utils import SCCallSiteSync


Expand Down Expand Up @@ -100,7 +100,8 @@ def toPandas(self):
# of PyArrow is found, if 'spark.sql.execution.arrow.pyspark.enabled' is enabled.
if use_arrow:
try:
from pyspark.sql.pandas.types import _check_series_localize_timestamps
from pyspark.sql.pandas.types import _check_series_localize_timestamps, \
_convert_map_items_to_dict
import pyarrow
# Rename columns to avoid duplicated column names.
tmp_column_names = ['col_{}'.format(i) for i in range(len(self.columns))]
Expand All @@ -117,6 +118,9 @@ def toPandas(self):
if isinstance(field.dataType, TimestampType):
pdf[field.name] = \
_check_series_localize_timestamps(pdf[field.name], timezone)
elif isinstance(field.dataType, MapType):
pdf[field.name] = \
_convert_map_items_to_dict(pdf[field.name])
return pdf
else:
return pd.DataFrame.from_records([], columns=self.columns)
Expand Down
1 change: 0 additions & 1 deletion python/pyspark/sql/pandas/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,6 @@ def calculate(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
should be checked for accuracy by users.

Currently,
:class:`pyspark.sql.types.MapType`,
:class:`pyspark.sql.types.ArrayType` of :class:`pyspark.sql.types.TimestampType` and
nested :class:`pyspark.sql.types.StructType`
are currently not supported as output types.
Expand Down
10 changes: 8 additions & 2 deletions python/pyspark/sql/pandas/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ def __init__(self, timezone, safecheck, assign_cols_by_name):
self._assign_cols_by_name = assign_cols_by_name

def arrow_to_pandas(self, arrow_column):
from pyspark.sql.pandas.types import _check_series_localize_timestamps
from pyspark.sql.pandas.types import _check_series_localize_timestamps, \
_convert_map_items_to_dict
import pyarrow

# If the given column is a date type column, creates a series of datetime.date directly
Expand All @@ -127,6 +128,8 @@ def arrow_to_pandas(self, arrow_column):

if pyarrow.types.is_timestamp(arrow_column.type):
return _check_series_localize_timestamps(s, self._timezone)
elif pyarrow.types.is_map(arrow_column.type):
return _convert_map_items_to_dict(s)
else:
return s

Expand All @@ -147,7 +150,8 @@ def _create_batch(self, series):
"""
import pandas as pd
import pyarrow as pa
from pyspark.sql.pandas.types import _check_series_convert_timestamps_internal
from pyspark.sql.pandas.types import _check_series_convert_timestamps_internal, \
_convert_dict_to_map_items
from pandas.api.types import is_categorical_dtype
# Make input conform to [(series1, type1), (series2, type2), ...]
if not isinstance(series, (list, tuple)) or \
Expand All @@ -160,6 +164,8 @@ def create_array(s, t):
# Ensure timestamp series are in expected form for Spark internal representation
if t is not None and pa.types.is_timestamp(t):
s = _check_series_convert_timestamps_internal(s, self._timezone)
elif t is not None and pa.types.is_map(t):
s = _convert_dict_to_map_items(s)
elif is_categorical_dtype(s.dtype):
# Note: This can be removed once minimum pyarrow version is >= 0.16.1
s = s.astype(s.dtypes.categories.dtype)
Expand Down
42 changes: 39 additions & 3 deletions python/pyspark/sql/pandas/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@
pandas instances during the type conversion.
"""

from pyspark.sql.types import ByteType, ShortType, IntegerType, LongType, FloatType, \
DoubleType, DecimalType, StringType, BinaryType, DateType, TimestampType, ArrayType, \
StructType, StructField, BooleanType
from pyspark.sql.types import BooleanType, ByteType, ShortType, IntegerType, LongType, \
FloatType, DoubleType, DecimalType, StringType, BinaryType, DateType, TimestampType, \
ArrayType, MapType, StructType, StructField


def to_arrow_type(dt):
""" Convert Spark data type to pyarrow type
"""
from distutils.version import LooseVersion
import pyarrow as pa
if type(dt) == BooleanType:
arrow_type = pa.bool_()
Expand Down Expand Up @@ -58,6 +59,13 @@ def to_arrow_type(dt):
if type(dt.elementType) in [StructType, TimestampType]:
raise TypeError("Unsupported type in conversion to Arrow: " + str(dt))
arrow_type = pa.list_(to_arrow_type(dt.elementType))
elif type(dt) == MapType:
if LooseVersion(pa.__version__) < LooseVersion("2.0.0"):
raise TypeError("MapType is only supported with pyarrow 2.0.0 and above")
if type(dt.keyType) in [StructType, TimestampType] or \
type(dt.valueType) in [StructType, TimestampType]:
raise TypeError("Unsupported type in conversion to Arrow: " + str(dt))
arrow_type = pa.map_(to_arrow_type(dt.keyType), to_arrow_type(dt.valueType))
elif type(dt) == StructType:
if any(type(field.dataType) == StructType for field in dt):
raise TypeError("Nested StructType not supported in conversion to Arrow")
Expand All @@ -81,6 +89,8 @@ def to_arrow_schema(schema):
def from_arrow_type(at):
""" Convert pyarrow type to Spark data type.
"""
from distutils.version import LooseVersion
import pyarrow as pa
import pyarrow.types as types
if types.is_boolean(at):
spark_type = BooleanType()
Expand Down Expand Up @@ -110,6 +120,12 @@ def from_arrow_type(at):
if types.is_timestamp(at.value_type):
raise TypeError("Unsupported type in conversion from Arrow: " + str(at))
spark_type = ArrayType(from_arrow_type(at.value_type))
elif types.is_map(at):
if LooseVersion(pa.__version__) < LooseVersion("2.0.0"):
raise TypeError("MapType is only supported with pyarrow 2.0.0 and above")
if types.is_timestamp(at.key_type) or types.is_timestamp(at.item_type):
raise TypeError("Unsupported type in conversion from Arrow: " + str(at))
spark_type = MapType(from_arrow_type(at.key_type), from_arrow_type(at.item_type))
elif types.is_struct(at):
if any(types.is_struct(field.type) for field in at):
raise TypeError("Nested StructType not supported in conversion from Arrow: " + str(at))
Expand Down Expand Up @@ -306,3 +322,23 @@ def _check_series_convert_timestamps_tz_local(s, timezone):
`pandas.Series` where if it is a timestamp, has been converted to tz-naive
"""
return _check_series_convert_timestamps_localize(s, timezone, None)


def _convert_map_items_to_dict(s):

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: these conversion functions are because pyarrow expects map items as a list of (key, value) pairs, and has this format when converting to Pandas also. The reason is that the arrow spec could allow for duplicate key values in a row, and doesn't say how these should be handled exactly. So by having these conversions, we match the non-arrow behavior for maps, with a dictionary as input/output.

"""
Convert a series with items as list of (key, value), as made from an Arrow column of map type,
to dict for compatibility with non-arrow MapType columns.
:param s: pandas.Series of lists of (key, value) pairs
:return: pandas.Series of dictionaries
"""
return s.apply(lambda m: None if m is None else {k: v for k, v in m})


def _convert_dict_to_map_items(s):
"""
Convert a series of dictionaries to list of (key, value) pairs to match expected data
for Arrow column of map type.
:param s: pandas.Series of dictionaries
:return: pandas.Series of lists of (key, value) pairs
"""
return s.apply(lambda d: list(d.items()) if d is not None else None)
77 changes: 68 additions & 9 deletions python/pyspark/sql/tests/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@
import time
import unittest
import warnings
from distutils.version import LooseVersion

from pyspark import SparkContext, SparkConf
from pyspark.sql import Row, SparkSession
from pyspark.sql.functions import udf
from pyspark.sql.types import StructType, StringType, IntegerType, LongType, \
FloatType, DoubleType, DecimalType, DateType, TimestampType, BinaryType, StructField, MapType, \
ArrayType
FloatType, DoubleType, DecimalType, DateType, TimestampType, BinaryType, StructField, ArrayType
from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \
pandas_requirement_message, pyarrow_requirement_message
from pyspark.testing.utils import QuietTest
Expand Down Expand Up @@ -114,9 +114,10 @@ def create_pandas_data_frame(self):
return pd.DataFrame(data=data_dict)

def test_toPandas_fallback_enabled(self):
ts = datetime.datetime(2015, 11, 1, 0, 30)
with self.sql_conf({"spark.sql.execution.arrow.pyspark.fallback.enabled": True}):
schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)])
df = self.spark.createDataFrame([({u'a': 1},)], schema=schema)
schema = StructType([StructField("a", ArrayType(TimestampType()), True)])
df = self.spark.createDataFrame([([ts],)], schema=schema)
with QuietTest(self.sc):
with self.warnings_lock:
with warnings.catch_warnings(record=True) as warns:
Expand All @@ -129,10 +130,10 @@ def test_toPandas_fallback_enabled(self):
self.assertTrue(len(user_warns) > 0)
self.assertTrue(
"Attempting non-optimization" in str(user_warns[-1]))
assert_frame_equal(pdf, pd.DataFrame({u'map': [{u'a': 1}]}))
assert_frame_equal(pdf, pd.DataFrame({"a": [[ts]]}))

def test_toPandas_fallback_disabled(self):
schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)])
schema = StructType([StructField("a", ArrayType(TimestampType()), True)])
df = self.spark.createDataFrame([(None,)], schema=schema)
with QuietTest(self.sc):
with self.warnings_lock:
Expand Down Expand Up @@ -336,6 +337,62 @@ def test_toPandas_with_array_type(self):
self.assertTrue(expected[r][e] == result_arrow[r][e] and
result[r][e] == result_arrow[r][e])

def test_createDataFrame_with_map_type(self):
map_data = [{"a": 1}, {"b": 2, "c": 3}, {}, None, {"d": None}]

pdf = pd.DataFrame({"id": [0, 1, 2, 3, 4], "m": map_data})
schema = "id long, m map<string, long>"

with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}):
df = self.spark.createDataFrame(pdf, schema=schema)

if LooseVersion(pa.__version__) < LooseVersion("2.0.0"):
with QuietTest(self.sc):
with self.assertRaisesRegex(Exception, "MapType.*only.*pyarrow 2.0.0"):
self.spark.createDataFrame(pdf, schema=schema)
else:
df_arrow = self.spark.createDataFrame(pdf, schema=schema)

result = df.collect()
result_arrow = df_arrow.collect()

self.assertEqual(len(result), len(result_arrow))
for row, row_arrow in zip(result, result_arrow):
i, m = row
_, m_arrow = row_arrow
self.assertEqual(m, map_data[i])
self.assertEqual(m_arrow, map_data[i])

def test_toPandas_with_map_type(self):
pdf = pd.DataFrame({"id": [0, 1, 2, 3],
"m": [{}, {"a": 1}, {"a": 1, "b": 2}, {"a": 1, "b": 2, "c": 3}]})

with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}):
df = self.spark.createDataFrame(pdf, schema="id long, m map<string, long>")

if LooseVersion(pa.__version__) < LooseVersion("2.0.0"):
with QuietTest(self.sc):
with self.assertRaisesRegex(Exception, "MapType.*only.*pyarrow 2.0.0"):
df.toPandas()
else:
pdf_non, pdf_arrow = self._toPandas_arrow_toggle(df)
assert_frame_equal(pdf_arrow, pdf_non)

def test_toPandas_with_map_type_nulls(self):
pdf = pd.DataFrame({"id": [0, 1, 2, 3, 4],
"m": [{"a": 1}, {"b": 2, "c": 3}, {}, None, {"d": None}]})

with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}):
df = self.spark.createDataFrame(pdf, schema="id long, m map<string, long>")

if LooseVersion(pa.__version__) < LooseVersion("2.0.0"):
with QuietTest(self.sc):
with self.assertRaisesRegex(Exception, "MapType.*only.*pyarrow 2.0.0"):
df.toPandas()
else:
pdf_non, pdf_arrow = self._toPandas_arrow_toggle(df)
assert_frame_equal(pdf_arrow, pdf_non)

def test_createDataFrame_with_int_col_names(self):
import numpy as np
pdf = pd.DataFrame(np.random.rand(4, 2))
Expand All @@ -345,26 +402,28 @@ def test_createDataFrame_with_int_col_names(self):
self.assertEqual(pdf_col_names, df_arrow.columns)

def test_createDataFrame_fallback_enabled(self):
ts = datetime.datetime(2015, 11, 1, 0, 30)
with QuietTest(self.sc):
with self.sql_conf({"spark.sql.execution.arrow.pyspark.fallback.enabled": True}):
with warnings.catch_warnings(record=True) as warns:
# we want the warnings to appear even if this test is run from a subclass
warnings.simplefilter("always")
df = self.spark.createDataFrame(
pd.DataFrame([[{u'a': 1}]]), "a: map<string, int>")
pd.DataFrame({"a": [[ts]]}), "a: array<timestamp>")
# Catch and check the last UserWarning.
user_warns = [
warn.message for warn in warns if isinstance(warn.message, UserWarning)]
self.assertTrue(len(user_warns) > 0)
self.assertTrue(
"Attempting non-optimization" in str(user_warns[-1]))
self.assertEqual(df.collect(), [Row(a={u'a': 1})])
self.assertEqual(df.collect(), [Row(a=[ts])])

def test_createDataFrame_fallback_disabled(self):
with QuietTest(self.sc):
with self.assertRaisesRegexp(TypeError, 'Unsupported type'):
self.spark.createDataFrame(
pd.DataFrame([[{u'a': 1}]]), "a: map<string, int>")
pd.DataFrame({"a": [[datetime.datetime(2015, 11, 1, 0, 30)]]}),
"a: array<timestamp>")

# Regression test for SPARK-23314
def test_timestamp_dst(self):
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/sql/tests/test_pandas_cogrouped_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,9 @@ def test_wrong_return_type(self):
with QuietTest(self.sc):
with self.assertRaisesRegexp(
NotImplementedError,
'Invalid return type.*MapType'):
'Invalid return type.*ArrayType.*TimestampType'):
left.groupby('id').cogroup(right.groupby('id')).applyInPandas(
lambda l, r: l, 'id long, v map<int, int>')
lambda l, r: l, 'id long, v array<timestamp>')

def test_wrong_args(self):
left = self.data1
Expand Down
7 changes: 3 additions & 4 deletions python/pyspark/sql/tests/test_pandas_grouped_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
window
from pyspark.sql.types import IntegerType, DoubleType, ArrayType, BinaryType, ByteType, \
LongType, DecimalType, ShortType, FloatType, StringType, BooleanType, StructType, \
StructField, NullType, MapType, TimestampType
StructField, NullType, TimestampType
from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \
pandas_requirement_message, pyarrow_requirement_message
from pyspark.testing.utils import QuietTest
Expand Down Expand Up @@ -246,10 +246,10 @@ def test_wrong_return_type(self):
with QuietTest(self.sc):
with self.assertRaisesRegexp(
NotImplementedError,
'Invalid return type.*grouped map Pandas UDF.*MapType'):
'Invalid return type.*grouped map Pandas UDF.*ArrayType.*TimestampType'):
pandas_udf(
lambda pdf: pdf,
'id long, v map<int, int>',
'id long, v array<timestamp>',
PandasUDFType.GROUPED_MAP)

def test_wrong_args(self):
Expand All @@ -276,7 +276,6 @@ def test_wrong_args(self):
def test_unsupported_types(self):
common_err_msg = 'Invalid return type.*grouped map Pandas UDF.*'
unsupported_types = [
StructField('map', MapType(StringType(), IntegerType())),
StructField('arr_ts', ArrayType(TimestampType())),
StructField('null', NullType()),
StructField('struct', StructType([StructField('l', LongType())])),
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from pyspark.sql import Row
from pyspark.sql.functions import array, explode, col, lit, mean, sum, \
udf, pandas_udf, PandasUDFType
from pyspark.sql.types import ArrayType, TimestampType, DoubleType, MapType
from pyspark.sql.types import ArrayType, TimestampType
from pyspark.sql.utils import AnalysisException
from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \
pandas_requirement_message, pyarrow_requirement_message
Expand Down Expand Up @@ -159,7 +159,7 @@ def mean_and_std_udf(v):

with QuietTest(self.sc):
with self.assertRaisesRegexp(NotImplementedError, 'not supported'):
@pandas_udf(MapType(DoubleType(), DoubleType()), PandasUDFType.GROUPED_AGG)
@pandas_udf(ArrayType(TimestampType()), PandasUDFType.GROUPED_AGG)
def mean_and_std_udf(v):
return {v.mean(): v.std()}

Expand Down
Loading