Skip to content

Commit f29502a

Browse files
xinrong-mengHyukjinKwon
authored andcommitted
[SPARK-43082][CONNECT][PYTHON] Arrow-optimized Python UDFs in Spark Connect
### What changes were proposed in this pull request? Implement Arrow-optimized Python UDFs in Spark Connect. Please see #39384 for motivation and performance improvements of Arrow-optimized Python UDFs. ### Why are the changes needed? Parity with vanilla PySpark. ### Does this PR introduce _any_ user-facing change? Yes. In Spark Connect Python Client, users can: 1. Set `useArrow` parameter True to enable Arrow optimization for a specific Python UDF. ```sh >>> df = spark.range(2) >>> df.select(udf(lambda x : x + 1, useArrow=True)('id')).show() +------------+ |<lambda>(id)| +------------+ | 1| | 2| +------------+ # ArrowEvalPython indicates Arrow optimization >>> df.select(udf(lambda x : x + 1, useArrow=True)('id')).explain() == Physical Plan == *(2) Project [pythonUDF0#18 AS <lambda>(id)#16] +- ArrowEvalPython [<lambda>(id#14L)#15], [pythonUDF0#18], 200 +- *(1) Range (0, 2, step=1, splits=1) ``` 2. Enable `spark.sql.execution.pythonUDF.arrow.enabled` Spark Conf to make all Python UDFs Arrow-optimized. ```sh >>> spark.conf.set("spark.sql.execution.pythonUDF.arrow.enabled", True) >>> df.select(udf(lambda x : x + 1)('id')).show() +------------+ |<lambda>(id)| +------------+ | 1| | 2| +------------+ # ArrowEvalPython indicates Arrow optimization >>> df.select(udf(lambda x : x + 1)('id')).explain() == Physical Plan == *(2) Project [pythonUDF0#30 AS <lambda>(id)#28] +- ArrowEvalPython [<lambda>(id#26L)#27], [pythonUDF0#30], 200 +- *(1) Range (0, 2, step=1, splits=1) ``` ### How was this patch tested? Parity unit tests. Closes #40725 from xinrong-meng/connect_arrow_py_udf. Authored-by: Xinrong Meng <[email protected]> Signed-off-by: Hyukjin Kwon <[email protected]>
1 parent fece7ed commit f29502a

File tree

6 files changed

+171
-90
lines changed

6 files changed

+171
-90
lines changed

python/pyspark/sql/connect/functions.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
LambdaFunction,
5151
UnresolvedNamedLambdaVariable,
5252
)
53-
from pyspark.sql.connect.udf import _create_udf
53+
from pyspark.sql.connect.udf import _create_py_udf
5454
from pyspark.sql import functions as pysparkfuncs
5555
from pyspark.sql.types import _from_numpy_type, DataType, StructType, ArrayType, StringType
5656

@@ -2461,6 +2461,7 @@ def unwrap_udt(col: "ColumnOrName") -> Column:
24612461
def udf(
24622462
f: Optional[Union[Callable[..., Any], "DataTypeOrString"]] = None,
24632463
returnType: "DataTypeOrString" = StringType(),
2464+
useArrow: Optional[bool] = None,
24642465
) -> Union["UserDefinedFunctionLike", Callable[[Callable[..., Any]], "UserDefinedFunctionLike"]]:
24652466
from pyspark.rdd import PythonEvalType
24662467

@@ -2469,10 +2470,15 @@ def udf(
24692470
# for decorator use it as a returnType
24702471
return_type = f or returnType
24712472
return functools.partial(
2472-
_create_udf, returnType=return_type, evalType=PythonEvalType.SQL_BATCHED_UDF
2473+
_create_py_udf,
2474+
returnType=return_type,
2475+
evalType=PythonEvalType.SQL_BATCHED_UDF,
2476+
useArrow=useArrow,
24732477
)
24742478
else:
2475-
return _create_udf(f=f, returnType=returnType, evalType=PythonEvalType.SQL_BATCHED_UDF)
2479+
return _create_py_udf(
2480+
f=f, returnType=returnType, evalType=PythonEvalType.SQL_BATCHED_UDF, useArrow=useArrow
2481+
)
24762482

24772483

24782484
udf.__doc__ = pysparkfuncs.udf.__doc__

python/pyspark/sql/connect/udf.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323

2424
import sys
2525
import functools
26+
import warnings
27+
from inspect import getfullargspec
2628
from typing import cast, Callable, Any, TYPE_CHECKING, Optional, Union
2729

2830
from pyspark.rdd import PythonEvalType
@@ -33,7 +35,7 @@
3335
)
3436
from pyspark.sql.connect.column import Column
3537
from pyspark.sql.connect.types import UnparsedDataType
36-
from pyspark.sql.types import DataType, StringType
38+
from pyspark.sql.types import ArrayType, DataType, MapType, StringType, StructType
3739
from pyspark.sql.udf import UDFRegistration as PySparkUDFRegistration
3840

