Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2781,14 +2781,14 @@ def pandas_udf(f=None, returnType=None, functionType=None):
+---+-------------------+

Alternatively, the user can define a function that takes two arguments.
In this case, the grouping key will be passed as the first argument and the data will
be passed as the second argument. The grouping key will be passed as a tuple of numpy
In this case, the grouping key(s) will be passed as the first argument and the data will
be passed as the second argument. The grouping key(s) will be passed as a tuple of numpy
data types, e.g., `numpy.int32` and `numpy.float64`. The data will still be passed in
as a `pandas.DataFrame` containing all columns from the original Spark DataFrame.
This is useful when the user does not want to hardcode grouping key in the function.
This is useful when the user does not want to hardcode grouping key(s) in the function.

>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
>>> import pandas as pd # doctest: +SKIP
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
>>> df = spark.createDataFrame(
... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
... ("id", "v")) # doctest: +SKIP
Expand All @@ -2804,6 +2804,22 @@ def pandas_udf(f=None, returnType=None, functionType=None):
| 1|1.5|
| 2|6.0|
+---+---+
>>> @pandas_udf(
... "id long, `ceil(v / 2)` long, v double",
... PandasUDFType.GROUPED_MAP) # doctest: +SKIP
>>> def sum_udf(key, pdf):
... # key is a tuple of two numpy.int64s, which is the values
... # of 'id' and 'ceil(df.v / 2)' for the current group
... return pd.DataFrame([key + (pdf.v.sum(),)])
>>> df.groupby(df.id, ceil(df.v / 2)).apply(sum_udf).show() # doctest: +SKIP
+---+-----------+----+
| id|ceil(v / 2)| v|
+---+-----------+----+
| 2| 5|10.0|
| 1| 1| 3.0|
| 2| 3| 5.0|
| 2| 2| 3.0|
+---+-----------+----+

.. note:: If returning a new `pandas.DataFrame` constructed with a dictionary, it is
recommended to explicitly index the columns by name to ensure the positions are correct,
Expand Down