Skip to content
4 changes: 1 addition & 3 deletions narwhals/_arrow/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,7 @@
PromoteOptions: TypeAlias = Literal["none", "default", "permissive"]


class ArrowDataFrame(
EagerDataFrame["ArrowSeries", "ArrowExpr", "pa.Table", "ChunkedArrayAny"]
):
class ArrowDataFrame(EagerDataFrame["ArrowSeries", "ArrowExpr", "pa.Table"]):
def __init__(
self,
native_dataframe: pa.Table,
Expand Down
27 changes: 14 additions & 13 deletions narwhals/_arrow/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,13 @@
from narwhals._utils import Implementation

if TYPE_CHECKING:
from narwhals._arrow.typing import ArrayOrScalar, ChunkedArrayAny, Incomplete
from narwhals._arrow.typing import Incomplete
from narwhals._utils import Version
from narwhals.dtypes import DType
from narwhals.typing import NonNestedLiteral


class ArrowNamespace(
EagerNamespace[ArrowDataFrame, ArrowSeries, ArrowExpr, "pa.Table", "ChunkedArrayAny"]
):
class ArrowNamespace(EagerNamespace[ArrowDataFrame, ArrowSeries, ArrowExpr, pa.Table]):
@property
def _dataframe(self) -> type[ArrowDataFrame]:
return ArrowDataFrame
Expand Down Expand Up @@ -266,20 +264,23 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
)


class ArrowWhen(EagerWhen[ArrowDataFrame, ArrowSeries, ArrowExpr, "ChunkedArrayAny"]):
class ArrowWhen(EagerWhen[ArrowDataFrame, ArrowSeries, ArrowExpr]):
@property
def _then(self) -> type[ArrowThen]:
return ArrowThen

def _if_then_else(
self,
when: ChunkedArrayAny,
then: ChunkedArrayAny,
otherwise: ArrayOrScalar | NonNestedLiteral,
/,
) -> ChunkedArrayAny:
otherwise = pa.nulls(len(when), then.type) if otherwise is None else otherwise
return pc.if_else(when, then, otherwise)
self, when: ArrowSeries, then: ArrowSeries, otherwise: ArrowSeries | None, /
) -> ArrowSeries:
if otherwise is None:
when, then = align_series_full_broadcast(when, then)
res_native = pc.if_else(
when.native, then.native, pa.nulls(len(when.native), then.native.type)
)
else:
when, then, otherwise = align_series_full_broadcast(when, then, otherwise)
res_native = pc.if_else(when.native, then.native, otherwise.native)
return then._with_native(res_native)


class ArrowThen(CompliantThen[ArrowDataFrame, ArrowSeries, ArrowExpr], ArrowExpr): ...
15 changes: 5 additions & 10 deletions narwhals/_compliant/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
EagerSeriesT,
NativeExprT,
NativeFrameT,
NativeSeriesT,
)
from narwhals._translate import (
ArrowConvertible,
Expand Down Expand Up @@ -402,11 +401,11 @@ def _check_columns_exist(self, subset: Sequence[str]) -> ColumnNotFoundError | N
class EagerDataFrame(
CompliantDataFrame[EagerSeriesT, EagerExprT, NativeFrameT, "DataFrame[NativeFrameT]"],
CompliantLazyFrame[EagerExprT, NativeFrameT, "DataFrame[NativeFrameT]"],
Protocol[EagerSeriesT, EagerExprT, NativeFrameT, NativeSeriesT],
Protocol[EagerSeriesT, EagerExprT, NativeFrameT],
):
def __narwhals_namespace__(
self,
) -> EagerNamespace[Self, EagerSeriesT, EagerExprT, NativeFrameT, NativeSeriesT]: ...
) -> EagerNamespace[Self, EagerSeriesT, EagerExprT, NativeFrameT]: ...

def to_narwhals(self) -> DataFrame[NativeFrameT]:
return self._version.dataframe(self, level="full")
Expand Down Expand Up @@ -450,14 +449,10 @@ def _numpy_column_names(
) -> list[str]:
return list(columns or (f"column_{x}" for x in range(data.shape[1])))

