Skip to content

Commit a9da924

Browse files
grundprinzipHyukjinKwon
authored andcommitted
[SPARK-40538][CONNECT] Improve built-in function support for Python client
### What changes were proposed in this pull request? This patch changes the way simple scalar built-in functions are resolved in the Python Spark Connect client. Previously, it was trying to manually load specific functions. With the changes in this patch, the trivial binary operators like `<`, `+`, ... are mapped to their name equivalents in Spark so that the dynamic function lookup works. In addition, it cleans up the Scala planner side to remove the now unnecessary code translating the trivial binary expressions into their equivalent functions. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? UT, E2E Closes #38270 from grundprinzip/spark-40538. Authored-by: Martin Grund <[email protected]> Signed-off-by: Hyukjin Kwon <[email protected]>
1 parent fc4643b commit a9da924

File tree

8 files changed

+156
-60
lines changed

8 files changed

+156
-60
lines changed

connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,44 @@ package object dsl {
9292
.build()
9393
}
9494

95+
/**
96+
* Create an unresolved function from name parts.
97+
*
98+
* @param nameParts
99+
* @param args
100+
* @return
101+
* Expression wrapping the unresolved function.
102+
*/
103+
def callFunction(nameParts: Seq[String], args: Seq[proto.Expression]): proto.Expression = {
104+
proto.Expression
105+
.newBuilder()
106+
.setUnresolvedFunction(
107+
proto.Expression.UnresolvedFunction
108+
.newBuilder()
109+
.addAllParts(nameParts.asJava)
110+
.addAllArguments(args.asJava))
111+
.build()
112+
}
113+
114+
/**
115+
* Creates an UnresolvedFunction from a single identifier.
116+
*
117+
* @param name
118+
* @param args
119+
* @return
120+
* Expression wrapping the unresolved function.
121+
*/
122+
def callFunction(name: String, args: Seq[proto.Expression]): proto.Expression = {
123+
proto.Expression
124+
.newBuilder()
125+
.setUnresolvedFunction(
126+
proto.Expression.UnresolvedFunction
127+
.newBuilder()
128+
.addParts(name)
129+
.addAllArguments(args.asJava))
130+
.build()
131+
}
132+
95133
implicit def intToLiteral(i: Int): proto.Expression =
96134
proto.Expression
97135
.newBuilder()

connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -197,10 +197,6 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
197197
limitExpr = expressions.Literal(limit.getLimit, IntegerType))
198198
}
199199

