From 7b13db41883ed3d42fb69a2073e7d4c5127e3c47 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Tue, 6 Dec 2022 18:10:52 +0800 Subject: [PATCH 1/3] init --- .../protobuf/spark/connect/expressions.proto | 15 ++ .../connect/planner/SparkConnectPlanner.scala | 9 + python/pyspark/sql/connect/column.py | 160 +++++++++++++++++- python/pyspark/sql/connect/functions.py | 95 ++++++----- .../sql/connect/proto/expressions_pb2.py | 82 ++++++--- .../sql/connect/proto/expressions_pb2.pyi | 66 ++++++++ .../tests/connect/test_connect_function.py | 111 ++++++++++++ 7 files changed, 462 insertions(+), 76 deletions(-) diff --git a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto index 6c0facbfeee89..88d172b762b4d 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto @@ -36,6 +36,7 @@ message Expression { UnresolvedStar unresolved_star = 5; Alias alias = 6; Cast cast = 7; + CaseWhen case_when = 8; } message Cast { @@ -180,4 +181,18 @@ message Expression { // (Optional) Alias metadata expressed as a JSON map. optional string metadata = 3; } + + // Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END". + message CaseWhen { + // (Required) The seq of (branch condition, branch value) + repeated Branch branches = 1; + + // (Optional) Value for the else branch. + Expression else_value = 2; + + message Branch { + Expression condition = 1; + Expression value = 2; + } + } } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 55283ca96b13e..e25ce0ac1f8c7 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -410,6 +410,7 @@ class SparkConnectPlanner(session: SparkSession) { case proto.Expression.ExprTypeCase.UNRESOLVED_STAR => transformUnresolvedStar(exp.getUnresolvedStar) case proto.Expression.ExprTypeCase.CAST => transformCast(exp.getCast) + case proto.Expression.ExprTypeCase.CASE_WHEN => transformCaseWhen(exp.getCaseWhen) case _ => throw InvalidPlanInput( s"Expression with ID: ${exp.getExprTypeCase.getNumber} is not supported") @@ -530,6 +531,14 @@ class SparkConnectPlanner(session: SparkSession) { } } + private def transformCaseWhen(casewhen: proto.Expression.CaseWhen): Expression = { + CaseWhen( + branches = casewhen.getBranchesList.asScala + .map(b => (transformExpression(b.getCondition), transformExpression(b.getValue))), + elseValue = + if (casewhen.hasElseValue) Some(transformExpression(casewhen.getElseValue)) else None) + } + private def transformSetOperation(u: proto.SetOperation): LogicalPlan = { assert(u.hasLeftInput && u.hasRightInput, "Union must have 2 inputs") diff --git a/python/pyspark/sql/connect/column.py b/python/pyspark/sql/connect/column.py index 63e95c851db85..25d0b7f037f01 100644 --- a/python/pyspark/sql/connect/column.py +++ b/python/pyspark/sql/connect/column.py @@ -15,7 +15,18 @@ # limitations under the License. # -from typing import get_args, TYPE_CHECKING, Callable, Any, Union, overload, cast, Sequence +from typing import ( + get_args, + TYPE_CHECKING, + Callable, + Any, + Union, + overload, + cast, + Sequence, + Tuple, + Optional, +) import json import decimal @@ -130,6 +141,44 @@ def name(self) -> str: ... +class CaseWhen(Expression): + def __init__( + self, branches: Sequence[Tuple[Expression, Expression]], else_value: Optional[Expression] + ): + + assert isinstance(branches, list) + for branch in branches: + assert ( + isinstance(branch, tuple) + and len(branch) == 2 + and all(isinstance(expr, Expression) for expr in branch) + ) + self._branches = branches + + if else_value is not None: + assert isinstance(else_value, Expression) + + self._else_value = else_value + + def to_plan(self, session: "SparkConnectClient") -> "proto.Expression": + expr = proto.Expression() + for condition, value in self._branches: + branch = proto.Expression.CaseWhen.Branch() + branch.condition.CopyFrom(condition.to_plan(session)) + branch.value.CopyFrom(value.to_plan(session)) + expr.case_when.branches.append(branch) + + if self._else_value is not None: + expr.case_when.else_value.CopyFrom(self._else_value.to_plan(session)) + + return expr + + def __repr__(self) -> str: + _cases = "".join([f" WHEN {c} THEN {v}" for c, v in self._branches]) + _else = f" ELSE {self._else_value}" if self._else_value is not None else "" + return "CASE" + _cases + _else + " END" + + class ColumnAlias(Expression): def __init__(self, parent: Expression, alias: list[str], metadata: Any): @@ -591,6 +640,115 @@ def contains(self, other: Union[PrimitiveType, "Column"]) -> "Column": startswith = _bin_op("startsWith", _startswith_doc) endswith = _bin_op("endsWith", _endswith_doc) + def when(self, condition: "Column", value: Any) -> "Column": + """ + Evaluates a list of conditions and returns one of multiple possible result expressions. + If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + condition : :class:`Column` + a boolean :class:`Column` expression. + value + a literal value, or a :class:`Column` expression. + + Returns + ------- + :class:`Column` + Column representing whether each element of Column is in conditions. + + Examples + -------- + >>> from pyspark.sql import functions as F + >>> df = spark.createDataFrame( + ... [(2, "Alice"), (5, "Bob")], ["age", "name"]) + >>> df.select(df.name, F.when(df.age > 4, 1).when(df.age < 3, -1).otherwise(0)).show() + +-----+------------------------------------------------------------+ + | name|CASE WHEN (age > 4) THEN 1 WHEN (age < 3) THEN -1 ELSE 0 END| + +-----+------------------------------------------------------------+ + |Alice| -1| + | Bob| 1| + +-----+------------------------------------------------------------+ + + See Also + -------- + pyspark.sql.functions.when + """ + if not isinstance(condition, Column): + raise TypeError("condition should be a Column") + + if not isinstance(self._expr, CaseWhen): + raise TypeError( + "when() can only be applied on a Column previously generated by when() function" + ) + + case_when: CaseWhen = cast(CaseWhen, self._expr) + if case_when._else_value is not None: + raise TypeError("when() cannot be applied once otherwise() is applied") + + if isinstance(value, Column): + _value = value._expr + else: + _value = LiteralExpression(value) + + _branches = case_when._branches + [(condition._expr, _value)] + + return Column(CaseWhen(branches=_branches, else_value=None)) + + def otherwise(self, value: Any) -> "Column": + """ + Evaluates a list of conditions and returns one of multiple possible result expressions. + If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + value + a literal value, or a :class:`Column` expression. + + Returns + ------- + :class:`Column` + Column representing whether each element of Column is unmatched conditions. + + Examples + -------- + >>> from pyspark.sql import functions as F + >>> df = spark.createDataFrame( + ... [(2, "Alice"), (5, "Bob")], ["age", "name"]) + >>> df.select(df.name, F.when(df.age > 3, 1).otherwise(0)).show() + +-----+-------------------------------------+ + | name|CASE WHEN (age > 3) THEN 1 ELSE 0 END| + +-----+-------------------------------------+ + |Alice| 0| + | Bob| 1| + +-----+-------------------------------------+ + + See Also + -------- + pyspark.sql.functions.when + """ + if not isinstance(self._expr, CaseWhen): + raise TypeError( + "otherwise() can only be applied on a Column previously generated by when()" + ) + + case_when: CaseWhen = cast(CaseWhen, self._expr) + if case_when._else_value is not None: + raise TypeError( + "otherwise() can only be applied once on a Column previously generated by when()" + ) + + if isinstance(value, Column): + _value = value._expr + else: + _value = LiteralExpression(value) + + return Column(CaseWhen(branches=case_when._branches, else_value=_value)) + def like(self: "Column", other: str) -> "Column": """ SQL like expression. Returns a boolean :class:`Column` based on a SQL LIKE match. diff --git a/python/pyspark/sql/connect/functions.py b/python/pyspark/sql/connect/functions.py index 7eb17bd89acd7..c484349e1ae39 100644 --- a/python/pyspark/sql/connect/functions.py +++ b/python/pyspark/sql/connect/functions.py @@ -16,6 +16,7 @@ # from pyspark.sql.connect.column import ( Column, + CaseWhen, Expression, LiteralExpression, ColumnReference, @@ -546,53 +547,53 @@ def spark_partition_id() -> Column: return _invoke_function("spark_partition_id") -# TODO(SPARK-41319): Support case-when in Column -# def when(condition: Column, value: Any) -> Column: -# """Evaluates a list of conditions and returns one of multiple possible result expressions. -# If :func:`pyspark.sql.Column.otherwise` is not invoked, None is returned for unmatched -# conditions. -# -# .. versionadded:: 3.4.0 -# -# Parameters -# ---------- -# condition : :class:`~pyspark.sql.Column` -# a boolean :class:`~pyspark.sql.Column` expression. -# value : -# a literal value, or a :class:`~pyspark.sql.Column` expression. -# -# Returns -# ------- -# :class:`~pyspark.sql.Column` -# column representing when expression. -# -# Examples -# -------- -# >>> df = spark.range(3) -# >>> df.select(when(df['id'] == 2, 3).otherwise(4).alias("age")).show() -# +---+ -# |age| -# +---+ -# | 4| -# | 4| -# | 3| -# +---+ -# -# >>> df.select(when(df.id == 2, df.id + 1).alias("age")).show() -# +----+ -# | age| -# +----+ -# |null| -# |null| -# | 3| -# +----+ -# """ -# # Explicitly not using ColumnOrName type here to make reading condition less opaque -# if not isinstance(condition, Column): -# raise TypeError("condition should be a Column") -# v = value._jc if isinstance(value, Column) else value -# -# return _invoke_function("when", condition._jc, v) +def when(condition: Column, value: Any) -> Column: + """Evaluates a list of conditions and returns one of multiple possible result expressions. + If :func:`pyspark.sql.Column.otherwise` is not invoked, None is returned for unmatched + conditions. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + condition : :class:`~pyspark.sql.Column` + a boolean :class:`~pyspark.sql.Column` expression. + value : + a literal value, or a :class:`~pyspark.sql.Column` expression. + + Returns + ------- + :class:`~pyspark.sql.Column` + column representing when expression. + + Examples + -------- + >>> df = spark.range(3) + >>> df.select(when(df['id'] == 2, 3).otherwise(4).alias("age")).show() + +---+ + |age| + +---+ + | 4| + | 4| + | 3| + +---+ + + >>> df.select(when(df.id == 2, df.id + 1).alias("age")).show() + +----+ + | age| + +----+ + |null| + |null| + | 3| + +----+ + """ + # Explicitly not using ColumnOrName type here to make reading condition less opaque + if not isinstance(condition, Column): + raise TypeError("condition should be a Column") + + value_expr = value._expr if isinstance(value, Column) else LiteralExpression(value) + + return Column(CaseWhen(branches=[(condition._expr, value_expr)], else_value=None)) # Sort Functions diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py b/python/pyspark/sql/connect/proto/expressions_pb2.py index 91c57a9ef2202..c4a5111e2d6ea 100644 --- a/python/pyspark/sql/connect/proto/expressions_pb2.py +++ b/python/pyspark/sql/connect/proto/expressions_pb2.py @@ -33,7 +33,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19spark/connect/types.proto"\xf4\x14\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunction\x12Y\n\x11\x65xpression_string\x18\x04 \x01(\x0b\x32*.spark.connect.Expression.ExpressionStringH\x00R\x10\x65xpressionString\x12S\n\x0funresolved_star\x18\x05 \x01(\x0b\x32(.spark.connect.Expression.UnresolvedStarH\x00R\x0eunresolvedStar\x12\x37\n\x05\x61lias\x18\x06 \x01(\x0b\x32\x1f.spark.connect.Expression.AliasH\x00R\x05\x61lias\x12\x34\n\x04\x63\x61st\x18\x07 \x01(\x0b\x32\x1e.spark.connect.Expression.CastH\x00R\x04\x63\x61st\x1a\x91\x01\n\x04\x43\x61st\x12-\n\x04\x65xpr\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x04\x65xpr\x12-\n\x04type\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\x04type\x12\x1b\n\x08type_str\x18\x03 \x01(\tH\x00R\x07typeStrB\x0e\n\x0c\x63\x61st_to_type\x1a\xb2\x0b\n\x07Literal\x12\x14\n\x04null\x18\x01 \x01(\x08H\x00R\x04null\x12\x18\n\x06\x62inary\x18\x02 \x01(\x0cH\x00R\x06\x62inary\x12\x1a\n\x07\x62oolean\x18\x03 \x01(\x08H\x00R\x07\x62oolean\x12\x14\n\x04\x62yte\x18\x04 \x01(\x05H\x00R\x04\x62yte\x12\x16\n\x05short\x18\x05 \x01(\x05H\x00R\x05short\x12\x1a\n\x07integer\x18\x06 \x01(\x05H\x00R\x07integer\x12\x14\n\x04long\x18\x07 \x01(\x03H\x00R\x04long\x12\x16\n\x05\x66loat\x18\n \x01(\x02H\x00R\x05\x66loat\x12\x18\n\x06\x64ouble\x18\x0b \x01(\x01H\x00R\x06\x64ouble\x12\x45\n\x07\x64\x65\x63imal\x18\x0c \x01(\x0b\x32).spark.connect.Expression.Literal.DecimalH\x00R\x07\x64\x65\x63imal\x12\x18\n\x06string\x18\r \x01(\tH\x00R\x06string\x12\x14\n\x04\x64\x61te\x18\x10 \x01(\x05H\x00R\x04\x64\x61te\x12\x1e\n\ttimestamp\x18\x11 \x01(\x03H\x00R\ttimestamp\x12%\n\rtimestamp_ntz\x18\x12 \x01(\x03H\x00R\x0ctimestampNtz\x12\x61\n\x11\x63\x61lendar_interval\x18\x13 \x01(\x0b\x32\x32.spark.connect.Expression.Literal.CalendarIntervalH\x00R\x10\x63\x61lendarInterval\x12\x30\n\x13year_month_interval\x18\x14 \x01(\x05H\x00R\x11yearMonthInterval\x12,\n\x11\x64\x61y_time_interval\x18\x15 \x01(\x03H\x00R\x0f\x64\x61yTimeInterval\x12?\n\x05\x61rray\x18\x16 \x01(\x0b\x32\'.spark.connect.Expression.Literal.ArrayH\x00R\x05\x61rray\x12\x42\n\x06struct\x18\x17 \x01(\x0b\x32(.spark.connect.Expression.Literal.StructH\x00R\x06struct\x12\x39\n\x03map\x18\x18 \x01(\x0b\x32%.spark.connect.Expression.Literal.MapH\x00R\x03map\x12\x1a\n\x08nullable\x18\x32 \x01(\x08R\x08nullable\x12\x38\n\x18type_variation_reference\x18\x33 \x01(\rR\x16typeVariationReference\x1au\n\x07\x44\x65\x63imal\x12\x14\n\x05value\x18\x01 \x01(\tR\x05value\x12!\n\tprecision\x18\x02 \x01(\x05H\x00R\tprecision\x88\x01\x01\x12\x19\n\x05scale\x18\x03 \x01(\x05H\x01R\x05scale\x88\x01\x01\x42\x0c\n\n_precisionB\x08\n\x06_scale\x1a\x62\n\x10\x43\x61lendarInterval\x12\x16\n\x06months\x18\x01 \x01(\x05R\x06months\x12\x12\n\x04\x64\x61ys\x18\x02 \x01(\x05R\x04\x64\x61ys\x12"\n\x0cmicroseconds\x18\x03 \x01(\x03R\x0cmicroseconds\x1a\x43\n\x06Struct\x12\x39\n\x06\x66ields\x18\x01 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06\x66ields\x1a\x42\n\x05\x41rray\x12\x39\n\x06values\x18\x01 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values\x1a\xbd\x01\n\x03Map\x12@\n\x05pairs\x18\x01 \x03(\x0b\x32*.spark.connect.Expression.Literal.Map.PairR\x05pairs\x1at\n\x04Pair\x12\x33\n\x03key\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x03key\x12\x37\n\x05value\x18\x02 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x05valueB\x0e\n\x0cliteral_type\x1a\x46\n\x13UnresolvedAttribute\x12/\n\x13unparsed_identifier\x18\x01 \x01(\tR\x12unparsedIdentifier\x1a\xcc\x01\n\x12UnresolvedFunction\x12#\n\rfunction_name\x18\x01 \x01(\tR\x0c\x66unctionName\x12\x37\n\targuments\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments\x12\x1f\n\x0bis_distinct\x18\x03 \x01(\x08R\nisDistinct\x12\x37\n\x18is_user_defined_function\x18\x04 \x01(\x08R\x15isUserDefinedFunction\x1a\x32\n\x10\x45xpressionString\x12\x1e\n\nexpression\x18\x01 \x01(\tR\nexpression\x1a(\n\x0eUnresolvedStar\x12\x16\n\x06target\x18\x01 \x03(\tR\x06target\x1ax\n\x05\x41lias\x12-\n\x04\x65xpr\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x04\x65xpr\x12\x12\n\x04name\x18\x02 \x03(\tR\x04name\x12\x1f\n\x08metadata\x18\x03 \x01(\tH\x00R\x08metadata\x88\x01\x01\x42\x0b\n\t_metadataB\x0b\n\texpr_typeB"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3' + b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19spark/connect/types.proto"\xb9\x17\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunction\x12Y\n\x11\x65xpression_string\x18\x04 \x01(\x0b\x32*.spark.connect.Expression.ExpressionStringH\x00R\x10\x65xpressionString\x12S\n\x0funresolved_star\x18\x05 \x01(\x0b\x32(.spark.connect.Expression.UnresolvedStarH\x00R\x0eunresolvedStar\x12\x37\n\x05\x61lias\x18\x06 \x01(\x0b\x32\x1f.spark.connect.Expression.AliasH\x00R\x05\x61lias\x12\x34\n\x04\x63\x61st\x18\x07 \x01(\x0b\x32\x1e.spark.connect.Expression.CastH\x00R\x04\x63\x61st\x12\x41\n\tcase_when\x18\x08 \x01(\x0b\x32".spark.connect.Expression.CaseWhenH\x00R\x08\x63\x61seWhen\x1a\x91\x01\n\x04\x43\x61st\x12-\n\x04\x65xpr\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x04\x65xpr\x12-\n\x04type\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\x04type\x12\x1b\n\x08type_str\x18\x03 \x01(\tH\x00R\x07typeStrB\x0e\n\x0c\x63\x61st_to_type\x1a\xb2\x0b\n\x07Literal\x12\x14\n\x04null\x18\x01 \x01(\x08H\x00R\x04null\x12\x18\n\x06\x62inary\x18\x02 \x01(\x0cH\x00R\x06\x62inary\x12\x1a\n\x07\x62oolean\x18\x03 \x01(\x08H\x00R\x07\x62oolean\x12\x14\n\x04\x62yte\x18\x04 \x01(\x05H\x00R\x04\x62yte\x12\x16\n\x05short\x18\x05 \x01(\x05H\x00R\x05short\x12\x1a\n\x07integer\x18\x06 \x01(\x05H\x00R\x07integer\x12\x14\n\x04long\x18\x07 \x01(\x03H\x00R\x04long\x12\x16\n\x05\x66loat\x18\n \x01(\x02H\x00R\x05\x66loat\x12\x18\n\x06\x64ouble\x18\x0b \x01(\x01H\x00R\x06\x64ouble\x12\x45\n\x07\x64\x65\x63imal\x18\x0c \x01(\x0b\x32).spark.connect.Expression.Literal.DecimalH\x00R\x07\x64\x65\x63imal\x12\x18\n\x06string\x18\r \x01(\tH\x00R\x06string\x12\x14\n\x04\x64\x61te\x18\x10 \x01(\x05H\x00R\x04\x64\x61te\x12\x1e\n\ttimestamp\x18\x11 \x01(\x03H\x00R\ttimestamp\x12%\n\rtimestamp_ntz\x18\x12 \x01(\x03H\x00R\x0ctimestampNtz\x12\x61\n\x11\x63\x61lendar_interval\x18\x13 \x01(\x0b\x32\x32.spark.connect.Expression.Literal.CalendarIntervalH\x00R\x10\x63\x61lendarInterval\x12\x30\n\x13year_month_interval\x18\x14 \x01(\x05H\x00R\x11yearMonthInterval\x12,\n\x11\x64\x61y_time_interval\x18\x15 \x01(\x03H\x00R\x0f\x64\x61yTimeInterval\x12?\n\x05\x61rray\x18\x16 \x01(\x0b\x32\'.spark.connect.Expression.Literal.ArrayH\x00R\x05\x61rray\x12\x42\n\x06struct\x18\x17 \x01(\x0b\x32(.spark.connect.Expression.Literal.StructH\x00R\x06struct\x12\x39\n\x03map\x18\x18 \x01(\x0b\x32%.spark.connect.Expression.Literal.MapH\x00R\x03map\x12\x1a\n\x08nullable\x18\x32 \x01(\x08R\x08nullable\x12\x38\n\x18type_variation_reference\x18\x33 \x01(\rR\x16typeVariationReference\x1au\n\x07\x44\x65\x63imal\x12\x14\n\x05value\x18\x01 \x01(\tR\x05value\x12!\n\tprecision\x18\x02 \x01(\x05H\x00R\tprecision\x88\x01\x01\x12\x19\n\x05scale\x18\x03 \x01(\x05H\x01R\x05scale\x88\x01\x01\x42\x0c\n\n_precisionB\x08\n\x06_scale\x1a\x62\n\x10\x43\x61lendarInterval\x12\x16\n\x06months\x18\x01 \x01(\x05R\x06months\x12\x12\n\x04\x64\x61ys\x18\x02 \x01(\x05R\x04\x64\x61ys\x12"\n\x0cmicroseconds\x18\x03 \x01(\x03R\x0cmicroseconds\x1a\x43\n\x06Struct\x12\x39\n\x06\x66ields\x18\x01 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06\x66ields\x1a\x42\n\x05\x41rray\x12\x39\n\x06values\x18\x01 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values\x1a\xbd\x01\n\x03Map\x12@\n\x05pairs\x18\x01 \x03(\x0b\x32*.spark.connect.Expression.Literal.Map.PairR\x05pairs\x1at\n\x04Pair\x12\x33\n\x03key\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x03key\x12\x37\n\x05value\x18\x02 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x05valueB\x0e\n\x0cliteral_type\x1a\x46\n\x13UnresolvedAttribute\x12/\n\x13unparsed_identifier\x18\x01 \x01(\tR\x12unparsedIdentifier\x1a\xcc\x01\n\x12UnresolvedFunction\x12#\n\rfunction_name\x18\x01 \x01(\tR\x0c\x66unctionName\x12\x37\n\targuments\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments\x12\x1f\n\x0bis_distinct\x18\x03 \x01(\x08R\nisDistinct\x12\x37\n\x18is_user_defined_function\x18\x04 \x01(\x08R\x15isUserDefinedFunction\x1a\x32\n\x10\x45xpressionString\x12\x1e\n\nexpression\x18\x01 \x01(\tR\nexpression\x1a(\n\x0eUnresolvedStar\x12\x16\n\x06target\x18\x01 \x03(\tR\x06target\x1ax\n\x05\x41lias\x12-\n\x04\x65xpr\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x04\x65xpr\x12\x12\n\x04name\x18\x02 \x03(\tR\x04name\x12\x1f\n\x08metadata\x18\x03 \x01(\tH\x00R\x08metadata\x88\x01\x01\x42\x0b\n\t_metadata\x1a\xff\x01\n\x08\x43\x61seWhen\x12\x45\n\x08\x62ranches\x18\x01 \x03(\x0b\x32).spark.connect.Expression.CaseWhen.BranchR\x08\x62ranches\x12\x38\n\nelse_value\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\telseValue\x1ar\n\x06\x42ranch\x12\x37\n\tcondition\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\tcondition\x12/\n\x05value\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x05valueB\x0b\n\texpr_typeB"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3' ) @@ -51,6 +51,8 @@ _EXPRESSION_EXPRESSIONSTRING = _EXPRESSION.nested_types_by_name["ExpressionString"] _EXPRESSION_UNRESOLVEDSTAR = _EXPRESSION.nested_types_by_name["UnresolvedStar"] _EXPRESSION_ALIAS = _EXPRESSION.nested_types_by_name["Alias"] +_EXPRESSION_CASEWHEN = _EXPRESSION.nested_types_by_name["CaseWhen"] +_EXPRESSION_CASEWHEN_BRANCH = _EXPRESSION_CASEWHEN.nested_types_by_name["Branch"] Expression = _reflection.GeneratedProtocolMessageType( "Expression", (_message.Message,), @@ -172,6 +174,24 @@ # @@protoc_insertion_point(class_scope:spark.connect.Expression.Alias) }, ), + "CaseWhen": _reflection.GeneratedProtocolMessageType( + "CaseWhen", + (_message.Message,), + { + "Branch": _reflection.GeneratedProtocolMessageType( + "Branch", + (_message.Message,), + { + "DESCRIPTOR": _EXPRESSION_CASEWHEN_BRANCH, + "__module__": "spark.connect.expressions_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.Expression.CaseWhen.Branch) + }, + ), + "DESCRIPTOR": _EXPRESSION_CASEWHEN, + "__module__": "spark.connect.expressions_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.Expression.CaseWhen) + }, + ), "DESCRIPTOR": _EXPRESSION, "__module__": "spark.connect.expressions_pb2" # @@protoc_insertion_point(class_scope:spark.connect.Expression) @@ -191,37 +211,43 @@ _sym_db.RegisterMessage(Expression.ExpressionString) _sym_db.RegisterMessage(Expression.UnresolvedStar) _sym_db.RegisterMessage(Expression.Alias) +_sym_db.RegisterMessage(Expression.CaseWhen) +_sym_db.RegisterMessage(Expression.CaseWhen.Branch) if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None DESCRIPTOR._serialized_options = b"\n\036org.apache.spark.connect.protoP\001" _EXPRESSION._serialized_start = 78 - _EXPRESSION._serialized_end = 2754 - _EXPRESSION_CAST._serialized_start = 640 - _EXPRESSION_CAST._serialized_end = 785 - _EXPRESSION_LITERAL._serialized_start = 788 - _EXPRESSION_LITERAL._serialized_end = 2246 - _EXPRESSION_LITERAL_DECIMAL._serialized_start = 1684 - _EXPRESSION_LITERAL_DECIMAL._serialized_end = 1801 - _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 1803 - _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 1901 - _EXPRESSION_LITERAL_STRUCT._serialized_start = 1903 - _EXPRESSION_LITERAL_STRUCT._serialized_end = 1970 - _EXPRESSION_LITERAL_ARRAY._serialized_start = 1972 - _EXPRESSION_LITERAL_ARRAY._serialized_end = 2038 - _EXPRESSION_LITERAL_MAP._serialized_start = 2041 - _EXPRESSION_LITERAL_MAP._serialized_end = 2230 - _EXPRESSION_LITERAL_MAP_PAIR._serialized_start = 2114 - _EXPRESSION_LITERAL_MAP_PAIR._serialized_end = 2230 - _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 2248 - _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 2318 - _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 2321 - _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 2525 - _EXPRESSION_EXPRESSIONSTRING._serialized_start = 2527 - _EXPRESSION_EXPRESSIONSTRING._serialized_end = 2577 - _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 2579 - _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 2619 - _EXPRESSION_ALIAS._serialized_start = 2621 - _EXPRESSION_ALIAS._serialized_end = 2741 + _EXPRESSION._serialized_end = 3079 + _EXPRESSION_CAST._serialized_start = 707 + _EXPRESSION_CAST._serialized_end = 852 + _EXPRESSION_LITERAL._serialized_start = 855 + _EXPRESSION_LITERAL._serialized_end = 2313 + _EXPRESSION_LITERAL_DECIMAL._serialized_start = 1751 + _EXPRESSION_LITERAL_DECIMAL._serialized_end = 1868 + _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 1870 + _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 1968 + _EXPRESSION_LITERAL_STRUCT._serialized_start = 1970 + _EXPRESSION_LITERAL_STRUCT._serialized_end = 2037 + _EXPRESSION_LITERAL_ARRAY._serialized_start = 2039 + _EXPRESSION_LITERAL_ARRAY._serialized_end = 2105 + _EXPRESSION_LITERAL_MAP._serialized_start = 2108 + _EXPRESSION_LITERAL_MAP._serialized_end = 2297 + _EXPRESSION_LITERAL_MAP_PAIR._serialized_start = 2181 + _EXPRESSION_LITERAL_MAP_PAIR._serialized_end = 2297 + _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 2315 + _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 2385 + _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 2388 + _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 2592 + _EXPRESSION_EXPRESSIONSTRING._serialized_start = 2594 + _EXPRESSION_EXPRESSIONSTRING._serialized_end = 2644 + _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 2646 + _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 2686 + _EXPRESSION_ALIAS._serialized_start = 2688 + _EXPRESSION_ALIAS._serialized_end = 2808 + _EXPRESSION_CASEWHEN._serialized_start = 2811 + _EXPRESSION_CASEWHEN._serialized_end = 3066 + _EXPRESSION_CASEWHEN_BRANCH._serialized_start = 2952 + _EXPRESSION_CASEWHEN_BRANCH._serialized_end = 3066 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi b/python/pyspark/sql/connect/proto/expressions_pb2.pyi index 2c486f62a9ddb..2eebabb672dab 100644 --- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi +++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi @@ -633,6 +633,63 @@ class Expression(google.protobuf.message.Message): self, oneof_group: typing_extensions.Literal["_metadata", b"_metadata"] ) -> typing_extensions.Literal["metadata"] | None: ... + class CaseWhen(google.protobuf.message.Message): + """Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END".""" + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + class Branch(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + CONDITION_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + @property + def condition(self) -> global___Expression: ... + @property + def value(self) -> global___Expression: ... + def __init__( + self, + *, + condition: global___Expression | None = ..., + value: global___Expression | None = ..., + ) -> None: ... + def HasField( + self, + field_name: typing_extensions.Literal["condition", b"condition", "value", b"value"], + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal["condition", b"condition", "value", b"value"], + ) -> None: ... + + BRANCHES_FIELD_NUMBER: builtins.int + ELSE_VALUE_FIELD_NUMBER: builtins.int + @property + def branches( + self, + ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[ + global___Expression.CaseWhen.Branch + ]: + """(Required) The seq of (branch condition, branch value)""" + @property + def else_value(self) -> global___Expression: + """(Optional) Value for the else branch.""" + def __init__( + self, + *, + branches: collections.abc.Iterable[global___Expression.CaseWhen.Branch] | None = ..., + else_value: global___Expression | None = ..., + ) -> None: ... + def HasField( + self, field_name: typing_extensions.Literal["else_value", b"else_value"] + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "branches", b"branches", "else_value", b"else_value" + ], + ) -> None: ... + LITERAL_FIELD_NUMBER: builtins.int UNRESOLVED_ATTRIBUTE_FIELD_NUMBER: builtins.int UNRESOLVED_FUNCTION_FIELD_NUMBER: builtins.int @@ -640,6 +697,7 @@ class Expression(google.protobuf.message.Message): UNRESOLVED_STAR_FIELD_NUMBER: builtins.int ALIAS_FIELD_NUMBER: builtins.int CAST_FIELD_NUMBER: builtins.int + CASE_WHEN_FIELD_NUMBER: builtins.int @property def literal(self) -> global___Expression.Literal: ... @property @@ -654,6 +712,8 @@ class Expression(google.protobuf.message.Message): def alias(self) -> global___Expression.Alias: ... @property def cast(self) -> global___Expression.Cast: ... + @property + def case_when(self) -> global___Expression.CaseWhen: ... def __init__( self, *, @@ -664,12 +724,15 @@ class Expression(google.protobuf.message.Message): unresolved_star: global___Expression.UnresolvedStar | None = ..., alias: global___Expression.Alias | None = ..., cast: global___Expression.Cast | None = ..., + case_when: global___Expression.CaseWhen | None = ..., ) -> None: ... def HasField( self, field_name: typing_extensions.Literal[ "alias", b"alias", + "case_when", + b"case_when", "cast", b"cast", "expr_type", @@ -691,6 +754,8 @@ class Expression(google.protobuf.message.Message): field_name: typing_extensions.Literal[ "alias", b"alias", + "case_when", + b"case_when", "cast", b"cast", "expr_type", @@ -717,6 +782,7 @@ class Expression(google.protobuf.message.Message): "unresolved_star", "alias", "cast", + "case_when", ] | None: ... global___Expression = Expression diff --git a/python/pyspark/sql/tests/connect/test_connect_function.py b/python/pyspark/sql/tests/connect/test_connect_function.py index 96801c58ca9e6..1caf138f34dec 100644 --- a/python/pyspark/sql/tests/connect/test_connect_function.py +++ b/python/pyspark/sql/tests/connect/test_connect_function.py @@ -164,6 +164,117 @@ def test_normal_functions(self): sdf.select(SF.spark_partition_id()).toPandas(), ) + def test_when_otherwise(self): + from pyspark.sql import functions as SF + from pyspark.sql.connect import functions as CF + + query = """ + SELECT * FROM VALUES + (0, float("NAN"), NULL), (1, NULL, 2.0), (2, 2.1, 3.5), (3, 3.1, float("NAN")) + AS tab(a, b, c) + """ + # +---+----+----+ + # | a| b| c| + # +---+----+----+ + # | 0| NaN|null| + # | 1|null| 2.0| + # | 2| 2.1| 3.5| + # | 3| 3.1| NaN| + # +---+----+----+ + + cdf = self.connect.sql(query) + sdf = self.spark.sql(query) + + self.assert_eq( + cdf.select(CF.when(cdf.a == 0, 1.0).otherwise(2.0)).toPandas(), + sdf.select(SF.when(sdf.a == 0, 1.0).otherwise(2.0)).toPandas(), + ) + self.assert_eq( + cdf.select(CF.when(cdf.a < 1, cdf.b).otherwise(cdf.c)).toPandas(), + sdf.select(SF.when(sdf.a < 1, sdf.b).otherwise(sdf.c)).toPandas(), + ) + self.assert_eq( + cdf.select( + CF.when(cdf.a == 0, 1.0) + .when(CF.col("a") == 1, 2.0) + .when(cdf.a == 2, -1.0) + .otherwise(cdf.c) + ).toPandas(), + sdf.select( + SF.when(sdf.a == 0, 1.0) + .when(SF.col("a") == 1, 2.0) + .when(sdf.a == 2, -1.0) + .otherwise(sdf.c) + ).toPandas(), + ) + self.assert_eq( + cdf.select( + CF.when(cdf.a < cdf.b, 1.0) + .when(CF.col("a") == 1, CF.abs("c") + cdf.b) + .otherwise(cdf.c + CF.col("a")) + ).toPandas(), + sdf.select( + SF.when(sdf.a < sdf.b, 1.0) + .when(SF.col("a") == 1, SF.abs("c") + sdf.b) + .otherwise(sdf.c + SF.col("a")) + ).toPandas(), + ) + + # when without otherwise + self.assert_eq( + cdf.select(CF.when(cdf.a < 1, cdf.b)).toPandas(), + sdf.select(SF.when(sdf.a < 1, sdf.b)).toPandas(), + ) + self.assert_eq( + cdf.select( + CF.when(cdf.a == 0, 1.0) + .when(CF.col("a") == 1, cdf.b + CF.col("c")) + .when(cdf.a == 2, CF.abs(cdf.b)) + ).toPandas(), + sdf.select( + SF.when(sdf.a == 0, 1.0) + .when(SF.col("a") == 1, sdf.b + SF.col("c")) + .when(sdf.a == 2, SF.abs(sdf.b)) + ).toPandas(), + ) + + # check error + with self.assertRaisesRegex( + TypeError, + "when.* can only be applied on a Column previously generated by when.* function", + ): + cdf.a.when(cdf.a == 0, 1.0) + + with self.assertRaisesRegex( + TypeError, + "when.* can only be applied on a Column previously generated by when.* function", + ): + CF.col("c").when(cdf.a == 0, 1.0) + + with self.assertRaisesRegex( + TypeError, + "otherwise.* can only be applied on a Column previously generated by when", + ): + cdf.a.otherwise(1.0) + + with self.assertRaisesRegex( + TypeError, + "otherwise.* can only be applied on a Column previously generated by when", + ): + CF.col("c").otherwise(1.0) + + with self.assertRaisesRegex( + TypeError, + "otherwise.* can only be applied once on a Column previously generated by when", + ): + CF.when(cdf.a == 0, 1.0).otherwise(1.0).otherwise(1.0) + + with self.assertRaisesRegex( + TypeError, + """condition should be a Column""", + ): + CF.when(True, 1.0).otherwise(1.0) + def test_sorting_functions_with_column(self): from pyspark.sql.connect import functions as CF from pyspark.sql.connect.column import Column From 8feb7eecea42bf92f79d150814c5529b2cde9ed0 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Tue, 6 Dec 2022 20:22:06 +0800 Subject: [PATCH 2/3] fix lint --- python/pyspark/sql/connect/column.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/connect/column.py b/python/pyspark/sql/connect/column.py index 25d0b7f037f01..c9582f3ed0c87 100644 --- a/python/pyspark/sql/connect/column.py +++ b/python/pyspark/sql/connect/column.py @@ -684,8 +684,7 @@ def when(self, condition: "Column", value: Any) -> "Column": "when() can only be applied on a Column previously generated by when() function" ) - case_when: CaseWhen = cast(CaseWhen, self._expr) - if case_when._else_value is not None: + if self._expr._else_value is not None: raise TypeError("when() cannot be applied once otherwise() is applied") if isinstance(value, Column): @@ -693,7 +692,7 @@ def when(self, condition: "Column", value: Any) -> "Column": else: _value = LiteralExpression(value) - _branches = case_when._branches + [(condition._expr, _value)] + _branches = self._expr._branches + [(condition._expr, _value)] return Column(CaseWhen(branches=_branches, else_value=None)) @@ -736,8 +735,7 @@ def otherwise(self, value: Any) -> "Column": "otherwise() can only be applied on a Column previously generated by when()" ) - case_when: CaseWhen = cast(CaseWhen, self._expr) - if case_when._else_value is not None: + if self._expr._else_value is not None: raise TypeError( "otherwise() can only be applied once on a Column previously generated by when()" ) @@ -747,7 +745,7 @@ def otherwise(self, value: Any) -> "Column": else: _value = LiteralExpression(value) - return Column(CaseWhen(branches=case_when._branches, else_value=_value)) + return Column(CaseWhen(branches=self._expr._branches, else_value=_value)) def like(self: "Column", other: str) -> "Column": """ From 110d7e7f4f9cc1038fd95f0ceda08680dc9c3ec6 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Thu, 8 Dec 2022 17:33:57 +0800 Subject: [PATCH 3/3] fix --- .../org/apache/spark/sql/connect/config/Connect.scala | 7 ++++--- .../spark/sql/connect/planner/SparkConnectPlanner.scala | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala index 358cb3c8f79b4..60fdd96401888 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala @@ -38,9 +38,10 @@ private[spark] object Connect { val CONNECT_GRPC_ARROW_MAX_BATCH_SIZE = ConfigBuilder("spark.connect.grpc.arrow.maxBatchSize") - .doc("When using Apache Arrow, limit the maximum size of one arrow batch that " + - "can be sent from server side to client side. Currently, we conservatively use 70% " + - "of it because the size is not accurate but estimated.") + .doc( + "When using Apache Arrow, limit the maximum size of one arrow batch that " + + "can be sent from server side to client side. Currently, we conservatively use 70% " + + "of it because the size is not accurate but estimated.") .version("3.4.0") .bytesConf(ByteUnit.MiB) .createWithDefaultString("4m") diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index e25ce0ac1f8c7..050b4f72ff3e9 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -533,7 +533,7 @@ class SparkConnectPlanner(session: SparkSession) { private def transformCaseWhen(casewhen: proto.Expression.CaseWhen): Expression = { CaseWhen( - branches = casewhen.getBranchesList.asScala + branches = casewhen.getBranchesList.asScala.toSeq .map(b => (transformExpression(b.getCondition), transformExpression(b.getValue))), elseValue = if (casewhen.hasElseValue) Some(transformExpression(casewhen.getElseValue)) else None)