def _gather(self, rows: SizedMultiIndexSelector[NativeSeriesT]) -> Self: ...
def _gather(self, rows: SizedMultiIndexSelector[Any]) -> Self: ...
def _gather_slice(self, rows: _SliceIndex | range) -> Self: ...
def _select_multi_index(
self, columns: SizedMultiIndexSelector[NativeSeriesT]
) -> Self: ...
def _select_multi_name(
self, columns: SizedMultiNameSelector[NativeSeriesT]
) -> Self: ...
def _select_multi_index(self, columns: SizedMultiIndexSelector[Any]) -> Self: ...
def _select_multi_name(self, columns: SizedMultiNameSelector[Any]) -> Self: ...
def _select_slice_index(self, columns: _SliceIndex | range) -> Self: ...
Comment on lines +452 to 456
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah can you revert this PR?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A refactor on when shouldn't have to remove typing from other parts of narwhals

def _select_slice_name(self, columns: _SliceName) -> Self: ...
def __getitem__( # noqa: C901, PLR0912
Expand Down
2 changes: 1 addition & 1 deletion narwhals/_compliant/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def __call__(self, df: EagerDataFrameT) -> Sequence[EagerSeriesT]:

def __narwhals_namespace__(
self,
) -> EagerNamespace[EagerDataFrameT, EagerSeriesT, Self, Any, Any]: ...
) -> EagerNamespace[EagerDataFrameT, EagerSeriesT, Self, Any]: ...
def __narwhals_expr__(self) -> None: ...

@classmethod
Expand Down
21 changes: 3 additions & 18 deletions narwhals/_compliant/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
LazyExprT,
NativeFrameT,
NativeFrameT_co,
NativeSeriesT,
)
from narwhals._utils import (
exclude_column_names,
Expand Down Expand Up @@ -132,31 +131,17 @@ def from_native(self, data: NativeFrameT_co | Any, /) -> CompliantLazyFrameT:

class EagerNamespace(
DepthTrackingNamespace[EagerDataFrameT, EagerExprT],
Protocol[EagerDataFrameT, EagerSeriesT, EagerExprT, NativeFrameT, NativeSeriesT],
Protocol[EagerDataFrameT, EagerSeriesT, EagerExprT, NativeFrameT],
):
@property
def _dataframe(self) -> type[EagerDataFrameT]: ...
@property
def _series(self) -> type[EagerSeriesT]: ...
def when(
self, predicate: EagerExprT
) -> EagerWhen[EagerDataFrameT, EagerSeriesT, EagerExprT, NativeSeriesT]: ...
) -> EagerWhen[EagerDataFrameT, EagerSeriesT, EagerExprT]: ...
Comment on lines 143 to +142
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm struggling to follow this

If I remove NativeSeriesT from EagerWhen, and hence from the return type, then mypy / pyright tell me:

   /home/runner/work/narwhals/narwhals/narwhals/_compliant/namespace.py:133:7 - warning: Type variable "NativeSeriesT" used in generic Protocol "EagerNamespace" should be covariant (reportInvalidTypeVarUse)

But, I can't make it covariant, because it's used as an argument to from_native. If I did, type checkers would complain. I can't tell what they're expecting me to do πŸ˜„

@dangotbanned sorry for the ping, any tips would be appreciated πŸ™

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @MarcoGorelli, I'll try to take a look at this today πŸ™

Would you be able to give some more background on what you're trying to accomplish with this PR please?

I'm not following the motivation from either the title or commit messages πŸ€”

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, it's to enable #2652 and #2645

i spent about 2 hours trying different things and wasn't able to get anywhere with this 😳 , based on #2662 (comment) it looks to me like there were already issues with these protocols. happy for some parts to be brought back, so long as we're able to move towards the linked issues


