diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index cd4da39d62fe..e9572435ecb7 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -480,6 +480,11 @@ class SparkConnectPlanner(val session: SparkSession) { pythonUdf, pythonUdf.dataType.asInstanceOf[StructType].toAttributes, transformRelation(rel.getInput)) + case PythonEvalType.SQL_MAP_ARROW_ITER_UDF => + logical.PythonMapInArrow( + pythonUdf, + pythonUdf.dataType.asInstanceOf[StructType].toAttributes, + transformRelation(rel.getInput)) case _ => throw InvalidPlanInput(s"Function with EvalType: ${pythonUdf.evalType} is not supported") } diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index d8dee651c2b5..751f0687f2c8 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -535,6 +535,7 @@ def __hash__(self): "pyspark.sql.tests.connect.test_parity_udf", "pyspark.sql.tests.connect.test_parity_pandas_udf", "pyspark.sql.tests.connect.test_parity_pandas_map", + "pyspark.sql.tests.connect.test_parity_arrow_map", ], excluded_python_implementations=[ "PyPy" # Skip these tests under PyPy since they require numpy, pandas, and pyarrow and diff --git a/python/pyspark/sql/connect/_typing.py b/python/pyspark/sql/connect/_typing.py index c91d4e629d84..6df3f15d87dd 100644 --- a/python/pyspark/sql/connect/_typing.py +++ b/python/pyspark/sql/connect/_typing.py @@ -26,6 +26,7 @@ import datetime import decimal +import pyarrow from pandas.core.frame import DataFrame as PandasDataFrame from pyspark.sql.connect.column import Column @@ -50,6 +51,8 @@ PandasMapIterFunction = Callable[[Iterable[DataFrameLike]], Iterable[DataFrameLike]] +ArrowMapIterFunction = Callable[[Iterable[pyarrow.RecordBatch]], Iterable[pyarrow.RecordBatch]] + class UserDefinedFunctionLike(Protocol): func: Callable[..., Any] diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 69921896f467..0e114f9fedbb 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -76,6 +76,7 @@ PrimitiveType, OptionalPrimitiveType, PandasMapIterFunction, + ArrowMapIterFunction, ) from pyspark.sql.connect.session import SparkSession @@ -1572,8 +1573,11 @@ def registerTempTable(self, *args: Any, **kwargs: Any) -> None: def storageLevel(self, *args: Any, **kwargs: Any) -> None: raise NotImplementedError("storageLevel() is not implemented.") - def mapInPandas( - self, func: "PandasMapIterFunction", schema: Union[StructType, str] + def _map_partitions( + self, + func: "PandasMapIterFunction", + schema: Union[StructType, str], + evalType: int, ) -> "DataFrame": from pyspark.sql.connect.udf import UserDefinedFunction @@ -1581,7 +1585,9 @@ def mapInPandas( raise Exception("Cannot mapInPandas when self._plan is empty.") udf_obj = UserDefinedFunction( - func, returnType=schema, evalType=PythonEvalType.SQL_MAP_PANDAS_ITER_UDF + func, + returnType=schema, + evalType=evalType, ) return DataFrame.withPlan( @@ -1589,10 +1595,19 @@ def mapInPandas( session=self._session, ) + def mapInPandas( + self, func: "PandasMapIterFunction", schema: Union[StructType, str] + ) -> "DataFrame": + return self._map_partitions(func, schema, PythonEvalType.SQL_MAP_PANDAS_ITER_UDF) + mapInPandas.__doc__ = PySparkDataFrame.mapInPandas.__doc__ - def mapInArrow(self, *args: Any, **kwargs: Any) -> None: - raise NotImplementedError("mapInArrow() is not implemented.") + def mapInArrow( + self, func: "ArrowMapIterFunction", schema: Union[StructType, str] + ) -> "DataFrame": + return self._map_partitions(func, schema, PythonEvalType.SQL_MAP_ARROW_ITER_UDF) + + mapInArrow.__doc__ = PySparkDataFrame.mapInArrow.__doc__ def writeStream(self, *args: Any, **kwargs: Any) -> None: raise NotImplementedError("writeStream() is not implemented.") diff --git a/python/pyspark/sql/tests/connect/test_parity_arrow_map.py b/python/pyspark/sql/tests/connect/test_parity_arrow_map.py new file mode 100644 index 000000000000..ed51d0d3d199 --- /dev/null +++ b/python/pyspark/sql/tests/connect/test_parity_arrow_map.py @@ -0,0 +1,37 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + +from pyspark.sql.tests.test_arrow_map import MapInArrowTestsMixin +from pyspark.testing.connectutils import ReusedConnectTestCase + + +class ArrowMapParityTests(MapInArrowTestsMixin, ReusedConnectTestCase): + pass + + +if __name__ == "__main__": + from pyspark.sql.tests.connect.test_parity_arrow_map import * # noqa: F401 + + try: + import xmlrunner # type: ignore[import] + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_arrow_map.py b/python/pyspark/sql/tests/test_arrow_map.py index 6166cc5dcc8d..ff3d9b96b6b5 100644 --- a/python/pyspark/sql/tests/test_arrow_map.py +++ b/python/pyspark/sql/tests/test_arrow_map.py @@ -37,28 +37,7 @@ not have_pandas or not have_pyarrow, pandas_requirement_message or pyarrow_requirement_message, ) -class MapInArrowTests(ReusedSQLTestCase): - @classmethod - def setUpClass(cls): - ReusedSQLTestCase.setUpClass() - - # Synchronize default timezone between Python and Java - cls.tz_prev = os.environ.get("TZ", None) # save current tz if set - tz = "America/Los_Angeles" - os.environ["TZ"] = tz - time.tzset() - - cls.sc.environment["TZ"] = tz - cls.spark.conf.set("spark.sql.session.timeZone", tz) - - @classmethod - def tearDownClass(cls): - del os.environ["TZ"] - if cls.tz_prev is not None: - os.environ["TZ"] = cls.tz_prev - time.tzset() - ReusedSQLTestCase.tearDownClass() - +class MapInArrowTestsMixin(object): def test_map_in_arrow(self): def func(iterator): for batch in iterator: @@ -126,6 +105,29 @@ def test_self_join(self): self.assertEqual(sorted(actual), sorted(expected)) +class MapInArrowTests(MapInArrowTestsMixin, ReusedSQLTestCase): + @classmethod + def setUpClass(cls): + ReusedSQLTestCase.setUpClass() + + # Synchronize default timezone between Python and Java + cls.tz_prev = os.environ.get("TZ", None) # save current tz if set + tz = "America/Los_Angeles" + os.environ["TZ"] = tz + time.tzset() + + cls.sc.environment["TZ"] = tz + cls.spark.conf.set("spark.sql.session.timeZone", tz) + + @classmethod + def tearDownClass(cls): + del os.environ["TZ"] + if cls.tz_prev is not None: + os.environ["TZ"] = cls.tz_prev + time.tzset() + ReusedSQLTestCase.tearDownClass() + + if __name__ == "__main__": from pyspark.sql.tests.test_arrow_map import * # noqa: F401