Skip to content
3 changes: 2 additions & 1 deletion python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2718,9 +2718,10 @@ def pandas_udf(f=None, returnType=None, functionType=None):
1. SCALAR

A scalar UDF defines a transformation: One or more `pandas.Series` -> A `pandas.Series`.
The returnType should be a primitive data type, e.g., :class:`DoubleType`.
The length of the returned `pandas.Series` must be of the same as the input `pandas.Series`.

:class:`MapType`, :class:`StructType` are currently not supported as output types.

Scalar UDFs are used with :meth:`pyspark.sql.DataFrame.withColumn` and
:meth:`pyspark.sql.DataFrame.select`.

Expand Down
19 changes: 19 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -4358,6 +4358,7 @@ def test_timestamp_dst(self):
not _have_pandas or not _have_pyarrow,
_pandas_requirement_message or _pyarrow_requirement_message)
class PandasUDFTests(ReusedSQLTestCase):

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems unrelated ..

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is, however makes it fit with the style of the rest of the file.

def test_pandas_udf_basic(self):
from pyspark.rdd import PythonEvalType
from pyspark.sql.functions import pandas_udf, PandasUDFType
Expand Down Expand Up @@ -4573,6 +4574,24 @@ def random_udf(v):
random_udf = random_udf.asNondeterministic()
return random_udf

def test_pandas_udf_tokenize(self):
from pyspark.sql.functions import pandas_udf
tokenize = pandas_udf(lambda s: s.apply(lambda str: str.split(' ')),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of s.apply with a lambda, you could do lambda s: s.str.split(' '), both do the same just a little more compact

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm. I thought this PR targets to clarify array type wuth primitive types. can we improve the test case here -https://github.com/holdenk/spark/blob/342d2228a5c68fd2c07bd8c1b518da6135ce1bf6/python/pyspark/sql/tests.py#L3998, and remove this test case?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer to have a test case for primitive array type as well.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a pretty common use to tokenize, so I think it's fine to have an explicit test for this

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think this PR targets to fix or support tokenizing in an udf ..

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@HyukjinKwon It doesn't, but given that the old documentation implied that the ionization usecase wouldn't work I thought it would be good to illustrate that it does in a test.

ArrayType(StringType()))
self.assertEqual(tokenize.returnType, ArrayType(StringType()))
df = self.spark.createDataFrame([("hi boo",), ("bye boo",)], ["vals"])
result = df.select(tokenize("vals").alias("hi"))
self.assertEqual([Row(hi=[u'hi', u'boo']), Row(hi=[u'bye', u'boo'])], result.collect())

def test_pandas_udf_nested_arrays(self):
from pyspark.sql.functions import pandas_udf
tokenize = pandas_udf(lambda s: s.apply(lambda str: [str.split(' ')]),
ArrayType(ArrayType(StringType())))
self.assertEqual(tokenize.returnType, ArrayType(ArrayType(StringType())))
df = self.spark.createDataFrame([("hi boo",), ("bye boo",)], ["vals"])
result = df.select(tokenize("vals").alias("hi"))
self.assertEqual([Row(hi=[[u'hi', u'boo']]), Row(hi=[[u'bye', u'boo']])], result.collect())

def test_vectorized_udf_basic(self):
from pyspark.sql.functions import pandas_udf, col, array
df = self.spark.range(10).select(
Expand Down