diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index aef16939ca..e2168baa28 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -83,9 +83,7 @@ PromoteOptions: TypeAlias = Literal["none", "default", "permissive"] -class ArrowDataFrame( - EagerDataFrame["ArrowSeries", "ArrowExpr", "pa.Table", "ChunkedArrayAny"] -): +class ArrowDataFrame(EagerDataFrame["ArrowSeries", "ArrowExpr", "pa.Table"]): def __init__( self, native_dataframe: pa.Table, diff --git a/narwhals/_arrow/namespace.py b/narwhals/_arrow/namespace.py index 2730ebaae9..42a01924fb 100644 --- a/narwhals/_arrow/namespace.py +++ b/narwhals/_arrow/namespace.py @@ -24,15 +24,13 @@ from narwhals._utils import Implementation if TYPE_CHECKING: - from narwhals._arrow.typing import ArrayOrScalar, ChunkedArrayAny, Incomplete + from narwhals._arrow.typing import Incomplete from narwhals._utils import Version from narwhals.dtypes import DType from narwhals.typing import NonNestedLiteral -class ArrowNamespace( - EagerNamespace[ArrowDataFrame, ArrowSeries, ArrowExpr, "pa.Table", "ChunkedArrayAny"] -): +class ArrowNamespace(EagerNamespace[ArrowDataFrame, ArrowSeries, ArrowExpr, pa.Table]): @property def _dataframe(self) -> type[ArrowDataFrame]: return ArrowDataFrame @@ -266,20 +264,23 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: ) -class ArrowWhen(EagerWhen[ArrowDataFrame, ArrowSeries, ArrowExpr, "ChunkedArrayAny"]): +class ArrowWhen(EagerWhen[ArrowDataFrame, ArrowSeries, ArrowExpr]): @property def _then(self) -> type[ArrowThen]: return ArrowThen def _if_then_else( - self, - when: ChunkedArrayAny, - then: ChunkedArrayAny, - otherwise: ArrayOrScalar | NonNestedLiteral, - /, - ) -> ChunkedArrayAny: - otherwise = pa.nulls(len(when), then.type) if otherwise is None else otherwise - return pc.if_else(when, then, otherwise) + self, when: ArrowSeries, then: ArrowSeries, otherwise: ArrowSeries | None, / + ) -> ArrowSeries: + if otherwise is None: + when, then = align_series_full_broadcast(when, then) + res_native = pc.if_else( + when.native, then.native, pa.nulls(len(when.native), then.native.type) + ) + else: + when, then, otherwise = align_series_full_broadcast(when, then, otherwise) + res_native = pc.if_else(when.native, then.native, otherwise.native) + return then._with_native(res_native) class ArrowThen(CompliantThen[ArrowDataFrame, ArrowSeries, ArrowExpr], ArrowExpr): ... diff --git a/narwhals/_compliant/dataframe.py b/narwhals/_compliant/dataframe.py index 57226ff5b4..5f210550e6 100644 --- a/narwhals/_compliant/dataframe.py +++ b/narwhals/_compliant/dataframe.py @@ -23,7 +23,6 @@ EagerSeriesT, NativeExprT, NativeFrameT, - NativeSeriesT, ) from narwhals._translate import ( ArrowConvertible, @@ -402,11 +401,11 @@ def _check_columns_exist(self, subset: Sequence[str]) -> ColumnNotFoundError | N class EagerDataFrame( CompliantDataFrame[EagerSeriesT, EagerExprT, NativeFrameT, "DataFrame[NativeFrameT]"], CompliantLazyFrame[EagerExprT, NativeFrameT, "DataFrame[NativeFrameT]"], - Protocol[EagerSeriesT, EagerExprT, NativeFrameT, NativeSeriesT], + Protocol[EagerSeriesT, EagerExprT, NativeFrameT], ): def __narwhals_namespace__( self, - ) -> EagerNamespace[Self, EagerSeriesT, EagerExprT, NativeFrameT, NativeSeriesT]: ... + ) -> EagerNamespace[Self, EagerSeriesT, EagerExprT, NativeFrameT]: ... def to_narwhals(self) -> DataFrame[NativeFrameT]: return self._version.dataframe(self, level="full") @@ -450,14 +449,10 @@ def _numpy_column_names( ) -> list[str]: return list(columns or (f"column_{x}" for x in range(data.shape[1]))) - def _gather(self, rows: SizedMultiIndexSelector[NativeSeriesT]) -> Self: ... + def _gather(self, rows: SizedMultiIndexSelector[Any]) -> Self: ... def _gather_slice(self, rows: _SliceIndex | range) -> Self: ... - def _select_multi_index( - self, columns: SizedMultiIndexSelector[NativeSeriesT] - ) -> Self: ... - def _select_multi_name( - self, columns: SizedMultiNameSelector[NativeSeriesT] - ) -> Self: ... + def _select_multi_index(self, columns: SizedMultiIndexSelector[Any]) -> Self: ... + def _select_multi_name(self, columns: SizedMultiNameSelector[Any]) -> Self: ... def _select_slice_index(self, columns: _SliceIndex | range) -> Self: ... def _select_slice_name(self, columns: _SliceName) -> Self: ... def __getitem__( # noqa: C901, PLR0912 diff --git a/narwhals/_compliant/expr.py b/narwhals/_compliant/expr.py index bb783398bf..9490bc8b89 100644 --- a/narwhals/_compliant/expr.py +++ b/narwhals/_compliant/expr.py @@ -337,7 +337,7 @@ def __call__(self, df: EagerDataFrameT) -> Sequence[EagerSeriesT]: def __narwhals_namespace__( self, - ) -> EagerNamespace[EagerDataFrameT, EagerSeriesT, Self, Any, Any]: ... + ) -> EagerNamespace[EagerDataFrameT, EagerSeriesT, Self, Any]: ... def __narwhals_expr__(self) -> None: ... @classmethod diff --git a/narwhals/_compliant/namespace.py b/narwhals/_compliant/namespace.py index d0440152a8..7371232336 100644 --- a/narwhals/_compliant/namespace.py +++ b/narwhals/_compliant/namespace.py @@ -23,7 +23,6 @@ LazyExprT, NativeFrameT, NativeFrameT_co, - NativeSeriesT, ) from narwhals._utils import ( exclude_column_names, @@ -132,7 +131,7 @@ def from_native(self, data: NativeFrameT_co | Any, /) -> CompliantLazyFrameT: class EagerNamespace( DepthTrackingNamespace[EagerDataFrameT, EagerExprT], - Protocol[EagerDataFrameT, EagerSeriesT, EagerExprT, NativeFrameT, NativeSeriesT], + Protocol[EagerDataFrameT, EagerSeriesT, EagerExprT, NativeFrameT], ): @property def _dataframe(self) -> type[EagerDataFrameT]: ... @@ -140,23 +139,9 @@ def _dataframe(self) -> type[EagerDataFrameT]: ... def _series(self) -> type[EagerSeriesT]: ... def when( self, predicate: EagerExprT - ) -> EagerWhen[EagerDataFrameT, EagerSeriesT, EagerExprT, NativeSeriesT]: ... + ) -> EagerWhen[EagerDataFrameT, EagerSeriesT, EagerExprT]: ... - @overload - def from_native(self, data: NativeFrameT, /) -> EagerDataFrameT: ... - @overload - def from_native(self, data: NativeSeriesT, /) -> EagerSeriesT: ... - # TODO @dangotbanned: Align `PandasLike` typing with `_namespace`, then drop this `@overload` - # - Using the guards there introduces `_NativeModin`, `_NativeCuDF` - # - These types haven't been integrated into the backend - # - Most of the `pandas` stuff is still untyped - @overload - def from_native( - self, data: NativeFrameT | NativeSeriesT | Any, / - ) -> EagerDataFrameT | EagerSeriesT: ... - def from_native( - self, data: NativeFrameT | NativeSeriesT | Any, / - ) -> EagerDataFrameT | EagerSeriesT: + def from_native(self, data: Any, /) -> EagerDataFrameT | EagerSeriesT: if self._dataframe._is_native(data): return self._dataframe.from_native(data, context=self) elif self._series._is_native(data): diff --git a/narwhals/_compliant/series.py b/narwhals/_compliant/series.py index cc194af4ca..fb0cb98ca7 100644 --- a/narwhals/_compliant/series.py +++ b/narwhals/_compliant/series.py @@ -300,9 +300,7 @@ def _with_native( """ ... - def __narwhals_namespace__( - self, - ) -> EagerNamespace[Any, Self, Any, Any, NativeSeriesT]: ... + def __narwhals_namespace__(self) -> EagerNamespace[Any, Self, Any, Any]: ... def _to_expr(self) -> EagerExpr[Any, Any]: return self.__narwhals_namespace__()._expr._from_series(self) # type: ignore[no-any-return] diff --git a/narwhals/_compliant/typing.py b/narwhals/_compliant/typing.py index 457d71d6b5..4c3685b728 100644 --- a/narwhals/_compliant/typing.py +++ b/narwhals/_compliant/typing.py @@ -60,10 +60,12 @@ class ScalarKwargs(TypedDict, total=False): DepthTrackingExprAny: TypeAlias = "DepthTrackingExpr[Any, Any]" -EagerDataFrameAny: TypeAlias = "EagerDataFrame[Any, Any, Any, Any]" +EagerDataFrameAny: TypeAlias = "EagerDataFrame[Any, Any, Any]" EagerSeriesAny: TypeAlias = "EagerSeries[Any]" EagerExprAny: TypeAlias = "EagerExpr[Any, Any]" -EagerNamespaceAny: TypeAlias = "EagerNamespace[EagerDataFrameAny, EagerSeriesAny, EagerExprAny, NativeFrame, NativeSeries]" +EagerNamespaceAny: TypeAlias = ( + "EagerNamespace[EagerDataFrameAny, EagerSeriesAny, EagerExprAny, NativeFrame]" +) LazyExprAny: TypeAlias = "LazyExpr[Any, Any]" @@ -124,7 +126,7 @@ class ScalarKwargs(TypedDict, total=False): EagerSeriesT_co = TypeVar("EagerSeriesT_co", bound=EagerSeriesAny, covariant=True) # NOTE: `pyright` gives false (8) positives if this uses `EagerDataFrameAny`? -EagerDataFrameT = TypeVar("EagerDataFrameT", bound="EagerDataFrame[Any, Any, Any, Any]") +EagerDataFrameT = TypeVar("EagerDataFrameT", bound="EagerDataFrame[Any, Any, Any]") LazyExprT = TypeVar("LazyExprT", bound=LazyExprAny) LazyExprT_contra = TypeVar("LazyExprT_contra", bound=LazyExprAny, contravariant=True) diff --git a/narwhals/_compliant/when_then.py b/narwhals/_compliant/when_then.py index 37e70dcb55..a2a2b4fe78 100644 --- a/narwhals/_compliant/when_then.py +++ b/narwhals/_compliant/when_then.py @@ -13,7 +13,6 @@ EagerSeriesT, LazyExprAny, NativeExprT, - NativeSeriesT, WindowFunction, ) from narwhals._typing_compat import Protocol38 @@ -43,7 +42,7 @@ class CompliantWhen(Protocol38[FrameT, SeriesT, ExprT]): _condition: ExprT _then_value: IntoExpr[SeriesT, ExprT] - _otherwise_value: IntoExpr[SeriesT, ExprT] + _otherwise_value: IntoExpr[SeriesT, ExprT] | None _implementation: Implementation _backend_version: tuple[int, ...] _version: Version @@ -145,15 +144,11 @@ def from_when( class EagerWhen( CompliantWhen[EagerDataFrameT, EagerSeriesT, EagerExprT], - Protocol38[EagerDataFrameT, EagerSeriesT, EagerExprT, NativeSeriesT], + Protocol38[EagerDataFrameT, EagerSeriesT, EagerExprT], ): def _if_then_else( - self, - when: NativeSeriesT, - then: NativeSeriesT, - otherwise: NativeSeriesT | NonNestedLiteral | Scalar, - /, - ) -> NativeSeriesT: ... + self, when: EagerSeriesT, then: EagerSeriesT, otherwise: EagerSeriesT | None, / + ) -> EagerSeriesT: ... def __call__(self, df: EagerDataFrameT, /) -> Sequence[EagerSeriesT]: is_expr = self._condition._is_expr @@ -165,11 +160,13 @@ def __call__(self, df: EagerDataFrameT, /) -> Sequence[EagerSeriesT]: then = when.alias("literal")._from_scalar(self._then_value) then._broadcast = True if is_expr(self._otherwise_value): - otherwise = df._extract_comparand(self._otherwise_value(df)[0]) + otherwise = self._otherwise_value(df)[0] + elif self._otherwise_value is not None: + otherwise = when._from_scalar(self._otherwise_value) + otherwise._broadcast = True else: otherwise = self._otherwise_value - result = self._if_then_else(when.native, df._extract_comparand(then), otherwise) - return [then._with_native(result)] + return [self._if_then_else(when, then, otherwise)] class LazyWhen( diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index c23748d016..bf5287fbad 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -103,9 +103,7 @@ ) -class PandasLikeDataFrame( - EagerDataFrame["PandasLikeSeries", "PandasLikeExpr", "Any", "pd.Series[Any]"] -): +class PandasLikeDataFrame(EagerDataFrame["PandasLikeSeries", "PandasLikeExpr", "Any"]): def __init__( self, native_dataframe: Any, diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index bbff8bbc7e..7d149e4c60 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -3,7 +3,9 @@ import operator import warnings from functools import reduce -from typing import TYPE_CHECKING, Any, Literal, Sequence +from typing import TYPE_CHECKING, Literal, Sequence + +import pandas as pd from narwhals._compliant import CompliantThen, EagerNamespace, EagerWhen from narwhals._expression_parsing import ( @@ -17,8 +19,6 @@ from narwhals._pandas_like.utils import align_series_full_broadcast if TYPE_CHECKING: - import pandas as pd - from narwhals._pandas_like.typing import NDFrameT from narwhals._utils import Implementation, Version from narwhals.dtypes import DType @@ -29,13 +29,7 @@ class PandasLikeNamespace( - EagerNamespace[ - PandasLikeDataFrame, - PandasLikeSeries, - PandasLikeExpr, - "pd.DataFrame", - "pd.Series[Any]", - ] + EagerNamespace[PandasLikeDataFrame, PandasLikeSeries, PandasLikeExpr, pd.DataFrame] ): @property def _dataframe(self) -> type[PandasLikeDataFrame]: @@ -315,21 +309,25 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: ) -class PandasWhen( - EagerWhen[PandasLikeDataFrame, PandasLikeSeries, PandasLikeExpr, "pd.Series[Any]"] -): +class PandasWhen(EagerWhen[PandasLikeDataFrame, PandasLikeSeries, PandasLikeExpr]): @property def _then(self) -> type[PandasThen]: return PandasThen def _if_then_else( self, - when: pd.Series[Any], - then: pd.Series[Any], - otherwise: pd.Series[Any] | NonNestedLiteral, + when: PandasLikeSeries, + then: PandasLikeSeries, + otherwise: PandasLikeSeries | None, /, - ) -> pd.Series[Any]: - return then.where(when) if otherwise is None else then.where(when, otherwise) + ) -> PandasLikeSeries: + if otherwise is None: + when, then = align_series_full_broadcast(when, then) + res_native = then.native.where(when.native) + else: + when, then, otherwise = align_series_full_broadcast(when, then, otherwise) + res_native = then.native.where(when.native, otherwise.native) + return then._with_native(res_native) class PandasThen(