diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index be311f1a31..b1fcf85e94 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -16,19 +16,17 @@ from narwhals._arrow.series import ArrowSeries from narwhals._arrow.utils import align_series_full_broadcast -from narwhals._arrow.utils import convert_str_slice_to_int_slice from narwhals._arrow.utils import native_to_narwhals_dtype -from narwhals._arrow.utils import select_rows from narwhals._compliant import EagerDataFrame from narwhals._expression_parsing import ExprKind -from narwhals.dependencies import is_numpy_array_1d +from narwhals.dependencies import is_numpy_array from narwhals.exceptions import ShapeError from narwhals.utils import Implementation from narwhals.utils import Version from narwhals.utils import check_column_exists from narwhals.utils import check_column_names_are_unique +from narwhals.utils import convert_str_slice_to_int_slice from narwhals.utils import generate_temporary_column_name -from narwhals.utils import is_sequence_but_not_str from narwhals.utils import not_implemented from narwhals.utils import parse_columns_to_drop from narwhals.utils import parse_version @@ -51,7 +49,6 @@ from narwhals._arrow.group_by import ArrowGroupBy from narwhals._arrow.namespace import ArrowNamespace from narwhals._arrow.typing import ArrowChunkedArray - from narwhals._arrow.typing import Indices # type: ignore[attr-defined] from narwhals._arrow.typing import Mask # type: ignore[attr-defined] from narwhals._arrow.typing import Order # type: ignore[attr-defined] from narwhals._translate import IntoArrowTable @@ -60,10 +57,14 @@ from narwhals.typing import CompliantDataFrame from narwhals.typing import CompliantLazyFrame from narwhals.typing import JoinStrategy + from narwhals.typing import SizedMultiIndexSelector + from narwhals.typing import SizedMultiNameSelector from narwhals.typing import SizeUnit from narwhals.typing import UniqueKeepStrategy from narwhals.typing import _1DArray from narwhals.typing import _2DArray + from narwhals.typing import _SliceIndex + from narwhals.typing import _SliceName from narwhals.utils import Version from narwhals.utils import _FullContext @@ -80,8 +81,9 @@ PromoteOptions: TypeAlias = Literal["none", "default", "permissive"] -class ArrowDataFrame(EagerDataFrame["ArrowSeries", "ArrowExpr", "pa.Table"]): - # --- not in the spec --- +class ArrowDataFrame( + EagerDataFrame["ArrowSeries", "ArrowExpr", "pa.Table", "pa.ChunkedArray[Any]"] +): def __init__( self, native_dataframe: pa.Table, @@ -248,118 +250,61 @@ def get_column(self, name: str) -> ArrowSeries: def __array__(self, dtype: Any, *, copy: bool | None) -> _2DArray: return self.native.__array__(dtype, copy=copy) - @overload - def __getitem__( # type: ignore[overload-overlap, unused-ignore] - self, item: str | tuple[slice | Sequence[int] | _1DArray, int | str] - ) -> ArrowSeries: ... - @overload - def __getitem__( - self, - item: ( - int - | slice - | Sequence[int] - | Sequence[str] - | _1DArray - | tuple[ - slice | Sequence[int] | _1DArray, slice | Sequence[int] | Sequence[str] - ] - ), - ) -> Self: ... - def __getitem__( - self, - item: ( - str - | int - | slice - | Sequence[int] - | Sequence[str] - | _1DArray - | tuple[slice | Sequence[int] | _1DArray, int | str] - | tuple[ - slice | Sequence[int] | _1DArray, slice | Sequence[int] | Sequence[str] - ] - ), - ) -> ArrowSeries | Self: - if isinstance(item, tuple): - item = tuple(list(i) if is_sequence_but_not_str(i) else i for i in item) # pyright: ignore[reportAssignmentType] - - if isinstance(item, str): - return ArrowSeries.from_native(self.native[item], context=self, name=item) - elif ( - isinstance(item, tuple) - and len(item) == 2 - and is_sequence_but_not_str(item[1]) - and not isinstance(item[0], str) - ): - if len(item[1]) == 0: - # Return empty dataframe - return self._with_native(self.native.slice(0, 0).select([])) - selected_rows = select_rows(self.native, item[0]) - return self._with_native(selected_rows.select(cast("Indices", item[1]))) - - elif isinstance(item, tuple) and len(item) == 2: - if isinstance(item[1], slice): - columns = self.columns - indices = cast("Indices", item[0]) - if item[1] == slice(None): - if isinstance(item[0], Sequence) and len(item[0]) == 0: - return self._with_native(self.native.slice(0, 0)) - return self._with_native(self.native.take(indices)) - if isinstance(item[1].start, str) or isinstance(item[1].stop, str): - start, stop, step = convert_str_slice_to_int_slice(item[1], columns) - return self._with_native( - self.native.take(indices).select(columns[start:stop:step]) - ) - if isinstance(item[1].start, int) or isinstance(item[1].stop, int): - return self._with_native( - self.native.take(indices).select( - columns[item[1].start : item[1].stop : item[1].step] - ) - ) - msg = f"Expected slice of integers or strings, got: {type(item[1])}" # pragma: no cover - raise TypeError(msg) # pragma: no cover - - # PyArrow columns are always strings - col_name = ( - item[1] - if isinstance(item[1], str) - else self.columns[cast("int", item[1])] - ) - if isinstance(item[0], str): # pragma: no cover - msg = "Can not slice with tuple with the first element as a str" - raise TypeError(msg) - if (isinstance(item[0], slice)) and (item[0] == slice(None)): - return ArrowSeries.from_native( - self.native[col_name], context=self, name=col_name - ) - selected_rows = select_rows(self.native, item[0]) - return ArrowSeries.from_native( - selected_rows[col_name], context=self, name=col_name - ) + def _gather(self, rows: SizedMultiIndexSelector[ArrowChunkedArray]) -> Self: + if len(rows) == 0: + return self._with_native(self.native.slice(0, 0)) + if self._backend_version < (18,) and isinstance(rows, tuple): + rows = list(rows) + return self._with_native(self.native.take(rows)) # pyright: ignore[reportArgumentType] + + def _gather_slice(self, rows: _SliceIndex | range) -> Self: + start = rows.start or 0 + stop = rows.stop if rows.stop is not None else len(self.native) + if start < 0: + start = len(self.native) + start + if stop < 0: + stop = len(self.native) + stop + if rows.step is not None and rows.step != 1: + msg = "Slicing with step is not supported on PyArrow tables" + raise NotImplementedError(msg) + return self._with_native(self.native.slice(start, stop - start)) + + def _select_slice_name(self, columns: _SliceName) -> Self: + start, stop, step = convert_str_slice_to_int_slice(columns, self.columns) + return self._with_native(self.native.select(self.columns[start:stop:step])) + + def _select_slice_index(self, columns: _SliceIndex | range) -> Self: + return self._with_native( + self.native.select(self.columns[columns.start : columns.stop : columns.step]) + ) - elif isinstance(item, slice): - if item.step is not None and item.step != 1: - msg = "Slicing with step is not supported on PyArrow tables" - raise NotImplementedError(msg) - columns = self.columns - if isinstance(item.start, str) or isinstance(item.stop, str): - start, stop, step = convert_str_slice_to_int_slice(item, columns) - return self._with_native(self.native.select(columns[start:stop:step])) - start = item.start or 0 - stop = item.stop if item.stop is not None else len(self.native) - return self._with_native(self.native.slice(start, stop - start)) - - elif isinstance(item, Sequence) or is_numpy_array_1d(item): - if isinstance(item, Sequence) and len(item) > 0 and isinstance(item[0], str): - return self._with_native(self.native.select(cast("Indices", item))) - if isinstance(item, Sequence) and len(item) == 0: - return self._with_native(self.native.slice(0, 0)) - return self._with_native(self.native.take(cast("Indices", item))) + def _select_multi_index( + self, columns: SizedMultiIndexSelector[ArrowChunkedArray] + ) -> Self: + selector: Sequence[int] + if isinstance(columns, pa.ChunkedArray): + # TODO @dangotbanned: Fix upstream with `pa.ChunkedArray.to_pylist(self) -> list[Any]:` + selector = cast("Sequence[int]", columns.to_pylist()) + # TODO @dangotbanned: Fix upstream, it is actually much narrower + # **Doesn't accept `ndarray`** + elif is_numpy_array(columns): + selector = columns.tolist() + else: + selector = columns + return self._with_native(self.native.select(selector)) - else: # pragma: no cover - msg = f"Expected str or slice, got: {type(item)}" - raise TypeError(msg) + def _select_multi_name( + self, columns: SizedMultiNameSelector[ArrowChunkedArray] + ) -> Self: + selector: Sequence[str] | _1DArray + if isinstance(columns, pa.ChunkedArray): + # TODO @dangotbanned: Fix upstream with `pa.ChunkedArray.to_pylist(self) -> list[Any]:` + selector = cast("Sequence[str]", columns.to_pylist()) + else: + selector = columns + # TODO @dangotbanned: Fix upstream `pa.Table.select` https://github.com/zen-xu/pyarrow-stubs/blob/f899bb35e10b36f7906a728e9f8acf3e0a1f9f64/pyarrow-stubs/__lib_pxi/table.pyi#L597 + # NOTE: Investigate what `cython` actually checks + return self._with_native(self.native.select(selector)) # pyright: ignore[reportArgumentType] @property def schema(self) -> dict[str, DType]: diff --git a/narwhals/_arrow/series.py b/narwhals/_arrow/series.py index 0dcae8fb10..b08b5374ca 100644 --- a/narwhals/_arrow/series.py +++ b/narwhals/_arrow/series.py @@ -67,9 +67,11 @@ from narwhals.typing import PythonLiteral from narwhals.typing import RankMethod from narwhals.typing import RollingInterpolationMethod + from narwhals.typing import SizedMultiIndexSelector from narwhals.typing import TemporalLiteral from narwhals.typing import _1DArray from narwhals.typing import _2DArray + from narwhals.typing import _SliceIndex from narwhals.utils import Version from narwhals.utils import _FullContext @@ -406,20 +408,24 @@ def __native_namespace__(self) -> ModuleType: def name(self) -> str: return self._name - @overload - def __getitem__(self, idx: int) -> Any: ... - - @overload - def __getitem__(self, idx: slice | Sequence[int] | ArrowChunkedArray) -> Self: ... - - def __getitem__( - self, idx: int | slice | Sequence[int] | ArrowChunkedArray - ) -> Any | Self: - if isinstance(idx, int): - return maybe_extract_py_scalar(self.native[idx], return_py_scalar=True) - if isinstance(idx, (Sequence, pa.ChunkedArray)): - return self._with_native(self.native.take(idx)) - return self._with_native(self.native[idx]) + def _gather(self, rows: SizedMultiIndexSelector[ArrowChunkedArray]) -> Self: + if len(rows) == 0: + return self._with_native(self.native.slice(0, 0)) + if self._backend_version < (18,) and isinstance(rows, tuple): + rows = list(rows) + return self._with_native(self.native.take(rows)) + + def _gather_slice(self, rows: _SliceIndex | range) -> Self: + start = rows.start or 0 + stop = rows.stop if rows.stop is not None else len(self.native) + if start < 0: + start = len(self.native) + start + if stop < 0: + stop = len(self.native) + stop + if rows.step is not None and rows.step != 1: + msg = "Slicing with step is not supported on PyArrow tables" + raise NotImplementedError(msg) + return self._with_native(self.native.slice(start, stop - start)) def scatter(self, indices: int | Sequence[int], values: Any) -> Self: import numpy as np # ignore-banned-import @@ -911,7 +917,7 @@ def rolling_sum(self, window_size: int, *, min_samples: int, center: bool) -> Se result = self._with_native( pc.if_else((count_in_window >= min_samples).native, rolling_sum.native, None) ) - return result[offset:] + return result._gather_slice(slice(offset, None)) def rolling_mean(self, window_size: int, *, min_samples: int, center: bool) -> Self: min_samples = min_samples if min_samples is not None else window_size @@ -940,7 +946,7 @@ def rolling_mean(self, window_size: int, *, min_samples: int, center: bool) -> S ) / count_in_window ) - return result[offset:] + return result._gather_slice(slice(offset, None)) def rolling_var( self, window_size: int, *, min_samples: int, center: bool, ddof: int @@ -983,7 +989,7 @@ def rolling_var( ) ) / self._with_native(pc.max_element_wise((count_in_window - ddof).native, 0)) - return result[offset:] + return result._gather_slice(slice(offset, None, None)) def rolling_std( self, window_size: int, *, min_samples: int, center: bool, ddof: int diff --git a/narwhals/_arrow/utils.py b/narwhals/_arrow/utils.py index 23164982a1..da51f8a025 100644 --- a/narwhals/_arrow/utils.py +++ b/narwhals/_arrow/utils.py @@ -7,7 +7,6 @@ from typing import Iterator from typing import Sequence from typing import cast -from typing import overload import pyarrow as pa import pyarrow.compute as pc @@ -18,8 +17,6 @@ from narwhals.utils import isinstance_or_issubclass if TYPE_CHECKING: - from typing import TypeVar - from typing_extensions import TypeAlias from typing_extensions import TypeIs @@ -33,15 +30,12 @@ from narwhals._arrow.typing import ScalarAny from narwhals.dtypes import DType from narwhals.typing import PythonLiteral - from narwhals.typing import _AnyDArray from narwhals.utils import Version # NOTE: stubs don't allow for `ChunkedArray[StructArray]` # Intended to represent the `.chunks` property storing `list[pa.StructArray]` ChunkedArrayStructArray: TypeAlias = ArrowChunkedArray - _T = TypeVar("_T") - def is_timestamp(t: Any) -> TypeIs[pa.TimestampType[Any, Any]]: ... def is_duration(t: Any) -> TypeIs[pa.DurationType[Any]]: ... def is_list(t: Any) -> TypeIs[pa.ListType[Any]]: ... @@ -324,41 +318,6 @@ def cast_for_truediv( return arrow_array, pa_object -@overload -def convert_slice_to_nparray(num_rows: int, rows_slice: slice) -> _AnyDArray: ... -@overload -def convert_slice_to_nparray(num_rows: int, rows_slice: _T) -> _T: ... -def convert_slice_to_nparray(num_rows: int, rows_slice: slice | _T) -> _AnyDArray | _T: - if isinstance(rows_slice, slice): - import numpy as np # ignore-banned-import - - return np.arange(num_rows)[rows_slice] - else: - return rows_slice - - -def select_rows( - table: pa.Table, rows: slice | int | Sequence[int] | _AnyDArray -) -> pa.Table: - if isinstance(rows, slice) and rows == slice(None): - selected_rows = table - elif isinstance(rows, Sequence) and not rows: - selected_rows = table.slice(0, 0) - else: - range_ = convert_slice_to_nparray(num_rows=len(table), rows_slice=rows) - selected_rows = table.take(cast("list[int]", range_)) - return selected_rows - - -def convert_str_slice_to_int_slice( - str_slice: slice, columns: list[str] -) -> tuple[int | None, int | None, int | None]: - start = columns.index(str_slice.start) if str_slice.start is not None else None - stop = columns.index(str_slice.stop) + 1 if str_slice.stop is not None else None - step = str_slice.step - return (start, stop, step) - - # Regex for date, time, separator and timezone components DATE_RE = r"(?P\d{1,4}[-/.]\d{1,2}[-/.]\d{1,4}|\d{8})" SEP_RE = r"(?P\s|T)" diff --git a/narwhals/_compliant/dataframe.py b/narwhals/_compliant/dataframe.py index 67cfad404c..f218db68c9 100644 --- a/narwhals/_compliant/dataframe.py +++ b/narwhals/_compliant/dataframe.py @@ -14,9 +14,10 @@ from narwhals._compliant.typing import CompliantExprT_contra from narwhals._compliant.typing import CompliantSeriesT -from narwhals._compliant.typing import EagerExprT_contra +from narwhals._compliant.typing import EagerExprT from narwhals._compliant.typing import EagerSeriesT from narwhals._compliant.typing import NativeFrameT +from narwhals._compliant.typing import NativeSeriesT from narwhals._expression_parsing import evaluate_output_names_and_aliases from narwhals._translate import ArrowConvertible from narwhals._translate import DictConvertible @@ -25,6 +26,13 @@ from narwhals.utils import Version from narwhals.utils import _StoresNative from narwhals.utils import deprecated +from narwhals.utils import is_compliant_series +from narwhals.utils import is_index_selector +from narwhals.utils import is_range +from narwhals.utils import is_sequence_like +from narwhals.utils import is_sized_multi_index_selector +from narwhals.utils import is_slice_index +from narwhals.utils import is_slice_none if TYPE_CHECKING: from io import BytesIO @@ -38,16 +46,24 @@ from narwhals._compliant.group_by import CompliantGroupBy from narwhals._compliant.group_by import DataFrameGroupBy + from narwhals._compliant.namespace import EagerNamespace from narwhals._translate import IntoArrowTable from narwhals.dtypes import DType from narwhals.schema import Schema from narwhals.typing import AsofJoinStrategy from narwhals.typing import JoinStrategy from narwhals.typing import LazyUniqueKeepStrategy + from narwhals.typing import MultiColSelector + from narwhals.typing import MultiIndexSelector from narwhals.typing import PivotAgg + from narwhals.typing import SingleIndexSelector + from narwhals.typing import SizedMultiIndexSelector + from narwhals.typing import SizedMultiNameSelector from narwhals.typing import SizeUnit from narwhals.typing import UniqueKeepStrategy from narwhals.typing import _2DArray + from narwhals.typing import _SliceIndex + from narwhals.typing import _SliceName from narwhals.utils import Implementation from narwhals.utils import _FullContext @@ -99,7 +115,13 @@ def from_numpy( schema: Mapping[str, DType] | Schema | Sequence[str] | None, ) -> Self: ... def __array__(self, dtype: Any, *, copy: bool | None) -> _2DArray: ... - def __getitem__(self, item: Any) -> CompliantSeriesT | Self: ... + def __getitem__( + self, + item: tuple[ + SingleIndexSelector | MultiIndexSelector[CompliantSeriesT], + MultiIndexSelector[CompliantSeriesT] | MultiColSelector[CompliantSeriesT], + ], + ) -> Self: ... def simple_select(self, *column_names: str) -> Self: """`select` where all args are column names.""" ... @@ -329,21 +351,25 @@ def _evaluate_expr(self, expr: CompliantExprT_contra, /) -> Any: class EagerDataFrame( - CompliantDataFrame[EagerSeriesT, EagerExprT_contra, NativeFrameT], - CompliantLazyFrame[EagerExprT_contra, NativeFrameT], - Protocol[EagerSeriesT, EagerExprT_contra, NativeFrameT], + CompliantDataFrame[EagerSeriesT, EagerExprT, NativeFrameT], + CompliantLazyFrame[EagerExprT, NativeFrameT], + Protocol[EagerSeriesT, EagerExprT, NativeFrameT, NativeSeriesT], ): - def _evaluate_expr(self, expr: EagerExprT_contra, /) -> EagerSeriesT: + def __narwhals_namespace__( + self, + ) -> EagerNamespace[Self, EagerSeriesT, EagerExprT, NativeFrameT, NativeSeriesT]: ... + + def _evaluate_expr(self, expr: EagerExprT, /) -> EagerSeriesT: """Evaluate `expr` and ensure it has a **single** output.""" result: Sequence[EagerSeriesT] = expr(self) assert len(result) == 1 # debug assertion # noqa: S101 return result[0] - def _evaluate_into_exprs(self, *exprs: EagerExprT_contra) -> Sequence[EagerSeriesT]: + def _evaluate_into_exprs(self, *exprs: EagerExprT) -> Sequence[EagerSeriesT]: # NOTE: Ignore is to avoid an intermittent false positive return list(chain.from_iterable(self._evaluate_into_expr(expr) for expr in exprs)) # pyright: ignore[reportArgumentType] - def _evaluate_into_expr(self, expr: EagerExprT_contra, /) -> Sequence[EagerSeriesT]: + def _evaluate_into_expr(self, expr: EagerExprT, /) -> Sequence[EagerSeriesT]: """Return list of raw columns. For eager backends we alias operations at each step. @@ -355,7 +381,9 @@ def _evaluate_into_expr(self, expr: EagerExprT_contra, /) -> Sequence[EagerSerie """ _, aliases = evaluate_output_names_and_aliases(expr, self, []) result = expr(self) - if list(aliases) != (result_aliases := [s.name for s in result]): + if list(aliases) != ( + result_aliases := [s.name for s in result] + ): # pragma: no cover msg = f"Safety assertion failed, expected {aliases}, got {result_aliases}" raise AssertionError(msg) return result @@ -369,3 +397,57 @@ def _numpy_column_names( data: _2DArray, columns: Sequence[str] | None, / ) -> list[str]: return list(columns or (f"column_{x}" for x in range(data.shape[1]))) + + def _gather(self, rows: SizedMultiIndexSelector[NativeSeriesT]) -> Self: ... + def _gather_slice(self, rows: _SliceIndex | range) -> Self: ... + def _select_multi_index( + self, columns: SizedMultiIndexSelector[NativeSeriesT] + ) -> Self: ... + def _select_multi_name( + self, columns: SizedMultiNameSelector[NativeSeriesT] + ) -> Self: ... + def _select_slice_index(self, columns: _SliceIndex | range) -> Self: ... + def _select_slice_name(self, columns: _SliceName) -> Self: ... + def __getitem__( + self, + item: tuple[ + SingleIndexSelector | MultiIndexSelector[EagerSeriesT], + MultiIndexSelector[EagerSeriesT] | MultiColSelector[EagerSeriesT], + ], + ) -> Self: + rows, columns = item + compliant = self + if not is_slice_none(columns): + if isinstance(columns, Sized) and len(columns) == 0: + return compliant.select() + if is_index_selector(columns): + if is_slice_index(columns) or is_range(columns): + compliant = compliant._select_slice_index(columns) + elif is_compliant_series(columns): + compliant = self._select_multi_index(columns.native) + else: + compliant = compliant._select_multi_index(columns) + elif isinstance(columns, slice): + compliant = compliant._select_slice_name(columns) + elif is_compliant_series(columns): + compliant = self._select_multi_name(columns.native) + elif is_sequence_like(columns): + compliant = self._select_multi_name(columns) + else: # pragma: no cover + msg = f"Unreachable code, got unexpected type: {type(columns)}" + raise AssertionError(msg) + + if not is_slice_none(rows): + if isinstance(rows, int): + compliant = compliant._gather([rows]) + elif isinstance(rows, (slice, range)): + compliant = compliant._gather_slice(rows) + elif is_compliant_series(rows): + compliant = compliant._gather(rows.native) + elif is_sized_multi_index_selector(rows): + compliant = compliant._gather(rows) + else: # pragma: no cover + msg = f"Unreachable code, got unexpected type: {type(rows)}" + raise AssertionError(msg) + + return compliant diff --git a/narwhals/_compliant/namespace.py b/narwhals/_compliant/namespace.py index ac8d922e9f..4fba98de1d 100644 --- a/narwhals/_compliant/namespace.py +++ b/narwhals/_compliant/namespace.py @@ -127,8 +127,9 @@ def _lazyframe(self) -> type[CompliantLazyFrameT]: ... def from_native(self, data: NativeFrameT_co | Any, /) -> CompliantLazyFrameT: if self._lazyframe._is_native(data): return self._lazyframe.from_native(data, context=self) - msg = f"Unsupported type: {type(data).__name__!r}" - raise TypeError(msg) + else: # pragma: no cover + msg = f"Unsupported type: {type(data).__name__!r}" + raise TypeError(msg) class EagerNamespace( @@ -198,6 +199,6 @@ def concat( native = self._concat_vertical(dfs) elif how == "diagonal": native = self._concat_diagonal(dfs) - else: + else: # pragma: no cover raise NotImplementedError return self._dataframe.from_native(native, context=self) diff --git a/narwhals/_compliant/series.py b/narwhals/_compliant/series.py index 2841ea51bc..2f1cc7842e 100644 --- a/narwhals/_compliant/series.py +++ b/narwhals/_compliant/series.py @@ -23,6 +23,8 @@ from narwhals._translate import NumpyConvertible from narwhals.utils import _StoresCompliant from narwhals.utils import _StoresNative +from narwhals.utils import is_compliant_series +from narwhals.utils import is_sized_multi_index_selector from narwhals.utils import unstable if TYPE_CHECKING: @@ -42,12 +44,15 @@ from narwhals.typing import ClosedInterval from narwhals.typing import FillNullStrategy from narwhals.typing import Into1DArray + from narwhals.typing import MultiIndexSelector from narwhals.typing import NonNestedLiteral from narwhals.typing import NumericLiteral from narwhals.typing import RankMethod from narwhals.typing import RollingInterpolationMethod + from narwhals.typing import SizedMultiIndexSelector from narwhals.typing import TemporalLiteral from narwhals.typing import _1DArray + from narwhals.typing import _SliceIndex from narwhals.utils import Implementation from narwhals.utils import Version from narwhals.utils import _FullContext @@ -78,7 +83,7 @@ def __narwhals_namespace__(self) -> CompliantNamespace[Any, Any]: ... def __native_namespace__(self) -> ModuleType: ... def __array__(self, dtype: Any, *, copy: bool | None) -> _1DArray: ... def __contains__(self, other: Any) -> bool: ... - def __getitem__(self, item: Any) -> Any: ... + def __getitem__(self, item: MultiIndexSelector[Self]) -> Any: ... def __iter__(self) -> Iterator[Any]: ... def __len__(self) -> int: return len(self.native) @@ -316,6 +321,21 @@ def __narwhals_namespace__( def _to_expr(self) -> EagerExpr[Any, Any]: return self.__narwhals_namespace__()._expr._from_series(self) # type: ignore[no-any-return] + def _gather(self, rows: SizedMultiIndexSelector[NativeSeriesT]) -> Self: ... + def _gather_slice(self, rows: _SliceIndex | range) -> Self: ... + def __getitem__( + self, item: MultiIndexSelector[CompliantSeries[NativeSeriesT]] + ) -> Self: + if isinstance(item, (slice, range)): + return self._gather_slice(item) + elif is_compliant_series(item): + return self._gather(item.native) + elif is_sized_multi_index_selector(item): + return self._gather(item) + else: # pragma: no cover + msg = f"Unreachable code, got unexpected type: {type(item)}" + raise AssertionError(msg) + @property def str(self) -> EagerSeriesStringNamespace[Self, NativeSeriesT]: ... @property diff --git a/narwhals/_compliant/typing.py b/narwhals/_compliant/typing.py index 30788b3210..56c092799d 100644 --- a/narwhals/_compliant/typing.py +++ b/narwhals/_compliant/typing.py @@ -47,7 +47,7 @@ DepthTrackingExprAny: TypeAlias = "DepthTrackingExpr[Any, Any]" -EagerDataFrameAny: TypeAlias = "EagerDataFrame[Any, Any, Any]" +EagerDataFrameAny: TypeAlias = "EagerDataFrame[Any, Any, Any, Any]" EagerSeriesAny: TypeAlias = "EagerSeries[Any]" EagerExprAny: TypeAlias = "EagerExpr[Any, Any]" EagerNamespaceAny: TypeAlias = "EagerNamespace[EagerDataFrameAny, EagerSeriesAny, EagerExprAny, NativeFrame, NativeSeries]" @@ -111,7 +111,7 @@ EagerSeriesT_co = TypeVar("EagerSeriesT_co", bound=EagerSeriesAny, covariant=True) # NOTE: `pyright` gives false (8) positives if this uses `EagerDataFrameAny`? -EagerDataFrameT = TypeVar("EagerDataFrameT", bound="EagerDataFrame[Any, Any, Any]") +EagerDataFrameT = TypeVar("EagerDataFrameT", bound="EagerDataFrame[Any, Any, Any, Any]") LazyExprT = TypeVar("LazyExprT", bound=LazyExprAny) LazyExprT_contra = TypeVar("LazyExprT_contra", bound=LazyExprAny, contravariant=True) diff --git a/narwhals/_duckdb/dataframe.py b/narwhals/_duckdb/dataframe.py index eca27cbda3..dd350e5165 100644 --- a/narwhals/_duckdb/dataframe.py +++ b/narwhals/_duckdb/dataframe.py @@ -102,10 +102,10 @@ def __narwhals_namespace__(self) -> DuckDBNamespace: backend_version=self._backend_version, version=self._version ) - def __getitem__(self, item: str) -> DuckDBInterchangeSeries: + def get_column(self, name: str) -> DuckDBInterchangeSeries: from narwhals._duckdb.series import DuckDBInterchangeSeries - return DuckDBInterchangeSeries(self.native.select(item), version=self._version) + return DuckDBInterchangeSeries(self.native.select(name), version=self._version) def _iter_columns(self) -> Iterator[duckdb.Expression]: for name in self.columns: diff --git a/narwhals/_ibis/dataframe.py b/narwhals/_ibis/dataframe.py index 11eaaa2a52..4b34412fe0 100644 --- a/narwhals/_ibis/dataframe.py +++ b/narwhals/_ibis/dataframe.py @@ -100,10 +100,10 @@ def __narwhals_lazyframe__(self) -> Any: def __native_namespace__(self) -> ModuleType: return get_ibis() - def __getitem__(self, item: str) -> IbisInterchangeSeries: + def get_column(self, name: str) -> IbisInterchangeSeries: from narwhals._ibis.series import IbisInterchangeSeries - return IbisInterchangeSeries(self._native_frame[item], version=self._version) + return IbisInterchangeSeries(self._native_frame[name], version=self._version) def to_pandas(self) -> pd.DataFrame: return self._native_frame.to_pandas() diff --git a/narwhals/_interchange/dataframe.py b/narwhals/_interchange/dataframe.py index b9377d886a..c52cb0a506 100644 --- a/narwhals/_interchange/dataframe.py +++ b/narwhals/_interchange/dataframe.py @@ -105,11 +105,11 @@ def __native_namespace__(self) -> NoReturn: ) raise NotImplementedError(msg) - def __getitem__(self, item: str) -> InterchangeSeries: + def get_column(self, name: str) -> InterchangeSeries: from narwhals._interchange.series import InterchangeSeries return InterchangeSeries( - self._interchange_frame.get_column_by_name(item), version=self._version + self._interchange_frame.get_column_by_name(name), version=self._version ) def to_pandas(self) -> pd.DataFrame: diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index b73b0b5d2a..2117f68eac 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -19,7 +19,6 @@ from narwhals._pandas_like.utils import align_and_extract_native from narwhals._pandas_like.utils import align_series_full_broadcast from narwhals._pandas_like.utils import check_column_names_are_unique -from narwhals._pandas_like.utils import convert_str_slice_to_int_slice from narwhals._pandas_like.utils import get_dtype_backend from narwhals._pandas_like.utils import native_to_narwhals_dtype from narwhals._pandas_like.utils import object_native_to_narwhals_dtype @@ -27,7 +26,6 @@ from narwhals._pandas_like.utils import rename from narwhals._pandas_like.utils import select_columns_by_name from narwhals._pandas_like.utils import set_index -from narwhals.dependencies import is_numpy_array_1d from narwhals.dependencies import is_pandas_like_dataframe from narwhals.exceptions import InvalidOperationError from narwhals.exceptions import ShapeError @@ -37,7 +35,6 @@ from narwhals.utils import check_column_exists from narwhals.utils import generate_temporary_column_name from narwhals.utils import import_dtypes_module -from narwhals.utils import is_sequence_but_not_str from narwhals.utils import parse_columns_to_drop from narwhals.utils import parse_version from narwhals.utils import scale_bytes @@ -66,10 +63,13 @@ from narwhals.typing import DTypeBackend from narwhals.typing import JoinStrategy from narwhals.typing import PivotAgg + from narwhals.typing import SizedMultiIndexSelector + from narwhals.typing import SizedMultiNameSelector from narwhals.typing import SizeUnit from narwhals.typing import UniqueKeepStrategy - from narwhals.typing import _1DArray from narwhals.typing import _2DArray + from narwhals.typing import _SliceIndex + from narwhals.typing import _SliceName from narwhals.utils import Version from narwhals.utils import _FullContext @@ -102,8 +102,9 @@ ) -class PandasLikeDataFrame(EagerDataFrame["PandasLikeSeries", "PandasLikeExpr", "Any"]): - # --- not in the spec --- +class PandasLikeDataFrame( + EagerDataFrame["PandasLikeSeries", "PandasLikeExpr", "Any", "pd.Series[Any]"] +): def __init__( self, native_dataframe: Any, @@ -281,135 +282,49 @@ def get_column(self, name: str) -> PandasLikeSeries: def __array__(self, dtype: Any = None, *, copy: bool | None = None) -> _2DArray: return self.to_numpy(dtype=dtype, copy=copy) - @overload - def __getitem__( # type: ignore[overload-overlap] - self, - item: str | tuple[slice | Sequence[int] | _1DArray, int | str], - ) -> PandasLikeSeries: ... + def _gather(self, rows: SizedMultiIndexSelector[pd.Series[Any]]) -> Self: + items = list(rows) if isinstance(rows, tuple) else rows + return self._with_native(self.native.iloc[items, :]) - @overload - def __getitem__( - self, - item: ( - int - | slice - | Sequence[int] - | Sequence[str] - | _1DArray - | tuple[ - slice | Sequence[int] | _1DArray, slice | Sequence[int] | Sequence[str] - ] - ), - ) -> Self: ... - def __getitem__( - self, - item: ( - str - | int - | slice - | Sequence[int] - | Sequence[str] - | _1DArray - | tuple[slice | Sequence[int] | _1DArray, int | str] - | tuple[ - slice | Sequence[int] | _1DArray, slice | Sequence[int] | Sequence[str] - ] - ), - ) -> PandasLikeSeries | Self: - if isinstance(item, tuple): - item = tuple(list(i) if is_sequence_but_not_str(i) else i for i in item) # pyright: ignore[reportAssignmentType] - - if isinstance(item, str): - return PandasLikeSeries.from_native(self.native[item], context=self) - - elif ( - isinstance(item, tuple) - and len(item) == 2 - and is_sequence_but_not_str(item[1]) - ): - if len(item[1]) == 0: - # Return empty dataframe - return self._with_native( - self.native.__class__(), validate_column_names=False - ) - if isinstance(item[1][0], int): - return self._with_native( - self.native.iloc[item], validate_column_names=False - ) - if isinstance(item[1][0], str): - indexer = ( - item[0], - self.native.columns.get_indexer(item[1]), - ) - return self._with_native( - self.native.iloc[indexer], validate_column_names=False - ) - msg = ( - f"Expected sequence str or int, got: {type(item[1])}" # pragma: no cover - ) - raise TypeError(msg) # pragma: no cover - - elif isinstance(item, tuple) and len(item) == 2 and isinstance(item[1], slice): - columns = self.native.columns - if item[1] == slice(None): - return self._with_native( - self.native.iloc[item[0], :], validate_column_names=False - ) - if isinstance(item[1].start, str) or isinstance(item[1].stop, str): - start, stop, step = convert_str_slice_to_int_slice(item[1], columns) - return self._with_native( - self.native.iloc[item[0], slice(start, stop, step)], - validate_column_names=False, - ) - if isinstance(item[1].start, int) or isinstance(item[1].stop, int): - return self._with_native( - self.native.iloc[ - item[0], slice(item[1].start, item[1].stop, item[1].step) - ], - validate_column_names=False, - ) - msg = f"Expected slice of integers or strings, got: {type(item[1])}" # pragma: no cover - raise TypeError(msg) # pragma: no cover - - elif isinstance(item, tuple) and len(item) == 2: - if isinstance(item[1], str): - index = (item[0], self.native.columns.get_loc(item[1])) - native_series = self.native.iloc[index] - elif isinstance(item[1], int): - native_series = self.native.iloc[item] - else: # pragma: no cover - msg = f"Expected str or int, got: {type(item[1])}" - raise TypeError(msg) + def _gather_slice(self, rows: _SliceIndex | range) -> Self: + return self._with_native( + self.native.iloc[slice(rows.start, rows.stop, rows.step), :], + validate_column_names=False, + ) - return PandasLikeSeries.from_native(native_series, context=self) + def _select_slice_name(self, columns: _SliceName) -> Self: + start = ( + self.native.columns.get_loc(columns.start) + if columns.start is not None + else None + ) + stop = ( + self.native.columns.get_loc(columns.stop) + 1 + if columns.stop is not None + else None + ) + selector = slice(start, stop, columns.step) + return self._with_native( + self.native.iloc[:, selector], validate_column_names=False + ) - elif is_sequence_but_not_str(item) or is_numpy_array_1d(item): - if len(item) > 0 and isinstance(item[0], str): - return self._with_native( - select_columns_by_name( - self.native, - cast("list[str] | _1DArray", item), - self._backend_version, - self._implementation, - ), - validate_column_names=False, - ) - return self._with_native(self.native.iloc[item], validate_column_names=False) + def _select_slice_index(self, columns: _SliceIndex | range) -> Self: + return self._with_native( + self.native.iloc[:, columns], validate_column_names=False + ) - elif isinstance(item, slice): - if isinstance(item.start, str) or isinstance(item.stop, str): - start, stop, step = convert_str_slice_to_int_slice( - item, self.native.columns - ) - return self._with_native( - self.native.iloc[:, slice(start, stop, step)], - validate_column_names=False, - ) - return self._with_native(self.native.iloc[item], validate_column_names=False) + def _select_multi_index( + self, columns: SizedMultiIndexSelector[pd.Series[Any]] + ) -> Self: + columns = list(columns) if isinstance(columns, tuple) else columns + return self._with_native( + self.native.iloc[:, columns], validate_column_names=False + ) - else: # pragma: no cover - msg = f"Expected str or slice, got: {type(item)}" - raise TypeError(msg) + def _select_multi_name( + self, columns: SizedMultiNameSelector[pd.Series[Any]] + ) -> PandasLikeDataFrame: + return self._with_native(self.native.loc[:, columns]) # --- properties --- @property diff --git a/narwhals/_pandas_like/expr.py b/narwhals/_pandas_like/expr.py index 2ecc0acc56..6e8f841295 100644 --- a/narwhals/_pandas_like/expr.py +++ b/narwhals/_pandas_like/expr.py @@ -264,7 +264,7 @@ def func(df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]: sorting_indices = df.get_column(token) elif reverse: columns = list(set(partition_by).union(output_names)) - df = df.simple_select(*columns)[::-1] + df = df.simple_select(*columns)._gather_slice(slice(None, None, -1)) grouped = df._native_frame.groupby(partition_by) if function_name.startswith("rolling"): rolling = grouped[list(output_names)].rolling(**pandas_kwargs) @@ -293,7 +293,7 @@ def func(df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]: s._scatter_in_place(sorting_indices, s) return results if reverse: - return [s[::-1] for s in results] + return [s._gather_slice(slice(None, None, -1)) for s in results] return results return self.__class__( diff --git a/narwhals/_pandas_like/series.py b/narwhals/_pandas_like/series.py index 32d6c7b8dc..08d94746e0 100644 --- a/narwhals/_pandas_like/series.py +++ b/narwhals/_pandas_like/series.py @@ -7,7 +7,6 @@ from typing import Mapping from typing import Sequence from typing import cast -from typing import overload import numpy as np @@ -26,7 +25,6 @@ from narwhals._pandas_like.utils import select_columns_by_name from narwhals._pandas_like.utils import set_index from narwhals.dependencies import is_numpy_array_1d -from narwhals.dependencies import is_numpy_scalar from narwhals.dependencies import is_pandas_like_series from narwhals.exceptions import InvalidOperationError from narwhals.utils import Implementation @@ -55,9 +53,11 @@ from narwhals.typing import NumericLiteral from narwhals.typing import RankMethod from narwhals.typing import RollingInterpolationMethod + from narwhals.typing import SizedMultiIndexSelector from narwhals.typing import TemporalLiteral from narwhals.typing import _1DArray from narwhals.typing import _AnyDArray + from narwhals.typing import _SliceIndex from narwhals.utils import Version from narwhals.utils import _FullContext @@ -147,16 +147,14 @@ def __narwhals_namespace__(self) -> PandasLikeNamespace: self._implementation, self._backend_version, self._version ) - @overload - def __getitem__(self, idx: int) -> Any: ... + def _gather(self, rows: SizedMultiIndexSelector[pd.Series[Any]]) -> Self: + rows = list(rows) if isinstance(rows, tuple) else rows + return self._with_native(self.native.iloc[rows]) - @overload - def __getitem__(self, idx: slice | Sequence[int]) -> Self: ... - - def __getitem__(self, idx: int | slice | Sequence[int]) -> Any | Self: - if isinstance(idx, int) or is_numpy_scalar(idx): - return self.native.iloc[idx] - return self._with_native(self.native.iloc[idx]) + def _gather_slice(self, rows: _SliceIndex | range) -> Self: + return self._with_native( + self.native.iloc[slice(rows.start, rows.stop, rows.step)] + ) def _with_version(self, version: Version) -> Self: return self.__class__( diff --git a/narwhals/_pandas_like/utils.py b/narwhals/_pandas_like/utils.py index 7f1a0ea020..74ab8f88e8 100644 --- a/narwhals/_pandas_like/utils.py +++ b/narwhals/_pandas_like/utils.py @@ -8,7 +8,6 @@ from typing import Sequence from typing import Sized from typing import TypeVar -from typing import cast import pandas as pd @@ -565,24 +564,6 @@ def int_dtype_mapper(dtype: Any) -> str: return "int64" -def convert_str_slice_to_int_slice( - str_slice: slice, columns: pd.Index[str] -) -> tuple[int | None, int | None, int | None]: - # We can safely cast to int because we know that `columns` doesn't contain duplicates. - start = ( - cast("int", columns.get_loc(str_slice.start)) - if str_slice.start is not None - else None - ) - stop = ( - cast("int", columns.get_loc(str_slice.stop)) + 1 - if str_slice.stop is not None - else None - ) - step = str_slice.step - return (start, stop, step) - - def calculate_timestamp_datetime( s: pd.Series[int], original_time_unit: str, time_unit: str ) -> pd.Series[int]: diff --git a/narwhals/_polars/dataframe.py b/narwhals/_polars/dataframe.py index 65e7bf5cee..fe9d3aa140 100644 --- a/narwhals/_polars/dataframe.py +++ b/narwhals/_polars/dataframe.py @@ -6,6 +6,7 @@ from typing import Literal from typing import Mapping from typing import Sequence +from typing import Sized from typing import cast from typing import overload @@ -14,13 +15,19 @@ from narwhals._polars.namespace import PolarsNamespace from narwhals._polars.series import PolarsSeries from narwhals._polars.utils import catch_polars_exception -from narwhals._polars.utils import convert_str_slice_to_int_slice from narwhals._polars.utils import extract_args_kwargs from narwhals._polars.utils import native_to_narwhals_dtype +from narwhals.dependencies import is_numpy_array_1d from narwhals.exceptions import ColumnNotFoundError from narwhals.utils import Implementation from narwhals.utils import _into_arrow_table -from narwhals.utils import is_sequence_but_not_str +from narwhals.utils import convert_str_slice_to_int_slice +from narwhals.utils import is_compliant_series +from narwhals.utils import is_index_selector +from narwhals.utils import is_range +from narwhals.utils import is_sequence_like +from narwhals.utils import is_slice_index +from narwhals.utils import is_slice_none from narwhals.utils import parse_columns_to_drop from narwhals.utils import parse_version from narwhals.utils import requires @@ -45,7 +52,10 @@ from narwhals.typing import CompliantDataFrame from narwhals.typing import CompliantLazyFrame from narwhals.typing import JoinStrategy + from narwhals.typing import MultiColSelector + from narwhals.typing import MultiIndexSelector from narwhals.typing import PivotAgg + from narwhals.typing import SingleIndexSelector from narwhals.typing import _2DArray from narwhals.utils import Version from narwhals.utils import _FullContext @@ -257,54 +267,70 @@ def collect_schema(self) -> dict[str, DType]: def shape(self) -> tuple[int, int]: return self.native.shape - def __getitem__(self, item: Any) -> Any: + def __getitem__( + self, + item: tuple[ + SingleIndexSelector | MultiIndexSelector[PolarsSeries], + MultiIndexSelector[PolarsSeries] | MultiColSelector[PolarsSeries], + ], + ) -> Any: + rows, columns = item if self._backend_version > (0, 20, 30): - return self._from_native_object(self.native.__getitem__(item)) + rows_native = rows.native if is_compliant_series(rows) else rows + columns_native = columns.native if is_compliant_series(columns) else columns + selector = rows_native, columns_native + selected = self.native.__getitem__(selector) # type: ignore[index] + return self._from_native_object(selected) else: # pragma: no cover # TODO(marco): we can delete this branch after Polars==0.20.30 becomes the minimum # Polars version we support - if isinstance(item, tuple): - item = tuple(list(i) if is_sequence_but_not_str(i) else i for i in item) - - columns = self.columns - if isinstance(item, tuple) and len(item) == 2 and isinstance(item[1], slice): - if item[1] == slice(None): - if isinstance(item[0], Sequence) and not len(item[0]): - return self._with_native(self.native[0:0]) - return self._with_native(self.native.__getitem__(item[0])) - if isinstance(item[1].start, str) or isinstance(item[1].stop, str): - start, stop, step = convert_str_slice_to_int_slice(item[1], columns) - return self._with_native( - self.native.select(columns[start:stop:step]).__getitem__(item[0]) - ) - if isinstance(item[1].start, int) or isinstance(item[1].stop, int): - return self._with_native( - self.native.select( - columns[item[1].start : item[1].stop : item[1].step] - ).__getitem__(item[0]) + # This mostly mirrors the logic in `EagerDataFrame.__getitem__`. + rows = list(rows) if isinstance(rows, tuple) else rows + columns = list(columns) if isinstance(columns, tuple) else columns + if is_numpy_array_1d(columns): + columns = columns.tolist() + + native = self.native + if not is_slice_none(columns): + if isinstance(columns, Sized) and len(columns) == 0: + return self.select() + if is_index_selector(columns): + if is_slice_index(columns) or is_range(columns): + native = native.select( + self.columns[slice(columns.start, columns.stop, columns.step)] + ) + elif is_compliant_series(columns): + native = native[:, columns.native.to_list()] # type: ignore[attr-defined, index] + else: + native = native[:, columns] + elif isinstance(columns, slice): + native = native.select( + self.columns[ + slice(*convert_str_slice_to_int_slice(columns, self.columns)) + ] ) - msg = f"Expected slice of integers or strings, got: {type(item[1])}" # pragma: no cover - raise TypeError(msg) # pragma: no cover - - if ( - isinstance(item, tuple) - and (len(item) == 2) - and is_sequence_but_not_str(item[1]) - and (len(item[1]) == 0) - ): - result = self.native.select(item[1]) - elif isinstance(item, slice) and ( - isinstance(item.start, str) or isinstance(item.stop, str) - ): - start, stop, step = convert_str_slice_to_int_slice(item, columns) - return self._with_native(self.native.select(columns[start:stop:step])) - elif is_sequence_but_not_str(item) and (len(item) == 0): - result = self.native.slice(0, 0) - else: - result = self.native.__getitem__(item) - if isinstance(result, pl.Series): - return PolarsSeries.from_native(result, context=self) - return self._from_native_object(result) + elif is_compliant_series(columns): + native = native.select(columns.native.to_list()) + elif is_sequence_like(columns): + native = native.select(columns) + else: + msg = f"Unreachable code, got unexpected type: {type(columns)}" + raise AssertionError(msg) + + if not is_slice_none(rows): + if isinstance(rows, int): + native = native[[rows], :] # pyright: ignore[reportArgumentType,reportCallIssue] + elif isinstance(rows, (slice, range)): + native = native[rows, :] # pyright: ignore[reportArgumentType,reportCallIssue] + elif is_compliant_series(rows): + native = native[rows.native, :] # pyright: ignore[reportArgumentType,reportCallIssue] + elif is_sequence_like(rows): + native = native[rows, :] # pyright: ignore[reportArgumentType,reportCallIssue] + else: + msg = f"Unreachable code, got unexpected type: {type(rows)}" + raise AssertionError(msg) + + return self._with_native(native) # pyright: ignore[reportArgumentType] def simple_select(self, *column_names: str) -> Self: return self._with_native(self.native.select(*column_names)) diff --git a/narwhals/_polars/series.py b/narwhals/_polars/series.py index 433e0648c1..ea8ed4d31e 100644 --- a/narwhals/_polars/series.py +++ b/narwhals/_polars/series.py @@ -36,6 +36,7 @@ from narwhals._polars.namespace import PolarsNamespace from narwhals.dtypes import DType from narwhals.typing import Into1DArray + from narwhals.typing import MultiIndexSelector from narwhals.typing import _1DArray from narwhals.utils import Version from narwhals.utils import _FullContext @@ -175,13 +176,9 @@ def native(self) -> pl.Series: def alias(self, name: str) -> Self: return self._from_native_object(self.native.alias(name)) - @overload - def __getitem__(self, item: int) -> Any: ... - - @overload - def __getitem__(self, item: slice | Sequence[int] | pl.Series) -> Self: ... - - def __getitem__(self, item: int | slice | Sequence[int] | pl.Series) -> Any | Self: + def __getitem__(self, item: MultiIndexSelector[Self]) -> Any | Self: + if isinstance(item, PolarsSeries): + return self._from_native_object(self.native.__getitem__(item.native)) return self._from_native_object(self.native.__getitem__(item)) def cast(self, dtype: DType | type[DType]) -> Self: diff --git a/narwhals/_polars/utils.py b/narwhals/_polars/utils.py index dd0df4d8a7..de34caae2d 100644 --- a/narwhals/_polars/utils.py +++ b/narwhals/_polars/utils.py @@ -228,15 +228,6 @@ def narwhals_to_native_dtype( return pl.Unknown() # pragma: no cover -def convert_str_slice_to_int_slice( - str_slice: slice, columns: list[str] -) -> tuple[int | None, int | None, int | None]: # pragma: no cover - start = columns.index(str_slice.start) if str_slice.start is not None else None - stop = columns.index(str_slice.stop) + 1 if str_slice.stop is not None else None - step = str_slice.step - return (start, stop, step) - - def catch_polars_exception( exception: Exception, backend_version: tuple[int, ...] ) -> NarwhalsError | Exception: diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index e7b86124d6..2ab5c354c4 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -22,7 +22,6 @@ from narwhals._expression_parsing import is_scalar_like from narwhals.dependencies import get_polars from narwhals.dependencies import is_numpy_array -from narwhals.dependencies import is_numpy_array_1d from narwhals.exceptions import ColumnNotFoundError from narwhals.exceptions import InvalidIntoExprError from narwhals.exceptions import LengthChangingExprError @@ -35,8 +34,10 @@ from narwhals.utils import generate_repr from narwhals.utils import is_compliant_dataframe from narwhals.utils import is_compliant_lazyframe +from narwhals.utils import is_index_selector from narwhals.utils import is_list_of -from narwhals.utils import is_sequence_but_not_str +from narwhals.utils import is_sequence_like +from narwhals.utils import is_slice_none from narwhals.utils import issue_deprecation_warning from narwhals.utils import parse_version from narwhals.utils import supports_arrow_c_stream @@ -52,6 +53,7 @@ from typing_extensions import Concatenate from typing_extensions import ParamSpec from typing_extensions import Self + from typing_extensions import TypeAlias from narwhals._compliant import CompliantDataFrame from narwhals._compliant import CompliantLazyFrame @@ -66,10 +68,13 @@ from narwhals.typing import IntoFrame from narwhals.typing import JoinStrategy from narwhals.typing import LazyUniqueKeepStrategy + from narwhals.typing import MultiColSelector as _MultiColSelector + from narwhals.typing import MultiIndexSelector as _MultiIndexSelector from narwhals.typing import PivotAgg + from narwhals.typing import SingleColSelector + from narwhals.typing import SingleIndexSelector from narwhals.typing import SizeUnit from narwhals.typing import UniqueKeepStrategy - from narwhals.typing import _1DArray from narwhals.typing import _2DArray PS = ParamSpec("PS") @@ -79,6 +84,9 @@ DataFrameT = TypeVar("DataFrameT", bound="IntoDataFrame") R = TypeVar("R") +MultiColSelector: TypeAlias = "_MultiColSelector[Series[Any]]" +MultiIndexSelector: TypeAlias = "_MultiIndexSelector[Series[Any]]" + class BaseFrame(Generic[_FrameT]): _compliant_frame: Any @@ -795,41 +803,40 @@ def estimated_size(self, unit: SizeUnit = "b") -> int | float: """ return self._compliant_frame.estimated_size(unit=unit) + # `str` overlaps with `Sequence[str]` + # We can ignore this but we must keep this overload ordering + @overload + def __getitem__(self, item: tuple[SingleIndexSelector, SingleColSelector]) -> Any: ... + @overload def __getitem__( # type: ignore[overload-overlap] - self, - item: str | tuple[slice | Sequence[int] | _1DArray, int | str], + self, item: str | tuple[MultiIndexSelector, SingleColSelector] ) -> Series[Any]: ... @overload def __getitem__( self, item: ( - int - | slice - | Sequence[int] - | Sequence[str] - | _1DArray - | tuple[ - slice | Sequence[int] | _1DArray, slice | Sequence[int] | Sequence[str] - ] + SingleIndexSelector + | MultiIndexSelector + | MultiColSelector + | tuple[SingleIndexSelector, MultiColSelector] + | tuple[MultiIndexSelector, MultiColSelector] ), ) -> Self: ... def __getitem__( self, item: ( - str - | int - | slice - | Sequence[int] - | Sequence[str] - | _1DArray - | tuple[slice | Sequence[int] | _1DArray, int | str] - | tuple[ - slice | Sequence[int] | _1DArray, slice | Sequence[int] | Sequence[str] - ] + SingleIndexSelector + | SingleColSelector + | MultiColSelector + | MultiIndexSelector + | tuple[SingleIndexSelector, SingleColSelector] + | tuple[SingleIndexSelector, MultiColSelector] + | tuple[MultiIndexSelector, SingleColSelector] + | tuple[MultiIndexSelector, MultiColSelector] ), - ) -> Series[Any] | Self: + ) -> Series[Any] | Self | Any: """Extract column or slice of DataFrame. Arguments: @@ -879,41 +886,57 @@ def __getitem__( 1 2 Name: a, dtype: int64 """ - if isinstance(item, int): - item = [item] - if ( - isinstance(item, tuple) - and len(item) == 2 - and (isinstance(item[0], (str, int))) - ): - msg = ( - f"Expected str or slice, got: {type(item)}.\n\n" - "Hint: if you were trying to get a single element out of a " - "dataframe, use `DataFrame.item`." - ) - raise TypeError(msg) - if ( - isinstance(item, tuple) - and len(item) == 2 - and (is_sequence_but_not_str(item[1]) or isinstance(item[1], slice)) - ): - if item[1] == slice(None) and item[0] == slice(None): - return self - return self._with_compliant(self._compliant_frame[item]) - if isinstance(item, str) or (isinstance(item, tuple) and len(item) == 2): - return self._series(self._compliant_frame[item], level=self._level) - - elif ( - is_sequence_but_not_str(item) - or isinstance(item, slice) - or (is_numpy_array_1d(item)) - ): - return self._with_compliant(self._compliant_frame[item]) + from narwhals.series import Series + + msg = ( + f"Unexpected type for `DataFrame.__getitem__`, got: {type(item)}.\n\n" + "Hints:\n" + "- use `df.item` to select a single item.\n" + "- Use `df[indices, :]` to select rows positionally.\n" + "- Use `df.filter(mask)` to filter rows based on a boolean mask." + ) + if isinstance(item, tuple): + if len(item) > 2: + tuple_msg = ( + "Tuples cannot be passed to DataFrame.__getitem__ directly.\n\n" + "Hint: instead of `df[indices]`, did you mean `df[indices, :]`?" + ) + raise TypeError(tuple_msg) + rows = None if not item or is_slice_none(item[0]) else item[0] + columns = None if len(item) < 2 or is_slice_none(item[1]) else item[1] + if rows is None and columns is None: + return self + elif is_index_selector(item): + rows = item + columns = None + elif is_sequence_like(item) or isinstance(item, (slice, str)): + rows = None + columns = item else: - msg = f"Expected str or slice, got: {type(item)}" raise TypeError(msg) + if isinstance(rows, str): + raise TypeError(msg) + + compliant = self._compliant_frame + + if isinstance(columns, (int, str)): + if isinstance(rows, int): + return self.item(rows, columns) + col_name = columns if isinstance(columns, str) else self.columns[columns] + series = self.get_column(col_name) + return series[rows] if rows is not None else series + if isinstance(rows, Series): + rows = rows._compliant_series + if isinstance(columns, Series): + columns = columns._compliant_series + if rows is None: + return self._with_compliant(compliant[:, columns]) + if columns is None: + return self._with_compliant(compliant[rows, :]) + return self._with_compliant(compliant[rows, columns]) + def __contains__(self, key: str) -> bool: return key in self.columns diff --git a/narwhals/series.py b/narwhals/series.py index 062ccacaa9..c55b526750 100644 --- a/narwhals/series.py +++ b/narwhals/series.py @@ -21,9 +21,11 @@ from narwhals.translate import to_native from narwhals.typing import IntoSeriesT from narwhals.typing import NonNestedLiteral +from narwhals.typing import SingleIndexSelector from narwhals.utils import _validate_rolling_arguments from narwhals.utils import generate_repr from narwhals.utils import is_compliant_series +from narwhals.utils import is_index_selector from narwhals.utils import parse_version from narwhals.utils import supports_arrow_c_stream @@ -37,6 +39,7 @@ from narwhals._arrow.typing import ArrowArray from narwhals._compliant import CompliantSeries from narwhals.dataframe import DataFrame + from narwhals.dataframe import MultiIndexSelector from narwhals.dtypes import DType from narwhals.typing import ClosedInterval from narwhals.typing import FillNullStrategy @@ -129,12 +132,12 @@ def __array__(self, dtype: Any = None, copy: bool | None = None) -> _1DArray: # return self._compliant_series.__array__(dtype=dtype, copy=copy) @overload - def __getitem__(self, idx: int) -> Any: ... + def __getitem__(self, idx: SingleIndexSelector) -> Any: ... @overload - def __getitem__(self, idx: slice | Sequence[int] | Self) -> Self: ... + def __getitem__(self, idx: MultiIndexSelector) -> Self: ... - def __getitem__(self, idx: int | slice | Sequence[int] | Self) -> Any | Self: + def __getitem__(self, idx: SingleIndexSelector | MultiIndexSelector) -> Any | Self: """Retrieve elements from the object using integer indexing or slicing. Arguments: @@ -169,10 +172,25 @@ def __getitem__(self, idx: int | slice | Sequence[int] | Self) -> Any | Self: if isinstance(idx, int) or ( is_numpy_scalar(idx) and idx.dtype.kind in {"i", "u"} ): - return self._compliant_series[idx] - return self._with_compliant( - self._compliant_series[to_native(idx, pass_through=True)] - ) + idx = int(idx) if not isinstance(idx, int) else idx + return self._compliant_series.item(idx) + + if isinstance(idx, self.to_native().__class__): + idx = self._with_compliant(self._compliant_series._with_native(idx)) + + if not is_index_selector(idx): + msg = ( + f"Unexpected type for `Series.__getitem__`: {type(idx)}.\n\n" + "Hints:\n" + "- use `s.item` to select a single item.\n" + "- Use `s[indices]` to select rows positionally.\n" + "- Use `s.filter(mask)` to filter rows based on a boolean mask." + ) + raise TypeError(msg) + if isinstance(idx, Series): + return self._with_compliant(self._compliant_series[idx._compliant_series]) + assert not isinstance(idx, int) # noqa: S101 # help mypy + return self._with_compliant(self._compliant_series[idx]) def __native_namespace__(self) -> ModuleType: return self._compliant_series.__native_namespace__() diff --git a/narwhals/stable/v1/__init__.py b/narwhals/stable/v1/__init__.py index cfc035f463..f966434725 100644 --- a/narwhals/stable/v1/__init__.py +++ b/narwhals/stable/v1/__init__.py @@ -91,6 +91,8 @@ from typing_extensions import TypeVar from narwhals._translate import IntoArrowTable + from narwhals.dataframe import MultiColSelector + from narwhals.dataframe import MultiIndexSelector from narwhals.dtypes import DType from narwhals.typing import ConcatMethod from narwhals.typing import IntoExpr @@ -98,6 +100,8 @@ from narwhals.typing import IntoLazyFrameT from narwhals.typing import IntoSeries from narwhals.typing import NonNestedLiteral + from narwhals.typing import SingleColSelector + from narwhals.typing import SingleIndexSelector from narwhals.typing import _1DArray from narwhals.typing import _2DArray @@ -150,27 +154,38 @@ def _series(self) -> type[Series[Any]]: def _lazyframe(self) -> type[LazyFrame[Any]]: return LazyFrame + @overload + def __getitem__(self, item: tuple[SingleIndexSelector, SingleColSelector]) -> Any: ... + @overload def __getitem__( # type: ignore[overload-overlap] - self, - item: str | tuple[slice | Sequence[int] | _1DArray, int | str], + self, item: str | tuple[MultiIndexSelector, SingleColSelector] ) -> Series[Any]: ... + @overload def __getitem__( self, item: ( - int - | slice - | _1DArray - | Sequence[int] - | Sequence[str] - | tuple[ - slice | Sequence[int] | _1DArray, slice | Sequence[int] | Sequence[str] - ] + SingleIndexSelector + | MultiIndexSelector + | MultiColSelector + | tuple[SingleIndexSelector, MultiColSelector] + | tuple[MultiIndexSelector, MultiColSelector] ), ) -> Self: ... - - def __getitem__(self, item: Any) -> Any: + def __getitem__( + self, + item: ( + SingleIndexSelector + | SingleColSelector + | MultiColSelector + | MultiIndexSelector + | tuple[SingleIndexSelector, SingleColSelector] + | tuple[SingleIndexSelector, MultiColSelector] + | tuple[MultiIndexSelector, SingleColSelector] + | tuple[MultiIndexSelector, MultiColSelector] + ), + ) -> Series[Any] | Self | Any: return super().__getitem__(item) def lazy( diff --git a/narwhals/typing.py b/narwhals/typing.py index c222b200f8..cd19d24f05 100644 --- a/narwhals/typing.py +++ b/narwhals/typing.py @@ -4,6 +4,7 @@ from typing import Any from typing import Literal from typing import Protocol +from typing import Sequence from typing import TypeVar from typing import Union @@ -309,6 +310,25 @@ def __native_namespace__(self) -> ModuleType: ... ) PythonLiteral: TypeAlias = "NonNestedLiteral | list[Any] | tuple[Any, ...]" +# Annotations for `__getitem__` methods +_T = TypeVar("_T") +_Slice: TypeAlias = "slice[_T, Any, Any] | slice[Any, _T, Any] | slice[None, None, _T]" +_SliceNone: TypeAlias = "slice[None, None, None]" +# Index/column positions +SingleIndexSelector: TypeAlias = int +_SliceIndex: TypeAlias = "_Slice[int] | _SliceNone" +"""E.g. `[1:]` or `[:3]` or `[::2]`.""" +SizedMultiIndexSelector: TypeAlias = "Sequence[int] | _T | _1DArray" +MultiIndexSelector: TypeAlias = "_SliceIndex | SizedMultiIndexSelector[_T]" +# Labels/column names +SingleNameSelector: TypeAlias = str +_SliceName: TypeAlias = "_Slice[str] | _SliceNone" +SizedMultiNameSelector: TypeAlias = "Sequence[str] | _T | _1DArray" +MultiNameSelector: TypeAlias = "_SliceName | SizedMultiNameSelector[_T]" +# Mixed selectors +SingleColSelector: TypeAlias = "SingleIndexSelector | SingleNameSelector" +MultiColSelector: TypeAlias = "MultiIndexSelector[_T] | MultiNameSelector[_T]" + # ruff: noqa: N802 class DTypes(Protocol): diff --git a/narwhals/utils.py b/narwhals/utils.py index b48e5d43cd..ed716d9241 100644 --- a/narwhals/utils.py +++ b/narwhals/utils.py @@ -29,6 +29,7 @@ from narwhals.dependencies import get_duckdb from narwhals.dependencies import get_ibis from narwhals.dependencies import get_modin +from narwhals.dependencies import get_numpy from narwhals.dependencies import get_pandas from narwhals.dependencies import get_polars from narwhals.dependencies import get_pyarrow @@ -37,6 +38,8 @@ from narwhals.dependencies import get_sqlframe from narwhals.dependencies import is_cudf_series from narwhals.dependencies import is_modin_series +from narwhals.dependencies import is_narwhals_series +from narwhals.dependencies import is_numpy_array_1d from narwhals.dependencies import is_pandas_dataframe from narwhals.dependencies import is_pandas_like_dataframe from narwhals.dependencies import is_pandas_like_series @@ -82,9 +85,16 @@ from narwhals.typing import DataFrameLike from narwhals.typing import DTypes from narwhals.typing import IntoSeriesT + from narwhals.typing import MultiIndexSelector + from narwhals.typing import SingleIndexSelector + from narwhals.typing import SizedMultiIndexSelector from narwhals.typing import SizeUnit from narwhals.typing import SupportsNativeNamespace from narwhals.typing import TimeUnit + from narwhals.typing import _1DArray + from narwhals.typing import _SliceIndex + from narwhals.typing import _SliceName + from narwhals.typing import _SliceNone FrameOrSeriesT = TypeVar( "FrameOrSeriesT", bound=Union[LazyFrame[Any], DataFrame[Any], Series[Any]] @@ -1256,10 +1266,64 @@ def parse_columns_to_drop( return to_drop -def is_sequence_but_not_str(sequence: Any | Sequence[_T]) -> TypeIs[Sequence[_T]]: +def is_sequence_but_not_str(sequence: Sequence[_T] | Any) -> TypeIs[Sequence[_T]]: return isinstance(sequence, Sequence) and not isinstance(sequence, str) +def is_slice_none(obj: Any) -> TypeIs[_SliceNone]: + return isinstance(obj, slice) and obj == slice(None) + + +def is_sized_multi_index_selector(obj: Any) -> TypeIs[SizedMultiIndexSelector[Any]]: + np = get_numpy() + return ( + ( + is_sequence_but_not_str(obj) + and ((len(obj) > 0 and isinstance(obj[0], int)) or (len(obj) == 0)) + ) + or (is_numpy_array_1d(obj) and np.issubdtype(obj.dtype, np.integer)) + or ( + (is_narwhals_series(obj) or is_compliant_series(obj)) + and obj.dtype.is_integer() + ) + ) + + +def is_sequence_like( + obj: Sequence[_T] | Any, +) -> TypeIs[Sequence[_T] | Series[Any] | _1DArray]: + return ( + is_sequence_but_not_str(obj) + or is_numpy_array_1d(obj) + or is_narwhals_series(obj) + or is_compliant_series(obj) + ) + + +def is_slice_index(obj: Any) -> TypeIs[_SliceIndex]: + return isinstance(obj, slice) and ( + isinstance(obj.start, int) + or isinstance(obj.stop, int) + or (isinstance(obj.step, int) and obj.start is None and obj.stop is None) + ) + + +def is_range(obj: Any) -> TypeIs[range]: + return isinstance(obj, range) + + +def is_single_index_selector(obj: Any) -> TypeIs[SingleIndexSelector]: + return bool(isinstance(obj, int) and not isinstance(obj, bool)) + + +def is_index_selector(obj: Any) -> TypeIs[SingleIndexSelector | MultiIndexSelector[Any]]: + return ( + is_single_index_selector(obj) + or is_sized_multi_index_selector(obj) + or is_slice_index(obj) + ) + + def is_list_of(obj: Any, tp: type[_T]) -> TypeIs[list[_T]]: # Check if an object is a list of `tp`, only sniffing the first element. return bool(isinstance(obj, list) and obj and isinstance(obj[0], tp)) @@ -1818,3 +1882,12 @@ def wrapper(instance: _ContextT, *args: P.args, **kwds: P.kwargs) -> R: # NOTE: Only getting a complaint from `mypy` return wrapper # type: ignore[return-value] + + +def convert_str_slice_to_int_slice( + str_slice: _SliceName, columns: Sequence[str] +) -> tuple[int | None, int | None, Any]: + start = columns.index(str_slice.start) if str_slice.start is not None else None + stop = columns.index(str_slice.stop) + 1 if str_slice.stop is not None else None + step = str_slice.step + return (start, stop, step) diff --git a/pyproject.toml b/pyproject.toml index d5910e7ac6..d97ea34bd8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -247,10 +247,10 @@ omit = [ 'narwhals/_spark_like/typing.py', # we can't run this in every environment that we measure coverage on due to upper-bound constraits 'narwhals/_ibis/*', - # we don't run these in every environment + # we don't (yet) run these in every environment 'tests/ibis_test.py', # Remove after finishing eager sub-protocols - 'narwhals/_compliant/*', + 'narwhals/_compliant/namespace.py', ] exclude_also = [ "if sys.version_info() <", diff --git a/tests/frame/getitem_test.py b/tests/frame/getitem_test.py index c1c4c09588..39f97c286c 100644 --- a/tests/frame/getitem_test.py +++ b/tests/frame/getitem_test.py @@ -1,5 +1,6 @@ from __future__ import annotations +from datetime import datetime from typing import TYPE_CHECKING from typing import Any from typing import cast @@ -53,6 +54,11 @@ def test_slice_rows_with_step_pyarrow() -> None: match="Slicing with step is not supported on PyArrow tables", ): nw.from_native(pa.table(data))[1::2] + with pytest.raises( + NotImplementedError, + match="Slicing with step is not supported on PyArrow tables", + ): + nw.from_native(pa.chunked_array([data["a"]]), series_only=True)[1::2] def test_slice_lazy_fails() -> None: @@ -71,7 +77,7 @@ def test_slice_int(constructor_eager: ConstructorEager) -> None: def test_slice_fails(constructor_eager: ConstructorEager) -> None: class Foo: ... - with pytest.raises(TypeError, match="Expected str or slice, got:"): + with pytest.raises(TypeError, match="Unexpected type.*, got:"): nw.from_native(constructor_eager(data), eager_only=True)[Foo()] # type: ignore[call-overload, unused-ignore] @@ -113,7 +119,7 @@ def test_gather_rows_cols(constructor_eager: ConstructorEager) -> None: assert_equal_data(result, expected) -def test_slice_both_tuples_of_ints(constructor_eager: ConstructorEager) -> None: +def test_slice_both_list_of_ints(constructor_eager: ConstructorEager) -> None: data = {"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]} df = nw.from_native(constructor_eager(data), eager_only=True) result = df[[0, 1], [0, 2]] @@ -121,6 +127,14 @@ def test_slice_both_tuples_of_ints(constructor_eager: ConstructorEager) -> None: assert_equal_data(result, expected) +def test_slice_both_tuple(constructor_eager: ConstructorEager) -> None: + data = {"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]} + df = nw.from_native(constructor_eager(data), eager_only=True) + result = df[(0, 1), ("a", "c")] + expected = {"a": [1, 2], "c": [7, 8]} + assert_equal_data(result, expected) + + def test_slice_int_rows_str_columns(constructor_eager: ConstructorEager) -> None: data = {"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]} df = nw.from_native(constructor_eager(data), eager_only=True) @@ -171,11 +185,10 @@ def test_slice_slice_columns( assert_equal_data(result, expected) -def test_slice_invalid(constructor_eager: ConstructorEager) -> None: +def test_slice_item(constructor_eager: ConstructorEager) -> None: data = {"a": [1, 2], "b": [4, 5]} df = nw.from_native(constructor_eager(data), eager_only=True) - with pytest.raises(TypeError, match="Hint:"): - df[0, 0] + assert df[0, 0] == 1 def test_slice_edge_cases(constructor_eager: ConstructorEager) -> None: @@ -242,3 +255,120 @@ def test_get_item_works_with_tuple_and_list_indexing_and_str( ) -> None: nw_df = nw.from_native(constructor_eager(data), eager_only=True) nw_df[row_idx, col] + + +def test_getitem_ndarray_columns(constructor_eager: ConstructorEager) -> None: + data = {"col1": ["a", "b", "c", "d"], "col2": np.arange(4), "col3": [4, 3, 2, 1]} + nw_df = nw.from_native(constructor_eager(data), eager_only=True) + arr = np.arange(2) + result = nw_df[:, arr] + expected = {"col1": ["a", "b", "c", "d"], "col2": [0, 1, 2, 3]} + assert_equal_data(result, expected) + + +def test_getitem_ndarray_columns_labels(constructor_eager: ConstructorEager) -> None: + data = {"col1": ["a", "b", "c", "d"], "col2": np.arange(4), "col3": [4, 3, 2, 1]} + nw_df = nw.from_native(constructor_eager(data), eager_only=True) + arr: np.ndarray[tuple[int], np.dtype[Any]] = np.array(["col1", "col2"]) # pyright: ignore[reportAssignmentType] + result = nw_df[:, arr] + expected = {"col1": ["a", "b", "c", "d"], "col2": [0, 1, 2, 3]} + assert_equal_data(result, expected) + + +def test_getitem_negative_slice(constructor_eager: ConstructorEager) -> None: + data = {"col1": ["a", "b", "c", "d"], "col2": np.arange(4), "col3": [4, 3, 2, 1]} + nw_df = nw.from_native(constructor_eager(data), eager_only=True) + result = nw_df[-3:-2, ["col3", "col1"]] + expected = {"col3": [3], "col1": ["b"]} + assert_equal_data(result, expected) + result = nw_df[-3:-2] + expected = {"col1": ["b"], "col2": [1], "col3": [3]} + assert_equal_data(result, expected) + result_s = nw_df["col1"][-3:-2] + expected = {"col1": ["b"]} + assert_equal_data({"col1": result_s}, expected) + + +def test_zeroth_row_no_columns(constructor_eager: ConstructorEager) -> None: + data = {"col1": ["a", "b", "c", "d"], "col2": np.arange(4), "col3": [4, 3, 2, 1]} + nw_df = nw.from_native(constructor_eager(data), eager_only=True) + columns: list[str] = [] + result = nw_df[0, columns] + assert result.shape == (0, 0) + + +def test_single_tuple(constructor_eager: ConstructorEager) -> None: + data = {"a": [1, 2, 3]} + nw_df = nw.from_native(constructor_eager(data), eager_only=True) + # Technically works but we should probably discourage it + # OK if overloads don't match it. + result = nw_df[[0, 1],] # type: ignore[index] + expected = {"a": [1, 2]} + assert_equal_data(result, expected) + + +def test_triple_tuple(constructor_eager: ConstructorEager) -> None: + data = {"a": [1, 2, 3]} + with pytest.raises(TypeError, match="Tuples cannot"): + nw.from_native(constructor_eager(data), eager_only=True)[(1, 2, 3)] + + +def test_slice_with_series( + constructor_eager: ConstructorEager, request: pytest.FixtureRequest +) -> None: + if "pandas_pyarrow" in str(constructor_eager): + request.applymarker(pytest.mark.xfail) + data = {"a": [1, 2, 3], "c": [0, 2, 1]} + nw_df = nw.from_native(constructor_eager(data), eager_only=True) + result = nw_df[nw_df["c"]] + expected = {"a": [1, 3, 2], "c": [0, 1, 2]} + assert_equal_data(result, expected) + result = nw_df[nw_df["c"], ["a"]] + expected = {"a": [1, 3, 2]} + assert_equal_data(result, expected) + + +def test_horizontal_slice_with_series(constructor_eager: ConstructorEager) -> None: + data = {"a": [1, 2], "c": [0, 2], "d": ["c", "a"]} + nw_df = nw.from_native(constructor_eager(data), eager_only=True) + result = nw_df[nw_df["d"]] + expected = {"c": [0, 2], "a": [1, 2]} + assert_equal_data(result, expected) + + +def test_horizontal_slice_with_series_2( + constructor_eager: ConstructorEager, request: pytest.FixtureRequest +) -> None: + if "pandas_pyarrow" in str(constructor_eager): + request.applymarker(pytest.mark.xfail) + data = {"a": [1, 2], "c": [0, 2], "d": ["c", "a"]} + nw_df = nw.from_native(constructor_eager(data), eager_only=True) + result = nw_df[:, nw_df["c"]] + expected = {"a": [1, 2], "d": ["c", "a"]} + assert_equal_data(result, expected) + + +def test_native_slice_series(constructor_eager: ConstructorEager) -> None: + s = nw.from_native(constructor_eager({"a": [0, 2, 1]}), eager_only=True)["a"] + result = {"a": s[s.to_native()]} + expected = {"a": [0, 1, 2]} + assert_equal_data(result, expected) + + +def test_pandas_non_str_columns() -> None: + # The general rule with getitem is: ints are always treated as positions. The rest, we should + # be able to hand down to the native frame. Here we check what happens for pandas with + # datetime column names. + df = nw.from_native( + pd.DataFrame({datetime(2020, 1, 1): [1, 2, 3], datetime(2020, 1, 2): [4, 5, 6]}), + eager_only=True, + ) + result = df[:, [datetime(2020, 1, 1)]] # type: ignore[index] + expected = {datetime(2020, 1, 1): [1, 2, 3]} + assert result.to_dict(as_series=False) == expected + + +def test_select_rows_by_name(constructor_eager: ConstructorEager) -> None: + df = nw.from_native(constructor_eager({"a": [0, 2, 1]}), eager_only=True) + with pytest.raises(TypeError, match="Unexpected type"): + df["a", :] # type: ignore[index] diff --git a/tests/hypothesis/getitem_test.py b/tests/hypothesis/getitem_test.py index 970745f03e..940712d29f 100644 --- a/tests/hypothesis/getitem_test.py +++ b/tests/hypothesis/getitem_test.py @@ -155,79 +155,23 @@ def test_getitem( ) ) - # IndexError: Offset must be non-negative (pyarrow does not support negative indexing) - assume( - not ( - pandas_or_pyarrow_constructor is pyarrow_table_constructor - and isinstance(selector, slice) - and isinstance(selector.start, int) - and selector.start < 0 - ) - ) - assume( - not ( - pandas_or_pyarrow_constructor is pyarrow_table_constructor - and isinstance(selector, slice) - and isinstance(selector.stop, int) - and selector.stop < 0 - ) - ) - - # Pairs of slices are not supported - # NB a few trivial cases are supported, eg df[0:1, :] - # TypeError: Got unexpected argument type for compute function + # NotImplementedError: Slicing with step is not supported on PyArrow tables assume( not ( pandas_or_pyarrow_constructor is pyarrow_table_constructor and isinstance(selector, tuple) - and isinstance(selector[0], slice) - and isinstance(selector[1], slice) and ( - selector[0] != slice(None, None, None) - or selector[1] != slice(None, None, None) + (isinstance(selector[0], slice) and selector[0].step is not None) + or (isinstance(selector[1], slice) and selector[1].step is not None) ) ) ) - - # df[[], "a":], df[[], :] etc fail in pyarrow: - # ArrowNotImplementedError: Function 'array_take' has no kernel matching input types (int64, null) - assume( - not ( - pandas_or_pyarrow_constructor is pyarrow_table_constructor - and isinstance(selector, tuple) - and isinstance(selector[0], list) - and len(selector[0]) == 0 - and isinstance(selector[1], slice) - ) - ) - - # df[[], "a":], df[[], :] etc return different results between pandas/polars: - assume( - not ( - pandas_or_pyarrow_constructor is pandas_constructor - and isinstance(selector, tuple) - and isinstance(selector[0], list) - and len(selector[0]) == 0 - and isinstance(selector[1], slice) - ) - ) - - # df[..., ::step] is not fine: - # TypeError: Expected slice of integers or strings, got: - assume( - not ( - isinstance(selector, tuple) - and isinstance(selector[1], slice) - and selector[1].start is None - and selector[1].stop is None - ) - ) # End TODO ================================================================ df_polars = nw.from_native(pl.DataFrame(TEST_DATA)) try: result_polars = df_polars[selector] - except TypeError: + except TypeError: # pragma: no cover # If the selector fails on polars, then skip the test. # e.g. df[0, 'a'] fails, suggesting to use DataFrame.item to extract a single # element. @@ -240,6 +184,8 @@ def test_getitem( if isinstance(result_polars, nw.Series): assert_equal_data({"a": result_other}, {"a": result_polars.to_list()}) + elif isinstance(result_polars, (str, int)): # pragma: no cover + assert result_polars == result_other else: assert_equal_data( result_other, diff --git a/tests/series_only/__getitem___test.py b/tests/series_only/getitem_test.py similarity index 88% rename from tests/series_only/__getitem___test.py rename to tests/series_only/getitem_test.py index 8ecc5fb0a3..de3c5b9a17 100644 --- a/tests/series_only/__getitem___test.py +++ b/tests/series_only/getitem_test.py @@ -61,3 +61,11 @@ def test_getitem_other_series(constructor_eager: ConstructorEager) -> None: ] other = nw.from_native(constructor_eager({"b": [1, 3]}), eager_only=True)["b"] assert_equal_data(series[other].to_frame(), {"a": [None, 3]}) + + +def test_getitem_invalid_series(constructor_eager: ConstructorEager) -> None: + series = nw.from_native(constructor_eager({"a": [1, None, 2, 3]}), eager_only=True)[ + "a" + ] + with pytest.raises(TypeError, match="Unexpected type"): + series[series > 1]