Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7905,4 +7905,16 @@ object functions {
}
// scalastyle:off line.size.limit

/**
* Call a builtin or temp function.
*
* @param funcName
* function name
* @param cols
* the expression parameters of function
* @since 3.5.0
*/
@scala.annotation.varargs
def call_function(funcName: String, cols: Column*): Column = Column.fn(funcName, cols: _*)

}
Original file line number Diff line number Diff line change
Expand Up @@ -2873,6 +2873,10 @@ class PlanGenerationTestSuite
fn.random(lit(1))
}

functionTest("call_function") {
fn.call_function("lower", fn.col("g"))
}

test("hll_sketch_agg with column lgConfigK") {
binary.select(fn.hll_sketch_agg(fn.col("bytes"), lit(0)))
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Project [lower(g#0) AS lower(g)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
{
"common": {
"planId": "1"
},
"project": {
"input": {
"common": {
"planId": "0"
},
"localRelation": {
"schema": "struct\u003cid:bigint,a:int,b:double,d:struct\u003cid:bigint,a:int,b:double\u003e,e:array\u003cint\u003e,f:map\u003cstring,struct\u003cid:bigint,a:int,b:double\u003e\u003e,g:string\u003e"
}
},
"expressions": [{
"unresolvedFunction": {
"functionName": "lower",
"arguments": [{
"unresolvedAttribute": {
"unparsedIdentifier": "g"
}
}]
}
}]
}
}
Binary file not shown.
5 changes: 3 additions & 2 deletions python/docs/source/reference/pyspark.sql/functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -460,11 +460,12 @@ Bitwise Functions
getbit


UDF
---
Call Functions
--------------
.. autosummary::
:toctree: api/

call_function
call_udf
pandas_udf
udf
Expand Down
9 changes: 8 additions & 1 deletion python/pyspark/sql/connect/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3853,7 +3853,7 @@ def bitmap_or_agg(col: "ColumnOrName") -> Column:
bitmap_or_agg.__doc__ = pysparkfuncs.bitmap_or_agg.__doc__


# User Defined Function
# Call Functions


def call_udf(udfName: str, *cols: "ColumnOrName") -> Column:
Expand Down Expand Up @@ -3891,6 +3891,13 @@ def udf(
udf.__doc__ = pysparkfuncs.udf.__doc__


def call_function(udfName: str, *cols: "ColumnOrName") -> Column:
return _invoke_function(udfName, *[_to_col(c) for c in cols])


call_function.__doc__ = pysparkfuncs.call_function.__doc__


def _test() -> None:
import sys
import doctest
Expand Down
53 changes: 53 additions & 0 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14394,6 +14394,59 @@ def call_udf(udfName: str, *cols: "ColumnOrName") -> Column:
return _invoke_function("call_udf", udfName, _to_seq(sc, cols, _to_java_column))


@try_remote_functions
def call_function(udfName: str, *cols: "ColumnOrName") -> Column:
"""
Call a builtin or temp function.

.. versionadded:: 3.5.0

Parameters
----------
udfName : str
name of the function
cols : :class:`~pyspark.sql.Column` or str
column names or :class:`~pyspark.sql.Column`\\s to be used in the function

Returns
-------
:class:`~pyspark.sql.Column`
result of executed function.

Examples
--------
>>> from pyspark.sql.functions import call_udf, col
>>> from pyspark.sql.types import IntegerType, StringType
>>> df = spark.createDataFrame([(1, "a"),(2, "b"), (3, "c")],["id", "name"])
>>> _ = spark.udf.register("intX2", lambda i: i * 2, IntegerType())
>>> df.select(call_function("intX2", "id")).show()
+---------+
|intX2(id)|
+---------+
| 2|
| 4|
| 6|
+---------+
>>> _ = spark.udf.register("strX2", lambda s: s * 2, StringType())
>>> df.select(call_function("strX2", col("name"))).show()
+-----------+
|strX2(name)|
+-----------+
| aa|
| bb|
| cc|
+-----------+
>>> df.select(call_function("avg", col("id"))).show()
+-------+
|avg(id)|
+-------+
| 2.0|
+-------+
"""
sc = get_active_spark_context()
return _invoke_function("call_function", udfName, _to_seq(sc, cols, _to_java_column))


@try_remote_functions
def unwrap_udt(col: "ColumnOrName") -> Column:
"""
Expand Down
90 changes: 37 additions & 53 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1936,9 +1936,7 @@ object functions {
* @group math_funcs
* @since 3.5.0
*/
def try_add(left: Column, right: Column): Column = withExpr {
UnresolvedFunction("try_add", Seq(left.expr, right.expr), isDistinct = false)
}
def try_add(left: Column, right: Column): Column = call_function("try_add", left, right)

/**
* Returns the mean calculated from values of a group and the result is null on overflow.
Expand All @@ -1957,9 +1955,8 @@ object functions {
* @group math_funcs
* @since 3.5.0
*/
def try_divide(dividend: Column, divisor: Column): Column = withExpr {
UnresolvedFunction("try_divide", Seq(dividend.expr, divisor.expr), isDistinct = false)
}
def try_divide(dividend: Column, divisor: Column): Column =
call_function("try_divide", dividend, divisor)

/**
* Returns `left``*``right` and the result is null on overflow. The acceptable input types are
Expand All @@ -1968,9 +1965,8 @@ object functions {
* @group math_funcs
* @since 3.5.0
*/
def try_multiply(left: Column, right: Column): Column = withExpr {
UnresolvedFunction("try_multiply", Seq(left.expr, right.expr), isDistinct = false)
}
def try_multiply(left: Column, right: Column): Column =
call_function("try_multiply", left, right)

/**
* Returns `left``-``right` and the result is null on overflow. The acceptable input types are
Expand All @@ -1979,9 +1975,8 @@ object functions {
* @group math_funcs
* @since 3.5.0
*/
def try_subtract(left: Column, right: Column): Column = withExpr {
UnresolvedFunction("try_subtract", Seq(left.expr, right.expr), isDistinct = false)
}
def try_subtract(left: Column, right: Column): Column =
call_function("try_subtract", left, right)

/**
* Returns the sum calculated from values of a group and the result is null on overflow.
Expand Down Expand Up @@ -2366,19 +2361,15 @@ object functions {
* @group math_funcs
* @since 3.3.0
*/
def ceil(e: Column, scale: Column): Column = withExpr {
UnresolvedFunction(Seq("ceil"), Seq(e.expr, scale.expr), isDistinct = false)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just realize that it maybe problematic in such cases, if some users happen to register a udf with the same name ceil

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. If we want avoid this issue, it seems we should make the built-in-only, udf-only, global as you said.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, if the goal of this PR is to replace call_udf with call_function, we can resolve this naming conflict issue in another PRs.

}
def ceil(e: Column, scale: Column): Column = call_function("ceil", e, scale)

/**
* Computes the ceiling of the given value of `e` to 0 decimal places.
*
* @group math_funcs
* @since 1.4.0
*/
def ceil(e: Column): Column = withExpr {
UnresolvedFunction(Seq("ceil"), Seq(e.expr), isDistinct = false)
}
def ceil(e: Column): Column = call_function("ceil", e)

/**
* Computes the ceiling of the given value of `e` to 0 decimal places.
Expand Down Expand Up @@ -2522,19 +2513,15 @@ object functions {
* @group math_funcs
* @since 3.3.0
*/
def floor(e: Column, scale: Column): Column = withExpr {
UnresolvedFunction(Seq("floor"), Seq(e.expr, scale.expr), isDistinct = false)
}
def floor(e: Column, scale: Column): Column = call_function("floor", e, scale)

/**
* Computes the floor of the given value of `e` to 0 decimal places.
*
* @group math_funcs
* @since 1.4.0
*/
def floor(e: Column): Column = withExpr {
UnresolvedFunction(Seq("floor"), Seq(e.expr), isDistinct = false)
}
def floor(e: Column): Column = call_function("floor", e)

/**
* Computes the floor of the given column value to 0 decimal places.
Expand Down Expand Up @@ -4007,9 +3994,8 @@ object functions {
* @group string_funcs
* @since 3.3.0
*/
def lpad(str: Column, len: Int, pad: Array[Byte]): Column = withExpr {
UnresolvedFunction("lpad", Seq(str.expr, lit(len).expr, lit(pad).expr), isDistinct = false)
}
def lpad(str: Column, len: Int, pad: Array[Byte]): Column =
call_function("lpad", str, lit(len), lit(pad))

/**
* Trim the spaces from left end for the specified string value.
Expand Down Expand Up @@ -4190,9 +4176,8 @@ object functions {
* @group string_funcs
* @since 3.3.0
*/
def rpad(str: Column, len: Int, pad: Array[Byte]): Column = withExpr {
UnresolvedFunction("rpad", Seq(str.expr, lit(len).expr, lit(pad).expr), isDistinct = false)
}
def rpad(str: Column, len: Int, pad: Array[Byte]): Column =
call_function("rpad", str, lit(len), lit(pad))

/**
* Repeats a string column n times, and returns it as a new string column.
Expand Down Expand Up @@ -4628,9 +4613,7 @@ object functions {
* @group string_funcs
* @since 3.5.0
*/
def endswith(str: Column, suffix: Column): Column = withExpr {
UnresolvedFunction(Seq("endswith"), Seq(str.expr, suffix.expr), isDistinct = false)
}
def endswith(str: Column, suffix: Column): Column = call_function("endswith", str, suffix)

/**
* Returns a boolean. The value is True if str starts with prefix.
Expand All @@ -4640,9 +4623,7 @@ object functions {
* @group string_funcs
* @since 3.5.0
*/
def startswith(str: Column, prefix: Column): Column = withExpr {
UnresolvedFunction(Seq("startswith"), Seq(str.expr, prefix.expr), isDistinct = false)
}
def startswith(str: Column, prefix: Column): Column = call_function("startswith", str, prefix)

/**
* Returns the ASCII character having the binary equivalent to `n`.
Expand Down Expand Up @@ -4752,9 +4733,7 @@ object functions {
* @group string_funcs
* @since 3.5.0
*/
def contains(left: Column, right: Column): Column = withExpr {
UnresolvedFunction(Seq("contains"), Seq(left.expr, right.expr), isDistinct = false)
}
def contains(left: Column, right: Column): Column = call_function("contains", left, right)

/**
* Returns the `n`-th input, e.g., returns `input2` when `n` is 2.
Expand Down Expand Up @@ -5167,9 +5146,7 @@ object functions {
* @group datetime_funcs
* @since 3.5.0
*/
def extract(field: Column, source: Column): Column = withExpr {
UnresolvedFunction("extract", Seq(field.expr, source.expr), isDistinct = false)
}
def extract(field: Column, source: Column): Column = call_function("extract", field, source)

/**
* Extracts a part of the date/timestamp or interval source.
Expand All @@ -5181,9 +5158,7 @@ object functions {
* @group datetime_funcs
* @since 3.5.0
*/
def date_part(field: Column, source: Column): Column = withExpr {
UnresolvedFunction("date_part", Seq(field.expr, source.expr), isDistinct = false)
}
def date_part(field: Column, source: Column): Column = call_function("date_part", field, source)

/**
* Extracts a part of the date/timestamp or interval source.
Expand All @@ -5195,9 +5170,7 @@ object functions {
* @group datetime_funcs
* @since 3.5.0
*/
def datepart(field: Column, source: Column): Column = withExpr {
UnresolvedFunction("datepart", Seq(field.expr, source.expr), isDistinct = false)
}
def datepart(field: Column, source: Column): Column = call_function("datepart", field, source)

/**
* Returns the last day of the month which the given date belongs to.
Expand Down Expand Up @@ -8363,9 +8336,9 @@ object functions {
* @since 1.5.0
*/
@scala.annotation.varargs
@deprecated("Use call_udf")
@deprecated("Use call_function")
def callUDF(udfName: String, cols: Column*): Column =
call_udf(udfName, cols: _*)
call_function(udfName, cols: _*)

/**
* Call an user-defined function.
Expand All @@ -8383,9 +8356,20 @@ object functions {
* @since 3.2.0
*/
@scala.annotation.varargs
def call_udf(udfName: String, cols: Column*): Column = withExpr {
UnresolvedFunction(udfName, cols.map(_.expr), isDistinct = false)
}
@deprecated("Use call_function")
def call_udf(udfName: String, cols: Column*): Column =
call_function(udfName, cols: _*)

/**
* Call a builtin or temp function.
*
* @param funcName function name
* @param cols the expression parameters of function
* @since 3.5.0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, this new method is not udf_funcs group, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. it's not.

*/
@scala.annotation.varargs
def call_function(funcName: String, cols: Column*): Column =
Copy link
Member

@dongjoon-hyun dongjoon-hyun Jun 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that we need @scala.annotation.varargs.

withExpr { UnresolvedFunction(funcName, cols.map(_.expr), false) }

/**
* Unwrap UDT data type column into its underlying type.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
"countDistinct", "count_distinct", // equivalent to count(distinct foo)
"sum_distinct", // equivalent to sum(distinct foo)
"typedLit", "typedlit", // Scala only
"udaf", "udf" // create function statement in sql
"udaf", "udf", // create function statement in sql
"call_function" // moot in SQL as you just call the function directly
)

val excludedSqlFunctions = Set.empty[String]
Expand Down Expand Up @@ -5914,6 +5915,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
parameters = Map.empty
)
}

test("call_function") {
checkAnswer(testData2.select(call_function("avg", $"a")), testData2.selectExpr("avg(a)"))
}
}

object DataFrameFunctionsSuite {
Expand Down