200-
private def lookupFunction(name: String, args: Seq[Expression]): Expression = {
201-
UnresolvedFunction(Seq(name), args, isDistinct = false)
202-
}
203-
204200
/**
205201
* Translates a scalar function from proto to the Catalyst expression.
206202
*
@@ -211,21 +207,14 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
211207
* @return
212208
*/
213209
private def transformScalarFunction(fun: proto.Expression.UnresolvedFunction): Expression = {
214-
val funName = fun.getPartsList.asScala.mkString(".")
215-
funName match {
216-
case "gt" =>
217-
assert(fun.getArgumentsCount == 2, "`gt` function must have two arguments.")
218-
expressions.GreaterThan(
219-
transformExpression(fun.getArguments(0)),
220-
transformExpression(fun.getArguments(1)))
221-
case "eq" =>
222-
assert(fun.getArgumentsCount == 2, "`eq` function must have two arguments.")
223-
expressions.EqualTo(
224-
transformExpression(fun.getArguments(0)),
225-
transformExpression(fun.getArguments(1)))
226-
case _ =>
227-
lookupFunction(funName, fun.getArgumentsList.asScala.map(transformExpression).toSeq)
210+
if (fun.getPartsCount == 1 && fun.getParts(0).contains(".")) {
211+
throw new IllegalArgumentException(
212+
"Function identifier must be passed as sequence of name parts.")
228213
}
214+
UnresolvedFunction(
215+
fun.getPartsList.asScala.toSeq,
216+
fun.getArgumentsList.asScala.map(transformExpression).toSeq,
217+
isDistinct = false)
229218
}
230219

231220
private def transformAlias(alias: proto.Expression.Alias): Expression = {

connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest {
197197

198198
val joinCondition = proto.Expression.newBuilder.setUnresolvedFunction(
199199
proto.Expression.UnresolvedFunction.newBuilder
200-
.addAllParts(Seq("eq").asJava)
200+
.addAllParts(Seq("==").asJava)
201201
.addArguments(unresolvedAttribute)
202202
.addArguments(unresolvedAttribute)
203203
.build())

connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,34 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
5050
comparePlans(connectPlan.analyze, sparkPlan.analyze, false)
5151
}
5252

53+
test("UnresolvedFunction resolution.") {
54+
{
55+
import org.apache.spark.sql.connect.dsl.expressions._
56+
import org.apache.spark.sql.connect.dsl.plans._
57+
assertThrows[IllegalArgumentException] {
58+
transform(connectTestRelation.select(callFunction("default.hex", Seq("id".protoAttr))))
59+
}
60+
}
61+
62+
val connectPlan = {
63+
import org.apache.spark.sql.connect.dsl.expressions._
64+
import org.apache.spark.sql.connect.dsl.plans._
65+
transform(
66+
connectTestRelation.select(callFunction(Seq("default", "hex"), Seq("id".protoAttr))))
67+
}
68+
69+
assertThrows[UnsupportedOperationException] {
70+
connectPlan.analyze
71+
}
72+
73+
val validPlan = {
74+
import org.apache.spark.sql.connect.dsl.expressions._
75+
import org.apache.spark.sql.connect.dsl.plans._
76+
transform(connectTestRelation.select(callFunction(Seq("hex"), Seq("id".protoAttr))))
77+
}
78+
assert(validPlan.analyze != null)
79+
}
80+
5381
test("Basic filter") {
5482
val connectPlan = {
5583
import org.apache.spark.sql.connect.dsl.expressions._

python/pyspark/sql/connect/column.py

Lines changed: 40 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,51 @@
2626
import pyspark.sql.connect.proto as proto
2727

2828

29+
def _bin_op(
30+
name: str, doc: str = "binary function", reverse: bool = False
31+
) -> Callable[["ColumnRef", Any], "Expression"]:
32+
def _(self: "ColumnRef", other: Any) -> "Expression":
33+
if isinstance(other, get_args(PrimitiveType)):
34+
other = LiteralExpression(other)
35+
if not reverse:
36+
return ScalarFunctionExpression(name, self, other)
37+
else:
38+
return ScalarFunctionExpression(name, other, self)
39+
40+
return _
41+
42+
2943
class Expression(object):
3044
"""
3145
Expression base class.
3246
"""
3347

48+
__gt__ = _bin_op(">")
49+
__lt__ = _bin_op(">")
50+
__add__ = _bin_op("+")
51+
__sub__ = _bin_op("-")
52+
__mul__ = _bin_op("*")
53+
__div__ = _bin_op("/")
54+
__truediv__ = _bin_op("/")
55+
__mod__ = _bin_op("%")
56+
__radd__ = _bin_op("+", reverse=True)
57+
__rsub__ = _bin_op("-", reverse=True)
58+
__rmul__ = _bin_op("*", reverse=True)
59+
__rdiv__ = _bin_op("/", reverse=True)
60+
__rtruediv__ = _bin_op("/", reverse=True)
61+
__pow__ = _bin_op("pow")
62+
__rpow__ = _bin_op("pow", reverse=True)
63+
__ge__ = _bin_op(">=")
64+
__le__ = _bin_op("<=")
65+
66+
def __eq__(self, other: Any) -> "Expression": # type: ignore[override]
67+
"""Returns a binary expression with the current column as the left
68+
side and the other expression as the right side.
69+
"""
70+
if isinstance(other, get_args(PrimitiveType)):
71+
other = LiteralExpression(other)
72+
return ScalarFunctionExpression("==", self, other)
73+
3474
def __init__(self) -> None:
3575
pass
3676

@@ -73,20 +113,6 @@ def __str__(self) -> str:
73113
return f"Literal({self._value})"
74114

75115

76-
def _bin_op(
77-
name: str, doc: str = "binary function", reverse: bool = False
78-
) -> Callable[["ColumnRef", Any], Expression]:
79-
def _(self: "ColumnRef", other: Any) -> Expression:
80-
if isinstance(other, get_args(PrimitiveType)):
81-
other = LiteralExpression(other)
82-
if not reverse:
83-
return ScalarFunctionExpression(name, self, other)
84-
else:
85-
return ScalarFunctionExpression(name, other, self)
86-
87-
return _
88-
89-
90116
class ColumnRef(Expression):
91117
"""Represents a column reference. There is no guarantee that this column
92118
actually exists. In the context of this project, we refer by its name and
@@ -105,32 +131,6 @@ def name(self) -> str:
105131
"""Returns the qualified name of the column reference."""
106132
return ".".join(self._parts)
107133

108-
__gt__ = _bin_op("gt")
109-
__lt__ = _bin_op("lt")
110-
__add__ = _bin_op("plus")
111-
__sub__ = _bin_op("minus")
112-
__mul__ = _bin_op("multiply")
113-
__div__ = _bin_op("divide")
114-
__truediv__ = _bin_op("divide")
115-
__mod__ = _bin_op("modulo")
116-
__radd__ = _bin_op("plus", reverse=True)
117-
__rsub__ = _bin_op("minus", reverse=True)
118-
__rmul__ = _bin_op("multiply", reverse=True)
119-
__rdiv__ = _bin_op("divide", reverse=True)
120-
__rtruediv__ = _bin_op("divide", reverse=True)
121-
__pow__ = _bin_op("pow")
122-
__rpow__ = _bin_op("pow", reverse=True)
123-
__ge__ = _bin_op("greterEquals")
124-
__le__ = _bin_op("lessEquals")
125-
126-
def __eq__(self, other: Any) -> Expression: # type: ignore[override]
127-
"""Returns a binary expression with the current column as the left
128-
side and the other expression as the right side.
129-
"""
130-
if isinstance(other, get_args(PrimitiveType)):
131-
other = LiteralExpression(other)
132-
return ScalarFunctionExpression("eq", self, other)
133-
134134
def to_plan(self, session: Optional["RemoteSparkSession"]) -> proto.Expression:
135135
"""Returns the Proto representation of the expression."""
136136
expr = proto.Expression()

python/pyspark/sql/tests/connect/test_connect_basic.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,12 @@
1919
import unittest
2020
import tempfile
2121

22+
import pandas
23+
2224
from pyspark.sql import SparkSession, Row
2325
from pyspark.sql.connect.client import RemoteSparkSession
2426
from pyspark.sql.connect.function_builder import udf
27+
from pyspark.sql.connect.functions import lit
2528
from pyspark.testing.connectutils import should_test_connect, connect_requirement_message
2629
from pyspark.testing.utils import ReusedPySparkTestCase
2730

@@ -79,6 +82,15 @@ def test_simple_explain_string(self):
7982
result = df.explain()
8083
self.assertGreater(len(result), 0)
8184

85+
def test_simple_binary_expressions(self):
86+
"""Test complex expression"""
87+
df = self.connect.read.table(self.tbl_name)
88+
pd = df.select(df.id).where(df.id % lit(30) == lit(0)).sort(df.id.asc()).toPandas()
89+
self.assertEqual(len(pd.index), 4)
90+
91+
res = pandas.DataFrame(data={"id": [0, 30, 60, 90]})
92+
self.assert_(pd.equals(res), f"{pd.to_string()} != {res.to_string()}")
93+
8294

8395
if __name__ == "__main__":
8496
from pyspark.sql.tests.connect.test_connect_basic import * # noqa: F401

python/pyspark/sql/tests/connect/test_connect_column_expressions.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#
1717

1818
from pyspark.testing.connectutils import PlanOnlyTestFixture
19+
from pyspark.sql.connect.proto import Expression as ProtoExpression
1920
import pyspark.sql.connect as c
2021
import pyspark.sql.connect.plan as p
2122
import pyspark.sql.connect.column as col
@@ -51,6 +52,34 @@ def test_column_literals(self):
5152
plan = fun.lit(10).to_plan(None)
5253
self.assertIs(plan.literal.i32, 10)
5354

55+
def test_column_expressions(self):
56+
"""Test a more complex combination of expressions and their translation into
57+
the protobuf structure."""
58+
df = c.DataFrame.withPlan(p.Read("table"))
59+
60+
expr = df.id % fun.lit(10) == fun.lit(10)
61+
expr_plan = expr.to_plan(None)
62+
self.assertIsNotNone(expr_plan.unresolved_function)
63+
self.assertEqual(expr_plan.unresolved_function.parts[0], "==")
64+
65+
lit_fun = expr_plan.unresolved_function.arguments[1]
66+
self.assertIsInstance(lit_fun, ProtoExpression)
67+
self.assertIsInstance(lit_fun.literal, ProtoExpression.Literal)
68+
self.assertEqual(lit_fun.literal.i32, 10)
69+
70+
mod_fun = expr_plan.unresolved_function.arguments[0]
71+
self.assertIsInstance(mod_fun, ProtoExpression)
72+
self.assertIsInstance(mod_fun.unresolved_function, ProtoExpression.UnresolvedFunction)
73+
self.assertEqual(len(mod_fun.unresolved_function.arguments), 2)
74+
self.assertIsInstance(mod_fun.unresolved_function.arguments[0], ProtoExpression)
75+
self.assertIsInstance(
76+
mod_fun.unresolved_function.arguments[0].unresolved_attribute,
77+
ProtoExpression.UnresolvedAttribute,
78+
)
79+
self.assertEqual(
80+
mod_fun.unresolved_function.arguments[0].unresolved_attribute.parts, ["id"]
81+
)
82+
5483

5584
if __name__ == "__main__":
5685
import unittest

python/pyspark/sql/tests/connect/test_connect_plan_only.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def test_filter(self):
4040
plan.root.filter.condition.unresolved_function, proto.Expression.UnresolvedFunction
4141
)
4242
)
43-
self.assertEqual(plan.root.filter.condition.unresolved_function.parts, ["gt"])
43+
self.assertEqual(plan.root.filter.condition.unresolved_function.parts, [">"])
4444
self.assertEqual(len(plan.root.filter.condition.unresolved_function.arguments), 2)
4545

4646
def test_relation_alias(self):

0 commit comments

Comments
 (0)