From 642934f718b1e3f18378d940f0015b1db64c8419 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 24 Sep 2020 14:48:02 -0700 Subject: [PATCH 01/12] Adding map type support for pyspark, createDataFrame working, toPandas fails on unimplemented for arrow to pandas conversion --- python/pyspark/sql/pandas/types.py | 15 ++++++-- python/pyspark/sql/tests/test_arrow.py | 34 +++++++++++++++++++ .../sql/execution/arrow/ArrowWriter.scala | 12 ++++--- .../execution/arrow/ArrowWriterSuite.scala | 2 ++ 4 files changed, 55 insertions(+), 8 deletions(-) diff --git a/python/pyspark/sql/pandas/types.py b/python/pyspark/sql/pandas/types.py index 67557120715ac..5fbdc902b9134 100644 --- a/python/pyspark/sql/pandas/types.py +++ b/python/pyspark/sql/pandas/types.py @@ -20,9 +20,9 @@ pandas instances during the type conversion. """ -from pyspark.sql.types import ByteType, ShortType, IntegerType, LongType, FloatType, \ - DoubleType, DecimalType, StringType, BinaryType, DateType, TimestampType, ArrayType, \ - StructType, StructField, BooleanType +from pyspark.sql.types import BooleanType, ByteType, ShortType, IntegerType, LongType, \ + FloatType, DoubleType, DecimalType, StringType, BinaryType, DateType, TimestampType, \ + ArrayType, MapType, StructType, StructField def to_arrow_type(dt): @@ -58,6 +58,11 @@ def to_arrow_type(dt): if type(dt.elementType) in [StructType, TimestampType]: raise TypeError("Unsupported type in conversion to Arrow: " + str(dt)) arrow_type = pa.list_(to_arrow_type(dt.elementType)) + elif type(dt) == MapType: + if type(dt.keyType) in [StructType, TimestampType] or \ + type(dt.valueType) in [StructType, TimestampType]: + raise TypeError("Unsupported type in conversion to Arrow: " + str(dt)) + arrow_type = pa.map_(to_arrow_type(dt.keyType), to_arrow_type(dt.valueType)) elif type(dt) == StructType: if any(type(field.dataType) == StructType for field in dt): raise TypeError("Nested StructType not supported in conversion to Arrow") @@ -110,6 +115,10 @@ def from_arrow_type(at): if types.is_timestamp(at.value_type): raise TypeError("Unsupported type in conversion from Arrow: " + str(at)) spark_type = ArrayType(from_arrow_type(at.value_type)) + elif types.is_map(at): + if types.is_timestamp(at.key_type) or types.is_timestamp(at.item_type): + raise TypeError("Unsupported type in conversion from Arrow: " + str(at)) + spark_type = MapType(from_arrow_type(at.key_type), from_arrow_type(at.item_type)) elif types.is_struct(at): if any(types.is_struct(field.type) for field in at): raise TypeError("Nested StructType not supported in conversion from Arrow: " + str(at)) diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py index 55d5e9017b345..f735e57ed9ef9 100644 --- a/python/pyspark/sql/tests/test_arrow.py +++ b/python/pyspark/sql/tests/test_arrow.py @@ -336,6 +336,40 @@ def test_toPandas_with_array_type(self): self.assertTrue(expected[r][e] == result_arrow[r][e] and result[r][e] == result_arrow[r][e]) + def test_createDataFrame_with_map_type(self): + map_data = [{"a": 1}, {"b": 2, "c": 3}, {}, None, {"d": None}] + + pdf = pd.DataFrame({"id": [0, 1, 2, 3, 4], "m": map_data}) + schema = "id long, m map" + + with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}): + df = self.spark.createDataFrame(pdf, schema=schema) + + # pyarrow currently requires a list of pairs for each map + pdf["m"] = [list(m.items()) if m is not None else None for m in map_data] + + df_arrow = self.spark.createDataFrame(pdf, schema=schema) + + result = df.collect() + result_arrow = df_arrow.collect() + + self.assertEqual(len(result), len(result_arrow)) + for row, row_arrow in zip(result, result_arrow): + i, m = row + _, m_arrow = row_arrow + self.assertEqual(m, map_data[i]) + self.assertEqual(m_arrow, map_data[i]) + + def test_toPandas_with_map_type(self): + pdf = pd.DataFrame({"id": [0, 1, 2, 3, 4], + "m": [{"a": 1}, {"b": 2, "c": 3}, {}, None, {"d": None}]}) + + with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}): + df = self.spark.createDataFrame(pdf, schema="id long, m map") + + pdf_non, pdf_arrow = self._toPandas_arrow_toggle(df) + assert_frame_equal(pdf_arrow, pdf_non) + def test_createDataFrame_with_int_col_names(self): import numpy as np pdf = pd.DataFrame(np.random.rand(4, 2)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala index 501e1c460f9c9..f62aa5db0872f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala @@ -63,10 +63,10 @@ object ArrowWriter { val elementVector = createFieldWriter(vector.getDataVector()) new ArrayWriter(vector, elementVector) case (MapType(_, _, _), vector: MapVector) => - val entryWriter = createFieldWriter(vector.getDataVector).asInstanceOf[StructWriter] - val keyWriter = createFieldWriter(entryWriter.valueVector.getChild(MapVector.KEY_NAME)) - val valueWriter = createFieldWriter(entryWriter.valueVector.getChild(MapVector.VALUE_NAME)) - new MapWriter(vector, keyWriter, valueWriter) + val structVector = vector.getDataVector.asInstanceOf[StructVector] + val keyWriter = createFieldWriter(structVector.getChild(MapVector.KEY_NAME)) + val valueWriter = createFieldWriter(structVector.getChild(MapVector.VALUE_NAME)) + new MapWriter(vector, structVector, keyWriter, valueWriter) case (StructType(_), vector: StructVector) => val children = (0 until vector.size()).map { ordinal => createFieldWriter(vector.getChildByOrdinal(ordinal)) @@ -331,11 +331,11 @@ private[arrow] class StructWriter( override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { val struct = input.getStruct(ordinal, children.length) var i = 0 + valueVector.setIndexDefined(count) while (i < struct.numFields) { children(i).write(struct, i) i += 1 } - valueVector.setIndexDefined(count) } override def finish(): Unit = { @@ -351,6 +351,7 @@ private[arrow] class StructWriter( private[arrow] class MapWriter( val valueVector: MapVector, + val structVector: StructVector, val keyWriter: ArrowFieldWriter, val valueWriter: ArrowFieldWriter) extends ArrowFieldWriter { @@ -363,6 +364,7 @@ private[arrow] class MapWriter( val values = map.valueArray() var i = 0 while (i < map.numElements()) { + structVector.setIndexDefined(keyWriter.count) keyWriter.write(keys, i) valueWriter.write(values, i) i += 1 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala index bdc3b5eed7d8d..5b021e5917319 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala @@ -309,6 +309,8 @@ class ArrowWriterSuite extends SparkFunSuite { val map3 = reader.getMap(3) assert(map3 == null) writer.root.close() + + // TODO: figure why writer doesn't fail without setIndexDefined } test("empty map") { From 8a563ca73552ff9311bad3bc89a29d83e1247efe Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Wed, 21 Oct 2020 14:42:02 -0700 Subject: [PATCH 02/12] Add todo notes --- python/pyspark/sql/tests/test_arrow.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py index f735e57ed9ef9..962e2115759c5 100644 --- a/python/pyspark/sql/tests/test_arrow.py +++ b/python/pyspark/sql/tests/test_arrow.py @@ -115,6 +115,7 @@ def create_pandas_data_frame(self): def test_toPandas_fallback_enabled(self): with self.sql_conf({"spark.sql.execution.arrow.pyspark.fallback.enabled": True}): + # TODO update with type that will fail schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)]) df = self.spark.createDataFrame([({u'a': 1},)], schema=schema) with QuietTest(self.sc): @@ -132,6 +133,7 @@ def test_toPandas_fallback_enabled(self): assert_frame_equal(pdf, pd.DataFrame({u'map': [{u'a': 1}]})) def test_toPandas_fallback_disabled(self): + # TODO update with type that will fail schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)]) df = self.spark.createDataFrame([(None,)], schema=schema) with QuietTest(self.sc): From 3ee35b87c98e398bfde97db11a4f13c5cc304209 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 29 Oct 2020 14:59:33 -0700 Subject: [PATCH 03/12] added conversion function, toPandas working now --- python/pyspark/sql/pandas/conversion.py | 8 ++++++-- python/pyspark/sql/pandas/serializers.py | 5 ++++- python/pyspark/sql/pandas/types.py | 10 ++++++++++ python/pyspark/sql/tests/test_arrow.py | 10 ++++++++++ 4 files changed, 30 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/pandas/conversion.py b/python/pyspark/sql/pandas/conversion.py index 3456c12e59c09..d8a241417532e 100644 --- a/python/pyspark/sql/pandas/conversion.py +++ b/python/pyspark/sql/pandas/conversion.py @@ -22,7 +22,7 @@ from pyspark.sql.pandas.serializers import ArrowCollectSerializer from pyspark.sql.types import IntegralType from pyspark.sql.types import ByteType, ShortType, IntegerType, LongType, FloatType, \ - DoubleType, BooleanType, TimestampType, StructType, DataType + DoubleType, BooleanType, MapType, TimestampType, StructType, DataType from pyspark.traceback_utils import SCCallSiteSync @@ -100,7 +100,8 @@ def toPandas(self): # of PyArrow is found, if 'spark.sql.execution.arrow.pyspark.enabled' is enabled. if use_arrow: try: - from pyspark.sql.pandas.types import _check_series_localize_timestamps + from pyspark.sql.pandas.types import _check_series_localize_timestamps, \ + _convert_map_items_to_dict import pyarrow # Rename columns to avoid duplicated column names. tmp_column_names = ['col_{}'.format(i) for i in range(len(self.columns))] @@ -117,6 +118,9 @@ def toPandas(self): if isinstance(field.dataType, TimestampType): pdf[field.name] = \ _check_series_localize_timestamps(pdf[field.name], timezone) + elif isinstance(field.dataType, MapType): + pdf[field.name] = \ + _convert_map_items_to_dict(pdf[field.name]) return pdf else: return pd.DataFrame.from_records([], columns=self.columns) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 73d36ee555fb5..c2cdbad1dcc39 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -117,7 +117,8 @@ def __init__(self, timezone, safecheck, assign_cols_by_name): self._assign_cols_by_name = assign_cols_by_name def arrow_to_pandas(self, arrow_column): - from pyspark.sql.pandas.types import _check_series_localize_timestamps + from pyspark.sql.pandas.types import _check_series_localize_timestamps, \ + _convert_map_items_to_dict import pyarrow # If the given column is a date type column, creates a series of datetime.date directly @@ -127,6 +128,8 @@ def arrow_to_pandas(self, arrow_column): if pyarrow.types.is_timestamp(arrow_column.type): return _check_series_localize_timestamps(s, self._timezone) + elif pyarrow.types.is_map(arrow_column.type): + return _convert_map_items_to_dict(s) else: return s diff --git a/python/pyspark/sql/pandas/types.py b/python/pyspark/sql/pandas/types.py index 5fbdc902b9134..7443151ee561b 100644 --- a/python/pyspark/sql/pandas/types.py +++ b/python/pyspark/sql/pandas/types.py @@ -315,3 +315,13 @@ def _check_series_convert_timestamps_tz_local(s, timezone): `pandas.Series` where if it is a timestamp, has been converted to tz-naive """ return _check_series_convert_timestamps_localize(s, timezone, None) + + +def _convert_map_items_to_dict(s): + """ + Convert a series with items as list of (key, value), as made from an Arrow column of map type, + to dict for compatibility with non-arrow MapType columns. + :param s: a pandas.Series + :return: pandas.Series with each record as a dict + """ + return s.apply(lambda m: None if m is None else {k: v for k, v in m}) diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py index 962e2115759c5..7da9ed0ff64db 100644 --- a/python/pyspark/sql/tests/test_arrow.py +++ b/python/pyspark/sql/tests/test_arrow.py @@ -363,6 +363,16 @@ def test_createDataFrame_with_map_type(self): self.assertEqual(m_arrow, map_data[i]) def test_toPandas_with_map_type(self): + pdf = pd.DataFrame({"id": [0, 1, 2, 3], + "m": [{}, {"a": 1}, {"a": 1, "b": 2}, {"a": 1, "b": 2, "c": 3}]}) + + with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}): + df = self.spark.createDataFrame(pdf, schema="id long, m map") + + pdf_non, pdf_arrow = self._toPandas_arrow_toggle(df) + assert_frame_equal(pdf_arrow, pdf_non) + + def test_toPandas_with_map_type_nulls(self): pdf = pd.DataFrame({"id": [0, 1, 2, 3, 4], "m": [{"a": 1}, {"b": 2, "c": 3}, {}, None, {"d": None}]}) From dfec9a5bdfbe9a658c95f4a1787c5611ea852829 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 10 Nov 2020 14:27:54 -0800 Subject: [PATCH 04/12] Fixed up remaining ArrowTests --- python/pyspark/sql/tests/test_arrow.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py index 7da9ed0ff64db..3698a3dbbea42 100644 --- a/python/pyspark/sql/tests/test_arrow.py +++ b/python/pyspark/sql/tests/test_arrow.py @@ -114,10 +114,10 @@ def create_pandas_data_frame(self): return pd.DataFrame(data=data_dict) def test_toPandas_fallback_enabled(self): + ts = datetime.datetime(2015, 11, 1, 0, 30) with self.sql_conf({"spark.sql.execution.arrow.pyspark.fallback.enabled": True}): - # TODO update with type that will fail - schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)]) - df = self.spark.createDataFrame([({u'a': 1},)], schema=schema) + schema = StructType([StructField("a", ArrayType(TimestampType()), True)]) + df = self.spark.createDataFrame([([ts],)], schema=schema) with QuietTest(self.sc): with self.warnings_lock: with warnings.catch_warnings(record=True) as warns: @@ -130,11 +130,10 @@ def test_toPandas_fallback_enabled(self): self.assertTrue(len(user_warns) > 0) self.assertTrue( "Attempting non-optimization" in str(user_warns[-1])) - assert_frame_equal(pdf, pd.DataFrame({u'map': [{u'a': 1}]})) + assert_frame_equal(pdf, pd.DataFrame({"a": [[ts]]})) def test_toPandas_fallback_disabled(self): - # TODO update with type that will fail - schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)]) + schema = StructType([StructField("a", ArrayType(TimestampType()), True)]) df = self.spark.createDataFrame([(None,)], schema=schema) with QuietTest(self.sc): with self.warnings_lock: @@ -391,26 +390,28 @@ def test_createDataFrame_with_int_col_names(self): self.assertEqual(pdf_col_names, df_arrow.columns) def test_createDataFrame_fallback_enabled(self): + ts = datetime.datetime(2015, 11, 1, 0, 30) with QuietTest(self.sc): with self.sql_conf({"spark.sql.execution.arrow.pyspark.fallback.enabled": True}): with warnings.catch_warnings(record=True) as warns: # we want the warnings to appear even if this test is run from a subclass warnings.simplefilter("always") df = self.spark.createDataFrame( - pd.DataFrame([[{u'a': 1}]]), "a: map") + pd.DataFrame({"a": [[ts]]}), "a: array") # Catch and check the last UserWarning. user_warns = [ warn.message for warn in warns if isinstance(warn.message, UserWarning)] self.assertTrue(len(user_warns) > 0) self.assertTrue( "Attempting non-optimization" in str(user_warns[-1])) - self.assertEqual(df.collect(), [Row(a={u'a': 1})]) + self.assertEqual(df.collect(), [Row(a=[ts])]) def test_createDataFrame_fallback_disabled(self): with QuietTest(self.sc): with self.assertRaisesRegexp(TypeError, 'Unsupported type'): self.spark.createDataFrame( - pd.DataFrame([[{u'a': 1}]]), "a: map") + pd.DataFrame({"a": [[datetime.datetime(2015, 11, 1, 0, 30)]]}), + "a: array") # Regression test for SPARK-23314 def test_timestamp_dst(self): From aed88b294ccad388bcfa0dbea28050304860c258 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Fri, 13 Nov 2020 15:28:43 -0800 Subject: [PATCH 05/12] Added conversion from Pandas series of dicts to arrow map --- python/pyspark/sql/pandas/serializers.py | 5 ++++- python/pyspark/sql/pandas/types.py | 14 ++++++++++++-- python/pyspark/sql/tests/test_arrow.py | 3 --- 3 files changed, 16 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index c2cdbad1dcc39..2dcfdc1046049 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -150,7 +150,8 @@ def _create_batch(self, series): """ import pandas as pd import pyarrow as pa - from pyspark.sql.pandas.types import _check_series_convert_timestamps_internal + from pyspark.sql.pandas.types import _check_series_convert_timestamps_internal, \ + _convert_dict_to_map_items from pandas.api.types import is_categorical_dtype # Make input conform to [(series1, type1), (series2, type2), ...] if not isinstance(series, (list, tuple)) or \ @@ -163,6 +164,8 @@ def create_array(s, t): # Ensure timestamp series are in expected form for Spark internal representation if t is not None and pa.types.is_timestamp(t): s = _check_series_convert_timestamps_internal(s, self._timezone) + elif t is not None and pa.types.is_map(t): + s = _convert_dict_to_map_items(s) elif is_categorical_dtype(s.dtype): # Note: This can be removed once minimum pyarrow version is >= 0.16.1 s = s.astype(s.dtypes.categories.dtype) diff --git a/python/pyspark/sql/pandas/types.py b/python/pyspark/sql/pandas/types.py index 7443151ee561b..edf2208969878 100644 --- a/python/pyspark/sql/pandas/types.py +++ b/python/pyspark/sql/pandas/types.py @@ -321,7 +321,17 @@ def _convert_map_items_to_dict(s): """ Convert a series with items as list of (key, value), as made from an Arrow column of map type, to dict for compatibility with non-arrow MapType columns. - :param s: a pandas.Series - :return: pandas.Series with each record as a dict + :param s: pandas.Series of lists of (key, value) pairs + :return: pandas.Series of dictionaries """ return s.apply(lambda m: None if m is None else {k: v for k, v in m}) + + +def _convert_dict_to_map_items(s): + """ + Convert a series of dictionaries to list of (key, value) pairs to match expected data + for Arrow column of map type. + :param s: pandas.Series of dictionaries + :return: pandas.Series of lists of (key, value) pairs + """ + return s.apply(lambda d: list(d.items()) if d is not None else None) diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py index 3698a3dbbea42..a9c1ca3eaf07a 100644 --- a/python/pyspark/sql/tests/test_arrow.py +++ b/python/pyspark/sql/tests/test_arrow.py @@ -346,9 +346,6 @@ def test_createDataFrame_with_map_type(self): with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}): df = self.spark.createDataFrame(pdf, schema=schema) - # pyarrow currently requires a list of pairs for each map - pdf["m"] = [list(m.items()) if m is not None else None for m in map_data] - df_arrow = self.spark.createDataFrame(pdf, schema=schema) result = df.collect() From cbe6c2381131a4a88dee3b6aa4221ce1de51c674 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Fri, 13 Nov 2020 15:31:06 -0800 Subject: [PATCH 06/12] Added test for pandas scalar udf --- python/pyspark/sql/tests/test_pandas_udf_scalar.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/python/pyspark/sql/tests/test_pandas_udf_scalar.py b/python/pyspark/sql/tests/test_pandas_udf_scalar.py index 6d325c9085ce1..aad0e902d88da 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_scalar.py +++ b/python/pyspark/sql/tests/test_pandas_udf_scalar.py @@ -379,6 +379,15 @@ def test_vectorized_udf_nested_struct(self): 'Invalid return type with scalar Pandas UDFs'): pandas_udf(lambda x: x, returnType=nested_type, functionType=udf_type) + def test_vectorized_udf_map_type(self): + data = [({},), ({"a": 1},), ({"a": 1, "b": 2},), ({"a": 1, "b": 2, "c": 3},)] + schema = StructType([StructField("map", MapType(StringType(), LongType()))]) + df = self.spark.createDataFrame(data, schema=schema) + for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]: + map_f = pandas_udf(lambda x: x, MapType(StringType(), LongType()), udf_type) + result = df.select(map_f(col('map'))) + self.assertEquals(df.collect(), result.collect()) + def test_vectorized_udf_complex(self): df = self.spark.range(10).select( col('id').cast('int').alias('a'), From 2db803b29a1b79551b055e19fc79235fd26e9a16 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Fri, 13 Nov 2020 22:38:26 -0800 Subject: [PATCH 07/12] Added checks to disable map type for pyarrow < 2 --- python/pyspark/sql/pandas/types.py | 7 ++++ python/pyspark/sql/tests/test_arrow.py | 42 +++++++++++++------ .../sql/tests/test_pandas_udf_scalar.py | 12 ++++-- 3 files changed, 45 insertions(+), 16 deletions(-) diff --git a/python/pyspark/sql/pandas/types.py b/python/pyspark/sql/pandas/types.py index edf2208969878..7e4d61b0d21b8 100644 --- a/python/pyspark/sql/pandas/types.py +++ b/python/pyspark/sql/pandas/types.py @@ -28,6 +28,7 @@ def to_arrow_type(dt): """ Convert Spark data type to pyarrow type """ + from distutils.version import LooseVersion import pyarrow as pa if type(dt) == BooleanType: arrow_type = pa.bool_() @@ -59,6 +60,8 @@ def to_arrow_type(dt): raise TypeError("Unsupported type in conversion to Arrow: " + str(dt)) arrow_type = pa.list_(to_arrow_type(dt.elementType)) elif type(dt) == MapType: + if LooseVersion(pa.__version__) < LooseVersion("2.0.0"): + raise TypeError("MapType is only supported with pyarrow 2.0.0 and above") if type(dt.keyType) in [StructType, TimestampType] or \ type(dt.valueType) in [StructType, TimestampType]: raise TypeError("Unsupported type in conversion to Arrow: " + str(dt)) @@ -86,6 +89,8 @@ def to_arrow_schema(schema): def from_arrow_type(at): """ Convert pyarrow type to Spark data type. """ + from distutils.version import LooseVersion + import pyarrow as pa import pyarrow.types as types if types.is_boolean(at): spark_type = BooleanType() @@ -116,6 +121,8 @@ def from_arrow_type(at): raise TypeError("Unsupported type in conversion from Arrow: " + str(at)) spark_type = ArrayType(from_arrow_type(at.value_type)) elif types.is_map(at): + if LooseVersion(pa.__version__) < LooseVersion("2.0.0"): + raise TypeError("MapType is only supported with pyarrow 2.0.0 and above") if types.is_timestamp(at.key_type) or types.is_timestamp(at.item_type): raise TypeError("Unsupported type in conversion from Arrow: " + str(at)) spark_type = MapType(from_arrow_type(at.key_type), from_arrow_type(at.item_type)) diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py index a9c1ca3eaf07a..4d3cb759d30dd 100644 --- a/python/pyspark/sql/tests/test_arrow.py +++ b/python/pyspark/sql/tests/test_arrow.py @@ -21,6 +21,7 @@ import time import unittest import warnings +from distutils.version import LooseVersion from pyspark import SparkContext, SparkConf from pyspark.sql import Row, SparkSession @@ -346,17 +347,22 @@ def test_createDataFrame_with_map_type(self): with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}): df = self.spark.createDataFrame(pdf, schema=schema) - df_arrow = self.spark.createDataFrame(pdf, schema=schema) + if LooseVersion(pa.__version__) < LooseVersion("2.0.0"): + with QuietTest(self.sc): + with self.assertRaisesRegex(Exception, "MapType.*only.*pyarrow 2.0.0"): + self.spark.createDataFrame(pdf, schema=schema) + else: + df_arrow = self.spark.createDataFrame(pdf, schema=schema) - result = df.collect() - result_arrow = df_arrow.collect() + result = df.collect() + result_arrow = df_arrow.collect() - self.assertEqual(len(result), len(result_arrow)) - for row, row_arrow in zip(result, result_arrow): - i, m = row - _, m_arrow = row_arrow - self.assertEqual(m, map_data[i]) - self.assertEqual(m_arrow, map_data[i]) + self.assertEqual(len(result), len(result_arrow)) + for row, row_arrow in zip(result, result_arrow): + i, m = row + _, m_arrow = row_arrow + self.assertEqual(m, map_data[i]) + self.assertEqual(m_arrow, map_data[i]) def test_toPandas_with_map_type(self): pdf = pd.DataFrame({"id": [0, 1, 2, 3], @@ -365,8 +371,13 @@ def test_toPandas_with_map_type(self): with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}): df = self.spark.createDataFrame(pdf, schema="id long, m map") - pdf_non, pdf_arrow = self._toPandas_arrow_toggle(df) - assert_frame_equal(pdf_arrow, pdf_non) + if LooseVersion(pa.__version__) < LooseVersion("2.0.0"): + with QuietTest(self.sc): + with self.assertRaisesRegex(Exception, "MapType.*only.*pyarrow 2.0.0"): + df.toPandas() + else: + pdf_non, pdf_arrow = self._toPandas_arrow_toggle(df) + assert_frame_equal(pdf_arrow, pdf_non) def test_toPandas_with_map_type_nulls(self): pdf = pd.DataFrame({"id": [0, 1, 2, 3, 4], @@ -375,8 +386,13 @@ def test_toPandas_with_map_type_nulls(self): with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}): df = self.spark.createDataFrame(pdf, schema="id long, m map") - pdf_non, pdf_arrow = self._toPandas_arrow_toggle(df) - assert_frame_equal(pdf_arrow, pdf_non) + if LooseVersion(pa.__version__) < LooseVersion("2.0.0"): + with QuietTest(self.sc): + with self.assertRaisesRegex(Exception, "MapType.*only.*pyarrow 2.0.0"): + df.toPandas() + else: + pdf_non, pdf_arrow = self._toPandas_arrow_toggle(df) + assert_frame_equal(pdf_arrow, pdf_non) def test_createDataFrame_with_int_col_names(self): import numpy as np diff --git a/python/pyspark/sql/tests/test_pandas_udf_scalar.py b/python/pyspark/sql/tests/test_pandas_udf_scalar.py index aad0e902d88da..72972b89aba62 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_scalar.py +++ b/python/pyspark/sql/tests/test_pandas_udf_scalar.py @@ -22,6 +22,7 @@ import unittest from datetime import date, datetime from decimal import Decimal +from distutils.version import LooseVersion from pyspark import TaskContext from pyspark.rdd import PythonEvalType @@ -384,9 +385,14 @@ def test_vectorized_udf_map_type(self): schema = StructType([StructField("map", MapType(StringType(), LongType()))]) df = self.spark.createDataFrame(data, schema=schema) for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]: - map_f = pandas_udf(lambda x: x, MapType(StringType(), LongType()), udf_type) - result = df.select(map_f(col('map'))) - self.assertEquals(df.collect(), result.collect()) + if LooseVersion(pa.__version__) < LooseVersion("2.0.0"): + with QuietTest(self.sc): + with self.assertRaisesRegex(Exception, "MapType.*not supported"): + pandas_udf(lambda x: x, MapType(StringType(), LongType()), udf_type) + else: + map_f = pandas_udf(lambda x: x, MapType(StringType(), LongType()), udf_type) + result = df.select(map_f(col('map'))) + self.assertEquals(df.collect(), result.collect()) def test_vectorized_udf_complex(self): df = self.spark.range(10).select( From a92af2fe5e48239b154d997e055e32d34b2c61ce Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 16 Nov 2020 23:03:55 -0800 Subject: [PATCH 08/12] Remove TODO, MapAccessor only checks key/value vectors, not StructVector in Java --- .../org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala index 5b021e5917319..bdc3b5eed7d8d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala @@ -309,8 +309,6 @@ class ArrowWriterSuite extends SparkFunSuite { val map3 = reader.getMap(3) assert(map3 == null) writer.root.close() - - // TODO: figure why writer doesn't fail without setIndexDefined } test("empty map") { From dec27974b65eb8fc55f6d206be54e0f4f527587a Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 17 Nov 2020 09:20:48 -0800 Subject: [PATCH 09/12] Remove import no longer used --- python/pyspark/sql/tests/test_arrow.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py index 4d3cb759d30dd..e764c42d88a31 100644 --- a/python/pyspark/sql/tests/test_arrow.py +++ b/python/pyspark/sql/tests/test_arrow.py @@ -27,8 +27,7 @@ from pyspark.sql import Row, SparkSession from pyspark.sql.functions import udf from pyspark.sql.types import StructType, StringType, IntegerType, LongType, \ - FloatType, DoubleType, DecimalType, DateType, TimestampType, BinaryType, StructField, MapType, \ - ArrayType + FloatType, DoubleType, DecimalType, DateType, TimestampType, BinaryType, StructField, ArrayType from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \ pandas_requirement_message, pyarrow_requirement_message from pyspark.testing.utils import QuietTest From 78b2604713a56e80c8b7f4e9860fdabf366b9004 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 17 Nov 2020 11:49:09 -0800 Subject: [PATCH 10/12] Corrected remaining tests that use map as unsupported type --- python/pyspark/sql/tests/test_pandas_cogrouped_map.py | 4 ++-- python/pyspark/sql/tests/test_pandas_grouped_map.py | 7 +++---- python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py | 4 ++-- python/pyspark/sql/tests/test_pandas_udf_scalar.py | 8 ++++---- 4 files changed, 11 insertions(+), 12 deletions(-) diff --git a/python/pyspark/sql/tests/test_pandas_cogrouped_map.py b/python/pyspark/sql/tests/test_pandas_cogrouped_map.py index f9a7dd69b61fb..4afc1dfcc1c6e 100644 --- a/python/pyspark/sql/tests/test_pandas_cogrouped_map.py +++ b/python/pyspark/sql/tests/test_pandas_cogrouped_map.py @@ -176,9 +176,9 @@ def test_wrong_return_type(self): with QuietTest(self.sc): with self.assertRaisesRegexp( NotImplementedError, - 'Invalid return type.*MapType'): + 'Invalid return type.*ArrayType.*TimestampType'): left.groupby('id').cogroup(right.groupby('id')).applyInPandas( - lambda l, r: l, 'id long, v map') + lambda l, r: l, 'id long, v array') def test_wrong_args(self): left = self.data1 diff --git a/python/pyspark/sql/tests/test_pandas_grouped_map.py b/python/pyspark/sql/tests/test_pandas_grouped_map.py index 93e37125eaa33..ee68b95fc478d 100644 --- a/python/pyspark/sql/tests/test_pandas_grouped_map.py +++ b/python/pyspark/sql/tests/test_pandas_grouped_map.py @@ -26,7 +26,7 @@ window from pyspark.sql.types import IntegerType, DoubleType, ArrayType, BinaryType, ByteType, \ LongType, DecimalType, ShortType, FloatType, StringType, BooleanType, StructType, \ - StructField, NullType, MapType, TimestampType + StructField, NullType, TimestampType from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \ pandas_requirement_message, pyarrow_requirement_message from pyspark.testing.utils import QuietTest @@ -246,10 +246,10 @@ def test_wrong_return_type(self): with QuietTest(self.sc): with self.assertRaisesRegexp( NotImplementedError, - 'Invalid return type.*grouped map Pandas UDF.*MapType'): + 'Invalid return type.*grouped map Pandas UDF.*ArrayType.*TimestampType'): pandas_udf( lambda pdf: pdf, - 'id long, v map', + 'id long, v array', PandasUDFType.GROUPED_MAP) def test_wrong_args(self): @@ -276,7 +276,6 @@ def test_wrong_args(self): def test_unsupported_types(self): common_err_msg = 'Invalid return type.*grouped map Pandas UDF.*' unsupported_types = [ - StructField('map', MapType(StringType(), IntegerType())), StructField('arr_ts', ArrayType(TimestampType())), StructField('null', NullType()), StructField('struct', StructType([StructField('l', LongType())])), diff --git a/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py b/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py index 451308927629b..2cbcf31f6e7b3 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py +++ b/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py @@ -21,7 +21,7 @@ from pyspark.sql import Row from pyspark.sql.functions import array, explode, col, lit, mean, sum, \ udf, pandas_udf, PandasUDFType -from pyspark.sql.types import ArrayType, TimestampType, DoubleType, MapType +from pyspark.sql.types import ArrayType, TimestampType from pyspark.sql.utils import AnalysisException from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \ pandas_requirement_message, pyarrow_requirement_message @@ -159,7 +159,7 @@ def mean_and_std_udf(v): with QuietTest(self.sc): with self.assertRaisesRegexp(NotImplementedError, 'not supported'): - @pandas_udf(MapType(DoubleType(), DoubleType()), PandasUDFType.GROUPED_AGG) + @pandas_udf(ArrayType(TimestampType()), PandasUDFType.GROUPED_AGG) def mean_and_std_udf(v): return {v.mean(): v.std()} diff --git a/python/pyspark/sql/tests/test_pandas_udf_scalar.py b/python/pyspark/sql/tests/test_pandas_udf_scalar.py index 72972b89aba62..5da5d043ceca4 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_scalar.py +++ b/python/pyspark/sql/tests/test_pandas_udf_scalar.py @@ -519,8 +519,8 @@ def test_vectorized_udf_wrong_return_type(self): for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]: with self.assertRaisesRegexp( NotImplementedError, - 'Invalid return type.*scalar Pandas UDF.*MapType'): - pandas_udf(lambda x: x, MapType(LongType(), LongType()), udf_type) + 'Invalid return type.*scalar Pandas UDF.*ArrayType.*TimestampType'): + pandas_udf(lambda x: x, ArrayType(TimestampType()), udf_type) def test_vectorized_udf_return_scalar(self): df = self.spark.range(10) @@ -592,8 +592,8 @@ def test_vectorized_udf_unsupported_types(self): for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]: with self.assertRaisesRegexp( NotImplementedError, - 'Invalid return type.*scalar Pandas UDF.*MapType'): - pandas_udf(lambda x: x, MapType(StringType(), IntegerType()), udf_type) + 'Invalid return type.*scalar Pandas UDF.*ArrayType.*TimestampType'): + pandas_udf(lambda x: x, ArrayType(TimestampType()), udf_type) with self.assertRaisesRegexp( NotImplementedError, 'Invalid return type.*scalar Pandas UDF.*ArrayType.StructType'): From b257470bccb363ca1ae10a96126c3d49d10ad8c6 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 17 Nov 2020 22:20:46 -0800 Subject: [PATCH 11/12] Update docs to remove MapType from unsupported --- python/docs/source/user_guide/arrow_pandas.rst | 2 +- python/pyspark/sql/pandas/functions.py | 1 - .../src/main/scala/org/apache/spark/sql/internal/SQLConf.scala | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/python/docs/source/user_guide/arrow_pandas.rst b/python/docs/source/user_guide/arrow_pandas.rst index fe04315f87ad5..4eece98d886f4 100644 --- a/python/docs/source/user_guide/arrow_pandas.rst +++ b/python/docs/source/user_guide/arrow_pandas.rst @@ -341,7 +341,7 @@ Supported SQL Types .. currentmodule:: pyspark.sql.types -Currently, all Spark SQL data types are supported by Arrow-based conversion except :class:`MapType`, +Currently, all Spark SQL data types are supported by Arrow-based conversion except :class:`ArrayType` of :class:`TimestampType`, and nested :class:`StructType`. Setting Arrow Batch Size diff --git a/python/pyspark/sql/pandas/functions.py b/python/pyspark/sql/pandas/functions.py index 16462e8702a0b..750aa4b0e6c56 100644 --- a/python/pyspark/sql/pandas/functions.py +++ b/python/pyspark/sql/pandas/functions.py @@ -284,7 +284,6 @@ def calculate(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]: should be checked for accuracy by users. Currently, - :class:`pyspark.sql.types.MapType`, :class:`pyspark.sql.types.ArrayType` of :class:`pyspark.sql.types.TimestampType` and nested :class:`pyspark.sql.types.StructType` are currently not supported as output types. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 21357a492e39e..d7fd2511cc979 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1867,7 +1867,7 @@ object SQLConf { "1. pyspark.sql.DataFrame.toPandas " + "2. pyspark.sql.SparkSession.createDataFrame when its input is a Pandas DataFrame " + "The following data types are unsupported: " + - "MapType, ArrayType of TimestampType, and nested StructType.") + "ArrayType of TimestampType, and nested StructType.") .version("3.0.0") .fallbackConf(ARROW_EXECUTION_ENABLED) From 3f2ef9867d280b232010cae6b4bee76e8ca1e25d Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 17 Nov 2020 22:48:02 -0800 Subject: [PATCH 12/12] Add note only supported for pyarrow 2 --- python/docs/source/user_guide/arrow_pandas.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/python/docs/source/user_guide/arrow_pandas.rst b/python/docs/source/user_guide/arrow_pandas.rst index 4eece98d886f4..91d8155523391 100644 --- a/python/docs/source/user_guide/arrow_pandas.rst +++ b/python/docs/source/user_guide/arrow_pandas.rst @@ -343,6 +343,7 @@ Supported SQL Types Currently, all Spark SQL data types are supported by Arrow-based conversion except :class:`ArrayType` of :class:`TimestampType`, and nested :class:`StructType`. +:class: `MapType` is only supported when using PyArrow 2.0.0 and above. Setting Arrow Batch Size ~~~~~~~~~~~~~~~~~~~~~~~~