Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions python/pyspark/sql/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from pyspark import SparkContext, SQLContext
from pyspark.sql import SparkSession, Column, Row
from pyspark.sql.functions import udf, assert_true, lit, rand
from pyspark.sql.functions import col, udf, assert_true, lit, rand
from pyspark.sql.udf import UserDefinedFunction
from pyspark.sql.types import (
StringType,
Expand All @@ -38,9 +38,9 @@
TimestampNTZType,
DayTimeIntervalType,
)
from pyspark.errors import AnalysisException, PySparkTypeError
from pyspark.errors import AnalysisException, PythonException, PySparkTypeError
from pyspark.testing.sqlutils import ReusedSQLTestCase, test_compiled, test_not_compiled_message
from pyspark.testing.utils import QuietTest
from pyspark.testing.utils import QuietTest, assertDataFrameEqual


class BaseUDFTestsMixin(object):
Expand Down Expand Up @@ -898,6 +898,21 @@ def test_complex_return_types(self):
self.assertEquals(row[1], {"a": "b"})
self.assertEquals(row[2], Row(col1=1, col2=2))

def test_raise_stop_iteration(self):
@udf("int")
def test_udf(a):
if a < 5:
return a
else:
raise StopIteration()

assertDataFrameEqual(
self.spark.range(5).select(test_udf(col("id"))), [Row(i) for i in range(5)]
)

with self.assertRaisesRegex(PythonException, "StopIteration"):
self.spark.range(10).select(test_udf(col("id"))).show()


class UDFTests(BaseUDFTestsMixin, ReusedSQLTestCase):
@classmethod
Expand Down
22 changes: 7 additions & 15 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,20 +154,9 @@ def wrap_arrow_batch_udf(f, return_type):
elif type(return_type) == BinaryType:
result_func = lambda r: bytes(r) if r is not None else r # noqa: E731

@fail_on_stopiteration
def evaluate(*args: pd.Series) -> pd.Series:
return pd.Series(result_func(f(*a)) for a in zip(*args))

def verify_result_type(result):
if not hasattr(result, "__len__"):
pd_type = "pandas.DataFrame" if type(return_type) == StructType else "pandas.Series"
raise PySparkTypeError(
error_class="UDF_RETURN_TYPE",
message_parameters={
"expected": pd_type,
"actual": type(result).__name__,
},
)
return result
return pd.Series([result_func(f(*a)) for a in zip(*args)])

def verify_result_length(result, length):
if len(result) != length:
Expand All @@ -181,7 +170,7 @@ def verify_result_length(result, length):
return result

return lambda *a: (
verify_result_length(verify_result_type(evaluate(*a)), len(a[0])),
verify_result_length(evaluate(*a), len(a[0])),
arrow_return_type,
)

Expand Down Expand Up @@ -543,7 +532,10 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index):
else:
chained_func = chain(chained_func, f)

if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF:
if eval_type in (
PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
PythonEvalType.SQL_ARROW_BATCHED_UDF,
):
func = chained_func
else:
# make sure StopIteration's raised in the user code are not ignored
Expand Down