-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-40538] [CONNECT] Improve built-in function support for Python client. #38270
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
dee4267
341d44e
b7902cb
b327268
4fe501e
bac0664
b7acacb
1c7b7ef
5675ff6
e03c60a
09f1540
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -50,6 +50,34 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { | |
| comparePlans(connectPlan.analyze, sparkPlan.analyze, false) | ||
| } | ||
|
|
||
| test("UnresolvedFunction resolution.") { | ||
| { | ||
| import org.apache.spark.sql.connect.dsl.expressions._ | ||
| import org.apache.spark.sql.connect.dsl.plans._ | ||
| assertThrows[IllegalArgumentException] { | ||
| transform(connectTestRelation.select(callFunction("default.hex", Seq("id".protoAttr)))) | ||
| } | ||
| } | ||
|
|
||
| val connectPlan = { | ||
| import org.apache.spark.sql.connect.dsl.expressions._ | ||
| import org.apache.spark.sql.connect.dsl.plans._ | ||
| transform( | ||
| connectTestRelation.select(callFunction(Seq("default", "hex"), Seq("id".protoAttr)))) | ||
| } | ||
|
|
||
| assertThrows[UnsupportedOperationException] { | ||
| connectPlan.analyze | ||
| } | ||
|
|
||
| val validPlan = { | ||
| import org.apache.spark.sql.connect.dsl.expressions._ | ||
| import org.apache.spark.sql.connect.dsl.plans._ | ||
| transform(connectTestRelation.select(callFunction(Seq("hex"), Seq("id".protoAttr)))) | ||
| } | ||
| assert(validPlan.analyze != null) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's better to compare it with the catalyst plan
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not to validate that the catalyst plan exists, but really just that existing functions are actually resolved. The |
||
| } | ||
|
|
||
| test("Basic filter") { | ||
| val connectPlan = { | ||
| import org.apache.spark.sql.connect.dsl.expressions._ | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -26,11 +26,51 @@ | |
| import pyspark.sql.connect.proto as proto | ||
|
|
||
|
|
||
| def _bin_op( | ||
| name: str, doc: str = "binary function", reverse: bool = False | ||
| ) -> Callable[["ColumnRef", Any], "Expression"]: | ||
| def _(self: "ColumnRef", other: Any) -> "Expression": | ||
| if isinstance(other, get_args(PrimitiveType)): | ||
| other = LiteralExpression(other) | ||
| if not reverse: | ||
| return ScalarFunctionExpression(name, self, other) | ||
| else: | ||
| return ScalarFunctionExpression(name, other, self) | ||
|
|
||
| return _ | ||
|
|
||
|
|
||
| class Expression(object): | ||
| """ | ||
| Expression base class. | ||
| """ | ||
|
|
||
| __gt__ = _bin_op(">") | ||
grundprinzip marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| __lt__ = _bin_op(">") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, I think this was a mistake.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| __add__ = _bin_op("+") | ||
| __sub__ = _bin_op("-") | ||
| __mul__ = _bin_op("*") | ||
| __div__ = _bin_op("/") | ||
| __truediv__ = _bin_op("/") | ||
| __mod__ = _bin_op("%") | ||
| __radd__ = _bin_op("+", reverse=True) | ||
| __rsub__ = _bin_op("-", reverse=True) | ||
| __rmul__ = _bin_op("*", reverse=True) | ||
| __rdiv__ = _bin_op("/", reverse=True) | ||
| __rtruediv__ = _bin_op("/", reverse=True) | ||
| __pow__ = _bin_op("pow") | ||
| __rpow__ = _bin_op("pow", reverse=True) | ||
| __ge__ = _bin_op(">=") | ||
| __le__ = _bin_op("<=") | ||
|
|
||
| def __eq__(self, other: Any) -> "Expression": # type: ignore[override] | ||
| """Returns a binary expression with the current column as the left | ||
| side and the other expression as the right side. | ||
| """ | ||
| if isinstance(other, get_args(PrimitiveType)): | ||
| other = LiteralExpression(other) | ||
| return ScalarFunctionExpression("==", self, other) | ||
|
|
||
| def __init__(self) -> None: | ||
| pass | ||
|
|
||
|
|
@@ -73,20 +113,6 @@ def __str__(self) -> str: | |
| return f"Literal({self._value})" | ||
|
|
||
|
|
||
| def _bin_op( | ||
| name: str, doc: str = "binary function", reverse: bool = False | ||
| ) -> Callable[["ColumnRef", Any], Expression]: | ||
| def _(self: "ColumnRef", other: Any) -> Expression: | ||
| if isinstance(other, get_args(PrimitiveType)): | ||
| other = LiteralExpression(other) | ||
| if not reverse: | ||
| return ScalarFunctionExpression(name, self, other) | ||
| else: | ||
| return ScalarFunctionExpression(name, other, self) | ||
|
|
||
| return _ | ||
|
|
||
|
|
||
| class ColumnRef(Expression): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should rename this to Should better to keep it matched with either Catalyst internal types or user-facing Spark SQL interface classes.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's have a discussion about this, but this is an unrelated change to this one. I think we should probably call
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 I have been thinking this |
||
| """Represents a column reference. There is no guarantee that this column | ||
| actually exists. In the context of this project, we refer by its name and | ||
|
|
@@ -105,32 +131,6 @@ def name(self) -> str: | |
| """Returns the qualified name of the column reference.""" | ||
| return ".".join(self._parts) | ||
|
|
||
| __gt__ = _bin_op("gt") | ||
| __lt__ = _bin_op("lt") | ||
| __add__ = _bin_op("plus") | ||
| __sub__ = _bin_op("minus") | ||
| __mul__ = _bin_op("multiply") | ||
| __div__ = _bin_op("divide") | ||
| __truediv__ = _bin_op("divide") | ||
| __mod__ = _bin_op("modulo") | ||
| __radd__ = _bin_op("plus", reverse=True) | ||
| __rsub__ = _bin_op("minus", reverse=True) | ||
| __rmul__ = _bin_op("multiply", reverse=True) | ||
| __rdiv__ = _bin_op("divide", reverse=True) | ||
| __rtruediv__ = _bin_op("divide", reverse=True) | ||
| __pow__ = _bin_op("pow") | ||
| __rpow__ = _bin_op("pow", reverse=True) | ||
| __ge__ = _bin_op("greterEquals") | ||
| __le__ = _bin_op("lessEquals") | ||
|
|
||
| def __eq__(self, other: Any) -> Expression: # type: ignore[override] | ||
| """Returns a binary expression with the current column as the left | ||
| side and the other expression as the right side. | ||
| """ | ||
| if isinstance(other, get_args(PrimitiveType)): | ||
| other = LiteralExpression(other) | ||
| return ScalarFunctionExpression("eq", self, other) | ||
|
|
||
| def to_plan(self, session: Optional["RemoteSparkSession"]) -> proto.Expression: | ||
| """Returns the Proto representation of the expression.""" | ||
| expr = proto.Expression() | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,9 +19,12 @@ | |
| import unittest | ||
| import tempfile | ||
|
|
||
| import pandas | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hm .. we gotta fix this or do something. pandas isn't a required library for SQL package. Should probably skip this tests when pandas is not installed for now until we have a clear way to handle this. (see
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Interestingly, nothing in Spark Connect will work atm without pandas because we always call |
||
|
|
||
| from pyspark.sql import SparkSession, Row | ||
| from pyspark.sql.connect.client import RemoteSparkSession | ||
| from pyspark.sql.connect.function_builder import udf | ||
| from pyspark.sql.connect.functions import lit | ||
| from pyspark.testing.connectutils import should_test_connect, connect_requirement_message | ||
| from pyspark.testing.utils import ReusedPySparkTestCase | ||
|
|
||
|
|
@@ -79,6 +82,15 @@ def test_simple_explain_string(self): | |
| result = df.explain() | ||
| self.assertGreater(len(result), 0) | ||
|
|
||
| def test_simple_binary_expressions(self): | ||
| """Test complex expression""" | ||
| df = self.connect.read.table(self.tbl_name) | ||
| pd = df.select(df.id).where(df.id % lit(30) == lit(0)).sort(df.id.asc()).toPandas() | ||
| self.assertEqual(len(pd.index), 4) | ||
|
|
||
| res = pandas.DataFrame(data={"id": [0, 30, 60, 90]}) | ||
| self.assert_(pd.equals(res), f"{pd.to_string()} != {res.to_string()}") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| from pyspark.sql.tests.connect.test_connect_basic import * # noqa: F401 | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: only import them once in
test("UnresolvedFunction resolution.")?