Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions python/pyspark/sql/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,6 +1005,21 @@ def test_udf(a, b=0):
with self.subTest(with_b=True, query_no=i):
assertDataFrameEqual(df, [Row(0), Row(101)])

def test_raise_stop_iteration(self):
@udf("int")
def test_udf(a):
if a < 5:
return a
else:
raise StopIteration()

assertDataFrameEqual(
self.spark.range(5).select(test_udf(col("id"))), [Row(i) for i in range(5)]
)

with self.assertRaisesRegex(PythonException, "StopIteration"):
self.spark.range(10).select(test_udf(col("id"))).show()


class UDFTests(BaseUDFTestsMixin, ReusedSQLTestCase):
@classmethod
Expand Down
22 changes: 6 additions & 16 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def wrap_arrow_batch_udf(f, return_type):
elif type(return_type) == BinaryType:
result_func = lambda r: bytes(r) if r is not None else r # noqa: E731

@fail_on_stopiteration
def evaluate(*args: pd.Series, **kwargs: pd.Series) -> pd.Series:
keys = list(kwargs.keys())
len_args = len(args)
Expand All @@ -151,18 +152,6 @@ def evaluate(*args: pd.Series, **kwargs: pd.Series) -> pd.Series:
]
)

def verify_result_type(result):
if not hasattr(result, "__len__"):
pd_type = "pandas.DataFrame" if type(return_type) == StructType else "pandas.Series"
raise PySparkTypeError(
error_class="UDF_RETURN_TYPE",
message_parameters={
"expected": pd_type,
"actual": type(result).__name__,
},
)
return result

def verify_result_length(result, length):
if len(result) != length:
raise PySparkRuntimeError(
Expand All @@ -175,9 +164,7 @@ def verify_result_length(result, length):
return result

return lambda *a, **kw: (
verify_result_length(
verify_result_type(evaluate(*a, **kw)), len((list(a) + list(kw.values()))[0])
),
verify_result_length(evaluate(*a, **kw), len((list(a) + list(kw.values()))[0])),
arrow_return_type,
)

Expand Down Expand Up @@ -562,7 +549,10 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index):
else:
chained_func = chain(chained_func, f)

if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF:
if eval_type in (
PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
PythonEvalType.SQL_ARROW_BATCHED_UDF,
):
func = chained_func
else:
# make sure StopIteration's raised in the user code are not ignored
Expand Down