Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions narwhals/_compliant/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,10 @@ def _is_elementary(self) -> bool:
def __repr__(self) -> str: # pragma: no cover
return f"{type(self).__name__}(depth={self._depth}, function_name={self._function_name})"

@classmethod
def _is_expr(cls, obj: Self | Any) -> TypeIs[Self]:
return hasattr(obj, "__narwhals_expr__")


class EagerExpr(
DepthTrackingExpr[EagerDataFrameT, EagerSeriesT],
Expand Down Expand Up @@ -451,10 +455,6 @@ def _reuse_series_extra_kwargs(
) -> dict[str, Any]:
return {}

@classmethod
def _is_expr(cls, obj: Self | Any) -> TypeIs[Self]:
return hasattr(obj, "__narwhals_expr__")

def _reuse_series_inner(
self,
df: EagerDataFrameT,
Expand Down
39 changes: 15 additions & 24 deletions narwhals/_dask/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,25 +18,21 @@
from narwhals._dask.expr import DaskExpr
from narwhals._dask.selectors import DaskSelectorNamespace
from narwhals._dask.utils import align_series_full_broadcast
from narwhals._dask.utils import extract_comparand
from narwhals._dask.utils import name_preserving_div
from narwhals._dask.utils import name_preserving_sum
from narwhals._dask.utils import narwhals_to_native_dtype
from narwhals._dask.utils import validate_comparand
from narwhals._expression_parsing import combine_alias_output_names
from narwhals._expression_parsing import combine_evaluate_output_names
from narwhals.utils import Implementation

if TYPE_CHECKING:
import dask.dataframe.dask_expr as dx
from typing_extensions import Self

from narwhals.dtypes import DType
from narwhals.utils import Version

try:
import dask.dataframe.dask_expr as dx
except ModuleNotFoundError: # pragma: no cover
import dask_expr as dx


class DaskNamespace(DepthTrackingNamespace[DaskLazyFrame, "DaskExpr"]):
_implementation: Implementation = Implementation.DASK
Expand Down Expand Up @@ -313,26 +309,21 @@ class DaskWhen(CompliantWhen[DaskLazyFrame, "dx.Series", DaskExpr]):
def _then(self) -> type[DaskThen]:
return DaskThen

def __call__(self: Self, df: DaskLazyFrame) -> Sequence[dx.Series]:
condition = self._condition(df)[0]

if isinstance(self._then_value, DaskExpr):
then_value = self._then_value(df)[0]
else:
then_value = self._then_value
(then_series,) = align_series_full_broadcast(df, then_value)
validate_comparand(condition, then_series)

if self._otherwise_value is None:
return [then_series.where(condition)]
def _if_then_else(
self, when: dx.Series, then: dx.Series, otherwise: Any, /
) -> dx.Series:
return then.where(when) if otherwise is None else then.where(when, otherwise)

if isinstance(self._otherwise_value, DaskExpr):
otherwise_value = self._otherwise_value(df)[0]
def __call__(self: Self, df: DaskLazyFrame) -> Sequence[dx.Series]:
is_expr = self._condition._is_expr
when = self._condition(df)[0]
then = self._then_value(df)[0] if is_expr(self._then_value) else self._then_value
if is_expr(self._otherwise_value):
otherwise = extract_comparand(df, when, self._otherwise_value(df)[0])
else:
return [then_series.where(condition, self._otherwise_value)] # pyright: ignore[reportArgumentType]
(otherwise_series,) = align_series_full_broadcast(df, otherwise_value)
validate_comparand(condition, otherwise_series)
return [then_series.where(condition, otherwise_series)] # pyright: ignore[reportArgumentType]
otherwise = self._otherwise_value
result = self._if_then_else(when, extract_comparand(df, when, then), otherwise)
return [result]


class DaskThen(CompliantThen[DaskLazyFrame, "dx.Series", DaskExpr], DaskExpr): ...
28 changes: 18 additions & 10 deletions narwhals/_dask/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,19 @@
from narwhals.utils import isinstance_or_issubclass
from narwhals.utils import parse_version

try:
import dask.dataframe.dask_expr as dx
except ModuleNotFoundError: # pragma: no cover
import dask_expr as dx

if TYPE_CHECKING:
import dask.dataframe as dd
import dask.dataframe.dask_expr as dx

from narwhals._dask.dataframe import DaskLazyFrame
from narwhals._dask.expr import DaskExpr
from narwhals.dtypes import DType
from narwhals.utils import Version
else:
try:
import dask.dataframe.dask_expr as dx
except ModuleNotFoundError: # pragma: no cover
import dask_expr as dx


def maybe_evaluate_expr(df: DaskLazyFrame, obj: DaskExpr | object) -> dx.Series | object:
Expand Down Expand Up @@ -74,12 +75,19 @@ def add_row_index(
)


def validate_comparand(lhs: dx.Series, rhs: dx.Series) -> None:
try:
import dask.dataframe.dask_expr as dx
except ModuleNotFoundError: # pragma: no cover
import dask_expr as dx
def extract_comparand(
df: DaskLazyFrame, condition: dx.Series, value: dx.Series | Any
) -> dx.Series:
rhs = (
value
if isinstance(value, dx.Series)
else df.native.assign(_literal=value)["_literal"]
)
validate_comparand(condition, rhs)
return rhs


def validate_comparand(lhs: dx.Series, rhs: dx.Series) -> None:
if not dx.expr.are_co_aligned(lhs._expr, rhs._expr): # pragma: no cover
# are_co_aligned is a method which cheaply checks if two Dask expressions
# have the same index, and therefore don't require index alignment.
Expand Down
Loading