3941

@@ -47,6 +49,48 @@
4749
from pyspark.sql.types import StringType
4850

4951

52+
def _create_py_udf(
53+
f: Callable[..., Any],
54+
returnType: "DataTypeOrString",
55+
evalType: int,
56+
useArrow: Optional[bool] = None,
57+
) -> "UserDefinedFunctionLike":
58+
from pyspark.sql.udf import _create_arrow_py_udf
59+
from pyspark.sql.connect.session import _active_spark_session
60+
61+
if _active_spark_session is None:
62+
is_arrow_enabled = False
63+
else:
64+
is_arrow_enabled = (
65+
_active_spark_session.conf.get("spark.sql.execution.pythonUDF.arrow.enabled") == "true"
66+
if useArrow is None
67+
else useArrow
68+
)
69+
70+
regular_udf = _create_udf(f, returnType, evalType)
71+
return_type = regular_udf.returnType
72+
try:
73+
is_func_with_args = len(getfullargspec(f).args) > 0
74+
except TypeError:
75+
is_func_with_args = False
76+
is_output_atomic_type = (
77+
not isinstance(return_type, StructType)
78+
and not isinstance(return_type, MapType)
79+
and not isinstance(return_type, ArrayType)
80+
)
81+
if is_arrow_enabled:
82+
if is_output_atomic_type and is_func_with_args:
83+
return _create_arrow_py_udf(regular_udf)
84+
else:
85+
warnings.warn(
86+
"Arrow optimization for Python UDFs cannot be enabled.",
87+
UserWarning,
88+
)
89+
return regular_udf
90+
else:
91+
return regular_udf
92+
93+
5094
def _create_udf(
5195
f: Callable[..., Any],
5296
returnType: "DataTypeOrString",
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
import unittest
19+
20+
from pyspark.sql.tests.connect.test_parity_udf import UDFParityTests
21+
from pyspark.sql.tests.test_arrow_python_udf import PythonUDFArrowTestsMixin
22+
23+
24+
class ArrowPythonUDFParityTests(UDFParityTests, PythonUDFArrowTestsMixin):
25+
@classmethod
26+
def setUpClass(cls):
27+
super(ArrowPythonUDFParityTests, cls).setUpClass()
28+
cls.spark.conf.set("spark.sql.execution.pythonUDF.arrow.enabled", "true")
29+
30+
@classmethod
31+
def tearDownClass(cls):
32+
try:
33+
cls.spark.conf.unset("spark.sql.execution.pythonUDF.arrow.enabled")
34+
finally:
35+
super(ArrowPythonUDFParityTests, cls).tearDownClass()
36+
37+
38+
if __name__ == "__main__":
39+
import unittest
40+
from pyspark.sql.tests.connect.test_parity_arrow_python_udf import * # noqa: F401
41+
42+
try:
43+
import xmlrunner # type: ignore[import]
44+
45+
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
46+
except ImportError:
47+
testRunner = None
48+
unittest.main(testRunner=testRunner, verbosity=2)

python/pyspark/sql/tests/test_arrow_python_udf.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,7 @@
3131
@unittest.skipIf(
3232
not have_pandas or not have_pyarrow, pandas_requirement_message or pyarrow_requirement_message
3333
)
34-
class PythonUDFArrowTests(BaseUDFTestsMixin, ReusedSQLTestCase):
35-
@classmethod
36-
def setUpClass(cls):
37-
super(PythonUDFArrowTests, cls).setUpClass()
38-
cls.spark.conf.set("spark.sql.execution.pythonUDF.arrow.enabled", "true")
39-
34+
class PythonUDFArrowTestsMixin(BaseUDFTestsMixin):
4035
@unittest.skip("Unrelated test, and it fails when it runs duplicatedly.")
4136
def test_broadcast_in_udf(self):
4237
super(PythonUDFArrowTests, self).test_broadcast_in_udf()
@@ -118,6 +113,20 @@ def test_use_arrow(self):
118113
self.assertEquals(row_false[0], "[1, 2, 3]")
119114

120115

116+
class PythonUDFArrowTests(PythonUDFArrowTestsMixin, ReusedSQLTestCase):
117+
@classmethod
118+
def setUpClass(cls):
119+
super(PythonUDFArrowTests, cls).setUpClass()
120+
cls.spark.conf.set("spark.sql.execution.pythonUDF.arrow.enabled", "true")
121+
122+
@classmethod
123+
def tearDownClass(cls):
124+
try:
125+
cls.spark.conf.unset("spark.sql.execution.pythonUDF.arrow.enabled")
126+
finally:
127+
super(PythonUDFArrowTests, cls).tearDownClass()
128+
129+
121130
if __name__ == "__main__":
122131
from pyspark.sql.tests.test_arrow_python_udf import * # noqa: F401
123132

python/pyspark/sql/tests/test_udf.py

Lines changed: 0 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -838,47 +838,6 @@ def setUpClass(cls):
838838
cls.spark.conf.set("spark.sql.execution.pythonUDF.arrow.enabled", "false")
839839

840840

841-
def test_use_arrow(self):
842-
# useArrow=True
843-
row_true = (
844-
self.spark.range(1)
845-
.selectExpr(
846-
"array(1, 2, 3) as array",
847-
)
848-
.select(
849-
udf(lambda x: str(x), useArrow=True)("array"),
850-
)
851-
.first()
852-
)
853-
# The input is a NumPy array when the Arrow optimization is on.
854-
self.assertEquals(row_true[0], "[1 2 3]")
855-
856-
# useArrow=None
857-
row_none = (
858-
self.spark.range(1)
859-
.selectExpr(
860-
"array(1, 2, 3) as array",
861-
)
862-
.select(
863-
udf(lambda x: str(x), useArrow=None)("array"),
864-
)
865-
.first()
866-
)
867-
868-
# useArrow=False
869-
row_false = (
870-
self.spark.range(1)
871-
.selectExpr(
872-
"array(1, 2, 3) as array",
873-
)
874-
.select(
875-
udf(lambda x: str(x), useArrow=False)("array"),
876-
)
877-
.first()
878-
)
879-
self.assertEquals(row_false[0], row_none[0]) # "[1, 2, 3]"
880-
881-
882841
class UDFInitializationTests(unittest.TestCase):
883842
def tearDown(self):
884843
if SparkSession._instantiatedSession is not None:

python/pyspark/sql/udf.py

Lines changed: 54 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def _create_udf(
7575
name: Optional[str] = None,
7676
deterministic: bool = True,
7777
) -> "UserDefinedFunctionLike":
78+
"""Create a regular(non-Arrow-optimized) Python UDF."""
7879
# Set the name of the UserDefinedFunction object to be the name of function f
7980
udf_obj = UserDefinedFunction(
8081
f, returnType=returnType, name=name, evalType=evalType, deterministic=deterministic
@@ -88,6 +89,7 @@ def _create_py_udf(
8889
evalType: int,
8990
useArrow: Optional[bool] = None,
9091
) -> "UserDefinedFunctionLike":
92+
"""Create a regular/Arrow-optimized Python UDF."""
9193
# The following table shows the results when the type coercion in Arrow is needed, that is,
9294
# when the user-specified return type(SQL Type) of the UDF and the actual instance(Python
9395
# Value(Type)) that the UDF returns are different.
@@ -138,49 +140,62 @@ def _create_py_udf(
138140
and not isinstance(return_type, MapType)
139141
and not isinstance(return_type, ArrayType)
140142
)
141-
if is_arrow_enabled and is_output_atomic_type and is_func_with_args:
142-
require_minimum_pandas_version()
143-
require_minimum_pyarrow_version()
144-
145-
import pandas as pd
146-
from pyspark.sql.pandas.functions import _create_pandas_udf # type: ignore[attr-defined]
147-
148-
# "result_func" ensures the result of a Python UDF to be consistent with/without Arrow
149-
# optimization.
150-
# Otherwise, an Arrow-optimized Python UDF raises "pyarrow.lib.ArrowTypeError: Expected a
151-
# string or bytes dtype, got ..." whereas a non-Arrow-optimized Python UDF returns
152-
# successfully.
153-
result_func = lambda pdf: pdf # noqa: E731
154-
if type(return_type) == StringType:
155-
result_func = lambda r: str(r) if r is not None else r # noqa: E731
156-
elif type(return_type) == BinaryType:
157-
result_func = lambda r: bytes(r) if r is not None else r # noqa: E731
158-
159-
def vectorized_udf(*args: pd.Series) -> pd.Series:
160-
if any(map(lambda arg: isinstance(arg, pd.DataFrame), args)):
161-
raise NotImplementedError(
162-
"Struct input type are not supported with Arrow optimization "
163-
"enabled in Python UDFs. Disable "
164-
"'spark.sql.execution.pythonUDF.arrow.enabled' to workaround."
165-
)
166-
return pd.Series(result_func(f(*a)) for a in zip(*args))
167-
168-
# Regular UDFs can take callable instances too.
169-
vectorized_udf.__name__ = f.__name__ if hasattr(f, "__name__") else f.__class__.__name__
170-
vectorized_udf.__module__ = (
171-
f.__module__ if hasattr(f, "__module__") else f.__class__.__module__
172-
)
173-
vectorized_udf.__doc__ = f.__doc__
174-
pudf = _create_pandas_udf(vectorized_udf, returnType, None)
175-
# Keep the attributes as if this is a regular Python UDF.
176-
pudf.func = f
177-
pudf.returnType = return_type
178-
pudf.evalType = regular_udf.evalType
179-
return pudf
143+
if is_arrow_enabled:
144+
if is_output_atomic_type and is_func_with_args:
145+
return _create_arrow_py_udf(regular_udf)
146+
else:
147+
warnings.warn(
148+
"Arrow optimization for Python UDFs cannot be enabled.",
149+
UserWarning,
150+
)
151+
return regular_udf
180152
else:
181153
return regular_udf
182154

183155

156+
def _create_arrow_py_udf(regular_udf): # type: ignore
157+
"""Create an Arrow-optimized Python UDF out of a regular Python UDF."""
158+
require_minimum_pandas_version()
159+
require_minimum_pyarrow_version()
160+
161+
import pandas as pd
162+
from pyspark.sql.pandas.functions import _create_pandas_udf
163+
164+
f = regular_udf.func
165+
return_type = regular_udf.returnType
166+
167+
# "result_func" ensures the result of a Python UDF to be consistent with/without Arrow
168+
# optimization.
169+
# Otherwise, an Arrow-optimized Python UDF raises "pyarrow.lib.ArrowTypeError: Expected a
170+
# string or bytes dtype, got ..." whereas a non-Arrow-optimized Python UDF returns
171+
# successfully.
172+
result_func = lambda pdf: pdf # noqa: E731
173+
if type(return_type) == StringType:
174+
result_func = lambda r: str(r) if r is not None else r # noqa: E731
175+
elif type(return_type) == BinaryType:
176+
result_func = lambda r: bytes(r) if r is not None else r # noqa: E731
177+
178+
def vectorized_udf(*args: pd.Series) -> pd.Series:
179+
if any(map(lambda arg: isinstance(arg, pd.DataFrame), args)):
180+
raise NotImplementedError(
181+
"Struct input type are not supported with Arrow optimization "
182+
"enabled in Python UDFs. Disable "
183+
"'spark.sql.execution.pythonUDF.arrow.enabled' to workaround."
184+
)
185+
return pd.Series(result_func(f(*a)) for a in zip(*args))
186+
187+
# Regular UDFs can take callable instances too.
188+
vectorized_udf.__name__ = f.__name__ if hasattr(f, "__name__") else f.__class__.__name__
189+
vectorized_udf.__module__ = f.__module__ if hasattr(f, "__module__") else f.__class__.__module__
190+
vectorized_udf.__doc__ = f.__doc__
191+
pudf = _create_pandas_udf(vectorized_udf, return_type, None)
192+
# Keep the attributes as if this is a regular Python UDF.
193+
pudf.func = f
194+
pudf.returnType = return_type
195+
pudf.evalType = regular_udf.evalType
196+
return pudf
197+
198+
184199
class UserDefinedFunction:
185200
"""
186201
User defined function in Python

0 commit comments

Comments
 (0)