From 264e717ab9a6429eb08d1de231e5592c55e49cae Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Thu, 17 Apr 2025 19:32:16 +0100 Subject: [PATCH 01/80] feat: Improve ``DataFrame.__getitem__`` consistency --- narwhals/_arrow/dataframe.py | 144 ++++++-------------------- narwhals/_arrow/series.py | 34 +++---- narwhals/_compliant/dataframe.py | 44 +++++++- narwhals/_compliant/series.py | 19 +++- narwhals/_ibis/dataframe.py | 5 - narwhals/_interchange/dataframe.py | 7 -- narwhals/_pandas_like/dataframe.py | 156 ++++++----------------------- narwhals/_pandas_like/expr.py | 6 +- narwhals/_pandas_like/series.py | 18 ++-- narwhals/_polars/namespace.py | 2 +- narwhals/_polars/series.py | 2 + narwhals/dataframe.py | 87 ++++++++++------ narwhals/series.py | 14 +-- narwhals/stable/v1/__init__.py | 10 +- narwhals/utils.py | 38 +++++++ tests/frame/getitem_test.py | 14 ++- 16 files changed, 266 insertions(+), 334 deletions(-) diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index 7745a9e7d4..ea511ffa2f 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -18,17 +18,14 @@ from narwhals._arrow.utils import align_series_full_broadcast from narwhals._arrow.utils import convert_str_slice_to_int_slice from narwhals._arrow.utils import native_to_narwhals_dtype -from narwhals._arrow.utils import select_rows from narwhals._compliant import EagerDataFrame from narwhals._expression_parsing import ExprKind -from narwhals.dependencies import is_numpy_array_1d from narwhals.exceptions import ShapeError from narwhals.utils import Implementation from narwhals.utils import Version from narwhals.utils import check_column_exists from narwhals.utils import check_column_names_are_unique from narwhals.utils import generate_temporary_column_name -from narwhals.utils import is_sequence_but_not_str from narwhals.utils import not_implemented from narwhals.utils import parse_columns_to_drop from narwhals.utils import parse_version @@ -51,7 +48,6 @@ from narwhals._arrow.group_by import ArrowGroupBy from narwhals._arrow.namespace import ArrowNamespace from narwhals._arrow.typing import ArrowChunkedArray - from narwhals._arrow.typing import Indices # type: ignore[attr-defined] from narwhals._arrow.typing import Mask # type: ignore[attr-defined] from narwhals._arrow.typing import Order # type: ignore[attr-defined] from narwhals._translate import IntoArrowTable @@ -62,7 +58,6 @@ from narwhals.typing import JoinStrategy from narwhals.typing import SizeUnit from narwhals.typing import UniqueKeepStrategy - from narwhals.typing import _1DArray from narwhals.typing import _2DArray from narwhals.utils import Version from narwhals.utils import _FullContext @@ -252,118 +247,37 @@ def get_column(self: Self, name: str) -> ArrowSeries: def __array__(self: Self, dtype: Any, *, copy: bool | None) -> _2DArray: return self.native.__array__(dtype, copy=copy) - @overload - def __getitem__( # type: ignore[overload-overlap, unused-ignore] - self: Self, item: str | tuple[slice | Sequence[int] | _1DArray, int | str] - ) -> ArrowSeries: ... - @overload - def __getitem__( - self: Self, - item: ( - int - | slice - | Sequence[int] - | Sequence[str] - | _1DArray - | tuple[ - slice | Sequence[int] | _1DArray, slice | Sequence[int] | Sequence[str] - ] - ), - ) -> Self: ... - def __getitem__( - self: Self, - item: ( - str - | int - | slice - | Sequence[int] - | Sequence[str] - | _1DArray - | tuple[slice | Sequence[int] | _1DArray, int | str] - | tuple[ - slice | Sequence[int] | _1DArray, slice | Sequence[int] | Sequence[str] - ] - ), - ) -> ArrowSeries | Self: - if isinstance(item, tuple): - item = tuple(list(i) if is_sequence_but_not_str(i) else i for i in item) # pyright: ignore[reportAssignmentType] - - if isinstance(item, str): - return ArrowSeries.from_native(self.native[item], context=self, name=item) - elif ( - isinstance(item, tuple) - and len(item) == 2 - and is_sequence_but_not_str(item[1]) - and not isinstance(item[0], str) - ): - if len(item[1]) == 0: - # Return empty dataframe - return self._with_native(self.native.slice(0, 0).select([])) - selected_rows = select_rows(self.native, item[0]) - return self._with_native(selected_rows.select(cast("Indices", item[1]))) - - elif isinstance(item, tuple) and len(item) == 2: - if isinstance(item[1], slice): - columns = self.columns - indices = cast("Indices", item[0]) - if item[1] == slice(None): - if isinstance(item[0], Sequence) and len(item[0]) == 0: - return self._with_native(self.native.slice(0, 0)) - return self._with_native(self.native.take(indices)) - if isinstance(item[1].start, str) or isinstance(item[1].stop, str): - start, stop, step = convert_str_slice_to_int_slice(item[1], columns) - return self._with_native( - self.native.take(indices).select(columns[start:stop:step]) - ) - if isinstance(item[1].start, int) or isinstance(item[1].stop, int): - return self._with_native( - self.native.take(indices).select( - columns[item[1].start : item[1].stop : item[1].step] - ) - ) - msg = f"Expected slice of integers or strings, got: {type(item[1])}" # pragma: no cover - raise TypeError(msg) # pragma: no cover - - # PyArrow columns are always strings - col_name = ( - item[1] - if isinstance(item[1], str) - else self.columns[cast("int", item[1])] - ) - if isinstance(item[0], str): # pragma: no cover - msg = "Can not slice with tuple with the first element as a str" - raise TypeError(msg) - if (isinstance(item[0], slice)) and (item[0] == slice(None)): - return ArrowSeries.from_native( - self.native[col_name], context=self, name=col_name - ) - selected_rows = select_rows(self.native, item[0]) - return ArrowSeries.from_native( - selected_rows[col_name], context=self, name=col_name - ) + def gather(self, item: Any) -> Self: + if len(item) == 0: + return self._with_native(self.native.slice(0, 0)) + return self._with_native(self.native.take(item)) - elif isinstance(item, slice): - if item.step is not None and item.step != 1: - msg = "Slicing with step is not supported on PyArrow tables" - raise NotImplementedError(msg) - columns = self.columns - if isinstance(item.start, str) or isinstance(item.stop, str): - start, stop, step = convert_str_slice_to_int_slice(item, columns) - return self._with_native(self.native.select(columns[start:stop:step])) - start = item.start or 0 - stop = item.stop if item.stop is not None else len(self.native) - return self._with_native(self.native.slice(start, stop - start)) - - elif isinstance(item, Sequence) or is_numpy_array_1d(item): - if isinstance(item, Sequence) and len(item) > 0 and isinstance(item[0], str): - return self._with_native(self.native.select(cast("Indices", item))) - if isinstance(item, Sequence) and len(item) == 0: - return self._with_native(self.native.slice(0, 0)) - return self._with_native(self.native.take(cast("Indices", item))) + def _gather_slice(self, item: Any) -> Self: + start = item.start or 0 + stop = item.stop if item.stop is not None else len(self.native) + if item.step is not None and item.step != 1: + msg = "Slicing with step is not supported on PyArrow tables" + raise NotImplementedError(msg) + return self._with_native(self.native.slice(start, stop - start)) - else: # pragma: no cover - msg = f"Expected str or slice, got: {type(item)}" - raise TypeError(msg) + def _select_slice_of_labels(self, item: Any) -> Self: + start, stop, step = convert_str_slice_to_int_slice(item, self.columns) + return self._with_native(self.native.select(self.columns[start:stop:step])) + + def _select_slice_of_indices(self, item: Any) -> Self: + return self._with_native( + self.native.select(self.columns[item.start : item.stop : item.step]) + ) + + def _select_indices(self, item: Any) -> Self: + if len(item) == 0: + return self._with_native( + self.native.__class__.from_arrays([]), validate_column_names=False + ) + return self._with_native(self.native.select([self.columns[x] for x in item])) + + def _select_labels(self, item: Any) -> Self: + return self._with_native(self.native.select(item)) @property def schema(self: Self) -> dict[str, DType]: diff --git a/narwhals/_arrow/series.py b/narwhals/_arrow/series.py index efeb851e95..741bf2361b 100644 --- a/narwhals/_arrow/series.py +++ b/narwhals/_arrow/series.py @@ -405,22 +405,18 @@ def __native_namespace__(self: Self) -> ModuleType: def name(self: Self) -> str: return self._name - @overload - def __getitem__(self: Self, idx: int) -> Any: ... - - @overload - def __getitem__( - self: Self, idx: slice | Sequence[int] | ArrowChunkedArray - ) -> Self: ... - - def __getitem__( - self: Self, idx: int | slice | Sequence[int] | ArrowChunkedArray - ) -> Any | Self: - if isinstance(idx, int): - return maybe_extract_py_scalar(self.native[idx], return_py_scalar=True) - if isinstance(idx, (Sequence, pa.ChunkedArray)): - return self._with_native(self.native.take(idx)) - return self._with_native(self.native[idx]) + def gather(self, item: Any) -> Self: + if len(item) == 0: + return self._with_native(self.native.slice(0, 0)) + return self._with_native(self.native.take(item)) + + def _gather_slice(self, item: Any) -> Self: + start = item.start or 0 + stop = item.stop if item.stop is not None else len(self.native) + if item.step is not None and item.step != 1: + msg = "Slicing with step is not supported on PyArrow tables" + raise NotImplementedError(msg) + return self._with_native(self.native.slice(start, stop - start)) def scatter(self: Self, indices: int | Sequence[int], values: Any) -> Self: import numpy as np # ignore-banned-import @@ -924,7 +920,7 @@ def rolling_sum( result = self._with_native( pc.if_else((count_in_window >= min_samples).native, rolling_sum.native, None) ) - return result[offset:] + return result._gather_slice(slice(offset, None)) def rolling_mean( self: Self, @@ -959,7 +955,7 @@ def rolling_mean( ) / count_in_window ) - return result[offset:] + return result._gather_slice(slice(offset, None)) def rolling_var( self: Self, @@ -1007,7 +1003,7 @@ def rolling_var( ) ) / self._with_native(pc.max_element_wise((count_in_window - ddof).native, 0)) - return result[offset:] + return result._gather_slice(slice(offset, None, None)) def rolling_std( self: Self, diff --git a/narwhals/_compliant/dataframe.py b/narwhals/_compliant/dataframe.py index 08ac415380..cd5ab2c7d4 100644 --- a/narwhals/_compliant/dataframe.py +++ b/narwhals/_compliant/dataframe.py @@ -25,6 +25,10 @@ from narwhals.utils import Version from narwhals.utils import _StoresNative from narwhals.utils import deprecated +from narwhals.utils import is_int_like_indexer +from narwhals.utils import is_null_slice +from narwhals.utils import is_sequence_like +from narwhals.utils import is_sequence_like_ints if TYPE_CHECKING: from io import BytesIO @@ -99,7 +103,6 @@ def from_numpy( schema: Mapping[str, DType] | Schema | Sequence[str] | None, ) -> Self: ... def __array__(self, dtype: Any, *, copy: bool | None) -> _2DArray: ... - def __getitem__(self, item: Any) -> CompliantSeriesT | Self: ... def simple_select(self, *column_names: str) -> Self: """`select` where all args are column names.""" ... @@ -227,6 +230,7 @@ def write_csv(self, file: None) -> str: ... def write_csv(self, file: str | Path | BytesIO) -> None: ... def write_csv(self, file: str | Path | BytesIO | None) -> str | None: ... def write_parquet(self, file: str | Path | BytesIO) -> None: ... + def __getitem__(self, item: tuple[Any, Any]) -> Self: ... class CompliantLazyFrame( @@ -369,3 +373,41 @@ def _numpy_column_names( data: _2DArray, columns: Sequence[str] | None, / ) -> list[str]: return list(columns or (f"column_{x}" for x in range(data.shape[1]))) + + def gather(self, indices: Any) -> Self: ... + def _gather_slice(self, indices: Any) -> Self: ... + def _select_indices(self, indices: Any) -> Self: ... + def _select_labels(self, indices: Any) -> Self: ... + def _select_slice_of_indices(self, indices: Any) -> Self: ... + def _select_slice_of_labels(self, indices: Any) -> Self: ... + + def __getitem__(self, item: tuple[Any, Any]) -> Self: + rows, columns = item + + is_int_col_indexer = is_int_like_indexer(columns) + compliant = self + if not is_null_slice(columns): + if is_int_col_indexer and not isinstance(columns, slice): + compliant = compliant._select_indices(columns) + elif is_int_col_indexer: + compliant = compliant._select_slice_of_indices(columns) + elif isinstance(columns, slice): + compliant = compliant._select_slice_of_labels(columns) + elif is_sequence_like(columns): + compliant = self._select_labels(columns) + else: + msg = "Unreachable code" + raise AssertionError(msg) + + if not is_null_slice(rows): + if isinstance(rows, int): + compliant = compliant.gather([rows]) + elif isinstance(rows, (slice, range)): + compliant = compliant._gather_slice(rows) + elif is_sequence_like_ints(rows): + compliant = compliant.gather(rows) + else: + msg = "Unreachable code" + raise AssertionError(msg) + + return compliant diff --git a/narwhals/_compliant/series.py b/narwhals/_compliant/series.py index 0347bb95ab..b17f5333a3 100644 --- a/narwhals/_compliant/series.py +++ b/narwhals/_compliant/series.py @@ -23,6 +23,8 @@ from narwhals._translate import NumpyConvertible from narwhals.utils import _StoresCompliant from narwhals.utils import _StoresNative +from narwhals.utils import is_null_slice +from narwhals.utils import is_sequence_like_ints from narwhals.utils import unstable if TYPE_CHECKING: @@ -78,7 +80,6 @@ def __narwhals_namespace__(self) -> CompliantNamespace[Any, Any]: ... def __native_namespace__(self) -> ModuleType: ... def __array__(self, dtype: Any, *, copy: bool | None) -> _1DArray: ... def __contains__(self, other: Any) -> bool: ... - def __getitem__(self, item: Any) -> Any: ... def __iter__(self) -> Iterator[Any]: ... def __len__(self) -> int: return len(self.native) @@ -285,6 +286,22 @@ def list(self) -> Any: ... @property def struct(self) -> Any: ... + def gather(self, indices: Any) -> Self: ... + def _gather_slice(self, indices: Any) -> Self: ... + + def __getitem__(self, rows: Any) -> Self: + if is_null_slice(rows): + return self + if isinstance(rows, int): + return self.gather([rows]) + elif isinstance(rows, (slice, range)): + return self._gather_slice(rows) + elif is_sequence_like_ints(rows): + return self.gather(rows) + else: + msg = "Unreachable code" + raise AssertionError(msg) + class EagerSeries(CompliantSeries[NativeSeriesT], Protocol[NativeSeriesT]): _native_series: Any diff --git a/narwhals/_ibis/dataframe.py b/narwhals/_ibis/dataframe.py index 7396982c7c..ee879e102e 100644 --- a/narwhals/_ibis/dataframe.py +++ b/narwhals/_ibis/dataframe.py @@ -100,11 +100,6 @@ def __narwhals_lazyframe__(self) -> Any: def __native_namespace__(self: Self) -> ModuleType: return get_ibis() - def __getitem__(self, item: str) -> IbisInterchangeSeries: - from narwhals._ibis.series import IbisInterchangeSeries - - return IbisInterchangeSeries(self._native_frame[item], version=self._version) - def to_pandas(self: Self) -> pd.DataFrame: return self._native_frame.to_pandas() diff --git a/narwhals/_interchange/dataframe.py b/narwhals/_interchange/dataframe.py index 6252e18662..e28905551f 100644 --- a/narwhals/_interchange/dataframe.py +++ b/narwhals/_interchange/dataframe.py @@ -105,13 +105,6 @@ def __native_namespace__(self: Self) -> NoReturn: ) raise NotImplementedError(msg) - def __getitem__(self: Self, item: str) -> InterchangeSeries: - from narwhals._interchange.series import InterchangeSeries - - return InterchangeSeries( - self._interchange_frame.get_column_by_name(item), version=self._version - ) - def to_pandas(self: Self) -> pd.DataFrame: import pandas as pd # ignore-banned-import() diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index 752b047f33..ead7f02341 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -27,7 +27,6 @@ from narwhals._pandas_like.utils import rename from narwhals._pandas_like.utils import select_columns_by_name from narwhals._pandas_like.utils import set_index -from narwhals.dependencies import is_numpy_array_1d from narwhals.dependencies import is_pandas_like_dataframe from narwhals.exceptions import InvalidOperationError from narwhals.exceptions import ShapeError @@ -37,7 +36,6 @@ from narwhals.utils import check_column_exists from narwhals.utils import generate_temporary_column_name from narwhals.utils import import_dtypes_module -from narwhals.utils import is_sequence_but_not_str from narwhals.utils import parse_columns_to_drop from narwhals.utils import parse_version from narwhals.utils import scale_bytes @@ -68,7 +66,6 @@ from narwhals.typing import PivotAgg from narwhals.typing import SizeUnit from narwhals.typing import UniqueKeepStrategy - from narwhals.typing import _1DArray from narwhals.typing import _2DArray from narwhals.utils import Version from narwhals.utils import _FullContext @@ -281,135 +278,40 @@ def get_column(self: Self, name: str) -> PandasLikeSeries: def __array__(self: Self, dtype: Any = None, *, copy: bool | None = None) -> _2DArray: return self.to_numpy(dtype=dtype, copy=copy) - @overload - def __getitem__( # type: ignore[overload-overlap] - self: Self, - item: str | tuple[slice | Sequence[int] | _1DArray, int | str], - ) -> PandasLikeSeries: ... + def gather(self, items: Any) -> Self: + items = list(items) if isinstance(items, tuple) else items + return self._with_native(self.native.iloc[items, :]) - @overload - def __getitem__( - self: Self, - item: ( - int - | slice - | Sequence[int] - | Sequence[str] - | _1DArray - | tuple[ - slice | Sequence[int] | _1DArray, slice | Sequence[int] | Sequence[str] - ] - ), - ) -> Self: ... - def __getitem__( - self: Self, - item: ( - str - | int - | slice - | Sequence[int] - | Sequence[str] - | _1DArray - | tuple[slice | Sequence[int] | _1DArray, int | str] - | tuple[ - slice | Sequence[int] | _1DArray, slice | Sequence[int] | Sequence[str] - ] - ), - ) -> PandasLikeSeries | Self: - if isinstance(item, tuple): - item = tuple(list(i) if is_sequence_but_not_str(i) else i for i in item) # pyright: ignore[reportAssignmentType] - - if isinstance(item, str): - return PandasLikeSeries.from_native(self.native[item], context=self) - - elif ( - isinstance(item, tuple) - and len(item) == 2 - and is_sequence_but_not_str(item[1]) - ): - if len(item[1]) == 0: - # Return empty dataframe - return self._with_native( - self.native.__class__(), validate_column_names=False - ) - if isinstance(item[1][0], int): - return self._with_native( - self.native.iloc[item], validate_column_names=False - ) - if isinstance(item[1][0], str): - indexer = ( - item[0], - self.native.columns.get_indexer(item[1]), - ) - return self._with_native( - self.native.iloc[indexer], validate_column_names=False - ) - msg = ( - f"Expected sequence str or int, got: {type(item[1])}" # pragma: no cover - ) - raise TypeError(msg) # pragma: no cover - - elif isinstance(item, tuple) and len(item) == 2 and isinstance(item[1], slice): - columns = self.native.columns - if item[1] == slice(None): - return self._with_native( - self.native.iloc[item[0], :], validate_column_names=False - ) - if isinstance(item[1].start, str) or isinstance(item[1].stop, str): - start, stop, step = convert_str_slice_to_int_slice(item[1], columns) - return self._with_native( - self.native.iloc[item[0], slice(start, stop, step)], - validate_column_names=False, - ) - if isinstance(item[1].start, int) or isinstance(item[1].stop, int): - return self._with_native( - self.native.iloc[ - item[0], slice(item[1].start, item[1].stop, item[1].step) - ], - validate_column_names=False, - ) - msg = f"Expected slice of integers or strings, got: {type(item[1])}" # pragma: no cover - raise TypeError(msg) # pragma: no cover - - elif isinstance(item, tuple) and len(item) == 2: - if isinstance(item[1], str): - index = (item[0], self.native.columns.get_loc(item[1])) - native_series = self.native.iloc[index] - elif isinstance(item[1], int): - native_series = self.native.iloc[item] - else: # pragma: no cover - msg = f"Expected str or int, got: {type(item[1])}" - raise TypeError(msg) + def _gather_slice(self, item: Any) -> Self: + return self._with_native( + self.native.iloc[slice(item.start, item.stop, item.step), :], + validate_column_names=False, + ) - return PandasLikeSeries.from_native(native_series, context=self) + def _select_slice_of_labels(self, item: Any) -> Self: + start, stop, step = convert_str_slice_to_int_slice(item, self.native.columns) + return self._with_native( + self.native.iloc[:, slice(start, stop, step)], + validate_column_names=False, + ) - elif is_sequence_but_not_str(item) or is_numpy_array_1d(item): - if len(item) > 0 and isinstance(item[0], str): - return self._with_native( - select_columns_by_name( - self.native, - cast("list[str] | _1DArray", item), - self._backend_version, - self._implementation, - ), - validate_column_names=False, - ) - return self._with_native(self.native.iloc[item], validate_column_names=False) + def _select_slice_of_indices(self, item: Any) -> Self: + return self._with_native( + self.native.iloc[:, slice(item.start, item.stop, item.step)], + validate_column_names=False, + ) - elif isinstance(item, slice): - if isinstance(item.start, str) or isinstance(item.stop, str): - start, stop, step = convert_str_slice_to_int_slice( - item, self.native.columns - ) - return self._with_native( - self.native.iloc[:, slice(start, stop, step)], - validate_column_names=False, - ) - return self._with_native(self.native.iloc[item], validate_column_names=False) + def _select_indices(self, item: Any) -> Self: + item = list(item) if isinstance(item, tuple) else item + if len(item) == 0: + return self._with_native(self.native.__class__(), validate_column_names=False) + return self._with_native( + self.native.iloc[:, item], + validate_column_names=False, + ) - else: # pragma: no cover - msg = f"Expected str or slice, got: {type(item)}" - raise TypeError(msg) + def _select_labels(self, indices: Any) -> PandasLikeDataFrame: + return self._with_native(self.native.loc[:, indices]) # --- properties --- @property diff --git a/narwhals/_pandas_like/expr.py b/narwhals/_pandas_like/expr.py index ae6197a5cd..765d8cc230 100644 --- a/narwhals/_pandas_like/expr.py +++ b/narwhals/_pandas_like/expr.py @@ -264,7 +264,9 @@ def func(df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]: sorting_indices = df.get_column(token) elif reverse: columns = list(set(partition_by).union(output_names)) - df = df.simple_select(*columns)[::-1] + df = df.simple_select(*columns)._gather_slice( + slice(None, None, -1) + ) grouped = df._native_frame.groupby(partition_by) if function_name.startswith("rolling"): rolling = grouped[list(output_names)].rolling(**pandas_kwargs) @@ -293,7 +295,7 @@ def func(df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]: s._scatter_in_place(sorting_indices, s) return results if reverse: - return [s[::-1] for s in results] + return [s._gather_slice(slice(None, None, -1)) for s in results] return results return self.__class__( diff --git a/narwhals/_pandas_like/series.py b/narwhals/_pandas_like/series.py index 90bb1a94ec..7215c731d4 100644 --- a/narwhals/_pandas_like/series.py +++ b/narwhals/_pandas_like/series.py @@ -7,7 +7,6 @@ from typing import Mapping from typing import Sequence from typing import cast -from typing import overload import numpy as np @@ -26,7 +25,6 @@ from narwhals._pandas_like.utils import select_columns_by_name from narwhals._pandas_like.utils import set_index from narwhals.dependencies import is_numpy_array_1d -from narwhals.dependencies import is_numpy_scalar from narwhals.dependencies import is_pandas_like_series from narwhals.exceptions import InvalidOperationError from narwhals.utils import Implementation @@ -147,16 +145,14 @@ def __narwhals_namespace__(self) -> PandasLikeNamespace: self._implementation, self._backend_version, self._version ) - @overload - def __getitem__(self: Self, idx: int) -> Any: ... + def gather(self, rows: Any) -> Self: + rows = list(rows) if isinstance(rows, tuple) else rows + return self._with_native(self.native.iloc[rows]) - @overload - def __getitem__(self: Self, idx: slice | Sequence[int]) -> Self: ... - - def __getitem__(self: Self, idx: int | slice | Sequence[int]) -> Any | Self: - if isinstance(idx, int) or is_numpy_scalar(idx): - return self.native.iloc[idx] - return self._with_native(self.native.iloc[idx]) + def _gather_slice(self, item: Any) -> Self: + return self._with_native( + self.native.iloc[slice(item.start, item.stop, item.step)] + ) def _with_version(self: Self, version: Version) -> Self: return self.__class__( diff --git a/narwhals/_polars/namespace.py b/narwhals/_polars/namespace.py index 27094362f5..a0988672f7 100644 --- a/narwhals/_polars/namespace.py +++ b/narwhals/_polars/namespace.py @@ -254,7 +254,7 @@ def concat_str( @property def selectors(self) -> CompliantSelectorNamespace[PolarsDataFrame, PolarsSeries]: return cast( - "CompliantSelectorNamespace[PolarsDataFrame, PolarsSeries]", + "CompliantSelectorNamespace[PolarsDataFrame, PolarsSeries]", # pyright: ignore[reportInvalidTypeArguments] PolarsSelectorNamespace(self), ) diff --git a/narwhals/_polars/series.py b/narwhals/_polars/series.py index ca9c0410f8..664a7b9988 100644 --- a/narwhals/_polars/series.py +++ b/narwhals/_polars/series.py @@ -616,6 +616,8 @@ def struct(self: Self) -> PolarsSeriesStructNamespace: drop_nulls: Method[Self] fill_null: Method[Self] filter: Method[Self] + _gather_slice: Method[Self] + gather: Method[Self] gather_every: Method[Self] head: Method[Self] is_between: Method[Self] diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index f67572e1ae..16d737320b 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -22,7 +22,6 @@ from narwhals._expression_parsing import is_scalar_like from narwhals.dependencies import get_polars from narwhals.dependencies import is_numpy_array -from narwhals.dependencies import is_numpy_array_1d from narwhals.exceptions import ColumnNotFoundError from narwhals.exceptions import InvalidIntoExprError from narwhals.exceptions import LengthChangingExprError @@ -35,8 +34,10 @@ from narwhals.utils import generate_repr from narwhals.utils import is_compliant_dataframe from narwhals.utils import is_compliant_lazyframe +from narwhals.utils import is_int_like_indexer from narwhals.utils import is_list_of -from narwhals.utils import is_sequence_but_not_str +from narwhals.utils import is_null_slice +from narwhals.utils import is_sequence_like from narwhals.utils import issue_deprecation_warning from narwhals.utils import parse_version from narwhals.utils import supports_arrow_c_stream @@ -797,6 +798,9 @@ def estimated_size(self: Self, unit: SizeUnit = "b") -> int | float: """ return self._compliant_frame.estimated_size(unit=unit) + @overload + def __getitem__(self: Self, item: tuple[int, int | str]) -> Any: ... + @overload def __getitem__( # type: ignore[overload-overlap] self: Self, @@ -813,7 +817,8 @@ def __getitem__( | Sequence[str] | _1DArray | tuple[ - slice | Sequence[int] | _1DArray, slice | Sequence[int] | Sequence[str] + slice | Sequence[int] | _1DArray, + slice | Sequence[int] | Sequence[str] | _1DArray, ] ), ) -> Self: ... @@ -826,12 +831,14 @@ def __getitem__( | Sequence[int] | Sequence[str] | _1DArray + | tuple[int, int | str] | tuple[slice | Sequence[int] | _1DArray, int | str] | tuple[ - slice | Sequence[int] | _1DArray, slice | Sequence[int] | Sequence[str] + slice | Sequence[int] | _1DArray, + slice | Sequence[int] | Sequence[str] | _1DArray, ] ), - ) -> Series[Any] | Self: + ) -> Series[Any] | Self | Any: """Extract column or slice of DataFrame. Arguments: @@ -881,41 +888,55 @@ def __getitem__( 1 2 Name: a, dtype: int64 """ - if isinstance(item, int): - item = [item] - if ( - isinstance(item, tuple) - and len(item) == 2 - and (isinstance(item[0], (str, int))) - ): + if isinstance(item, tuple) and len(item) > 2: msg = ( - f"Expected str or slice, got: {type(item)}.\n\n" - "Hint: if you were trying to get a single element out of a " - "dataframe, use `DataFrame.item`." + "Tuples be passed to DataFrame.__getitem__ directly.\n\n" + "Hint: instead of `df[indices]`, did you mean `df[indices, :]`?" ) raise TypeError(msg) - if ( - isinstance(item, tuple) - and len(item) == 2 - and (is_sequence_but_not_str(item[1]) or isinstance(item[1], slice)) - ): - if item[1] == slice(None) and item[0] == slice(None): - return self - return self._with_compliant(self._compliant_frame[item]) - if isinstance(item, str) or (isinstance(item, tuple) and len(item) == 2): - return self._series(self._compliant_frame[item], level=self._level) - - elif ( - is_sequence_but_not_str(item) - or isinstance(item, slice) - or (is_numpy_array_1d(item)) - ): - return self._with_compliant(self._compliant_frame[item]) + if isinstance(item, tuple) and len(item) == 2: + # These are so heavily overloaded that we just ignore the types for now. + rows: Any = item[0] if not is_null_slice(item[0]) else None + columns: Any = item[1] if not is_null_slice(item[1]) else None + elif isinstance(item, tuple) and item: + rows = item[0] + columns = None + elif isinstance(item, str): + rows = None + columns = item + elif is_int_like_indexer(item): + rows = item + columns = None + elif is_sequence_like(item) or isinstance(item, (slice, range)): + rows = None + columns = item else: - msg = f"Expected str or slice, got: {type(item)}" + msg = ( + f"Expected str or slice, got: {type(item)}.\n\n" + "Hints:\n" + "- use `DataFrame.item` to select a single item.\n" + "- Use `DataFrame[indices, :]` to select rows positionally.\n" + "- Use `DataFrame.filter(mask)` to filter rows based on a boolean mask." + ) raise TypeError(msg) + compliant = self._compliant_frame + + if isinstance(rows, int) and isinstance(columns, (int, str)): + return self.item(rows, columns) + if isinstance(columns, (str, int)): + col_name = columns if isinstance(columns, str) else self.columns[columns] + series = self.get_column(col_name) + return series[rows] if rows is not None else series + if rows is None and columns is None: + return self + if rows is None: + return self._with_compliant(compliant[:, columns]) + if columns is None: + return self._with_compliant(compliant[rows, :]) + return self._with_compliant(compliant[rows, columns]) + def __contains__(self: Self, key: str) -> bool: return key in self.columns diff --git a/narwhals/series.py b/narwhals/series.py index 5fe3e3f19b..f8132c1e5c 100644 --- a/narwhals/series.py +++ b/narwhals/series.py @@ -8,6 +8,7 @@ from typing import Literal from typing import Mapping from typing import Sequence +from typing import cast from typing import overload from narwhals.dependencies import is_numpy_scalar @@ -132,9 +133,11 @@ def __array__(self: Self, dtype: Any = None, copy: bool | None = None) -> _1DArr def __getitem__(self: Self, idx: int) -> Any: ... @overload - def __getitem__(self: Self, idx: slice | Sequence[int] | Self) -> Self: ... + def __getitem__(self: Self, idx: slice | Sequence[int] | _1DArray | Self) -> Self: ... - def __getitem__(self: Self, idx: int | slice | Sequence[int] | Self) -> Any | Self: + def __getitem__( + self: Self, idx: int | slice | Sequence[int] | _1DArray | Self + ) -> Any | Self: """Retrieve elements from the object using integer indexing or slicing. Arguments: @@ -169,10 +172,9 @@ def __getitem__(self: Self, idx: int | slice | Sequence[int] | Self) -> Any | Se if isinstance(idx, int) or ( is_numpy_scalar(idx) and idx.dtype.kind in {"i", "u"} ): - return self._compliant_series[idx] - return self._with_compliant( - self._compliant_series[to_native(idx, pass_through=True)] - ) + return self._compliant_series.item(cast("int", idx)) + idx = to_native(idx, pass_through=True) + return self._with_compliant(self._compliant_series[idx]) def __native_namespace__(self: Self) -> ModuleType: return self._compliant_series.__native_namespace__() diff --git a/narwhals/stable/v1/__init__.py b/narwhals/stable/v1/__init__.py index bb69caf34b..71d03e763e 100644 --- a/narwhals/stable/v1/__init__.py +++ b/narwhals/stable/v1/__init__.py @@ -150,26 +150,30 @@ def _series(self: Self) -> type[Series[Any]]: def _lazyframe(self: Self) -> type[LazyFrame[Any]]: return LazyFrame + @overload + def __getitem__(self: Self, item: tuple[int, int | str]) -> Any: ... + @overload def __getitem__( # type: ignore[overload-overlap] self: Self, item: str | tuple[slice | Sequence[int] | _1DArray, int | str], ) -> Series[Any]: ... + @overload def __getitem__( self: Self, item: ( int | slice - | _1DArray | Sequence[int] | Sequence[str] + | _1DArray | tuple[ - slice | Sequence[int] | _1DArray, slice | Sequence[int] | Sequence[str] + slice | Sequence[int] | _1DArray, + slice | Sequence[int] | Sequence[str] | _1DArray, ] ), ) -> Self: ... - def __getitem__(self: Self, item: Any) -> Any: return super().__getitem__(item) diff --git a/narwhals/utils.py b/narwhals/utils.py index d268a02261..d7d1115084 100644 --- a/narwhals/utils.py +++ b/narwhals/utils.py @@ -28,6 +28,7 @@ from narwhals.dependencies import get_duckdb from narwhals.dependencies import get_ibis from narwhals.dependencies import get_modin +from narwhals.dependencies import get_numpy from narwhals.dependencies import get_pandas from narwhals.dependencies import get_polars from narwhals.dependencies import get_pyarrow @@ -35,6 +36,7 @@ from narwhals.dependencies import get_sqlframe from narwhals.dependencies import is_cudf_series from narwhals.dependencies import is_modin_series +from narwhals.dependencies import is_numpy_array_1d from narwhals.dependencies import is_pandas_dataframe from narwhals.dependencies import is_pandas_like_dataframe from narwhals.dependencies import is_pandas_like_series @@ -1271,6 +1273,42 @@ def is_sequence_but_not_str(sequence: Any | Sequence[_T]) -> TypeIs[Sequence[_T] return isinstance(sequence, Sequence) and not isinstance(sequence, str) +def is_sequence_like_ints(sequence: Any | Sequence[_T]) -> bool: + np = get_numpy() + return ( + isinstance(sequence, Sequence) + and ((len(sequence) > 0 and isinstance(sequence[0], int)) or (len(sequence) == 0)) + ) or (is_numpy_array_1d(sequence) and np.issubdtype(sequence.dtype, np.integer)) + + +def is_sequence_like(sequence: Any | Sequence[_T]) -> bool: + return (isinstance(sequence, Sequence) and not isinstance(sequence, str)) or ( + is_numpy_array_1d(sequence) + ) + + +def is_slice_strs(obj: object) -> bool: + return isinstance(obj, slice) and ( + isinstance(obj.start, str) or isinstance(obj.stop, str) + ) + + +def is_slice_ints(obj: object) -> bool: + return isinstance(obj, slice) and ( + isinstance(obj.start, int) # e.g. [1:] + or isinstance(obj.stop, int) # e.g. [:3] + or (obj.start is None and obj.stop is None) # e.g. [::2] + ) + + +def is_int_like_indexer(cols: object) -> bool: + return isinstance(cols, int) or is_sequence_like_ints(cols) or is_slice_ints(cols) + + +def is_null_slice(obj: object) -> bool: + return isinstance(obj, slice) and obj == slice(None) + + def is_list_of(obj: Any, tp: type[_T]) -> TypeIs[list[type[_T]]]: # Check if an object is a list of `tp`, only sniffing the first element. return bool(isinstance(obj, list) and obj and isinstance(obj[0], tp)) diff --git a/tests/frame/getitem_test.py b/tests/frame/getitem_test.py index c1c4c09588..701287f2f2 100644 --- a/tests/frame/getitem_test.py +++ b/tests/frame/getitem_test.py @@ -171,11 +171,10 @@ def test_slice_slice_columns( assert_equal_data(result, expected) -def test_slice_invalid(constructor_eager: ConstructorEager) -> None: +def test_slice_item(constructor_eager: ConstructorEager) -> None: data = {"a": [1, 2], "b": [4, 5]} df = nw.from_native(constructor_eager(data), eager_only=True) - with pytest.raises(TypeError, match="Hint:"): - df[0, 0] + assert df[0, 0] == 1 def test_slice_edge_cases(constructor_eager: ConstructorEager) -> None: @@ -242,3 +241,12 @@ def test_get_item_works_with_tuple_and_list_indexing_and_str( ) -> None: nw_df = nw.from_native(constructor_eager(data), eager_only=True) nw_df[row_idx, col] + + +def test_getitem_ndarray_columns(constructor_eager: ConstructorEager) -> None: + data = {"col1": ["a", "b", "c", "d"], "col2": np.arange(4), "col3": [4, 3, 2, 1]} + nw_df = nw.from_native(constructor_eager(data), eager_only=True) + arr: np.ndarray[tuple[int], np.dtype[np.int64]] = np.array([0, 1]) # pyright: ignore[reportAssignmentType] + result = nw_df[:, arr] + expected = {"col1": ["a", "b", "c", "d"], "col2": [0, 1, 2, 3]} + assert_equal_data(result, expected) From 8bc2bfaed07829d45ec7802ef4eea9ce3e2added Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 17 Apr 2025 18:36:34 +0000 Subject: [PATCH 02/80] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- narwhals/_ibis/dataframe.py | 1 - narwhals/_interchange/dataframe.py | 1 - narwhals/_pandas_like/expr.py | 4 +--- 3 files changed, 1 insertion(+), 5 deletions(-) diff --git a/narwhals/_ibis/dataframe.py b/narwhals/_ibis/dataframe.py index ee879e102e..2ecb3d32dc 100644 --- a/narwhals/_ibis/dataframe.py +++ b/narwhals/_ibis/dataframe.py @@ -19,7 +19,6 @@ import pyarrow as pa from typing_extensions import Self - from narwhals._ibis.series import IbisInterchangeSeries from narwhals.dtypes import DType diff --git a/narwhals/_interchange/dataframe.py b/narwhals/_interchange/dataframe.py index e28905551f..f2fc27cf64 100644 --- a/narwhals/_interchange/dataframe.py +++ b/narwhals/_interchange/dataframe.py @@ -13,7 +13,6 @@ import pyarrow as pa from typing_extensions import Self - from narwhals._interchange.series import InterchangeSeries from narwhals.dtypes import DType from narwhals.typing import DataFrameLike from narwhals.utils import Version diff --git a/narwhals/_pandas_like/expr.py b/narwhals/_pandas_like/expr.py index 765d8cc230..fd4643363c 100644 --- a/narwhals/_pandas_like/expr.py +++ b/narwhals/_pandas_like/expr.py @@ -264,9 +264,7 @@ def func(df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]: sorting_indices = df.get_column(token) elif reverse: columns = list(set(partition_by).union(output_names)) - df = df.simple_select(*columns)._gather_slice( - slice(None, None, -1) - ) + df = df.simple_select(*columns)._gather_slice(slice(None, None, -1)) grouped = df._native_frame.groupby(partition_by) if function_name.startswith("rolling"): rolling = grouped[list(output_names)].rolling(**pandas_kwargs) From 2c4f1c1245aee475ee05651dc53671f2a8ce06d8 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Fri, 18 Apr 2025 14:46:44 +0100 Subject: [PATCH 03/80] fixup --- narwhals/_arrow/dataframe.py | 3 ++- narwhals/_compliant/dataframe.py | 5 ++++- narwhals/_compliant/series.py | 3 ++- narwhals/_duckdb/dataframe.py | 4 ++-- narwhals/_ibis/dataframe.py | 6 ++++++ narwhals/_interchange/dataframe.py | 8 ++++++++ narwhals/_pandas_like/dataframe.py | 5 ++++- narwhals/_polars/dataframe.py | 24 ++++++++++++++++++------ narwhals/dataframe.py | 2 ++ narwhals/utils.py | 30 ++++++++++++++++++++++++------ 10 files changed, 72 insertions(+), 18 deletions(-) diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index ea511ffa2f..32e35913d9 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -76,7 +76,8 @@ class ArrowDataFrame(EagerDataFrame["ArrowSeries", "ArrowExpr", "pa.Table"]): - # --- not in the spec --- + native_series = pa.ChunkedArray + def __init__( self: Self, native_dataframe: pa.Table, diff --git a/narwhals/_compliant/dataframe.py b/narwhals/_compliant/dataframe.py index cd5ab2c7d4..8538f77928 100644 --- a/narwhals/_compliant/dataframe.py +++ b/narwhals/_compliant/dataframe.py @@ -337,6 +337,9 @@ class EagerDataFrame( CompliantLazyFrame[EagerExprT_contra, NativeFrameT], Protocol[EagerSeriesT, EagerExprT_contra, NativeFrameT], ): + @property + def native_series(self) -> Any: ... + def _evaluate_expr(self, expr: EagerExprT_contra, /) -> EagerSeriesT: """Evaluate `expr` and ensure it has a **single** output.""" result: Sequence[EagerSeriesT] = expr(self) @@ -404,7 +407,7 @@ def __getitem__(self, item: tuple[Any, Any]) -> Self: compliant = compliant.gather([rows]) elif isinstance(rows, (slice, range)): compliant = compliant._gather_slice(rows) - elif is_sequence_like_ints(rows): + elif is_sequence_like_ints(rows) or isinstance(rows, self.native_series): compliant = compliant.gather(rows) else: msg = "Unreachable code" diff --git a/narwhals/_compliant/series.py b/narwhals/_compliant/series.py index b17f5333a3..e6dee7c22d 100644 --- a/narwhals/_compliant/series.py +++ b/narwhals/_compliant/series.py @@ -287,6 +287,7 @@ def list(self) -> Any: ... def struct(self) -> Any: ... def gather(self, indices: Any) -> Self: ... + def _gather_slice(self, indices: Any) -> Self: ... def __getitem__(self, rows: Any) -> Self: @@ -296,7 +297,7 @@ def __getitem__(self, rows: Any) -> Self: return self.gather([rows]) elif isinstance(rows, (slice, range)): return self._gather_slice(rows) - elif is_sequence_like_ints(rows): + elif is_sequence_like_ints(rows) or isinstance(rows, self.native.__class__): return self.gather(rows) else: msg = "Unreachable code" diff --git a/narwhals/_duckdb/dataframe.py b/narwhals/_duckdb/dataframe.py index 3a5f749d3f..1cf4067397 100644 --- a/narwhals/_duckdb/dataframe.py +++ b/narwhals/_duckdb/dataframe.py @@ -102,10 +102,10 @@ def __narwhals_namespace__(self: Self) -> DuckDBNamespace: backend_version=self._backend_version, version=self._version ) - def __getitem__(self: Self, item: str) -> DuckDBInterchangeSeries: + def get_column(self: Self, col: str) -> DuckDBInterchangeSeries: from narwhals._duckdb.series import DuckDBInterchangeSeries - return DuckDBInterchangeSeries(self.native.select(item), version=self._version) + return DuckDBInterchangeSeries(self.native.select(col), version=self._version) def _iter_columns(self) -> Iterator[duckdb.Expression]: for name in self.columns: diff --git a/narwhals/_ibis/dataframe.py b/narwhals/_ibis/dataframe.py index 2ecb3d32dc..fba83f6a9e 100644 --- a/narwhals/_ibis/dataframe.py +++ b/narwhals/_ibis/dataframe.py @@ -19,6 +19,7 @@ import pyarrow as pa from typing_extensions import Self + from narwhals._ibis.series import IbisInterchangeSeries from narwhals.dtypes import DType @@ -99,6 +100,11 @@ def __narwhals_lazyframe__(self) -> Any: def __native_namespace__(self: Self) -> ModuleType: return get_ibis() + def get_column(self, col: str) -> IbisInterchangeSeries: + from narwhals._ibis.series import IbisInterchangeSeries + + return IbisInterchangeSeries(self._native_frame[col], version=self._version) + def to_pandas(self: Self) -> pd.DataFrame: return self._native_frame.to_pandas() diff --git a/narwhals/_interchange/dataframe.py b/narwhals/_interchange/dataframe.py index f2fc27cf64..b672624b06 100644 --- a/narwhals/_interchange/dataframe.py +++ b/narwhals/_interchange/dataframe.py @@ -13,6 +13,7 @@ import pyarrow as pa from typing_extensions import Self + from narwhals._interchange.series import InterchangeSeries from narwhals.dtypes import DType from narwhals.typing import DataFrameLike from narwhals.utils import Version @@ -104,6 +105,13 @@ def __native_namespace__(self: Self) -> NoReturn: ) raise NotImplementedError(msg) + def get_column(self, col: str) -> InterchangeSeries: + from narwhals._interchange.series import InterchangeSeries + + return InterchangeSeries( + self._interchange_frame.get_column_by_name(col), version=self._version + ) + def to_pandas(self: Self) -> pd.DataFrame: import pandas as pd # ignore-banned-import() diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index ead7f02341..3bc599094e 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -100,7 +100,6 @@ class PandasLikeDataFrame(EagerDataFrame["PandasLikeSeries", "PandasLikeExpr", "Any"]): - # --- not in the spec --- def __init__( self: Self, native_dataframe: Any, @@ -118,6 +117,10 @@ def __init__( if validate_column_names: check_column_names_are_unique(native_dataframe.columns) + @property + def native_series(self) -> type[pd.Series[Any]]: + return self.__native_namespace__().Series + @classmethod def from_arrow(cls, data: IntoArrowTable, /, *, context: _FullContext) -> Self: implementation = context._implementation diff --git a/narwhals/_polars/dataframe.py b/narwhals/_polars/dataframe.py index 6258e4218d..a0e4759d4b 100644 --- a/narwhals/_polars/dataframe.py +++ b/narwhals/_polars/dataframe.py @@ -20,6 +20,7 @@ from narwhals.exceptions import ColumnNotFoundError from narwhals.utils import Implementation from narwhals.utils import _into_arrow_table +from narwhals.utils import is_compliant_series from narwhals.utils import is_sequence_but_not_str from narwhals.utils import parse_columns_to_drop from narwhals.utils import parse_version @@ -257,8 +258,13 @@ def shape(self: Self) -> tuple[int, int]: return self.native.shape def __getitem__(self: Self, item: Any) -> Any: + rows, columns = item + if is_compliant_series(rows): + rows = rows.native + if is_compliant_series(columns): + columns = columns.native if self._backend_version > (0, 20, 30): - return self._from_native_object(self.native.__getitem__(item)) + return self._from_native_object(self.native.__getitem__((rows, columns))) else: # pragma: no cover # TODO(marco): we can delete this branch after Polars==0.20.30 becomes the minimum # Polars version we support @@ -272,14 +278,18 @@ def __getitem__(self: Self, item: Any) -> Any: return self._with_native(self.native[0:0]) return self._with_native(self.native.__getitem__(item[0])) if isinstance(item[1].start, str) or isinstance(item[1].stop, str): - start, stop, step = convert_str_slice_to_int_slice(item[1], columns) + start, stop, step = convert_str_slice_to_int_slice( + item[1], self.columns + ) return self._with_native( - self.native.select(columns[start:stop:step]).__getitem__(item[0]) + self.native.select(self.columns[start:stop:step]).__getitem__( + item[0] + ) ) if isinstance(item[1].start, int) or isinstance(item[1].stop, int): return self._with_native( self.native.select( - columns[item[1].start : item[1].stop : item[1].step] + self.columns[item[1].start : item[1].stop : item[1].step] ).__getitem__(item[0]) ) msg = f"Expected slice of integers or strings, got: {type(item[1])}" # pragma: no cover @@ -295,8 +305,10 @@ def __getitem__(self: Self, item: Any) -> Any: elif isinstance(item, slice) and ( isinstance(item.start, str) or isinstance(item.stop, str) ): - start, stop, step = convert_str_slice_to_int_slice(item, columns) - return self._with_native(self.native.select(columns[start:stop:step])) + start, stop, step = convert_str_slice_to_int_slice(item, self.columns) + return self._with_native( + self.native.select(self.columns[start:stop:step]) + ) elif is_sequence_but_not_str(item) and (len(item) == 0): result = self.native.slice(0, 0) else: diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index 16d737320b..cb60607c68 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -922,6 +922,8 @@ def __getitem__( raise TypeError(msg) compliant = self._compliant_frame + rows = to_native(rows, pass_through=True) + columns = to_native(columns, pass_through=True) if isinstance(rows, int) and isinstance(columns, (int, str)): return self.item(rows, columns) diff --git a/narwhals/utils.py b/narwhals/utils.py index d7d1115084..5ea5533ee2 100644 --- a/narwhals/utils.py +++ b/narwhals/utils.py @@ -1276,14 +1276,25 @@ def is_sequence_but_not_str(sequence: Any | Sequence[_T]) -> TypeIs[Sequence[_T] def is_sequence_like_ints(sequence: Any | Sequence[_T]) -> bool: np = get_numpy() return ( - isinstance(sequence, Sequence) - and ((len(sequence) > 0 and isinstance(sequence[0], int)) or (len(sequence) == 0)) - ) or (is_numpy_array_1d(sequence) and np.issubdtype(sequence.dtype, np.integer)) + ( + isinstance(sequence, Sequence) + and ( + (len(sequence) > 0 and isinstance(sequence[0], int)) + or (len(sequence) == 0) + ) + ) + or (is_numpy_array_1d(sequence) and np.issubdtype(sequence.dtype, np.integer)) + or (is_compliant_series(sequence) and sequence.dtype.is_integer()) + ) def is_sequence_like(sequence: Any | Sequence[_T]) -> bool: - return (isinstance(sequence, Sequence) and not isinstance(sequence, str)) or ( - is_numpy_array_1d(sequence) + from narwhals.series import Series + + return ( + (isinstance(sequence, Sequence) and not isinstance(sequence, str)) + or (is_numpy_array_1d(sequence)) + or isinstance(sequence, Series) ) @@ -1302,7 +1313,14 @@ def is_slice_ints(obj: object) -> bool: def is_int_like_indexer(cols: object) -> bool: - return isinstance(cols, int) or is_sequence_like_ints(cols) or is_slice_ints(cols) + from narwhals.series import Series + + return ( + isinstance(cols, int) + or is_sequence_like_ints(cols) + or is_slice_ints(cols) + or (isinstance(cols, Series) and cols.dtype.is_integer()) + ) def is_null_slice(obj: object) -> bool: From 01bd787f3ebc282240402bda21f213b3ef92c9b3 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Fri, 18 Apr 2025 14:48:42 +0100 Subject: [PATCH 04/80] fixup --- narwhals/_arrow/dataframe.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index 32e35913d9..40b8178f5d 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -76,8 +76,6 @@ class ArrowDataFrame(EagerDataFrame["ArrowSeries", "ArrowExpr", "pa.Table"]): - native_series = pa.ChunkedArray - def __init__( self: Self, native_dataframe: pa.Table, @@ -94,6 +92,10 @@ def __init__( self._version = version validate_backend_version(self._implementation, self._backend_version) + @property + def native_series(self) -> type[ArrowChunkedArray]: + return pa.ChunkedArray + @classmethod def from_arrow(cls, data: IntoArrowTable, /, *, context: _FullContext) -> Self: backend_version = context._backend_version From 86a057cd5dec3bd30c1bdbaca35cabdbdd5edec9 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Fri, 18 Apr 2025 14:53:48 +0100 Subject: [PATCH 05/80] gather -> _gather --- narwhals/_arrow/dataframe.py | 2 +- narwhals/_arrow/series.py | 2 +- narwhals/_compliant/dataframe.py | 6 +++--- narwhals/_compliant/series.py | 6 +++--- narwhals/_pandas_like/dataframe.py | 2 +- narwhals/_pandas_like/series.py | 2 +- narwhals/_polars/series.py | 2 +- tests/series_only/{__getitem___test.py => getitem_test.py} | 0 8 files changed, 11 insertions(+), 11 deletions(-) rename tests/series_only/{__getitem___test.py => getitem_test.py} (100%) diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index 40b8178f5d..d3a5b4696b 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -250,7 +250,7 @@ def get_column(self: Self, name: str) -> ArrowSeries: def __array__(self: Self, dtype: Any, *, copy: bool | None) -> _2DArray: return self.native.__array__(dtype, copy=copy) - def gather(self, item: Any) -> Self: + def _gather(self, item: Any) -> Self: if len(item) == 0: return self._with_native(self.native.slice(0, 0)) return self._with_native(self.native.take(item)) diff --git a/narwhals/_arrow/series.py b/narwhals/_arrow/series.py index 741bf2361b..47e99c08fb 100644 --- a/narwhals/_arrow/series.py +++ b/narwhals/_arrow/series.py @@ -405,7 +405,7 @@ def __native_namespace__(self: Self) -> ModuleType: def name(self: Self) -> str: return self._name - def gather(self, item: Any) -> Self: + def _gather(self, item: Any) -> Self: if len(item) == 0: return self._with_native(self.native.slice(0, 0)) return self._with_native(self.native.take(item)) diff --git a/narwhals/_compliant/dataframe.py b/narwhals/_compliant/dataframe.py index 8538f77928..7cc01a9233 100644 --- a/narwhals/_compliant/dataframe.py +++ b/narwhals/_compliant/dataframe.py @@ -377,7 +377,7 @@ def _numpy_column_names( ) -> list[str]: return list(columns or (f"column_{x}" for x in range(data.shape[1]))) - def gather(self, indices: Any) -> Self: ... + def _gather(self, indices: Any) -> Self: ... def _gather_slice(self, indices: Any) -> Self: ... def _select_indices(self, indices: Any) -> Self: ... def _select_labels(self, indices: Any) -> Self: ... @@ -404,11 +404,11 @@ def __getitem__(self, item: tuple[Any, Any]) -> Self: if not is_null_slice(rows): if isinstance(rows, int): - compliant = compliant.gather([rows]) + compliant = compliant._gather([rows]) elif isinstance(rows, (slice, range)): compliant = compliant._gather_slice(rows) elif is_sequence_like_ints(rows) or isinstance(rows, self.native_series): - compliant = compliant.gather(rows) + compliant = compliant._gather(rows) else: msg = "Unreachable code" raise AssertionError(msg) diff --git a/narwhals/_compliant/series.py b/narwhals/_compliant/series.py index e6dee7c22d..f6f171062d 100644 --- a/narwhals/_compliant/series.py +++ b/narwhals/_compliant/series.py @@ -286,7 +286,7 @@ def list(self) -> Any: ... @property def struct(self) -> Any: ... - def gather(self, indices: Any) -> Self: ... + def _gather(self, indices: Any) -> Self: ... def _gather_slice(self, indices: Any) -> Self: ... @@ -294,11 +294,11 @@ def __getitem__(self, rows: Any) -> Self: if is_null_slice(rows): return self if isinstance(rows, int): - return self.gather([rows]) + return self._gather([rows]) elif isinstance(rows, (slice, range)): return self._gather_slice(rows) elif is_sequence_like_ints(rows) or isinstance(rows, self.native.__class__): - return self.gather(rows) + return self._gather(rows) else: msg = "Unreachable code" raise AssertionError(msg) diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index 3bc599094e..4e24774ab3 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -281,7 +281,7 @@ def get_column(self: Self, name: str) -> PandasLikeSeries: def __array__(self: Self, dtype: Any = None, *, copy: bool | None = None) -> _2DArray: return self.to_numpy(dtype=dtype, copy=copy) - def gather(self, items: Any) -> Self: + def _gather(self, items: Any) -> Self: items = list(items) if isinstance(items, tuple) else items return self._with_native(self.native.iloc[items, :]) diff --git a/narwhals/_pandas_like/series.py b/narwhals/_pandas_like/series.py index 7215c731d4..ec9e0e3468 100644 --- a/narwhals/_pandas_like/series.py +++ b/narwhals/_pandas_like/series.py @@ -145,7 +145,7 @@ def __narwhals_namespace__(self) -> PandasLikeNamespace: self._implementation, self._backend_version, self._version ) - def gather(self, rows: Any) -> Self: + def _gather(self, rows: Any) -> Self: rows = list(rows) if isinstance(rows, tuple) else rows return self._with_native(self.native.iloc[rows]) diff --git a/narwhals/_polars/series.py b/narwhals/_polars/series.py index 664a7b9988..5130f2f2ec 100644 --- a/narwhals/_polars/series.py +++ b/narwhals/_polars/series.py @@ -617,7 +617,7 @@ def struct(self: Self) -> PolarsSeriesStructNamespace: fill_null: Method[Self] filter: Method[Self] _gather_slice: Method[Self] - gather: Method[Self] + _gather: Method[Self] gather_every: Method[Self] head: Method[Self] is_between: Method[Self] diff --git a/tests/series_only/__getitem___test.py b/tests/series_only/getitem_test.py similarity index 100% rename from tests/series_only/__getitem___test.py rename to tests/series_only/getitem_test.py From 8c216fce24cd2a83b87cd9308e0a0b7f052a0a64 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Fri, 18 Apr 2025 15:12:33 +0100 Subject: [PATCH 06/80] type alias intindexer and strindexer --- narwhals/_arrow/dataframe.py | 14 ++++++++------ narwhals/_arrow/series.py | 5 +++-- narwhals/_arrow/utils.py | 14 +++++++++++--- narwhals/_compliant/dataframe.py | 14 ++++++++------ narwhals/_compliant/series.py | 5 +++-- narwhals/_pandas_like/dataframe.py | 14 ++++++++------ narwhals/_pandas_like/series.py | 5 +++-- narwhals/_pandas_like/utils.py | 2 +- narwhals/typing.py | 5 +++++ 9 files changed, 50 insertions(+), 28 deletions(-) diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index d3a5b4696b..d9845ec716 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -59,6 +59,8 @@ from narwhals.typing import SizeUnit from narwhals.typing import UniqueKeepStrategy from narwhals.typing import _2DArray + from narwhals.typing import _IntIndexer + from narwhals.typing import _StrIndexer from narwhals.utils import Version from narwhals.utils import _FullContext @@ -250,12 +252,12 @@ def get_column(self: Self, name: str) -> ArrowSeries: def __array__(self: Self, dtype: Any, *, copy: bool | None) -> _2DArray: return self.native.__array__(dtype, copy=copy) - def _gather(self, item: Any) -> Self: + def _gather(self, item: _IntIndexer) -> Self: if len(item) == 0: return self._with_native(self.native.slice(0, 0)) return self._with_native(self.native.take(item)) - def _gather_slice(self, item: Any) -> Self: + def _gather_slice(self, item: slice | range) -> Self: start = item.start or 0 stop = item.stop if item.stop is not None else len(self.native) if item.step is not None and item.step != 1: @@ -263,23 +265,23 @@ def _gather_slice(self, item: Any) -> Self: raise NotImplementedError(msg) return self._with_native(self.native.slice(start, stop - start)) - def _select_slice_of_labels(self, item: Any) -> Self: + def _select_slice_of_labels(self, item: slice | range) -> Self: start, stop, step = convert_str_slice_to_int_slice(item, self.columns) return self._with_native(self.native.select(self.columns[start:stop:step])) - def _select_slice_of_indices(self, item: Any) -> Self: + def _select_slice_of_indices(self, item: slice | range) -> Self: return self._with_native( self.native.select(self.columns[item.start : item.stop : item.step]) ) - def _select_indices(self, item: Any) -> Self: + def _select_indices(self, item: _IntIndexer) -> Self: if len(item) == 0: return self._with_native( self.native.__class__.from_arrays([]), validate_column_names=False ) return self._with_native(self.native.select([self.columns[x] for x in item])) - def _select_labels(self, item: Any) -> Self: + def _select_labels(self, item: _StrIndexer) -> Self: return self._with_native(self.native.select(item)) @property diff --git a/narwhals/_arrow/series.py b/narwhals/_arrow/series.py index 47e99c08fb..01ba885609 100644 --- a/narwhals/_arrow/series.py +++ b/narwhals/_arrow/series.py @@ -69,6 +69,7 @@ from narwhals.typing import TemporalLiteral from narwhals.typing import _1DArray from narwhals.typing import _2DArray + from narwhals.typing import _IntIndexer from narwhals.utils import Version from narwhals.utils import _FullContext @@ -405,12 +406,12 @@ def __native_namespace__(self: Self) -> ModuleType: def name(self: Self) -> str: return self._name - def _gather(self, item: Any) -> Self: + def _gather(self, item: _IntIndexer) -> Self: if len(item) == 0: return self._with_native(self.native.slice(0, 0)) return self._with_native(self.native.take(item)) - def _gather_slice(self, item: Any) -> Self: + def _gather_slice(self, item: slice | range) -> Self: start = item.start or 0 stop = item.stop if item.stop is not None else len(self.native) if item.step is not None and item.step != 1: diff --git a/narwhals/_arrow/utils.py b/narwhals/_arrow/utils.py index 6dfdb0fd9e..6be68f6431 100644 --- a/narwhals/_arrow/utils.py +++ b/narwhals/_arrow/utils.py @@ -352,10 +352,18 @@ def select_rows( def convert_str_slice_to_int_slice( - str_slice: slice, columns: list[str] + str_slice: slice | range, columns: list[str] ) -> tuple[int | None, int | None, int | None]: - start = columns.index(str_slice.start) if str_slice.start is not None else None - stop = columns.index(str_slice.stop) + 1 if str_slice.stop is not None else None + start = ( + columns.index(cast("str", str_slice.start)) + if str_slice.start is not None + else None + ) + stop = ( + columns.index(cast("str", str_slice.stop)) + 1 + if str_slice.stop is not None + else None + ) step = str_slice.step return (start, stop, step) diff --git a/narwhals/_compliant/dataframe.py b/narwhals/_compliant/dataframe.py index 7cc01a9233..42794cdb8a 100644 --- a/narwhals/_compliant/dataframe.py +++ b/narwhals/_compliant/dataframe.py @@ -52,6 +52,8 @@ from narwhals.typing import SizeUnit from narwhals.typing import UniqueKeepStrategy from narwhals.typing import _2DArray + from narwhals.typing import _IntIndexer + from narwhals.typing import _StrIndexer from narwhals.utils import Implementation from narwhals.utils import _FullContext @@ -377,12 +379,12 @@ def _numpy_column_names( ) -> list[str]: return list(columns or (f"column_{x}" for x in range(data.shape[1]))) - def _gather(self, indices: Any) -> Self: ... - def _gather_slice(self, indices: Any) -> Self: ... - def _select_indices(self, indices: Any) -> Self: ... - def _select_labels(self, indices: Any) -> Self: ... - def _select_slice_of_indices(self, indices: Any) -> Self: ... - def _select_slice_of_labels(self, indices: Any) -> Self: ... + def _gather(self, indices: _IntIndexer) -> Self: ... + def _gather_slice(self, indices: slice | range) -> Self: ... + def _select_indices(self, indices: _IntIndexer) -> Self: ... + def _select_labels(self, indices: _StrIndexer) -> Self: ... + def _select_slice_of_indices(self, indices: slice | range) -> Self: ... + def _select_slice_of_labels(self, indices: slice | range) -> Self: ... def __getitem__(self, item: tuple[Any, Any]) -> Self: rows, columns = item diff --git a/narwhals/_compliant/series.py b/narwhals/_compliant/series.py index f6f171062d..754a784378 100644 --- a/narwhals/_compliant/series.py +++ b/narwhals/_compliant/series.py @@ -50,6 +50,7 @@ from narwhals.typing import RollingInterpolationMethod from narwhals.typing import TemporalLiteral from narwhals.typing import _1DArray + from narwhals.typing import _IntIndexer from narwhals.utils import Implementation from narwhals.utils import Version from narwhals.utils import _FullContext @@ -286,9 +287,9 @@ def list(self) -> Any: ... @property def struct(self) -> Any: ... - def _gather(self, indices: Any) -> Self: ... + def _gather(self, indices: _IntIndexer) -> Self: ... - def _gather_slice(self, indices: Any) -> Self: ... + def _gather_slice(self, indices: slice | range) -> Self: ... def __getitem__(self, rows: Any) -> Self: if is_null_slice(rows): diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index 4e24774ab3..06dcf829fe 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -67,6 +67,8 @@ from narwhals.typing import SizeUnit from narwhals.typing import UniqueKeepStrategy from narwhals.typing import _2DArray + from narwhals.typing import _IntIndexer + from narwhals.typing import _StrIndexer from narwhals.utils import Version from narwhals.utils import _FullContext @@ -281,30 +283,30 @@ def get_column(self: Self, name: str) -> PandasLikeSeries: def __array__(self: Self, dtype: Any = None, *, copy: bool | None = None) -> _2DArray: return self.to_numpy(dtype=dtype, copy=copy) - def _gather(self, items: Any) -> Self: + def _gather(self, items: _IntIndexer) -> Self: items = list(items) if isinstance(items, tuple) else items return self._with_native(self.native.iloc[items, :]) - def _gather_slice(self, item: Any) -> Self: + def _gather_slice(self, item: slice | range) -> Self: return self._with_native( self.native.iloc[slice(item.start, item.stop, item.step), :], validate_column_names=False, ) - def _select_slice_of_labels(self, item: Any) -> Self: + def _select_slice_of_labels(self, item: slice | range) -> Self: start, stop, step = convert_str_slice_to_int_slice(item, self.native.columns) return self._with_native( self.native.iloc[:, slice(start, stop, step)], validate_column_names=False, ) - def _select_slice_of_indices(self, item: Any) -> Self: + def _select_slice_of_indices(self, item: slice | range) -> Self: return self._with_native( self.native.iloc[:, slice(item.start, item.stop, item.step)], validate_column_names=False, ) - def _select_indices(self, item: Any) -> Self: + def _select_indices(self, item: _IntIndexer) -> Self: item = list(item) if isinstance(item, tuple) else item if len(item) == 0: return self._with_native(self.native.__class__(), validate_column_names=False) @@ -313,7 +315,7 @@ def _select_indices(self, item: Any) -> Self: validate_column_names=False, ) - def _select_labels(self, indices: Any) -> PandasLikeDataFrame: + def _select_labels(self, indices: _StrIndexer) -> PandasLikeDataFrame: return self._with_native(self.native.loc[:, indices]) # --- properties --- diff --git a/narwhals/_pandas_like/series.py b/narwhals/_pandas_like/series.py index ec9e0e3468..d578cc9a2d 100644 --- a/narwhals/_pandas_like/series.py +++ b/narwhals/_pandas_like/series.py @@ -56,6 +56,7 @@ from narwhals.typing import TemporalLiteral from narwhals.typing import _1DArray from narwhals.typing import _AnyDArray + from narwhals.typing import _IntIndexer from narwhals.utils import Version from narwhals.utils import _FullContext @@ -145,11 +146,11 @@ def __narwhals_namespace__(self) -> PandasLikeNamespace: self._implementation, self._backend_version, self._version ) - def _gather(self, rows: Any) -> Self: + def _gather(self, rows: _IntIndexer) -> Self: rows = list(rows) if isinstance(rows, tuple) else rows return self._with_native(self.native.iloc[rows]) - def _gather_slice(self, item: Any) -> Self: + def _gather_slice(self, item: slice | range) -> Self: return self._with_native( self.native.iloc[slice(item.start, item.stop, item.step)] ) diff --git a/narwhals/_pandas_like/utils.py b/narwhals/_pandas_like/utils.py index e73369d1cf..766b66455a 100644 --- a/narwhals/_pandas_like/utils.py +++ b/narwhals/_pandas_like/utils.py @@ -551,7 +551,7 @@ def int_dtype_mapper(dtype: Any) -> str: def convert_str_slice_to_int_slice( - str_slice: slice, columns: pd.Index[str] + str_slice: slice | range, columns: pd.Index[str] ) -> tuple[int | None, int | None, int | None]: # We can safely cast to int because we know that `columns` doesn't contain duplicates. start = ( diff --git a/narwhals/typing.py b/narwhals/typing.py index c222b200f8..e843fcf462 100644 --- a/narwhals/typing.py +++ b/narwhals/typing.py @@ -309,6 +309,11 @@ def __native_namespace__(self) -> ModuleType: ... ) PythonLiteral: TypeAlias = "NonNestedLiteral | list[Any] | tuple[Any, ...]" +# Overloaded sequence of integers +_IntIndexer: TypeAlias = Any # noqa: PYI047 +# Overloaded sequence of strings +_StrIndexer: TypeAlias = Any # noqa: PYI047 + # ruff: noqa: N802 class DTypes(Protocol): From 944b5977a76358bed229be4fb65516fe66746325 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Fri, 18 Apr 2025 15:17:17 +0100 Subject: [PATCH 07/80] reduce diff --- narwhals/_compliant/dataframe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/narwhals/_compliant/dataframe.py b/narwhals/_compliant/dataframe.py index 42794cdb8a..6f0e5ba648 100644 --- a/narwhals/_compliant/dataframe.py +++ b/narwhals/_compliant/dataframe.py @@ -105,6 +105,7 @@ def from_numpy( schema: Mapping[str, DType] | Schema | Sequence[str] | None, ) -> Self: ... def __array__(self, dtype: Any, *, copy: bool | None) -> _2DArray: ... + def __getitem__(self, item: tuple[Any, Any]) -> Self: ... def simple_select(self, *column_names: str) -> Self: """`select` where all args are column names.""" ... @@ -232,7 +233,6 @@ def write_csv(self, file: None) -> str: ... def write_csv(self, file: str | Path | BytesIO) -> None: ... def write_csv(self, file: str | Path | BytesIO | None) -> str | None: ... def write_parquet(self, file: str | Path | BytesIO) -> None: ... - def __getitem__(self, item: tuple[Any, Any]) -> Self: ... class CompliantLazyFrame( From 6c98cee72e3dab43cc256d754d73f2714f6e4170 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Fri, 18 Apr 2025 15:19:27 +0100 Subject: [PATCH 08/80] naming --- narwhals/_duckdb/dataframe.py | 4 ++-- narwhals/_ibis/dataframe.py | 4 ++-- narwhals/_interchange/dataframe.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/narwhals/_duckdb/dataframe.py b/narwhals/_duckdb/dataframe.py index 1cf4067397..aa3f58cd53 100644 --- a/narwhals/_duckdb/dataframe.py +++ b/narwhals/_duckdb/dataframe.py @@ -102,10 +102,10 @@ def __narwhals_namespace__(self: Self) -> DuckDBNamespace: backend_version=self._backend_version, version=self._version ) - def get_column(self: Self, col: str) -> DuckDBInterchangeSeries: + def get_column(self: Self, name: str) -> DuckDBInterchangeSeries: from narwhals._duckdb.series import DuckDBInterchangeSeries - return DuckDBInterchangeSeries(self.native.select(col), version=self._version) + return DuckDBInterchangeSeries(self.native.select(name), version=self._version) def _iter_columns(self) -> Iterator[duckdb.Expression]: for name in self.columns: diff --git a/narwhals/_ibis/dataframe.py b/narwhals/_ibis/dataframe.py index fba83f6a9e..64b2ceb062 100644 --- a/narwhals/_ibis/dataframe.py +++ b/narwhals/_ibis/dataframe.py @@ -100,10 +100,10 @@ def __narwhals_lazyframe__(self) -> Any: def __native_namespace__(self: Self) -> ModuleType: return get_ibis() - def get_column(self, col: str) -> IbisInterchangeSeries: + def get_column(self, name: str) -> IbisInterchangeSeries: from narwhals._ibis.series import IbisInterchangeSeries - return IbisInterchangeSeries(self._native_frame[col], version=self._version) + return IbisInterchangeSeries(self._native_frame[name], version=self._version) def to_pandas(self: Self) -> pd.DataFrame: return self._native_frame.to_pandas() diff --git a/narwhals/_interchange/dataframe.py b/narwhals/_interchange/dataframe.py index b672624b06..4f2884aadf 100644 --- a/narwhals/_interchange/dataframe.py +++ b/narwhals/_interchange/dataframe.py @@ -105,11 +105,11 @@ def __native_namespace__(self: Self) -> NoReturn: ) raise NotImplementedError(msg) - def get_column(self, col: str) -> InterchangeSeries: + def get_column(self, name: str) -> InterchangeSeries: from narwhals._interchange.series import InterchangeSeries return InterchangeSeries( - self._interchange_frame.get_column_by_name(col), version=self._version + self._interchange_frame.get_column_by_name(name), version=self._version ) def to_pandas(self: Self) -> pd.DataFrame: From dfbb63f6dc8ff3184c6aee9aab0b930b954ecee6 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Fri, 18 Apr 2025 15:27:10 +0100 Subject: [PATCH 09/80] reduce polars diff --- narwhals/_polars/dataframe.py | 29 ++++++----------------------- 1 file changed, 6 insertions(+), 23 deletions(-) diff --git a/narwhals/_polars/dataframe.py b/narwhals/_polars/dataframe.py index ba66f7688f..1e4005d748 100644 --- a/narwhals/_polars/dataframe.py +++ b/narwhals/_polars/dataframe.py @@ -20,7 +20,6 @@ from narwhals.exceptions import ColumnNotFoundError from narwhals.utils import Implementation from narwhals.utils import _into_arrow_table -from narwhals.utils import is_compliant_series from narwhals.utils import is_sequence_but_not_str from narwhals.utils import parse_columns_to_drop from narwhals.utils import parse_version @@ -259,19 +258,11 @@ def shape(self: Self) -> tuple[int, int]: return self.native.shape def __getitem__(self: Self, item: Any) -> Any: - rows, columns = item - if is_compliant_series(rows): - rows = rows.native - if is_compliant_series(columns): - columns = columns.native if self._backend_version > (0, 20, 30): - return self._from_native_object(self.native.__getitem__((rows, columns))) + return self._from_native_object(self.native.__getitem__(item)) else: # pragma: no cover # TODO(marco): we can delete this branch after Polars==0.20.30 becomes the minimum # Polars version we support - if isinstance(item, tuple): - item = tuple(list(i) if is_sequence_but_not_str(i) else i for i in item) - columns = self.columns if isinstance(item, tuple) and len(item) == 2 and isinstance(item[1], slice): if item[1] == slice(None): @@ -279,18 +270,14 @@ def __getitem__(self: Self, item: Any) -> Any: return self._with_native(self.native[0:0]) return self._with_native(self.native.__getitem__(item[0])) if isinstance(item[1].start, str) or isinstance(item[1].stop, str): - start, stop, step = convert_str_slice_to_int_slice( - item[1], self.columns - ) + start, stop, step = convert_str_slice_to_int_slice(item[1], columns) return self._with_native( - self.native.select(self.columns[start:stop:step]).__getitem__( - item[0] - ) + self.native.select(columns[start:stop:step]).__getitem__(item[0]) ) if isinstance(item[1].start, int) or isinstance(item[1].stop, int): return self._with_native( self.native.select( - self.columns[item[1].start : item[1].stop : item[1].step] + columns[item[1].start : item[1].stop : item[1].step] ).__getitem__(item[0]) ) msg = f"Expected slice of integers or strings, got: {type(item[1])}" # pragma: no cover @@ -306,16 +293,12 @@ def __getitem__(self: Self, item: Any) -> Any: elif isinstance(item, slice) and ( isinstance(item.start, str) or isinstance(item.stop, str) ): - start, stop, step = convert_str_slice_to_int_slice(item, self.columns) - return self._with_native( - self.native.select(self.columns[start:stop:step]) - ) + start, stop, step = convert_str_slice_to_int_slice(item, columns) + return self._with_native(self.native.select(columns[start:stop:step])) elif is_sequence_but_not_str(item) and (len(item) == 0): result = self.native.slice(0, 0) else: result = self.native.__getitem__(item) - if isinstance(result, pl.Series): - return PolarsSeries.from_native(result, context=self) return self._from_native_object(result) def simple_select(self, *column_names: str) -> Self: From 190e265ffb2ad024fa30df3c81d09900ef9ead4e Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Fri, 18 Apr 2025 15:31:12 +0100 Subject: [PATCH 10/80] appease pyright --- narwhals/_polars/namespace.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/narwhals/_polars/namespace.py b/narwhals/_polars/namespace.py index 35b11a2113..14165996d3 100644 --- a/narwhals/_polars/namespace.py +++ b/narwhals/_polars/namespace.py @@ -253,7 +253,7 @@ def concat_str( # i. None of that is useful here # 2. We don't have a `PolarsSelector` abstraction, and just use `PolarsExpr` @property - def selectors(self) -> CompliantSelectorNamespace[PolarsDataFrame, PolarsSeries]: + def selectors(self) -> CompliantSelectorNamespace[PolarsDataFrame, PolarsSeries]: # pyright: ignore[reportInvalidTypeArguments] return cast( "CompliantSelectorNamespace[PolarsDataFrame, PolarsSeries]", # pyright: ignore[reportInvalidTypeArguments] PolarsSelectorNamespace(self), From 85940eea9756f497cd88a97c7ee15cf997769e7e Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Fri, 18 Apr 2025 15:45:49 +0100 Subject: [PATCH 11/80] extra test --- narwhals/_arrow/dataframe.py | 2 ++ tests/frame/getitem_test.py | 8 ++++++++ tests/hypothesis/getitem_test.py | 14 ++++++++++++++ 3 files changed, 24 insertions(+) diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index d9845ec716..94e431417d 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -259,6 +259,8 @@ def _gather(self, item: _IntIndexer) -> Self: def _gather_slice(self, item: slice | range) -> Self: start = item.start or 0 + if start < 0: + start = len(self.native) + start stop = item.stop if item.stop is not None else len(self.native) if item.step is not None and item.step != 1: msg = "Slicing with step is not supported on PyArrow tables" diff --git a/tests/frame/getitem_test.py b/tests/frame/getitem_test.py index 701287f2f2..642562c2ef 100644 --- a/tests/frame/getitem_test.py +++ b/tests/frame/getitem_test.py @@ -250,3 +250,11 @@ def test_getitem_ndarray_columns(constructor_eager: ConstructorEager) -> None: result = nw_df[:, arr] expected = {"col1": ["a", "b", "c", "d"], "col2": [0, 1, 2, 3]} assert_equal_data(result, expected) + + +def test_getitem_negative_slice(constructor_eager: ConstructorEager) -> None: + data = {"col1": ["a", "b", "c", "d"], "col2": np.arange(4), "col3": [4, 3, 2, 1]} + nw_df = nw.from_native(constructor_eager(data), eager_only=True) + result = nw_df[-3:2, ["col3", "col1"]] + expected = {"col3": [3], "col1": ["b"]} + assert_equal_data(result, expected) diff --git a/tests/hypothesis/getitem_test.py b/tests/hypothesis/getitem_test.py index 970745f03e..714e32c982 100644 --- a/tests/hypothesis/getitem_test.py +++ b/tests/hypothesis/getitem_test.py @@ -155,6 +155,18 @@ def test_getitem( ) ) + # NotImplementedError: Slicing with step is not supported on PyArrow tables + assume( + not ( + pandas_or_pyarrow_constructor is pyarrow_table_constructor + and isinstance(selector, tuple) + and ( + (isinstance(selector[0], slice) and selector[0].step is not None) + or (isinstance(selector[1], slice) and selector[1].step is not None) + ) + ) + ) + # IndexError: Offset must be non-negative (pyarrow does not support negative indexing) assume( not ( @@ -240,6 +252,8 @@ def test_getitem( if isinstance(result_polars, nw.Series): assert_equal_data({"a": result_other}, {"a": result_polars.to_list()}) + elif isinstance(result_polars, (str, int)): + assert result_polars == result_other else: assert_equal_data( result_other, From 0d79f24f08f8a8133551b8b9604df21ff85950ea Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Fri, 18 Apr 2025 16:25:04 +0100 Subject: [PATCH 12/80] arrow negative slicing --- narwhals/_arrow/dataframe.py | 4 +++- narwhals/_arrow/series.py | 4 ++++ tests/frame/getitem_test.py | 8 +++++++- 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index 94e431417d..8c14591110 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -259,9 +259,11 @@ def _gather(self, item: _IntIndexer) -> Self: def _gather_slice(self, item: slice | range) -> Self: start = item.start or 0 + stop = item.stop if item.stop is not None else len(self.native) if start < 0: start = len(self.native) + start - stop = item.stop if item.stop is not None else len(self.native) + if stop < 0: + stop = len(self.native) + stop if item.step is not None and item.step != 1: msg = "Slicing with step is not supported on PyArrow tables" raise NotImplementedError(msg) diff --git a/narwhals/_arrow/series.py b/narwhals/_arrow/series.py index c7c37953ab..589f717343 100644 --- a/narwhals/_arrow/series.py +++ b/narwhals/_arrow/series.py @@ -415,6 +415,10 @@ def _gather(self, item: _IntIndexer) -> Self: def _gather_slice(self, item: slice | range) -> Self: start = item.start or 0 stop = item.stop if item.stop is not None else len(self.native) + if start < 0: + start = len(self.native) + start + if stop < 0: + stop = len(self.native) + stop if item.step is not None and item.step != 1: msg = "Slicing with step is not supported on PyArrow tables" raise NotImplementedError(msg) diff --git a/tests/frame/getitem_test.py b/tests/frame/getitem_test.py index 642562c2ef..366685d098 100644 --- a/tests/frame/getitem_test.py +++ b/tests/frame/getitem_test.py @@ -255,6 +255,12 @@ def test_getitem_ndarray_columns(constructor_eager: ConstructorEager) -> None: def test_getitem_negative_slice(constructor_eager: ConstructorEager) -> None: data = {"col1": ["a", "b", "c", "d"], "col2": np.arange(4), "col3": [4, 3, 2, 1]} nw_df = nw.from_native(constructor_eager(data), eager_only=True) - result = nw_df[-3:2, ["col3", "col1"]] + result = nw_df[-3:-2, ["col3", "col1"]] expected = {"col3": [3], "col1": ["b"]} assert_equal_data(result, expected) + result = nw_df[-3:-2] + expected = {"col1": ["b"], "col2": [1], "col3": [3]} + assert_equal_data(result, expected) + result_s = nw_df["col1"][-3:-2] + expected = {"col1": ["b"]} + assert_equal_data({"col1": result_s}, expected) From 93110ef3687f818b5c4329fb0621372b2ef3554f Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Fri, 18 Apr 2025 16:35:26 +0100 Subject: [PATCH 13/80] old pyarrow compat --- narwhals/_arrow/dataframe.py | 2 ++ narwhals/_arrow/series.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index 8c14591110..56a98ed5f4 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -255,6 +255,8 @@ def __array__(self: Self, dtype: Any, *, copy: bool | None) -> _2DArray: def _gather(self, item: _IntIndexer) -> Self: if len(item) == 0: return self._with_native(self.native.slice(0, 0)) + if self._backend_version < (18,) and isinstance(item, tuple): + item = list(item) return self._with_native(self.native.take(item)) def _gather_slice(self, item: slice | range) -> Self: diff --git a/narwhals/_arrow/series.py b/narwhals/_arrow/series.py index 589f717343..e03598307a 100644 --- a/narwhals/_arrow/series.py +++ b/narwhals/_arrow/series.py @@ -410,6 +410,8 @@ def name(self: Self) -> str: def _gather(self, item: _IntIndexer) -> Self: if len(item) == 0: return self._with_native(self.native.slice(0, 0)) + if self._backend_version < (18,) and isinstance(item, tuple): + item = list(item) return self._with_native(self.native.take(item)) def _gather_slice(self, item: slice | range) -> Self: From 4788a47b14d35e1a72674a18ed0dc054d05b6e08 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Fri, 18 Apr 2025 17:09:48 +0100 Subject: [PATCH 14/80] old polars compat --- narwhals/_arrow/dataframe.py | 2 +- narwhals/_arrow/utils.py | 17 ------- narwhals/_polars/dataframe.py | 86 ++++++++++++++++++++--------------- narwhals/_polars/utils.py | 9 ---- narwhals/utils.py | 17 +++++++ 5 files changed, 67 insertions(+), 64 deletions(-) diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index 56a98ed5f4..3c5c0e8501 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -16,7 +16,6 @@ from narwhals._arrow.series import ArrowSeries from narwhals._arrow.utils import align_series_full_broadcast -from narwhals._arrow.utils import convert_str_slice_to_int_slice from narwhals._arrow.utils import native_to_narwhals_dtype from narwhals._compliant import EagerDataFrame from narwhals._expression_parsing import ExprKind @@ -25,6 +24,7 @@ from narwhals.utils import Version from narwhals.utils import check_column_exists from narwhals.utils import check_column_names_are_unique +from narwhals.utils import convert_str_slice_to_int_slice from narwhals.utils import generate_temporary_column_name from narwhals.utils import not_implemented from narwhals.utils import parse_columns_to_drop diff --git a/narwhals/_arrow/utils.py b/narwhals/_arrow/utils.py index 6be68f6431..901549408c 100644 --- a/narwhals/_arrow/utils.py +++ b/narwhals/_arrow/utils.py @@ -351,23 +351,6 @@ def select_rows( return selected_rows -def convert_str_slice_to_int_slice( - str_slice: slice | range, columns: list[str] -) -> tuple[int | None, int | None, int | None]: - start = ( - columns.index(cast("str", str_slice.start)) - if str_slice.start is not None - else None - ) - stop = ( - columns.index(cast("str", str_slice.stop)) + 1 - if str_slice.stop is not None - else None - ) - step = str_slice.step - return (start, stop, step) - - # Regex for date, time, separator and timezone components DATE_RE = r"(?P\d{1,4}[-/.]\d{1,2}[-/.]\d{1,4}|\d{8})" SEP_RE = r"(?P\s|T)" diff --git a/narwhals/_polars/dataframe.py b/narwhals/_polars/dataframe.py index 1e4005d748..473543f560 100644 --- a/narwhals/_polars/dataframe.py +++ b/narwhals/_polars/dataframe.py @@ -14,13 +14,17 @@ from narwhals._polars.namespace import PolarsNamespace from narwhals._polars.series import PolarsSeries from narwhals._polars.utils import catch_polars_exception -from narwhals._polars.utils import convert_str_slice_to_int_slice from narwhals._polars.utils import extract_args_kwargs from narwhals._polars.utils import native_to_narwhals_dtype +from narwhals.dependencies import is_numpy_array_1d from narwhals.exceptions import ColumnNotFoundError from narwhals.utils import Implementation from narwhals.utils import _into_arrow_table -from narwhals.utils import is_sequence_but_not_str +from narwhals.utils import convert_str_slice_to_int_slice +from narwhals.utils import is_int_like_indexer +from narwhals.utils import is_null_slice +from narwhals.utils import is_sequence_like +from narwhals.utils import is_sequence_like_ints from narwhals.utils import parse_columns_to_drop from narwhals.utils import parse_version from narwhals.utils import requires @@ -263,43 +267,51 @@ def __getitem__(self: Self, item: Any) -> Any: else: # pragma: no cover # TODO(marco): we can delete this branch after Polars==0.20.30 becomes the minimum # Polars version we support - columns = self.columns - if isinstance(item, tuple) and len(item) == 2 and isinstance(item[1], slice): - if item[1] == slice(None): - if isinstance(item[0], Sequence) and not len(item[0]): - return self._with_native(self.native[0:0]) - return self._with_native(self.native.__getitem__(item[0])) - if isinstance(item[1].start, str) or isinstance(item[1].stop, str): - start, stop, step = convert_str_slice_to_int_slice(item[1], columns) - return self._with_native( - self.native.select(columns[start:stop:step]).__getitem__(item[0]) + rows, columns = item + rows = list(rows) if isinstance(rows, tuple) else rows + columns = list(columns) if isinstance(columns, tuple) else columns + if is_numpy_array_1d(columns): + columns = columns.tolist() + + is_int_col_indexer = is_int_like_indexer(columns) + native = self.native + if not is_null_slice(columns): + if hasattr(columns, "__len__") and len(columns) == 0: + native = native.select() + if is_int_col_indexer and not isinstance(columns, (slice, range)): + native = native[:, columns] + elif is_int_col_indexer and isinstance(columns, (slice, range)): + native = native.select( + self.columns[slice(columns.start, columns.stop, columns.step)] ) - if isinstance(item[1].start, int) or isinstance(item[1].stop, int): - return self._with_native( - self.native.select( - columns[item[1].start : item[1].stop : item[1].step] - ).__getitem__(item[0]) + elif isinstance(columns, (slice, range)): + native = native.select( + self.columns[ + slice(*convert_str_slice_to_int_slice(columns, self.columns)) + ] ) - msg = f"Expected slice of integers or strings, got: {type(item[1])}" # pragma: no cover - raise TypeError(msg) # pragma: no cover - - if ( - isinstance(item, tuple) - and (len(item) == 2) - and is_sequence_but_not_str(item[1]) - and (len(item[1]) == 0) - ): - result = self.native.select(item[1]) - elif isinstance(item, slice) and ( - isinstance(item.start, str) or isinstance(item.stop, str) - ): - start, stop, step = convert_str_slice_to_int_slice(item, columns) - return self._with_native(self.native.select(columns[start:stop:step])) - elif is_sequence_but_not_str(item) and (len(item) == 0): - result = self.native.slice(0, 0) - else: - result = self.native.__getitem__(item) - return self._from_native_object(result) + elif is_int_col_indexer: + native = native[:, columns] + elif is_sequence_like(columns): + native = native.select(columns) + else: + msg = "Unreachable code" + raise AssertionError(msg) + + if not is_null_slice(rows): + if isinstance(rows, int): + native = native[[rows], :] + elif ( + isinstance(rows, (slice, range)) + or is_sequence_like_ints(rows) + or isinstance(rows, self.native_series) + ): + native = native[rows, :] + else: + msg = "Unreachable code" + raise AssertionError(msg) + + return self._with_native(native) def simple_select(self, *column_names: str) -> Self: return self._with_native(self.native.select(*column_names)) diff --git a/narwhals/_polars/utils.py b/narwhals/_polars/utils.py index 9452877eac..87276dac2f 100644 --- a/narwhals/_polars/utils.py +++ b/narwhals/_polars/utils.py @@ -221,15 +221,6 @@ def narwhals_to_native_dtype( return pl.Unknown() # pragma: no cover -def convert_str_slice_to_int_slice( - str_slice: slice, columns: list[str] -) -> tuple[int | None, int | None, int | None]: # pragma: no cover - start = columns.index(str_slice.start) if str_slice.start is not None else None - stop = columns.index(str_slice.stop) + 1 if str_slice.stop is not None else None - step = str_slice.step - return (start, stop, step) - - def catch_polars_exception( exception: Exception, backend_version: tuple[int, ...] ) -> NarwhalsError | Exception: diff --git a/narwhals/utils.py b/narwhals/utils.py index 1db3ce183e..ac9f741132 100644 --- a/narwhals/utils.py +++ b/narwhals/utils.py @@ -1929,3 +1929,20 @@ def wrapper(instance: _ContextT, *args: P.args, **kwds: P.kwargs) -> R: # NOTE: Only getting a complaint from `mypy` return wrapper # type: ignore[return-value] + + +def convert_str_slice_to_int_slice( + str_slice: slice | range, columns: list[str] +) -> tuple[int | None, int | None, int | None]: + start = ( + columns.index(cast("str", str_slice.start)) + if str_slice.start is not None + else None + ) + stop = ( + columns.index(cast("str", str_slice.stop)) + 1 + if str_slice.stop is not None + else None + ) + step = str_slice.step + return (start, stop, step) From 2f26c45755f966604817e3795f7f7e20f4c54deb Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Fri, 18 Apr 2025 17:31:33 +0100 Subject: [PATCH 15/80] pandas fixup --- narwhals/_compliant/dataframe.py | 2 ++ narwhals/dataframe.py | 8 ++++---- narwhals/stable/v1/__init__.py | 4 ++-- tests/frame/getitem_test.py | 8 ++++++++ 4 files changed, 16 insertions(+), 6 deletions(-) diff --git a/narwhals/_compliant/dataframe.py b/narwhals/_compliant/dataframe.py index 6f0e5ba648..3b63f1dc3f 100644 --- a/narwhals/_compliant/dataframe.py +++ b/narwhals/_compliant/dataframe.py @@ -392,6 +392,8 @@ def __getitem__(self, item: tuple[Any, Any]) -> Self: is_int_col_indexer = is_int_like_indexer(columns) compliant = self if not is_null_slice(columns): + if hasattr(columns, "__len__") and len(columns) == 0: + return compliant.select() if is_int_col_indexer and not isinstance(columns, slice): compliant = compliant._select_indices(columns) elif is_int_col_indexer: diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index cb60607c68..3f4cf7f543 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -804,7 +804,7 @@ def __getitem__(self: Self, item: tuple[int, int | str]) -> Any: ... @overload def __getitem__( # type: ignore[overload-overlap] self: Self, - item: str | tuple[slice | Sequence[int] | _1DArray, int | str], + item: str | tuple[int | slice | Sequence[int] | _1DArray, int | str], ) -> Series[Any]: ... @overload @@ -817,7 +817,7 @@ def __getitem__( | Sequence[str] | _1DArray | tuple[ - slice | Sequence[int] | _1DArray, + int | slice | Sequence[int] | _1DArray, slice | Sequence[int] | Sequence[str] | _1DArray, ] ), @@ -832,9 +832,9 @@ def __getitem__( | Sequence[str] | _1DArray | tuple[int, int | str] - | tuple[slice | Sequence[int] | _1DArray, int | str] + | tuple[int | slice | Sequence[int] | _1DArray, int | str] | tuple[ - slice | Sequence[int] | _1DArray, + int | slice | Sequence[int] | _1DArray, slice | Sequence[int] | Sequence[str] | _1DArray, ] ), diff --git a/narwhals/stable/v1/__init__.py b/narwhals/stable/v1/__init__.py index 71d03e763e..8760c1eee9 100644 --- a/narwhals/stable/v1/__init__.py +++ b/narwhals/stable/v1/__init__.py @@ -156,7 +156,7 @@ def __getitem__(self: Self, item: tuple[int, int | str]) -> Any: ... @overload def __getitem__( # type: ignore[overload-overlap] self: Self, - item: str | tuple[slice | Sequence[int] | _1DArray, int | str], + item: str | tuple[int | slice | Sequence[int] | _1DArray, int | str], ) -> Series[Any]: ... @overload @@ -169,7 +169,7 @@ def __getitem__( | Sequence[str] | _1DArray | tuple[ - slice | Sequence[int] | _1DArray, + int | slice | Sequence[int] | _1DArray, slice | Sequence[int] | Sequence[str] | _1DArray, ] ), diff --git a/tests/frame/getitem_test.py b/tests/frame/getitem_test.py index 366685d098..45eca4c11a 100644 --- a/tests/frame/getitem_test.py +++ b/tests/frame/getitem_test.py @@ -264,3 +264,11 @@ def test_getitem_negative_slice(constructor_eager: ConstructorEager) -> None: result_s = nw_df["col1"][-3:-2] expected = {"col1": ["b"]} assert_equal_data({"col1": result_s}, expected) + + +def test_zeroth_row_no_columns(constructor_eager: ConstructorEager) -> None: + data = {"col1": ["a", "b", "c", "d"], "col2": np.arange(4), "col3": [4, 3, 2, 1]} + nw_df = nw.from_native(constructor_eager(data), eager_only=True) + columns: list[str] = [] + result = nw_df[0, columns] + assert result.shape == (0, 0) From d9c619ed1ce6e2d28ef1619f8303f2947e17f4ac Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 18 Apr 2025 17:41:22 +0100 Subject: [PATCH 16/80] refactor: Reuse `CompliantSeries._is_native` --- narwhals/_arrow/dataframe.py | 4 ---- narwhals/_compliant/dataframe.py | 24 ++++++++++++++---------- narwhals/_pandas_like/dataframe.py | 4 ---- 3 files changed, 14 insertions(+), 18 deletions(-) diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index 3c5c0e8501..5b0a0a46c0 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -94,10 +94,6 @@ def __init__( self._version = version validate_backend_version(self._implementation, self._backend_version) - @property - def native_series(self) -> type[ArrowChunkedArray]: - return pa.ChunkedArray - @classmethod def from_arrow(cls, data: IntoArrowTable, /, *, context: _FullContext) -> Self: backend_version = context._backend_version diff --git a/narwhals/_compliant/dataframe.py b/narwhals/_compliant/dataframe.py index 6f0e5ba648..f3f655d567 100644 --- a/narwhals/_compliant/dataframe.py +++ b/narwhals/_compliant/dataframe.py @@ -14,7 +14,7 @@ from narwhals._compliant.typing import CompliantExprT_contra from narwhals._compliant.typing import CompliantSeriesT -from narwhals._compliant.typing import EagerExprT_contra +from narwhals._compliant.typing import EagerExprT from narwhals._compliant.typing import EagerSeriesT from narwhals._compliant.typing import NativeFrameT from narwhals._expression_parsing import evaluate_output_names_and_aliases @@ -42,6 +42,7 @@ from narwhals._compliant.group_by import CompliantGroupBy from narwhals._compliant.group_by import DataFrameGroupBy + from narwhals._compliant.namespace import EagerNamespace from narwhals._translate import IntoArrowTable from narwhals.dtypes import DType from narwhals.schema import Schema @@ -335,24 +336,25 @@ def _evaluate_expr(self, expr: CompliantExprT_contra, /) -> Any: class EagerDataFrame( - CompliantDataFrame[EagerSeriesT, EagerExprT_contra, NativeFrameT], - CompliantLazyFrame[EagerExprT_contra, NativeFrameT], - Protocol[EagerSeriesT, EagerExprT_contra, NativeFrameT], + CompliantDataFrame[EagerSeriesT, EagerExprT, NativeFrameT], + CompliantLazyFrame[EagerExprT, NativeFrameT], + Protocol[EagerSeriesT, EagerExprT, NativeFrameT], ): - @property - def native_series(self) -> Any: ... + def __narwhals_namespace__( + self, + ) -> EagerNamespace[Self, EagerSeriesT, EagerExprT, NativeFrameT, Any]: ... - def _evaluate_expr(self, expr: EagerExprT_contra, /) -> EagerSeriesT: + def _evaluate_expr(self, expr: EagerExprT, /) -> EagerSeriesT: """Evaluate `expr` and ensure it has a **single** output.""" result: Sequence[EagerSeriesT] = expr(self) assert len(result) == 1 # debug assertion # noqa: S101 return result[0] - def _evaluate_into_exprs(self, *exprs: EagerExprT_contra) -> Sequence[EagerSeriesT]: + def _evaluate_into_exprs(self, *exprs: EagerExprT) -> Sequence[EagerSeriesT]: # NOTE: Ignore is to avoid an intermittent false positive return list(chain.from_iterable(self._evaluate_into_expr(expr) for expr in exprs)) # pyright: ignore[reportArgumentType] - def _evaluate_into_expr(self, expr: EagerExprT_contra, /) -> Sequence[EagerSeriesT]: + def _evaluate_into_expr(self, expr: EagerExprT, /) -> Sequence[EagerSeriesT]: """Return list of raw columns. For eager backends we alias operations at each step. @@ -405,11 +407,13 @@ def __getitem__(self, item: tuple[Any, Any]) -> Self: raise AssertionError(msg) if not is_null_slice(rows): + is_native_series = self.__narwhals_namespace__()._series._is_native if isinstance(rows, int): compliant = compliant._gather([rows]) elif isinstance(rows, (slice, range)): compliant = compliant._gather_slice(rows) - elif is_sequence_like_ints(rows) or isinstance(rows, self.native_series): + + elif is_sequence_like_ints(rows) or is_native_series(rows): compliant = compliant._gather(rows) else: msg = "Unreachable code" diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index 06dcf829fe..1361ada3bf 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -119,10 +119,6 @@ def __init__( if validate_column_names: check_column_names_are_unique(native_dataframe.columns) - @property - def native_series(self) -> type[pd.Series[Any]]: - return self.__native_namespace__().Series - @classmethod def from_arrow(cls, data: IntoArrowTable, /, *, context: _FullContext) -> Self: implementation = context._implementation From 9317d7aff76f238cea9f13508b843b81ecee7796 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Fri, 18 Apr 2025 18:02:29 +0100 Subject: [PATCH 17/80] test and cover slicing dataframe with series --- narwhals/_arrow/dataframe.py | 4 ---- narwhals/_arrow/utils.py | 32 -------------------------- narwhals/_pandas_like/dataframe.py | 2 -- narwhals/dataframe.py | 17 ++++++++------ narwhals/stable/v1/__init__.py | 8 ++++--- narwhals/utils.py | 6 ----- tests/frame/getitem_test.py | 36 ++++++++++++++++++++++++++++++ tests/hypothesis/getitem_test.py | 4 ++-- 8 files changed, 53 insertions(+), 56 deletions(-) diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index 3c5c0e8501..6f4173e4e0 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -281,10 +281,6 @@ def _select_slice_of_indices(self, item: slice | range) -> Self: ) def _select_indices(self, item: _IntIndexer) -> Self: - if len(item) == 0: - return self._with_native( - self.native.__class__.from_arrays([]), validate_column_names=False - ) return self._with_native(self.native.select([self.columns[x] for x in item])) def _select_labels(self, item: _StrIndexer) -> Self: diff --git a/narwhals/_arrow/utils.py b/narwhals/_arrow/utils.py index 901549408c..1ce3d0eae4 100644 --- a/narwhals/_arrow/utils.py +++ b/narwhals/_arrow/utils.py @@ -7,7 +7,6 @@ from typing import Iterator from typing import Sequence from typing import cast -from typing import overload import pyarrow as pa import pyarrow.compute as pc @@ -18,8 +17,6 @@ from narwhals.utils import isinstance_or_issubclass if TYPE_CHECKING: - from typing import TypeVar - from typing_extensions import Self from typing_extensions import TypeAlias from typing_extensions import TypeIs @@ -34,15 +31,12 @@ from narwhals._arrow.typing import ScalarAny from narwhals.dtypes import DType from narwhals.typing import PythonLiteral - from narwhals.typing import _AnyDArray from narwhals.utils import Version # NOTE: stubs don't allow for `ChunkedArray[StructArray]` # Intended to represent the `.chunks` property storing `list[pa.StructArray]` ChunkedArrayStructArray: TypeAlias = ArrowChunkedArray - _T = TypeVar("_T") - def is_timestamp(t: Any) -> TypeIs[pa.TimestampType[Any, Any]]: ... def is_duration(t: Any) -> TypeIs[pa.DurationType[Any]]: ... def is_list(t: Any) -> TypeIs[pa.ListType[Any]]: ... @@ -325,32 +319,6 @@ def cast_for_truediv( return arrow_array, pa_object -@overload -def convert_slice_to_nparray(num_rows: int, rows_slice: slice) -> _AnyDArray: ... -@overload -def convert_slice_to_nparray(num_rows: int, rows_slice: _T) -> _T: ... -def convert_slice_to_nparray(num_rows: int, rows_slice: slice | _T) -> _AnyDArray | _T: - if isinstance(rows_slice, slice): - import numpy as np # ignore-banned-import - - return np.arange(num_rows)[rows_slice] - else: - return rows_slice - - -def select_rows( - table: pa.Table, rows: slice | int | Sequence[int] | _AnyDArray -) -> pa.Table: - if isinstance(rows, slice) and rows == slice(None): - selected_rows = table - elif isinstance(rows, Sequence) and not rows: - selected_rows = table.slice(0, 0) - else: - range_ = convert_slice_to_nparray(num_rows=len(table), rows_slice=rows) - selected_rows = table.take(cast("list[int]", range_)) - return selected_rows - - # Regex for date, time, separator and timezone components DATE_RE = r"(?P\d{1,4}[-/.]\d{1,2}[-/.]\d{1,4}|\d{8})" SEP_RE = r"(?P\s|T)" diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index 06dcf829fe..69a4f9a392 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -308,8 +308,6 @@ def _select_slice_of_indices(self, item: slice | range) -> Self: def _select_indices(self, item: _IntIndexer) -> Self: item = list(item) if isinstance(item, tuple) else item - if len(item) == 0: - return self._with_native(self.native.__class__(), validate_column_names=False) return self._with_native( self.native.iloc[:, item], validate_column_names=False, diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index 3f4cf7f543..cfbd677207 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -804,7 +804,8 @@ def __getitem__(self: Self, item: tuple[int, int | str]) -> Any: ... @overload def __getitem__( # type: ignore[overload-overlap] self: Self, - item: str | tuple[int | slice | Sequence[int] | _1DArray, int | str], + item: str + | tuple[int | slice | Sequence[int] | _1DArray | Series[Any], int | str], ) -> Series[Any]: ... @overload @@ -816,9 +817,10 @@ def __getitem__( | Sequence[int] | Sequence[str] | _1DArray + | Series[Any] | tuple[ - int | slice | Sequence[int] | _1DArray, - slice | Sequence[int] | Sequence[str] | _1DArray, + int | slice | Sequence[int] | _1DArray | Series[Any], + slice | Sequence[int] | Sequence[str] | _1DArray | Series[Any], ] ), ) -> Self: ... @@ -831,11 +833,12 @@ def __getitem__( | Sequence[int] | Sequence[str] | _1DArray + | Series[Any] | tuple[int, int | str] - | tuple[int | slice | Sequence[int] | _1DArray, int | str] + | tuple[int | slice | Sequence[int] | _1DArray | Series[Any], int | str] | tuple[ - int | slice | Sequence[int] | _1DArray, - slice | Sequence[int] | Sequence[str] | _1DArray, + int | slice | Sequence[int] | _1DArray | Series[Any], + slice | Sequence[int] | Sequence[str] | _1DArray | Series[Any], ] ), ) -> Series[Any] | Self | Any: @@ -890,7 +893,7 @@ def __getitem__( """ if isinstance(item, tuple) and len(item) > 2: msg = ( - "Tuples be passed to DataFrame.__getitem__ directly.\n\n" + "Tuples cannot be passed to DataFrame.__getitem__ directly.\n\n" "Hint: instead of `df[indices]`, did you mean `df[indices, :]`?" ) raise TypeError(msg) diff --git a/narwhals/stable/v1/__init__.py b/narwhals/stable/v1/__init__.py index 8760c1eee9..675f27b59d 100644 --- a/narwhals/stable/v1/__init__.py +++ b/narwhals/stable/v1/__init__.py @@ -156,7 +156,8 @@ def __getitem__(self: Self, item: tuple[int, int | str]) -> Any: ... @overload def __getitem__( # type: ignore[overload-overlap] self: Self, - item: str | tuple[int | slice | Sequence[int] | _1DArray, int | str], + item: str + | tuple[int | slice | Sequence[int] | _1DArray | NwSeries[Any], int | str], ) -> Series[Any]: ... @overload @@ -168,9 +169,10 @@ def __getitem__( | Sequence[int] | Sequence[str] | _1DArray + | NwSeries[Any] | tuple[ - int | slice | Sequence[int] | _1DArray, - slice | Sequence[int] | Sequence[str] | _1DArray, + int | slice | Sequence[int] | _1DArray | NwSeries[Any], + slice | Sequence[int] | Sequence[str] | _1DArray | NwSeries[Any], ] ), ) -> Self: ... diff --git a/narwhals/utils.py b/narwhals/utils.py index ac9f741132..f8250b3ecb 100644 --- a/narwhals/utils.py +++ b/narwhals/utils.py @@ -1322,12 +1322,6 @@ def is_sequence_like(sequence: Any | Sequence[_T]) -> bool: ) -def is_slice_strs(obj: object) -> bool: - return isinstance(obj, slice) and ( - isinstance(obj.start, str) or isinstance(obj.stop, str) - ) - - def is_slice_ints(obj: object) -> bool: return isinstance(obj, slice) and ( isinstance(obj.start, int) # e.g. [1:] diff --git a/tests/frame/getitem_test.py b/tests/frame/getitem_test.py index 45eca4c11a..6fa022883d 100644 --- a/tests/frame/getitem_test.py +++ b/tests/frame/getitem_test.py @@ -53,6 +53,11 @@ def test_slice_rows_with_step_pyarrow() -> None: match="Slicing with step is not supported on PyArrow tables", ): nw.from_native(pa.table(data))[1::2] + with pytest.raises( + NotImplementedError, + match="Slicing with step is not supported on PyArrow tables", + ): + nw.from_native(pa.chunked_array([data["a"]]), series_only=True)[1::2] def test_slice_lazy_fails() -> None: @@ -272,3 +277,34 @@ def test_zeroth_row_no_columns(constructor_eager: ConstructorEager) -> None: columns: list[str] = [] result = nw_df[0, columns] assert result.shape == (0, 0) + + +def test_single_tuple(constructor_eager: ConstructorEager) -> None: + data = {"a": [1, 2, 3]} + nw_df = nw.from_native(constructor_eager(data), eager_only=True) + # Technically works but we should probably discourage it + # OK if overloads don't match it. + result = nw_df[[0, 1],] # type: ignore[index] + expected = {"a": [1, 2]} + assert_equal_data(result, expected) + + +def test_triple_tuple(constructor_eager: ConstructorEager) -> None: + data = {"a": [1, 2, 3]} + with pytest.raises(TypeError, match="Tuples cannot"): + nw.from_native(constructor_eager(data), eager_only=True)[(1, 2, 3)] + + +def test_slice_with_series( + constructor_eager: ConstructorEager, request: pytest.FixtureRequest +) -> None: + if "pandas_pyarrow" in str(constructor_eager): + request.applymarker(pytest.mark.xfail) + data = {"a": [1, 2, 3], "c": [0, 2, 1]} + nw_df = nw.from_native(constructor_eager(data), eager_only=True) + result = nw_df[nw_df["c"]] + expected = {"a": [1, 3, 2], "c": [0, 1, 2]} + assert_equal_data(result, expected) + result = nw_df[nw_df["c"], ["a"]] + expected = {"a": [1, 3, 2]} + assert_equal_data(result, expected) diff --git a/tests/hypothesis/getitem_test.py b/tests/hypothesis/getitem_test.py index 714e32c982..4c1691131c 100644 --- a/tests/hypothesis/getitem_test.py +++ b/tests/hypothesis/getitem_test.py @@ -239,7 +239,7 @@ def test_getitem( df_polars = nw.from_native(pl.DataFrame(TEST_DATA)) try: result_polars = df_polars[selector] - except TypeError: + except TypeError: # pragma: no cover # If the selector fails on polars, then skip the test. # e.g. df[0, 'a'] fails, suggesting to use DataFrame.item to extract a single # element. @@ -252,7 +252,7 @@ def test_getitem( if isinstance(result_polars, nw.Series): assert_equal_data({"a": result_other}, {"a": result_polars.to_list()}) - elif isinstance(result_polars, (str, int)): + elif isinstance(result_polars, (str, int)): # pragma: no cover assert result_polars == result_other else: assert_equal_data( From 16a625b52a7c4c37faa493133322b43e8a06f516 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Fri, 18 Apr 2025 18:07:17 +0100 Subject: [PATCH 18/80] old polars fixup --- narwhals/_polars/dataframe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/narwhals/_polars/dataframe.py b/narwhals/_polars/dataframe.py index 473543f560..4aa4c8a521 100644 --- a/narwhals/_polars/dataframe.py +++ b/narwhals/_polars/dataframe.py @@ -304,7 +304,7 @@ def __getitem__(self: Self, item: Any) -> Any: elif ( isinstance(rows, (slice, range)) or is_sequence_like_ints(rows) - or isinstance(rows, self.native_series) + or isinstance(rows, pl.Series) ): native = native[rows, :] else: From 19e276260d793cfea99c711ab0c081fbc41fa211 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Fri, 18 Apr 2025 18:39:21 +0100 Subject: [PATCH 19/80] avoid tolist in old polars --- narwhals/_polars/dataframe.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/narwhals/_polars/dataframe.py b/narwhals/_polars/dataframe.py index 4aa4c8a521..141e1c449a 100644 --- a/narwhals/_polars/dataframe.py +++ b/narwhals/_polars/dataframe.py @@ -271,9 +271,11 @@ def __getitem__(self: Self, item: Any) -> Any: rows = list(rows) if isinstance(rows, tuple) else rows columns = list(columns) if isinstance(columns, tuple) else columns if is_numpy_array_1d(columns): - columns = columns.tolist() + columns = pl.Series(columns) - is_int_col_indexer = is_int_like_indexer(columns) + is_int_col_indexer = is_int_like_indexer(columns) or ( + isinstance(columns, pl.Series) and columns.dtype.is_integer() + ) native = self.native if not is_null_slice(columns): if hasattr(columns, "__len__") and len(columns) == 0: From 0bc31fcba58cd2ed2cf89bfecc8ee78cf4c75088 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 18 Apr 2025 18:40:28 +0100 Subject: [PATCH 20/80] `is_null_slice` typing --- narwhals/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/narwhals/utils.py b/narwhals/utils.py index f8250b3ecb..8fcc774c3f 100644 --- a/narwhals/utils.py +++ b/narwhals/utils.py @@ -1341,7 +1341,7 @@ def is_int_like_indexer(cols: object) -> bool: ) -def is_null_slice(obj: object) -> bool: +def is_null_slice(obj: object) -> TypeIs[slice[None, None, None]]: return isinstance(obj, slice) and obj == slice(None) From 426f4530d197d10fdf626af24e0f2b2893c826ea Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Fri, 18 Apr 2025 18:45:09 +0100 Subject: [PATCH 21/80] Revert "avoid tolist in old polars" This reverts commit 19e276260d793cfea99c711ab0c081fbc41fa211. --- narwhals/_polars/dataframe.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/narwhals/_polars/dataframe.py b/narwhals/_polars/dataframe.py index 141e1c449a..4aa4c8a521 100644 --- a/narwhals/_polars/dataframe.py +++ b/narwhals/_polars/dataframe.py @@ -271,11 +271,9 @@ def __getitem__(self: Self, item: Any) -> Any: rows = list(rows) if isinstance(rows, tuple) else rows columns = list(columns) if isinstance(columns, tuple) else columns if is_numpy_array_1d(columns): - columns = pl.Series(columns) + columns = columns.tolist() - is_int_col_indexer = is_int_like_indexer(columns) or ( - isinstance(columns, pl.Series) and columns.dtype.is_integer() - ) + is_int_col_indexer = is_int_like_indexer(columns) native = self.native if not is_null_slice(columns): if hasattr(columns, "__len__") and len(columns) == 0: From 40d552cd38a71c0769552f4e46c3db0cd1b962e0 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 18 Apr 2025 19:01:09 +0100 Subject: [PATCH 22/80] `is_sequence_like` typing --- narwhals/utils.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/narwhals/utils.py b/narwhals/utils.py index 8fcc774c3f..3bee6fc983 100644 --- a/narwhals/utils.py +++ b/narwhals/utils.py @@ -36,6 +36,7 @@ from narwhals.dependencies import get_sqlframe from narwhals.dependencies import is_cudf_series from narwhals.dependencies import is_modin_series +from narwhals.dependencies import is_narwhals_series from narwhals.dependencies import is_numpy_array_1d from narwhals.dependencies import is_pandas_dataframe from narwhals.dependencies import is_pandas_like_dataframe @@ -95,6 +96,7 @@ from narwhals.typing import SizeUnit from narwhals.typing import SupportsNativeNamespace from narwhals.typing import TimeUnit + from narwhals.typing import _1DArray FrameOrSeriesT = TypeVar( "FrameOrSeriesT", bound=Union[LazyFrame[Any], DataFrame[Any], Series[Any]] @@ -1312,13 +1314,13 @@ def is_sequence_like_ints(sequence: Any | Sequence[_T]) -> bool: ) -def is_sequence_like(sequence: Any | Sequence[_T]) -> bool: - from narwhals.series import Series - +def is_sequence_like( + sequence: Sequence[_T] | Any, +) -> TypeIs[Sequence[_T]] | TypeIs[Series[Any]] | TypeIs[_1DArray]: return ( - (isinstance(sequence, Sequence) and not isinstance(sequence, str)) - or (is_numpy_array_1d(sequence)) - or isinstance(sequence, Series) + is_sequence_but_not_str(sequence) + or is_numpy_array_1d(sequence) + or is_narwhals_series(sequence) ) From de06d26d945d94fa0724e346c5aefec046b50555 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Fri, 18 Apr 2025 19:32:20 +0100 Subject: [PATCH 23/80] avoid is_sequence_like_ints("") false positive --- narwhals/utils.py | 1 + tests/frame/getitem_test.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/narwhals/utils.py b/narwhals/utils.py index 8fcc774c3f..edb62164b4 100644 --- a/narwhals/utils.py +++ b/narwhals/utils.py @@ -1302,6 +1302,7 @@ def is_sequence_like_ints(sequence: Any | Sequence[_T]) -> bool: return ( ( isinstance(sequence, Sequence) + and not isinstance(sequence, str) and ( (len(sequence) > 0 and isinstance(sequence[0], int)) or (len(sequence) == 0) diff --git a/tests/frame/getitem_test.py b/tests/frame/getitem_test.py index 6fa022883d..2a8b2b8b35 100644 --- a/tests/frame/getitem_test.py +++ b/tests/frame/getitem_test.py @@ -251,7 +251,7 @@ def test_get_item_works_with_tuple_and_list_indexing_and_str( def test_getitem_ndarray_columns(constructor_eager: ConstructorEager) -> None: data = {"col1": ["a", "b", "c", "d"], "col2": np.arange(4), "col3": [4, 3, 2, 1]} nw_df = nw.from_native(constructor_eager(data), eager_only=True) - arr: np.ndarray[tuple[int], np.dtype[np.int64]] = np.array([0, 1]) # pyright: ignore[reportAssignmentType] + arr = np.arange(2) result = nw_df[:, arr] expected = {"col1": ["a", "b", "c", "d"], "col2": [0, 1, 2, 3]} assert_equal_data(result, expected) From b503c2f2bb41b4490aaa310cdb764698789350d5 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 18 Apr 2025 19:51:33 +0100 Subject: [PATCH 24/80] fix(typing): Resolve `PolarsSeries` issues Fixes https://github.com/narwhals-dev/narwhals/pull/2393#discussion_r2050713248 --- narwhals/_compliant/series.py | 34 +++++++++++++++++----------------- narwhals/_polars/namespace.py | 4 ++-- narwhals/_polars/series.py | 2 -- 3 files changed, 19 insertions(+), 21 deletions(-) diff --git a/narwhals/_compliant/series.py b/narwhals/_compliant/series.py index 754a784378..1f2c27d6aa 100644 --- a/narwhals/_compliant/series.py +++ b/narwhals/_compliant/series.py @@ -81,6 +81,7 @@ def __narwhals_namespace__(self) -> CompliantNamespace[Any, Any]: ... def __native_namespace__(self) -> ModuleType: ... def __array__(self, dtype: Any, *, copy: bool | None) -> _1DArray: ... def __contains__(self, other: Any) -> bool: ... + def __getitem__(self, item: Any) -> Any: ... def __iter__(self) -> Iterator[Any]: ... def __len__(self) -> int: return len(self.native) @@ -287,23 +288,6 @@ def list(self) -> Any: ... @property def struct(self) -> Any: ... - def _gather(self, indices: _IntIndexer) -> Self: ... - - def _gather_slice(self, indices: slice | range) -> Self: ... - - def __getitem__(self, rows: Any) -> Self: - if is_null_slice(rows): - return self - if isinstance(rows, int): - return self._gather([rows]) - elif isinstance(rows, (slice, range)): - return self._gather_slice(rows) - elif is_sequence_like_ints(rows) or isinstance(rows, self.native.__class__): - return self._gather(rows) - else: - msg = "Unreachable code" - raise AssertionError(msg) - class EagerSeries(CompliantSeries[NativeSeriesT], Protocol[NativeSeriesT]): _native_series: Any @@ -335,6 +319,22 @@ def __narwhals_namespace__( def _to_expr(self) -> EagerExpr[Any, Any]: return self.__narwhals_namespace__()._expr._from_series(self) # type: ignore[no-any-return] + def _gather(self, indices: _IntIndexer) -> Self: ... + def _gather_slice(self, indices: slice | range) -> Self: ... + + def __getitem__(self, item: Any) -> Self: + if is_null_slice(item): + return self + if isinstance(item, int): + return self._gather([item]) + elif isinstance(item, (slice, range)): + return self._gather_slice(item) + elif is_sequence_like_ints(item) or isinstance(item, self.native.__class__): + return self._gather(item) + else: + msg = "Unreachable code" + raise AssertionError(msg) + @property def str(self) -> EagerSeriesStringNamespace[Self, NativeSeriesT]: ... @property diff --git a/narwhals/_polars/namespace.py b/narwhals/_polars/namespace.py index 14165996d3..54af4f516b 100644 --- a/narwhals/_polars/namespace.py +++ b/narwhals/_polars/namespace.py @@ -253,9 +253,9 @@ def concat_str( # i. None of that is useful here # 2. We don't have a `PolarsSelector` abstraction, and just use `PolarsExpr` @property - def selectors(self) -> CompliantSelectorNamespace[PolarsDataFrame, PolarsSeries]: # pyright: ignore[reportInvalidTypeArguments] + def selectors(self) -> CompliantSelectorNamespace[PolarsDataFrame, PolarsSeries]: return cast( - "CompliantSelectorNamespace[PolarsDataFrame, PolarsSeries]", # pyright: ignore[reportInvalidTypeArguments] + "CompliantSelectorNamespace[PolarsDataFrame, PolarsSeries]", PolarsSelectorNamespace(self), ) diff --git a/narwhals/_polars/series.py b/narwhals/_polars/series.py index c82518e687..971ad37cf1 100644 --- a/narwhals/_polars/series.py +++ b/narwhals/_polars/series.py @@ -605,8 +605,6 @@ def struct(self: Self) -> PolarsSeriesStructNamespace: drop_nulls: Method[Self] fill_null: Method[Self] filter: Method[Self] - _gather_slice: Method[Self] - _gather: Method[Self] gather_every: Method[Self] head: Method[Self] is_between: Method[Self] From d3b02240f3f1f4c23fc575355762921fc527a723 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 18 Apr 2025 22:07:38 +0100 Subject: [PATCH 25/80] feat(typing): Add `__getitem__` aliases from polars https://github.com/narwhals-dev/narwhals/pull/2393#discussion_r2050476211, https://github.com/narwhals-dev/narwhals/pull/2393#discussion_r2051016924 --- narwhals/typing.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/narwhals/typing.py b/narwhals/typing.py index e843fcf462..9bbbe7e4fb 100644 --- a/narwhals/typing.py +++ b/narwhals/typing.py @@ -4,6 +4,7 @@ from typing import Any from typing import Literal from typing import Protocol +from typing import Sequence from typing import TypeVar from typing import Union @@ -315,6 +316,17 @@ def __native_namespace__(self) -> ModuleType: ... _StrIndexer: TypeAlias = Any # noqa: PYI047 +# Annotations for `__getitem__` methods +_Slice: TypeAlias = "slice[Any, Any, Any]" +SingleIndexSelector: TypeAlias = int +SingleNameSelector: TypeAlias = str +MultiIndexSelector: TypeAlias = "_Slice | Sequence[int] | Series[Any] | _1DArray" +MultiNameSelector: TypeAlias = "_Slice | Sequence[str] | Series[Any] | _1DArray" +BooleanMask: TypeAlias = "Sequence[bool] | Series[Any] | _1DArray" +SingleColSelector: TypeAlias = "SingleIndexSelector | SingleNameSelector" +MultiColSelector: TypeAlias = "MultiIndexSelector | MultiNameSelector | BooleanMask" + + # ruff: noqa: N802 class DTypes(Protocol): @property From 20f61fc432d2a6f6886cb86f514f8e027e9466ae Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 18 Apr 2025 22:17:00 +0100 Subject: [PATCH 26/80] feat: Clone `polars` overloads --- narwhals/dataframe.py | 53 +++++++++++++++++++------------------------ narwhals/series.py | 10 ++++---- 2 files changed, 28 insertions(+), 35 deletions(-) diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index cfbd677207..aab12987c4 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -67,10 +67,13 @@ from narwhals.typing import IntoFrame from narwhals.typing import JoinStrategy from narwhals.typing import LazyUniqueKeepStrategy + from narwhals.typing import MultiColSelector + from narwhals.typing import MultiIndexSelector from narwhals.typing import PivotAgg + from narwhals.typing import SingleColSelector + from narwhals.typing import SingleIndexSelector from narwhals.typing import SizeUnit from narwhals.typing import UniqueKeepStrategy - from narwhals.typing import _1DArray from narwhals.typing import _2DArray PS = ParamSpec("PS") @@ -798,48 +801,38 @@ def estimated_size(self: Self, unit: SizeUnit = "b") -> int | float: """ return self._compliant_frame.estimated_size(unit=unit) + # `str` overlaps with `Sequence[str]` + # We can ignore this but we must keep this overload ordering @overload - def __getitem__(self: Self, item: tuple[int, int | str]) -> Any: ... + def __getitem__(self, item: tuple[SingleIndexSelector, SingleColSelector]) -> Any: ... @overload def __getitem__( # type: ignore[overload-overlap] - self: Self, - item: str - | tuple[int | slice | Sequence[int] | _1DArray | Series[Any], int | str], + self, item: str | tuple[MultiIndexSelector, SingleColSelector] ) -> Series[Any]: ... @overload def __getitem__( - self: Self, + self, item: ( - int - | slice - | Sequence[int] - | Sequence[str] - | _1DArray - | Series[Any] - | tuple[ - int | slice | Sequence[int] | _1DArray | Series[Any], - slice | Sequence[int] | Sequence[str] | _1DArray | Series[Any], - ] + SingleIndexSelector + | MultiIndexSelector + | MultiColSelector + | tuple[SingleIndexSelector, MultiColSelector] + | tuple[MultiIndexSelector, MultiColSelector] ), ) -> Self: ... def __getitem__( - self: Self, + self, item: ( - str - | int - | slice - | Sequence[int] - | Sequence[str] - | _1DArray - | Series[Any] - | tuple[int, int | str] - | tuple[int | slice | Sequence[int] | _1DArray | Series[Any], int | str] - | tuple[ - int | slice | Sequence[int] | _1DArray | Series[Any], - slice | Sequence[int] | Sequence[str] | _1DArray | Series[Any], - ] + SingleIndexSelector + | SingleColSelector + | MultiColSelector + | MultiIndexSelector + | tuple[SingleIndexSelector, SingleColSelector] + | tuple[SingleIndexSelector, MultiColSelector] + | tuple[MultiIndexSelector, SingleColSelector] + | tuple[MultiIndexSelector, MultiColSelector] ), ) -> Series[Any] | Self | Any: """Extract column or slice of DataFrame. diff --git a/narwhals/series.py b/narwhals/series.py index f8132c1e5c..8ea518f0c9 100644 --- a/narwhals/series.py +++ b/narwhals/series.py @@ -21,7 +21,9 @@ from narwhals.series_struct import SeriesStructNamespace from narwhals.translate import to_native from narwhals.typing import IntoSeriesT +from narwhals.typing import MultiIndexSelector from narwhals.typing import NonNestedLiteral +from narwhals.typing import SingleIndexSelector from narwhals.utils import _validate_rolling_arguments from narwhals.utils import generate_repr from narwhals.utils import is_compliant_series @@ -130,14 +132,12 @@ def __array__(self: Self, dtype: Any = None, copy: bool | None = None) -> _1DArr return self._compliant_series.__array__(dtype=dtype, copy=copy) @overload - def __getitem__(self: Self, idx: int) -> Any: ... + def __getitem__(self, idx: SingleIndexSelector) -> Any: ... @overload - def __getitem__(self: Self, idx: slice | Sequence[int] | _1DArray | Self) -> Self: ... + def __getitem__(self, idx: MultiIndexSelector) -> Self: ... - def __getitem__( - self: Self, idx: int | slice | Sequence[int] | _1DArray | Self - ) -> Any | Self: + def __getitem__(self, idx: SingleIndexSelector | MultiIndexSelector) -> Any | Self: """Retrieve elements from the object using integer indexing or slicing. Arguments: From 42d5a4ed04fc4b6eb4661b9eb5e88a4afe249f2e Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 18 Apr 2025 22:28:31 +0100 Subject: [PATCH 27/80] backport `v1` --- narwhals/stable/v1/__init__.py | 41 +++++++++++++++++++++------------- 1 file changed, 25 insertions(+), 16 deletions(-) diff --git a/narwhals/stable/v1/__init__.py b/narwhals/stable/v1/__init__.py index 675f27b59d..5f91a88b62 100644 --- a/narwhals/stable/v1/__init__.py +++ b/narwhals/stable/v1/__init__.py @@ -97,7 +97,11 @@ from narwhals.typing import IntoFrame from narwhals.typing import IntoLazyFrameT from narwhals.typing import IntoSeries + from narwhals.typing import MultiColSelector + from narwhals.typing import MultiIndexSelector from narwhals.typing import NonNestedLiteral + from narwhals.typing import SingleColSelector + from narwhals.typing import SingleIndexSelector from narwhals.typing import _1DArray from narwhals.typing import _2DArray @@ -151,32 +155,37 @@ def _lazyframe(self: Self) -> type[LazyFrame[Any]]: return LazyFrame @overload - def __getitem__(self: Self, item: tuple[int, int | str]) -> Any: ... + def __getitem__(self, item: tuple[SingleIndexSelector, SingleColSelector]) -> Any: ... @overload def __getitem__( # type: ignore[overload-overlap] - self: Self, - item: str - | tuple[int | slice | Sequence[int] | _1DArray | NwSeries[Any], int | str], + self, item: str | tuple[MultiIndexSelector, SingleColSelector] ) -> Series[Any]: ... @overload def __getitem__( - self: Self, + self, item: ( - int - | slice - | Sequence[int] - | Sequence[str] - | _1DArray - | NwSeries[Any] - | tuple[ - int | slice | Sequence[int] | _1DArray | NwSeries[Any], - slice | Sequence[int] | Sequence[str] | _1DArray | NwSeries[Any], - ] + SingleIndexSelector + | MultiIndexSelector + | MultiColSelector + | tuple[SingleIndexSelector, MultiColSelector] + | tuple[MultiIndexSelector, MultiColSelector] ), ) -> Self: ... - def __getitem__(self: Self, item: Any) -> Any: + def __getitem__( + self, + item: ( + SingleIndexSelector + | SingleColSelector + | MultiColSelector + | MultiIndexSelector + | tuple[SingleIndexSelector, SingleColSelector] + | tuple[SingleIndexSelector, MultiColSelector] + | tuple[MultiIndexSelector, SingleColSelector] + | tuple[MultiIndexSelector, MultiColSelector] + ), + ) -> Series[Any] | Self | Any: return super().__getitem__(item) def lazy( From 694102917c67b2f244ea1f68539ac96523cbd23b Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 18 Apr 2025 22:32:48 +0100 Subject: [PATCH 28/80] unhide typing issues --- narwhals/dataframe.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index aab12987c4..7b06943bcb 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -893,8 +893,8 @@ def __getitem__( if isinstance(item, tuple) and len(item) == 2: # These are so heavily overloaded that we just ignore the types for now. - rows: Any = item[0] if not is_null_slice(item[0]) else None - columns: Any = item[1] if not is_null_slice(item[1]) else None + rows = item[0] if not is_null_slice(item[0]) else None + columns = item[1] if not is_null_slice(item[1]) else None elif isinstance(item, tuple) and item: rows = item[0] columns = None From f4aed5fa34f29fe17cc3d1c8c6dd3d3473eda121 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 18 Apr 2025 23:12:42 +0100 Subject: [PATCH 29/80] chore: Remove `BooleanMask` https://github.com/narwhals-dev/narwhals/pull/2393#discussion_r2051164421 --- narwhals/typing.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/narwhals/typing.py b/narwhals/typing.py index 9bbbe7e4fb..0bb2833c66 100644 --- a/narwhals/typing.py +++ b/narwhals/typing.py @@ -322,9 +322,8 @@ def __native_namespace__(self) -> ModuleType: ... SingleNameSelector: TypeAlias = str MultiIndexSelector: TypeAlias = "_Slice | Sequence[int] | Series[Any] | _1DArray" MultiNameSelector: TypeAlias = "_Slice | Sequence[str] | Series[Any] | _1DArray" -BooleanMask: TypeAlias = "Sequence[bool] | Series[Any] | _1DArray" SingleColSelector: TypeAlias = "SingleIndexSelector | SingleNameSelector" -MultiColSelector: TypeAlias = "MultiIndexSelector | MultiNameSelector | BooleanMask" +MultiColSelector: TypeAlias = "MultiIndexSelector | MultiNameSelector" # ruff: noqa: N802 From f05c3a2440f8bdec74c139e34b88f5474b100f66 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 18 Apr 2025 23:38:44 +0100 Subject: [PATCH 30/80] chore(typing): Fixing `DataFrame.__getitem__` --- narwhals/dataframe.py | 30 +++++++++++++----------------- 1 file changed, 13 insertions(+), 17 deletions(-) diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index 7b06943bcb..ac50aae94b 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -884,27 +884,23 @@ def __getitem__( 1 2 Name: a, dtype: int64 """ - if isinstance(item, tuple) and len(item) > 2: - msg = ( - "Tuples cannot be passed to DataFrame.__getitem__ directly.\n\n" - "Hint: instead of `df[indices]`, did you mean `df[indices, :]`?" - ) - raise TypeError(msg) - - if isinstance(item, tuple) and len(item) == 2: + if isinstance(item, tuple): + if len(item) > 2: + msg = ( + "Tuples cannot be passed to DataFrame.__getitem__ directly.\n\n" + "Hint: instead of `df[indices]`, did you mean `df[indices, :]`?" + ) + raise TypeError(msg) # These are so heavily overloaded that we just ignore the types for now. - rows = item[0] if not is_null_slice(item[0]) else None - columns = item[1] if not is_null_slice(item[1]) else None - elif isinstance(item, tuple) and item: - rows = item[0] - columns = None + rows = None if not item or is_null_slice(item[0]) else item[0] + columns = None if len(item) < 2 or is_null_slice(item[1]) else item[1] elif isinstance(item, str): rows = None columns = item elif is_int_like_indexer(item): rows = item columns = None - elif is_sequence_like(item) or isinstance(item, (slice, range)): + elif is_sequence_like(item) or isinstance(item, slice): rows = None columns = item else: @@ -916,6 +912,8 @@ def __getitem__( "- Use `DataFrame.filter(mask)` to filter rows based on a boolean mask." ) raise TypeError(msg) + if rows is None and columns is None: + return self compliant = self._compliant_frame rows = to_native(rows, pass_through=True) @@ -923,12 +921,10 @@ def __getitem__( if isinstance(rows, int) and isinstance(columns, (int, str)): return self.item(rows, columns) - if isinstance(columns, (str, int)): + if isinstance(columns, (int, str)): col_name = columns if isinstance(columns, str) else self.columns[columns] series = self.get_column(col_name) return series[rows] if rows is not None else series - if rows is None and columns is None: - return self if rows is None: return self._with_compliant(compliant[:, columns]) if columns is None: From 1cf8ea8b4a1e98503137dc4b5818b6cf729724bb Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 18 Apr 2025 23:49:26 +0100 Subject: [PATCH 31/80] refactor: Merge branches --- narwhals/dataframe.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index ac50aae94b..83e1eae563 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -894,13 +894,10 @@ def __getitem__( # These are so heavily overloaded that we just ignore the types for now. rows = None if not item or is_null_slice(item[0]) else item[0] columns = None if len(item) < 2 or is_null_slice(item[1]) else item[1] - elif isinstance(item, str): - rows = None - columns = item elif is_int_like_indexer(item): rows = item columns = None - elif is_sequence_like(item) or isinstance(item, slice): + elif is_sequence_like(item) or isinstance(item, (slice, str)): rows = None columns = item else: From 147dbac9cc96921f9d543dad8694b388aff5736c Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 18 Apr 2025 23:52:00 +0100 Subject: [PATCH 32/80] =?UTF-8?q?early=20return=20even=20earlier=20?= =?UTF-8?q?=F0=9F=98=8E?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- narwhals/dataframe.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index 83e1eae563..f145cef7c6 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -894,6 +894,8 @@ def __getitem__( # These are so heavily overloaded that we just ignore the types for now. rows = None if not item or is_null_slice(item[0]) else item[0] columns = None if len(item) < 2 or is_null_slice(item[1]) else item[1] + if rows is None and columns is None: + return self elif is_int_like_indexer(item): rows = item columns = None @@ -909,8 +911,6 @@ def __getitem__( "- Use `DataFrame.filter(mask)` to filter rows based on a boolean mask." ) raise TypeError(msg) - if rows is None and columns is None: - return self compliant = self._compliant_frame rows = to_native(rows, pass_through=True) From 46be441789c6131f33625e463db610f950e7e413 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 19 Apr 2025 10:48:49 +0100 Subject: [PATCH 33/80] feat(typing): Slice-typing Related https://github.com/narwhals-dev/narwhals/pull/2393#discussion_r2050928165 --- narwhals/_compliant/dataframe.py | 10 +++++----- narwhals/_compliant/series.py | 4 ++-- narwhals/_polars/dataframe.py | 10 +++++----- narwhals/dataframe.py | 10 +++++----- narwhals/typing.py | 8 +++++++- narwhals/utils.py | 26 ++++++++++++++------------ 6 files changed, 38 insertions(+), 30 deletions(-) diff --git a/narwhals/_compliant/dataframe.py b/narwhals/_compliant/dataframe.py index ca040c7dd8..ee9ed3bdfa 100644 --- a/narwhals/_compliant/dataframe.py +++ b/narwhals/_compliant/dataframe.py @@ -25,10 +25,10 @@ from narwhals.utils import Version from narwhals.utils import _StoresNative from narwhals.utils import deprecated -from narwhals.utils import is_int_like_indexer -from narwhals.utils import is_null_slice +from narwhals.utils import is_index_selector from narwhals.utils import is_sequence_like from narwhals.utils import is_sequence_like_ints +from narwhals.utils import is_slice_none if TYPE_CHECKING: from io import BytesIO @@ -391,9 +391,9 @@ def _select_slice_of_labels(self, indices: slice | range) -> Self: ... def __getitem__(self, item: tuple[Any, Any]) -> Self: rows, columns = item - is_int_col_indexer = is_int_like_indexer(columns) + is_int_col_indexer = is_index_selector(columns) compliant = self - if not is_null_slice(columns): + if not is_slice_none(columns): if hasattr(columns, "__len__") and len(columns) == 0: return compliant.select() if is_int_col_indexer and not isinstance(columns, slice): @@ -408,7 +408,7 @@ def __getitem__(self, item: tuple[Any, Any]) -> Self: msg = "Unreachable code" raise AssertionError(msg) - if not is_null_slice(rows): + if not is_slice_none(rows): is_native_series = self.__narwhals_namespace__()._series._is_native if isinstance(rows, int): compliant = compliant._gather([rows]) diff --git a/narwhals/_compliant/series.py b/narwhals/_compliant/series.py index 1f2c27d6aa..3b1d96cb1f 100644 --- a/narwhals/_compliant/series.py +++ b/narwhals/_compliant/series.py @@ -23,8 +23,8 @@ from narwhals._translate import NumpyConvertible from narwhals.utils import _StoresCompliant from narwhals.utils import _StoresNative -from narwhals.utils import is_null_slice from narwhals.utils import is_sequence_like_ints +from narwhals.utils import is_slice_none from narwhals.utils import unstable if TYPE_CHECKING: @@ -323,7 +323,7 @@ def _gather(self, indices: _IntIndexer) -> Self: ... def _gather_slice(self, indices: slice | range) -> Self: ... def __getitem__(self, item: Any) -> Self: - if is_null_slice(item): + if is_slice_none(item): return self if isinstance(item, int): return self._gather([item]) diff --git a/narwhals/_polars/dataframe.py b/narwhals/_polars/dataframe.py index 4aa4c8a521..887b7241b4 100644 --- a/narwhals/_polars/dataframe.py +++ b/narwhals/_polars/dataframe.py @@ -21,10 +21,10 @@ from narwhals.utils import Implementation from narwhals.utils import _into_arrow_table from narwhals.utils import convert_str_slice_to_int_slice -from narwhals.utils import is_int_like_indexer -from narwhals.utils import is_null_slice +from narwhals.utils import is_index_selector from narwhals.utils import is_sequence_like from narwhals.utils import is_sequence_like_ints +from narwhals.utils import is_slice_none from narwhals.utils import parse_columns_to_drop from narwhals.utils import parse_version from narwhals.utils import requires @@ -273,9 +273,9 @@ def __getitem__(self: Self, item: Any) -> Any: if is_numpy_array_1d(columns): columns = columns.tolist() - is_int_col_indexer = is_int_like_indexer(columns) + is_int_col_indexer = is_index_selector(columns) native = self.native - if not is_null_slice(columns): + if not is_slice_none(columns): if hasattr(columns, "__len__") and len(columns) == 0: native = native.select() if is_int_col_indexer and not isinstance(columns, (slice, range)): @@ -298,7 +298,7 @@ def __getitem__(self: Self, item: Any) -> Any: msg = "Unreachable code" raise AssertionError(msg) - if not is_null_slice(rows): + if not is_slice_none(rows): if isinstance(rows, int): native = native[[rows], :] elif ( diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index f145cef7c6..1909cd6679 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -34,10 +34,10 @@ from narwhals.utils import generate_repr from narwhals.utils import is_compliant_dataframe from narwhals.utils import is_compliant_lazyframe -from narwhals.utils import is_int_like_indexer +from narwhals.utils import is_index_selector from narwhals.utils import is_list_of -from narwhals.utils import is_null_slice from narwhals.utils import is_sequence_like +from narwhals.utils import is_slice_none from narwhals.utils import issue_deprecation_warning from narwhals.utils import parse_version from narwhals.utils import supports_arrow_c_stream @@ -892,11 +892,11 @@ def __getitem__( ) raise TypeError(msg) # These are so heavily overloaded that we just ignore the types for now. - rows = None if not item or is_null_slice(item[0]) else item[0] - columns = None if len(item) < 2 or is_null_slice(item[1]) else item[1] + rows = None if not item or is_slice_none(item[0]) else item[0] + columns = None if len(item) < 2 or is_slice_none(item[1]) else item[1] if rows is None and columns is None: return self - elif is_int_like_indexer(item): + elif is_index_selector(item): rows = item columns = None elif is_sequence_like(item) or isinstance(item, (slice, str)): diff --git a/narwhals/typing.py b/narwhals/typing.py index 0bb2833c66..abe469aa1c 100644 --- a/narwhals/typing.py +++ b/narwhals/typing.py @@ -318,9 +318,15 @@ def __native_namespace__(self) -> ModuleType: ... # Annotations for `__getitem__` methods _Slice: TypeAlias = "slice[Any, Any, Any]" +_SliceNone: TypeAlias = "slice[None, None, None]" +_SliceIndex: TypeAlias = ( + "slice[int, Any, Any] | slice[Any, int, Any] | slice[None, None, int] | _SliceNone" +) +"""E.g. `[1:]` or `[:3]` or `[::2]`.""" + SingleIndexSelector: TypeAlias = int SingleNameSelector: TypeAlias = str -MultiIndexSelector: TypeAlias = "_Slice | Sequence[int] | Series[Any] | _1DArray" +MultiIndexSelector: TypeAlias = "_SliceIndex | Sequence[int] | Series[Any] | _1DArray" MultiNameSelector: TypeAlias = "_Slice | Sequence[str] | Series[Any] | _1DArray" SingleColSelector: TypeAlias = "SingleIndexSelector | SingleNameSelector" MultiColSelector: TypeAlias = "MultiIndexSelector | MultiNameSelector" diff --git a/narwhals/utils.py b/narwhals/utils.py index 88fd583f9c..2a4a31ef79 100644 --- a/narwhals/utils.py +++ b/narwhals/utils.py @@ -93,10 +93,14 @@ from narwhals.typing import DataFrameLike from narwhals.typing import DTypes from narwhals.typing import IntoSeriesT + from narwhals.typing import MultiIndexSelector + from narwhals.typing import SingleIndexSelector from narwhals.typing import SizeUnit from narwhals.typing import SupportsNativeNamespace from narwhals.typing import TimeUnit from narwhals.typing import _1DArray + from narwhals.typing import _SliceIndex + from narwhals.typing import _SliceNone FrameOrSeriesT = TypeVar( "FrameOrSeriesT", bound=Union[LazyFrame[Any], DataFrame[Any], Series[Any]] @@ -1325,29 +1329,27 @@ def is_sequence_like( ) -def is_slice_ints(obj: object) -> bool: +def is_slice_index(obj: _SliceIndex | Any) -> TypeIs[_SliceIndex]: return isinstance(obj, slice) and ( - isinstance(obj.start, int) # e.g. [1:] - or isinstance(obj.stop, int) # e.g. [:3] - or (obj.start is None and obj.stop is None) # e.g. [::2] + isinstance(obj.start, int) + or isinstance(obj.stop, int) + or (isinstance(obj.step, int) and obj.start is None and obj.stop is None) ) -def is_int_like_indexer(cols: object) -> bool: - from narwhals.series import Series +def is_slice_none(obj: object) -> TypeIs[_SliceNone]: + return isinstance(obj, slice) and obj == slice(None) + +def is_index_selector(cols: SingleIndexSelector | MultiIndexSelector | Any) -> bool: return ( isinstance(cols, int) or is_sequence_like_ints(cols) - or is_slice_ints(cols) - or (isinstance(cols, Series) and cols.dtype.is_integer()) + or is_slice_index(cols) + or (is_narwhals_series(cols) and cols.dtype.is_integer()) ) -def is_null_slice(obj: object) -> TypeIs[slice[None, None, None]]: - return isinstance(obj, slice) and obj == slice(None) - - def is_list_of(obj: Any, tp: type[_T]) -> TypeIs[list[type[_T]]]: # Check if an object is a list of `tp`, only sniffing the first element. return bool(isinstance(obj, list) and obj and isinstance(obj[0], tp)) From 7a75fb290f5568fc1be1e75a6f539e6fc8c79c23 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 19 Apr 2025 12:24:45 +0100 Subject: [PATCH 34/80] fix: `get_column` rename --- narwhals/_interchange/dataframe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/narwhals/_interchange/dataframe.py b/narwhals/_interchange/dataframe.py index 03ac0f3b35..c52cb0a506 100644 --- a/narwhals/_interchange/dataframe.py +++ b/narwhals/_interchange/dataframe.py @@ -105,7 +105,7 @@ def __native_namespace__(self) -> NoReturn: ) raise NotImplementedError(msg) - def __getitem__(self, name: str) -> InterchangeSeries: + def get_column(self, name: str) -> InterchangeSeries: from narwhals._interchange.series import InterchangeSeries return InterchangeSeries( From d24f1b71fbbd6a968685f24186e94e309da075e1 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sat, 19 Apr 2025 19:49:34 +0100 Subject: [PATCH 35/80] pass compliant down instead of native --- narwhals/_arrow/dataframe.py | 4 ++++ narwhals/_compliant/dataframe.py | 19 ++++++++++++------- narwhals/_compliant/series.py | 5 ++++- narwhals/_polars/dataframe.py | 9 +++++++-- narwhals/_polars/series.py | 3 +++ narwhals/dataframe.py | 16 +++++++++++++--- narwhals/series.py | 3 ++- narwhals/utils.py | 16 +++++++--------- tests/frame/getitem_test.py | 20 ++++++++++++++++++++ 9 files changed, 72 insertions(+), 23 deletions(-) diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index 61ffda5c62..cf949f9cd8 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -273,9 +273,13 @@ def _select_slice_of_indices(self, item: slice | range) -> Self: ) def _select_indices(self, item: _IntIndexer) -> Self: + if isinstance(item, pa.ChunkedArray): + item = item.to_pylist() return self._with_native(self.native.select([self.columns[x] for x in item])) def _select_labels(self, item: _StrIndexer) -> Self: + if isinstance(item, pa.ChunkedArray): + item = item.to_pylist() return self._with_native(self.native.select(item)) @property diff --git a/narwhals/_compliant/dataframe.py b/narwhals/_compliant/dataframe.py index e4888a9476..7b60fcae8c 100644 --- a/narwhals/_compliant/dataframe.py +++ b/narwhals/_compliant/dataframe.py @@ -25,6 +25,7 @@ from narwhals.utils import Version from narwhals.utils import _StoresNative from narwhals.utils import deprecated +from narwhals.utils import is_compliant_series from narwhals.utils import is_index_selector from narwhals.utils import is_sequence_like from narwhals.utils import is_sequence_like_ints @@ -396,12 +397,16 @@ def __getitem__(self, item: tuple[Any, Any]) -> Self: if not is_slice_none(columns): if hasattr(columns, "__len__") and len(columns) == 0: return compliant.select() - if is_int_col_indexer and not isinstance(columns, slice): - compliant = compliant._select_indices(columns) - elif is_int_col_indexer: + if is_int_col_indexer and isinstance(columns, (slice, range)): compliant = compliant._select_slice_of_indices(columns) - elif isinstance(columns, slice): + elif is_int_col_indexer and is_compliant_series(columns): + compliant = self._select_indices(columns.native) + elif is_int_col_indexer and is_sequence_like_ints(columns): + compliant = compliant._select_indices(columns) + elif isinstance(columns, (slice, range)): compliant = compliant._select_slice_of_labels(columns) + elif is_compliant_series(columns): + compliant = self._select_labels(columns.native) elif is_sequence_like(columns): compliant = self._select_labels(columns) else: @@ -409,13 +414,13 @@ def __getitem__(self, item: tuple[Any, Any]) -> Self: raise AssertionError(msg) if not is_slice_none(rows): - is_native_series = self.__narwhals_namespace__()._series._is_native if isinstance(rows, int): compliant = compliant._gather([rows]) elif isinstance(rows, (slice, range)): compliant = compliant._gather_slice(rows) - - elif is_sequence_like_ints(rows) or is_native_series(rows): + elif is_compliant_series(rows): + compliant = compliant._gather(rows.native) + elif is_sequence_like_ints(rows): compliant = compliant._gather(rows) else: msg = "Unreachable code" diff --git a/narwhals/_compliant/series.py b/narwhals/_compliant/series.py index 42bad57ff2..e3d8abf6d9 100644 --- a/narwhals/_compliant/series.py +++ b/narwhals/_compliant/series.py @@ -23,6 +23,7 @@ from narwhals._translate import NumpyConvertible from narwhals.utils import _StoresCompliant from narwhals.utils import _StoresNative +from narwhals.utils import is_compliant_series from narwhals.utils import is_sequence_like_ints from narwhals.utils import is_slice_none from narwhals.utils import unstable @@ -329,7 +330,9 @@ def __getitem__(self, item: Any) -> Self: return self._gather([item]) elif isinstance(item, (slice, range)): return self._gather_slice(item) - elif is_sequence_like_ints(item) or isinstance(item, self.native.__class__): + elif is_compliant_series(item): + return self._gather(item.native) + elif is_sequence_like_ints(item): return self._gather(item) else: msg = "Unreachable code" diff --git a/narwhals/_polars/dataframe.py b/narwhals/_polars/dataframe.py index 71001c8899..9439e41adf 100644 --- a/narwhals/_polars/dataframe.py +++ b/narwhals/_polars/dataframe.py @@ -21,6 +21,7 @@ from narwhals.utils import Implementation from narwhals.utils import _into_arrow_table from narwhals.utils import convert_str_slice_to_int_slice +from narwhals.utils import is_compliant_series from narwhals.utils import is_index_selector from narwhals.utils import is_sequence_like from narwhals.utils import is_sequence_like_ints @@ -262,12 +263,16 @@ def shape(self) -> tuple[int, int]: return self.native.shape def __getitem__(self, item: Any) -> Any: + rows, columns = item + if is_compliant_series(rows): + rows = rows.native + if is_compliant_series(columns): + columns = columns.native if self._backend_version > (0, 20, 30): - return self._from_native_object(self.native.__getitem__(item)) + return self._from_native_object(self.native.__getitem__((rows, columns))) else: # pragma: no cover # TODO(marco): we can delete this branch after Polars==0.20.30 becomes the minimum # Polars version we support - rows, columns = item rows = list(rows) if isinstance(rows, tuple) else rows columns = list(columns) if isinstance(columns, tuple) else columns if is_numpy_array_1d(columns): diff --git a/narwhals/_polars/series.py b/narwhals/_polars/series.py index 433e0648c1..b38c1ef0bb 100644 --- a/narwhals/_polars/series.py +++ b/narwhals/_polars/series.py @@ -18,6 +18,7 @@ from narwhals._polars.utils import native_to_narwhals_dtype from narwhals.dependencies import is_numpy_array_1d from narwhals.utils import Implementation +from narwhals.utils import is_compliant_series from narwhals.utils import requires from narwhals.utils import validate_backend_version @@ -182,6 +183,8 @@ def __getitem__(self, item: int) -> Any: ... def __getitem__(self, item: slice | Sequence[int] | pl.Series) -> Self: ... def __getitem__(self, item: int | slice | Sequence[int] | pl.Series) -> Any | Self: + if is_compliant_series(item): + item = item.native return self._from_native_object(self.native.__getitem__(item)) def cast(self, dtype: DType | type[DType]) -> Self: diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index b35b385ae9..86803f156c 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -12,6 +12,7 @@ from typing import NoReturn from typing import Sequence from typing import TypeVar +from typing import cast from typing import overload from warnings import warn @@ -882,6 +883,8 @@ def __getitem__( 1 2 Name: a, dtype: int64 """ + from narwhals.series import Series + if isinstance(item, tuple): if len(item) > 2: msg = ( @@ -911,15 +914,22 @@ def __getitem__( raise TypeError(msg) compliant = self._compliant_frame - rows = to_native(rows, pass_through=True) - columns = to_native(columns, pass_through=True) if isinstance(rows, int) and isinstance(columns, (int, str)): return self.item(rows, columns) if isinstance(columns, (int, str)): col_name = columns if isinstance(columns, str) else self.columns[columns] series = self.get_column(col_name) - return series[rows] if rows is not None else series + return ( + series[cast("SingleIndexSelector | MultiIndexSelector", rows)] + if rows is not None + else series + ) + + if isinstance(rows, Series): + rows = rows._compliant_series + if isinstance(columns, Series): + columns = columns._compliant_series if rows is None: return self._with_compliant(compliant[:, columns]) if columns is None: diff --git a/narwhals/series.py b/narwhals/series.py index 9e23c73c39..3d47557498 100644 --- a/narwhals/series.py +++ b/narwhals/series.py @@ -173,7 +173,8 @@ def __getitem__(self, idx: SingleIndexSelector | MultiIndexSelector) -> Any | Se is_numpy_scalar(idx) and idx.dtype.kind in {"i", "u"} ): return self._compliant_series.item(cast("int", idx)) - idx = to_native(idx, pass_through=True) + if isinstance(idx, Series): + return self._with_compliant(self._compliant_series[idx._compliant_series]) return self._with_compliant(self._compliant_series[idx]) def __native_namespace__(self) -> ModuleType: diff --git a/narwhals/utils.py b/narwhals/utils.py index a6691cf357..ced8dda98b 100644 --- a/narwhals/utils.py +++ b/narwhals/utils.py @@ -1307,15 +1307,17 @@ def is_sequence_like_ints(sequence: Any | Sequence[_T]) -> bool: np = get_numpy() return ( ( - isinstance(sequence, Sequence) - and not isinstance(sequence, str) + is_sequence_but_not_str(sequence) and ( (len(sequence) > 0 and isinstance(sequence[0], int)) or (len(sequence) == 0) ) ) or (is_numpy_array_1d(sequence) and np.issubdtype(sequence.dtype, np.integer)) - or (is_compliant_series(sequence) and sequence.dtype.is_integer()) + or ( + (is_narwhals_series(sequence) or is_compliant_series(sequence)) + and sequence.dtype.is_integer() + ) ) @@ -1326,6 +1328,7 @@ def is_sequence_like( is_sequence_but_not_str(sequence) or is_numpy_array_1d(sequence) or is_narwhals_series(sequence) + or is_compliant_series(sequence) ) @@ -1342,12 +1345,7 @@ def is_slice_none(obj: object) -> TypeIs[_SliceNone]: def is_index_selector(cols: SingleIndexSelector | MultiIndexSelector | Any) -> bool: - return ( - isinstance(cols, int) - or is_sequence_like_ints(cols) - or is_slice_index(cols) - or (is_narwhals_series(cols) and cols.dtype.is_integer()) - ) + return isinstance(cols, int) or is_sequence_like_ints(cols) or is_slice_index(cols) def is_list_of(obj: Any, tp: type[_T]) -> TypeIs[list[_T]]: diff --git a/tests/frame/getitem_test.py b/tests/frame/getitem_test.py index 2a8b2b8b35..a49371f084 100644 --- a/tests/frame/getitem_test.py +++ b/tests/frame/getitem_test.py @@ -308,3 +308,23 @@ def test_slice_with_series( result = nw_df[nw_df["c"], ["a"]] expected = {"a": [1, 3, 2]} assert_equal_data(result, expected) + + +def test_horizontal_slice_with_series(constructor_eager: ConstructorEager) -> None: + data = {"a": [1, 2], "c": [0, 2], "d": ["c", "a"]} + nw_df = nw.from_native(constructor_eager(data), eager_only=True) + result = nw_df[nw_df["d"]] + expected = {"c": [0, 2], "a": [1, 2]} + assert_equal_data(result, expected) + + +def test_horizontal_slice_with_series_2( + constructor_eager: ConstructorEager, request: pytest.FixtureRequest +) -> None: + if "pandas_pyarrow" in str(constructor_eager): + request.applymarker(pytest.mark.xfail) + data = {"a": [1, 2], "c": [0, 2], "d": ["c", "a"]} + nw_df = nw.from_native(constructor_eager(data), eager_only=True) + result = nw_df[:, nw_df["c"]] + expected = {"a": [1, 2], "d": ["c", "a"]} + assert_equal_data(result, expected) From e85a1ad5ec08e81fdfbcacf1a29b8e452764602c Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sun, 20 Apr 2025 11:33:15 +0100 Subject: [PATCH 36/80] old polars fixup --- narwhals/_polars/dataframe.py | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/narwhals/_polars/dataframe.py b/narwhals/_polars/dataframe.py index 9439e41adf..3a9214bff8 100644 --- a/narwhals/_polars/dataframe.py +++ b/narwhals/_polars/dataframe.py @@ -264,15 +264,12 @@ def shape(self) -> tuple[int, int]: def __getitem__(self, item: Any) -> Any: rows, columns = item - if is_compliant_series(rows): - rows = rows.native - if is_compliant_series(columns): - columns = columns.native if self._backend_version > (0, 20, 30): return self._from_native_object(self.native.__getitem__((rows, columns))) else: # pragma: no cover # TODO(marco): we can delete this branch after Polars==0.20.30 becomes the minimum # Polars version we support + # This mostly mirrors the logic in `EagerDataFrame.__getitem__`. rows = list(rows) if isinstance(rows, tuple) else rows columns = list(columns) if isinstance(columns, tuple) else columns if is_numpy_array_1d(columns): @@ -283,20 +280,22 @@ def __getitem__(self, item: Any) -> Any: if not is_slice_none(columns): if hasattr(columns, "__len__") and len(columns) == 0: native = native.select() - if is_int_col_indexer and not isinstance(columns, (slice, range)): - native = native[:, columns] - elif is_int_col_indexer and isinstance(columns, (slice, range)): + if is_int_col_indexer and isinstance(columns, (slice, range)): native = native.select( self.columns[slice(columns.start, columns.stop, columns.step)] ) + elif is_int_col_indexer and is_compliant_series(columns): + native = native[:, cast("pl.Series", columns.native).to_list()] + elif is_int_col_indexer and is_sequence_like_ints(columns): + native = native[:, columns] elif isinstance(columns, (slice, range)): native = native.select( self.columns[ slice(*convert_str_slice_to_int_slice(columns, self.columns)) ] ) - elif is_int_col_indexer: - native = native[:, columns] + elif is_compliant_series(columns): + native = native.select(cast("pl.Series", columns.native).to_list()) elif is_sequence_like(columns): native = native.select(columns) else: @@ -306,11 +305,11 @@ def __getitem__(self, item: Any) -> Any: if not is_slice_none(rows): if isinstance(rows, int): native = native[[rows], :] - elif ( - isinstance(rows, (slice, range)) - or is_sequence_like_ints(rows) - or isinstance(rows, pl.Series) - ): + elif isinstance(rows, (slice, range)): + native = native[rows, :] + elif is_compliant_series(rows): + native = native[rows.native, :] + elif is_sequence_like(rows): native = native[rows, :] else: msg = "Unreachable code" From 037c9d729ca2b2ac49a1badc17f3fac4a77fbafe Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sun, 20 Apr 2025 11:42:10 +0100 Subject: [PATCH 37/80] type __getitem__ --- narwhals/_compliant/dataframe.py | 21 ++++++++++++++++++--- narwhals/_polars/dataframe.py | 7 ++++--- narwhals/dataframe.py | 1 + 3 files changed, 23 insertions(+), 6 deletions(-) diff --git a/narwhals/_compliant/dataframe.py b/narwhals/_compliant/dataframe.py index 7b60fcae8c..7f5c0eee9f 100644 --- a/narwhals/_compliant/dataframe.py +++ b/narwhals/_compliant/dataframe.py @@ -50,7 +50,10 @@ from narwhals.typing import AsofJoinStrategy from narwhals.typing import JoinStrategy from narwhals.typing import LazyUniqueKeepStrategy + from narwhals.typing import MultiColSelector + from narwhals.typing import MultiIndexSelector from narwhals.typing import PivotAgg + from narwhals.typing import SingleIndexSelector from narwhals.typing import SizeUnit from narwhals.typing import UniqueKeepStrategy from narwhals.typing import _2DArray @@ -107,7 +110,13 @@ def from_numpy( schema: Mapping[str, DType] | Schema | Sequence[str] | None, ) -> Self: ... def __array__(self, dtype: Any, *, copy: bool | None) -> _2DArray: ... - def __getitem__(self, item: tuple[Any, Any]) -> Self: ... + def __getitem__( + self, + item: tuple[ + SingleIndexSelector | MultiIndexSelector | CompliantSeriesT, + MultiIndexSelector | MultiColSelector | CompliantSeriesT, + ], + ) -> Self: ... def simple_select(self, *column_names: str) -> Self: """`select` where all args are column names.""" ... @@ -389,13 +398,19 @@ def _select_labels(self, indices: _StrIndexer) -> Self: ... def _select_slice_of_indices(self, indices: slice | range) -> Self: ... def _select_slice_of_labels(self, indices: slice | range) -> Self: ... - def __getitem__(self, item: tuple[Any, Any]) -> Self: + def __getitem__( + self, + item: tuple[ + SingleIndexSelector | MultiIndexSelector | CompliantSeriesT, + MultiIndexSelector | MultiColSelector | CompliantSeriesT, + ], + ) -> Self: rows, columns = item is_int_col_indexer = is_index_selector(columns) compliant = self if not is_slice_none(columns): - if hasattr(columns, "__len__") and len(columns) == 0: + if isinstance(columns, Sized) and len(columns) == 0: return compliant.select() if is_int_col_indexer and isinstance(columns, (slice, range)): compliant = compliant._select_slice_of_indices(columns) diff --git a/narwhals/_polars/dataframe.py b/narwhals/_polars/dataframe.py index 3a9214bff8..29233d07b4 100644 --- a/narwhals/_polars/dataframe.py +++ b/narwhals/_polars/dataframe.py @@ -6,6 +6,7 @@ from typing import Literal from typing import Mapping from typing import Sequence +from typing import Sized from typing import cast from typing import overload @@ -278,7 +279,7 @@ def __getitem__(self, item: Any) -> Any: is_int_col_indexer = is_index_selector(columns) native = self.native if not is_slice_none(columns): - if hasattr(columns, "__len__") and len(columns) == 0: + if isinstance(columns, Sized) and len(columns) == 0: native = native.select() if is_int_col_indexer and isinstance(columns, (slice, range)): native = native.select( @@ -287,7 +288,7 @@ def __getitem__(self, item: Any) -> Any: elif is_int_col_indexer and is_compliant_series(columns): native = native[:, cast("pl.Series", columns.native).to_list()] elif is_int_col_indexer and is_sequence_like_ints(columns): - native = native[:, columns] + native = native[:, cast("Sequence[int]", columns)] elif isinstance(columns, (slice, range)): native = native.select( self.columns[ @@ -297,7 +298,7 @@ def __getitem__(self, item: Any) -> Any: elif is_compliant_series(columns): native = native.select(cast("pl.Series", columns.native).to_list()) elif is_sequence_like(columns): - native = native.select(columns) + native = native.select(cast("Sequence[str]", columns)) else: msg = "Unreachable code" raise AssertionError(msg) diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index 86803f156c..fab9d9c201 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -926,6 +926,7 @@ def __getitem__( else series ) + rows = cast("SingleIndexSelector | MultiIndexSelector", rows) if isinstance(rows, Series): rows = rows._compliant_series if isinstance(columns, Series): From 0fcbcad5a11d894df488013d2875c23dbc15a5b6 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sun, 20 Apr 2025 12:08:27 +0100 Subject: [PATCH 38/80] fix modern polars too :sunglasses: --- narwhals/_polars/dataframe.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/narwhals/_polars/dataframe.py b/narwhals/_polars/dataframe.py index 29233d07b4..c435a0f9cb 100644 --- a/narwhals/_polars/dataframe.py +++ b/narwhals/_polars/dataframe.py @@ -266,6 +266,10 @@ def shape(self) -> tuple[int, int]: def __getitem__(self, item: Any) -> Any: rows, columns = item if self._backend_version > (0, 20, 30): + if is_compliant_series(rows): + rows = rows.native + if is_compliant_series(columns): + columns = columns.native return self._from_native_object(self.native.__getitem__((rows, columns))) else: # pragma: no cover # TODO(marco): we can delete this branch after Polars==0.20.30 becomes the minimum From 6d2c1aa682d4b48936456406101d032ee0b188fe Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sun, 20 Apr 2025 12:18:49 +0100 Subject: [PATCH 39/80] log type --- narwhals/_compliant/dataframe.py | 4 ++-- narwhals/_compliant/series.py | 2 +- narwhals/_polars/dataframe.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/narwhals/_compliant/dataframe.py b/narwhals/_compliant/dataframe.py index 7f5c0eee9f..ce4922ff5b 100644 --- a/narwhals/_compliant/dataframe.py +++ b/narwhals/_compliant/dataframe.py @@ -425,7 +425,7 @@ def __getitem__( elif is_sequence_like(columns): compliant = self._select_labels(columns) else: - msg = "Unreachable code" + msg = f"Unreachable code, got unexpected type: {type(columns)}" raise AssertionError(msg) if not is_slice_none(rows): @@ -438,7 +438,7 @@ def __getitem__( elif is_sequence_like_ints(rows): compliant = compliant._gather(rows) else: - msg = "Unreachable code" + msg = f"Unreachable code, got unexpected type: {type(rows)}" raise AssertionError(msg) return compliant diff --git a/narwhals/_compliant/series.py b/narwhals/_compliant/series.py index e3d8abf6d9..d635440132 100644 --- a/narwhals/_compliant/series.py +++ b/narwhals/_compliant/series.py @@ -335,7 +335,7 @@ def __getitem__(self, item: Any) -> Self: elif is_sequence_like_ints(item): return self._gather(item) else: - msg = "Unreachable code" + msg = f"Unreachable code, got unexpected type: {type(item)}" raise AssertionError(msg) @property diff --git a/narwhals/_polars/dataframe.py b/narwhals/_polars/dataframe.py index c435a0f9cb..cbaf492d6e 100644 --- a/narwhals/_polars/dataframe.py +++ b/narwhals/_polars/dataframe.py @@ -304,7 +304,7 @@ def __getitem__(self, item: Any) -> Any: elif is_sequence_like(columns): native = native.select(cast("Sequence[str]", columns)) else: - msg = "Unreachable code" + msg = f"Unreachable code, got unexpected type: {type(columns)}" raise AssertionError(msg) if not is_slice_none(rows): @@ -317,7 +317,7 @@ def __getitem__(self, item: Any) -> Any: elif is_sequence_like(rows): native = native[rows, :] else: - msg = "Unreachable code" + msg = f"Unreachable code, got unexpected type: {type(rows)}" raise AssertionError(msg) return self._with_native(native) From 5b3cfef16768f191afaa84c521bd75220bcd2d85 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sun, 20 Apr 2025 14:00:59 +0100 Subject: [PATCH 40/80] allow slicing Series with native objects --- narwhals/dataframe.py | 8 ++++---- narwhals/series.py | 15 +++++++++++++++ tests/frame/getitem_test.py | 9 ++++++++- 3 files changed, 27 insertions(+), 5 deletions(-) diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index fab9d9c201..5b357cdf6a 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -905,11 +905,11 @@ def __getitem__( columns = item else: msg = ( - f"Expected str or slice, got: {type(item)}.\n\n" + f"Unexpected type for `DataFrame.__getitem__`, got: {type(item)}.\n\n" "Hints:\n" - "- use `DataFrame.item` to select a single item.\n" - "- Use `DataFrame[indices, :]` to select rows positionally.\n" - "- Use `DataFrame.filter(mask)` to filter rows based on a boolean mask." + "- use `df.item` to select a single item.\n" + "- Use `df[indices, :]` to select rows positionally.\n" + "- Use `df.filter(mask)` to filter rows based on a boolean mask." ) raise TypeError(msg) diff --git a/narwhals/series.py b/narwhals/series.py index 3d47557498..90a1f2e810 100644 --- a/narwhals/series.py +++ b/narwhals/series.py @@ -27,6 +27,7 @@ from narwhals.utils import _validate_rolling_arguments from narwhals.utils import generate_repr from narwhals.utils import is_compliant_series +from narwhals.utils import is_index_selector from narwhals.utils import parse_version from narwhals.utils import supports_arrow_c_stream @@ -173,6 +174,20 @@ def __getitem__(self, idx: SingleIndexSelector | MultiIndexSelector) -> Any | Se is_numpy_scalar(idx) and idx.dtype.kind in {"i", "u"} ): return self._compliant_series.item(cast("int", idx)) + + if isinstance(idx, self.to_native().__class__): + idx = self._with_compliant(self._compliant_series._with_native(idx)) + + # For Series.__getitem__, we only + if not is_index_selector(idx): + msg = ( + f"Expected sequence-like or slice of ints, got: {type(idx)}.\n\n" + "Hints:\n" + "- use `s.item` to select a single item.\n" + "- Use `s[indices]` to select rows positionally.\n" + "- Use `s.filter(mask)` to filter rows based on a boolean mask." + ) + raise TypeError(msg) if isinstance(idx, Series): return self._with_compliant(self._compliant_series[idx._compliant_series]) return self._with_compliant(self._compliant_series[idx]) diff --git a/tests/frame/getitem_test.py b/tests/frame/getitem_test.py index a49371f084..718f947d11 100644 --- a/tests/frame/getitem_test.py +++ b/tests/frame/getitem_test.py @@ -76,7 +76,7 @@ def test_slice_int(constructor_eager: ConstructorEager) -> None: def test_slice_fails(constructor_eager: ConstructorEager) -> None: class Foo: ... - with pytest.raises(TypeError, match="Expected str or slice, got:"): + with pytest.raises(TypeError, match="Unexpected type.*, got:"): nw.from_native(constructor_eager(data), eager_only=True)[Foo()] # type: ignore[call-overload, unused-ignore] @@ -328,3 +328,10 @@ def test_horizontal_slice_with_series_2( result = nw_df[:, nw_df["c"]] expected = {"a": [1, 2], "d": ["c", "a"]} assert_equal_data(result, expected) + + +def test_native_slice_series(constructor_eager: ConstructorEager) -> None: + s = nw.from_native(constructor_eager({"a": [0, 2, 1]}), eager_only=True)["a"] + result = {"a": s[s.to_native()]} + expected = {"a": [0, 1, 2]} + assert_equal_data(result, expected) From 2791f492a9371a8cee70d1332573568a22b7ec28 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sun, 20 Apr 2025 14:17:32 +0100 Subject: [PATCH 41/80] coverage --- narwhals/series.py | 2 +- tests/series_only/getitem_test.py | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/narwhals/series.py b/narwhals/series.py index 90a1f2e810..b4b153f882 100644 --- a/narwhals/series.py +++ b/narwhals/series.py @@ -181,7 +181,7 @@ def __getitem__(self, idx: SingleIndexSelector | MultiIndexSelector) -> Any | Se # For Series.__getitem__, we only if not is_index_selector(idx): msg = ( - f"Expected sequence-like or slice of ints, got: {type(idx)}.\n\n" + f"Unexpected type for `Series.__getitem__`: {type(idx)}.\n\n" "Hints:\n" "- use `s.item` to select a single item.\n" "- Use `s[indices]` to select rows positionally.\n" diff --git a/tests/series_only/getitem_test.py b/tests/series_only/getitem_test.py index 8ecc5fb0a3..de3c5b9a17 100644 --- a/tests/series_only/getitem_test.py +++ b/tests/series_only/getitem_test.py @@ -61,3 +61,11 @@ def test_getitem_other_series(constructor_eager: ConstructorEager) -> None: ] other = nw.from_native(constructor_eager({"b": [1, 3]}), eager_only=True)["b"] assert_equal_data(series[other].to_frame(), {"a": [None, 3]}) + + +def test_getitem_invalid_series(constructor_eager: ConstructorEager) -> None: + series = nw.from_native(constructor_eager({"a": [1, None, 2, 3]}), eager_only=True)[ + "a" + ] + with pytest.raises(TypeError, match="Unexpected type"): + series[series > 1] From 498d6a45446e923f052e57284af30baa29667fa5 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sun, 20 Apr 2025 15:00:44 +0100 Subject: [PATCH 42/80] extra test --- narwhals/_arrow/dataframe.py | 10 +++++----- narwhals/_arrow/series.py | 4 ++-- narwhals/_compliant/dataframe.py | 13 +++++++------ narwhals/_compliant/series.py | 6 +++--- narwhals/_pandas_like/dataframe.py | 10 +++++----- narwhals/_pandas_like/series.py | 4 ++-- narwhals/typing.py | 12 ++++-------- narwhals/utils.py | 5 ++++- tests/frame/getitem_test.py | 14 ++++++++++++++ 9 files changed, 46 insertions(+), 32 deletions(-) diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index cf949f9cd8..971edac99c 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -56,11 +56,11 @@ from narwhals.typing import CompliantDataFrame from narwhals.typing import CompliantLazyFrame from narwhals.typing import JoinStrategy + from narwhals.typing import SizedMultiIndexSelector + from narwhals.typing import SizedMultiNameSelector from narwhals.typing import SizeUnit from narwhals.typing import UniqueKeepStrategy from narwhals.typing import _2DArray - from narwhals.typing import _IntIndexer - from narwhals.typing import _StrIndexer from narwhals.utils import Version from narwhals.utils import _FullContext @@ -244,7 +244,7 @@ def get_column(self, name: str) -> ArrowSeries: def __array__(self, dtype: Any, *, copy: bool | None) -> _2DArray: return self.native.__array__(dtype, copy=copy) - def _gather(self, item: _IntIndexer) -> Self: + def _gather(self, item: SizedMultiIndexSelector) -> Self: if len(item) == 0: return self._with_native(self.native.slice(0, 0)) if self._backend_version < (18,) and isinstance(item, tuple): @@ -272,12 +272,12 @@ def _select_slice_of_indices(self, item: slice | range) -> Self: self.native.select(self.columns[item.start : item.stop : item.step]) ) - def _select_indices(self, item: _IntIndexer) -> Self: + def _select_indices(self, item: SizedMultiIndexSelector) -> Self: if isinstance(item, pa.ChunkedArray): item = item.to_pylist() return self._with_native(self.native.select([self.columns[x] for x in item])) - def _select_labels(self, item: _StrIndexer) -> Self: + def _select_labels(self, item: SizedMultiNameSelector) -> Self: if isinstance(item, pa.ChunkedArray): item = item.to_pylist() return self._with_native(self.native.select(item)) diff --git a/narwhals/_arrow/series.py b/narwhals/_arrow/series.py index fba5a16bc0..a975858831 100644 --- a/narwhals/_arrow/series.py +++ b/narwhals/_arrow/series.py @@ -67,10 +67,10 @@ from narwhals.typing import PythonLiteral from narwhals.typing import RankMethod from narwhals.typing import RollingInterpolationMethod + from narwhals.typing import SizedMultiIndexSelector from narwhals.typing import TemporalLiteral from narwhals.typing import _1DArray from narwhals.typing import _2DArray - from narwhals.typing import _IntIndexer from narwhals.utils import Version from narwhals.utils import _FullContext @@ -407,7 +407,7 @@ def __native_namespace__(self) -> ModuleType: def name(self) -> str: return self._name - def _gather(self, item: _IntIndexer) -> Self: + def _gather(self, item: SizedMultiIndexSelector) -> Self: if len(item) == 0: return self._with_native(self.native.slice(0, 0)) if self._backend_version < (18,) and isinstance(item, tuple): diff --git a/narwhals/_compliant/dataframe.py b/narwhals/_compliant/dataframe.py index ce4922ff5b..255e3b63a7 100644 --- a/narwhals/_compliant/dataframe.py +++ b/narwhals/_compliant/dataframe.py @@ -10,6 +10,7 @@ from typing import Sequence from typing import Sized from typing import TypeVar +from typing import cast from typing import overload from narwhals._compliant.typing import CompliantExprT_contra @@ -54,11 +55,11 @@ from narwhals.typing import MultiIndexSelector from narwhals.typing import PivotAgg from narwhals.typing import SingleIndexSelector + from narwhals.typing import SizedMultiIndexSelector + from narwhals.typing import SizedMultiNameSelector from narwhals.typing import SizeUnit from narwhals.typing import UniqueKeepStrategy from narwhals.typing import _2DArray - from narwhals.typing import _IntIndexer - from narwhals.typing import _StrIndexer from narwhals.utils import Implementation from narwhals.utils import _FullContext @@ -391,10 +392,10 @@ def _numpy_column_names( ) -> list[str]: return list(columns or (f"column_{x}" for x in range(data.shape[1]))) - def _gather(self, indices: _IntIndexer) -> Self: ... + def _gather(self, indices: SizedMultiIndexSelector) -> Self: ... def _gather_slice(self, indices: slice | range) -> Self: ... - def _select_indices(self, indices: _IntIndexer) -> Self: ... - def _select_labels(self, indices: _StrIndexer) -> Self: ... + def _select_indices(self, indices: SizedMultiIndexSelector) -> Self: ... + def _select_labels(self, indices: SizedMultiNameSelector) -> Self: ... def _select_slice_of_indices(self, indices: slice | range) -> Self: ... def _select_slice_of_labels(self, indices: slice | range) -> Self: ... @@ -423,7 +424,7 @@ def __getitem__( elif is_compliant_series(columns): compliant = self._select_labels(columns.native) elif is_sequence_like(columns): - compliant = self._select_labels(columns) + compliant = self._select_labels(cast("SizedMultiNameSelector", columns)) else: msg = f"Unreachable code, got unexpected type: {type(columns)}" raise AssertionError(msg) diff --git a/narwhals/_compliant/series.py b/narwhals/_compliant/series.py index d635440132..7ea446cab6 100644 --- a/narwhals/_compliant/series.py +++ b/narwhals/_compliant/series.py @@ -49,9 +49,9 @@ from narwhals.typing import NumericLiteral from narwhals.typing import RankMethod from narwhals.typing import RollingInterpolationMethod + from narwhals.typing import SizedMultiIndexSelector from narwhals.typing import TemporalLiteral from narwhals.typing import _1DArray - from narwhals.typing import _IntIndexer from narwhals.utils import Implementation from narwhals.utils import Version from narwhals.utils import _FullContext @@ -320,7 +320,7 @@ def __narwhals_namespace__( def _to_expr(self) -> EagerExpr[Any, Any]: return self.__narwhals_namespace__()._expr._from_series(self) # type: ignore[no-any-return] - def _gather(self, indices: _IntIndexer) -> Self: ... + def _gather(self, indices: SizedMultiIndexSelector) -> Self: ... def _gather_slice(self, indices: slice | range) -> Self: ... def __getitem__(self, item: Any) -> Self: @@ -332,7 +332,7 @@ def __getitem__(self, item: Any) -> Self: return self._gather_slice(item) elif is_compliant_series(item): return self._gather(item.native) - elif is_sequence_like_ints(item): + elif isinstance(item, self._native_series) or is_sequence_like_ints(item): return self._gather(item) else: msg = f"Unreachable code, got unexpected type: {type(item)}" diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index f0d6806d74..a1f97a3bff 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -64,11 +64,11 @@ from narwhals.typing import DTypeBackend from narwhals.typing import JoinStrategy from narwhals.typing import PivotAgg + from narwhals.typing import SizedMultiIndexSelector + from narwhals.typing import SizedMultiNameSelector from narwhals.typing import SizeUnit from narwhals.typing import UniqueKeepStrategy from narwhals.typing import _2DArray - from narwhals.typing import _IntIndexer - from narwhals.typing import _StrIndexer from narwhals.utils import Version from narwhals.utils import _FullContext @@ -279,7 +279,7 @@ def get_column(self, name: str) -> PandasLikeSeries: def __array__(self, dtype: Any = None, *, copy: bool | None = None) -> _2DArray: return self.to_numpy(dtype=dtype, copy=copy) - def _gather(self, items: _IntIndexer) -> Self: + def _gather(self, items: SizedMultiIndexSelector) -> Self: items = list(items) if isinstance(items, tuple) else items return self._with_native(self.native.iloc[items, :]) @@ -302,14 +302,14 @@ def _select_slice_of_indices(self, item: slice | range) -> Self: validate_column_names=False, ) - def _select_indices(self, item: _IntIndexer) -> Self: + def _select_indices(self, item: SizedMultiIndexSelector) -> Self: item = list(item) if isinstance(item, tuple) else item return self._with_native( self.native.iloc[:, item], validate_column_names=False, ) - def _select_labels(self, indices: _StrIndexer) -> PandasLikeDataFrame: + def _select_labels(self, indices: SizedMultiNameSelector) -> PandasLikeDataFrame: return self._with_native(self.native.loc[:, indices]) # --- properties --- diff --git a/narwhals/_pandas_like/series.py b/narwhals/_pandas_like/series.py index d4f859a43d..01e4df0052 100644 --- a/narwhals/_pandas_like/series.py +++ b/narwhals/_pandas_like/series.py @@ -53,10 +53,10 @@ from narwhals.typing import NumericLiteral from narwhals.typing import RankMethod from narwhals.typing import RollingInterpolationMethod + from narwhals.typing import SizedMultiIndexSelector from narwhals.typing import TemporalLiteral from narwhals.typing import _1DArray from narwhals.typing import _AnyDArray - from narwhals.typing import _IntIndexer from narwhals.utils import Version from narwhals.utils import _FullContext @@ -146,7 +146,7 @@ def __narwhals_namespace__(self) -> PandasLikeNamespace: self._implementation, self._backend_version, self._version ) - def _gather(self, rows: _IntIndexer) -> Self: + def _gather(self, rows: SizedMultiIndexSelector) -> Self: rows = list(rows) if isinstance(rows, tuple) else rows return self._with_native(self.native.iloc[rows]) diff --git a/narwhals/typing.py b/narwhals/typing.py index abe469aa1c..1adfd945a2 100644 --- a/narwhals/typing.py +++ b/narwhals/typing.py @@ -310,12 +310,6 @@ def __native_namespace__(self) -> ModuleType: ... ) PythonLiteral: TypeAlias = "NonNestedLiteral | list[Any] | tuple[Any, ...]" -# Overloaded sequence of integers -_IntIndexer: TypeAlias = Any # noqa: PYI047 -# Overloaded sequence of strings -_StrIndexer: TypeAlias = Any # noqa: PYI047 - - # Annotations for `__getitem__` methods _Slice: TypeAlias = "slice[Any, Any, Any]" _SliceNone: TypeAlias = "slice[None, None, None]" @@ -326,8 +320,10 @@ def __native_namespace__(self) -> ModuleType: ... SingleIndexSelector: TypeAlias = int SingleNameSelector: TypeAlias = str -MultiIndexSelector: TypeAlias = "_SliceIndex | Sequence[int] | Series[Any] | _1DArray" -MultiNameSelector: TypeAlias = "_Slice | Sequence[str] | Series[Any] | _1DArray" +SizedMultiIndexSelector: TypeAlias = "Sequence[int] | Series[Any] | _1DArray" +MultiIndexSelector: TypeAlias = "_SliceIndex | SizedMultiIndexSelector" +SizedMultiNameSelector: TypeAlias = "Sequence[str] | Series[Any] | _1DArray" +MultiNameSelector: TypeAlias = "_Slice | SizedMultiNameSelector" SingleColSelector: TypeAlias = "SingleIndexSelector | SingleNameSelector" MultiColSelector: TypeAlias = "MultiIndexSelector | MultiNameSelector" diff --git a/narwhals/utils.py b/narwhals/utils.py index b5894bbd86..ec38672d83 100644 --- a/narwhals/utils.py +++ b/narwhals/utils.py @@ -87,6 +87,7 @@ from narwhals.typing import IntoSeriesT from narwhals.typing import MultiIndexSelector from narwhals.typing import SingleIndexSelector + from narwhals.typing import SizedMultiIndexSelector from narwhals.typing import SizeUnit from narwhals.typing import SupportsNativeNamespace from narwhals.typing import TimeUnit @@ -1268,7 +1269,9 @@ def is_sequence_but_not_str(sequence: Any | Sequence[_T]) -> TypeIs[Sequence[_T] return isinstance(sequence, Sequence) and not isinstance(sequence, str) -def is_sequence_like_ints(sequence: Any | Sequence[_T]) -> bool: +def is_sequence_like_ints( + sequence: Any | Sequence[_T], +) -> TypeIs[SizedMultiIndexSelector]: np = get_numpy() return ( ( diff --git a/tests/frame/getitem_test.py b/tests/frame/getitem_test.py index 718f947d11..9cc769b4fc 100644 --- a/tests/frame/getitem_test.py +++ b/tests/frame/getitem_test.py @@ -1,5 +1,6 @@ from __future__ import annotations +from datetime import datetime from typing import TYPE_CHECKING from typing import Any from typing import cast @@ -335,3 +336,16 @@ def test_native_slice_series(constructor_eager: ConstructorEager) -> None: result = {"a": s[s.to_native()]} expected = {"a": [0, 1, 2]} assert_equal_data(result, expected) + + +def test_pandas_non_str_columns() -> None: + # The general rule with getitem is: ints are always treated as positions. The rest, we should + # be able to hand down to the native frame. Here we check what happens for pandas with + # datetime column names. + df = nw.from_native( + pd.DataFrame({datetime(2020, 1, 1): [1, 2, 3], datetime(2020, 1, 2): [4, 5, 6]}), + eager_only=True, + ) + result = df[:, [datetime(2020, 1, 1)]] # type: ignore[index] + expected = {datetime(2020, 1, 1): [1, 2, 3]} + assert result.to_dict(as_series=False) == expected From 80bc093493de52a0750ea54750f508da2993a6fa Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sun, 20 Apr 2025 15:21:42 +0100 Subject: [PATCH 43/80] uurgh undo accidental change --- narwhals/_compliant/series.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/narwhals/_compliant/series.py b/narwhals/_compliant/series.py index 7ea446cab6..2fd6f35905 100644 --- a/narwhals/_compliant/series.py +++ b/narwhals/_compliant/series.py @@ -332,7 +332,7 @@ def __getitem__(self, item: Any) -> Self: return self._gather_slice(item) elif is_compliant_series(item): return self._gather(item.native) - elif isinstance(item, self._native_series) or is_sequence_like_ints(item): + elif is_sequence_like_ints(item): return self._gather(item) else: msg = f"Unreachable code, got unexpected type: {type(item)}" From 5fcc35b7232b09137e78fbb7cf37085ae3f6ed5f Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sun, 20 Apr 2025 15:37:02 +0100 Subject: [PATCH 44/80] pyright, simplify --- narwhals/_arrow/dataframe.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index 971edac99c..cbe967c4c9 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -19,6 +19,7 @@ from narwhals._arrow.utils import native_to_narwhals_dtype from narwhals._compliant import EagerDataFrame from narwhals._expression_parsing import ExprKind +from narwhals.dependencies import is_numpy_array from narwhals.exceptions import ShapeError from narwhals.utils import Implementation from narwhals.utils import Version @@ -48,6 +49,7 @@ from narwhals._arrow.group_by import ArrowGroupBy from narwhals._arrow.namespace import ArrowNamespace from narwhals._arrow.typing import ArrowChunkedArray + from narwhals._arrow.typing import Indices # type: ignore[attr-defined] from narwhals._arrow.typing import Mask # type: ignore[attr-defined] from narwhals._arrow.typing import Order # type: ignore[attr-defined] from narwhals._translate import IntoArrowTable @@ -249,7 +251,7 @@ def _gather(self, item: SizedMultiIndexSelector) -> Self: return self._with_native(self.native.slice(0, 0)) if self._backend_version < (18,) and isinstance(item, tuple): item = list(item) - return self._with_native(self.native.take(item)) + return self._with_native(self.native.take(cast("Indices", item))) def _gather_slice(self, item: slice | range) -> Self: start = item.start or 0 @@ -274,13 +276,16 @@ def _select_slice_of_indices(self, item: slice | range) -> Self: def _select_indices(self, item: SizedMultiIndexSelector) -> Self: if isinstance(item, pa.ChunkedArray): - item = item.to_pylist() - return self._with_native(self.native.select([self.columns[x] for x in item])) + item = cast("list[int]", item.to_pylist()) + if is_numpy_array(item): + item = cast("list[int]", item.tolist()) + return self._with_native(self.native.select(cast("Indices", item))) def _select_labels(self, item: SizedMultiNameSelector) -> Self: if isinstance(item, pa.ChunkedArray): - item = item.to_pylist() - return self._with_native(self.native.select(item)) + item = cast("list[str]", item.to_pylist()) + # pyarrow-stubs overly strict, accept list[str] | Indices + return self._with_native(self.native.select(item)) # pyright: ignore[reportArgumentType] @property def schema(self) -> dict[str, DType]: From 9591d10de3877ee708bbbe2519b315f8f8522eeb Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sun, 20 Apr 2025 15:41:12 +0100 Subject: [PATCH 45/80] _arrow/series.py typing --- narwhals/_arrow/series.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/narwhals/_arrow/series.py b/narwhals/_arrow/series.py index a975858831..5dfddc1b57 100644 --- a/narwhals/_arrow/series.py +++ b/narwhals/_arrow/series.py @@ -52,6 +52,7 @@ from narwhals._arrow.typing import ArrowArray from narwhals._arrow.typing import ArrowChunkedArray from narwhals._arrow.typing import Incomplete + from narwhals._arrow.typing import Indices # type: ignore[attr-defined] from narwhals._arrow.typing import NullPlacement from narwhals._arrow.typing import Order # type: ignore[attr-defined] from narwhals._arrow.typing import ScalarAny @@ -412,7 +413,7 @@ def _gather(self, item: SizedMultiIndexSelector) -> Self: return self._with_native(self.native.slice(0, 0)) if self._backend_version < (18,) and isinstance(item, tuple): item = list(item) - return self._with_native(self.native.take(item)) + return self._with_native(self.native.take(cast("Indices", item))) def _gather_slice(self, item: slice | range) -> Self: start = item.start or 0 From b43feae7f9dac159ef853f8e7dd43990c2243d26 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 20 Apr 2025 16:23:14 +0100 Subject: [PATCH 46/80] refactor(typing): Add `_SliceName`, reuse `_Slice` for `_SliceIndex` --- narwhals/typing.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/narwhals/typing.py b/narwhals/typing.py index 1adfd945a2..c62eae57d0 100644 --- a/narwhals/typing.py +++ b/narwhals/typing.py @@ -311,19 +311,19 @@ def __native_namespace__(self) -> ModuleType: ... PythonLiteral: TypeAlias = "NonNestedLiteral | list[Any] | tuple[Any, ...]" # Annotations for `__getitem__` methods -_Slice: TypeAlias = "slice[Any, Any, Any]" +_T = TypeVar("_T") +_Slice: TypeAlias = "slice[_T, Any, Any] | slice[Any, _T, Any] | slice[None, None, _T]" _SliceNone: TypeAlias = "slice[None, None, None]" -_SliceIndex: TypeAlias = ( - "slice[int, Any, Any] | slice[Any, int, Any] | slice[None, None, int] | _SliceNone" -) +_SliceIndex: TypeAlias = "_Slice[int] | _SliceNone" """E.g. `[1:]` or `[:3]` or `[::2]`.""" +_SliceName: TypeAlias = "_Slice[str] | _SliceNone" SingleIndexSelector: TypeAlias = int SingleNameSelector: TypeAlias = str SizedMultiIndexSelector: TypeAlias = "Sequence[int] | Series[Any] | _1DArray" MultiIndexSelector: TypeAlias = "_SliceIndex | SizedMultiIndexSelector" SizedMultiNameSelector: TypeAlias = "Sequence[str] | Series[Any] | _1DArray" -MultiNameSelector: TypeAlias = "_Slice | SizedMultiNameSelector" +MultiNameSelector: TypeAlias = "_SliceName | SizedMultiNameSelector" SingleColSelector: TypeAlias = "SingleIndexSelector | SingleNameSelector" MultiColSelector: TypeAlias = "MultiIndexSelector | MultiNameSelector" From f5ab0ab72271961293bebc856253cb443b1efe0c Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 20 Apr 2025 17:33:34 +0100 Subject: [PATCH 47/80] chore(typing): Unhide some issues Related https://github.com/narwhals-dev/narwhals/pull/2393#discussion_r2051747123 --- narwhals/_compliant/dataframe.py | 3 +-- narwhals/utils.py | 12 ++---------- 2 files changed, 3 insertions(+), 12 deletions(-) diff --git a/narwhals/_compliant/dataframe.py b/narwhals/_compliant/dataframe.py index 255e3b63a7..46758fd30a 100644 --- a/narwhals/_compliant/dataframe.py +++ b/narwhals/_compliant/dataframe.py @@ -10,7 +10,6 @@ from typing import Sequence from typing import Sized from typing import TypeVar -from typing import cast from typing import overload from narwhals._compliant.typing import CompliantExprT_contra @@ -424,7 +423,7 @@ def __getitem__( elif is_compliant_series(columns): compliant = self._select_labels(columns.native) elif is_sequence_like(columns): - compliant = self._select_labels(cast("SizedMultiNameSelector", columns)) + compliant = self._select_labels(columns) else: msg = f"Unreachable code, got unexpected type: {type(columns)}" raise AssertionError(msg) diff --git a/narwhals/utils.py b/narwhals/utils.py index ec38672d83..2f781c9c23 100644 --- a/narwhals/utils.py +++ b/narwhals/utils.py @@ -1879,15 +1879,7 @@ def wrapper(instance: _ContextT, *args: P.args, **kwds: P.kwargs) -> R: def convert_str_slice_to_int_slice( str_slice: slice | range, columns: list[str] ) -> tuple[int | None, int | None, int | None]: - start = ( - columns.index(cast("str", str_slice.start)) - if str_slice.start is not None - else None - ) - stop = ( - columns.index(cast("str", str_slice.stop)) + 1 - if str_slice.stop is not None - else None - ) + start = columns.index(str_slice.start) if str_slice.start is not None else None + stop = columns.index(str_slice.stop) + 1 if str_slice.stop is not None else None step = str_slice.step return (start, stop, step) From 1aa04be64915c5ae204bddd6658557c722a97b37 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 20 Apr 2025 17:34:31 +0100 Subject: [PATCH 48/80] chore: reorganise aliases --- narwhals/typing.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/narwhals/typing.py b/narwhals/typing.py index c62eae57d0..575e17a27a 100644 --- a/narwhals/typing.py +++ b/narwhals/typing.py @@ -314,16 +314,18 @@ def __native_namespace__(self) -> ModuleType: ... _T = TypeVar("_T") _Slice: TypeAlias = "slice[_T, Any, Any] | slice[Any, _T, Any] | slice[None, None, _T]" _SliceNone: TypeAlias = "slice[None, None, None]" +# Index/column positions +SingleIndexSelector: TypeAlias = int _SliceIndex: TypeAlias = "_Slice[int] | _SliceNone" """E.g. `[1:]` or `[:3]` or `[::2]`.""" - -_SliceName: TypeAlias = "_Slice[str] | _SliceNone" -SingleIndexSelector: TypeAlias = int -SingleNameSelector: TypeAlias = str SizedMultiIndexSelector: TypeAlias = "Sequence[int] | Series[Any] | _1DArray" MultiIndexSelector: TypeAlias = "_SliceIndex | SizedMultiIndexSelector" +# Labels/column names +SingleNameSelector: TypeAlias = str +_SliceName: TypeAlias = "_Slice[str] | _SliceNone" SizedMultiNameSelector: TypeAlias = "Sequence[str] | Series[Any] | _1DArray" MultiNameSelector: TypeAlias = "_SliceName | SizedMultiNameSelector" +# Mixed selectors SingleColSelector: TypeAlias = "SingleIndexSelector | SingleNameSelector" MultiColSelector: TypeAlias = "MultiIndexSelector | MultiNameSelector" From 7465b907487df2ccb351141b2938b8bfd2977d2f Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 20 Apr 2025 17:36:45 +0100 Subject: [PATCH 49/80] refactor: More matching aliases to guards Also plugged a soundness hole for `bool` --- narwhals/_compliant/dataframe.py | 6 +++--- narwhals/_compliant/series.py | 4 ++-- narwhals/_polars/dataframe.py | 4 ++-- narwhals/utils.py | 24 +++++++++++++++++------- 4 files changed, 24 insertions(+), 14 deletions(-) diff --git a/narwhals/_compliant/dataframe.py b/narwhals/_compliant/dataframe.py index 46758fd30a..f55df2ca7c 100644 --- a/narwhals/_compliant/dataframe.py +++ b/narwhals/_compliant/dataframe.py @@ -28,7 +28,7 @@ from narwhals.utils import is_compliant_series from narwhals.utils import is_index_selector from narwhals.utils import is_sequence_like -from narwhals.utils import is_sequence_like_ints +from narwhals.utils import is_sized_multi_index_selector from narwhals.utils import is_slice_none if TYPE_CHECKING: @@ -416,7 +416,7 @@ def __getitem__( compliant = compliant._select_slice_of_indices(columns) elif is_int_col_indexer and is_compliant_series(columns): compliant = self._select_indices(columns.native) - elif is_int_col_indexer and is_sequence_like_ints(columns): + elif is_int_col_indexer and is_sized_multi_index_selector(columns): compliant = compliant._select_indices(columns) elif isinstance(columns, (slice, range)): compliant = compliant._select_slice_of_labels(columns) @@ -435,7 +435,7 @@ def __getitem__( compliant = compliant._gather_slice(rows) elif is_compliant_series(rows): compliant = compliant._gather(rows.native) - elif is_sequence_like_ints(rows): + elif is_sized_multi_index_selector(rows): compliant = compliant._gather(rows) else: msg = f"Unreachable code, got unexpected type: {type(rows)}" diff --git a/narwhals/_compliant/series.py b/narwhals/_compliant/series.py index 2fd6f35905..5ab2d26109 100644 --- a/narwhals/_compliant/series.py +++ b/narwhals/_compliant/series.py @@ -24,7 +24,7 @@ from narwhals.utils import _StoresCompliant from narwhals.utils import _StoresNative from narwhals.utils import is_compliant_series -from narwhals.utils import is_sequence_like_ints +from narwhals.utils import is_sized_multi_index_selector from narwhals.utils import is_slice_none from narwhals.utils import unstable @@ -332,7 +332,7 @@ def __getitem__(self, item: Any) -> Self: return self._gather_slice(item) elif is_compliant_series(item): return self._gather(item.native) - elif is_sequence_like_ints(item): + elif is_sized_multi_index_selector(item): return self._gather(item) else: msg = f"Unreachable code, got unexpected type: {type(item)}" diff --git a/narwhals/_polars/dataframe.py b/narwhals/_polars/dataframe.py index cbaf492d6e..b5647671d4 100644 --- a/narwhals/_polars/dataframe.py +++ b/narwhals/_polars/dataframe.py @@ -25,7 +25,7 @@ from narwhals.utils import is_compliant_series from narwhals.utils import is_index_selector from narwhals.utils import is_sequence_like -from narwhals.utils import is_sequence_like_ints +from narwhals.utils import is_sized_multi_index_selector from narwhals.utils import is_slice_none from narwhals.utils import parse_columns_to_drop from narwhals.utils import parse_version @@ -291,7 +291,7 @@ def __getitem__(self, item: Any) -> Any: ) elif is_int_col_indexer and is_compliant_series(columns): native = native[:, cast("pl.Series", columns.native).to_list()] - elif is_int_col_indexer and is_sequence_like_ints(columns): + elif is_int_col_indexer and is_sized_multi_index_selector(columns): native = native[:, cast("Sequence[int]", columns)] elif isinstance(columns, (slice, range)): native = native.select( diff --git a/narwhals/utils.py b/narwhals/utils.py index 2f781c9c23..43315fac4c 100644 --- a/narwhals/utils.py +++ b/narwhals/utils.py @@ -1269,8 +1269,12 @@ def is_sequence_but_not_str(sequence: Any | Sequence[_T]) -> TypeIs[Sequence[_T] return isinstance(sequence, Sequence) and not isinstance(sequence, str) -def is_sequence_like_ints( - sequence: Any | Sequence[_T], +def is_slice_none(obj: object) -> TypeIs[_SliceNone]: + return isinstance(obj, slice) and obj == slice(None) + + +def is_sized_multi_index_selector( + sequence: Sequence[_T] | Any, ) -> TypeIs[SizedMultiIndexSelector]: np = get_numpy() return ( @@ -1300,7 +1304,7 @@ def is_sequence_like( ) -def is_slice_index(obj: _SliceIndex | Any) -> TypeIs[_SliceIndex]: +def is_slice_index(obj: Any) -> TypeIs[_SliceIndex]: return isinstance(obj, slice) and ( isinstance(obj.start, int) or isinstance(obj.stop, int) @@ -1308,12 +1312,18 @@ def is_slice_index(obj: _SliceIndex | Any) -> TypeIs[_SliceIndex]: ) -def is_slice_none(obj: object) -> TypeIs[_SliceNone]: - return isinstance(obj, slice) and obj == slice(None) +def is_single_index_selector(obj: Any) -> TypeIs[SingleIndexSelector]: + return isinstance(obj, int) and not isinstance(obj, bool) -def is_index_selector(cols: SingleIndexSelector | MultiIndexSelector | Any) -> bool: - return isinstance(cols, int) or is_sequence_like_ints(cols) or is_slice_index(cols) +def is_index_selector( + obj: Any, +) -> TypeIs[SingleIndexSelector] | TypeIs[MultiIndexSelector]: + return ( + is_single_index_selector(obj) + or is_sized_multi_index_selector(obj) + or is_slice_index(obj) + ) def is_list_of(obj: Any, tp: type[_T]) -> TypeIs[list[_T]]: From de2d9ba9bf21a3ca73d99e93d7db0e37e67aa335 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 20 Apr 2025 17:44:43 +0100 Subject: [PATCH 50/80] revert(typing): Remove every new `cast` Waaaaay too much is being hidden --- narwhals/_arrow/dataframe.py | 8 ++++---- narwhals/_arrow/series.py | 3 +-- narwhals/_polars/dataframe.py | 8 ++++---- narwhals/dataframe.py | 8 +------- narwhals/series.py | 3 +-- 5 files changed, 11 insertions(+), 19 deletions(-) diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index cbe967c4c9..0b693c2e2a 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -251,7 +251,7 @@ def _gather(self, item: SizedMultiIndexSelector) -> Self: return self._with_native(self.native.slice(0, 0)) if self._backend_version < (18,) and isinstance(item, tuple): item = list(item) - return self._with_native(self.native.take(cast("Indices", item))) + return self._with_native(self.native.take(item)) def _gather_slice(self, item: slice | range) -> Self: start = item.start or 0 @@ -276,14 +276,14 @@ def _select_slice_of_indices(self, item: slice | range) -> Self: def _select_indices(self, item: SizedMultiIndexSelector) -> Self: if isinstance(item, pa.ChunkedArray): - item = cast("list[int]", item.to_pylist()) + item = item.to_pylist() if is_numpy_array(item): - item = cast("list[int]", item.tolist()) + item = item.tolist() return self._with_native(self.native.select(cast("Indices", item))) def _select_labels(self, item: SizedMultiNameSelector) -> Self: if isinstance(item, pa.ChunkedArray): - item = cast("list[str]", item.to_pylist()) + item = item.to_pylist() # pyarrow-stubs overly strict, accept list[str] | Indices return self._with_native(self.native.select(item)) # pyright: ignore[reportArgumentType] diff --git a/narwhals/_arrow/series.py b/narwhals/_arrow/series.py index 5dfddc1b57..a975858831 100644 --- a/narwhals/_arrow/series.py +++ b/narwhals/_arrow/series.py @@ -52,7 +52,6 @@ from narwhals._arrow.typing import ArrowArray from narwhals._arrow.typing import ArrowChunkedArray from narwhals._arrow.typing import Incomplete - from narwhals._arrow.typing import Indices # type: ignore[attr-defined] from narwhals._arrow.typing import NullPlacement from narwhals._arrow.typing import Order # type: ignore[attr-defined] from narwhals._arrow.typing import ScalarAny @@ -413,7 +412,7 @@ def _gather(self, item: SizedMultiIndexSelector) -> Self: return self._with_native(self.native.slice(0, 0)) if self._backend_version < (18,) and isinstance(item, tuple): item = list(item) - return self._with_native(self.native.take(cast("Indices", item))) + return self._with_native(self.native.take(item)) def _gather_slice(self, item: slice | range) -> Self: start = item.start or 0 diff --git a/narwhals/_polars/dataframe.py b/narwhals/_polars/dataframe.py index b5647671d4..8f44cc09d7 100644 --- a/narwhals/_polars/dataframe.py +++ b/narwhals/_polars/dataframe.py @@ -290,9 +290,9 @@ def __getitem__(self, item: Any) -> Any: self.columns[slice(columns.start, columns.stop, columns.step)] ) elif is_int_col_indexer and is_compliant_series(columns): - native = native[:, cast("pl.Series", columns.native).to_list()] + native = native[:, columns.native.to_list()] elif is_int_col_indexer and is_sized_multi_index_selector(columns): - native = native[:, cast("Sequence[int]", columns)] + native = native[:, columns] elif isinstance(columns, (slice, range)): native = native.select( self.columns[ @@ -300,9 +300,9 @@ def __getitem__(self, item: Any) -> Any: ] ) elif is_compliant_series(columns): - native = native.select(cast("pl.Series", columns.native).to_list()) + native = native.select(columns.native.to_list()) elif is_sequence_like(columns): - native = native.select(cast("Sequence[str]", columns)) + native = native.select(columns) else: msg = f"Unreachable code, got unexpected type: {type(columns)}" raise AssertionError(msg) diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index 5b357cdf6a..8cecd900c2 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -12,7 +12,6 @@ from typing import NoReturn from typing import Sequence from typing import TypeVar -from typing import cast from typing import overload from warnings import warn @@ -920,13 +919,8 @@ def __getitem__( if isinstance(columns, (int, str)): col_name = columns if isinstance(columns, str) else self.columns[columns] series = self.get_column(col_name) - return ( - series[cast("SingleIndexSelector | MultiIndexSelector", rows)] - if rows is not None - else series - ) + return series[rows] if rows is not None else series - rows = cast("SingleIndexSelector | MultiIndexSelector", rows) if isinstance(rows, Series): rows = rows._compliant_series if isinstance(columns, Series): diff --git a/narwhals/series.py b/narwhals/series.py index b4b153f882..f83670e160 100644 --- a/narwhals/series.py +++ b/narwhals/series.py @@ -8,7 +8,6 @@ from typing import Literal from typing import Mapping from typing import Sequence -from typing import cast from typing import overload from narwhals.dependencies import is_numpy_scalar @@ -173,7 +172,7 @@ def __getitem__(self, idx: SingleIndexSelector | MultiIndexSelector) -> Any | Se if isinstance(idx, int) or ( is_numpy_scalar(idx) and idx.dtype.kind in {"i", "u"} ): - return self._compliant_series.item(cast("int", idx)) + return self._compliant_series.item(idx) if isinstance(idx, self.to_native().__class__): idx = self._with_compliant(self._compliant_series._with_native(idx)) From 9fec62dfce9057dbbc6a28fe923b1669f579ae09 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 20 Apr 2025 18:47:35 +0100 Subject: [PATCH 51/80] fix: `range` != `slice` Part of https://github.com/narwhals-dev/narwhals/pull/2393#discussion_r2051747123 --- narwhals/_arrow/dataframe.py | 3 ++- narwhals/_compliant/dataframe.py | 5 +++-- narwhals/_pandas_like/dataframe.py | 3 ++- narwhals/_polars/dataframe.py | 2 +- narwhals/utils.py | 5 +++-- 5 files changed, 11 insertions(+), 7 deletions(-) diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index 0b693c2e2a..2aac70841c 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -63,6 +63,7 @@ from narwhals.typing import SizeUnit from narwhals.typing import UniqueKeepStrategy from narwhals.typing import _2DArray + from narwhals.typing import _SliceName from narwhals.utils import Version from narwhals.utils import _FullContext @@ -265,7 +266,7 @@ def _gather_slice(self, item: slice | range) -> Self: raise NotImplementedError(msg) return self._with_native(self.native.slice(start, stop - start)) - def _select_slice_of_labels(self, item: slice | range) -> Self: + def _select_slice_of_labels(self, item: _SliceName) -> Self: start, stop, step = convert_str_slice_to_int_slice(item, self.columns) return self._with_native(self.native.select(self.columns[start:stop:step])) diff --git a/narwhals/_compliant/dataframe.py b/narwhals/_compliant/dataframe.py index f55df2ca7c..1074d87430 100644 --- a/narwhals/_compliant/dataframe.py +++ b/narwhals/_compliant/dataframe.py @@ -59,6 +59,7 @@ from narwhals.typing import SizeUnit from narwhals.typing import UniqueKeepStrategy from narwhals.typing import _2DArray + from narwhals.typing import _SliceName from narwhals.utils import Implementation from narwhals.utils import _FullContext @@ -396,7 +397,7 @@ def _gather_slice(self, indices: slice | range) -> Self: ... def _select_indices(self, indices: SizedMultiIndexSelector) -> Self: ... def _select_labels(self, indices: SizedMultiNameSelector) -> Self: ... def _select_slice_of_indices(self, indices: slice | range) -> Self: ... - def _select_slice_of_labels(self, indices: slice | range) -> Self: ... + def _select_slice_of_labels(self, item: _SliceName) -> Self: ... def __getitem__( self, @@ -418,7 +419,7 @@ def __getitem__( compliant = self._select_indices(columns.native) elif is_int_col_indexer and is_sized_multi_index_selector(columns): compliant = compliant._select_indices(columns) - elif isinstance(columns, (slice, range)): + elif isinstance(columns, slice): compliant = compliant._select_slice_of_labels(columns) elif is_compliant_series(columns): compliant = self._select_labels(columns.native) diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index a1f97a3bff..037d70ccbc 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -69,6 +69,7 @@ from narwhals.typing import SizeUnit from narwhals.typing import UniqueKeepStrategy from narwhals.typing import _2DArray + from narwhals.typing import _SliceName from narwhals.utils import Version from narwhals.utils import _FullContext @@ -289,7 +290,7 @@ def _gather_slice(self, item: slice | range) -> Self: validate_column_names=False, ) - def _select_slice_of_labels(self, item: slice | range) -> Self: + def _select_slice_of_labels(self, item: _SliceName) -> Self: start, stop, step = convert_str_slice_to_int_slice(item, self.native.columns) return self._with_native( self.native.iloc[:, slice(start, stop, step)], diff --git a/narwhals/_polars/dataframe.py b/narwhals/_polars/dataframe.py index 8f44cc09d7..5980774458 100644 --- a/narwhals/_polars/dataframe.py +++ b/narwhals/_polars/dataframe.py @@ -293,7 +293,7 @@ def __getitem__(self, item: Any) -> Any: native = native[:, columns.native.to_list()] elif is_int_col_indexer and is_sized_multi_index_selector(columns): native = native[:, columns] - elif isinstance(columns, (slice, range)): + elif isinstance(columns, slice): native = native.select( self.columns[ slice(*convert_str_slice_to_int_slice(columns, self.columns)) diff --git a/narwhals/utils.py b/narwhals/utils.py index 43315fac4c..af8473722d 100644 --- a/narwhals/utils.py +++ b/narwhals/utils.py @@ -93,6 +93,7 @@ from narwhals.typing import TimeUnit from narwhals.typing import _1DArray from narwhals.typing import _SliceIndex + from narwhals.typing import _SliceName from narwhals.typing import _SliceNone FrameOrSeriesT = TypeVar( @@ -1887,8 +1888,8 @@ def wrapper(instance: _ContextT, *args: P.args, **kwds: P.kwargs) -> R: def convert_str_slice_to_int_slice( - str_slice: slice | range, columns: list[str] -) -> tuple[int | None, int | None, int | None]: + str_slice: _SliceName, columns: Sequence[str] +) -> tuple[int | None, int | None, Any]: start = columns.index(str_slice.start) if str_slice.start is not None else None stop = columns.index(str_slice.stop) + 1 if str_slice.stop is not None else None step = str_slice.step From 156037ffd1e1b1ec3b651698161fb6d0074345a3 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 20 Apr 2025 19:09:41 +0100 Subject: [PATCH 52/80] fix(typing): `EagerDataFrame.__getitem__` exhaustive happy See https://github.com/narwhals-dev/narwhals/pull/2393#issuecomment-2817253986 --- narwhals/_compliant/dataframe.py | 20 ++++++++-------- narwhals/utils.py | 39 +++++++++++++------------------- 2 files changed, 26 insertions(+), 33 deletions(-) diff --git a/narwhals/_compliant/dataframe.py b/narwhals/_compliant/dataframe.py index 1074d87430..8442e244cb 100644 --- a/narwhals/_compliant/dataframe.py +++ b/narwhals/_compliant/dataframe.py @@ -13,6 +13,7 @@ from typing import overload from narwhals._compliant.typing import CompliantExprT_contra +from narwhals._compliant.typing import CompliantSeriesAny from narwhals._compliant.typing import CompliantSeriesT from narwhals._compliant.typing import EagerExprT from narwhals._compliant.typing import EagerSeriesT @@ -402,23 +403,22 @@ def _select_slice_of_labels(self, item: _SliceName) -> Self: ... def __getitem__( self, item: tuple[ - SingleIndexSelector | MultiIndexSelector | CompliantSeriesT, - MultiIndexSelector | MultiColSelector | CompliantSeriesT, + SingleIndexSelector | MultiIndexSelector | CompliantSeriesAny, + MultiIndexSelector | MultiColSelector | CompliantSeriesAny, ], ) -> Self: rows, columns = item - - is_int_col_indexer = is_index_selector(columns) compliant = self if not is_slice_none(columns): if isinstance(columns, Sized) and len(columns) == 0: return compliant.select() - if is_int_col_indexer and isinstance(columns, (slice, range)): - compliant = compliant._select_slice_of_indices(columns) - elif is_int_col_indexer and is_compliant_series(columns): - compliant = self._select_indices(columns.native) - elif is_int_col_indexer and is_sized_multi_index_selector(columns): - compliant = compliant._select_indices(columns) + if is_index_selector(columns): + if isinstance(columns, (slice, range)): + compliant = compliant._select_slice_of_indices(columns) + elif is_compliant_series(columns): + compliant = self._select_indices(columns.native) + else: + compliant = compliant._select_indices(columns) elif isinstance(columns, slice): compliant = compliant._select_slice_of_labels(columns) elif is_compliant_series(columns): diff --git a/narwhals/utils.py b/narwhals/utils.py index af8473722d..c459d88ed0 100644 --- a/narwhals/utils.py +++ b/narwhals/utils.py @@ -1266,42 +1266,37 @@ def parse_columns_to_drop( return to_drop -def is_sequence_but_not_str(sequence: Any | Sequence[_T]) -> TypeIs[Sequence[_T]]: +def is_sequence_but_not_str(sequence: Sequence[_T] | Any) -> TypeIs[Sequence[_T]]: return isinstance(sequence, Sequence) and not isinstance(sequence, str) -def is_slice_none(obj: object) -> TypeIs[_SliceNone]: +def is_slice_none(obj: Any) -> TypeIs[_SliceNone]: return isinstance(obj, slice) and obj == slice(None) -def is_sized_multi_index_selector( - sequence: Sequence[_T] | Any, -) -> TypeIs[SizedMultiIndexSelector]: +def is_sized_multi_index_selector(obj: Any) -> TypeIs[SizedMultiIndexSelector]: np = get_numpy() return ( ( - is_sequence_but_not_str(sequence) - and ( - (len(sequence) > 0 and isinstance(sequence[0], int)) - or (len(sequence) == 0) - ) + is_sequence_but_not_str(obj) + and ((len(obj) > 0 and isinstance(obj[0], int)) or (len(obj) == 0)) ) - or (is_numpy_array_1d(sequence) and np.issubdtype(sequence.dtype, np.integer)) + or (is_numpy_array_1d(obj) and np.issubdtype(obj.dtype, np.integer)) or ( - (is_narwhals_series(sequence) or is_compliant_series(sequence)) - and sequence.dtype.is_integer() + (is_narwhals_series(obj) or is_compliant_series(obj)) + and obj.dtype.is_integer() ) ) def is_sequence_like( - sequence: Sequence[_T] | Any, -) -> TypeIs[Sequence[_T]] | TypeIs[Series[Any]] | TypeIs[_1DArray]: + obj: Sequence[_T] | Any, +) -> TypeIs[Sequence[_T] | Series[Any] | _1DArray]: return ( - is_sequence_but_not_str(sequence) - or is_numpy_array_1d(sequence) - or is_narwhals_series(sequence) - or is_compliant_series(sequence) + is_sequence_but_not_str(obj) + or is_numpy_array_1d(obj) + or is_narwhals_series(obj) + or is_compliant_series(obj) ) @@ -1314,12 +1309,10 @@ def is_slice_index(obj: Any) -> TypeIs[_SliceIndex]: def is_single_index_selector(obj: Any) -> TypeIs[SingleIndexSelector]: - return isinstance(obj, int) and not isinstance(obj, bool) + return bool(isinstance(obj, int) and not isinstance(obj, bool)) -def is_index_selector( - obj: Any, -) -> TypeIs[SingleIndexSelector] | TypeIs[MultiIndexSelector]: +def is_index_selector(obj: Any) -> TypeIs[SingleIndexSelector | MultiIndexSelector]: return ( is_single_index_selector(obj) or is_sized_multi_index_selector(obj) From 7f864c5cc6043d8180be2dc96d15c93042494c0b Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 20 Apr 2025 19:26:34 +0100 Subject: [PATCH 53/80] chore(typing): Update some signatures --- narwhals/_arrow/dataframe.py | 5 +++-- narwhals/_compliant/dataframe.py | 5 +++-- narwhals/_pandas_like/dataframe.py | 5 +++-- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index 2aac70841c..f91c2af923 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -63,6 +63,7 @@ from narwhals.typing import SizeUnit from narwhals.typing import UniqueKeepStrategy from narwhals.typing import _2DArray + from narwhals.typing import _SliceIndex from narwhals.typing import _SliceName from narwhals.utils import Version from narwhals.utils import _FullContext @@ -254,7 +255,7 @@ def _gather(self, item: SizedMultiIndexSelector) -> Self: item = list(item) return self._with_native(self.native.take(item)) - def _gather_slice(self, item: slice | range) -> Self: + def _gather_slice(self, item: _SliceIndex | range) -> Self: start = item.start or 0 stop = item.stop if item.stop is not None else len(self.native) if start < 0: @@ -270,7 +271,7 @@ def _select_slice_of_labels(self, item: _SliceName) -> Self: start, stop, step = convert_str_slice_to_int_slice(item, self.columns) return self._with_native(self.native.select(self.columns[start:stop:step])) - def _select_slice_of_indices(self, item: slice | range) -> Self: + def _select_slice_of_indices(self, item: _SliceIndex | range) -> Self: return self._with_native( self.native.select(self.columns[item.start : item.stop : item.step]) ) diff --git a/narwhals/_compliant/dataframe.py b/narwhals/_compliant/dataframe.py index 8442e244cb..d18b8613a4 100644 --- a/narwhals/_compliant/dataframe.py +++ b/narwhals/_compliant/dataframe.py @@ -60,6 +60,7 @@ from narwhals.typing import SizeUnit from narwhals.typing import UniqueKeepStrategy from narwhals.typing import _2DArray + from narwhals.typing import _SliceIndex from narwhals.typing import _SliceName from narwhals.utils import Implementation from narwhals.utils import _FullContext @@ -394,10 +395,10 @@ def _numpy_column_names( return list(columns or (f"column_{x}" for x in range(data.shape[1]))) def _gather(self, indices: SizedMultiIndexSelector) -> Self: ... - def _gather_slice(self, indices: slice | range) -> Self: ... + def _gather_slice(self, indices: _SliceIndex | range) -> Self: ... def _select_indices(self, indices: SizedMultiIndexSelector) -> Self: ... def _select_labels(self, indices: SizedMultiNameSelector) -> Self: ... - def _select_slice_of_indices(self, indices: slice | range) -> Self: ... + def _select_slice_of_indices(self, indices: _SliceIndex | range) -> Self: ... def _select_slice_of_labels(self, item: _SliceName) -> Self: ... def __getitem__( diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index 037d70ccbc..156766cbbb 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -69,6 +69,7 @@ from narwhals.typing import SizeUnit from narwhals.typing import UniqueKeepStrategy from narwhals.typing import _2DArray + from narwhals.typing import _SliceIndex from narwhals.typing import _SliceName from narwhals.utils import Version from narwhals.utils import _FullContext @@ -284,7 +285,7 @@ def _gather(self, items: SizedMultiIndexSelector) -> Self: items = list(items) if isinstance(items, tuple) else items return self._with_native(self.native.iloc[items, :]) - def _gather_slice(self, item: slice | range) -> Self: + def _gather_slice(self, item: _SliceIndex | range) -> Self: return self._with_native( self.native.iloc[slice(item.start, item.stop, item.step), :], validate_column_names=False, @@ -297,7 +298,7 @@ def _select_slice_of_labels(self, item: _SliceName) -> Self: validate_column_names=False, ) - def _select_slice_of_indices(self, item: slice | range) -> Self: + def _select_slice_of_indices(self, item: _SliceIndex | range) -> Self: return self._with_native( self.native.iloc[:, slice(item.start, item.stop, item.step)], validate_column_names=False, From 4276779d2ed177c6085cc5deb1442280f3eaa031 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 20 Apr 2025 19:28:05 +0100 Subject: [PATCH 54/80] fix(typing): Some `PolarsDataFrame` progress Really needs to be split up --- narwhals/_polars/dataframe.py | 40 +++++++++++++++++++++-------------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/narwhals/_polars/dataframe.py b/narwhals/_polars/dataframe.py index 5980774458..228f6e92bd 100644 --- a/narwhals/_polars/dataframe.py +++ b/narwhals/_polars/dataframe.py @@ -25,7 +25,6 @@ from narwhals.utils import is_compliant_series from narwhals.utils import is_index_selector from narwhals.utils import is_sequence_like -from narwhals.utils import is_sized_multi_index_selector from narwhals.utils import is_slice_none from narwhals.utils import parse_columns_to_drop from narwhals.utils import parse_version @@ -51,7 +50,10 @@ from narwhals.typing import CompliantDataFrame from narwhals.typing import CompliantLazyFrame from narwhals.typing import JoinStrategy + from narwhals.typing import MultiColSelector + from narwhals.typing import MultiIndexSelector from narwhals.typing import PivotAgg + from narwhals.typing import SingleIndexSelector from narwhals.typing import _2DArray from narwhals.utils import Version from narwhals.utils import _FullContext @@ -263,14 +265,20 @@ def collect_schema(self) -> dict[str, DType]: def shape(self) -> tuple[int, int]: return self.native.shape - def __getitem__(self, item: Any) -> Any: + def __getitem__( + self, + item: tuple[ + SingleIndexSelector | MultiIndexSelector | PolarsSeries, + MultiIndexSelector | MultiColSelector | PolarsSeries, + ], + ) -> Any: rows, columns = item if self._backend_version > (0, 20, 30): - if is_compliant_series(rows): - rows = rows.native - if is_compliant_series(columns): - columns = columns.native - return self._from_native_object(self.native.__getitem__((rows, columns))) + rows_native = rows.native if is_compliant_series(rows) else rows + columns_native = columns.native if is_compliant_series(columns) else columns + selector = rows_native, columns_native + selected = self.native.__getitem__(selector) + return self._from_native_object(selected) else: # pragma: no cover # TODO(marco): we can delete this branch after Polars==0.20.30 becomes the minimum # Polars version we support @@ -280,19 +288,19 @@ def __getitem__(self, item: Any) -> Any: if is_numpy_array_1d(columns): columns = columns.tolist() - is_int_col_indexer = is_index_selector(columns) native = self.native if not is_slice_none(columns): if isinstance(columns, Sized) and len(columns) == 0: native = native.select() - if is_int_col_indexer and isinstance(columns, (slice, range)): - native = native.select( - self.columns[slice(columns.start, columns.stop, columns.step)] - ) - elif is_int_col_indexer and is_compliant_series(columns): - native = native[:, columns.native.to_list()] - elif is_int_col_indexer and is_sized_multi_index_selector(columns): - native = native[:, columns] + if is_index_selector(columns): + if isinstance(columns, (slice, range)): + native = native.select( + self.columns[slice(columns.start, columns.stop, columns.step)] + ) + elif is_compliant_series(columns): + native = native[:, columns.native.to_list()] + else: + native = native[:, columns] elif isinstance(columns, slice): native = native.select( self.columns[ From 7a579d34fec43e1d873cf0dcab54697a97656a35 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 20 Apr 2025 19:31:15 +0100 Subject: [PATCH 55/80] fix(typing): `Series` allows `_NumpyScalar` --- narwhals/_compliant/series.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/narwhals/_compliant/series.py b/narwhals/_compliant/series.py index 5ab2d26109..d8e015f104 100644 --- a/narwhals/_compliant/series.py +++ b/narwhals/_compliant/series.py @@ -52,6 +52,7 @@ from narwhals.typing import SizedMultiIndexSelector from narwhals.typing import TemporalLiteral from narwhals.typing import _1DArray + from narwhals.typing import _NumpyScalar from narwhals.utils import Implementation from narwhals.utils import Version from narwhals.utils import _FullContext @@ -192,7 +193,7 @@ def is_nan(self) -> Self: ... def is_null(self) -> Self: ... def is_sorted(self, *, descending: bool) -> bool: ... def is_unique(self) -> Self: ... - def item(self, index: int | None) -> Any: ... + def item(self, index: int | _NumpyScalar | None) -> Any: ... def len(self) -> int: ... def max(self) -> Any: ... def mean(self) -> float: ... From ff6f2eb6532e694bfa1c789b21856dd232e936e8 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 20 Apr 2025 21:03:19 +0100 Subject: [PATCH 56/80] refactor: Don't pass down numpy scalar (#1515) didn't need to introduce this for every backend # --- narwhals/_compliant/series.py | 3 +-- narwhals/series.py | 1 + 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/narwhals/_compliant/series.py b/narwhals/_compliant/series.py index d8e015f104..5ab2d26109 100644 --- a/narwhals/_compliant/series.py +++ b/narwhals/_compliant/series.py @@ -52,7 +52,6 @@ from narwhals.typing import SizedMultiIndexSelector from narwhals.typing import TemporalLiteral from narwhals.typing import _1DArray - from narwhals.typing import _NumpyScalar from narwhals.utils import Implementation from narwhals.utils import Version from narwhals.utils import _FullContext @@ -193,7 +192,7 @@ def is_nan(self) -> Self: ... def is_null(self) -> Self: ... def is_sorted(self, *, descending: bool) -> bool: ... def is_unique(self) -> Self: ... - def item(self, index: int | _NumpyScalar | None) -> Any: ... + def item(self, index: int | None) -> Any: ... def len(self) -> int: ... def max(self) -> Any: ... def mean(self) -> float: ... diff --git a/narwhals/series.py b/narwhals/series.py index f83670e160..85dd3442a0 100644 --- a/narwhals/series.py +++ b/narwhals/series.py @@ -172,6 +172,7 @@ def __getitem__(self, idx: SingleIndexSelector | MultiIndexSelector) -> Any | Se if isinstance(idx, int) or ( is_numpy_scalar(idx) and idx.dtype.kind in {"i", "u"} ): + idx = int(idx) if not isinstance(idx, int) else idx return self._compliant_series.item(idx) if isinstance(idx, self.to_native().__class__): From 7edc0ca5b50e28c8e02841a8c9ccb5523db0fb7c Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 20 Apr 2025 21:19:03 +0100 Subject: [PATCH 57/80] fix: Narrow correctly in `DataFrame.__getitem__` Fixes https://github.com/narwhals-dev/narwhals/pull/2393#discussion_r2051784049 --- narwhals/dataframe.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index 8cecd900c2..cf7e2054b9 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -914,13 +914,12 @@ def __getitem__( compliant = self._compliant_frame - if isinstance(rows, int) and isinstance(columns, (int, str)): - return self.item(rows, columns) - if isinstance(columns, (int, str)): + if isinstance(columns, (int, str)) and not isinstance(rows, str): + if isinstance(rows, int): + return self.item(rows, columns) col_name = columns if isinstance(columns, str) else self.columns[columns] series = self.get_column(col_name) return series[rows] if rows is not None else series - if isinstance(rows, Series): rows = rows._compliant_series if isinstance(columns, Series): From becdc48f284b6881091442b31140d331dd131b7a Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sun, 20 Apr 2025 21:27:24 +0100 Subject: [PATCH 58/80] fix a mypy ``` narwhals/_compliant/dataframe.py:418: error: Argument 1 to "_select_slice_of_indices" of "EagerDataFrame" has incompatible type "slice[int, Any, Any] | slice[Any, int, Any] | slice[None, None, int] | slice[None, None, None] | range | slice[str, Any, Any] | slice[Any, str, Any] | slice[None, None, str]"; expected "slice[int, Any, Any] | slice[Any, int, Any] | slice[None, None, int] | slice[None, None, None] | range" [arg-type] ... compliant = compliant._select_slice_of_indices(columns) ^~~~~~~ ``` --- narwhals/_compliant/dataframe.py | 4 +++- narwhals/utils.py | 4 ++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/narwhals/_compliant/dataframe.py b/narwhals/_compliant/dataframe.py index d18b8613a4..830169c472 100644 --- a/narwhals/_compliant/dataframe.py +++ b/narwhals/_compliant/dataframe.py @@ -28,8 +28,10 @@ from narwhals.utils import deprecated from narwhals.utils import is_compliant_series from narwhals.utils import is_index_selector +from narwhals.utils import is_range from narwhals.utils import is_sequence_like from narwhals.utils import is_sized_multi_index_selector +from narwhals.utils import is_slice_index from narwhals.utils import is_slice_none if TYPE_CHECKING: @@ -414,7 +416,7 @@ def __getitem__( if isinstance(columns, Sized) and len(columns) == 0: return compliant.select() if is_index_selector(columns): - if isinstance(columns, (slice, range)): + if is_slice_index(columns) or is_range(columns): compliant = compliant._select_slice_of_indices(columns) elif is_compliant_series(columns): compliant = self._select_indices(columns.native) diff --git a/narwhals/utils.py b/narwhals/utils.py index c459d88ed0..dcd8735104 100644 --- a/narwhals/utils.py +++ b/narwhals/utils.py @@ -1308,6 +1308,10 @@ def is_slice_index(obj: Any) -> TypeIs[_SliceIndex]: ) +def is_range(obj: Any) -> TypeIs[range]: + return isinstance(obj, range) + + def is_single_index_selector(obj: Any) -> TypeIs[SingleIndexSelector]: return bool(isinstance(obj, int) and not isinstance(obj, bool)) From c29ecfdb4c9a860e021cace67a41d81891cf1499 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Mon, 21 Apr 2025 08:20:30 +0100 Subject: [PATCH 59/80] silence some _polars/_arrow type errors for now --- narwhals/_arrow/dataframe.py | 6 +++--- narwhals/_arrow/series.py | 2 +- narwhals/_polars/dataframe.py | 14 +++++++------- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index f91c2af923..230e83d263 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -253,7 +253,7 @@ def _gather(self, item: SizedMultiIndexSelector) -> Self: return self._with_native(self.native.slice(0, 0)) if self._backend_version < (18,) and isinstance(item, tuple): item = list(item) - return self._with_native(self.native.take(item)) + return self._with_native(self.native.take(item)) # pyright: ignore[reportArgumentType] def _gather_slice(self, item: _SliceIndex | range) -> Self: start = item.start or 0 @@ -278,14 +278,14 @@ def _select_slice_of_indices(self, item: _SliceIndex | range) -> Self: def _select_indices(self, item: SizedMultiIndexSelector) -> Self: if isinstance(item, pa.ChunkedArray): - item = item.to_pylist() + item = item.to_pylist() # pyright: ignore[reportAssignmentType] if is_numpy_array(item): item = item.tolist() return self._with_native(self.native.select(cast("Indices", item))) def _select_labels(self, item: SizedMultiNameSelector) -> Self: if isinstance(item, pa.ChunkedArray): - item = item.to_pylist() + item = item.to_pylist() # pyright: ignore[reportAssignmentType] # pyarrow-stubs overly strict, accept list[str] | Indices return self._with_native(self.native.select(item)) # pyright: ignore[reportArgumentType] diff --git a/narwhals/_arrow/series.py b/narwhals/_arrow/series.py index a975858831..7536a52992 100644 --- a/narwhals/_arrow/series.py +++ b/narwhals/_arrow/series.py @@ -412,7 +412,7 @@ def _gather(self, item: SizedMultiIndexSelector) -> Self: return self._with_native(self.native.slice(0, 0)) if self._backend_version < (18,) and isinstance(item, tuple): item = list(item) - return self._with_native(self.native.take(item)) + return self._with_native(self.native.take(item)) # pyright: ignore[reportArgumentType] def _gather_slice(self, item: slice | range) -> Self: start = item.start or 0 diff --git a/narwhals/_polars/dataframe.py b/narwhals/_polars/dataframe.py index 228f6e92bd..17c389f6f5 100644 --- a/narwhals/_polars/dataframe.py +++ b/narwhals/_polars/dataframe.py @@ -277,7 +277,7 @@ def __getitem__( rows_native = rows.native if is_compliant_series(rows) else rows columns_native = columns.native if is_compliant_series(columns) else columns selector = rows_native, columns_native - selected = self.native.__getitem__(selector) + selected = self.native.__getitem__(selector) # type: ignore[index] return self._from_native_object(selected) else: # pragma: no cover # TODO(marco): we can delete this branch after Polars==0.20.30 becomes the minimum @@ -300,7 +300,7 @@ def __getitem__( elif is_compliant_series(columns): native = native[:, columns.native.to_list()] else: - native = native[:, columns] + native = native[:, columns] # type: ignore[index] elif isinstance(columns, slice): native = native.select( self.columns[ @@ -317,18 +317,18 @@ def __getitem__( if not is_slice_none(rows): if isinstance(rows, int): - native = native[[rows], :] + native = native[[rows], :] # pyright: ignore[reportArgumentType,reportCallIssue] elif isinstance(rows, (slice, range)): - native = native[rows, :] + native = native[rows, :] # pyright: ignore[reportArgumentType,reportCallIssue] elif is_compliant_series(rows): - native = native[rows.native, :] + native = native[rows.native, :] # pyright: ignore[reportArgumentType,reportCallIssue] elif is_sequence_like(rows): - native = native[rows, :] + native = native[rows, :] # type: ignore[index] else: msg = f"Unreachable code, got unexpected type: {type(rows)}" raise AssertionError(msg) - return self._with_native(native) + return self._with_native(native) # pyright: ignore[reportArgumentType] def simple_select(self, *column_names: str) -> Self: return self._with_native(self.native.select(*column_names)) From dc71fd90e5a040970e24f68a90b790cad0cc05db Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Mon, 21 Apr 2025 10:56:39 +0100 Subject: [PATCH 60/80] extra tests, align polars logic more --- narwhals/_polars/dataframe.py | 6 ++++-- narwhals/dataframe.py | 24 ++++++++++++++---------- tests/frame/getitem_test.py | 25 ++++++++++++++++++++++++- 3 files changed, 42 insertions(+), 13 deletions(-) diff --git a/narwhals/_polars/dataframe.py b/narwhals/_polars/dataframe.py index 17c389f6f5..e3aa0adfb2 100644 --- a/narwhals/_polars/dataframe.py +++ b/narwhals/_polars/dataframe.py @@ -24,7 +24,9 @@ from narwhals.utils import convert_str_slice_to_int_slice from narwhals.utils import is_compliant_series from narwhals.utils import is_index_selector +from narwhals.utils import is_range from narwhals.utils import is_sequence_like +from narwhals.utils import is_slice_index from narwhals.utils import is_slice_none from narwhals.utils import parse_columns_to_drop from narwhals.utils import parse_version @@ -291,9 +293,9 @@ def __getitem__( native = self.native if not is_slice_none(columns): if isinstance(columns, Sized) and len(columns) == 0: - native = native.select() + return self.select() if is_index_selector(columns): - if isinstance(columns, (slice, range)): + if is_slice_index(columns) or is_range(columns): native = native.select( self.columns[slice(columns.start, columns.stop, columns.step)] ) diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index cf7e2054b9..6c2871dbce 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -884,13 +884,21 @@ def __getitem__( """ from narwhals.series import Series + msg = ( + f"Unexpected type for `DataFrame.__getitem__`, got: {type(item)}.\n\n" + "Hints:\n" + "- use `df.item` to select a single item.\n" + "- Use `df[indices, :]` to select rows positionally.\n" + "- Use `df.filter(mask)` to filter rows based on a boolean mask." + ) + if isinstance(item, tuple): if len(item) > 2: - msg = ( + tuple_msg = ( "Tuples cannot be passed to DataFrame.__getitem__ directly.\n\n" "Hint: instead of `df[indices]`, did you mean `df[indices, :]`?" ) - raise TypeError(msg) + raise TypeError(tuple_msg) # These are so heavily overloaded that we just ignore the types for now. rows = None if not item or is_slice_none(item[0]) else item[0] columns = None if len(item) < 2 or is_slice_none(item[1]) else item[1] @@ -903,18 +911,14 @@ def __getitem__( rows = None columns = item else: - msg = ( - f"Unexpected type for `DataFrame.__getitem__`, got: {type(item)}.\n\n" - "Hints:\n" - "- use `df.item` to select a single item.\n" - "- Use `df[indices, :]` to select rows positionally.\n" - "- Use `df.filter(mask)` to filter rows based on a boolean mask." - ) + raise TypeError(msg) + + if isinstance(rows, str): raise TypeError(msg) compliant = self._compliant_frame - if isinstance(columns, (int, str)) and not isinstance(rows, str): + if isinstance(columns, (int, str)): if isinstance(rows, int): return self.item(rows, columns) col_name = columns if isinstance(columns, str) else self.columns[columns] diff --git a/tests/frame/getitem_test.py b/tests/frame/getitem_test.py index 9cc769b4fc..39f97c286c 100644 --- a/tests/frame/getitem_test.py +++ b/tests/frame/getitem_test.py @@ -119,7 +119,7 @@ def test_gather_rows_cols(constructor_eager: ConstructorEager) -> None: assert_equal_data(result, expected) -def test_slice_both_tuples_of_ints(constructor_eager: ConstructorEager) -> None: +def test_slice_both_list_of_ints(constructor_eager: ConstructorEager) -> None: data = {"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]} df = nw.from_native(constructor_eager(data), eager_only=True) result = df[[0, 1], [0, 2]] @@ -127,6 +127,14 @@ def test_slice_both_tuples_of_ints(constructor_eager: ConstructorEager) -> None: assert_equal_data(result, expected) +def test_slice_both_tuple(constructor_eager: ConstructorEager) -> None: + data = {"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]} + df = nw.from_native(constructor_eager(data), eager_only=True) + result = df[(0, 1), ("a", "c")] + expected = {"a": [1, 2], "c": [7, 8]} + assert_equal_data(result, expected) + + def test_slice_int_rows_str_columns(constructor_eager: ConstructorEager) -> None: data = {"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]} df = nw.from_native(constructor_eager(data), eager_only=True) @@ -258,6 +266,15 @@ def test_getitem_ndarray_columns(constructor_eager: ConstructorEager) -> None: assert_equal_data(result, expected) +def test_getitem_ndarray_columns_labels(constructor_eager: ConstructorEager) -> None: + data = {"col1": ["a", "b", "c", "d"], "col2": np.arange(4), "col3": [4, 3, 2, 1]} + nw_df = nw.from_native(constructor_eager(data), eager_only=True) + arr: np.ndarray[tuple[int], np.dtype[Any]] = np.array(["col1", "col2"]) # pyright: ignore[reportAssignmentType] + result = nw_df[:, arr] + expected = {"col1": ["a", "b", "c", "d"], "col2": [0, 1, 2, 3]} + assert_equal_data(result, expected) + + def test_getitem_negative_slice(constructor_eager: ConstructorEager) -> None: data = {"col1": ["a", "b", "c", "d"], "col2": np.arange(4), "col3": [4, 3, 2, 1]} nw_df = nw.from_native(constructor_eager(data), eager_only=True) @@ -349,3 +366,9 @@ def test_pandas_non_str_columns() -> None: result = df[:, [datetime(2020, 1, 1)]] # type: ignore[index] expected = {datetime(2020, 1, 1): [1, 2, 3]} assert result.to_dict(as_series=False) == expected + + +def test_select_rows_by_name(constructor_eager: ConstructorEager) -> None: + df = nw.from_native(constructor_eager({"a": [0, 2, 1]}), eager_only=True) + with pytest.raises(TypeError, match="Unexpected type"): + df["a", :] # type: ignore[index] From 693c4eb009ae8f2de12db10e2e5623359227dfc9 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 21 Apr 2025 10:57:07 +0100 Subject: [PATCH 61/80] Start making generic https://github.com/narwhals-dev/narwhals/pull/2393#discussion_r2051809254 --- narwhals/dataframe.py | 8 ++++++-- narwhals/series.py | 2 +- narwhals/stable/v1/__init__.py | 4 ++-- narwhals/typing.py | 10 +++++----- narwhals/utils.py | 4 ++-- 5 files changed, 16 insertions(+), 12 deletions(-) diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index cf7e2054b9..298e1091d5 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -53,6 +53,7 @@ from typing_extensions import Concatenate from typing_extensions import ParamSpec from typing_extensions import Self + from typing_extensions import TypeAlias from narwhals._compliant import CompliantDataFrame from narwhals._compliant import CompliantLazyFrame @@ -67,8 +68,8 @@ from narwhals.typing import IntoFrame from narwhals.typing import JoinStrategy from narwhals.typing import LazyUniqueKeepStrategy - from narwhals.typing import MultiColSelector - from narwhals.typing import MultiIndexSelector + from narwhals.typing import MultiColSelector as _MultiColSelector + from narwhals.typing import MultiIndexSelector as _MultiIndexSelector from narwhals.typing import PivotAgg from narwhals.typing import SingleColSelector from narwhals.typing import SingleIndexSelector @@ -83,6 +84,9 @@ DataFrameT = TypeVar("DataFrameT", bound="IntoDataFrame") R = TypeVar("R") +MultiColSelector: TypeAlias = "_MultiColSelector[Series[Any]]" +MultiIndexSelector: TypeAlias = "_MultiIndexSelector[Series[Any]]" + class BaseFrame(Generic[_FrameT]): _compliant_frame: Any diff --git a/narwhals/series.py b/narwhals/series.py index 85dd3442a0..84016429af 100644 --- a/narwhals/series.py +++ b/narwhals/series.py @@ -20,7 +20,6 @@ from narwhals.series_struct import SeriesStructNamespace from narwhals.translate import to_native from narwhals.typing import IntoSeriesT -from narwhals.typing import MultiIndexSelector from narwhals.typing import NonNestedLiteral from narwhals.typing import SingleIndexSelector from narwhals.utils import _validate_rolling_arguments @@ -40,6 +39,7 @@ from narwhals._arrow.typing import ArrowArray from narwhals._compliant import CompliantSeries from narwhals.dataframe import DataFrame + from narwhals.dataframe import MultiIndexSelector from narwhals.dtypes import DType from narwhals.typing import ClosedInterval from narwhals.typing import FillNullStrategy diff --git a/narwhals/stable/v1/__init__.py b/narwhals/stable/v1/__init__.py index 047b8ca81b..f966434725 100644 --- a/narwhals/stable/v1/__init__.py +++ b/narwhals/stable/v1/__init__.py @@ -91,14 +91,14 @@ from typing_extensions import TypeVar from narwhals._translate import IntoArrowTable + from narwhals.dataframe import MultiColSelector + from narwhals.dataframe import MultiIndexSelector from narwhals.dtypes import DType from narwhals.typing import ConcatMethod from narwhals.typing import IntoExpr from narwhals.typing import IntoFrame from narwhals.typing import IntoLazyFrameT from narwhals.typing import IntoSeries - from narwhals.typing import MultiColSelector - from narwhals.typing import MultiIndexSelector from narwhals.typing import NonNestedLiteral from narwhals.typing import SingleColSelector from narwhals.typing import SingleIndexSelector diff --git a/narwhals/typing.py b/narwhals/typing.py index 575e17a27a..cd19d24f05 100644 --- a/narwhals/typing.py +++ b/narwhals/typing.py @@ -318,16 +318,16 @@ def __native_namespace__(self) -> ModuleType: ... SingleIndexSelector: TypeAlias = int _SliceIndex: TypeAlias = "_Slice[int] | _SliceNone" """E.g. `[1:]` or `[:3]` or `[::2]`.""" -SizedMultiIndexSelector: TypeAlias = "Sequence[int] | Series[Any] | _1DArray" -MultiIndexSelector: TypeAlias = "_SliceIndex | SizedMultiIndexSelector" +SizedMultiIndexSelector: TypeAlias = "Sequence[int] | _T | _1DArray" +MultiIndexSelector: TypeAlias = "_SliceIndex | SizedMultiIndexSelector[_T]" # Labels/column names SingleNameSelector: TypeAlias = str _SliceName: TypeAlias = "_Slice[str] | _SliceNone" -SizedMultiNameSelector: TypeAlias = "Sequence[str] | Series[Any] | _1DArray" -MultiNameSelector: TypeAlias = "_SliceName | SizedMultiNameSelector" +SizedMultiNameSelector: TypeAlias = "Sequence[str] | _T | _1DArray" +MultiNameSelector: TypeAlias = "_SliceName | SizedMultiNameSelector[_T]" # Mixed selectors SingleColSelector: TypeAlias = "SingleIndexSelector | SingleNameSelector" -MultiColSelector: TypeAlias = "MultiIndexSelector | MultiNameSelector" +MultiColSelector: TypeAlias = "MultiIndexSelector[_T] | MultiNameSelector[_T]" # ruff: noqa: N802 diff --git a/narwhals/utils.py b/narwhals/utils.py index dcd8735104..ed716d9241 100644 --- a/narwhals/utils.py +++ b/narwhals/utils.py @@ -1274,7 +1274,7 @@ def is_slice_none(obj: Any) -> TypeIs[_SliceNone]: return isinstance(obj, slice) and obj == slice(None) -def is_sized_multi_index_selector(obj: Any) -> TypeIs[SizedMultiIndexSelector]: +def is_sized_multi_index_selector(obj: Any) -> TypeIs[SizedMultiIndexSelector[Any]]: np = get_numpy() return ( ( @@ -1316,7 +1316,7 @@ def is_single_index_selector(obj: Any) -> TypeIs[SingleIndexSelector]: return bool(isinstance(obj, int) and not isinstance(obj, bool)) -def is_index_selector(obj: Any) -> TypeIs[SingleIndexSelector | MultiIndexSelector]: +def is_index_selector(obj: Any) -> TypeIs[SingleIndexSelector | MultiIndexSelector[Any]]: return ( is_single_index_selector(obj) or is_sized_multi_index_selector(obj) From 02cc6072fc15cff85ecb5b1024a12e2acde849cc Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Mon, 21 Apr 2025 11:00:09 +0100 Subject: [PATCH 62/80] :party: remove many unneded assumes from hypothesis test --- tests/hypothesis/getitem_test.py | 68 -------------------------------- 1 file changed, 68 deletions(-) diff --git a/tests/hypothesis/getitem_test.py b/tests/hypothesis/getitem_test.py index 4c1691131c..940712d29f 100644 --- a/tests/hypothesis/getitem_test.py +++ b/tests/hypothesis/getitem_test.py @@ -166,74 +166,6 @@ def test_getitem( ) ) ) - - # IndexError: Offset must be non-negative (pyarrow does not support negative indexing) - assume( - not ( - pandas_or_pyarrow_constructor is pyarrow_table_constructor - and isinstance(selector, slice) - and isinstance(selector.start, int) - and selector.start < 0 - ) - ) - assume( - not ( - pandas_or_pyarrow_constructor is pyarrow_table_constructor - and isinstance(selector, slice) - and isinstance(selector.stop, int) - and selector.stop < 0 - ) - ) - - # Pairs of slices are not supported - # NB a few trivial cases are supported, eg df[0:1, :] - # TypeError: Got unexpected argument type for compute function - assume( - not ( - pandas_or_pyarrow_constructor is pyarrow_table_constructor - and isinstance(selector, tuple) - and isinstance(selector[0], slice) - and isinstance(selector[1], slice) - and ( - selector[0] != slice(None, None, None) - or selector[1] != slice(None, None, None) - ) - ) - ) - - # df[[], "a":], df[[], :] etc fail in pyarrow: - # ArrowNotImplementedError: Function 'array_take' has no kernel matching input types (int64, null) - assume( - not ( - pandas_or_pyarrow_constructor is pyarrow_table_constructor - and isinstance(selector, tuple) - and isinstance(selector[0], list) - and len(selector[0]) == 0 - and isinstance(selector[1], slice) - ) - ) - - # df[[], "a":], df[[], :] etc return different results between pandas/polars: - assume( - not ( - pandas_or_pyarrow_constructor is pandas_constructor - and isinstance(selector, tuple) - and isinstance(selector[0], list) - and len(selector[0]) == 0 - and isinstance(selector[1], slice) - ) - ) - - # df[..., ::step] is not fine: - # TypeError: Expected slice of integers or strings, got: - assume( - not ( - isinstance(selector, tuple) - and isinstance(selector[1], slice) - and selector[1].start is None - and selector[1].stop is None - ) - ) # End TODO ================================================================ df_polars = nw.from_native(pl.DataFrame(TEST_DATA)) From e68ee9bd2b38a2af125917cf33e3ca295f375955 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 21 Apr 2025 11:07:34 +0100 Subject: [PATCH 63/80] fill in some holes https://github.com/narwhals-dev/narwhals/pull/2393#discussion_r2052213222 --- narwhals/_compliant/dataframe.py | 18 +++++++++++------- narwhals/_polars/dataframe.py | 4 ++-- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/narwhals/_compliant/dataframe.py b/narwhals/_compliant/dataframe.py index 830169c472..c050d1eaae 100644 --- a/narwhals/_compliant/dataframe.py +++ b/narwhals/_compliant/dataframe.py @@ -118,8 +118,8 @@ def __array__(self, dtype: Any, *, copy: bool | None) -> _2DArray: ... def __getitem__( self, item: tuple[ - SingleIndexSelector | MultiIndexSelector | CompliantSeriesT, - MultiIndexSelector | MultiColSelector | CompliantSeriesT, + SingleIndexSelector | MultiIndexSelector[CompliantSeriesT], + MultiIndexSelector[CompliantSeriesT] | MultiColSelector[CompliantSeriesT], ], ) -> Self: ... def simple_select(self, *column_names: str) -> Self: @@ -396,18 +396,22 @@ def _numpy_column_names( ) -> list[str]: return list(columns or (f"column_{x}" for x in range(data.shape[1]))) - def _gather(self, indices: SizedMultiIndexSelector) -> Self: ... + def _gather(self, indices: SizedMultiIndexSelector[CompliantSeriesAny]) -> Self: ... def _gather_slice(self, indices: _SliceIndex | range) -> Self: ... - def _select_indices(self, indices: SizedMultiIndexSelector) -> Self: ... - def _select_labels(self, indices: SizedMultiNameSelector) -> Self: ... + def _select_indices( + self, indices: SizedMultiIndexSelector[CompliantSeriesAny] + ) -> Self: ... + def _select_labels( + self, indices: SizedMultiNameSelector[CompliantSeriesAny] + ) -> Self: ... def _select_slice_of_indices(self, indices: _SliceIndex | range) -> Self: ... def _select_slice_of_labels(self, item: _SliceName) -> Self: ... def __getitem__( self, item: tuple[ - SingleIndexSelector | MultiIndexSelector | CompliantSeriesAny, - MultiIndexSelector | MultiColSelector | CompliantSeriesAny, + SingleIndexSelector | MultiIndexSelector[CompliantSeriesAny], + MultiIndexSelector[CompliantSeriesAny] | MultiColSelector[CompliantSeriesAny], ], ) -> Self: rows, columns = item diff --git a/narwhals/_polars/dataframe.py b/narwhals/_polars/dataframe.py index e3aa0adfb2..d3fb60faae 100644 --- a/narwhals/_polars/dataframe.py +++ b/narwhals/_polars/dataframe.py @@ -270,8 +270,8 @@ def shape(self) -> tuple[int, int]: def __getitem__( self, item: tuple[ - SingleIndexSelector | MultiIndexSelector | PolarsSeries, - MultiIndexSelector | MultiColSelector | PolarsSeries, + SingleIndexSelector | MultiIndexSelector[PolarsSeries], + MultiIndexSelector[PolarsSeries] | MultiColSelector[PolarsSeries], ], ) -> Any: rows, columns = item From 17538d3e7523c3312f5e725cbe036f40ea4da97d Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Mon, 21 Apr 2025 11:13:55 +0100 Subject: [PATCH 64/80] remove outdated comment --- narwhals/series.py | 1 - 1 file changed, 1 deletion(-) diff --git a/narwhals/series.py b/narwhals/series.py index 84016429af..02d9f5d673 100644 --- a/narwhals/series.py +++ b/narwhals/series.py @@ -178,7 +178,6 @@ def __getitem__(self, idx: SingleIndexSelector | MultiIndexSelector) -> Any | Se if isinstance(idx, self.to_native().__class__): idx = self._with_compliant(self._compliant_series._with_native(idx)) - # For Series.__getitem__, we only if not is_index_selector(idx): msg = ( f"Unexpected type for `Series.__getitem__`: {type(idx)}.\n\n" From 2a5e8668660c50129f1c3f94f12cc73ea4eed122 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 21 Apr 2025 11:39:20 +0100 Subject: [PATCH 65/80] more updating annotations --- narwhals/_arrow/dataframe.py | 34 +++++++++++++++++++----------- narwhals/_arrow/series.py | 7 +++--- narwhals/_compliant/dataframe.py | 16 ++++++-------- narwhals/_compliant/series.py | 5 +++-- narwhals/_pandas_like/dataframe.py | 8 ++++--- narwhals/_pandas_like/series.py | 5 +++-- narwhals/_polars/dataframe.py | 2 +- 7 files changed, 44 insertions(+), 33 deletions(-) diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index 230e83d263..207b4514c1 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -49,7 +49,6 @@ from narwhals._arrow.group_by import ArrowGroupBy from narwhals._arrow.namespace import ArrowNamespace from narwhals._arrow.typing import ArrowChunkedArray - from narwhals._arrow.typing import Indices # type: ignore[attr-defined] from narwhals._arrow.typing import Mask # type: ignore[attr-defined] from narwhals._arrow.typing import Order # type: ignore[attr-defined] from narwhals._translate import IntoArrowTable @@ -62,6 +61,7 @@ from narwhals.typing import SizedMultiNameSelector from narwhals.typing import SizeUnit from narwhals.typing import UniqueKeepStrategy + from narwhals.typing import _1DArray from narwhals.typing import _2DArray from narwhals.typing import _SliceIndex from narwhals.typing import _SliceName @@ -248,12 +248,12 @@ def get_column(self, name: str) -> ArrowSeries: def __array__(self, dtype: Any, *, copy: bool | None) -> _2DArray: return self.native.__array__(dtype, copy=copy) - def _gather(self, item: SizedMultiIndexSelector) -> Self: + def _gather(self, item: SizedMultiIndexSelector[ArrowChunkedArray]) -> Self: if len(item) == 0: return self._with_native(self.native.slice(0, 0)) if self._backend_version < (18,) and isinstance(item, tuple): item = list(item) - return self._with_native(self.native.take(item)) # pyright: ignore[reportArgumentType] + return self._with_native(self.native.take(item)) def _gather_slice(self, item: _SliceIndex | range) -> Self: start = item.start or 0 @@ -276,18 +276,28 @@ def _select_slice_of_indices(self, item: _SliceIndex | range) -> Self: self.native.select(self.columns[item.start : item.stop : item.step]) ) - def _select_indices(self, item: SizedMultiIndexSelector) -> Self: + def _select_indices(self, item: SizedMultiIndexSelector[ArrowChunkedArray]) -> Self: + selector: Sequence[int] | Sequence[str] if isinstance(item, pa.ChunkedArray): - item = item.to_pylist() # pyright: ignore[reportAssignmentType] - if is_numpy_array(item): - item = item.tolist() - return self._with_native(self.native.select(cast("Indices", item))) + # TODO @dangotbanned: Fix upstream with `pa.ChunkedArray.to_pylist(self) -> list[Any]:` + selector = cast("Sequence[int]", item.to_pylist()) + # NOTE: Probably don't need to convert to a list here? + elif is_numpy_array(item): + selector = item.tolist() + else: + selector = item + return self._with_native(self.native.select(selector)) - def _select_labels(self, item: SizedMultiNameSelector) -> Self: + def _select_labels(self, item: SizedMultiNameSelector[ArrowChunkedArray]) -> Self: + selector: Sequence[str] | _1DArray if isinstance(item, pa.ChunkedArray): - item = item.to_pylist() # pyright: ignore[reportAssignmentType] - # pyarrow-stubs overly strict, accept list[str] | Indices - return self._with_native(self.native.select(item)) # pyright: ignore[reportArgumentType] + # TODO @dangotbanned: Fix upstream with `pa.ChunkedArray.to_pylist(self) -> list[Any]:` + selector = cast("Sequence[str]", item.to_pylist()) + else: + selector = item + # TODO @dangotbanned: Fix upstream `pa.Table.select` https://github.com/zen-xu/pyarrow-stubs/blob/f899bb35e10b36f7906a728e9f8acf3e0a1f9f64/pyarrow-stubs/__lib_pxi/table.pyi#L597 + # NOTE: Investigate what `cython` actually checks + return self._with_native(self.native.select(selector)) # pyright: ignore[reportArgumentType] @property def schema(self) -> dict[str, DType]: diff --git a/narwhals/_arrow/series.py b/narwhals/_arrow/series.py index 7536a52992..6e6b280f77 100644 --- a/narwhals/_arrow/series.py +++ b/narwhals/_arrow/series.py @@ -71,6 +71,7 @@ from narwhals.typing import TemporalLiteral from narwhals.typing import _1DArray from narwhals.typing import _2DArray + from narwhals.typing import _SliceIndex from narwhals.utils import Version from narwhals.utils import _FullContext @@ -407,14 +408,14 @@ def __native_namespace__(self) -> ModuleType: def name(self) -> str: return self._name - def _gather(self, item: SizedMultiIndexSelector) -> Self: + def _gather(self, item: SizedMultiIndexSelector[ArrowChunkedArray]) -> Self: if len(item) == 0: return self._with_native(self.native.slice(0, 0)) if self._backend_version < (18,) and isinstance(item, tuple): item = list(item) - return self._with_native(self.native.take(item)) # pyright: ignore[reportArgumentType] + return self._with_native(self.native.take(item)) - def _gather_slice(self, item: slice | range) -> Self: + def _gather_slice(self, item: _SliceIndex | range) -> Self: start = item.start or 0 stop = item.stop if item.stop is not None else len(self.native) if start < 0: diff --git a/narwhals/_compliant/dataframe.py b/narwhals/_compliant/dataframe.py index c050d1eaae..2d923cc2c6 100644 --- a/narwhals/_compliant/dataframe.py +++ b/narwhals/_compliant/dataframe.py @@ -13,7 +13,6 @@ from typing import overload from narwhals._compliant.typing import CompliantExprT_contra -from narwhals._compliant.typing import CompliantSeriesAny from narwhals._compliant.typing import CompliantSeriesT from narwhals._compliant.typing import EagerExprT from narwhals._compliant.typing import EagerSeriesT @@ -55,6 +54,7 @@ from narwhals.typing import LazyUniqueKeepStrategy from narwhals.typing import MultiColSelector from narwhals.typing import MultiIndexSelector + from narwhals.typing import NativeSeries from narwhals.typing import PivotAgg from narwhals.typing import SingleIndexSelector from narwhals.typing import SizedMultiIndexSelector @@ -396,22 +396,18 @@ def _numpy_column_names( ) -> list[str]: return list(columns or (f"column_{x}" for x in range(data.shape[1]))) - def _gather(self, indices: SizedMultiIndexSelector[CompliantSeriesAny]) -> Self: ... + def _gather(self, indices: SizedMultiIndexSelector[NativeSeries]) -> Self: ... def _gather_slice(self, indices: _SliceIndex | range) -> Self: ... - def _select_indices( - self, indices: SizedMultiIndexSelector[CompliantSeriesAny] - ) -> Self: ... - def _select_labels( - self, indices: SizedMultiNameSelector[CompliantSeriesAny] - ) -> Self: ... + def _select_indices(self, indices: SizedMultiIndexSelector[NativeSeries]) -> Self: ... + def _select_labels(self, indices: SizedMultiNameSelector[NativeSeries]) -> Self: ... def _select_slice_of_indices(self, indices: _SliceIndex | range) -> Self: ... def _select_slice_of_labels(self, item: _SliceName) -> Self: ... def __getitem__( self, item: tuple[ - SingleIndexSelector | MultiIndexSelector[CompliantSeriesAny], - MultiIndexSelector[CompliantSeriesAny] | MultiColSelector[CompliantSeriesAny], + SingleIndexSelector | MultiIndexSelector[EagerSeriesT], + MultiIndexSelector[EagerSeriesT] | MultiColSelector[EagerSeriesT], ], ) -> Self: rows, columns = item diff --git a/narwhals/_compliant/series.py b/narwhals/_compliant/series.py index 5ab2d26109..6a475c1a60 100644 --- a/narwhals/_compliant/series.py +++ b/narwhals/_compliant/series.py @@ -52,6 +52,7 @@ from narwhals.typing import SizedMultiIndexSelector from narwhals.typing import TemporalLiteral from narwhals.typing import _1DArray + from narwhals.typing import _SliceIndex from narwhals.utils import Implementation from narwhals.utils import Version from narwhals.utils import _FullContext @@ -320,8 +321,8 @@ def __narwhals_namespace__( def _to_expr(self) -> EagerExpr[Any, Any]: return self.__narwhals_namespace__()._expr._from_series(self) # type: ignore[no-any-return] - def _gather(self, indices: SizedMultiIndexSelector) -> Self: ... - def _gather_slice(self, indices: slice | range) -> Self: ... + def _gather(self, indices: SizedMultiIndexSelector[NativeSeriesT]) -> Self: ... + def _gather_slice(self, indices: _SliceIndex | range) -> Self: ... def __getitem__(self, item: Any) -> Self: if is_slice_none(item): diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index 156766cbbb..364be1174c 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -281,7 +281,7 @@ def get_column(self, name: str) -> PandasLikeSeries: def __array__(self, dtype: Any = None, *, copy: bool | None = None) -> _2DArray: return self.to_numpy(dtype=dtype, copy=copy) - def _gather(self, items: SizedMultiIndexSelector) -> Self: + def _gather(self, items: SizedMultiIndexSelector[pd.Series[Any]]) -> Self: items = list(items) if isinstance(items, tuple) else items return self._with_native(self.native.iloc[items, :]) @@ -304,14 +304,16 @@ def _select_slice_of_indices(self, item: _SliceIndex | range) -> Self: validate_column_names=False, ) - def _select_indices(self, item: SizedMultiIndexSelector) -> Self: + def _select_indices(self, item: SizedMultiIndexSelector[pd.Series[Any]]) -> Self: item = list(item) if isinstance(item, tuple) else item return self._with_native( self.native.iloc[:, item], validate_column_names=False, ) - def _select_labels(self, indices: SizedMultiNameSelector) -> PandasLikeDataFrame: + def _select_labels( + self, indices: SizedMultiNameSelector[pd.Series[Any]] + ) -> PandasLikeDataFrame: return self._with_native(self.native.loc[:, indices]) # --- properties --- diff --git a/narwhals/_pandas_like/series.py b/narwhals/_pandas_like/series.py index 01e4df0052..0b8cc32432 100644 --- a/narwhals/_pandas_like/series.py +++ b/narwhals/_pandas_like/series.py @@ -57,6 +57,7 @@ from narwhals.typing import TemporalLiteral from narwhals.typing import _1DArray from narwhals.typing import _AnyDArray + from narwhals.typing import _SliceIndex from narwhals.utils import Version from narwhals.utils import _FullContext @@ -146,11 +147,11 @@ def __narwhals_namespace__(self) -> PandasLikeNamespace: self._implementation, self._backend_version, self._version ) - def _gather(self, rows: SizedMultiIndexSelector) -> Self: + def _gather(self, rows: SizedMultiIndexSelector[pd.Series[Any]]) -> Self: rows = list(rows) if isinstance(rows, tuple) else rows return self._with_native(self.native.iloc[rows]) - def _gather_slice(self, item: slice | range) -> Self: + def _gather_slice(self, item: _SliceIndex | range) -> Self: return self._with_native( self.native.iloc[slice(item.start, item.stop, item.step)] ) diff --git a/narwhals/_polars/dataframe.py b/narwhals/_polars/dataframe.py index d3fb60faae..8fd3508283 100644 --- a/narwhals/_polars/dataframe.py +++ b/narwhals/_polars/dataframe.py @@ -302,7 +302,7 @@ def __getitem__( elif is_compliant_series(columns): native = native[:, columns.native.to_list()] else: - native = native[:, columns] # type: ignore[index] + native = native[:, columns] elif isinstance(columns, slice): native = native.select( self.columns[ From 308e197c8ea2ba8afc1ba2e4097dfe1c74b44bb5 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 21 Apr 2025 11:52:21 +0100 Subject: [PATCH 66/80] feat(typing): Add `NativeSeriesT` to `EagerDataFrame` Fixes a lot https://github.com/narwhals-dev/narwhals/pull/2393#discussion_r2052251020 --- narwhals/_arrow/dataframe.py | 4 +++- narwhals/_compliant/dataframe.py | 14 ++++++++------ narwhals/_compliant/typing.py | 4 ++-- narwhals/_pandas_like/dataframe.py | 4 +++- 4 files changed, 16 insertions(+), 10 deletions(-) diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index 207b4514c1..5f6311d00f 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -81,7 +81,9 @@ PromoteOptions: TypeAlias = Literal["none", "default", "permissive"] -class ArrowDataFrame(EagerDataFrame["ArrowSeries", "ArrowExpr", "pa.Table"]): +class ArrowDataFrame( + EagerDataFrame["ArrowSeries", "ArrowExpr", "pa.Table", "pa.ChunkedArray[Any]"] +): def __init__( self, native_dataframe: pa.Table, diff --git a/narwhals/_compliant/dataframe.py b/narwhals/_compliant/dataframe.py index 2d923cc2c6..02fae999f3 100644 --- a/narwhals/_compliant/dataframe.py +++ b/narwhals/_compliant/dataframe.py @@ -17,6 +17,7 @@ from narwhals._compliant.typing import EagerExprT from narwhals._compliant.typing import EagerSeriesT from narwhals._compliant.typing import NativeFrameT +from narwhals._compliant.typing import NativeSeriesT from narwhals._expression_parsing import evaluate_output_names_and_aliases from narwhals._translate import ArrowConvertible from narwhals._translate import DictConvertible @@ -54,7 +55,6 @@ from narwhals.typing import LazyUniqueKeepStrategy from narwhals.typing import MultiColSelector from narwhals.typing import MultiIndexSelector - from narwhals.typing import NativeSeries from narwhals.typing import PivotAgg from narwhals.typing import SingleIndexSelector from narwhals.typing import SizedMultiIndexSelector @@ -353,11 +353,11 @@ def _evaluate_expr(self, expr: CompliantExprT_contra, /) -> Any: class EagerDataFrame( CompliantDataFrame[EagerSeriesT, EagerExprT, NativeFrameT], CompliantLazyFrame[EagerExprT, NativeFrameT], - Protocol[EagerSeriesT, EagerExprT, NativeFrameT], + Protocol[EagerSeriesT, EagerExprT, NativeFrameT, NativeSeriesT], ): def __narwhals_namespace__( self, - ) -> EagerNamespace[Self, EagerSeriesT, EagerExprT, NativeFrameT, Any]: ... + ) -> EagerNamespace[Self, EagerSeriesT, EagerExprT, NativeFrameT, NativeSeriesT]: ... def _evaluate_expr(self, expr: EagerExprT, /) -> EagerSeriesT: """Evaluate `expr` and ensure it has a **single** output.""" @@ -396,10 +396,12 @@ def _numpy_column_names( ) -> list[str]: return list(columns or (f"column_{x}" for x in range(data.shape[1]))) - def _gather(self, indices: SizedMultiIndexSelector[NativeSeries]) -> Self: ... + def _gather(self, indices: SizedMultiIndexSelector[NativeSeriesT]) -> Self: ... def _gather_slice(self, indices: _SliceIndex | range) -> Self: ... - def _select_indices(self, indices: SizedMultiIndexSelector[NativeSeries]) -> Self: ... - def _select_labels(self, indices: SizedMultiNameSelector[NativeSeries]) -> Self: ... + def _select_indices( + self, indices: SizedMultiIndexSelector[NativeSeriesT] + ) -> Self: ... + def _select_labels(self, indices: SizedMultiNameSelector[NativeSeriesT]) -> Self: ... def _select_slice_of_indices(self, indices: _SliceIndex | range) -> Self: ... def _select_slice_of_labels(self, item: _SliceName) -> Self: ... diff --git a/narwhals/_compliant/typing.py b/narwhals/_compliant/typing.py index 30788b3210..56c092799d 100644 --- a/narwhals/_compliant/typing.py +++ b/narwhals/_compliant/typing.py @@ -47,7 +47,7 @@ DepthTrackingExprAny: TypeAlias = "DepthTrackingExpr[Any, Any]" -EagerDataFrameAny: TypeAlias = "EagerDataFrame[Any, Any, Any]" +EagerDataFrameAny: TypeAlias = "EagerDataFrame[Any, Any, Any, Any]" EagerSeriesAny: TypeAlias = "EagerSeries[Any]" EagerExprAny: TypeAlias = "EagerExpr[Any, Any]" EagerNamespaceAny: TypeAlias = "EagerNamespace[EagerDataFrameAny, EagerSeriesAny, EagerExprAny, NativeFrame, NativeSeries]" @@ -111,7 +111,7 @@ EagerSeriesT_co = TypeVar("EagerSeriesT_co", bound=EagerSeriesAny, covariant=True) # NOTE: `pyright` gives false (8) positives if this uses `EagerDataFrameAny`? -EagerDataFrameT = TypeVar("EagerDataFrameT", bound="EagerDataFrame[Any, Any, Any]") +EagerDataFrameT = TypeVar("EagerDataFrameT", bound="EagerDataFrame[Any, Any, Any, Any]") LazyExprT = TypeVar("LazyExprT", bound=LazyExprAny) LazyExprT_contra = TypeVar("LazyExprT_contra", bound=LazyExprAny, contravariant=True) diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index 364be1174c..8163da9183 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -103,7 +103,9 @@ ) -class PandasLikeDataFrame(EagerDataFrame["PandasLikeSeries", "PandasLikeExpr", "Any"]): +class PandasLikeDataFrame( + EagerDataFrame["PandasLikeSeries", "PandasLikeExpr", "Any", "pd.Series[Any]"] +): def __init__( self, native_dataframe: Any, From 8e567d5c65b1bb0eb3cc82f813aa8299512c8554 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 21 Apr 2025 11:59:17 +0100 Subject: [PATCH 67/80] chore(typing): Update `polars` ignores --- narwhals/_polars/dataframe.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/narwhals/_polars/dataframe.py b/narwhals/_polars/dataframe.py index 8fd3508283..fe9d3aa140 100644 --- a/narwhals/_polars/dataframe.py +++ b/narwhals/_polars/dataframe.py @@ -300,7 +300,7 @@ def __getitem__( self.columns[slice(columns.start, columns.stop, columns.step)] ) elif is_compliant_series(columns): - native = native[:, columns.native.to_list()] + native = native[:, columns.native.to_list()] # type: ignore[attr-defined, index] else: native = native[:, columns] elif isinstance(columns, slice): @@ -325,7 +325,7 @@ def __getitem__( elif is_compliant_series(rows): native = native[rows.native, :] # pyright: ignore[reportArgumentType,reportCallIssue] elif is_sequence_like(rows): - native = native[rows, :] # type: ignore[index] + native = native[rows, :] # pyright: ignore[reportArgumentType,reportCallIssue] else: msg = f"Unreachable code, got unexpected type: {type(rows)}" raise AssertionError(msg) From 9a2f890cd7e723f8177b7c0c3c4e78b86d3c4243 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 21 Apr 2025 13:17:14 +0100 Subject: [PATCH 68/80] docs(typing): Add note on `pa.Table.select` --- narwhals/_arrow/dataframe.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index 5f6311d00f..a65b1720c1 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -283,7 +283,8 @@ def _select_indices(self, item: SizedMultiIndexSelector[ArrowChunkedArray]) -> S if isinstance(item, pa.ChunkedArray): # TODO @dangotbanned: Fix upstream with `pa.ChunkedArray.to_pylist(self) -> list[Any]:` selector = cast("Sequence[int]", item.to_pylist()) - # NOTE: Probably don't need to convert to a list here? + # TODO @dangotbanned: Fix upstream, it is actually much narrower + # **Doesn't accept `ndarray`** elif is_numpy_array(item): selector = item.tolist() else: From 2c0100778978fbc2f562ba25810c5184809e9920 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 21 Apr 2025 13:30:56 +0100 Subject: [PATCH 69/80] refactor: More align names --- narwhals/_arrow/dataframe.py | 10 ++++++---- narwhals/_compliant/dataframe.py | 27 +++++++++++++-------------- narwhals/_compliant/series.py | 5 ++--- narwhals/_pandas_like/dataframe.py | 16 ++++++++-------- narwhals/_pandas_like/series.py | 4 ++-- 5 files changed, 31 insertions(+), 31 deletions(-) diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index a65b1720c1..2c67d544fa 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -269,16 +269,18 @@ def _gather_slice(self, item: _SliceIndex | range) -> Self: raise NotImplementedError(msg) return self._with_native(self.native.slice(start, stop - start)) - def _select_slice_of_labels(self, item: _SliceName) -> Self: + def _select_slice_name(self, item: _SliceName) -> Self: start, stop, step = convert_str_slice_to_int_slice(item, self.columns) return self._with_native(self.native.select(self.columns[start:stop:step])) - def _select_slice_of_indices(self, item: _SliceIndex | range) -> Self: + def _select_slice_index(self, item: _SliceIndex | range) -> Self: return self._with_native( self.native.select(self.columns[item.start : item.stop : item.step]) ) - def _select_indices(self, item: SizedMultiIndexSelector[ArrowChunkedArray]) -> Self: + def _select_multi_index( + self, item: SizedMultiIndexSelector[ArrowChunkedArray] + ) -> Self: selector: Sequence[int] | Sequence[str] if isinstance(item, pa.ChunkedArray): # TODO @dangotbanned: Fix upstream with `pa.ChunkedArray.to_pylist(self) -> list[Any]:` @@ -291,7 +293,7 @@ def _select_indices(self, item: SizedMultiIndexSelector[ArrowChunkedArray]) -> S selector = item return self._with_native(self.native.select(selector)) - def _select_labels(self, item: SizedMultiNameSelector[ArrowChunkedArray]) -> Self: + def _select_multi_name(self, item: SizedMultiNameSelector[ArrowChunkedArray]) -> Self: selector: Sequence[str] | _1DArray if isinstance(item, pa.ChunkedArray): # TODO @dangotbanned: Fix upstream with `pa.ChunkedArray.to_pylist(self) -> list[Any]:` diff --git a/narwhals/_compliant/dataframe.py b/narwhals/_compliant/dataframe.py index 02fae999f3..f012c88cfc 100644 --- a/narwhals/_compliant/dataframe.py +++ b/narwhals/_compliant/dataframe.py @@ -396,15 +396,14 @@ def _numpy_column_names( ) -> list[str]: return list(columns or (f"column_{x}" for x in range(data.shape[1]))) - def _gather(self, indices: SizedMultiIndexSelector[NativeSeriesT]) -> Self: ... - def _gather_slice(self, indices: _SliceIndex | range) -> Self: ... - def _select_indices( - self, indices: SizedMultiIndexSelector[NativeSeriesT] + def _gather(self, item: SizedMultiIndexSelector[NativeSeriesT]) -> Self: ... + def _gather_slice(self, item: _SliceIndex | range) -> Self: ... + def _select_multi_index( + self, item: SizedMultiIndexSelector[NativeSeriesT] ) -> Self: ... - def _select_labels(self, indices: SizedMultiNameSelector[NativeSeriesT]) -> Self: ... - def _select_slice_of_indices(self, indices: _SliceIndex | range) -> Self: ... - def _select_slice_of_labels(self, item: _SliceName) -> Self: ... - + def _select_multi_name(self, item: SizedMultiNameSelector[NativeSeriesT]) -> Self: ... + def _select_slice_index(self, item: _SliceIndex | range) -> Self: ... + def _select_slice_name(self, item: _SliceName) -> Self: ... def __getitem__( self, item: tuple[ @@ -419,17 +418,17 @@ def __getitem__( return compliant.select() if is_index_selector(columns): if is_slice_index(columns) or is_range(columns): - compliant = compliant._select_slice_of_indices(columns) + compliant = compliant._select_slice_index(columns) elif is_compliant_series(columns): - compliant = self._select_indices(columns.native) + compliant = self._select_multi_index(columns.native) else: - compliant = compliant._select_indices(columns) + compliant = compliant._select_multi_index(columns) elif isinstance(columns, slice): - compliant = compliant._select_slice_of_labels(columns) + compliant = compliant._select_slice_name(columns) elif is_compliant_series(columns): - compliant = self._select_labels(columns.native) + compliant = self._select_multi_name(columns.native) elif is_sequence_like(columns): - compliant = self._select_labels(columns) + compliant = self._select_multi_name(columns) else: msg = f"Unreachable code, got unexpected type: {type(columns)}" raise AssertionError(msg) diff --git a/narwhals/_compliant/series.py b/narwhals/_compliant/series.py index 6a475c1a60..affc2e29fb 100644 --- a/narwhals/_compliant/series.py +++ b/narwhals/_compliant/series.py @@ -321,9 +321,8 @@ def __narwhals_namespace__( def _to_expr(self) -> EagerExpr[Any, Any]: return self.__narwhals_namespace__()._expr._from_series(self) # type: ignore[no-any-return] - def _gather(self, indices: SizedMultiIndexSelector[NativeSeriesT]) -> Self: ... - def _gather_slice(self, indices: _SliceIndex | range) -> Self: ... - + def _gather(self, item: SizedMultiIndexSelector[NativeSeriesT]) -> Self: ... + def _gather_slice(self, item: _SliceIndex | range) -> Self: ... def __getitem__(self, item: Any) -> Self: if is_slice_none(item): return self diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index 8163da9183..da5e9d326b 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -283,8 +283,8 @@ def get_column(self, name: str) -> PandasLikeSeries: def __array__(self, dtype: Any = None, *, copy: bool | None = None) -> _2DArray: return self.to_numpy(dtype=dtype, copy=copy) - def _gather(self, items: SizedMultiIndexSelector[pd.Series[Any]]) -> Self: - items = list(items) if isinstance(items, tuple) else items + def _gather(self, item: SizedMultiIndexSelector[pd.Series[Any]]) -> Self: + items = list(item) if isinstance(item, tuple) else item return self._with_native(self.native.iloc[items, :]) def _gather_slice(self, item: _SliceIndex | range) -> Self: @@ -293,30 +293,30 @@ def _gather_slice(self, item: _SliceIndex | range) -> Self: validate_column_names=False, ) - def _select_slice_of_labels(self, item: _SliceName) -> Self: + def _select_slice_name(self, item: _SliceName) -> Self: start, stop, step = convert_str_slice_to_int_slice(item, self.native.columns) return self._with_native( self.native.iloc[:, slice(start, stop, step)], validate_column_names=False, ) - def _select_slice_of_indices(self, item: _SliceIndex | range) -> Self: + def _select_slice_index(self, item: _SliceIndex | range) -> Self: return self._with_native( self.native.iloc[:, slice(item.start, item.stop, item.step)], validate_column_names=False, ) - def _select_indices(self, item: SizedMultiIndexSelector[pd.Series[Any]]) -> Self: + def _select_multi_index(self, item: SizedMultiIndexSelector[pd.Series[Any]]) -> Self: item = list(item) if isinstance(item, tuple) else item return self._with_native( self.native.iloc[:, item], validate_column_names=False, ) - def _select_labels( - self, indices: SizedMultiNameSelector[pd.Series[Any]] + def _select_multi_name( + self, item: SizedMultiNameSelector[pd.Series[Any]] ) -> PandasLikeDataFrame: - return self._with_native(self.native.loc[:, indices]) + return self._with_native(self.native.loc[:, item]) # --- properties --- @property diff --git a/narwhals/_pandas_like/series.py b/narwhals/_pandas_like/series.py index 0b8cc32432..535265a13d 100644 --- a/narwhals/_pandas_like/series.py +++ b/narwhals/_pandas_like/series.py @@ -147,8 +147,8 @@ def __narwhals_namespace__(self) -> PandasLikeNamespace: self._implementation, self._backend_version, self._version ) - def _gather(self, rows: SizedMultiIndexSelector[pd.Series[Any]]) -> Self: - rows = list(rows) if isinstance(rows, tuple) else rows + def _gather(self, item: SizedMultiIndexSelector[pd.Series[Any]]) -> Self: + rows = list(item) if isinstance(item, tuple) else item return self._with_native(self.native.iloc[rows]) def _gather_slice(self, item: _SliceIndex | range) -> Self: From 1119ad7e9dd5b03980bd714dcb355c385102361a Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Mon, 21 Apr 2025 15:53:13 +0100 Subject: [PATCH 70/80] refactor: Factor-in `_pandas_like.utils.convert_str_slice_to_int_slice` --- narwhals/_pandas_like/dataframe.py | 9 +++++---- narwhals/_pandas_like/utils.py | 19 ------------------- 2 files changed, 5 insertions(+), 23 deletions(-) diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index da5e9d326b..1245d823ea 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -19,7 +19,6 @@ from narwhals._pandas_like.utils import align_and_extract_native from narwhals._pandas_like.utils import align_series_full_broadcast from narwhals._pandas_like.utils import check_column_names_are_unique -from narwhals._pandas_like.utils import convert_str_slice_to_int_slice from narwhals._pandas_like.utils import get_dtype_backend from narwhals._pandas_like.utils import native_to_narwhals_dtype from narwhals._pandas_like.utils import object_native_to_narwhals_dtype @@ -294,10 +293,12 @@ def _gather_slice(self, item: _SliceIndex | range) -> Self: ) def _select_slice_name(self, item: _SliceName) -> Self: - start, stop, step = convert_str_slice_to_int_slice(item, self.native.columns) + columns = self.native.columns + start = columns.get_loc(item.start) if item.start is not None else None + stop = columns.get_loc(item.stop) + 1 if item.stop is not None else None + selector = slice(start, stop, item.step) return self._with_native( - self.native.iloc[:, slice(start, stop, step)], - validate_column_names=False, + self.native.iloc[:, selector], validate_column_names=False ) def _select_slice_index(self, item: _SliceIndex | range) -> Self: diff --git a/narwhals/_pandas_like/utils.py b/narwhals/_pandas_like/utils.py index 2d237a79aa..74ab8f88e8 100644 --- a/narwhals/_pandas_like/utils.py +++ b/narwhals/_pandas_like/utils.py @@ -8,7 +8,6 @@ from typing import Sequence from typing import Sized from typing import TypeVar -from typing import cast import pandas as pd @@ -565,24 +564,6 @@ def int_dtype_mapper(dtype: Any) -> str: return "int64" -def convert_str_slice_to_int_slice( - str_slice: slice | range, columns: pd.Index[str] -) -> tuple[int | None, int | None, int | None]: - # We can safely cast to int because we know that `columns` doesn't contain duplicates. - start = ( - cast("int", columns.get_loc(str_slice.start)) - if str_slice.start is not None - else None - ) - stop = ( - cast("int", columns.get_loc(str_slice.stop)) + 1 - if str_slice.stop is not None - else None - ) - step = str_slice.step - return (start, stop, step) - - def calculate_timestamp_datetime( s: pd.Series[int], original_time_unit: str, time_unit: str ) -> pd.Series[int]: From b6aec5921ee987a5bceee788721246dfbf5ca92c Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Tue, 22 Apr 2025 12:31:21 +0100 Subject: [PATCH 71/80] remove outdated comment, type CompliantSeies.__getitem__ --- narwhals/_compliant/series.py | 3 ++- narwhals/_pandas_like/dataframe.py | 10 ++-------- narwhals/dataframe.py | 1 - 3 files changed, 4 insertions(+), 10 deletions(-) diff --git a/narwhals/_compliant/series.py b/narwhals/_compliant/series.py index 5ab2d26109..d3827c5070 100644 --- a/narwhals/_compliant/series.py +++ b/narwhals/_compliant/series.py @@ -45,6 +45,7 @@ from narwhals.typing import ClosedInterval from narwhals.typing import FillNullStrategy from narwhals.typing import Into1DArray + from narwhals.typing import MultiIndexSelector from narwhals.typing import NonNestedLiteral from narwhals.typing import NumericLiteral from narwhals.typing import RankMethod @@ -323,7 +324,7 @@ def _to_expr(self) -> EagerExpr[Any, Any]: def _gather(self, indices: SizedMultiIndexSelector) -> Self: ... def _gather_slice(self, indices: slice | range) -> Self: ... - def __getitem__(self, item: Any) -> Self: + def __getitem__(self, item: MultiIndexSelector) -> Self: if is_slice_none(item): return self if isinstance(item, int): diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index 156766cbbb..53dc36412e 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -299,17 +299,11 @@ def _select_slice_of_labels(self, item: _SliceName) -> Self: ) def _select_slice_of_indices(self, item: _SliceIndex | range) -> Self: - return self._with_native( - self.native.iloc[:, slice(item.start, item.stop, item.step)], - validate_column_names=False, - ) + return self._with_native(self.native.iloc[:, item], validate_column_names=False) def _select_indices(self, item: SizedMultiIndexSelector) -> Self: item = list(item) if isinstance(item, tuple) else item - return self._with_native( - self.native.iloc[:, item], - validate_column_names=False, - ) + return self._with_native(self.native.iloc[:, item], validate_column_names=False) def _select_labels(self, indices: SizedMultiNameSelector) -> PandasLikeDataFrame: return self._with_native(self.native.loc[:, indices]) diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index c281e23ada..2ab5c354c4 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -903,7 +903,6 @@ def __getitem__( "Hint: instead of `df[indices]`, did you mean `df[indices, :]`?" ) raise TypeError(tuple_msg) - # These are so heavily overloaded that we just ignore the types for now. rows = None if not item or is_slice_none(item[0]) else item[0] columns = None if len(item) < 2 or is_slice_none(item[1]) else item[1] if rows is None and columns is None: From f87422e8c037c0d2f63c8cc4844ffe04ea1bf9fe Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Tue, 22 Apr 2025 13:06:49 +0100 Subject: [PATCH 72/80] more consistent naming --- narwhals/_arrow/dataframe.py | 50 +++++++++++++++--------------- narwhals/_arrow/series.py | 18 +++++------ narwhals/_compliant/dataframe.py | 22 ++++++------- narwhals/_compliant/series.py | 4 +-- narwhals/_pandas_like/dataframe.py | 32 ++++++++++--------- narwhals/_pandas_like/series.py | 4 +-- 6 files changed, 67 insertions(+), 63 deletions(-) diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index 230e83d263..346c881c69 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -248,46 +248,46 @@ def get_column(self, name: str) -> ArrowSeries: def __array__(self, dtype: Any, *, copy: bool | None) -> _2DArray: return self.native.__array__(dtype, copy=copy) - def _gather(self, item: SizedMultiIndexSelector) -> Self: - if len(item) == 0: + def _gather(self, rows: SizedMultiIndexSelector) -> Self: + if len(rows) == 0: return self._with_native(self.native.slice(0, 0)) - if self._backend_version < (18,) and isinstance(item, tuple): - item = list(item) - return self._with_native(self.native.take(item)) # pyright: ignore[reportArgumentType] + if self._backend_version < (18,) and isinstance(rows, tuple): + rows = list(rows) + return self._with_native(self.native.take(rows)) # pyright: ignore[reportArgumentType] - def _gather_slice(self, item: _SliceIndex | range) -> Self: - start = item.start or 0 - stop = item.stop if item.stop is not None else len(self.native) + def _gather_slice(self, rows: _SliceIndex | range) -> Self: + start = rows.start or 0 + stop = rows.stop if rows.stop is not None else len(self.native) if start < 0: start = len(self.native) + start if stop < 0: stop = len(self.native) + stop - if item.step is not None and item.step != 1: + if rows.step is not None and rows.step != 1: msg = "Slicing with step is not supported on PyArrow tables" raise NotImplementedError(msg) return self._with_native(self.native.slice(start, stop - start)) - def _select_slice_of_labels(self, item: _SliceName) -> Self: - start, stop, step = convert_str_slice_to_int_slice(item, self.columns) + def _select_slice_name(self, columns: _SliceName) -> Self: + start, stop, step = convert_str_slice_to_int_slice(columns, self.columns) return self._with_native(self.native.select(self.columns[start:stop:step])) - def _select_slice_of_indices(self, item: _SliceIndex | range) -> Self: + def _select_slice_index(self, columns: _SliceIndex | range) -> Self: return self._with_native( - self.native.select(self.columns[item.start : item.stop : item.step]) + self.native.select(self.columns[columns.start : columns.stop : columns.step]) ) - def _select_indices(self, item: SizedMultiIndexSelector) -> Self: - if isinstance(item, pa.ChunkedArray): - item = item.to_pylist() # pyright: ignore[reportAssignmentType] - if is_numpy_array(item): - item = item.tolist() - return self._with_native(self.native.select(cast("Indices", item))) - - def _select_labels(self, item: SizedMultiNameSelector) -> Self: - if isinstance(item, pa.ChunkedArray): - item = item.to_pylist() # pyright: ignore[reportAssignmentType] - # pyarrow-stubs overly strict, accept list[str] | Indices - return self._with_native(self.native.select(item)) # pyright: ignore[reportArgumentType] + def _select_indices(self, columns: SizedMultiIndexSelector) -> Self: + if isinstance(columns, pa.ChunkedArray): + columns = columns.to_pylist() # pyright: ignore[reportAssignmentType] + if is_numpy_array(columns): + columns = columns.tolist() + return self._with_native(self.native.select(cast("Indices", columns))) + + def _select_multi_name(self, columns: SizedMultiNameSelector) -> Self: + if isinstance(columns, pa.ChunkedArray): + columns = columns.to_pylist() # pyright: ignore[reportAssignmentType] + # pyarrow-stubs overly strict, accepts list[str] | Indices + return self._with_native(self.native.select(columns)) # pyright: ignore[reportArgumentType] @property def schema(self) -> dict[str, DType]: diff --git a/narwhals/_arrow/series.py b/narwhals/_arrow/series.py index 7536a52992..235a4b3225 100644 --- a/narwhals/_arrow/series.py +++ b/narwhals/_arrow/series.py @@ -407,21 +407,21 @@ def __native_namespace__(self) -> ModuleType: def name(self) -> str: return self._name - def _gather(self, item: SizedMultiIndexSelector) -> Self: - if len(item) == 0: + def _gather(self, rows: SizedMultiIndexSelector) -> Self: + if len(rows) == 0: return self._with_native(self.native.slice(0, 0)) - if self._backend_version < (18,) and isinstance(item, tuple): - item = list(item) - return self._with_native(self.native.take(item)) # pyright: ignore[reportArgumentType] + if self._backend_version < (18,) and isinstance(rows, tuple): + rows = list(rows) + return self._with_native(self.native.take(rows)) # pyright: ignore[reportArgumentType] - def _gather_slice(self, item: slice | range) -> Self: - start = item.start or 0 - stop = item.stop if item.stop is not None else len(self.native) + def _gather_slice(self, rows: slice | range) -> Self: + start = rows.start or 0 + stop = rows.stop if rows.stop is not None else len(self.native) if start < 0: start = len(self.native) + start if stop < 0: stop = len(self.native) + stop - if item.step is not None and item.step != 1: + if rows.step is not None and rows.step != 1: msg = "Slicing with step is not supported on PyArrow tables" raise NotImplementedError(msg) return self._with_native(self.native.slice(start, stop - start)) diff --git a/narwhals/_compliant/dataframe.py b/narwhals/_compliant/dataframe.py index c050d1eaae..70050eba00 100644 --- a/narwhals/_compliant/dataframe.py +++ b/narwhals/_compliant/dataframe.py @@ -396,16 +396,16 @@ def _numpy_column_names( ) -> list[str]: return list(columns or (f"column_{x}" for x in range(data.shape[1]))) - def _gather(self, indices: SizedMultiIndexSelector[CompliantSeriesAny]) -> Self: ... - def _gather_slice(self, indices: _SliceIndex | range) -> Self: ... + def _gather(self, rows: SizedMultiIndexSelector[CompliantSeriesAny]) -> Self: ... + def _gather_slice(self, rows: _SliceIndex | range) -> Self: ... def _select_indices( - self, indices: SizedMultiIndexSelector[CompliantSeriesAny] + self, columns: SizedMultiIndexSelector[CompliantSeriesAny] ) -> Self: ... - def _select_labels( - self, indices: SizedMultiNameSelector[CompliantSeriesAny] + def _select_multi_name( + self, columns: SizedMultiNameSelector[CompliantSeriesAny] ) -> Self: ... - def _select_slice_of_indices(self, indices: _SliceIndex | range) -> Self: ... - def _select_slice_of_labels(self, item: _SliceName) -> Self: ... + def _select_slice_index(self, columns: _SliceIndex | range) -> Self: ... + def _select_slice_name(self, columns: _SliceName) -> Self: ... def __getitem__( self, @@ -421,17 +421,17 @@ def __getitem__( return compliant.select() if is_index_selector(columns): if is_slice_index(columns) or is_range(columns): - compliant = compliant._select_slice_of_indices(columns) + compliant = compliant._select_slice_index(columns) elif is_compliant_series(columns): compliant = self._select_indices(columns.native) else: compliant = compliant._select_indices(columns) elif isinstance(columns, slice): - compliant = compliant._select_slice_of_labels(columns) + compliant = compliant._select_slice_name(columns) elif is_compliant_series(columns): - compliant = self._select_labels(columns.native) + compliant = self._select_multi_name(columns.native) elif is_sequence_like(columns): - compliant = self._select_labels(columns) + compliant = self._select_multi_name(columns) else: msg = f"Unreachable code, got unexpected type: {type(columns)}" raise AssertionError(msg) diff --git a/narwhals/_compliant/series.py b/narwhals/_compliant/series.py index d3827c5070..db14e944f2 100644 --- a/narwhals/_compliant/series.py +++ b/narwhals/_compliant/series.py @@ -321,8 +321,8 @@ def __narwhals_namespace__( def _to_expr(self) -> EagerExpr[Any, Any]: return self.__narwhals_namespace__()._expr._from_series(self) # type: ignore[no-any-return] - def _gather(self, indices: SizedMultiIndexSelector) -> Self: ... - def _gather_slice(self, indices: slice | range) -> Self: ... + def _gather(self, rows: SizedMultiIndexSelector) -> Self: ... + def _gather_slice(self, rows: slice | range) -> Self: ... def __getitem__(self, item: MultiIndexSelector) -> Self: if is_slice_none(item): diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index 53dc36412e..340aeb92a9 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -281,32 +281,36 @@ def get_column(self, name: str) -> PandasLikeSeries: def __array__(self, dtype: Any = None, *, copy: bool | None = None) -> _2DArray: return self.to_numpy(dtype=dtype, copy=copy) - def _gather(self, items: SizedMultiIndexSelector) -> Self: - items = list(items) if isinstance(items, tuple) else items - return self._with_native(self.native.iloc[items, :]) + def _gather(self, rows: SizedMultiIndexSelector) -> Self: + rows = list(rows) if isinstance(rows, tuple) else rows + return self._with_native(self.native.iloc[rows, :]) - def _gather_slice(self, item: _SliceIndex | range) -> Self: + def _gather_slice(self, rows: _SliceIndex | range) -> Self: return self._with_native( - self.native.iloc[slice(item.start, item.stop, item.step), :], + self.native.iloc[slice(rows.start, rows.stop, rows.step), :], validate_column_names=False, ) - def _select_slice_of_labels(self, item: _SliceName) -> Self: - start, stop, step = convert_str_slice_to_int_slice(item, self.native.columns) + def _select_slice_name(self, columns: _SliceName) -> Self: + start, stop, step = convert_str_slice_to_int_slice(columns, self.native.columns) return self._with_native( self.native.iloc[:, slice(start, stop, step)], validate_column_names=False, ) - def _select_slice_of_indices(self, item: _SliceIndex | range) -> Self: - return self._with_native(self.native.iloc[:, item], validate_column_names=False) + def _select_slice_index(self, columns: _SliceIndex | range) -> Self: + return self._with_native( + self.native.iloc[:, columns], validate_column_names=False + ) - def _select_indices(self, item: SizedMultiIndexSelector) -> Self: - item = list(item) if isinstance(item, tuple) else item - return self._with_native(self.native.iloc[:, item], validate_column_names=False) + def _select_indices(self, columns: SizedMultiIndexSelector) -> Self: + columns = list(columns) if isinstance(columns, tuple) else columns + return self._with_native( + self.native.iloc[:, columns], validate_column_names=False + ) - def _select_labels(self, indices: SizedMultiNameSelector) -> PandasLikeDataFrame: - return self._with_native(self.native.loc[:, indices]) + def _select_multi_name(self, columns: SizedMultiNameSelector) -> PandasLikeDataFrame: + return self._with_native(self.native.loc[:, columns]) # --- properties --- @property diff --git a/narwhals/_pandas_like/series.py b/narwhals/_pandas_like/series.py index 01e4df0052..81bf3b38ee 100644 --- a/narwhals/_pandas_like/series.py +++ b/narwhals/_pandas_like/series.py @@ -150,9 +150,9 @@ def _gather(self, rows: SizedMultiIndexSelector) -> Self: rows = list(rows) if isinstance(rows, tuple) else rows return self._with_native(self.native.iloc[rows]) - def _gather_slice(self, item: slice | range) -> Self: + def _gather_slice(self, rows: slice | range) -> Self: return self._with_native( - self.native.iloc[slice(item.start, item.stop, item.step)] + self.native.iloc[slice(rows.start, rows.stop, rows.step)] ) def _with_version(self, version: Version) -> Self: From 230ce6d8c378721c0c972a6676694eac93105bea Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Tue, 22 Apr 2025 13:12:09 +0100 Subject: [PATCH 73/80] remove another Any + dead code --- narwhals/_arrow/dataframe.py | 19 +++++++++++-------- narwhals/_arrow/series.py | 2 +- narwhals/_compliant/dataframe.py | 18 ++++++++++++------ narwhals/_compliant/namespace.py | 7 ++++--- narwhals/_compliant/series.py | 17 +++++++---------- narwhals/_pandas_like/dataframe.py | 8 +++++--- narwhals/_pandas_like/series.py | 2 +- narwhals/_polars/dataframe.py | 6 +++--- narwhals/series.py | 1 + pyproject.toml | 4 +--- 10 files changed, 46 insertions(+), 38 deletions(-) diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index 346c881c69..fa04644757 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -49,7 +49,6 @@ from narwhals._arrow.group_by import ArrowGroupBy from narwhals._arrow.namespace import ArrowNamespace from narwhals._arrow.typing import ArrowChunkedArray - from narwhals._arrow.typing import Indices # type: ignore[attr-defined] from narwhals._arrow.typing import Mask # type: ignore[attr-defined] from narwhals._arrow.typing import Order # type: ignore[attr-defined] from narwhals._translate import IntoArrowTable @@ -248,7 +247,7 @@ def get_column(self, name: str) -> ArrowSeries: def __array__(self, dtype: Any, *, copy: bool | None) -> _2DArray: return self.native.__array__(dtype, copy=copy) - def _gather(self, rows: SizedMultiIndexSelector) -> Self: + def _gather(self, rows: SizedMultiIndexSelector[ArrowChunkedArray]) -> Self: if len(rows) == 0: return self._with_native(self.native.slice(0, 0)) if self._backend_version < (18,) and isinstance(rows, tuple): @@ -276,16 +275,20 @@ def _select_slice_index(self, columns: _SliceIndex | range) -> Self: self.native.select(self.columns[columns.start : columns.stop : columns.step]) ) - def _select_indices(self, columns: SizedMultiIndexSelector) -> Self: + def _select_indices( + self, columns: SizedMultiIndexSelector[ArrowChunkedArray] + ) -> Self: if isinstance(columns, pa.ChunkedArray): - columns = columns.to_pylist() # pyright: ignore[reportAssignmentType] + columns = cast("list[int]", columns.to_pylist()) if is_numpy_array(columns): - columns = columns.tolist() - return self._with_native(self.native.select(cast("Indices", columns))) + columns = cast("list[int]", columns.tolist()) + return self._with_native(self.native.select(columns)) - def _select_multi_name(self, columns: SizedMultiNameSelector) -> Self: + def _select_multi_name( + self, columns: SizedMultiNameSelector[ArrowChunkedArray] + ) -> Self: if isinstance(columns, pa.ChunkedArray): - columns = columns.to_pylist() # pyright: ignore[reportAssignmentType] + columns = cast("list[str]", columns.to_pylist()) # pyarrow-stubs overly strict, accepts list[str] | Indices return self._with_native(self.native.select(columns)) # pyright: ignore[reportArgumentType] diff --git a/narwhals/_arrow/series.py b/narwhals/_arrow/series.py index 235a4b3225..2fe1267dcd 100644 --- a/narwhals/_arrow/series.py +++ b/narwhals/_arrow/series.py @@ -407,7 +407,7 @@ def __native_namespace__(self) -> ModuleType: def name(self) -> str: return self._name - def _gather(self, rows: SizedMultiIndexSelector) -> Self: + def _gather(self, rows: SizedMultiIndexSelector[ArrowChunkedArray]) -> Self: if len(rows) == 0: return self._with_native(self.native.slice(0, 0)) if self._backend_version < (18,) and isinstance(rows, tuple): diff --git a/narwhals/_compliant/dataframe.py b/narwhals/_compliant/dataframe.py index 70050eba00..29be5f5024 100644 --- a/narwhals/_compliant/dataframe.py +++ b/narwhals/_compliant/dataframe.py @@ -381,7 +381,9 @@ def _evaluate_into_expr(self, expr: EagerExprT, /) -> Sequence[EagerSeriesT]: """ _, aliases = evaluate_output_names_and_aliases(expr, self, []) result = expr(self) - if list(aliases) != (result_aliases := [s.name for s in result]): + if list(aliases) != ( + result_aliases := [s.name for s in result] + ): # pragma: no cover msg = f"Safety assertion failed, expected {aliases}, got {result_aliases}" raise AssertionError(msg) return result @@ -396,13 +398,17 @@ def _numpy_column_names( ) -> list[str]: return list(columns or (f"column_{x}" for x in range(data.shape[1]))) - def _gather(self, rows: SizedMultiIndexSelector[CompliantSeriesAny]) -> Self: ... + def _gather(self, rows: SizedMultiIndexSelector[Any]) -> Self: ... def _gather_slice(self, rows: _SliceIndex | range) -> Self: ... def _select_indices( - self, columns: SizedMultiIndexSelector[CompliantSeriesAny] + # TODO(unassigned): `Any` should be `NativeSeriesT` + self, + columns: SizedMultiIndexSelector[Any], ) -> Self: ... def _select_multi_name( - self, columns: SizedMultiNameSelector[CompliantSeriesAny] + # TODO(unassigned): `Any` should be `NativeSeriesT` + self, + columns: SizedMultiNameSelector[Any], ) -> Self: ... def _select_slice_index(self, columns: _SliceIndex | range) -> Self: ... def _select_slice_name(self, columns: _SliceName) -> Self: ... @@ -432,7 +438,7 @@ def __getitem__( compliant = self._select_multi_name(columns.native) elif is_sequence_like(columns): compliant = self._select_multi_name(columns) - else: + else: # pragma: no cover msg = f"Unreachable code, got unexpected type: {type(columns)}" raise AssertionError(msg) @@ -445,7 +451,7 @@ def __getitem__( compliant = compliant._gather(rows.native) elif is_sized_multi_index_selector(rows): compliant = compliant._gather(rows) - else: + else: # pragma: no cover msg = f"Unreachable code, got unexpected type: {type(rows)}" raise AssertionError(msg) diff --git a/narwhals/_compliant/namespace.py b/narwhals/_compliant/namespace.py index ac8d922e9f..4fba98de1d 100644 --- a/narwhals/_compliant/namespace.py +++ b/narwhals/_compliant/namespace.py @@ -127,8 +127,9 @@ def _lazyframe(self) -> type[CompliantLazyFrameT]: ... def from_native(self, data: NativeFrameT_co | Any, /) -> CompliantLazyFrameT: if self._lazyframe._is_native(data): return self._lazyframe.from_native(data, context=self) - msg = f"Unsupported type: {type(data).__name__!r}" - raise TypeError(msg) + else: # pragma: no cover + msg = f"Unsupported type: {type(data).__name__!r}" + raise TypeError(msg) class EagerNamespace( @@ -198,6 +199,6 @@ def concat( native = self._concat_vertical(dfs) elif how == "diagonal": native = self._concat_diagonal(dfs) - else: + else: # pragma: no cover raise NotImplementedError return self._dataframe.from_native(native, context=self) diff --git a/narwhals/_compliant/series.py b/narwhals/_compliant/series.py index db14e944f2..5741732817 100644 --- a/narwhals/_compliant/series.py +++ b/narwhals/_compliant/series.py @@ -25,7 +25,6 @@ from narwhals.utils import _StoresNative from narwhals.utils import is_compliant_series from narwhals.utils import is_sized_multi_index_selector -from narwhals.utils import is_slice_none from narwhals.utils import unstable if TYPE_CHECKING: @@ -83,7 +82,7 @@ def __narwhals_namespace__(self) -> CompliantNamespace[Any, Any]: ... def __native_namespace__(self) -> ModuleType: ... def __array__(self, dtype: Any, *, copy: bool | None) -> _1DArray: ... def __contains__(self, other: Any) -> bool: ... - def __getitem__(self, item: Any) -> Any: ... + def __getitem__(self, item: MultiIndexSelector[Self]) -> Any: ... def __iter__(self) -> Iterator[Any]: ... def __len__(self) -> int: return len(self.native) @@ -321,21 +320,19 @@ def __narwhals_namespace__( def _to_expr(self) -> EagerExpr[Any, Any]: return self.__narwhals_namespace__()._expr._from_series(self) # type: ignore[no-any-return] - def _gather(self, rows: SizedMultiIndexSelector) -> Self: ... + def _gather(self, rows: SizedMultiIndexSelector[NativeSeriesT]) -> Self: ... def _gather_slice(self, rows: slice | range) -> Self: ... - def __getitem__(self, item: MultiIndexSelector) -> Self: - if is_slice_none(item): - return self - if isinstance(item, int): - return self._gather([item]) - elif isinstance(item, (slice, range)): + def __getitem__( + self, item: MultiIndexSelector[CompliantSeries[NativeSeriesT]] + ) -> Self: + if isinstance(item, (slice, range)): return self._gather_slice(item) elif is_compliant_series(item): return self._gather(item.native) elif is_sized_multi_index_selector(item): return self._gather(item) - else: + else: # pragma: no cover msg = f"Unreachable code, got unexpected type: {type(item)}" raise AssertionError(msg) diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index 340aeb92a9..59b5fe3e9e 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -281,7 +281,7 @@ def get_column(self, name: str) -> PandasLikeSeries: def __array__(self, dtype: Any = None, *, copy: bool | None = None) -> _2DArray: return self.to_numpy(dtype=dtype, copy=copy) - def _gather(self, rows: SizedMultiIndexSelector) -> Self: + def _gather(self, rows: SizedMultiIndexSelector[pd.Series[Any]]) -> Self: rows = list(rows) if isinstance(rows, tuple) else rows return self._with_native(self.native.iloc[rows, :]) @@ -303,13 +303,15 @@ def _select_slice_index(self, columns: _SliceIndex | range) -> Self: self.native.iloc[:, columns], validate_column_names=False ) - def _select_indices(self, columns: SizedMultiIndexSelector) -> Self: + def _select_indices(self, columns: SizedMultiIndexSelector[pd.Series[Any]]) -> Self: columns = list(columns) if isinstance(columns, tuple) else columns return self._with_native( self.native.iloc[:, columns], validate_column_names=False ) - def _select_multi_name(self, columns: SizedMultiNameSelector) -> PandasLikeDataFrame: + def _select_multi_name( + self, columns: SizedMultiNameSelector[pd.Series[Any]] + ) -> PandasLikeDataFrame: return self._with_native(self.native.loc[:, columns]) # --- properties --- diff --git a/narwhals/_pandas_like/series.py b/narwhals/_pandas_like/series.py index 81bf3b38ee..2f66089bc9 100644 --- a/narwhals/_pandas_like/series.py +++ b/narwhals/_pandas_like/series.py @@ -146,7 +146,7 @@ def __narwhals_namespace__(self) -> PandasLikeNamespace: self._implementation, self._backend_version, self._version ) - def _gather(self, rows: SizedMultiIndexSelector) -> Self: + def _gather(self, rows: SizedMultiIndexSelector[pd.Series[Any]]) -> Self: rows = list(rows) if isinstance(rows, tuple) else rows return self._with_native(self.native.iloc[rows]) diff --git a/narwhals/_polars/dataframe.py b/narwhals/_polars/dataframe.py index d3fb60faae..d439dbf599 100644 --- a/narwhals/_polars/dataframe.py +++ b/narwhals/_polars/dataframe.py @@ -300,9 +300,9 @@ def __getitem__( self.columns[slice(columns.start, columns.stop, columns.step)] ) elif is_compliant_series(columns): - native = native[:, columns.native.to_list()] + native = native[:, columns.native.to_list()] # type: ignore[attr-defined,index] else: - native = native[:, columns] # type: ignore[index] + native = native[:, columns] elif isinstance(columns, slice): native = native.select( self.columns[ @@ -325,7 +325,7 @@ def __getitem__( elif is_compliant_series(rows): native = native[rows.native, :] # pyright: ignore[reportArgumentType,reportCallIssue] elif is_sequence_like(rows): - native = native[rows, :] # type: ignore[index] + native = native[rows, :] else: msg = f"Unreachable code, got unexpected type: {type(rows)}" raise AssertionError(msg) diff --git a/narwhals/series.py b/narwhals/series.py index 02d9f5d673..c55b526750 100644 --- a/narwhals/series.py +++ b/narwhals/series.py @@ -189,6 +189,7 @@ def __getitem__(self, idx: SingleIndexSelector | MultiIndexSelector) -> Any | Se raise TypeError(msg) if isinstance(idx, Series): return self._with_compliant(self._compliant_series[idx._compliant_series]) + assert not isinstance(idx, int) # noqa: S101 # help mypy return self._with_compliant(self._compliant_series[idx]) def __native_namespace__(self) -> ModuleType: diff --git a/pyproject.toml b/pyproject.toml index d5910e7ac6..7ef7fd2589 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -247,10 +247,8 @@ omit = [ 'narwhals/_spark_like/typing.py', # we can't run this in every environment that we measure coverage on due to upper-bound constraits 'narwhals/_ibis/*', - # we don't run these in every environment - 'tests/ibis_test.py', # Remove after finishing eager sub-protocols - 'narwhals/_compliant/*', + 'narwhals/_compliant/namespace.py', ] exclude_also = [ "if sys.version_info() <", From 427c9c9158f55b6f9c7ef13665bf8455fde89949 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Tue, 22 Apr 2025 16:14:47 +0200 Subject: [PATCH 74/80] remove dead code --- narwhals/_compliant/series.py | 4 +--- narwhals/_polars/series.py | 2 ++ 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/narwhals/_compliant/series.py b/narwhals/_compliant/series.py index 6af2ba1cef..d33dcf8ad1 100644 --- a/narwhals/_compliant/series.py +++ b/narwhals/_compliant/series.py @@ -324,9 +324,7 @@ def _to_expr(self) -> EagerExpr[Any, Any]: def _gather(self, rows: SizedMultiIndexSelector[NativeSeriesT]) -> Self: ... def _gather_slice(self, rows: _SliceIndex | range) -> Self: ... def __getitem__(self, item: Any) -> Self: - if isinstance(item, int): - return self._gather([item]) - elif isinstance(item, (slice, range)): + if isinstance(item, (slice, range)): return self._gather_slice(item) elif is_compliant_series(item): return self._gather(item.native) diff --git a/narwhals/_polars/series.py b/narwhals/_polars/series.py index b38c1ef0bb..cf3b50226d 100644 --- a/narwhals/_polars/series.py +++ b/narwhals/_polars/series.py @@ -590,6 +590,8 @@ def struct(self) -> PolarsSeriesStructNamespace: __ror__: Method[Self] __rtruediv__: Method[Self] __truediv__: Method[Self] + _gather: Method[Self] + _gather_slice: Method[Self] abs: Method[Self] all: Method[bool] any: Method[bool] From dc448ed2888a154220398c6f645d7df5de208a4e Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Tue, 22 Apr 2025 16:17:34 +0200 Subject: [PATCH 75/80] fill in Any annotation --- narwhals/_compliant/series.py | 4 +++- narwhals/_polars/namespace.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/narwhals/_compliant/series.py b/narwhals/_compliant/series.py index d33dcf8ad1..2f1cc7842e 100644 --- a/narwhals/_compliant/series.py +++ b/narwhals/_compliant/series.py @@ -323,7 +323,9 @@ def _to_expr(self) -> EagerExpr[Any, Any]: def _gather(self, rows: SizedMultiIndexSelector[NativeSeriesT]) -> Self: ... def _gather_slice(self, rows: _SliceIndex | range) -> Self: ... - def __getitem__(self, item: Any) -> Self: + def __getitem__( + self, item: MultiIndexSelector[CompliantSeries[NativeSeriesT]] + ) -> Self: if isinstance(item, (slice, range)): return self._gather_slice(item) elif is_compliant_series(item): diff --git a/narwhals/_polars/namespace.py b/narwhals/_polars/namespace.py index af64f46806..ce97994857 100644 --- a/narwhals/_polars/namespace.py +++ b/narwhals/_polars/namespace.py @@ -251,7 +251,7 @@ def concat_str( @property def selectors(self) -> CompliantSelectorNamespace[PolarsDataFrame, PolarsSeries]: return cast( - "CompliantSelectorNamespace[PolarsDataFrame, PolarsSeries]", + "CompliantSelectorNamespace[PolarsDataFrame, PolarsSeries]", # pyright: ignore[reportInvalidTypeArguments] PolarsSelectorNamespace(self), ) From e8e5a8a47b22f21a9e4cef5c3806db3e2cb47122 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Tue, 22 Apr 2025 16:21:20 +0200 Subject: [PATCH 76/80] better typing in _polars --- narwhals/_polars/series.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/narwhals/_polars/series.py b/narwhals/_polars/series.py index cf3b50226d..abcbb54fd8 100644 --- a/narwhals/_polars/series.py +++ b/narwhals/_polars/series.py @@ -37,6 +37,7 @@ from narwhals._polars.namespace import PolarsNamespace from narwhals.dtypes import DType from narwhals.typing import Into1DArray + from narwhals.typing import MultiIndexSelector from narwhals.typing import _1DArray from narwhals.utils import Version from narwhals.utils import _FullContext @@ -182,7 +183,7 @@ def __getitem__(self, item: int) -> Any: ... @overload def __getitem__(self, item: slice | Sequence[int] | pl.Series) -> Self: ... - def __getitem__(self, item: int | slice | Sequence[int] | pl.Series) -> Any | Self: + def __getitem__(self, item: MultiIndexSelector[pl.Series]) -> Any | Self: if is_compliant_series(item): item = item.native return self._from_native_object(self.native.__getitem__(item)) From bfb1cf0999f7091be4cccd32b47f1cd3745a74cc Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Tue, 22 Apr 2025 16:23:40 +0200 Subject: [PATCH 77/80] yay figured out the pyright error :party: --- narwhals/_polars/namespace.py | 2 +- narwhals/_polars/series.py | 8 +------- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/narwhals/_polars/namespace.py b/narwhals/_polars/namespace.py index ce97994857..af64f46806 100644 --- a/narwhals/_polars/namespace.py +++ b/narwhals/_polars/namespace.py @@ -251,7 +251,7 @@ def concat_str( @property def selectors(self) -> CompliantSelectorNamespace[PolarsDataFrame, PolarsSeries]: return cast( - "CompliantSelectorNamespace[PolarsDataFrame, PolarsSeries]", # pyright: ignore[reportInvalidTypeArguments] + "CompliantSelectorNamespace[PolarsDataFrame, PolarsSeries]", PolarsSelectorNamespace(self), ) diff --git a/narwhals/_polars/series.py b/narwhals/_polars/series.py index abcbb54fd8..c3e7ba73ad 100644 --- a/narwhals/_polars/series.py +++ b/narwhals/_polars/series.py @@ -177,13 +177,7 @@ def native(self) -> pl.Series: def alias(self, name: str) -> Self: return self._from_native_object(self.native.alias(name)) - @overload - def __getitem__(self, item: int) -> Any: ... - - @overload - def __getitem__(self, item: slice | Sequence[int] | pl.Series) -> Self: ... - - def __getitem__(self, item: MultiIndexSelector[pl.Series]) -> Any | Self: + def __getitem__(self, item: MultiIndexSelector[Self]) -> Any | Self: if is_compliant_series(item): item = item.native return self._from_native_object(self.native.__getitem__(item)) From b72560617fd7c215105f65886711ed6c304098c3 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Tue, 22 Apr 2025 16:26:42 +0200 Subject: [PATCH 78/80] remove incorrect Sequence[str] --- narwhals/_arrow/dataframe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index c41b6e6cd7..b1fcf85e94 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -281,7 +281,7 @@ def _select_slice_index(self, columns: _SliceIndex | range) -> Self: def _select_multi_index( self, columns: SizedMultiIndexSelector[ArrowChunkedArray] ) -> Self: - selector: Sequence[int] | Sequence[str] + selector: Sequence[int] if isinstance(columns, pa.ChunkedArray): # TODO @dangotbanned: Fix upstream with `pa.ChunkedArray.to_pylist(self) -> list[Any]:` selector = cast("Sequence[int]", columns.to_pylist()) From 98fe85066d349e7d0cd19e875061a3bb27fd744c Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Tue, 22 Apr 2025 16:37:40 +0200 Subject: [PATCH 79/80] fixup pyright --- narwhals/_polars/series.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/narwhals/_polars/series.py b/narwhals/_polars/series.py index c3e7ba73ad..ea8ed4d31e 100644 --- a/narwhals/_polars/series.py +++ b/narwhals/_polars/series.py @@ -18,7 +18,6 @@ from narwhals._polars.utils import native_to_narwhals_dtype from narwhals.dependencies import is_numpy_array_1d from narwhals.utils import Implementation -from narwhals.utils import is_compliant_series from narwhals.utils import requires from narwhals.utils import validate_backend_version @@ -178,8 +177,8 @@ def alias(self, name: str) -> Self: return self._from_native_object(self.native.alias(name)) def __getitem__(self, item: MultiIndexSelector[Self]) -> Any | Self: - if is_compliant_series(item): - item = item.native + if isinstance(item, PolarsSeries): + return self._from_native_object(self.native.__getitem__(item.native)) return self._from_native_object(self.native.__getitem__(item)) def cast(self, dtype: DType | type[DType]) -> Self: @@ -585,8 +584,6 @@ def struct(self) -> PolarsSeriesStructNamespace: __ror__: Method[Self] __rtruediv__: Method[Self] __truediv__: Method[Self] - _gather: Method[Self] - _gather_slice: Method[Self] abs: Method[Self] all: Method[bool] any: Method[bool] From 5b02b592183b8d39e2d32e0aedd6c234bb22d405 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Tue, 22 Apr 2025 17:23:55 +0200 Subject: [PATCH 80/80] coverage again --- pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 7ef7fd2589..d97ea34bd8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -247,6 +247,8 @@ omit = [ 'narwhals/_spark_like/typing.py', # we can't run this in every environment that we measure coverage on due to upper-bound constraits 'narwhals/_ibis/*', + # we don't (yet) run these in every environment + 'tests/ibis_test.py', # Remove after finishing eager sub-protocols 'narwhals/_compliant/namespace.py', ]