diff --git a/python/pyspark/sql/tests/test_udf.py b/python/pyspark/sql/tests/test_udf.py index 239ff27813be..2f8c1cd21364 100644 --- a/python/pyspark/sql/tests/test_udf.py +++ b/python/pyspark/sql/tests/test_udf.py @@ -24,7 +24,7 @@ from pyspark import SparkContext, SQLContext from pyspark.sql import SparkSession, Column, Row -from pyspark.sql.functions import udf, assert_true, lit, rand +from pyspark.sql.functions import col, udf, assert_true, lit, rand from pyspark.sql.udf import UserDefinedFunction from pyspark.sql.types import ( StringType, @@ -38,9 +38,9 @@ TimestampNTZType, DayTimeIntervalType, ) -from pyspark.errors import AnalysisException, PySparkTypeError +from pyspark.errors import AnalysisException, PythonException, PySparkTypeError from pyspark.testing.sqlutils import ReusedSQLTestCase, test_compiled, test_not_compiled_message -from pyspark.testing.utils import QuietTest +from pyspark.testing.utils import QuietTest, assertDataFrameEqual class BaseUDFTestsMixin(object): @@ -898,6 +898,21 @@ def test_complex_return_types(self): self.assertEquals(row[1], {"a": "b"}) self.assertEquals(row[2], Row(col1=1, col2=2)) + def test_raise_stop_iteration(self): + @udf("int") + def test_udf(a): + if a < 5: + return a + else: + raise StopIteration() + + assertDataFrameEqual( + self.spark.range(5).select(test_udf(col("id"))), [Row(i) for i in range(5)] + ) + + with self.assertRaisesRegex(PythonException, "StopIteration"): + self.spark.range(10).select(test_udf(col("id"))).show() + class UDFTests(BaseUDFTestsMixin, ReusedSQLTestCase): @classmethod diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index edbfad4a5dc2..d2ea18c45c92 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -154,20 +154,9 @@ def wrap_arrow_batch_udf(f, return_type): elif type(return_type) == BinaryType: result_func = lambda r: bytes(r) if r is not None else r # noqa: E731 + @fail_on_stopiteration def evaluate(*args: pd.Series) -> pd.Series: - return pd.Series(result_func(f(*a)) for a in zip(*args)) - - def verify_result_type(result): - if not hasattr(result, "__len__"): - pd_type = "pandas.DataFrame" if type(return_type) == StructType else "pandas.Series" - raise PySparkTypeError( - error_class="UDF_RETURN_TYPE", - message_parameters={ - "expected": pd_type, - "actual": type(result).__name__, - }, - ) - return result + return pd.Series([result_func(f(*a)) for a in zip(*args)]) def verify_result_length(result, length): if len(result) != length: @@ -181,7 +170,7 @@ def verify_result_length(result, length): return result return lambda *a: ( - verify_result_length(verify_result_type(evaluate(*a)), len(a[0])), + verify_result_length(evaluate(*a), len(a[0])), arrow_return_type, ) @@ -543,7 +532,10 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index): else: chained_func = chain(chained_func, f) - if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF: + if eval_type in ( + PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF, + PythonEvalType.SQL_ARROW_BATCHED_UDF, + ): func = chained_func else: # make sure StopIteration's raised in the user code are not ignored