diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index 83e7104e1f..209d831bf1 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -112,7 +112,7 @@ def __narwhals_lazyframe__(self: Self) -> Self: def _change_version(self: Self, version: Version) -> Self: return self.__class__( - self._native_frame, + self.native, backend_version=self._backend_version, version=version, validate_column_names=False, @@ -130,13 +130,13 @@ def _from_native_frame( @property def shape(self: Self) -> tuple[int, int]: - return self._native_frame.shape + return self.native.shape def __len__(self: Self) -> int: - return len(self._native_frame) + return len(self.native) def row(self: Self, index: int) -> tuple[Any, ...]: - return tuple(col[index] for col in self._native_frame.itercolumns()) + return tuple(col[index] for col in self.native.itercolumns()) @overload def rows(self: Self, *, named: Literal[True]) -> list[dict[str, Any]]: ... @@ -152,12 +152,12 @@ def rows( def rows(self: Self, *, named: bool) -> list[tuple[Any, ...]] | list[dict[str, Any]]: if not named: return list(self.iter_rows(named=False, buffer_size=512)) # type: ignore[return-value] - return self._native_frame.to_pylist() + return self.native.to_pylist() def iter_columns(self) -> Iterator[ArrowSeries]: from narwhals._arrow.series import ArrowSeries - for name, series in zip(self.columns, self._native_frame.itercolumns()): + for name, series in zip(self.columns, self.native.itercolumns()): yield ArrowSeries( series, name=name, @@ -170,7 +170,7 @@ def iter_columns(self) -> Iterator[ArrowSeries]: def iter_rows( self: Self, *, named: bool, buffer_size: int ) -> Iterator[tuple[Any, ...]] | Iterator[dict[str, Any]]: - df = self._native_frame + df = self.native num_rows = df.num_rows if not named: @@ -189,14 +189,14 @@ def get_column(self: Self, name: str) -> ArrowSeries: raise TypeError(msg) return ArrowSeries( - self._native_frame[name], + self.native[name], name=name, backend_version=self._backend_version, version=self._version, ) def __array__(self: Self, dtype: Any, *, copy: bool | None) -> _2DArray: - return self._native_frame.__array__(dtype, copy=copy) + return self.native.__array__(dtype, copy=copy) @overload def __getitem__( # type: ignore[overload-overlap, unused-ignore] @@ -238,7 +238,7 @@ def __getitem__( from narwhals._arrow.series import ArrowSeries return ArrowSeries( - self._native_frame[item], + self.native[item], name=item, backend_version=self._backend_version, version=self._version, @@ -251,8 +251,8 @@ def __getitem__( ): if len(item[1]) == 0: # Return empty dataframe - return self._from_native_frame(self._native_frame.slice(0, 0).select([])) - selected_rows = select_rows(self._native_frame, item[0]) + return self._from_native_frame(self.native.slice(0, 0).select([])) + selected_rows = select_rows(self.native, item[0]) return self._from_native_frame(selected_rows.select(cast("Indices", item[1]))) elif isinstance(item, tuple) and len(item) == 2: @@ -261,16 +261,16 @@ def __getitem__( indices = cast("Indices", item[0]) if item[1] == slice(None): if isinstance(item[0], Sequence) and len(item[0]) == 0: - return self._from_native_frame(self._native_frame.slice(0, 0)) - return self._from_native_frame(self._native_frame.take(indices)) + return self._from_native_frame(self.native.slice(0, 0)) + return self._from_native_frame(self.native.take(indices)) if isinstance(item[1].start, str) or isinstance(item[1].stop, str): start, stop, step = convert_str_slice_to_int_slice(item[1], columns) return self._from_native_frame( - self._native_frame.take(indices).select(columns[start:stop:step]) + self.native.take(indices).select(columns[start:stop:step]) ) if isinstance(item[1].start, int) or isinstance(item[1].stop, int): return self._from_native_frame( - self._native_frame.take(indices).select( + self.native.take(indices).select( columns[item[1].start : item[1].stop : item[1].step] ) ) @@ -289,12 +289,12 @@ def __getitem__( raise TypeError(msg) if (isinstance(item[0], slice)) and (item[0] == slice(None)): return ArrowSeries( - self._native_frame[col_name], + self.native[col_name], name=col_name, backend_version=self._backend_version, version=self._version, ) - selected_rows = select_rows(self._native_frame, item[0]) + selected_rows = select_rows(self.native, item[0]) return ArrowSeries( selected_rows[col_name], name=col_name, @@ -310,11 +310,11 @@ def __getitem__( if isinstance(item.start, str) or isinstance(item.stop, str): start, stop, step = convert_str_slice_to_int_slice(item, columns) return self._from_native_frame( - self._native_frame.select(columns[start:stop:step]) + self.native.select(columns[start:stop:step]) ) start = item.start or 0 - stop = item.stop if item.stop is not None else len(self._native_frame) - return self._from_native_frame(self._native_frame.slice(start, stop - start)) + stop = item.stop if item.stop is not None else len(self.native) + return self._from_native_frame(self.native.slice(start, stop - start)) elif isinstance(item, Sequence) or is_numpy_array_1d(item): if ( @@ -322,12 +322,10 @@ def __getitem__( and all(isinstance(x, str) for x in item) and len(item) > 0 ): - return self._from_native_frame( - self._native_frame.select(cast("Indices", item)) - ) + return self._from_native_frame(self.native.select(cast("Indices", item))) if isinstance(item, Sequence) and len(item) == 0: - return self._from_native_frame(self._native_frame.slice(0, 0)) - return self._from_native_frame(self._native_frame.take(cast("Indices", item))) + return self._from_native_frame(self.native.slice(0, 0)) + return self._from_native_frame(self.native.take(cast("Indices", item))) else: # pragma: no cover msg = f"Expected str or slice, got: {type(item)}" @@ -335,7 +333,7 @@ def __getitem__( @property def schema(self: Self) -> dict[str, DType]: - schema = self._native_frame.schema + schema = self.native.schema return { name: native_to_narwhals_dtype(dtype, self._version) for name, dtype in zip(schema.names, schema.types) @@ -345,18 +343,18 @@ def collect_schema(self: Self) -> dict[str, DType]: return self.schema def estimated_size(self: Self, unit: SizeUnit) -> int | float: - sz = self._native_frame.nbytes + sz = self.native.nbytes return scale_bytes(sz, unit) explode = not_implemented() @property def columns(self: Self) -> list[str]: - return self._native_frame.schema.names + return self.native.schema.names def simple_select(self, *column_names: str) -> Self: return self._from_native_frame( - self._native_frame.select(list(column_names)), validate_column_names=False + self.native.select(list(column_names)), validate_column_names=False ) def select(self: ArrowDataFrame, *exprs: ArrowExpr) -> ArrowDataFrame: @@ -364,7 +362,7 @@ def select(self: ArrowDataFrame, *exprs: ArrowExpr) -> ArrowDataFrame: if not new_series: # return empty dataframe, like Polars does return self._from_native_frame( - self._native_frame.__class__.from_arrays([]), validate_column_names=False + self.native.__class__.from_arrays([]), validate_column_names=False ) names = [s.name for s in new_series] reshaped = align_series_full_broadcast(*new_series) @@ -372,7 +370,9 @@ def select(self: ArrowDataFrame, *exprs: ArrowExpr) -> ArrowDataFrame: return self._from_native_frame(df, validate_column_names=True) def with_columns(self: ArrowDataFrame, *exprs: ArrowExpr) -> ArrowDataFrame: - native_frame = self._native_frame + # NOTE: We use a faux-mutable variable and repeatedly "overwrite" (native_frame) + # All `pyarrow` data is immutable, so this is fine + native_frame = self.native new_columns = self._evaluate_into_exprs(*exprs) length = len(self) @@ -440,7 +440,7 @@ def join( ) return self._from_native_frame( - self._native_frame.join( + self.native.join( other._native_frame, keys=left_on or [], # type: ignore[arg-type] right_keys=right_on, # type: ignore[arg-type] @@ -456,13 +456,13 @@ def drop(self: Self, columns: Sequence[str], *, strict: bool) -> Self: compliant_frame=self, columns=columns, strict=strict ) return self._from_native_frame( - self._native_frame.drop(to_drop), validate_column_names=False + self.native.drop(to_drop), validate_column_names=False ) def drop_nulls(self: ArrowDataFrame, subset: Sequence[str] | None) -> ArrowDataFrame: if subset is None: return self._from_native_frame( - self._native_frame.drop_null(), validate_column_names=False + self.native.drop_null(), validate_column_names=False ) plx = self.__narwhals_namespace__() return self.filter(~plx.any_horizontal(plx.col(*subset).is_null())) @@ -473,8 +473,6 @@ def sort( descending: bool | Sequence[bool], nulls_last: bool, ) -> Self: - df = self._native_frame - if isinstance(descending, bool): order: Order = "descending" if descending else "ascending" sorting: list[tuple[str, Order]] = [(key, order) for key in by] @@ -487,22 +485,22 @@ def sort( null_placement = "at_end" if nulls_last else "at_start" return self._from_native_frame( - df.sort_by(sorting, null_placement=null_placement), + self.native.sort_by(sorting, null_placement=null_placement), validate_column_names=False, ) def to_pandas(self: Self) -> pd.DataFrame: - return self._native_frame.to_pandas() + return self.native.to_pandas() def to_polars(self: Self) -> pl.DataFrame: import polars as pl # ignore-banned-import - return pl.from_arrow(self._native_frame) # type: ignore[return-value] + return pl.from_arrow(self.native) # type: ignore[return-value] def to_numpy(self: Self) -> _2DArray: import numpy as np # ignore-banned-import - arr: Any = np.column_stack([col.to_numpy() for col in self._native_frame.columns]) + arr: Any = np.column_stack([col.to_numpy() for col in self.native.columns]) return arr @overload @@ -514,7 +512,7 @@ def to_dict(self: Self, *, as_series: Literal[False]) -> dict[str, list[Any]]: . def to_dict( self: Self, *, as_series: bool ) -> dict[str, ArrowSeries] | dict[str, list[Any]]: - df = self._native_frame + df = self.native names_and_values = zip(df.column_names, df.columns) if as_series: @@ -533,7 +531,7 @@ def to_dict( return {name: col.to_pylist() for name, col in names_and_values} def with_row_index(self: Self, name: str) -> Self: - df = self._native_frame + df = self.native cols = self.columns row_indices = pa.array(range(df.num_rows)) @@ -548,14 +546,13 @@ def filter( mask_native: Mask | ArrowChunkedArray = predicate else: # `[0]` is safe as the predicate's expression only returns a single column - mask_native = self._evaluate_into_exprs(predicate)[0]._native_series + mask_native = self._evaluate_into_exprs(predicate)[0].native return self._from_native_frame( - self._native_frame.filter(mask_native), # pyright: ignore[reportArgumentType] - validate_column_names=False, + self.native.filter(mask_native), validate_column_names=False ) def head(self: Self, n: int) -> Self: - df = self._native_frame + df = self.native if n >= 0: return self._from_native_frame(df.slice(0, n), validate_column_names=False) else: @@ -565,7 +562,7 @@ def head(self: Self, n: int) -> Self: ) def tail(self: Self, n: int) -> Self: - df = self._native_frame + df = self.native if n >= 0: num_rows = df.num_rows return self._from_native_frame( @@ -577,8 +574,6 @@ def tail(self: Self, n: int) -> Self: def lazy( self: Self, *, backend: Implementation | None = None ) -> CompliantLazyFrame[Any, Any]: - from narwhals.utils import parse_version - if backend is None: return self elif backend is Implementation.DUCKDB: @@ -586,9 +581,9 @@ def lazy( from narwhals._duckdb.dataframe import DuckDBLazyFrame - df = self._native_frame # noqa: F841 + df = self.native # noqa: F841 return DuckDBLazyFrame( - df=duckdb.table("df"), + duckdb.table("df"), backend_version=parse_version(duckdb), version=self._version, ) @@ -598,7 +593,7 @@ def lazy( from narwhals._polars.dataframe import PolarsLazyFrame return PolarsLazyFrame( - df=pl.from_arrow(self._native_frame).lazy(), # type: ignore[union-attr] + cast("pl.DataFrame", pl.from_arrow(self.native)).lazy(), backend_version=parse_version(pl), version=self._version, ) @@ -609,7 +604,7 @@ def lazy( from narwhals._dask.dataframe import DaskLazyFrame return DaskLazyFrame( - native_dataframe=dd.from_pandas(self._native_frame.to_pandas()), + dd.from_pandas(self.native.to_pandas()), backend_version=parse_version(dask), version=self._version, ) @@ -619,12 +614,12 @@ def collect( self: Self, backend: Implementation | None, **kwargs: Any, - ) -> CompliantDataFrame[Any, Any]: + ) -> CompliantDataFrame[Any, Any, Any]: if backend is Implementation.PYARROW or backend is None: from narwhals._arrow.dataframe import ArrowDataFrame return ArrowDataFrame( - native_dataframe=self._native_frame, + self.native, backend_version=self._backend_version, version=self._version, validate_column_names=False, @@ -636,7 +631,7 @@ def collect( from narwhals._pandas_like.dataframe import PandasLikeDataFrame return PandasLikeDataFrame( - native_dataframe=self._native_frame.to_pandas(), + self.native.to_pandas(), implementation=Implementation.PANDAS, backend_version=parse_version(pd), version=self._version, @@ -649,7 +644,7 @@ def collect( from narwhals._polars.dataframe import PolarsDataFrame return PolarsDataFrame( - df=pl.from_arrow(self._native_frame), # type: ignore[arg-type] + cast("pl.DataFrame", pl.from_arrow(self.native)), backend_version=parse_version(pl), version=self._version, ) @@ -670,28 +665,24 @@ def item(self: Self, row: int | None, column: int | str | None) -> Any: f" frame has shape {self.shape!r}" ) raise ValueError(msg) - return maybe_extract_py_scalar( - self._native_frame[0][0], return_py_scalar=True - ) + return maybe_extract_py_scalar(self.native[0][0], return_py_scalar=True) elif row is None or column is None: msg = "cannot call `.item()` with only one of `row` or `column`" raise ValueError(msg) _col = self.columns.index(column) if isinstance(column, str) else column - return maybe_extract_py_scalar( - self._native_frame[_col][row], return_py_scalar=True - ) + return maybe_extract_py_scalar(self.native[_col][row], return_py_scalar=True) def rename(self: Self, mapping: Mapping[str, str]) -> Self: - df = self._native_frame + df = self.native new_cols = [mapping.get(c, c) for c in df.column_names] return self._from_native_frame(df.rename_columns(new_cols)) def write_parquet(self: Self, file: str | Path | BytesIO) -> None: import pyarrow.parquet as pp - pp.write_table(self._native_frame, file) + pp.write_table(self.native, file) @overload def write_csv(self: Self, file: None) -> str: ... @@ -702,12 +693,11 @@ def write_csv(self: Self, file: str | Path | BytesIO) -> None: ... def write_csv(self: Self, file: str | Path | BytesIO | None) -> str | None: import pyarrow.csv as pa_csv - pa_table = self._native_frame if file is None: csv_buffer = pa.BufferOutputStream() - pa_csv.write_csv(pa_table, csv_buffer) + pa_csv.write_csv(self.native, csv_buffer) return csv_buffer.getvalue().to_pybytes().decode() - pa_csv.write_csv(pa_table, file) + pa_csv.write_csv(self.native, file) return None def is_unique(self: Self) -> ArrowSeries: @@ -716,7 +706,7 @@ def is_unique(self: Self) -> ArrowSeries: col_token = generate_temporary_column_name(n_bytes=8, columns=self.columns) row_index = pa.array(range(len(self))) keep_idx = ( - self._native_frame.append_column(col_token, row_index) + self.native.append_column(col_token, row_index) .group_by(self.columns) .aggregate([(col_token, "min"), (col_token, "max")]) ) @@ -743,7 +733,6 @@ def unique( # and has no effect on the output. import numpy as np # ignore-banned-import - df = self._native_frame check_column_exists(self.columns, subset) subset = list(subset or self.columns) @@ -753,13 +742,14 @@ def unique( agg_func = agg_func_map[keep] col_token = generate_temporary_column_name(n_bytes=8, columns=self.columns) keep_idx_native = ( - df.append_column(col_token, pa.array(np.arange(len(self)))) + self.native.append_column(col_token, pa.array(np.arange(len(self)))) .group_by(subset) .aggregate([(col_token, agg_func)]) .column(f"{col_token}_{agg_func}") ) - indices = cast("Indices", keep_idx_native) - return self._from_native_frame(df.take(indices), validate_column_names=False) + return self._from_native_frame( + self.native.take(keep_idx_native), validate_column_names=False + ) keep_idx = self.simple_select(*subset).is_unique() plx = self.__narwhals_namespace__() @@ -767,11 +757,11 @@ def unique( def gather_every(self: Self, n: int, offset: int) -> Self: return self._from_native_frame( - self._native_frame[offset::n], validate_column_names=False + self.native[offset::n], validate_column_names=False ) def to_arrow(self: Self) -> pa.Table: - return self._native_frame + return self.native def sample( self: Self, @@ -783,16 +773,15 @@ def sample( ) -> Self: import numpy as np # ignore-banned-import - frame = self._native_frame num_rows = len(self) if n is None and fraction is not None: n = int(num_rows * fraction) - rng = np.random.default_rng(seed=seed) idx = np.arange(0, num_rows) mask = rng.choice(idx, size=n, replace=with_replacement) - - return self._from_native_frame(pc.take(frame, mask), validate_column_names=False) # type: ignore[call-overload, unused-ignore] + return self._from_native_frame( + self.native.take(mask), validate_column_names=False + ) def unpivot( self: Self, @@ -801,7 +790,6 @@ def unpivot( variable_name: str, value_name: str, ) -> Self: - native_frame = self._native_frame n_rows = len(self) index_ = [] if index is None else index on_ = [c for c in self.columns if c not in index_] if on is None else on @@ -816,12 +804,12 @@ def unpivot( [ pa.Table.from_arrays( [ - *(native_frame.column(idx_col) for idx_col in index_), + *(self.native.column(idx_col) for idx_col in index_), cast( "ArrowChunkedArray", pa.array([on_col] * n_rows, pa.string()), ), - native_frame.column(on_col), + self.native.column(on_col), ], names=names, ) diff --git a/narwhals/_compliant/__init__.py b/narwhals/_compliant/__init__.py index e47d1a9d60..2c65296a66 100644 --- a/narwhals/_compliant/__init__.py +++ b/narwhals/_compliant/__init__.py @@ -23,6 +23,7 @@ from narwhals._compliant.typing import EagerDataFrameT from narwhals._compliant.typing import EagerSeriesT from narwhals._compliant.typing import IntoCompliantExpr +from narwhals._compliant.typing import NativeFrameT_co __all__ = [ "CompliantDataFrame", @@ -48,4 +49,5 @@ "IntoCompliantExpr", "LazyExpr", "LazySelectorNamespace", + "NativeFrameT_co", ] diff --git a/narwhals/_compliant/dataframe.py b/narwhals/_compliant/dataframe.py index e23ddc0ab0..0246fd7dcf 100644 --- a/narwhals/_compliant/dataframe.py +++ b/narwhals/_compliant/dataframe.py @@ -44,7 +44,13 @@ T = TypeVar("T") -class CompliantDataFrame(Sized, Protocol[CompliantSeriesT, CompliantExprT_contra]): +class CompliantDataFrame( + _StoresNative[NativeFrameT_co], + Sized, + Protocol[CompliantSeriesT, CompliantExprT_contra, NativeFrameT_co], +): + _native_frame: Any + def __narwhals_dataframe__(self) -> Self: ... def __narwhals_namespace__(self) -> Any: ... def __array__(self, dtype: Any, *, copy: bool | None) -> _2DArray: ... @@ -60,6 +66,10 @@ def aggregate(self, *exprs: CompliantExprT_contra) -> Self: """ return self.select(*exprs) + @property + def native(self) -> NativeFrameT_co: + return self._native_frame # type: ignore[no-any-return] + @property def columns(self) -> Sequence[str]: ... @property @@ -69,7 +79,7 @@ def shape(self) -> tuple[int, int]: ... def clone(self) -> Self: ... def collect( self, backend: Implementation | None, **kwargs: Any - ) -> CompliantDataFrame[Any, Any]: ... + ) -> CompliantDataFrame[Any, Any, Any]: ... def collect_schema(self) -> Mapping[str, DType]: ... def drop(self, columns: Sequence[str], *, strict: bool) -> Self: ... def drop_nulls(self, subset: Sequence[str] | None) -> Self: ... @@ -195,7 +205,7 @@ def schema(self) -> Mapping[str, DType]: ... def _iter_columns(self) -> Iterator[Any]: ... def collect( self, backend: Implementation | None, **kwargs: Any - ) -> CompliantDataFrame[Any, Any]: ... + ) -> CompliantDataFrame[Any, Any, Any]: ... def collect_schema(self) -> Mapping[str, DType]: ... def drop(self, columns: Sequence[str], *, strict: bool) -> Self: ... def drop_nulls(self, subset: Sequence[str] | None) -> Self: ... @@ -252,7 +262,7 @@ def with_row_index(self, name: str) -> Self: ... class EagerDataFrame( - CompliantDataFrame[EagerSeriesT, EagerExprT_contra], + CompliantDataFrame[EagerSeriesT, EagerExprT_contra, NativeFrameT_co], CompliantLazyFrame[EagerExprT_contra, NativeFrameT_co], Protocol[EagerSeriesT, EagerExprT_contra, NativeFrameT_co], ): diff --git a/narwhals/_compliant/selectors.py b/narwhals/_compliant/selectors.py index efcb5a4f33..42803d3e15 100644 --- a/narwhals/_compliant/selectors.py +++ b/narwhals/_compliant/selectors.py @@ -66,9 +66,9 @@ SeriesT = TypeVar("SeriesT", bound="CompliantSeries") ExprT = TypeVar("ExprT", bound="NativeExpr") FrameT = TypeVar( - "FrameT", bound="CompliantDataFrame[Any, Any] | CompliantLazyFrame[Any, Any]" + "FrameT", bound="CompliantDataFrame[Any, Any, Any] | CompliantLazyFrame[Any, Any]" ) -DataFrameT = TypeVar("DataFrameT", bound="CompliantDataFrame[Any, Any]") +DataFrameT = TypeVar("DataFrameT", bound="CompliantDataFrame[Any, Any, Any]") LazyFrameT = TypeVar("LazyFrameT", bound="CompliantLazyFrame[Any, Any]") SelectorOrExpr: TypeAlias = ( "CompliantSelector[FrameT, SeriesOrExprT] | CompliantExpr[FrameT, SeriesOrExprT]" @@ -311,7 +311,7 @@ def __repr__(self: Self) -> str: # pragma: no cover def _eval_lhs_rhs( - df: CompliantDataFrame[Any, Any] | CompliantLazyFrame[Any, Any], + df: CompliantDataFrame[Any, Any, Any] | CompliantLazyFrame[Any, Any], lhs: CompliantExpr[Any, Any], rhs: CompliantExpr[Any, Any], ) -> tuple[Sequence[str], Sequence[str]]: diff --git a/narwhals/_compliant/typing.py b/narwhals/_compliant/typing.py index d2efbaa489..1da99e6685 100644 --- a/narwhals/_compliant/typing.py +++ b/narwhals/_compliant/typing.py @@ -28,6 +28,7 @@ "CompliantLazyFrameT", "CompliantSeriesT", "IntoCompliantExpr", + "NativeFrameT_co", ] NativeExprT_co = TypeVar("NativeExprT_co", bound="NativeExpr", covariant=True) CompliantSeriesT = TypeVar("CompliantSeriesT", bound="CompliantSeries") @@ -36,12 +37,14 @@ bound="CompliantSeries | NativeExpr", covariant=True, ) - NativeFrameT_co = TypeVar("NativeFrameT_co", bound="NativeFrame", covariant=True) CompliantFrameT = TypeVar( - "CompliantFrameT", bound="CompliantDataFrame[Any, Any] | CompliantLazyFrame[Any, Any]" + "CompliantFrameT", + bound="CompliantDataFrame[Any, Any, Any] | CompliantLazyFrame[Any, Any]", +) +CompliantDataFrameT = TypeVar( + "CompliantDataFrameT", bound="CompliantDataFrame[Any, Any, Any]" ) -CompliantDataFrameT = TypeVar("CompliantDataFrameT", bound="CompliantDataFrame[Any, Any]") CompliantLazyFrameT = TypeVar("CompliantLazyFrameT", bound="CompliantLazyFrame[Any, Any]") IntoCompliantExpr: TypeAlias = "CompliantExpr[CompliantFrameT, CompliantSeriesOrNativeExprT_co] | CompliantSeriesOrNativeExprT_co" CompliantExprT = TypeVar("CompliantExprT", bound="CompliantExpr[Any, Any]") diff --git a/narwhals/_dask/dataframe.py b/narwhals/_dask/dataframe.py index 95ea9aa6cb..58add7a79b 100644 --- a/narwhals/_dask/dataframe.py +++ b/narwhals/_dask/dataframe.py @@ -95,7 +95,7 @@ def collect( self: Self, backend: Implementation | None, **kwargs: Any, - ) -> CompliantDataFrame[Any, Any]: + ) -> CompliantDataFrame[Any, Any, Any]: result = self._native_frame.compute(**kwargs) if backend is None or backend is Implementation.PANDAS: diff --git a/narwhals/_duckdb/dataframe.py b/narwhals/_duckdb/dataframe.py index 34197c779b..5ac1642b23 100644 --- a/narwhals/_duckdb/dataframe.py +++ b/narwhals/_duckdb/dataframe.py @@ -94,7 +94,7 @@ def collect( self: Self, backend: ModuleType | Implementation | str | None, **kwargs: Any, - ) -> CompliantDataFrame[Any, Any]: + ) -> CompliantDataFrame[Any, Any, Any]: if backend is None or backend is Implementation.PYARROW: import pyarrow as pa # ignore-banned-import diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index 95f34c7033..2c5f8cb506 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -109,7 +109,7 @@ def extract_compliant( def evaluate_output_names_and_aliases( expr: CompliantExpr[Any, Any], - df: CompliantDataFrame[Any, Any] | CompliantLazyFrame[Any, Any], + df: CompliantDataFrame[Any, Any, Any] | CompliantLazyFrame[Any, Any], exclude: Sequence[str], ) -> tuple[Sequence[str], Sequence[str]]: output_names = expr._evaluate_output_names(df) diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index 5a4043e4f3..49b2e86ed6 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -127,11 +127,11 @@ def __native_namespace__(self: Self) -> ModuleType: raise AssertionError(msg) def __len__(self: Self) -> int: - return len(self._native_frame) + return len(self.native) def _change_version(self: Self, version: Version) -> Self: return self.__class__( - self._native_frame, + self.native, implementation=self._implementation, backend_version=self._backend_version, version=version, @@ -151,7 +151,7 @@ def _from_native_frame( def get_column(self: Self, name: str) -> PandasLikeSeries: return PandasLikeSeries( - self._native_frame[name], + self.native[name], implementation=self._implementation, backend_version=self._backend_version, version=self._version, @@ -200,7 +200,7 @@ def __getitem__( if isinstance(item, str): return PandasLikeSeries( - self._native_frame[item], + self.native[item], implementation=self._implementation, backend_version=self._backend_version, version=self._version, @@ -214,19 +214,19 @@ def __getitem__( if len(item[1]) == 0: # Return empty dataframe return self._from_native_frame( - self._native_frame.__class__(), validate_column_names=False + self.native.__class__(), validate_column_names=False ) if all(isinstance(x, int) for x in item[1]): # type: ignore[var-annotated] return self._from_native_frame( - self._native_frame.iloc[item], validate_column_names=False + self.native.iloc[item], validate_column_names=False ) if all(isinstance(x, str) for x in item[1]): # type: ignore[var-annotated] indexer = ( item[0], - self._native_frame.columns.get_indexer(item[1]), + self.native.columns.get_indexer(item[1]), ) return self._from_native_frame( - self._native_frame.iloc[indexer], validate_column_names=False + self.native.iloc[indexer], validate_column_names=False ) msg = ( f"Expected sequence str or int, got: {type(item[1])}" # pragma: no cover @@ -234,20 +234,20 @@ def __getitem__( raise TypeError(msg) # pragma: no cover elif isinstance(item, tuple) and len(item) == 2 and isinstance(item[1], slice): - columns = self._native_frame.columns + columns = self.native.columns if item[1] == slice(None): return self._from_native_frame( - self._native_frame.iloc[item[0], :], validate_column_names=False + self.native.iloc[item[0], :], validate_column_names=False ) if isinstance(item[1].start, str) or isinstance(item[1].stop, str): start, stop, step = convert_str_slice_to_int_slice(item[1], columns) return self._from_native_frame( - self._native_frame.iloc[item[0], slice(start, stop, step)], + self.native.iloc[item[0], slice(start, stop, step)], validate_column_names=False, ) if isinstance(item[1].start, int) or isinstance(item[1].stop, int): return self._from_native_frame( - self._native_frame.iloc[ + self.native.iloc[ item[0], slice(item[1].start, item[1].stop, item[1].step) ], validate_column_names=False, @@ -257,10 +257,10 @@ def __getitem__( elif isinstance(item, tuple) and len(item) == 2: if isinstance(item[1], str): - index = (item[0], self._native_frame.columns.get_loc(item[1])) - native_series = self._native_frame.iloc[index] + index = (item[0], self.native.columns.get_loc(item[1])) + native_series = self.native.iloc[index] elif isinstance(item[1], int): - native_series = self._native_frame.iloc[item] + native_series = self.native.iloc[item] else: # pragma: no cover msg = f"Expected str or int, got: {type(item[1])}" raise TypeError(msg) @@ -276,7 +276,7 @@ def __getitem__( if all(isinstance(x, str) for x in item) and len(item) > 0: return self._from_native_frame( select_columns_by_name( - self._native_frame, + self.native, cast("list[str] | _1DArray", item), self._backend_version, self._implementation, @@ -284,20 +284,20 @@ def __getitem__( validate_column_names=False, ) return self._from_native_frame( - self._native_frame.iloc[item], validate_column_names=False + self.native.iloc[item], validate_column_names=False ) elif isinstance(item, slice): if isinstance(item.start, str) or isinstance(item.stop, str): start, stop, step = convert_str_slice_to_int_slice( - item, self._native_frame.columns + item, self.native.columns ) return self._from_native_frame( - self._native_frame.iloc[:, slice(start, stop, step)], + self.native.iloc[:, slice(start, stop, step)], validate_column_names=False, ) return self._from_native_frame( - self._native_frame.iloc[item], validate_column_names=False + self.native.iloc[item], validate_column_names=False ) else: # pragma: no cover @@ -307,7 +307,7 @@ def __getitem__( # --- properties --- @property def columns(self: Self) -> list[str]: - return self._native_frame.columns.tolist() + return self.native.columns.tolist() @overload def rows( @@ -337,12 +337,12 @@ def rows(self: Self, *, named: bool) -> list[tuple[Any, ...]] | list[dict[str, A # Extract the row values from the named rows return [tuple(row.values()) for row in self.rows(named=True)] - return list(self._native_frame.itertuples(index=False, name=None)) + return list(self.native.itertuples(index=False, name=None)) - return self._native_frame.to_dict(orient="records") + return self.native.to_dict(orient="records") def iter_columns(self) -> Iterator[PandasLikeSeries]: - for _name, series in self._native_frame.items(): # noqa: PERF102 + for _name, series in self.native.items(): # noqa: PERF102 yield PandasLikeSeries( series, implementation=self._implementation, @@ -361,24 +361,24 @@ def iter_rows( # The param ``buffer_size`` is only here for compatibility with the Polars API # and has no effect on the output. if not named: - yield from self._native_frame.itertuples(index=False, name=None) + yield from self.native.itertuples(index=False, name=None) else: - col_names = self._native_frame.columns - for row in self._native_frame.itertuples(index=False): + col_names = self.native.columns + for row in self.native.itertuples(index=False): yield dict(zip(col_names, row)) @property def schema(self: Self) -> dict[str, DType]: - native_dtypes = self._native_frame.dtypes + native_dtypes = self.native.dtypes return { col: native_to_narwhals_dtype( native_dtypes[col], self._version, self._implementation ) if native_dtypes[col] != "object" else object_native_to_narwhals_dtype( - self._native_frame[col], self._version, self._implementation + self.native[col], self._version, self._implementation ) - for col in self._native_frame.columns + for col in self.native.columns } def collect_schema(self: Self) -> dict[str, DType]: @@ -388,7 +388,7 @@ def collect_schema(self: Self) -> dict[str, DType]: def simple_select(self: Self, *column_names: str) -> Self: return self._from_native_frame( select_columns_by_name( - self._native_frame, + self.native, list(column_names), self._backend_version, self._implementation, @@ -401,7 +401,7 @@ def select(self: PandasLikeDataFrame, *exprs: PandasLikeExpr) -> PandasLikeDataF if not new_series: # return empty dataframe, like Polars does return self._from_native_frame( - self._native_frame.__class__(), validate_column_names=False + self.native.__class__(), validate_column_names=False ) new_series = align_series_full_broadcast(*new_series) df = horizontal_concat( @@ -416,17 +416,17 @@ def drop_nulls( ) -> PandasLikeDataFrame: if subset is None: return self._from_native_frame( - self._native_frame.dropna(axis=0), validate_column_names=False + self.native.dropna(axis=0), validate_column_names=False ) plx = self.__narwhals_namespace__() return self.filter(~plx.any_horizontal(plx.col(*subset).is_null())) def estimated_size(self: Self, unit: SizeUnit) -> int | float: - sz = self._native_frame.memory_usage(deep=True).sum() + sz = self.native.memory_usage(deep=True).sum() return scale_bytes(sz, unit=unit) def with_row_index(self: Self, name: str) -> Self: - frame = self._native_frame + frame = self.native namespace = self.__narwhals_namespace__() row_index = namespace._series._from_iterable( range(len(frame)), name="", context=self, index=frame.index @@ -440,7 +440,7 @@ def with_row_index(self: Self, name: str) -> Self: ) def row(self: Self, index: int) -> tuple[Any, ...]: - return tuple(x for x in self._native_frame.iloc[index]) + return tuple(x for x in self.native.iloc[index]) def filter( self: PandasLikeDataFrame, predicate: PandasLikeExpr | list[bool] @@ -450,16 +450,16 @@ def filter( else: # `[0]` is safe as the predicate's expression only returns a single column mask = self._evaluate_into_exprs(predicate)[0] - mask_native = extract_dataframe_comparand(self._native_frame.index, mask) + mask_native = extract_dataframe_comparand(self.native.index, mask) return self._from_native_frame( - self._native_frame.loc[mask_native], validate_column_names=False + self.native.loc[mask_native], validate_column_names=False ) def with_columns( self: PandasLikeDataFrame, *exprs: PandasLikeExpr ) -> PandasLikeDataFrame: - index = self._native_frame.index + index = self.native.index new_columns = self._evaluate_into_exprs(*exprs) if not new_columns and len(self) == 0: return self @@ -467,7 +467,7 @@ def with_columns( new_column_name_to_new_column_map = {s.name: s for s in new_columns} to_concat = [] # Make sure to preserve column order - for name in self._native_frame.columns: + for name in self.native.columns: if name in new_column_name_to_new_column_map: to_concat.append( extract_dataframe_comparand( @@ -475,7 +475,7 @@ def with_columns( ) ) else: - to_concat.append(self._native_frame[name]) + to_concat.append(self.native[name]) to_concat.extend( extract_dataframe_comparand(index, new_column_name_to_new_column_map[s]) for s in new_column_name_to_new_column_map @@ -491,7 +491,7 @@ def with_columns( def rename(self: Self, mapping: Mapping[str, str]) -> Self: return self._from_native_frame( rename( - self._native_frame, + self.native, columns=mapping, implementation=self._implementation, backend_version=self._backend_version, @@ -503,7 +503,7 @@ def drop(self: Self, columns: Sequence[str], *, strict: bool) -> Self: compliant_frame=self, columns=columns, strict=strict ) return self._from_native_frame( - self._native_frame.drop(columns=to_drop), validate_column_names=False + self.native.drop(columns=to_drop), validate_column_names=False ) # --- transform --- @@ -513,7 +513,7 @@ def sort( descending: bool | Sequence[bool], nulls_last: bool, ) -> Self: - df = self._native_frame + df = self.native if isinstance(descending, bool): ascending: bool | list[bool] = not descending else: @@ -529,10 +529,10 @@ def collect( self: Self, backend: Implementation | None, **kwargs: Any, - ) -> CompliantDataFrame[Any, Any]: + ) -> CompliantDataFrame[Any, Any, Any]: if backend is None: return PandasLikeDataFrame( - self._native_frame, + self.native, implementation=self._implementation, backend_version=self._backend_version, version=self._version, @@ -608,7 +608,7 @@ def join( ) return self._from_native_frame( - self._native_frame.assign(**{key_token: 0}) + self.native.assign(**{key_token: 0}) .merge( other._native_frame.assign(**{key_token: 0}), how="inner", @@ -620,7 +620,7 @@ def join( ) else: return self._from_native_frame( - self._native_frame.merge( + self.native.merge( other._native_frame, how="cross", suffixes=("", suffix), @@ -630,7 +630,7 @@ def join( if how == "anti": if self._implementation is Implementation.CUDF: return self._from_native_frame( - self._native_frame.merge( + self.native.merge( other._native_frame, how="leftanti", left_on=left_on, @@ -658,7 +658,7 @@ def join( backend_version=self._backend_version, ).drop_duplicates() return self._from_native_frame( - self._native_frame.merge( + self.native.merge( other_native, how="outer", indicator=indicator_token, @@ -688,7 +688,7 @@ def join( ).drop_duplicates() # avoids potential rows duplication from inner join ) return self._from_native_frame( - self._native_frame.merge( + self.native.merge( other_native, how="inner", left_on=left_on, @@ -698,7 +698,7 @@ def join( if how == "left": other_native = other._native_frame - result_native = self._native_frame.merge( + result_native = self.native.merge( other_native, how="left", left_on=left_on, @@ -714,7 +714,7 @@ def join( return self._from_native_frame(result_native.drop(columns=extra)) return self._from_native_frame( - self._native_frame.merge( + self.native.merge( other._native_frame, left_on=left_on, right_on=right_on, @@ -737,7 +737,7 @@ def join_asof( plx = self.__native_namespace__() return self._from_native_frame( plx.merge_asof( - self._native_frame, + self.native, other._native_frame, left_on=left_on, right_on=right_on, @@ -751,14 +751,10 @@ def join_asof( # --- partial reduction --- def head(self: Self, n: int) -> Self: - return self._from_native_frame( - self._native_frame.head(n), validate_column_names=False - ) + return self._from_native_frame(self.native.head(n), validate_column_names=False) def tail(self: Self, n: int) -> Self: - return self._from_native_frame( - self._native_frame.tail(n), validate_column_names=False - ) + return self._from_native_frame(self.native.tail(n), validate_column_names=False) def unique( self: Self, @@ -772,7 +768,7 @@ def unique( mapped_keep = {"none": False, "any": "first"}.get(keep, keep) check_column_exists(self.columns, subset) return self._from_native_frame( - self._native_frame.drop_duplicates(subset=subset, keep=mapped_keep), + self.native.drop_duplicates(subset=subset, keep=mapped_keep), validate_column_names=False, ) @@ -820,23 +816,23 @@ def lazy( @property def shape(self: Self) -> tuple[int, int]: - return self._native_frame.shape + return self.native.shape def to_dict(self: Self, *, as_series: bool) -> dict[str, Any]: if as_series: return { col: PandasLikeSeries( - self._native_frame[col], + self.native[col], implementation=self._implementation, backend_version=self._backend_version, version=self._version, ) for col in self.columns } - return self._native_frame.to_dict(orient="list") + return self.native.to_dict(orient="list") def to_numpy(self: Self, dtype: Any = None, *, copy: bool | None = None) -> _2DArray: - native_dtypes = self._native_frame.dtypes + native_dtypes = self.native.dtypes if copy is None: # pandas default differs from Polars, but cuDF default is True @@ -845,8 +841,8 @@ def to_numpy(self: Self, dtype: Any = None, *, copy: bool | None = None) -> _2DA if native_dtypes.isin(CLASSICAL_NUMPY_DTYPES).all(): # Fast path, no conversions necessary. if dtype is not None: - return self._native_frame.to_numpy(dtype=dtype, copy=copy) - return self._native_frame.to_numpy(copy=copy) + return self.native.to_numpy(dtype=dtype, copy=copy) + return self.native.to_numpy(copy=copy) dtypes = import_dtypes_module(self._version) @@ -863,7 +859,7 @@ def to_numpy(self: Self, dtype: Any = None, *, copy: bool | None = None) -> _2DA .dt.replace_time_zone(None) )._native_frame else: - df = self._native_frame + df = self.native if dtype is not None: return df.to_numpy(dtype=dtype, copy=copy) @@ -887,11 +883,11 @@ def to_numpy(self: Self, dtype: Any = None, *, copy: bool | None = None) -> _2DA def to_pandas(self: Self) -> pd.DataFrame: if self._implementation is Implementation.PANDAS: - return self._native_frame + return self.native elif self._implementation is Implementation.CUDF: # pragma: no cover - return self._native_frame.to_pandas() + return self.native.to_pandas() elif self._implementation is Implementation.MODIN: - return self._native_frame._to_pandas() + return self.native._to_pandas() msg = f"Unknown implementation: {self._implementation}" # pragma: no cover raise AssertionError(msg) @@ -899,16 +895,16 @@ def to_polars(self: Self) -> pl.DataFrame: import polars as pl # ignore-banned-import if self._implementation is Implementation.PANDAS: - return pl.from_pandas(self._native_frame) + return pl.from_pandas(self.native) elif self._implementation is Implementation.CUDF: # pragma: no cover - return pl.from_pandas(self._native_frame.to_pandas()) + return pl.from_pandas(self.native.to_pandas()) elif self._implementation is Implementation.MODIN: - return pl.from_pandas(self._native_frame._to_pandas()) + return pl.from_pandas(self.native._to_pandas()) msg = f"Unknown implementation: {self._implementation}" # pragma: no cover raise AssertionError(msg) def write_parquet(self: Self, file: str | Path | BytesIO) -> None: - self._native_frame.to_parquet(file) + self.native.to_parquet(file) @overload def write_csv(self: Self, file: None) -> str: ... @@ -917,12 +913,12 @@ def write_csv(self: Self, file: None) -> str: ... def write_csv(self: Self, file: str | Path | BytesIO) -> None: ... def write_csv(self: Self, file: str | Path | BytesIO | None) -> str | None: - return self._native_frame.to_csv(file, index=False) + return self.native.to_csv(file, index=False) # --- descriptive --- def is_unique(self: Self) -> PandasLikeSeries: return PandasLikeSeries( - ~self._native_frame.duplicated(keep=False), + ~self.native.duplicated(keep=False), implementation=self._implementation, backend_version=self._backend_version, version=self._version, @@ -937,23 +933,21 @@ def item(self: Self, row: int | None, column: int | str | None) -> Any: f" frame has shape {self.shape!r}" ) raise ValueError(msg) - return self._native_frame.iloc[0, 0] + return self.native.iloc[0, 0] elif row is None or column is None: msg = "cannot call `.item()` with only one of `row` or `column`" raise ValueError(msg) _col = self.columns.index(column) if isinstance(column, str) else column - return self._native_frame.iloc[row, _col] + return self.native.iloc[row, _col] def clone(self: Self) -> Self: - return self._from_native_frame( - self._native_frame.copy(), validate_column_names=False - ) + return self._from_native_frame(self.native.copy(), validate_column_names=False) def gather_every(self: Self, n: int, offset: int) -> Self: return self._from_native_frame( - self._native_frame.iloc[offset::n], validate_column_names=False + self.native.iloc[offset::n], validate_column_names=False ) def pivot( @@ -976,7 +970,7 @@ def pivot( raise NotImplementedError(msg) from itertools import product - frame = self._native_frame + frame = self.native if index is None: index = [c for c in self.columns if c not in {*on, *values}] # type: ignore[misc] @@ -1005,19 +999,17 @@ def pivot( # Put columns in the right order if sort_columns and self._implementation is Implementation.CUDF: uniques = { - col: sorted(self._native_frame[col].unique().to_arrow().to_pylist()) + col: sorted(self.native[col].unique().to_arrow().to_pylist()) for col in on } elif sort_columns: - uniques = { - col: sorted(self._native_frame[col].unique().tolist()) for col in on - } + uniques = {col: sorted(self.native[col].unique().tolist()) for col in on} elif self._implementation is Implementation.CUDF: uniques = { - col: self._native_frame[col].unique().to_arrow().to_pylist() for col in on + col: self.native[col].unique().to_arrow().to_pylist() for col in on } else: - uniques = {col: self._native_frame[col].unique().tolist() for col in on} + uniques = {col: self.native[col].unique().tolist() for col in on} ordered_cols = list(product(values, *uniques.values())) result = result.loc[:, ordered_cols] columns = result.columns.tolist() @@ -1041,11 +1033,11 @@ def pivot( def to_arrow(self: Self) -> Any: if self._implementation is Implementation.CUDF: - return self._native_frame.to_arrow(preserve_index=False) + return self.native.to_arrow(preserve_index=False) import pyarrow as pa # ignore-banned-import() - return pa.Table.from_pandas(self._native_frame) + return pa.Table.from_pandas(self.native) def sample( self: Self, @@ -1056,7 +1048,7 @@ def sample( seed: int | None, ) -> Self: return self._from_native_frame( - self._native_frame.sample( + self.native.sample( n=n, frac=fraction, replace=with_replacement, random_state=seed ), validate_column_names=False, @@ -1070,7 +1062,7 @@ def unpivot( value_name: str, ) -> Self: return self._from_native_frame( - self._native_frame.melt( + self.native.melt( id_vars=index, value_vars=on, var_name=variable_name, @@ -1094,10 +1086,10 @@ def explode(self: Self, columns: Sequence[str]) -> Self: if len(columns) == 1: return self._from_native_frame( - self._native_frame.explode(columns[0]), validate_column_names=False + self.native.explode(columns[0]), validate_column_names=False ) else: - native_frame = self._native_frame + native_frame = self.native anchor_series = native_frame[columns[0]].list.len() if not all( diff --git a/narwhals/_polars/dataframe.py b/narwhals/_polars/dataframe.py index 53e8d46461..3ccce65580 100644 --- a/narwhals/_polars/dataframe.py +++ b/narwhals/_polars/dataframe.py @@ -53,7 +53,7 @@ class PolarsDataFrame: clone: Method[Self] - collect: Method[CompliantDataFrame[Any, Any]] + collect: Method[CompliantDataFrame[Any, Any, Any]] drop_nulls: Method[Self] estimated_size: Method[int | float] explode: Method[Self] @@ -93,6 +93,10 @@ def __init__( self._version = version validate_backend_version(self._implementation, self._backend_version) + @property + def native(self) -> pl.DataFrame: + return self._native_frame + def __repr__(self: Self) -> str: # pragma: no cover return "PolarsDataFrame" @@ -113,7 +117,7 @@ def __native_namespace__(self: Self) -> ModuleType: def _change_version(self: Self, version: Version) -> Self: return self.__class__( - self._native_frame, backend_version=self._backend_version, version=version + self.native, backend_version=self._backend_version, version=version ) def _from_native_frame(self: Self, df: pl.DataFrame) -> Self: @@ -145,20 +149,20 @@ def _from_native_object( return obj def __len__(self) -> int: - return len(self._native_frame) + return len(self.native) def head(self, n: int) -> Self: - return self._from_native_frame(self._native_frame.head(n)) + return self._from_native_frame(self.native.head(n)) def tail(self, n: int) -> Self: - return self._from_native_frame(self._native_frame.tail(n)) + return self._from_native_frame(self.native.tail(n)) def __getattr__(self: Self, attr: str) -> Any: def func(*args: Any, **kwargs: Any) -> Any: args, kwargs = extract_args_kwargs(args, kwargs) # type: ignore[assignment] try: return self._from_native_object( - getattr(self._native_frame, attr)(*args, **kwargs) + getattr(self.native, attr)(*args, **kwargs) ) except pl.exceptions.ColumnNotFoundError as e: # pragma: no cover msg = f"{e!s}\n\nHint: Did you mean one of these columns: {self.columns}?" @@ -175,8 +179,8 @@ def __array__( msg = "`copy` in `__array__` is only supported for Polars>=0.20.28" raise NotImplementedError(msg) if self._backend_version < (0, 20, 28): - return self._native_frame.__array__(dtype) - return self._native_frame.__array__(dtype) + return self.native.__array__(dtype) + return self.native.__array__(dtype) def collect_schema(self: Self) -> dict[str, DType]: if self._backend_version < (1,): @@ -184,10 +188,10 @@ def collect_schema(self: Self) -> dict[str, DType]: name: native_to_narwhals_dtype( dtype, self._version, self._backend_version ) - for name, dtype in self._native_frame.schema.items() + for name, dtype in self.native.schema.items() } else: - collected_schema = self._native_frame.collect_schema() + collected_schema = self.native.collect_schema() return { name: native_to_narwhals_dtype( dtype, self._version, self._backend_version @@ -197,11 +201,11 @@ def collect_schema(self: Self) -> dict[str, DType]: @property def shape(self: Self) -> tuple[int, int]: - return self._native_frame.shape + return self.native.shape def __getitem__(self: Self, item: Any) -> Any: if self._backend_version > (0, 20, 30): - return self._from_native_object(self._native_frame.__getitem__(item)) + return self._from_native_object(self.native.__getitem__(item)) else: # pragma: no cover # TODO(marco): we can delete this branch after Polars==0.20.30 becomes the minimum # Polars version we support @@ -212,20 +216,16 @@ def __getitem__(self: Self, item: Any) -> Any: if isinstance(item, tuple) and len(item) == 2 and isinstance(item[1], slice): if item[1] == slice(None): if isinstance(item[0], Sequence) and not len(item[0]): - return self._from_native_frame(self._native_frame[0:0]) - return self._from_native_frame( - self._native_frame.__getitem__(item[0]) - ) + return self._from_native_frame(self.native[0:0]) + return self._from_native_frame(self.native.__getitem__(item[0])) if isinstance(item[1].start, str) or isinstance(item[1].stop, str): start, stop, step = convert_str_slice_to_int_slice(item[1], columns) return self._from_native_frame( - self._native_frame.select(columns[start:stop:step]).__getitem__( - item[0] - ) + self.native.select(columns[start:stop:step]).__getitem__(item[0]) ) if isinstance(item[1].start, int) or isinstance(item[1].stop, int): return self._from_native_frame( - self._native_frame.select( + self.native.select( columns[item[1].start : item[1].stop : item[1].step] ).__getitem__(item[0]) ) @@ -238,18 +238,18 @@ def __getitem__(self: Self, item: Any) -> Any: and is_sequence_but_not_str(item[1]) and (len(item[1]) == 0) ): - result = self._native_frame.select(item[1]) + result = self.native.select(item[1]) elif isinstance(item, slice) and ( isinstance(item.start, str) or isinstance(item.stop, str) ): start, stop, step = convert_str_slice_to_int_slice(item, columns) return self._from_native_frame( - self._native_frame.select(columns[start:stop:step]) + self.native.select(columns[start:stop:step]) ) elif is_sequence_but_not_str(item) and (len(item) == 0): - result = self._native_frame.slice(0, 0) + result = self.native.slice(0, 0) else: - result = self._native_frame.__getitem__(item) + result = self.native.__getitem__(item) if isinstance(result, pl.Series): from narwhals._polars.series import PolarsSeries @@ -259,7 +259,7 @@ def __getitem__(self: Self, item: Any) -> Any: return self._from_native_object(result) def simple_select(self, *column_names: str) -> Self: - return self._from_native_frame(self._native_frame.select(*column_names)) + return self._from_native_frame(self.native.select(*column_names)) def aggregate(self: Self, *exprs: Any) -> Self: return self.select(*exprs) @@ -268,7 +268,7 @@ def get_column(self: Self, name: str) -> PolarsSeries: from narwhals._polars.series import PolarsSeries return PolarsSeries( - self._native_frame.get_column(name), + self.native.get_column(name), backend_version=self._backend_version, version=self._version, ) @@ -276,33 +276,28 @@ def get_column(self: Self, name: str) -> PolarsSeries: def iter_columns(self) -> Iterator[PolarsSeries]: from narwhals._polars.series import PolarsSeries - for series in self._native_frame.iter_columns(): + for series in self.native.iter_columns(): yield PolarsSeries( series, backend_version=self._backend_version, version=self._version ) @property def columns(self: Self) -> list[str]: - return self._native_frame.columns + return self.native.columns @property def schema(self: Self) -> dict[str, DType]: - schema = self._native_frame.schema return { name: native_to_narwhals_dtype(dtype, self._version, self._backend_version) - for name, dtype in schema.items() + for name, dtype in self.native.schema.items() } def lazy( self: Self, *, backend: Implementation | None = None ) -> CompliantLazyFrame[Any, Any]: - from narwhals.utils import parse_version - if backend is None or backend is Implementation.POLARS: - from narwhals._polars.dataframe import PolarsLazyFrame - return PolarsLazyFrame( - self._native_frame.lazy(), + self.native.lazy(), backend_version=self._backend_version, version=self._version, ) @@ -311,9 +306,10 @@ def lazy( from narwhals._duckdb.dataframe import DuckDBLazyFrame + # NOTE: (F841) is a false positive df = self._native_frame # noqa: F841 return DuckDBLazyFrame( - df=duckdb.table("df"), + duckdb.table("df"), backend_version=parse_version(duckdb), version=self._version, ) @@ -324,7 +320,7 @@ def lazy( from narwhals._dask.dataframe import DaskLazyFrame return DaskLazyFrame( - native_dataframe=dd.from_pandas(self._native_frame.to_pandas()), + dd.from_pandas(self.native.to_pandas()), backend_version=parse_version(dask), version=self._version, ) @@ -339,8 +335,6 @@ def to_dict(self: Self, *, as_series: Literal[False]) -> dict[str, list[Any]]: . def to_dict( self: Self, *, as_series: bool ) -> dict[str, PolarsSeries] | dict[str, list[Any]]: - df = self._native_frame - if as_series: from narwhals._polars.series import PolarsSeries @@ -348,10 +342,10 @@ def to_dict( name: PolarsSeries( col, backend_version=self._backend_version, version=self._version ) - for name, col in df.to_dict(as_series=True).items() + for name, col in self.native.to_dict().items() } else: - return df.to_dict(as_series=False) + return self.native.to_dict(as_series=False) def group_by(self: Self, *by: str, drop_null_keys: bool) -> PolarsGroupBy: from narwhals._polars.group_by import PolarsGroupBy @@ -360,14 +354,14 @@ def group_by(self: Self, *by: str, drop_null_keys: bool) -> PolarsGroupBy: def with_row_index(self: Self, name: str) -> Self: if self._backend_version < (0, 20, 4): - return self._from_native_frame(self._native_frame.with_row_count(name)) - return self._from_native_frame(self._native_frame.with_row_index(name)) + return self._from_native_frame(self.native.with_row_count(name)) + return self._from_native_frame(self.native.with_row_index(name)) def drop(self: Self, columns: Sequence[str], *, strict: bool) -> Self: to_drop = parse_columns_to_drop( compliant_frame=self, columns=columns, strict=strict ) - return self._from_native_frame(self._native_frame.drop(to_drop)) + return self._from_native_frame(self.native.drop(to_drop)) def unpivot( self: Self, @@ -378,7 +372,7 @@ def unpivot( ) -> Self: if self._backend_version < (1, 0, 0): return self._from_native_frame( - self._native_frame.melt( + self.native.melt( id_vars=index, value_vars=on, variable_name=variable_name, @@ -386,7 +380,7 @@ def unpivot( ) ) return self._from_native_frame( - self._native_frame.unpivot( + self.native.unpivot( on=on, index=index, variable_name=variable_name, value_name=value_name ) ) @@ -408,7 +402,7 @@ def pivot( msg = "`pivot` is only supported for Polars>=1.0.0" raise NotImplementedError(msg) try: - result = self._native_frame.pivot( + result = self.native.pivot( on, index=index, values=values, @@ -421,7 +415,7 @@ def pivot( return self._from_native_object(result) def to_polars(self: Self) -> pl.DataFrame: - return self._native_frame + return self.native class PolarsLazyFrame: @@ -535,7 +529,7 @@ def collect( self: Self, backend: Implementation | None, **kwargs: Any, - ) -> CompliantDataFrame[Any, Any]: + ) -> CompliantDataFrame[Any, Any, Any]: try: result = self.native.collect(**kwargs) except Exception as e: # noqa: BLE001 diff --git a/narwhals/_spark_like/dataframe.py b/narwhals/_spark_like/dataframe.py index f01fb0f0a9..0b829a2f23 100644 --- a/narwhals/_spark_like/dataframe.py +++ b/narwhals/_spark_like/dataframe.py @@ -174,7 +174,7 @@ def collect( self: Self, backend: ModuleType | Implementation | str | None, **kwargs: Any, - ) -> CompliantDataFrame[Any, Any]: + ) -> CompliantDataFrame[Any, Any, Any]: if backend is Implementation.PANDAS: import pandas as pd # ignore-banned-import diff --git a/narwhals/utils.py b/narwhals/utils.py index 97dd571904..a19b287306 100644 --- a/narwhals/utils.py +++ b/narwhals/utils.py @@ -58,7 +58,7 @@ from narwhals._compliant import CompliantExpr from narwhals._compliant import CompliantFrameT from narwhals._compliant import CompliantSeriesOrNativeExprT_co - from narwhals._compliant.typing import NativeFrameT_co + from narwhals._compliant import NativeFrameT_co from narwhals.dataframe import DataFrame from narwhals.dataframe import LazyFrame from narwhals.dtypes import DType @@ -1476,8 +1476,9 @@ def _hasattr_static(obj: Any, attr: str) -> bool: def is_compliant_dataframe( - obj: CompliantDataFrame[CompliantSeriesT_co, CompliantExprT_co] | Any, -) -> TypeIs[CompliantDataFrame[CompliantSeriesT_co, CompliantExprT_co]]: + obj: CompliantDataFrame[CompliantSeriesT_co, CompliantExprT_co, NativeFrameT_co] + | Any, +) -> TypeIs[CompliantDataFrame[CompliantSeriesT_co, CompliantExprT_co, NativeFrameT_co]]: return _hasattr_static(obj, "__narwhals_dataframe__")