diff --git a/python/pyspark/sql/connect/column.py b/python/pyspark/sql/connect/column.py index 9f610bf18fee..c4ffc54c20b7 100644 --- a/python/pyspark/sql/connect/column.py +++ b/python/pyspark/sql/connect/column.py @@ -263,6 +263,22 @@ def __str__(self) -> str: return f"Column({self._unparsed_identifier})" +class SQLExpression(Expression): + """Returns Expression which contains a string which is a SQL expression + and server side will parse it by Catalyst + """ + + def __init__(self, expr: str) -> None: + super().__init__() + self._expr: str = expr + + def to_plan(self, session: "RemoteSparkSession") -> proto.Expression: + """Returns the Proto representation of the SQL expression.""" + expr = proto.Expression() + expr.expression_string.expression = self._expr + return expr + + class SortOrder(Expression): def __init__(self, col: Column, ascending: bool = True, nullsLast: bool = True) -> None: super().__init__() diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 579403299fe8..82dc1f6a558d 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -36,6 +36,7 @@ Column, Expression, LiteralExpression, + SQLExpression, ) from pyspark.sql.types import ( StructType, @@ -140,6 +141,29 @@ def isEmpty(self) -> bool: def select(self, *cols: "ExpressionOrString") -> "DataFrame": return DataFrame.withPlan(plan.Project(self._plan, *cols), session=self._session) + def selectExpr(self, *expr: Union[str, List[str]]) -> "DataFrame": + """Projects a set of SQL expressions and returns a new :class:`DataFrame`. + + This is a variant of :func:`select` that accepts SQL expressions. + + .. versionadded:: 3.4.0 + + Returns + ------- + :class:`DataFrame` + A DataFrame with new/old columns transformed by expressions. + """ + sql_expr = [] + if len(expr) == 1 and isinstance(expr[0], list): + expr = expr[0] # type: ignore[assignment] + for element in expr: + if isinstance(element, str): + sql_expr.append(SQLExpression(element)) + else: + sql_expr.extend([SQLExpression(e) for e in element]) + + return DataFrame.withPlan(plan.Project(self._plan, *sql_expr), session=self._session) + def agg(self, exprs: Optional[GroupingFrame.MeasuresType]) -> "DataFrame": return self.groupBy().agg(exprs) diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 49973ba70c39..7304c2b9940c 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -20,7 +20,6 @@ import tempfile import grpc # type: ignore -from grpc._channel import _MultiThreadedRendezvous # type: ignore from pyspark.testing.sqlutils import have_pandas, SQLTestUtils @@ -245,7 +244,7 @@ def test_create_global_temp_view(self): # Test when creating a view which is alreayd exists but self.assertTrue(self.spark.catalog.tableExists("global_temp.view_1")) - with self.assertRaises(_MultiThreadedRendezvous): + with self.assertRaises(grpc.RpcError): self.connect.sql("SELECT 1 AS X LIMIT 0").createGlobalTempView("view_1") def test_to_pandas(self): @@ -302,6 +301,31 @@ def test_to_pandas(self): self.spark.sql(query).toPandas(), ) + def test_select_expr(self): + # SPARK-41201: test selectExpr API. + self.assert_eq( + self.connect.read.table(self.tbl_name).selectExpr("id * 2").toPandas(), + self.spark.read.table(self.tbl_name).selectExpr("id * 2").toPandas(), + ) + self.assert_eq( + self.connect.read.table(self.tbl_name) + .selectExpr(["id * 2", "cast(name as long) as name"]) + .toPandas(), + self.spark.read.table(self.tbl_name) + .selectExpr(["id * 2", "cast(name as long) as name"]) + .toPandas(), + ) + + self.assert_eq( + self.connect.read.table(self.tbl_name) + .selectExpr("id * 2", "cast(name as long) as name") + .toPandas(), + self.spark.read.table(self.tbl_name) + .selectExpr("id * 2", "cast(name as long) as name") + .toPandas(), + ) + + @unittest.skip("test_fill_na is flaky") def test_fill_na(self): # SPARK-41128: Test fill na query = """