Skip to content
27 changes: 27 additions & 0 deletions narwhals/_pandas_like/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,40 @@

if TYPE_CHECKING:
import sys
from typing import Any

if sys.version_info >= (3, 13):
from typing import TypeVar
else:
from typing_extensions import TypeVar

if sys.version_info >= (3, 10):
from typing import TypeAlias
else:
from typing_extensions import TypeAlias

import cudf
import modin.pandas as mpd
import pandas as pd

from narwhals._pandas_like.expr import PandasLikeExpr
from narwhals._pandas_like.series import PandasLikeSeries

IntoPandasLikeExpr: TypeAlias = Union[PandasLikeExpr, PandasLikeSeries]

DataFrameT = TypeVar(
"DataFrameT", pd.DataFrame, mpd.DataFrame, cudf.DataFrame, default=pd.DataFrame
)
SeriesT = TypeVar(
"SeriesT", pd.Series[Any], mpd.Series, cudf.Series[Any], default=pd.Series[Any]
)
NDFrameT = TypeVar(
"NDFrameT",
pd.DataFrame,
mpd.DataFrame,
cudf.DataFrame,
pd.Series[Any],
mpd.Series,
cudf.Series[Any],
default=pd.DataFrame,
)
61 changes: 33 additions & 28 deletions narwhals/_pandas_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
from narwhals._pandas_like.dataframe import PandasLikeDataFrame
from narwhals._pandas_like.expr import PandasLikeExpr
from narwhals._pandas_like.series import PandasLikeSeries
from narwhals._pandas_like.typing import DataFrameT
from narwhals._pandas_like.typing import NDFrameT
from narwhals.dtypes import DType
from narwhals.typing import DTypeBackend
from narwhals.typing import TimeUnit
Expand Down Expand Up @@ -262,20 +264,20 @@ def native_series_from_iterable(


def set_index(
obj: T,
obj: NDFrameT,
index: Any,
*,
implementation: Implementation,
backend_version: tuple[int, ...],
) -> T:
) -> NDFrameT:
"""Wrapper around pandas' set_axis to set object index.

We can set `copy` / `inplace` based on implementation/version.
"""
if implementation is Implementation.CUDF: # pragma: no cover
obj = obj.copy(deep=False) # type: ignore[attr-defined]
obj.index = index # type: ignore[attr-defined]
return obj
cudf_frame = obj.copy(deep=False)
cudf_frame.index = index
return cast("NDFrameT", cudf_frame) # type: ignore[redundant-cast]
if implementation is Implementation.PANDAS and (
backend_version < (1,)
): # pragma: no cover
Expand All @@ -288,24 +290,25 @@ def set_index(
kwargs["copy"] = False
else: # pragma: no cover
pass
return obj.set_axis(index, axis=0, **kwargs) # type: ignore[attr-defined]
nd_frame = obj.set_axis(index, axis=0, **kwargs)
return cast("NDFrameT", nd_frame) # type: ignore[redundant-cast]


def set_columns(
obj: T,
obj: NDFrameT,
columns: list[str],
*,
implementation: Implementation,
backend_version: tuple[int, ...],
) -> T:
) -> NDFrameT:
"""Wrapper around pandas' set_axis to set object columns.

We can set `copy` / `inplace` based on implementation/version.
"""
if implementation is Implementation.CUDF: # pragma: no cover
obj = obj.copy(deep=False) # type: ignore[attr-defined]
obj.columns = columns # type: ignore[attr-defined]
return obj
cudf_frame = obj.copy(deep=False)
cudf_frame.columns = cast("pd.Index[str]", columns)
return cast("NDFrameT", cudf_frame) # type: ignore[redundant-cast]
if implementation is Implementation.PANDAS and (
backend_version < (1,)
): # pragma: no cover
Expand All @@ -318,22 +321,24 @@ def set_columns(
kwargs["copy"] = False
else: # pragma: no cover
pass
return obj.set_axis(columns, axis=1, **kwargs) # type: ignore[attr-defined]
nd_frame = obj.set_axis(columns, axis=1, **kwargs)
return cast("NDFrameT", nd_frame) # type: ignore[redundant-cast]


def rename(
obj: T,
obj: NDFrameT,
*args: Any,
implementation: Implementation,
backend_version: tuple[int, ...],
**kwargs: Any,
) -> T:
) -> NDFrameT:
"""Wrapper around pandas' rename so that we can set `copy` based on implementation/version."""
if implementation is Implementation.PANDAS and (
backend_version >= (3,)
): # pragma: no cover
return obj.rename(*args, **kwargs) # type: ignore[attr-defined]
return obj.rename(*args, **kwargs, copy=False) # type: ignore[attr-defined]
nd_frame = (
obj.rename(*args, **kwargs, inplace=False)
if implementation.is_pandas() and (backend_version >= (3,))
else obj.rename(*args, **kwargs, copy=False, inplace=False)
)
return cast("NDFrameT", nd_frame) # type: ignore[redundant-cast]


@functools.lru_cache(maxsize=16)
Expand Down Expand Up @@ -749,34 +754,34 @@ def calculate_timestamp_date(s: pd.Series[int], time_unit: str) -> pd.Series[int


def select_columns_by_name(
df: T,
df: DataFrameT,
column_names: list[str] | _1DArray, # NOTE: Cannot be a tuple!
backend_version: tuple[int, ...],
implementation: Implementation,
) -> T:
) -> DataFrameT:
"""Select columns by name.

Prefer this over `df.loc[:, column_names]` as it's
generally more performant.
"""
if len(column_names) == df.shape[1] and all(column_names == df.columns): # type: ignore[attr-defined]
if len(column_names) == df.shape[1] and (df.columns == column_names).all():
return df
if (df.columns.dtype.kind == "b") or ( # type: ignore[attr-defined]
implementation is Implementation.PANDAS and backend_version < (1, 5)
if (df.columns.dtype.kind == "b") or (
implementation.is_pandas() and backend_version < (1, 5)
):
# See https://github.com/narwhals-dev/narwhals/issues/1349#issuecomment-2470118122
# for why we need this
available_columns = df.columns.tolist() # type: ignore[attr-defined]
available_columns = df.columns.tolist()
missing_columns = [x for x in column_names if x not in available_columns]
if missing_columns: # pragma: no cover
raise ColumnNotFoundError.from_missing_and_available_column_names(
missing_columns, available_columns
)
return df.loc[:, column_names] # type: ignore[attr-defined]
return cast("DataFrameT", df.loc[:, column_names]) # type: ignore[redundant-cast]
try:
return df[column_names] # type: ignore[index]
return df[column_names]
except KeyError as e:
available_columns = df.columns.tolist() # type: ignore[attr-defined]
available_columns = df.columns.tolist()
missing_columns = [x for x in column_names if x not in available_columns]
raise ColumnNotFoundError.from_missing_and_available_column_names(
missing_columns, available_columns
Expand Down
Loading