Skip to content

Commit 23a19e6

Browse files
committed
[SPARK-52905][PYTHON] Arrow UDF for window
### What changes were proposed in this pull request? Arrow UDF for window ### Why are the changes needed? to make Arrow UDF support window operation ### Does this PR introduce _any_ user-facing change? Not, yet. Will make Arrow UDF public soon ```py In [1]: from typing import Iterator, Tuple ...: import pyarrow as pa ...: from pyspark.sql import Window ...: from pyspark.sql import functions as sf ...: from pyspark.sql.pandas.functions import arrow_udf ...: ...: import pandas as pd ...: from pyspark.sql.functions import pandas_udf ...: from pyspark.sql import Window ...: ...: df = spark.createDataFrame([(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v")) ...: ...: w = Window.partitionBy('id').rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing) ...: ...: In [2]: arrow_udf("double") ...: def arrow_mean_udf(v: pa.Array) -> float: ...: assert isinstance(v, pa.Array), str(type(v)) ...: return pa.compute.mean(v) ...: ...: # df.select(arrow_mean_udf(df['v'])).show() ...: # df.groupby("id").agg(arrow_mean_udf('v')).show() ...: ...: df.withColumn('mean_v', arrow_mean_udf(df['v']).over(w)).show() ...: ...: +---+----+------+ | id| v|mean_v| +---+----+------+ | 1| 1.0| 1.5| | 1| 2.0| 1.5| | 2| 3.0| 6.0| | 2| 5.0| 6.0| | 2|10.0| 6.0| +---+----+------+ ``` ### How was this patch tested? New tests ### Was this patch authored or co-authored using generative AI tooling? No Closes #51593 from zhengruifeng/arrow_udf_win. Authored-by: Ruifeng Zheng <[email protected]> Signed-off-by: Ruifeng Zheng <[email protected]>
1 parent f345634 commit 23a19e6

File tree

11 files changed

+686
-9
lines changed

11 files changed

+686
-9
lines changed

core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ private[spark] object PythonEvalType {
7171
val SQL_SCALAR_ARROW_UDF = 250
7272
val SQL_SCALAR_ARROW_ITER_UDF = 251
7373
val SQL_GROUPED_AGG_ARROW_UDF = 252
74+
val SQL_WINDOW_AGG_ARROW_UDF = 253
7475

7576
val SQL_TABLE_UDF = 300
7677
val SQL_ARROW_TABLE_UDF = 301
@@ -103,6 +104,7 @@ private[spark] object PythonEvalType {
103104
case SQL_SCALAR_ARROW_UDF => "SQL_SCALAR_ARROW_UDF"
104105
case SQL_SCALAR_ARROW_ITER_UDF => "SQL_SCALAR_ARROW_ITER_UDF"
105106
case SQL_GROUPED_AGG_ARROW_UDF => "SQL_GROUPED_AGG_ARROW_UDF"
107+
case SQL_WINDOW_AGG_ARROW_UDF => "SQL_WINDOW_AGG_ARROW_UDF"
106108
}
107109
}
108110

dev/sparktestsupport/modules.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,6 +545,7 @@ def __hash__(self):
545545
"pyspark.sql.tests.arrow.test_arrow_udf",
546546
"pyspark.sql.tests.arrow.test_arrow_udf_grouped_agg",
547547
"pyspark.sql.tests.arrow.test_arrow_udf_scalar",
548+
"pyspark.sql.tests.arrow.test_arrow_udf_window",
548549
"pyspark.sql.tests.pandas.test_pandas_cogrouped_map",
549550
"pyspark.sql.tests.pandas.test_pandas_grouped_map",
550551
"pyspark.sql.tests.pandas.test_pandas_grouped_map_with_state",

python/pyspark/sql/pandas/_typing/__init__.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ GroupedMapUDFTransformWithStateInitStateType = Literal[214]
6464
ArrowScalarUDFType = Literal[250]
6565
ArrowScalarIterUDFType = Literal[251]
6666
ArrowGroupedAggUDFType = Literal[252]
67+
ArrowWindowAggUDFType = Literal[253]
6768

6869
class ArrowVariadicScalarToScalarFunction(Protocol):
6970
def __call__(self, *_: pyarrow.Array) -> pyarrow.Array: ...

0 commit comments

Comments
 (0)