Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,13 @@ class SparkConnectPlanner(session: SparkSession) {
.Product(transformExpression(fun.getArgumentsList.asScala.head))
.toAggregateExpression())

case "when" =>
if (fun.getArgumentsCount == 0) {
throw InvalidPlanInput("CaseWhen requires at least one child expression")
}
val children = fun.getArgumentsList.asScala.toSeq.map(transformExpression)
Some(CaseWhen.createFromParser(children))

case "in" =>
if (fun.getArgumentsCount == 0) {
throw InvalidPlanInput("In requires at least one child expression")
Expand Down
164 changes: 157 additions & 7 deletions python/pyspark/sql/connect/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,18 @@
# limitations under the License.
#

from typing import get_args, TYPE_CHECKING, Callable, Any, Union, overload, cast, Sequence
from typing import (
get_args,
TYPE_CHECKING,
Callable,
Any,
Union,
overload,
cast,
Sequence,
Tuple,
Optional,
)

import json
import decimal
Expand Down Expand Up @@ -152,6 +163,44 @@ def name(self) -> str:
...


class CaseWhen(Expression):
def __init__(
self, branches: Sequence[Tuple[Expression, Expression]], else_value: Optional[Expression]
):

assert isinstance(branches, list)
for branch in branches:
assert (
isinstance(branch, tuple)
and len(branch) == 2
and all(isinstance(expr, Expression) for expr in branch)
)
self._branches = branches

if else_value is not None:
assert isinstance(else_value, Expression)

self._else_value = else_value

def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
args = []
for condition, value in self._branches:
args.append(condition)
args.append(value)

if self._else_value is not None:
args.append(self._else_value)

unresolved_function = UnresolvedFunction(name="when", args=args)

return unresolved_function.to_plan(session)

def __repr__(self) -> str:
_cases = "".join([f" WHEN {c} THEN {v}" for c, v in self._branches])
_else = f" ELSE {self._else_value}" if self._else_value is not None else ""
return "CASE" + _cases + _else + " END"


class ColumnAlias(Expression):
def __init__(self, parent: Expression, alias: list[str], metadata: Any):

Expand Down Expand Up @@ -706,6 +755,113 @@ def contains(self, other: Union["PrimitiveType", "Column"]) -> "Column":
startswith = _bin_op("startsWith", _startswith_doc)
endswith = _bin_op("endsWith", _endswith_doc)

def when(self, condition: "Column", value: Any) -> "Column":
"""
Evaluates a list of conditions and returns one of multiple possible result expressions.
If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions.

.. versionadded:: 3.4.0

Parameters
----------
condition : :class:`Column`
a boolean :class:`Column` expression.
value
a literal value, or a :class:`Column` expression.

Returns
-------
:class:`Column`
Column representing whether each element of Column is in conditions.

Examples
--------
>>> from pyspark.sql import functions as F
>>> df = spark.createDataFrame(
... [(2, "Alice"), (5, "Bob")], ["age", "name"])
>>> df.select(df.name, F.when(df.age > 4, 1).when(df.age < 3, -1).otherwise(0)).show()
+-----+------------------------------------------------------------+
| name|CASE WHEN (age > 4) THEN 1 WHEN (age < 3) THEN -1 ELSE 0 END|
+-----+------------------------------------------------------------+
|Alice| -1|
| Bob| 1|
+-----+------------------------------------------------------------+

See Also
--------
pyspark.sql.functions.when
"""
if not isinstance(condition, Column):
raise TypeError("condition should be a Column")

if not isinstance(self._expr, CaseWhen):
raise TypeError(
"when() can only be applied on a Column previously generated by when() function"
)

if self._expr._else_value is not None:
raise TypeError("when() cannot be applied once otherwise() is applied")

if isinstance(value, Column):
_value = value._expr
else:
_value = LiteralExpression(value, LiteralExpression._infer_type(value))

_branches = self._expr._branches + [(condition._expr, _value)]

return Column(CaseWhen(branches=_branches, else_value=None))

def otherwise(self, value: Any) -> "Column":
"""
Evaluates a list of conditions and returns one of multiple possible result expressions.
If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions.

.. versionadded:: 3.4.0

Parameters
----------
value
a literal value, or a :class:`Column` expression.

Returns
-------
:class:`Column`
Column representing whether each element of Column is unmatched conditions.

Examples
--------
>>> from pyspark.sql import functions as F
>>> df = spark.createDataFrame(
... [(2, "Alice"), (5, "Bob")], ["age", "name"])
>>> df.select(df.name, F.when(df.age > 3, 1).otherwise(0)).show()
+-----+-------------------------------------+
| name|CASE WHEN (age > 3) THEN 1 ELSE 0 END|
+-----+-------------------------------------+
|Alice| 0|
| Bob| 1|
+-----+-------------------------------------+

See Also
--------
pyspark.sql.functions.when
"""
if not isinstance(self._expr, CaseWhen):
raise TypeError(
"otherwise() can only be applied on a Column previously generated by when()"
)

if self._expr._else_value is not None:
raise TypeError(
"otherwise() can only be applied once on a Column previously generated by when()"
)

if isinstance(value, Column):
_value = value._expr
else:
_value = LiteralExpression(value, LiteralExpression._infer_type(value))

return Column(CaseWhen(branches=self._expr._branches, else_value=_value))

def like(self: "Column", other: str) -> "Column":
"""
SQL like expression. Returns a boolean :class:`Column` based on a SQL LIKE match.
Expand Down Expand Up @@ -902,9 +1058,6 @@ def cast(self, dataType: Union[DataType, str]) -> "Column":
def __repr__(self) -> str:
return "Column<'%s'>" % self._expr.__repr__()

def otherwise(self, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError("otherwise() is not yet implemented.")

def over(self, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError("over() is not yet implemented.")

Expand Down Expand Up @@ -943,9 +1096,6 @@ def isin(self, *cols: Any) -> "Column":

return Column(UnresolvedFunction("in", [self._expr] + [lit(c)._expr for c in _cols]))

def when(self, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError("when() is not yet implemented.")

def getItem(self, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError("getItem() is not yet implemented.")

Expand Down
95 changes: 48 additions & 47 deletions python/pyspark/sql/connect/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#
from pyspark.sql.connect.column import (
Column,
CaseWhen,
Expression,
LiteralExpression,
ColumnReference,
Expand Down Expand Up @@ -549,53 +550,53 @@ def spark_partition_id() -> Column:
return _invoke_function("spark_partition_id")


# TODO(SPARK-41319): Support case-when in Column
# def when(condition: Column, value: Any) -> Column:
# """Evaluates a list of conditions and returns one of multiple possible result expressions.
# If :func:`pyspark.sql.Column.otherwise` is not invoked, None is returned for unmatched
# conditions.
#
# .. versionadded:: 3.4.0
#
# Parameters
# ----------
# condition : :class:`~pyspark.sql.Column`
# a boolean :class:`~pyspark.sql.Column` expression.
# value :
# a literal value, or a :class:`~pyspark.sql.Column` expression.
#
# Returns
# -------
# :class:`~pyspark.sql.Column`
# column representing when expression.
#
# Examples
# --------
# >>> df = spark.range(3)
# >>> df.select(when(df['id'] == 2, 3).otherwise(4).alias("age")).show()
# +---+
# |age|
# +---+
# | 4|
# | 4|
# | 3|
# +---+
#
# >>> df.select(when(df.id == 2, df.id + 1).alias("age")).show()
# +----+
# | age|
# +----+
# |null|
# |null|
# | 3|
# +----+
# """
# # Explicitly not using ColumnOrName type here to make reading condition less opaque
# if not isinstance(condition, Column):
# raise TypeError("condition should be a Column")
# v = value._jc if isinstance(value, Column) else value
#
# return _invoke_function("when", condition._jc, v)
def when(condition: Column, value: Any) -> Column:
"""Evaluates a list of conditions and returns one of multiple possible result expressions.
If :func:`pyspark.sql.Column.otherwise` is not invoked, None is returned for unmatched
conditions.

.. versionadded:: 3.4.0

Parameters
----------
condition : :class:`~pyspark.sql.Column`
a boolean :class:`~pyspark.sql.Column` expression.
value :
a literal value, or a :class:`~pyspark.sql.Column` expression.

Returns
-------
:class:`~pyspark.sql.Column`
column representing when expression.

Examples
--------
>>> df = spark.range(3)
>>> df.select(when(df['id'] == 2, 3).otherwise(4).alias("age")).show()
+---+
|age|
+---+
| 4|
| 4|
| 3|
+---+

>>> df.select(when(df.id == 2, df.id + 1).alias("age")).show()
+----+
| age|
+----+
|null|
|null|
| 3|
+----+
"""
# Explicitly not using ColumnOrName type here to make reading condition less opaque
if not isinstance(condition, Column):
raise TypeError("condition should be a Column")

value_col = value if isinstance(value, Column) else lit(value)

return Column(CaseWhen(branches=[(condition._expr, value_col._expr)], else_value=None))


# Sort Functions
Expand Down
2 changes: 0 additions & 2 deletions python/pyspark/sql/tests/connect/test_connect_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,9 +347,7 @@ def test_unsupported_functions(self):
# SPARK-41225: Disable unsupported functions.
c = self.connect.range(1).id
for f in (
"otherwise",
"over",
"when",
"getItem",
"astype",
"between",
Expand Down
Loading