-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-41201][CONNECT][PYTHON] Implement DataFrame.SelectExpr in Python client
#38723
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -263,6 +263,22 @@ def __str__(self) -> str: | |
| return f"Column({self._unparsed_identifier})" | ||
|
|
||
|
|
||
| class SQLExpression(Expression): | ||
| """Returns Expression which contains a string which is a SQL expression | ||
| and server side will parse it by Catalyst | ||
| """ | ||
|
|
||
| def __init__(self, expr: str) -> None: | ||
| super().__init__() | ||
|
||
| self._expr: str = expr | ||
|
|
||
| def to_plan(self, session: "RemoteSparkSession") -> proto.Expression: | ||
| """Returns the Proto representation of the SQL expression.""" | ||
| expr = proto.Expression() | ||
| expr.expression_string.expression = self._expr | ||
| return expr | ||
|
|
||
|
|
||
| class SortOrder(Expression): | ||
| def __init__(self, col: Column, ascending: bool = True, nullsLast: bool = True) -> None: | ||
| super().__init__() | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -36,6 +36,7 @@ | |
| Column, | ||
| Expression, | ||
| LiteralExpression, | ||
| SQLExpression, | ||
| ) | ||
| from pyspark.sql.types import ( | ||
| StructType, | ||
|
|
@@ -140,6 +141,29 @@ def isEmpty(self) -> bool: | |
| def select(self, *cols: "ExpressionOrString") -> "DataFrame": | ||
| return DataFrame.withPlan(plan.Project(self._plan, *cols), session=self._session) | ||
|
|
||
| def selectExpr(self, *expr: Union[str, List[str]]) -> "DataFrame": | ||
| """Projects a set of SQL expressions and returns a new :class:`DataFrame`. | ||
|
|
||
| This is a variant of :func:`select` that accepts SQL expressions. | ||
|
|
||
| .. versionadded:: 3.4.0 | ||
|
|
||
| Returns | ||
| ------- | ||
| :class:`DataFrame` | ||
| A DataFrame with new/old columns transformed by expressions. | ||
| """ | ||
| sql_expr = [] | ||
| if len(expr) == 1 and isinstance(expr[0], list): | ||
| expr = expr[0] # type: ignore[assignment] | ||
| for element in expr: | ||
| if isinstance(element, str): | ||
| sql_expr.append(SQLExpression(element)) | ||
| else: | ||
| sql_expr.extend([SQLExpression(e) for e in element]) | ||
|
|
||
| return DataFrame.withPlan(plan.Project(self._plan, *sql_expr), session=self._session) | ||
|
||
|
|
||
| def agg(self, exprs: Optional[GroupingFrame.MeasuresType]) -> "DataFrame": | ||
| return self.groupBy().agg(exprs) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,7 +20,6 @@ | |
| import tempfile | ||
|
|
||
| import grpc # type: ignore | ||
| from grpc._channel import _MultiThreadedRendezvous # type: ignore | ||
|
|
||
| from pyspark.testing.sqlutils import have_pandas, SQLTestUtils | ||
|
|
||
|
|
@@ -245,7 +244,7 @@ def test_create_global_temp_view(self): | |
|
|
||
| # Test when creating a view which is alreayd exists but | ||
| self.assertTrue(self.spark.catalog.tableExists("global_temp.view_1")) | ||
| with self.assertRaises(_MultiThreadedRendezvous): | ||
| with self.assertRaises(grpc.RpcError): | ||
| self.connect.sql("SELECT 1 AS X LIMIT 0").createGlobalTempView("view_1") | ||
|
|
||
| def test_to_pandas(self): | ||
|
|
@@ -302,6 +301,31 @@ def test_to_pandas(self): | |
| self.spark.sql(query).toPandas(), | ||
| ) | ||
|
|
||
| def test_select_expr(self): | ||
| # SPARK-41201: test selectExpr API. | ||
| self.assert_eq( | ||
|
||
| self.connect.read.table(self.tbl_name).selectExpr("id * 2").toPandas(), | ||
| self.spark.read.table(self.tbl_name).selectExpr("id * 2").toPandas(), | ||
| ) | ||
| self.assert_eq( | ||
| self.connect.read.table(self.tbl_name) | ||
| .selectExpr(["id * 2", "cast(name as long) as name"]) | ||
| .toPandas(), | ||
| self.spark.read.table(self.tbl_name) | ||
| .selectExpr(["id * 2", "cast(name as long) as name"]) | ||
| .toPandas(), | ||
| ) | ||
|
|
||
| self.assert_eq( | ||
| self.connect.read.table(self.tbl_name) | ||
| .selectExpr("id * 2", "cast(name as long) as name") | ||
| .toPandas(), | ||
| self.spark.read.table(self.tbl_name) | ||
| .selectExpr("id * 2", "cast(name as long) as name") | ||
| .toPandas(), | ||
| ) | ||
|
|
||
| @unittest.skip("test_fill_na is flaky") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @amaliujia why you disable this test again?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah .. I didn't notice this. Can we enable this back?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think so, will send a followup for it
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am pretty sure I removed this after conflict resolution. Actually Martin pointed out another case: #38723 (comment) Basically it seems happened more than once that after code conflict resolution, the code I want to keep is gone.| Maybe I should always do a
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will follow up this soon.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am really guessing if I have more than 1 commit locally, if the first one I resolve the conflict, the following commit that might add something back silently.....
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no worries, add it back in #38763 |
||
| def test_fill_na(self): | ||
| # SPARK-41128: Test fill na | ||
| query = """ | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I still don't like kind of naming .. but this is at least somewhat consistent with what we have in DSv2 so I am fine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's see in the future... I guess we will need to name more...