Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions python/pyspark/sql/connect/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,22 @@ def __str__(self) -> str:
return f"Column({self._unparsed_identifier})"


class SQLExpression(Expression):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still don't like kind of naming .. but this is at least somewhat consistent with what we have in DSv2 so I am fine.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's see in the future... I guess we will need to name more...

"""Returns Expression which contains a string which is a SQL expression
and server side will parse it by Catalyst
"""

def __init__(self, expr: str) -> None:
super().__init__()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we assert here that expr is string and not another expression?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this implementation, we don't need to because the caller has verified the type (thus mypy does not complain).

I think this is a good question:
Generally speaking, I think for public API, we should throw user-facing exception, for internal API, we can assert when we want to defensive check unexpected input.

So it is a question of if we want to enforce checking cross all public/private API (by corresponding ways). I guess maybe not now but worth it at a right time (maybe before 3.4 release).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's fine. One point of having type hints is to avoid asserts on those types too.

self._expr: str = expr

def to_plan(self, session: "RemoteSparkSession") -> proto.Expression:
"""Returns the Proto representation of the SQL expression."""
expr = proto.Expression()
expr.expression_string.expression = self._expr
return expr


class SortOrder(Expression):
def __init__(self, col: Column, ascending: bool = True, nullsLast: bool = True) -> None:
super().__init__()
Expand Down
24 changes: 24 additions & 0 deletions python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
Column,
Expression,
LiteralExpression,
SQLExpression,
)
from pyspark.sql.types import (
StructType,
Expand Down Expand Up @@ -140,6 +141,29 @@ def isEmpty(self) -> bool:
def select(self, *cols: "ExpressionOrString") -> "DataFrame":
return DataFrame.withPlan(plan.Project(self._plan, *cols), session=self._session)

def selectExpr(self, *expr: Union[str, List[str]]) -> "DataFrame":
"""Projects a set of SQL expressions and returns a new :class:`DataFrame`.

This is a variant of :func:`select` that accepts SQL expressions.

.. versionadded:: 3.4.0

Returns
-------
:class:`DataFrame`
A DataFrame with new/old columns transformed by expressions.
"""
sql_expr = []
if len(expr) == 1 and isinstance(expr[0], list):
expr = expr[0] # type: ignore[assignment]
for element in expr:
if isinstance(element, str):
sql_expr.append(SQLExpression(element))
else:
sql_expr.extend([SQLExpression(e) for e in element])

return DataFrame.withPlan(plan.Project(self._plan, *sql_expr), session=self._session)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this becomes an unresolved attribute and just works out of the box?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually it becomes Expression.Builder().setExpressionString() which are SQL expression strings.

str could be different things in DataFrame API.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was not necessarily the question that I had, but I was not remembering correctly the type interface to Project:

def __init__(self, child: Optional["LogicalPlan"], *columns: "ExpressionOrString") -> None:

In this case SQLExpression is an expression and it just works.


def agg(self, exprs: Optional[GroupingFrame.MeasuresType]) -> "DataFrame":
return self.groupBy().agg(exprs)

Expand Down
28 changes: 26 additions & 2 deletions python/pyspark/sql/tests/connect/test_connect_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import tempfile

import grpc # type: ignore
from grpc._channel import _MultiThreadedRendezvous # type: ignore

from pyspark.testing.sqlutils import have_pandas, SQLTestUtils

Expand Down Expand Up @@ -245,7 +244,7 @@ def test_create_global_temp_view(self):

# Test when creating a view which is alreayd exists but
self.assertTrue(self.spark.catalog.tableExists("global_temp.view_1"))
with self.assertRaises(_MultiThreadedRendezvous):
with self.assertRaises(grpc.RpcError):
self.connect.sql("SELECT 1 AS X LIMIT 0").createGlobalTempView("view_1")

def test_to_pandas(self):
Expand Down Expand Up @@ -302,6 +301,31 @@ def test_to_pandas(self):
self.spark.sql(query).toPandas(),
)

def test_select_expr(self):
# SPARK-41201: test selectExpr API.
self.assert_eq(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment with a JIRA ID?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

self.connect.read.table(self.tbl_name).selectExpr("id * 2").toPandas(),
self.spark.read.table(self.tbl_name).selectExpr("id * 2").toPandas(),
)
self.assert_eq(
self.connect.read.table(self.tbl_name)
.selectExpr(["id * 2", "cast(name as long) as name"])
.toPandas(),
self.spark.read.table(self.tbl_name)
.selectExpr(["id * 2", "cast(name as long) as name"])
.toPandas(),
)

self.assert_eq(
self.connect.read.table(self.tbl_name)
.selectExpr("id * 2", "cast(name as long) as name")
.toPandas(),
self.spark.read.table(self.tbl_name)
.selectExpr("id * 2", "cast(name as long) as name")
.toPandas(),
)

@unittest.skip("test_fill_na is flaky")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@amaliujia why you disable this test again?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah .. I didn't notice this. Can we enable this back?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so, will send a followup for it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am pretty sure I removed this after conflict resolution.

Actually Martin pointed out another case: #38723 (comment)

Basically it seems happened more than once that after code conflict resolution, the code I want to keep is gone.|

Maybe I should always do a -i commits square to in case more than 1 commit rebase causing unexpected result.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will follow up this soon.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am really guessing if I have more than 1 commit locally, if the first one I resolve the conflict, the following commit that might add something back silently.....

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no worries, add it back in #38763

def test_fill_na(self):
# SPARK-41128: Test fill na
query = """
Expand Down