diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 043734aff71f5..29f526b2d395a 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -67,6 +67,9 @@ private[spark] object PythonEvalType { val SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF = 213 val SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF = 214 + // Arrow UDFs + val SQL_SCALAR_ARROW_UDF = 250 + val SQL_TABLE_UDF = 300 val SQL_ARROW_TABLE_UDF = 301 @@ -93,6 +96,7 @@ private[spark] object PythonEvalType { case SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF => "SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF" case SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF => "SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF" + case SQL_SCALAR_ARROW_UDF => "SQL_SCALAR_ARROW_UDF" } } diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 88fff94efee6d..f6f2db8fb9d6d 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -541,6 +541,8 @@ def __hash__(self): "pyspark.sql.tests.arrow.test_arrow_cogrouped_map", "pyspark.sql.tests.arrow.test_arrow_grouped_map", "pyspark.sql.tests.arrow.test_arrow_python_udf", + "pyspark.sql.tests.arrow.test_arrow_udf", + "pyspark.sql.tests.arrow.test_arrow_udf_scalar", "pyspark.sql.tests.pandas.test_pandas_cogrouped_map", "pyspark.sql.tests.pandas.test_pandas_grouped_map", "pyspark.sql.tests.pandas.test_pandas_grouped_map_with_state", @@ -1100,6 +1102,8 @@ def __hash__(self): "pyspark.sql.tests.connect.arrow.test_parity_arrow_grouped_map", "pyspark.sql.tests.connect.arrow.test_parity_arrow_cogrouped_map", "pyspark.sql.tests.connect.arrow.test_parity_arrow_python_udf", + "pyspark.sql.tests.connect.arrow.test_parity_arrow_udf", + "pyspark.sql.tests.connect.arrow.test_parity_arrow_udf_scalar", "pyspark.sql.tests.connect.pandas.test_parity_pandas_map", "pyspark.sql.tests.connect.pandas.test_parity_pandas_grouped_map", "pyspark.sql.tests.connect.pandas.test_parity_pandas_grouped_map_with_state", diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 01566644071c8..101ae06f10edc 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -85,7 +85,7 @@ ) from pyspark.sql.connect.functions import builtin as F from pyspark.sql.pandas.types import from_arrow_schema, to_arrow_schema -from pyspark.sql.pandas.functions import _validate_pandas_udf # type: ignore[attr-defined] +from pyspark.sql.pandas.functions import _validate_vectorized_udf # type: ignore[attr-defined] from pyspark.sql.table_arg import TableArg @@ -2054,7 +2054,7 @@ def _map_partitions( ) -> ParentDataFrame: from pyspark.sql.connect.udf import UserDefinedFunction - _validate_pandas_udf(func, evalType) + _validate_vectorized_udf(func, evalType) if isinstance(schema, str): schema = cast(StructType, self._session._parse_ddl(schema)) udf_obj = UserDefinedFunction( diff --git a/python/pyspark/sql/connect/group.py b/python/pyspark/sql/connect/group.py index 8e1d46d66abd2..b7b0473c13ceb 100644 --- a/python/pyspark/sql/connect/group.py +++ b/python/pyspark/sql/connect/group.py @@ -34,7 +34,7 @@ from pyspark.util import PythonEvalType from pyspark.sql.group import GroupedData as PySparkGroupedData from pyspark.sql.pandas.group_ops import PandasCogroupedOps as PySparkPandasCogroupedOps -from pyspark.sql.pandas.functions import _validate_pandas_udf # type: ignore[attr-defined] +from pyspark.sql.pandas.functions import _validate_vectorized_udf # type: ignore[attr-defined] from pyspark.sql.types import NumericType, StructType import pyspark.sql.connect.plan as plan @@ -294,7 +294,7 @@ def applyInPandas( from pyspark.sql.connect.udf import UserDefinedFunction from pyspark.sql.connect.dataframe import DataFrame - _validate_pandas_udf(func, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) + _validate_vectorized_udf(func, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF) if isinstance(schema, str): schema = cast(StructType, self._df._session._parse_ddl(schema)) udf_obj = UserDefinedFunction( @@ -329,7 +329,7 @@ def applyInPandasWithState( from pyspark.sql.connect.udf import UserDefinedFunction from pyspark.sql.connect.dataframe import DataFrame - _validate_pandas_udf(func, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE) + _validate_vectorized_udf(func, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE) udf_obj = UserDefinedFunction( func, returnType=outputStructType, @@ -472,7 +472,7 @@ def applyInArrow( from pyspark.sql.connect.udf import UserDefinedFunction from pyspark.sql.connect.dataframe import DataFrame - _validate_pandas_udf(func, PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF) + _validate_vectorized_udf(func, PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF) if isinstance(schema, str): schema = cast(StructType, self._df._session._parse_ddl(schema)) udf_obj = UserDefinedFunction( @@ -517,7 +517,7 @@ def applyInPandas( from pyspark.sql.connect.udf import UserDefinedFunction from pyspark.sql.connect.dataframe import DataFrame - _validate_pandas_udf(func, PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF) + _validate_vectorized_udf(func, PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF) if isinstance(schema, str): schema = cast(StructType, self._gd1._df._session._parse_ddl(schema)) udf_obj = UserDefinedFunction( @@ -548,7 +548,7 @@ def applyInArrow( from pyspark.sql.connect.udf import UserDefinedFunction from pyspark.sql.connect.dataframe import DataFrame - _validate_pandas_udf(func, PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF) + _validate_vectorized_udf(func, PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF) if isinstance(schema, str): schema = cast(StructType, self._gd1._df._session._parse_ddl(schema)) udf_obj = UserDefinedFunction( diff --git a/python/pyspark/sql/connect/udf.py b/python/pyspark/sql/connect/udf.py index cd87a3ef74eaf..46ba6bf8e8e4a 100644 --- a/python/pyspark/sql/connect/udf.py +++ b/python/pyspark/sql/connect/udf.py @@ -276,6 +276,7 @@ def register( PythonEvalType.SQL_BATCHED_UDF, PythonEvalType.SQL_ARROW_BATCHED_UDF, PythonEvalType.SQL_SCALAR_PANDAS_UDF, + PythonEvalType.SQL_SCALAR_ARROW_UDF, PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, ]: diff --git a/python/pyspark/sql/pandas/_typing/__init__.pyi b/python/pyspark/sql/pandas/_typing/__init__.pyi index 193d368f6ebde..07ffcb7cd4fd4 100644 --- a/python/pyspark/sql/pandas/_typing/__init__.pyi +++ b/python/pyspark/sql/pandas/_typing/__init__.pyi @@ -60,6 +60,81 @@ PandasGroupedMapUDFTransformWithStateInitStateType = Literal[212] GroupedMapUDFTransformWithStateType = Literal[213] GroupedMapUDFTransformWithStateInitStateType = Literal[214] +# Arrow UDFs +ArrowScalarUDFType = Literal[250] + +class ArrowVariadicScalarToScalarFunction(Protocol): + def __call__(self, *_: pyarrow.Array) -> pyarrow.Array: ... + +ArrowScalarToScalarFunction = Union[ + ArrowVariadicScalarToScalarFunction, + Callable[[pyarrow.Array], pyarrow.Array], + Callable[[pyarrow.Array, pyarrow.Array], pyarrow.Array], + Callable[[pyarrow.Array, pyarrow.Array, pyarrow.Array], pyarrow.Array], + Callable[[pyarrow.Array, pyarrow.Array, pyarrow.Array, pyarrow.Array], pyarrow.Array], + Callable[ + [pyarrow.Array, pyarrow.Array, pyarrow.Array, pyarrow.Array, pyarrow.Array], pyarrow.Array + ], + Callable[ + [pyarrow.Array, pyarrow.Array, pyarrow.Array, pyarrow.Array, pyarrow.Array, pyarrow.Array], + pyarrow.Array, + ], + Callable[ + [ + pyarrow.Array, + pyarrow.Array, + pyarrow.Array, + pyarrow.Array, + pyarrow.Array, + pyarrow.Array, + pyarrow.Array, + ], + pyarrow.Array, + ], + Callable[ + [ + pyarrow.Array, + pyarrow.Array, + pyarrow.Array, + pyarrow.Array, + pyarrow.Array, + pyarrow.Array, + pyarrow.Array, + pyarrow.Array, + ], + pyarrow.Array, + ], + Callable[ + [ + pyarrow.Array, + pyarrow.Array, + pyarrow.Array, + pyarrow.Array, + pyarrow.Array, + pyarrow.Array, + pyarrow.Array, + pyarrow.Array, + pyarrow.Array, + ], + pyarrow.Array, + ], + Callable[ + [ + pyarrow.Array, + pyarrow.Array, + pyarrow.Array, + pyarrow.Array, + pyarrow.Array, + pyarrow.Array, + pyarrow.Array, + pyarrow.Array, + pyarrow.Array, + pyarrow.Array, + ], + pyarrow.Array, + ], +] + class PandasVariadicScalarToScalarFunction(Protocol): def __call__(self, *_: DataFrameOrSeriesLike_) -> DataFrameOrSeriesLike_: ... diff --git a/python/pyspark/sql/pandas/functions.py b/python/pyspark/sql/pandas/functions.py index be8ffacfa3d7b..e7147c69ed08c 100644 --- a/python/pyspark/sql/pandas/functions.py +++ b/python/pyspark/sql/pandas/functions.py @@ -41,9 +41,28 @@ class PandasUDFType: GROUPED_AGG = PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF +class ArrowUDFType: + """Arrow UDF Types. See :meth:`pyspark.sql.functions.arrow_udf`.""" + + SCALAR = PythonEvalType.SQL_SCALAR_ARROW_UDF + + +def arrow_udf(f=None, returnType=None, functionType=None): + return vectorized_udf(f, returnType, functionType, "arrow") + + def pandas_udf(f=None, returnType=None, functionType=None): + return vectorized_udf(f, returnType, functionType, "pandas") + + +def vectorized_udf( + f=None, + returnType=None, + functionType=None, + kind: str = "pandas", +): """ - Creates a pandas user defined function (a.k.a. vectorized user defined function). + Creates a vectorized user defined function. Pandas UDFs are user defined functions that are executed by Spark using Arrow to transfer data and Pandas to work with the data, which allows vectorized operations. A Pandas UDF @@ -372,6 +391,8 @@ def calculate(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]: require_minimum_pandas_version() require_minimum_pyarrow_version() + assert kind in ["pandas", "arrow"], "kind should be either 'pandas' or 'arrow'" + # decorator @pandas_udf(returnType, functionType) is_decorator = f is None or isinstance(f, (str, DataType)) @@ -404,7 +425,7 @@ def calculate(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]: messageParameters={"arg_name": "returnType"}, ) - if eval_type not in [ + if kind == "pandas" and eval_type not in [ PythonEvalType.SQL_SCALAR_PANDAS_UDF, PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, @@ -428,15 +449,38 @@ def calculate(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]: "arg_type": str(eval_type), }, ) + if kind == "arrow" and eval_type not in [ + PythonEvalType.SQL_SCALAR_ARROW_UDF, + None, + ]: # None means it should infer the type from type hints. + raise PySparkTypeError( + errorClass="INVALID_PANDAS_UDF_TYPE", + messageParameters={ + "arg_name": "functionType", + "arg_type": str(eval_type), + }, + ) if is_decorator: - return functools.partial(_create_pandas_udf, returnType=return_type, evalType=eval_type) + return functools.partial( + _create_vectorized_udf, + returnType=return_type, + evalType=eval_type, + kind=kind, + ) else: - return _create_pandas_udf(f=f, returnType=return_type, evalType=eval_type) + return _create_vectorized_udf( + f=f, + returnType=return_type, + evalType=eval_type, + kind=kind, + ) # validate the pandas udf and return the adjusted eval type -def _validate_pandas_udf(f, evalType) -> int: +def _validate_vectorized_udf(f, evalType, kind: str = "pandas") -> int: + assert kind in ["pandas", "arrow"], "kind should be either 'pandas' or 'arrow'" + argspec = getfullargspec(f) # pandas UDF by type hints. @@ -482,11 +526,17 @@ def _validate_pandas_udf(f, evalType) -> int: if evalType is None: # Set default is scalar UDF. - evalType = PythonEvalType.SQL_SCALAR_PANDAS_UDF + if kind == "pandas": + evalType = PythonEvalType.SQL_SCALAR_PANDAS_UDF + else: + evalType = PythonEvalType.SQL_SCALAR_ARROW_UDF + + kind_str = "pandas_udfs" if kind == "pandas" else "arrow_udfs" if ( ( evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF + or evalType == PythonEvalType.SQL_SCALAR_ARROW_UDF or evalType == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF ) and len(argspec.args) == 0 @@ -495,7 +545,7 @@ def _validate_pandas_udf(f, evalType) -> int: raise PySparkValueError( errorClass="INVALID_PANDAS_UDF", messageParameters={ - "detail": "0-arg pandas_udfs are not supported. " + "detail": f"0-arg {kind_str} are not supported. " "Instead, create a 1-arg pandas_udf and ignore the arg in your function.", }, ) @@ -504,7 +554,7 @@ def _validate_pandas_udf(f, evalType) -> int: raise PySparkValueError( errorClass="INVALID_PANDAS_UDF", messageParameters={ - "detail": "pandas_udf with function type GROUPED_MAP or the function in " + "detail": f"{kind_str} with function type GROUPED_MAP or the function in " "groupby.applyInPandas must take either one argument (data) or " "two arguments (key, data).", }, @@ -540,8 +590,8 @@ def _validate_pandas_udf(f, evalType) -> int: return evalType -def _create_pandas_udf(f, returnType, evalType): - evalType = _validate_pandas_udf(f, evalType) +def _create_vectorized_udf(f, returnType, evalType, kind): + evalType = _validate_vectorized_udf(f, evalType, kind) if is_remote(): from pyspark.sql.connect.udf import _create_udf as _create_connect_udf diff --git a/python/pyspark/sql/pandas/functions.pyi b/python/pyspark/sql/pandas/functions.pyi index b053b93a278e1..b82a39fff2ae4 100644 --- a/python/pyspark/sql/pandas/functions.pyi +++ b/python/pyspark/sql/pandas/functions.pyi @@ -22,6 +22,7 @@ from typing import Union, Callable from pyspark.sql._typing import ( AtomicDataTypeOrString, UserDefinedFunctionLike, + DataTypeOrString, ) from pyspark.sql.pandas._typing import ( GroupedMapPandasUserDefinedFunction, @@ -34,11 +35,13 @@ from pyspark.sql.pandas._typing import ( PandasScalarToScalarFunction, PandasScalarToStructFunction, PandasScalarUDFType, + ArrowScalarToScalarFunction, + ArrowScalarUDFType, ) from pyspark import since as since # noqa: F401 from pyspark.util import PythonEvalType as PythonEvalType # noqa: F401 -from pyspark.sql.types import ArrayType, StructType +from pyspark.sql.types import ArrayType, StructType, DataType class PandasUDFType: SCALAR: PandasScalarUDFType @@ -46,6 +49,27 @@ class PandasUDFType: GROUPED_MAP: PandasGroupedMapUDFType GROUPED_AGG: PandasGroupedAggUDFType +class ArrowUDFType: + SCALAR: ArrowScalarUDFType + +@overload +def arrow_udf( + f: ArrowScalarToScalarFunction, + returnType: DataTypeOrString, + functionType: ArrowScalarUDFType, +) -> UserDefinedFunctionLike: ... +@overload +def arrow_udf( + f: DataTypeOrString, returnType: ArrowScalarUDFType +) -> Callable[[ArrowScalarToScalarFunction], UserDefinedFunctionLike]: ... +@overload +def arrow_udf( + f: DataTypeOrString, *, functionType: ArrowScalarUDFType +) -> Callable[[ArrowScalarToScalarFunction], UserDefinedFunctionLike]: ... +@overload +def arrow_udf( + *, returnType: DataTypeOrString, functionType: ArrowScalarUDFType +) -> Callable[[ArrowScalarToScalarFunction], UserDefinedFunctionLike]: ... @overload def pandas_udf( f: PandasScalarToScalarFunction, diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index b3fa77f8ba204..b154318dc430c 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -629,6 +629,72 @@ def __repr__(self): return "ArrowStreamPandasUDFSerializer" +class ArrowStreamArrowUDFSerializer(ArrowStreamSerializer): + """ + Serializer used by Python worker to evaluate Arrow UDFs + """ + + def __init__( + self, + timezone, + safecheck, + assign_cols_by_name, + arrow_cast, + ): + super(ArrowStreamArrowUDFSerializer, self).__init__() + self._timezone = timezone + self._safecheck = safecheck + self._assign_cols_by_name = assign_cols_by_name + self._arrow_cast = arrow_cast + + def _create_array(self, arr, arrow_type, arrow_cast): + import pyarrow as pa + + assert isinstance(arr, pa.Array) + assert isinstance(arrow_type, pa.DataType) + + # TODO: should we handle timezone here? + + try: + return arr + except pa.lib.ArrowException: + if arrow_cast: + return arr.cast(target_type=arrow_type, safe=self._safecheck) + else: + raise + + def dump_stream(self, iterator, stream): + """ + Override because Arrow UDFs require a START_ARROW_STREAM before the Arrow stream is sent. + This should be sent after creating the first record batch so in case of an error, it can + be sent back to the JVM before the Arrow stream starts. + """ + import pyarrow as pa + + def wrap_and_init_stream(): + should_write_start_length = True + for packed in iterator: + if len(packed) == 2 and isinstance(packed[1], pa.DataType): + # single array UDF in a projection + arrs = [self._create_array(packed[0], packed[1], self._arrow_cast)] + else: + # multiple array UDFs in a projection + arrs = [self._create_array(t[0], t[1], self._arrow_cast) for t in packed] + + batch = pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in range(len(arrs))]) + + # Write the first record batch with initialization. + if should_write_start_length: + write_int(SpecialLengths.START_ARROW_STREAM, stream) + should_write_start_length = False + yield batch + + return ArrowStreamSerializer.dump_stream(self, wrap_and_init_stream(), stream) + + def __repr__(self): + return "ArrowStreamArrowUDFSerializer" + + class ArrowStreamPandasUDTFSerializer(ArrowStreamPandasUDFSerializer): """ Serializer used by Python worker to evaluate Arrow-optimized Python UDTFs. diff --git a/python/pyspark/sql/pandas/typehints.py b/python/pyspark/sql/pandas/typehints.py index 5335bcf03412b..fb22c11dccb7e 100644 --- a/python/pyspark/sql/pandas/typehints.py +++ b/python/pyspark/sql/pandas/typehints.py @@ -17,7 +17,7 @@ from inspect import Signature from typing import Any, Callable, Dict, Optional, Union, TYPE_CHECKING -from pyspark.sql.pandas.utils import require_minimum_pandas_version +from pyspark.sql.pandas.utils import require_minimum_pandas_version, require_minimum_pyarrow_version from pyspark.errors import PySparkNotImplementedError, PySparkValueError if TYPE_CHECKING: @@ -25,21 +25,29 @@ PandasScalarUDFType, PandasScalarIterUDFType, PandasGroupedAggUDFType, + ArrowScalarUDFType, ) def infer_eval_type( sig: Signature, type_hints: Dict[str, Any] -) -> Union["PandasScalarUDFType", "PandasScalarIterUDFType", "PandasGroupedAggUDFType"]: +) -> Union[ + "PandasScalarUDFType", + "PandasScalarIterUDFType", + "PandasGroupedAggUDFType", + "ArrowScalarUDFType", +]: """ Infers the evaluation type in :class:`pyspark.util.PythonEvalType` from :class:`inspect.Signature` instance and type hints. """ - from pyspark.sql.pandas.functions import PandasUDFType + from pyspark.sql.pandas.functions import PandasUDFType, ArrowUDFType require_minimum_pandas_version() + require_minimum_pyarrow_version() import pandas as pd + import pyarrow as pa annotations = {} for param in sig.parameters.values(): @@ -74,6 +82,12 @@ def infer_eval_type( for a in parameters_sig ) and (return_annotation == pd.Series or return_annotation == pd.DataFrame) + # pa.Array, ... -> pa.Array + is_arrow_array = all( + a == pa.Array or check_union_annotation(a, parameter_check_func=lambda na: na == pa.Array) + for a in parameters_sig + ) and (return_annotation == pa.Array) + # Iterator[Tuple[Series, Frame or Union[DataFrame, Series], ...] -> Iterator[Series or Frame] is_iterator_tuple_series_or_frame = ( len(parameters_sig) == 1 @@ -134,6 +148,8 @@ def infer_eval_type( if is_series_or_frame: return PandasUDFType.SCALAR + elif is_arrow_array: + return ArrowUDFType.SCALAR elif is_iterator_tuple_series_or_frame or is_iterator_series_or_frame: return PandasUDFType.SCALAR_ITER elif is_series_or_frame_agg: diff --git a/python/pyspark/sql/tests/arrow/test_arrow_udf.py b/python/pyspark/sql/tests/arrow/test_arrow_udf.py new file mode 100644 index 0000000000000..052ee1c3be830 --- /dev/null +++ b/python/pyspark/sql/tests/arrow/test_arrow_udf.py @@ -0,0 +1,222 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest +import datetime + +# TODO: import arrow_udf from public API +from pyspark.sql.pandas.functions import arrow_udf, ArrowUDFType, PandasUDFType +from pyspark.sql import functions as F +from pyspark.sql.types import ( + DoubleType, + StructType, + StructField, + LongType, + DayTimeIntervalType, + VariantType, +) +from pyspark.errors import ParseException, PySparkTypeError +from pyspark.util import PythonEvalType +from pyspark.testing.sqlutils import ( + ReusedSQLTestCase, + have_pyarrow, + pyarrow_requirement_message, +) + + +@unittest.skipIf(not have_pyarrow, pyarrow_requirement_message) +class ArrowUDFTestsMixin: + def test_arrow_udf_basic(self): + udf = arrow_udf(lambda x: x, DoubleType()) + self.assertEqual(udf.returnType, DoubleType()) + self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_ARROW_UDF) + + udf = arrow_udf(lambda x: x, VariantType()) + self.assertEqual(udf.returnType, VariantType()) + self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_ARROW_UDF) + + udf = arrow_udf(lambda x: x, DoubleType(), ArrowUDFType.SCALAR) + self.assertEqual(udf.returnType, DoubleType()) + self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_ARROW_UDF) + + udf = arrow_udf(lambda x: x, VariantType(), ArrowUDFType.SCALAR) + self.assertEqual(udf.returnType, VariantType()) + self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_ARROW_UDF) + + def test_arrow_udf_basic_with_return_type_string(self): + udf = arrow_udf(lambda x: x, "double", ArrowUDFType.SCALAR) + self.assertEqual(udf.returnType, DoubleType()) + self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_ARROW_UDF) + + udf = arrow_udf(lambda x: x, "variant", ArrowUDFType.SCALAR) + self.assertEqual(udf.returnType, VariantType()) + self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_ARROW_UDF) + + def test_arrow_udf_decorator(self): + @arrow_udf(DoubleType()) + def foo(x): + return x + + self.assertEqual(foo.returnType, DoubleType()) + self.assertEqual(foo.evalType, PythonEvalType.SQL_SCALAR_ARROW_UDF) + + @arrow_udf(returnType=DoubleType()) + def foo(x): + return x + + self.assertEqual(foo.returnType, DoubleType()) + self.assertEqual(foo.evalType, PythonEvalType.SQL_SCALAR_ARROW_UDF) + + def test_arrow_udf_decorator_with_return_type_string(self): + schema = StructType([StructField("v", DoubleType())]) + + @arrow_udf("v double", ArrowUDFType.SCALAR) + def foo(x): + return x + + self.assertEqual(foo.returnType, schema) + self.assertEqual(foo.evalType, PythonEvalType.SQL_SCALAR_ARROW_UDF) + + @arrow_udf(returnType="double", functionType=ArrowUDFType.SCALAR) + def foo(x): + return x + + self.assertEqual(foo.returnType, DoubleType()) + self.assertEqual(foo.evalType, PythonEvalType.SQL_SCALAR_ARROW_UDF) + + def test_arrow_udf_wrong_arg(self): + with self.quiet(): + with self.assertRaises(ParseException): + + @arrow_udf("blah") + def foo(x): + return x + + with self.assertRaises(PySparkTypeError) as pe: + + @arrow_udf(returnType="double", functionType=PandasUDFType.SCALAR) + def foo(df): + return df + + self.check_error( + exception=pe.exception, + errorClass="INVALID_PANDAS_UDF_TYPE", + messageParameters={ + "arg_name": "functionType", + "arg_type": "200", + }, + ) + + with self.assertRaises(PySparkTypeError) as pe: + + @arrow_udf(functionType=ArrowUDFType.SCALAR) + def foo(x): + return x + + self.check_error( + exception=pe.exception, + errorClass="CANNOT_BE_NONE", + messageParameters={"arg_name": "returnType"}, + ) + + with self.assertRaises(PySparkTypeError) as pe: + + @arrow_udf("double", 100) + def foo(x): + return x + + self.check_error( + exception=pe.exception, + errorClass="INVALID_PANDAS_UDF_TYPE", + messageParameters={"arg_name": "functionType", "arg_type": "100"}, + ) + + with self.assertRaises(PySparkTypeError) as pe: + + @arrow_udf(returnType=PandasUDFType.GROUPED_MAP) + def foo(df): + return df + + self.check_error( + exception=pe.exception, + errorClass="INVALID_PANDAS_UDF_TYPE", + messageParameters={"arg_name": "functionType", "arg_type": "201"}, + ) + + with self.assertRaisesRegex(ValueError, "0-arg arrow_udfs.*not.*supported"): + arrow_udf(lambda: 1, LongType(), ArrowUDFType.SCALAR) + + with self.assertRaisesRegex(ValueError, "0-arg arrow_udfs.*not.*supported"): + + @arrow_udf(LongType(), ArrowUDFType.SCALAR) + def zero_with_type(): + return 1 + + def test_arrow_udf_timestamp_ntz(self): + import pyarrow as pa + + @arrow_udf(returnType="timestamp_ntz") + def noop(s): + assert isinstance(s, pa.Array) + assert s[0].as_py() == datetime.datetime(1970, 1, 1, 0, 0) + return s + + with self.sql_conf({"spark.sql.session.timeZone": "Asia/Hong_Kong"}): + df = self.spark.createDataFrame( + [(datetime.datetime(1970, 1, 1, 0, 0),)], schema="dt timestamp_ntz" + ).select(noop("dt").alias("dt")) + + df.selectExpr("assert_true('1970-01-01 00:00:00' == CAST(dt AS STRING))").collect() + self.assertEqual(df.schema[0].dataType.typeName(), "timestamp_ntz") + self.assertEqual(df.first()[0], datetime.datetime(1970, 1, 1, 0, 0)) + + def test_arrow_udf_day_time_interval_type(self): + import pyarrow as pa + + @arrow_udf(DayTimeIntervalType(DayTimeIntervalType.DAY, DayTimeIntervalType.SECOND)) + def noop(s: pa.Array) -> pa.Array: + assert isinstance(s, pa.Array) + assert s[0].as_py() == datetime.timedelta(microseconds=123) + return s + + df = self.spark.createDataFrame( + [(datetime.timedelta(microseconds=123),)], schema="td interval day to second" + ).select(noop("td").alias("td")) + + df.select( + F.assert_true( + F.lit("INTERVAL '0 00:00:00.000123' DAY TO SECOND") == df.td.cast("string") + ) + ).collect() + self.assertEqual(df.schema[0].dataType.simpleString(), "interval day to second") + self.assertEqual(df.first()[0], datetime.timedelta(microseconds=123)) + + +class ArrowUDFTests(ArrowUDFTestsMixin, ReusedSQLTestCase): + pass + + +if __name__ == "__main__": + from pyspark.sql.tests.arrow.test_arrow_udf import * # noqa: F401 + + try: + import xmlrunner + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/arrow/test_arrow_udf_scalar.py b/python/pyspark/sql/tests/arrow/test_arrow_udf_scalar.py new file mode 100644 index 0000000000000..90f42d5a570ba --- /dev/null +++ b/python/pyspark/sql/tests/arrow/test_arrow_udf_scalar.py @@ -0,0 +1,639 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import random +import time +import unittest +from datetime import date, datetime, timezone +from decimal import Decimal + +from pyspark.util import PythonEvalType + +# TODO: import arrow_udf from public API +from pyspark.sql.pandas.functions import arrow_udf, ArrowUDFType +from pyspark.sql import functions as F +from pyspark.sql.types import ( + IntegerType, + ByteType, + StructType, + ShortType, + BooleanType, + LongType, + FloatType, + DoubleType, + DecimalType, + StringType, + ArrayType, + StructField, + Row, + MapType, + BinaryType, +) +from pyspark.errors import AnalysisException +from pyspark.testing.sqlutils import ( + ReusedSQLTestCase, + have_pyarrow, + pyarrow_requirement_message, +) + + +@unittest.skipIf(not have_pyarrow, pyarrow_requirement_message) +class ScalarArrowUDFTestsMixin: + @property + def nondeterministic_arrow_udf(self): + import pyarrow as pa + import numpy as np + + @arrow_udf("double") + def random_udf(v): + return pa.array(np.random.random(len(v))) + + return random_udf.asNondeterministic() + + def test_arrow_udf_tokenize(self): + import pyarrow as pa + + df = self.spark.createDataFrame([("hi boo",), ("bye boo",)], ["vals"]) + + tokenize = arrow_udf( + lambda s: pa.compute.ascii_split_whitespace(s), + ArrayType(StringType()), + ) + + result = df.select(tokenize("vals").alias("hi")) + self.assertEqual(tokenize.returnType, ArrayType(StringType())) + self.assertEqual([Row(hi=["hi", "boo"]), Row(hi=["bye", "boo"])], result.collect()) + + def test_arrow_udf_output_nested_arrays(self): + import pyarrow as pa + + df = self.spark.createDataFrame([("hi boo",), ("bye boo",)], ["vals"]) + + tokenize = arrow_udf( + lambda s: pa.array([pa.compute.ascii_split_whitespace(s).to_pylist()]), + ArrayType(ArrayType(StringType())), + ) + + result = df.select(tokenize("vals").alias("hi")) + self.assertEqual(tokenize.returnType, ArrayType(ArrayType(StringType()))) + self.assertEqual([Row(hi=[["hi", "boo"]]), Row(hi=[["bye", "boo"]])], result.collect()) + + def test_arrow_udf_output_structs(self): + import pyarrow as pa + + df = self.spark.range(10).select("id", F.lit("foo").alias("name")) + + create_struct = arrow_udf( + lambda x, y: pa.StructArray.from_arrays([x, y], names=["c1", "c2"]), + "struct", + ) + + self.assertEqual( + df.select(create_struct("id", "name").alias("res")).first(), + Row(res=Row(c1=0, c2="foo")), + ) + + def test_arrow_udf_output_nested_structs(self): + import pyarrow as pa + + df = self.spark.range(10).select("id", F.lit("foo").alias("name")) + + create_struct = arrow_udf( + lambda x, y: pa.StructArray.from_arrays( + [x, pa.StructArray.from_arrays([x, y], names=["c3", "c4"])], names=["c1", "c2"] + ), + "struct>", + ) + + self.assertEqual( + df.select(create_struct("id", "name").alias("res")).first(), + Row(res=Row(c1=0, c2=Row(c3=0, c4="foo"))), + ) + + def test_arrow_udf_input_output_nested_structs(self): + # root + # |-- s: struct (nullable = false) + # | |-- a: integer (nullable = false) + # | |-- y: struct (nullable = false) + # | | |-- b: integer (nullable = false) + # | | |-- x: struct (nullable = false) + # | | | |-- c: integer (nullable = false) + # | | | |-- d: integer (nullable = false) + df = self.spark.sql( + """ + SELECT STRUCT(a, STRUCT(b, STRUCT(c, d) AS x) AS y) AS s + FROM VALUES + (1, 2, 3, 4), + (-1, -2, -3, -4) + AS tab(a, b, c, d) + """ + ) + + schema = StructType( + [ + StructField("a", IntegerType(), False), + StructField( + "y", + StructType( + [ + StructField("b", IntegerType(), False), + StructField( + "x", + StructType( + [ + StructField("c", IntegerType(), False), + StructField("d", IntegerType(), False), + ] + ), + False, + ), + ] + ), + False, + ), + ] + ) + + extract_y = arrow_udf(lambda s: s.field("y"), schema["y"].dataType) + result = df.select(extract_y("s").alias("y")) + + self.assertEqual( + [Row(y=Row(b=2, x=Row(c=3, d=4))), Row(y=Row(b=-2, x=Row(c=-3, d=-4)))], + result.collect(), + ) + + def test_arrow_udf_input_nested_maps(self): + import pyarrow as pa + + schema = StructType( + [ + StructField("id", StringType(), True), + StructField( + "attributes", MapType(StringType(), MapType(StringType(), StringType())), True + ), + ] + ) + data = [("1", {"personal": {"name": "John", "city": "New York"}})] + # root + # |-- id: string (nullable = true) + # |-- attributes: map (nullable = true) + # | |-- key: string + # | |-- value: map (valueContainsNull = true) + # | | |-- key: string + # | | |-- value: string (valueContainsNull = true) + df = self.spark.createDataFrame(data, schema) + + str_repr = arrow_udf(lambda s: pa.array(str(x.as_py()) for x in s), StringType()) + result = df.select(str_repr("attributes").alias("s")) + + self.assertEqual( + [Row(s="[('personal', [('name', 'John'), ('city', 'New York')])]")], + result.collect(), + ) + + def test_arrow_udf_input_nested_arrays(self): + import pyarrow as pa + + # root + # |-- a: array (nullable = false) + # | |-- element: array (containsNull = false) + # | | |-- element: integer (containsNull = true) + df = self.spark.sql( + """ + SELECT ARRAY(ARRAY(1,2,3),ARRAY(4,NULL),ARRAY(5,6,NULL)) AS arr + """ + ) + + str_repr = arrow_udf(lambda s: pa.array(str(x.as_py()) for x in s), StringType()) + result = df.select(str_repr("arr").alias("s")) + + self.assertEqual( + [Row(s="[[1, 2, 3], [4, None], [5, 6, None]]")], + result.collect(), + ) + + def test_arrow_udf_input_arrow_array_struct(self): + import pyarrow as pa + + df = self.spark.createDataFrame( + [[[("a", 2, 3.0), ("a", 2, 3.0)]], [[("b", 5, 6.0), ("b", 5, 6.0)]]], + "array_struct_col array>", + ) + + @arrow_udf("array>") + def return_cols(cols): + assert isinstance(cols, pa.Array) + return cols + + result = df.select(return_cols("array_struct_col")) + + self.assertEqual( + [ + Row(output=[Row(col1="a", col2=2, col3=3.0), Row(col1="a", col2=2, col3=3.0)]), + Row(output=[Row(col1="b", col2=5, col3=6.0), Row(col1="b", col2=5, col3=6.0)]), + ], + result.collect(), + ) + + def test_arrow_udf_input_dates(self): + import pyarrow as pa + + df = self.spark.sql( + """ + SELECT * FROM VALUES + (1, DATE('2022-02-22')), + (2, DATE('2023-02-22')), + (3, DATE('2024-02-22')) + AS tab(i, date) + """ + ) + + @arrow_udf("int") + def extract_year(d): + assert isinstance(d, pa.Array) + assert isinstance(d, pa.Date32Array) + return pa.array([x.as_py().year for x in d], pa.int32()) + + result = df.select(extract_year("date").alias("year")) + self.assertEqual( + [Row(year=2022), Row(year=2023), Row(year=2024)], + result.collect(), + ) + + def test_arrow_udf_output_dates(self): + import pyarrow as pa + + df = self.spark.sql( + """ + SELECT * FROM VALUES + (2022, 1, 5), + (2023, 2, 6), + (2024, 3, 7) + AS tab(y, m, d) + """ + ) + + @arrow_udf("date") + def build_date(y, m, d): + assert all(isinstance(x, pa.Array) for x in [y, m, d]) + dates = [ + date(int(y[i].as_py()), int(m[i].as_py()), int(d[i].as_py())) for i in range(len(y)) + ] + return pa.array(dates, pa.date32()) + + result = df.select(build_date("y", "m", "d").alias("date")) + self.assertEqual( + [ + Row(date=date(2022, 1, 5)), + Row(date=date(2023, 2, 6)), + Row(date=date(2024, 3, 7)), + ], + result.collect(), + ) + + def test_arrow_udf_input_timestamps(self): + import pyarrow as pa + + df = self.spark.sql( + """ + SELECT * FROM VALUES + (1, TIMESTAMP('2019-04-12 15:50:01')), + (2, TIMESTAMP('2020-04-12 15:50:02')), + (3, TIMESTAMP('2021-04-12 15:50:03')) + AS tab(i, ts) + """ + ) + + @arrow_udf("int") + def extract_second(d): + assert isinstance(d, pa.Array) + assert isinstance(d, pa.TimestampArray) + return pa.array([x.as_py().second for x in d], pa.int32()) + + result = df.select(extract_second("ts").alias("second")) + self.assertEqual( + [Row(second=1), Row(second=2), Row(second=3)], + result.collect(), + ) + + def test_arrow_udf_output_timestamps_ltz(self): + import pyarrow as pa + + df = self.spark.sql( + """ + SELECT * FROM VALUES + (2022, 1, 5, 15, 0, 1), + (2023, 2, 6, 16, 1, 2), + (2024, 3, 7, 17, 2, 3) + AS tab(y, m, d, h, mi, s) + """ + ) + + @arrow_udf("timestamp") + def build_ts(y, m, d, h, mi, s): + assert all(isinstance(x, pa.Array) for x in [y, m, d, h, mi, s]) + dates = [ + datetime( + int(y[i].as_py()), + int(m[i].as_py()), + int(d[i].as_py()), + int(h[i].as_py()), + int(mi[i].as_py()), + int(s[i].as_py()), + tzinfo=timezone.utc, + ) + for i in range(len(y)) + ] + return pa.array(dates, pa.timestamp("us", "UTC")) + + result = df.select(build_ts("y", "m", "d", "h", "mi", "s").alias("ts")) + self.assertEqual( + [ + Row(ts=datetime(2022, 1, 5, 7, 0, 1)), + Row(ts=datetime(2023, 2, 6, 8, 1, 2)), + Row(ts=datetime(2024, 3, 7, 9, 2, 3)), + ], + result.collect(), + ) + + def test_arrow_udf_output_timestamps_ntz(self): + import pyarrow as pa + + df = self.spark.sql( + """ + SELECT * FROM VALUES + (2022, 1, 5, 15, 0, 1), + (2023, 2, 6, 16, 1, 2), + (2024, 3, 7, 17, 2, 3) + AS tab(y, m, d, h, mi, s) + """ + ) + + @arrow_udf("timestamp_ntz") + def build_ts(y, m, d, h, mi, s): + assert all(isinstance(x, pa.Array) for x in [y, m, d, h, mi, s]) + dates = [ + datetime( + int(y[i].as_py()), + int(m[i].as_py()), + int(d[i].as_py()), + int(h[i].as_py()), + int(mi[i].as_py()), + int(s[i].as_py()), + ) + for i in range(len(y)) + ] + return pa.array(dates, pa.timestamp("us")) + + result = df.select(build_ts("y", "m", "d", "h", "mi", "s").alias("ts")) + self.assertEqual( + [ + Row(ts=datetime(2022, 1, 5, 15, 0, 1)), + Row(ts=datetime(2023, 2, 6, 16, 1, 2)), + Row(ts=datetime(2024, 3, 7, 17, 2, 3)), + ], + result.collect(), + ) + + def test_arrow_udf_null_boolean(self): + data = [(True,), (True,), (None,), (False,)] + schema = StructType().add("bool", BooleanType()) + df = self.spark.createDataFrame(data, schema) + for udf_type in [ArrowUDFType.SCALAR]: + bool_f = arrow_udf(lambda x: x, BooleanType(), udf_type) + res = df.select(bool_f(F.col("bool"))) + self.assertEqual(df.collect(), res.collect()) + + def test_arrow_udf_null_byte(self): + data = [(None,), (2,), (3,), (4,)] + schema = StructType().add("byte", ByteType()) + df = self.spark.createDataFrame(data, schema) + for udf_type in [ArrowUDFType.SCALAR]: + byte_f = arrow_udf(lambda x: x, ByteType(), udf_type) + res = df.select(byte_f(F.col("byte"))) + self.assertEqual(df.collect(), res.collect()) + + def test_arrow_udf_null_short(self): + data = [(None,), (2,), (3,), (4,)] + schema = StructType().add("short", ShortType()) + df = self.spark.createDataFrame(data, schema) + for udf_type in [ArrowUDFType.SCALAR]: + short_f = arrow_udf(lambda x: x, ShortType(), udf_type) + res = df.select(short_f(F.col("short"))) + self.assertEqual(df.collect(), res.collect()) + + def test_arrow_udf_null_int(self): + data = [(None,), (2,), (3,), (4,)] + schema = StructType().add("int", IntegerType()) + df = self.spark.createDataFrame(data, schema) + for udf_type in [ArrowUDFType.SCALAR]: + int_f = arrow_udf(lambda x: x, IntegerType(), udf_type) + res = df.select(int_f(F.col("int"))) + self.assertEqual(df.collect(), res.collect()) + + def test_arrow_udf_null_long(self): + data = [(None,), (2,), (3,), (4,)] + schema = StructType().add("long", LongType()) + df = self.spark.createDataFrame(data, schema) + for udf_type in [ArrowUDFType.SCALAR]: + long_f = arrow_udf(lambda x: x, LongType(), udf_type) + res = df.select(long_f(F.col("long"))) + self.assertEqual(df.collect(), res.collect()) + + def test_arrow_udf_null_float(self): + data = [(3.0,), (5.0,), (-1.0,), (None,)] + schema = StructType().add("float", FloatType()) + df = self.spark.createDataFrame(data, schema) + for udf_type in [ArrowUDFType.SCALAR]: + float_f = arrow_udf(lambda x: x, FloatType(), udf_type) + res = df.select(float_f(F.col("float"))) + self.assertEqual(df.collect(), res.collect()) + + def test_arrow_udf_null_double(self): + data = [(3.0,), (5.0,), (-1.0,), (None,)] + schema = StructType().add("double", DoubleType()) + df = self.spark.createDataFrame(data, schema) + for udf_type in [ArrowUDFType.SCALAR]: + double_f = arrow_udf(lambda x: x, DoubleType(), udf_type) + res = df.select(double_f(F.col("double"))) + self.assertEqual(df.collect(), res.collect()) + + def test_arrow_udf_null_decimal(self): + data = [(Decimal(3.0),), (Decimal(5.0),), (Decimal(-1.0),), (None,)] + schema = StructType().add("decimal", DecimalType(38, 18)) + df = self.spark.createDataFrame(data, schema) + for udf_type in [ArrowUDFType.SCALAR]: + decimal_f = arrow_udf(lambda x: x, DecimalType(38, 18), udf_type) + res = df.select(decimal_f(F.col("decimal"))) + self.assertEqual(df.collect(), res.collect()) + + def test_arrow_udf_null_string(self): + data = [("foo",), (None,), ("bar",), ("bar",)] + schema = StructType().add("str", StringType()) + df = self.spark.createDataFrame(data, schema) + for udf_type in [ArrowUDFType.SCALAR]: + str_f = arrow_udf(lambda x: x, StringType(), udf_type) + res = df.select(str_f(F.col("str"))) + self.assertEqual(df.collect(), res.collect()) + + def test_arrow_udf_null_binary(self): + data = [(bytearray(b"a"),), (None,), (bytearray(b"bb"),), (bytearray(b"ccc"),)] + schema = StructType().add("binary", BinaryType()) + df = self.spark.createDataFrame(data, schema) + for udf_type in [ArrowUDFType.SCALAR]: + binary_f = arrow_udf(lambda x: x, BinaryType(), udf_type) + res = df.select(binary_f(F.col("binary"))) + self.assertEqual(df.collect(), res.collect()) + + def test_arrow_udf_null_array(self): + data = [([1, 2],), (None,), (None,), ([3, 4],), (None,)] + array_schema = StructType([StructField("array", ArrayType(IntegerType()))]) + df = self.spark.createDataFrame(data, schema=array_schema) + for udf_type in [ArrowUDFType.SCALAR]: + array_f = arrow_udf(lambda x: x, ArrayType(IntegerType()), udf_type) + result = df.select(array_f(F.col("array"))) + self.assertEqual(df.collect(), result.collect()) + + def test_arrow_udf_empty_partition(self): + df = self.spark.createDataFrame([Row(id=1)]).repartition(2) + for udf_type in [ArrowUDFType.SCALAR]: + f = arrow_udf(lambda x: x, LongType(), udf_type) + res = df.select(f(F.col("id"))) + self.assertEqual(df.collect(), res.collect()) + + def test_arrow_udf_datatype_string(self): + df = self.spark.range(10).select( + F.col("id").cast("string").alias("str"), + F.col("id").cast("int").alias("int"), + F.col("id").alias("long"), + F.col("id").cast("float").alias("float"), + F.col("id").cast("double").alias("double"), + # F.col("id").cast("decimal").alias("decimal"), + F.col("id").cast("boolean").alias("bool"), + ) + + def f(x): + return x + + for udf_type in [ArrowUDFType.SCALAR]: + str_f = arrow_udf(f, "string", udf_type) + int_f = arrow_udf(f, "integer", udf_type) + long_f = arrow_udf(f, "long", udf_type) + float_f = arrow_udf(f, "float", udf_type) + double_f = arrow_udf(f, "double", udf_type) + # decimal_f = arrow_udf(f, "decimal(38, 18)", udf_type) + bool_f = arrow_udf(f, "boolean", udf_type) + res = df.select( + str_f(F.col("str")), + int_f(F.col("int")), + long_f(F.col("long")), + float_f(F.col("float")), + double_f(F.col("double")), + # decimal_f("decimal"), + bool_f(F.col("bool")), + ) + self.assertEqual(df.collect(), res.collect()) + + def test_register_nondeterministic_arrow_udf(self): + import pyarrow as pa + + random_pandas_udf = arrow_udf( + lambda x: pa.compute.add(x, random.randint(6, 6)), LongType() + ).asNondeterministic() + self.assertEqual(random_pandas_udf.deterministic, False) + self.assertEqual(random_pandas_udf.evalType, PythonEvalType.SQL_SCALAR_ARROW_UDF) + nondeterministic_pandas_udf = self.spark.catalog.registerFunction( + "randomPandasUDF", random_pandas_udf + ) + self.assertEqual(nondeterministic_pandas_udf.deterministic, False) + self.assertEqual(nondeterministic_pandas_udf.evalType, PythonEvalType.SQL_SCALAR_ARROW_UDF) + [row] = self.spark.sql("SELECT randomPandasUDF(1)").collect() + self.assertEqual(row[0], 7) + + def test_nondeterministic_arrow_udf(self): + import pyarrow as pa + + # Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations + @arrow_udf("double") + def scalar_plus_ten(v): + return pa.compute.add(v, 10) + + # @arrow_udf("double", ArrowUDFType.SCALAR_ITER) + # def iter_plus_ten(it): + # for v in it: + # yield pa.compute.add(v, 10) + + for plus_ten in [scalar_plus_ten]: + random_udf = self.nondeterministic_arrow_udf + + df = self.spark.range(10).withColumn("rand", random_udf("id")) + result1 = df.withColumn("plus_ten(rand)", plus_ten(df["rand"])).toPandas() + + self.assertEqual(random_udf.deterministic, False) + self.assertTrue(result1["plus_ten(rand)"].equals(result1["rand"] + 10)) + + def test_nondeterministic_arrow_udf_in_aggregate(self): + with self.quiet(): + df = self.spark.range(10) + for random_udf in [ + self.nondeterministic_arrow_udf, + # self.nondeterministic_vectorized_iter_udf, + ]: + with self.assertRaisesRegex(AnalysisException, "Non-deterministic"): + df.groupby("id").agg(F.sum(random_udf("id"))).collect() + with self.assertRaisesRegex(AnalysisException, "Non-deterministic"): + df.agg(F.sum(random_udf("id"))).collect() + + # TODO: add tests for registering Arrow UDF + # TODO: add tests for chained Arrow UDFs + # TODO: add tests for named arguments + + +class ScalarArrowUDFTests(ScalarArrowUDFTestsMixin, ReusedSQLTestCase): + @classmethod + def setUpClass(cls): + ReusedSQLTestCase.setUpClass() + + # Synchronize default timezone between Python and Java + cls.tz_prev = os.environ.get("TZ", None) # save current tz if set + tz = "America/Los_Angeles" + os.environ["TZ"] = tz + time.tzset() + + cls.sc.environment["TZ"] = tz + cls.spark.conf.set("spark.sql.session.timeZone", tz) + + @classmethod + def tearDownClass(cls): + del os.environ["TZ"] + if cls.tz_prev is not None: + os.environ["TZ"] = cls.tz_prev + time.tzset() + ReusedSQLTestCase.tearDownClass() + + +if __name__ == "__main__": + from pyspark.sql.tests.arrow.test_arrow_udf_scalar import * # noqa: F401 + + try: + import xmlrunner + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/connect/arrow/test_parity_arrow_udf.py b/python/pyspark/sql/tests/connect/arrow/test_parity_arrow_udf.py new file mode 100644 index 0000000000000..567ab6913d30b --- /dev/null +++ b/python/pyspark/sql/tests/connect/arrow/test_parity_arrow_udf.py @@ -0,0 +1,36 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark.sql.tests.arrow.test_arrow_udf import ArrowUDFTestsMixin +from pyspark.testing.connectutils import ReusedConnectTestCase + + +class ArrowPythonUDFParityTests(ArrowUDFTestsMixin, ReusedConnectTestCase): + pass + + +if __name__ == "__main__": + import unittest + from pyspark.sql.tests.connect.arrow.test_parity_arrow_udf import * # noqa: F401 + + try: + import xmlrunner # type: ignore[import] + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/connect/arrow/test_parity_arrow_udf_scalar.py b/python/pyspark/sql/tests/connect/arrow/test_parity_arrow_udf_scalar.py new file mode 100644 index 0000000000000..ff18355775ef5 --- /dev/null +++ b/python/pyspark/sql/tests/connect/arrow/test_parity_arrow_udf_scalar.py @@ -0,0 +1,57 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import time + +from pyspark.sql.tests.arrow.test_arrow_udf_scalar import ScalarArrowUDFTestsMixin +from pyspark.testing.connectutils import ReusedConnectTestCase + + +class ScalarArrowPythonUDFParityTests(ScalarArrowUDFTestsMixin, ReusedConnectTestCase): + @classmethod + def setUpClass(cls): + ReusedConnectTestCase.setUpClass() + + # Synchronize default timezone between Python and Java + cls.tz_prev = os.environ.get("TZ", None) # save current tz if set + tz = "America/Los_Angeles" + os.environ["TZ"] = tz + time.tzset() + + cls.spark.conf.set("spark.sql.session.timeZone", tz) + + @classmethod + def tearDownClass(cls): + del os.environ["TZ"] + if cls.tz_prev is not None: + os.environ["TZ"] = cls.tz_prev + time.tzset() + ReusedConnectTestCase.tearDownClass() + + +if __name__ == "__main__": + import unittest + from pyspark.sql.tests.connect.arrow.test_parity_arrow_udf_scalar import * # noqa: F401 + + try: + import xmlrunner # type: ignore[import] + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index abfd0898fd545..6c166392f8cea 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -652,6 +652,7 @@ def register( PythonEvalType.SQL_BATCHED_UDF, PythonEvalType.SQL_ARROW_BATCHED_UDF, PythonEvalType.SQL_SCALAR_PANDAS_UDF, + PythonEvalType.SQL_SCALAR_ARROW_UDF, PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, ]: diff --git a/python/pyspark/util.py b/python/pyspark/util.py index 7e07d95538e4a..67ee88d3fb81a 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -65,6 +65,7 @@ PandasGroupedMapUDFTransformWithStateInitStateType, GroupedMapUDFTransformWithStateType, GroupedMapUDFTransformWithStateInitStateType, + ArrowScalarUDFType, ) from pyspark.sql._typing import ( SQLArrowBatchedUDFType, @@ -645,6 +646,10 @@ class PythonEvalType: SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF: "GroupedMapUDFTransformWithStateInitStateType" = ( # noqa: E501 214 ) + + # Arrow UDFs + SQL_SCALAR_ARROW_UDF: "ArrowScalarUDFType" = 250 + SQL_TABLE_UDF: "SQLTableUDFType" = 300 SQL_ARROW_TABLE_UDF: "SQLArrowTableUDFType" = 301 diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 18da7bcc77040..2209f393cdf02 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -52,6 +52,7 @@ from pyspark.sql.functions import SkipRestOfInputTableException from pyspark.sql.pandas.serializers import ( ArrowStreamPandasUDFSerializer, + ArrowStreamArrowUDFSerializer, ArrowStreamPandasUDTFSerializer, CogroupArrowUDFSerializer, CogroupPandasUDFSerializer, @@ -160,6 +161,46 @@ def verify_result_length(result, length): ) +def wrap_scalar_arrow_udf(f, args_offsets, kwargs_offsets, return_type, runner_conf): + func, args_kwargs_offsets = wrap_kwargs_support(f, args_offsets, kwargs_offsets) + + arrow_return_type = to_arrow_type( + return_type, prefers_large_types=use_large_var_types(runner_conf) + ) + + def verify_result_type(result): + if not hasattr(result, "__len__"): + pd_type = "pyarrow.Array" + raise PySparkTypeError( + errorClass="UDF_RETURN_TYPE", + messageParameters={ + "expected": pd_type, + "actual": type(result).__name__, + }, + ) + return result + + def verify_result_length(result, length): + if len(result) != length: + raise PySparkRuntimeError( + errorClass="SCHEMA_MISMATCH_FOR_PANDAS_UDF", + messageParameters={ + "udf_type": "arrow_udf", + "expected": str(length), + "actual": str(len(result)), + }, + ) + return result + + return ( + args_kwargs_offsets, + lambda *a: ( + verify_result_length(verify_result_type(func(*a)), len(a[0])), + arrow_return_type, + ), + ) + + def wrap_arrow_batch_udf(f, args_offsets, kwargs_offsets, return_type, runner_conf): import pandas as pd @@ -891,6 +932,7 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index, profil PythonEvalType.SQL_BATCHED_UDF, PythonEvalType.SQL_ARROW_BATCHED_UDF, PythonEvalType.SQL_SCALAR_PANDAS_UDF, + PythonEvalType.SQL_SCALAR_ARROW_UDF, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF, # The below doesn't support named argument, but shares the same protocol. @@ -947,6 +989,8 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index, profil # the last returnType will be the return type of UDF if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF: return wrap_scalar_pandas_udf(func, args_offsets, kwargs_offsets, return_type, runner_conf) + if eval_type == PythonEvalType.SQL_SCALAR_ARROW_UDF: + return wrap_scalar_arrow_udf(func, args_offsets, kwargs_offsets, return_type, runner_conf) elif eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF: return wrap_arrow_batch_udf(func, args_offsets, kwargs_offsets, return_type, runner_conf) elif eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF: @@ -1724,6 +1768,7 @@ def read_udfs(pickleSer, infile, eval_type): if eval_type in ( PythonEvalType.SQL_ARROW_BATCHED_UDF, PythonEvalType.SQL_SCALAR_PANDAS_UDF, + PythonEvalType.SQL_SCALAR_ARROW_UDF, PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF, PythonEvalType.SQL_MAP_PANDAS_ITER_UDF, @@ -1823,6 +1868,9 @@ def read_udfs(pickleSer, infile, eval_type): ser = ArrowStreamUDFSerializer() elif eval_type == PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF: ser = ArrowStreamGroupUDFSerializer(_assign_cols_by_name) + elif eval_type == PythonEvalType.SQL_SCALAR_ARROW_UDF: + # Arrow cast for type coercion is disabled by default + ser = ArrowStreamArrowUDFSerializer(timezone, safecheck, _assign_cols_by_name, False) else: # Scalar Pandas UDF handles struct type arguments as pandas DataFrames instead of # pandas Series. See SPARK-27240. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala index 53273b29a7c17..f036ef42690af 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala @@ -36,7 +36,8 @@ object PythonUDF { PythonEvalType.SQL_BATCHED_UDF, PythonEvalType.SQL_ARROW_BATCHED_UDF, PythonEvalType.SQL_SCALAR_PANDAS_UDF, - PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF + PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF, + PythonEvalType.SQL_SCALAR_ARROW_UDF ) def isScalarPythonUDF(e: Expression): Boolean = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index 6a81eaffd20a7..94a53ddc79a98 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -301,8 +301,10 @@ object ExtractPythonUDFs extends Rule[LogicalPlan] with Logging { log"Falling back to non-Arrow-optimized UDF execution.") } BatchEvalPython(validUdfs, resultAttrs, child) - case PythonEvalType.SQL_SCALAR_PANDAS_UDF | PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF - | PythonEvalType.SQL_ARROW_BATCHED_UDF => + case PythonEvalType.SQL_SCALAR_PANDAS_UDF + | PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF + | PythonEvalType.SQL_ARROW_BATCHED_UDF + | PythonEvalType.SQL_SCALAR_ARROW_UDF => ArrowEvalPython(validUdfs, resultAttrs, child, evalType) case _ => throw SparkException.internalError("Unexpected UDF evalType") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala index 5b508d1a0a722..c43cbad7c395f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala @@ -49,7 +49,8 @@ case class UserDefinedPythonFunction( if (pythonEvalType == PythonEvalType.SQL_BATCHED_UDF || pythonEvalType ==PythonEvalType.SQL_ARROW_BATCHED_UDF || pythonEvalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF - || pythonEvalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF) { + || pythonEvalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF + || pythonEvalType == PythonEvalType.SQL_SCALAR_ARROW_UDF) { /* * Check if the named arguments: * - don't have duplicated names