@overload
def from_native(self, data: NativeFrameT, /) -> EagerDataFrameT: ...
@overload
def from_native(self, data: NativeSeriesT, /) -> EagerSeriesT: ...
# TODO @dangotbanned: Align `PandasLike` typing with `_namespace`, then drop this `@overload`
# - Using the guards there introduces `_NativeModin`, `_NativeCuDF`
# - These types haven't been integrated into the backend
# - Most of the `pandas` stuff is still untyped
@overload
def from_native(
self, data: NativeFrameT | NativeSeriesT | Any, /
) -> EagerDataFrameT | EagerSeriesT: ...
def from_native(
self, data: NativeFrameT | NativeSeriesT | Any, /
) -> EagerDataFrameT | EagerSeriesT:
Comment on lines -145 to -159
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this should have been merged

def from_native(self, data: Any, /) -> EagerDataFrameT | EagerSeriesT:
if self._dataframe._is_native(data):
return self._dataframe.from_native(data, context=self)
elif self._series._is_native(data):
Expand Down
4 changes: 1 addition & 3 deletions narwhals/_compliant/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,9 +300,7 @@ def _with_native(
"""
...

def __narwhals_namespace__(
self,
) -> EagerNamespace[Any, Self, Any, Any, NativeSeriesT]: ...
def __narwhals_namespace__(self) -> EagerNamespace[Any, Self, Any, Any]: ...

def _to_expr(self) -> EagerExpr[Any, Any]:
return self.__narwhals_namespace__()._expr._from_series(self) # type: ignore[no-any-return]
Expand Down
8 changes: 5 additions & 3 deletions narwhals/_compliant/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,12 @@ class ScalarKwargs(TypedDict, total=False):

DepthTrackingExprAny: TypeAlias = "DepthTrackingExpr[Any, Any]"

EagerDataFrameAny: TypeAlias = "EagerDataFrame[Any, Any, Any, Any]"
EagerDataFrameAny: TypeAlias = "EagerDataFrame[Any, Any, Any]"
EagerSeriesAny: TypeAlias = "EagerSeries[Any]"
EagerExprAny: TypeAlias = "EagerExpr[Any, Any]"
EagerNamespaceAny: TypeAlias = "EagerNamespace[EagerDataFrameAny, EagerSeriesAny, EagerExprAny, NativeFrame, NativeSeries]"
EagerNamespaceAny: TypeAlias = (
"EagerNamespace[EagerDataFrameAny, EagerSeriesAny, EagerExprAny, NativeFrame]"
)

LazyExprAny: TypeAlias = "LazyExpr[Any, Any]"

Expand Down Expand Up @@ -124,7 +126,7 @@ class ScalarKwargs(TypedDict, total=False):
EagerSeriesT_co = TypeVar("EagerSeriesT_co", bound=EagerSeriesAny, covariant=True)

# NOTE: `pyright` gives false (8) positives if this uses `EagerDataFrameAny`?
EagerDataFrameT = TypeVar("EagerDataFrameT", bound="EagerDataFrame[Any, Any, Any, Any]")
EagerDataFrameT = TypeVar("EagerDataFrameT", bound="EagerDataFrame[Any, Any, Any]")

LazyExprT = TypeVar("LazyExprT", bound=LazyExprAny)
LazyExprT_contra = TypeVar("LazyExprT_contra", bound=LazyExprAny, contravariant=True)
Expand Down
21 changes: 9 additions & 12 deletions narwhals/_compliant/when_then.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
EagerSeriesT,
LazyExprAny,
NativeExprT,
NativeSeriesT,
WindowFunction,
)
from narwhals._typing_compat import Protocol38
Expand Down Expand Up @@ -43,7 +42,7 @@
class CompliantWhen(Protocol38[FrameT, SeriesT, ExprT]):
_condition: ExprT
_then_value: IntoExpr[SeriesT, ExprT]
_otherwise_value: IntoExpr[SeriesT, ExprT]
_otherwise_value: IntoExpr[SeriesT, ExprT] | None
_implementation: Implementation
_backend_version: tuple[int, ...]
_version: Version
Expand Down Expand Up @@ -145,15 +144,11 @@ def from_when(

class EagerWhen(
CompliantWhen[EagerDataFrameT, EagerSeriesT, EagerExprT],
Protocol38[EagerDataFrameT, EagerSeriesT, EagerExprT, NativeSeriesT],
Protocol38[EagerDataFrameT, EagerSeriesT, EagerExprT],
):
def _if_then_else(
self,
when: NativeSeriesT,
then: NativeSeriesT,
otherwise: NativeSeriesT | NonNestedLiteral | Scalar,
/,
) -> NativeSeriesT: ...
self, when: EagerSeriesT, then: EagerSeriesT, otherwise: EagerSeriesT | None, /
) -> EagerSeriesT: ...

def __call__(self, df: EagerDataFrameT, /) -> Sequence[EagerSeriesT]:
is_expr = self._condition._is_expr
Expand All @@ -165,11 +160,13 @@ def __call__(self, df: EagerDataFrameT, /) -> Sequence[EagerSeriesT]:
then = when.alias("literal")._from_scalar(self._then_value)
then._broadcast = True
if is_expr(self._otherwise_value):
otherwise = df._extract_comparand(self._otherwise_value(df)[0])
otherwise = self._otherwise_value(df)[0]
elif self._otherwise_value is not None:
otherwise = when._from_scalar(self._otherwise_value)
otherwise._broadcast = True
else:
otherwise = self._otherwise_value
result = self._if_then_else(when.native, df._extract_comparand(then), otherwise)
return [then._with_native(result)]
return [self._if_then_else(when, then, otherwise)]


class LazyWhen(
Expand Down
4 changes: 1 addition & 3 deletions narwhals/_pandas_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,7 @@
)


class PandasLikeDataFrame(
EagerDataFrame["PandasLikeSeries", "PandasLikeExpr", "Any", "pd.Series[Any]"]
):
class PandasLikeDataFrame(EagerDataFrame["PandasLikeSeries", "PandasLikeExpr", "Any"]):
def __init__(
self,
native_dataframe: Any,
Expand Down
34 changes: 16 additions & 18 deletions narwhals/_pandas_like/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import operator
import warnings
from functools import reduce
from typing import TYPE_CHECKING, Any, Literal, Sequence
from typing import TYPE_CHECKING, Literal, Sequence

import pandas as pd

from narwhals._compliant import CompliantThen, EagerNamespace, EagerWhen
from narwhals._expression_parsing import (
Expand All @@ -17,8 +19,6 @@
from narwhals._pandas_like.utils import align_series_full_broadcast

if TYPE_CHECKING:
import pandas as pd

from narwhals._pandas_like.typing import NDFrameT
from narwhals._utils import Implementation, Version
from narwhals.dtypes import DType
Expand All @@ -29,13 +29,7 @@


class PandasLikeNamespace(
EagerNamespace[
PandasLikeDataFrame,
PandasLikeSeries,
PandasLikeExpr,
"pd.DataFrame",
"pd.Series[Any]",
]
EagerNamespace[PandasLikeDataFrame, PandasLikeSeries, PandasLikeExpr, pd.DataFrame]
):
@property
def _dataframe(self) -> type[PandasLikeDataFrame]:
Expand Down Expand Up @@ -315,21 +309,25 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
)


class PandasWhen(
EagerWhen[PandasLikeDataFrame, PandasLikeSeries, PandasLikeExpr, "pd.Series[Any]"]
):
class PandasWhen(EagerWhen[PandasLikeDataFrame, PandasLikeSeries, PandasLikeExpr]):
@property
def _then(self) -> type[PandasThen]:
return PandasThen

def _if_then_else(
self,
when: pd.Series[Any],
then: pd.Series[Any],
otherwise: pd.Series[Any] | NonNestedLiteral,
when: PandasLikeSeries,
then: PandasLikeSeries,
otherwise: PandasLikeSeries | None,
/,
) -> pd.Series[Any]:
return then.where(when) if otherwise is None else then.where(when, otherwise)
) -> PandasLikeSeries:
if otherwise is None:
when, then = align_series_full_broadcast(when, then)
res_native = then.native.where(when.native)
else:
when, then, otherwise = align_series_full_broadcast(when, then, otherwise)
res_native = then.native.where(when.native, otherwise.native)
return then._with_native(res_native)


class PandasThen(
Expand Down
Loading