diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index 2cb84475d0..e0709a84ff 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -164,6 +164,8 @@ def iter_columns(self) -> Iterator[ArrowSeries]: version=self._version, ) + _iter_columns = iter_columns + def iter_rows( self: Self, *, named: bool, buffer_size: int ) -> Iterator[tuple[Any, ...]] | Iterator[dict[str, Any]]: diff --git a/narwhals/_arrow/selectors.py b/narwhals/_arrow/selectors.py index ec045c9e15..d9c74be112 100644 --- a/narwhals/_arrow/selectors.py +++ b/narwhals/_arrow/selectors.py @@ -1,130 +1,45 @@ from __future__ import annotations -import re from typing import TYPE_CHECKING -from typing import Any -from typing import Callable -from typing import Iterable -from typing import Sequence from narwhals._arrow.expr import ArrowExpr -from narwhals.utils import _parse_time_unit_and_time_zone -from narwhals.utils import dtype_matches_time_unit_and_time_zone -from narwhals.utils import import_dtypes_module +from narwhals._selectors import CompliantSelector +from narwhals._selectors import EagerSelectorNamespace if TYPE_CHECKING: - from datetime import timezone - from typing_extensions import Self from narwhals._arrow.dataframe import ArrowDataFrame from narwhals._arrow.series import ArrowSeries - from narwhals.dtypes import DType - from narwhals.typing import TimeUnit - from narwhals.utils import _LimitedContext - - -class ArrowSelectorNamespace: - def __init__(self: Self, context: _LimitedContext, /) -> None: - self._backend_version = context._backend_version - self._version = context._version - - def by_dtype(self: Self, dtypes: Iterable[DType | type[DType]]) -> ArrowSelector: - def func(df: ArrowDataFrame) -> list[ArrowSeries]: - return [df[col] for col in df.columns if df.schema[col] in dtypes] - - def evaluate_output_names(df: ArrowDataFrame) -> Sequence[str]: - return [col for col in df.columns if df.schema[col] in dtypes] - - return selector(self, func, evaluate_output_names) - - def matches(self: Self, pattern: str) -> ArrowSelector: - def func(df: ArrowDataFrame) -> list[ArrowSeries]: - return [df[col] for col in df.columns if re.search(pattern, col)] - - def evaluate_output_names(df: ArrowDataFrame) -> Sequence[str]: - return [col for col in df.columns if re.search(pattern, col)] + from narwhals._selectors import EvalNames + from narwhals._selectors import EvalSeries + from narwhals.utils import _FullContext - return selector(self, func, evaluate_output_names) - def numeric(self: Self) -> ArrowSelector: - dtypes = import_dtypes_module(self._version) - return self.by_dtype( - [ - dtypes.Int128, - dtypes.Int64, - dtypes.Int32, - dtypes.Int16, - dtypes.Int8, - dtypes.UInt128, - dtypes.UInt64, - dtypes.UInt32, - dtypes.UInt16, - dtypes.UInt8, - dtypes.Float64, - dtypes.Float32, - ], - ) - - def categorical(self: Self) -> ArrowSelector: - dtypes = import_dtypes_module(self._version) - return self.by_dtype([dtypes.Categorical]) - - def string(self: Self) -> ArrowSelector: - dtypes = import_dtypes_module(self._version) - return self.by_dtype([dtypes.String]) - - def boolean(self: Self) -> ArrowSelector: - dtypes = import_dtypes_module(self._version) - return self.by_dtype([dtypes.Boolean]) - - def all(self: Self) -> ArrowSelector: - def func(df: ArrowDataFrame) -> list[ArrowSeries]: - return [df[col] for col in df.columns] - - return selector(self, func, lambda df: df.columns) - - def datetime( - self: Self, - time_unit: TimeUnit | Iterable[TimeUnit] | None, - time_zone: str | timezone | Iterable[str | timezone | None] | None, +class ArrowSelectorNamespace(EagerSelectorNamespace["ArrowDataFrame", "ArrowSeries"]): + def _selector( + self, + call: EvalSeries[ArrowDataFrame, ArrowSeries], + evaluate_output_names: EvalNames[ArrowDataFrame], + /, ) -> ArrowSelector: - dtypes = import_dtypes_module(version=self._version) - time_units, time_zones = _parse_time_unit_and_time_zone( - time_unit=time_unit, time_zone=time_zone + return ArrowSelector( + call, + depth=0, + function_name="selector", + evaluate_output_names=evaluate_output_names, + alias_output_names=None, + backend_version=self._backend_version, + version=self._version, ) - def func(df: ArrowDataFrame) -> list[ArrowSeries]: - return [ - df[col] - for col in df.columns - if dtype_matches_time_unit_and_time_zone( - dtype=df.schema[col], - dtypes=dtypes, - time_units=time_units, - time_zones=time_zones, - ) - ] - - def evaluate_output_names(df: ArrowDataFrame) -> Sequence[str]: - return [ - col - for col in df.columns - if dtype_matches_time_unit_and_time_zone( - dtype=df.schema[col], - dtypes=dtypes, - time_units=time_units, - time_zones=time_zones, - ) - ] - - return selector(self, func, evaluate_output_names) - + def __init__(self: Self, context: _FullContext, /) -> None: + self._implementation = context._implementation + self._backend_version = context._backend_version + self._version = context._version -class ArrowSelector(ArrowExpr): - def __repr__(self: Self) -> str: # pragma: no cover - return f"ArrowSelector(depth={self._depth}, function_name={self._function_name})" +class ArrowSelector(CompliantSelector["ArrowDataFrame", "ArrowSeries"], ArrowExpr): # type: ignore[misc] def _to_expr(self: Self) -> ArrowExpr: return ArrowExpr( self._call, @@ -135,82 +50,3 @@ def _to_expr(self: Self) -> ArrowExpr: backend_version=self._backend_version, version=self._version, ) - - def __sub__(self: Self, other: Self | Any) -> ArrowSelector | Any: - if isinstance(other, ArrowSelector): - - def call(df: ArrowDataFrame) -> list[ArrowSeries]: - lhs_names = self._evaluate_output_names(df) - rhs_names = other._evaluate_output_names(df) - lhs = self._call(df) - return [x for x, name in zip(lhs, lhs_names) if name not in rhs_names] - - def evaluate_output_names(df: ArrowDataFrame) -> list[str]: - lhs_names = self._evaluate_output_names(df) - rhs_names = other._evaluate_output_names(df) - return [x for x in lhs_names if x not in rhs_names] - - return selector(self, call, evaluate_output_names) - else: - return self._to_expr() - other - - def __or__(self: Self, other: Self | Any) -> ArrowSelector | Any: - if isinstance(other, ArrowSelector): - - def call(df: ArrowDataFrame) -> list[ArrowSeries]: - lhs_names = self._evaluate_output_names(df) - rhs_names = other._evaluate_output_names(df) - lhs = self._call(df) - rhs = other._call(df) - return [ - *(x for x, name in zip(lhs, lhs_names) if name not in rhs_names), - *rhs, - ] - - def evaluate_output_names(df: ArrowDataFrame) -> list[str]: - lhs_names = self._evaluate_output_names(df) - rhs_names = other._evaluate_output_names(df) - return [*(x for x in lhs_names if x not in rhs_names), *rhs_names] - - return selector(self, call, evaluate_output_names) - else: - return self._to_expr() | other - - def __and__(self: Self, other: Self | Any) -> ArrowSelector | Any: - if isinstance(other, ArrowSelector): - - def call(df: ArrowDataFrame) -> list[ArrowSeries]: - lhs_names = self._evaluate_output_names(df) - rhs_names = other._evaluate_output_names(df) - lhs = self._call(df) - return [x for x, name in zip(lhs, lhs_names) if name in rhs_names] - - def evaluate_output_names(df: ArrowDataFrame) -> list[str]: - lhs_names = self._evaluate_output_names(df) - rhs_names = other._evaluate_output_names(df) - return [x for x in lhs_names if x in rhs_names] - - return selector(self, call, evaluate_output_names) - - else: - return self._to_expr() & other - - def __invert__(self: Self) -> ArrowSelector: - return ArrowSelectorNamespace(self).all() - self - - -def selector( - context: _LimitedContext, - call: Callable[[ArrowDataFrame], Sequence[ArrowSeries]], - evaluate_output_names: Callable[[ArrowDataFrame], Sequence[str]], - /, -) -> ArrowSelector: - return ArrowSelector( - call, - depth=0, - function_name="selector", - evaluate_output_names=evaluate_output_names, - alias_output_names=None, - backend_version=context._backend_version, - version=context._version, - ) diff --git a/narwhals/_dask/dataframe.py b/narwhals/_dask/dataframe.py index e77d05d742..e8dde409a3 100644 --- a/narwhals/_dask/dataframe.py +++ b/narwhals/_dask/dataframe.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING from typing import Any +from typing import Iterator from typing import Literal from typing import Sequence @@ -24,6 +25,7 @@ if TYPE_CHECKING: from types import ModuleType + import dask.dataframe.dask_expr as dx from typing_extensions import Self from narwhals._dask.expr import DaskExpr @@ -79,6 +81,10 @@ def _from_native_frame(self: Self, df: Any) -> Self: version=self._version, ) + def _iter_columns(self) -> Iterator[dx.Series]: + for _col, ser in self._native_frame.items(): # noqa: PERF102 + yield ser + def with_columns(self: Self, *exprs: DaskExpr) -> Self: df = self._native_frame new_series = evaluate_exprs(self, *exprs) diff --git a/narwhals/_dask/namespace.py b/narwhals/_dask/namespace.py index 828ad2afdd..8d64cf9d11 100644 --- a/narwhals/_dask/namespace.py +++ b/narwhals/_dask/namespace.py @@ -24,6 +24,7 @@ from narwhals._expression_parsing import combine_alias_output_names from narwhals._expression_parsing import combine_evaluate_output_names from narwhals.typing import CompliantNamespace +from narwhals.utils import Implementation from narwhals.utils import get_column_names if TYPE_CHECKING: @@ -39,6 +40,8 @@ class DaskNamespace(CompliantNamespace[DaskLazyFrame, "dx.Series"]): # pyright: ignore[reportInvalidTypeArguments] (#2044) + _implementation: Implementation = Implementation.DASK + @property def selectors(self: Self) -> DaskSelectorNamespace: return DaskSelectorNamespace(self) diff --git a/narwhals/_dask/selectors.py b/narwhals/_dask/selectors.py index a0e6f49b1f..9533721d49 100644 --- a/narwhals/_dask/selectors.py +++ b/narwhals/_dask/selectors.py @@ -1,16 +1,10 @@ from __future__ import annotations -import re from typing import TYPE_CHECKING -from typing import Any -from typing import Callable -from typing import Iterable -from typing import Sequence from narwhals._dask.expr import DaskExpr -from narwhals.utils import _parse_time_unit_and_time_zone -from narwhals.utils import dtype_matches_time_unit_and_time_zone -from narwhals.utils import import_dtypes_module +from narwhals._selectors import CompliantSelector +from narwhals._selectors import LazySelectorNamespace if TYPE_CHECKING: try: @@ -18,126 +12,38 @@ except ModuleNotFoundError: import dask_expr as dx - from datetime import timezone - from typing_extensions import Self from narwhals._dask.dataframe import DaskLazyFrame - from narwhals.dtypes import DType - from narwhals.typing import TimeUnit - from narwhals.utils import _LimitedContext - - try: - import dask.dataframe.dask_expr as dx - except ModuleNotFoundError: - import dask_expr as dx - + from narwhals._selectors import EvalNames + from narwhals._selectors import EvalSeries + from narwhals.utils import _FullContext + + +class DaskSelectorNamespace(LazySelectorNamespace["DaskLazyFrame", "dx.Series"]): # pyright: ignore[reportInvalidTypeArguments] + def _selector( + self, + call: EvalSeries[DaskLazyFrame, dx.Series], # pyright: ignore[reportInvalidTypeForm] + evaluate_output_names: EvalNames[DaskLazyFrame], + /, + ) -> DaskSelector: + return DaskSelector( + call, + depth=0, + function_name="selector", + evaluate_output_names=evaluate_output_names, + alias_output_names=None, + backend_version=self._backend_version, + version=self._version, + ) -class DaskSelectorNamespace: - def __init__(self: Self, context: _LimitedContext, /) -> None: + def __init__(self: Self, context: _FullContext, /) -> None: + self._implementation = context._implementation self._backend_version = context._backend_version self._version = context._version - def by_dtype(self: Self, dtypes: Iterable[DType | type[DType]]) -> DaskSelector: - def func(df: DaskLazyFrame) -> list[dx.Series]: - return [ - df._native_frame[col] for col in df.columns if df.schema[col] in dtypes - ] - - def evaluate_output_names(df: DaskLazyFrame) -> Sequence[str]: - return [col for col in df.columns if df.schema[col] in dtypes] - - return selector(self, func, evaluate_output_names) - - def matches(self: Self, pattern: str) -> DaskSelector: - def func(df: DaskLazyFrame) -> list[dx.Series]: - return [ - df._native_frame[col] for col in df.columns if re.search(pattern, col) - ] - - def evaluate_output_names(df: DaskLazyFrame) -> Sequence[str]: - return [col for col in df.columns if re.search(pattern, col)] - - return selector(self, func, evaluate_output_names) - - def numeric(self: Self) -> DaskSelector: - dtypes = import_dtypes_module(self._version) - return self.by_dtype( - { - dtypes.Int128, - dtypes.Int64, - dtypes.Int32, - dtypes.Int16, - dtypes.Int8, - dtypes.UInt128, - dtypes.UInt64, - dtypes.UInt32, - dtypes.UInt16, - dtypes.UInt8, - dtypes.Float64, - dtypes.Float32, - }, - ) - - def categorical(self: Self) -> DaskSelector: - dtypes = import_dtypes_module(self._version) - return self.by_dtype({dtypes.Categorical}) - - def string(self: Self) -> DaskSelector: - dtypes = import_dtypes_module(self._version) - return self.by_dtype({dtypes.String}) - - def boolean(self: Self) -> DaskSelector: - dtypes = import_dtypes_module(self._version) - return self.by_dtype({dtypes.Boolean}) - - def all(self: Self) -> DaskSelector: - def func(df: DaskLazyFrame) -> list[dx.Series]: - return [df._native_frame[col] for col in df.columns] - - return selector(self, func, lambda df: df.columns) - - def datetime( - self: Self, - time_unit: TimeUnit | Iterable[TimeUnit] | None, - time_zone: str | timezone | Iterable[str | timezone | None] | None, - ) -> DaskSelector: # pragma: no cover - dtypes = import_dtypes_module(version=self._version) - time_units, time_zones = _parse_time_unit_and_time_zone( - time_unit=time_unit, time_zone=time_zone - ) - - def func(df: DaskLazyFrame) -> list[dx.Series]: - return [ - df._native_frame[col] - for col in df.columns - if dtype_matches_time_unit_and_time_zone( - dtype=df.schema[col], - dtypes=dtypes, - time_units=time_units, - time_zones=time_zones, - ) - ] - - def evaluate_output_names(df: DaskLazyFrame) -> Sequence[str]: - return [ - col - for col in df.columns - if dtype_matches_time_unit_and_time_zone( - dtype=df.schema[col], - dtypes=dtypes, - time_units=time_units, - time_zones=time_zones, - ) - ] - - return selector(self, func, evaluate_output_names) - - -class DaskSelector(DaskExpr): - def __repr__(self: Self) -> str: # pragma: no cover - return f"DaskSelector(depth={self._depth}, function_name={self._function_name})" +class DaskSelector(CompliantSelector["DaskLazyFrame", "dx.Series"], DaskExpr): # type: ignore[misc] def _to_expr(self: Self) -> DaskExpr: return DaskExpr( self._call, @@ -148,81 +54,3 @@ def _to_expr(self: Self) -> DaskExpr: backend_version=self._backend_version, version=self._version, ) - - def __sub__(self: Self, other: DaskSelector | Any) -> DaskSelector | Any: - if isinstance(other, DaskSelector): - - def call(df: DaskLazyFrame) -> list[dx.Series]: - lhs_names = self._evaluate_output_names(df) - rhs_names = other._evaluate_output_names(df) - lhs = self._call(df) - return [x for x, name in zip(lhs, lhs_names) if name not in rhs_names] - - def evaluate_output_names(df: DaskLazyFrame) -> list[str]: - lhs_names = self._evaluate_output_names(df) - rhs_names = other._evaluate_output_names(df) - return [x for x in lhs_names if x not in rhs_names] - - return selector(self, call, evaluate_output_names) - else: - return self._to_expr() - other - - def __or__(self: Self, other: DaskSelector | Any) -> DaskSelector | Any: - if isinstance(other, DaskSelector): - - def call(df: DaskLazyFrame) -> list[dx.Series]: - lhs_names = self._evaluate_output_names(df) - rhs_names = other._evaluate_output_names(df) - lhs = self._call(df) - rhs = other._call(df) - return [ - *(x for x, name in zip(lhs, lhs_names) if name not in rhs_names), - *rhs, - ] - - def evaluate_output_names(df: DaskLazyFrame) -> list[str]: - lhs_names = self._evaluate_output_names(df) - rhs_names = other._evaluate_output_names(df) - return [*(x for x in lhs_names if x not in rhs_names), *rhs_names] - - return selector(self, call, evaluate_output_names) - else: - return self._to_expr() | other - - def __and__(self: Self, other: DaskSelector | Any) -> DaskSelector | Any: - if isinstance(other, DaskSelector): - - def call(df: DaskLazyFrame) -> list[dx.Series]: - lhs_names = self._evaluate_output_names(df) - rhs_names = other._evaluate_output_names(df) - lhs = self._call(df) - return [x for x, name in zip(lhs, lhs_names) if name in rhs_names] - - def evaluate_output_names(df: DaskLazyFrame) -> list[str]: - lhs_names = self._evaluate_output_names(df) - rhs_names = other._evaluate_output_names(df) - return [x for x in lhs_names if x in rhs_names] - - return selector(self, call, evaluate_output_names) - else: - return self._to_expr() & other - - def __invert__(self: Self) -> DaskSelector: - return DaskSelectorNamespace(self).all() - self - - -def selector( - context: _LimitedContext, - call: Callable[[DaskLazyFrame], Sequence[dx.Series]], - evaluate_output_names: Callable[[DaskLazyFrame], Sequence[str]], - /, -) -> DaskSelector: - return DaskSelector( - call, - depth=0, - function_name="selector", - evaluate_output_names=evaluate_output_names, - alias_output_names=None, - backend_version=context._backend_version, - version=context._version, - ) diff --git a/narwhals/_duckdb/dataframe.py b/narwhals/_duckdb/dataframe.py index dc74eae824..e1fa303bd9 100644 --- a/narwhals/_duckdb/dataframe.py +++ b/narwhals/_duckdb/dataframe.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING from typing import Any +from typing import Iterator from typing import Literal from typing import Sequence @@ -85,6 +86,10 @@ def __getitem__(self: Self, item: str) -> DuckDBInterchangeSeries: self._native_frame.select(item), version=self._version ) + def _iter_columns(self) -> Iterator[duckdb.Expression]: + for col in self.columns: + yield ColumnExpression(col) + def collect( self: Self, backend: ModuleType | Implementation | str | None, diff --git a/narwhals/_duckdb/namespace.py b/narwhals/_duckdb/namespace.py index 047355e940..45c877965e 100644 --- a/narwhals/_duckdb/namespace.py +++ b/narwhals/_duckdb/namespace.py @@ -25,6 +25,7 @@ from narwhals._expression_parsing import combine_alias_output_names from narwhals._expression_parsing import combine_evaluate_output_names from narwhals.typing import CompliantNamespace +from narwhals.utils import Implementation from narwhals.utils import get_column_names if TYPE_CHECKING: @@ -37,6 +38,8 @@ class DuckDBNamespace(CompliantNamespace["DuckDBLazyFrame", "duckdb.Expression"]): # type: ignore[type-var] + _implementation: Implementation = Implementation.DUCKDB + def __init__( self: Self, *, backend_version: tuple[int, ...], version: Version ) -> None: diff --git a/narwhals/_duckdb/selectors.py b/narwhals/_duckdb/selectors.py index a30cec06cb..0e54fd3c76 100644 --- a/narwhals/_duckdb/selectors.py +++ b/narwhals/_duckdb/selectors.py @@ -1,136 +1,49 @@ from __future__ import annotations -import re from typing import TYPE_CHECKING -from typing import Any -from typing import Callable -from typing import Iterable -from typing import Sequence - -from duckdb import ColumnExpression from narwhals._duckdb.expr import DuckDBExpr -from narwhals.utils import _parse_time_unit_and_time_zone -from narwhals.utils import dtype_matches_time_unit_and_time_zone -from narwhals.utils import import_dtypes_module +from narwhals._selectors import CompliantSelector +from narwhals._selectors import LazySelectorNamespace if TYPE_CHECKING: - from datetime import timezone - import duckdb from typing_extensions import Self from narwhals._duckdb.dataframe import DuckDBLazyFrame - from narwhals.dtypes import DType - from narwhals.typing import TimeUnit - from narwhals.utils import _LimitedContext - - -class DuckDBSelectorNamespace: - def __init__(self: Self, context: _LimitedContext, /) -> None: - self._backend_version = context._backend_version - self._version = context._version - - def by_dtype(self: Self, dtypes: Iterable[DType | type[DType]]) -> DuckDBSelector: - def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: - return [ - ColumnExpression(col) for col in df.columns if df.schema[col] in dtypes - ] - - def evaluate_output_names(df: DuckDBLazyFrame) -> Sequence[str]: - return [col for col in df.columns if df.schema[col] in dtypes] - - return selector(self, func, evaluate_output_names) - - def matches(self: Self, pattern: str) -> DuckDBSelector: - def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: - return [ - ColumnExpression(col) for col in df.columns if re.search(pattern, col) - ] - - def evaluate_output_names(df: DuckDBLazyFrame) -> Sequence[str]: - return [col for col in df.columns if re.search(pattern, col)] - - return selector(self, func, evaluate_output_names) - - def numeric(self: Self) -> DuckDBSelector: - dtypes = import_dtypes_module(self._version) - return self.by_dtype( - { - dtypes.Int128, - dtypes.Int64, - dtypes.Int32, - dtypes.Int16, - dtypes.Int8, - dtypes.UInt128, - dtypes.UInt64, - dtypes.UInt32, - dtypes.UInt16, - dtypes.UInt8, - dtypes.Float64, - dtypes.Float32, - }, - ) - - def categorical(self: Self) -> DuckDBSelector: # pragma: no cover - dtypes = import_dtypes_module(self._version) - return self.by_dtype({dtypes.Categorical}) - - def string(self: Self) -> DuckDBSelector: - dtypes = import_dtypes_module(self._version) - return self.by_dtype({dtypes.String}) - - def boolean(self: Self) -> DuckDBSelector: - dtypes = import_dtypes_module(self._version) - return self.by_dtype({dtypes.Boolean}) - - def all(self: Self) -> DuckDBSelector: - def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: - return [ColumnExpression(col) for col in df.columns] - - return selector(self, func, lambda df: df.columns) - - def datetime( - self: Self, - time_unit: TimeUnit | Iterable[TimeUnit] | None, - time_zone: str | timezone | Iterable[str | timezone | None] | None, + from narwhals._selectors import EvalNames + from narwhals._selectors import EvalSeries + from narwhals.utils import _FullContext + + +class DuckDBSelectorNamespace( + LazySelectorNamespace["DuckDBLazyFrame", "duckdb.Expression"] # type: ignore[type-var] +): + def _selector( + self, + call: EvalSeries[DuckDBLazyFrame, duckdb.Expression], # type: ignore[type-var] + evaluate_output_names: EvalNames[DuckDBLazyFrame], + /, ) -> DuckDBSelector: - dtypes = import_dtypes_module(version=self._version) - time_units, time_zones = _parse_time_unit_and_time_zone( - time_unit=time_unit, time_zone=time_zone + return DuckDBSelector( + call, + function_name="selector", + evaluate_output_names=evaluate_output_names, + alias_output_names=None, + backend_version=self._backend_version, + version=self._version, ) - def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: - return [ - ColumnExpression(col) - for col in df.columns - if dtype_matches_time_unit_and_time_zone( - dtype=df.schema[col], - dtypes=dtypes, - time_units=time_units, - time_zones=time_zones, - ) - ] - - def evaluate_output_names(df: DuckDBLazyFrame) -> Sequence[str]: - return [ - col - for col in df.columns - if dtype_matches_time_unit_and_time_zone( - dtype=df.schema[col], - dtypes=dtypes, - time_units=time_units, - time_zones=time_zones, - ) - ] - - return selector(self, func, evaluate_output_names) - + def __init__(self: Self, context: _FullContext, /) -> None: + self._implementation = context._implementation + self._backend_version = context._backend_version + self._version = context._version -class DuckDBSelector(DuckDBExpr): - def __repr__(self: Self) -> str: # pragma: no cover - return f"DuckDBSelector(function_name={self._function_name})" +class DuckDBSelector( # type: ignore[misc] + CompliantSelector["DuckDBLazyFrame", "duckdb.Expression"], # type: ignore[type-var] + DuckDBExpr, +): def _to_expr(self: Self) -> DuckDBExpr: return DuckDBExpr( self._call, @@ -140,80 +53,3 @@ def _to_expr(self: Self) -> DuckDBExpr: backend_version=self._backend_version, version=self._version, ) - - def __sub__(self: Self, other: DuckDBSelector | Any) -> DuckDBSelector | Any: - if isinstance(other, DuckDBSelector): - - def call(df: DuckDBLazyFrame) -> list[duckdb.Expression]: - lhs_names = self._evaluate_output_names(df) - rhs_names = other._evaluate_output_names(df) - lhs = self._call(df) - return [x for x, name in zip(lhs, lhs_names) if name not in rhs_names] - - def evaluate_output_names(df: DuckDBLazyFrame) -> list[str]: - lhs_names = self._evaluate_output_names(df) - rhs_names = other._evaluate_output_names(df) - return [x for x in lhs_names if x not in rhs_names] - - return selector(self, call, evaluate_output_names) - else: - return self._to_expr() - other - - def __or__(self: Self, other: DuckDBSelector | Any) -> DuckDBSelector | Any: - if isinstance(other, DuckDBSelector): - - def call(df: DuckDBLazyFrame) -> list[duckdb.Expression]: - lhs_names = self._evaluate_output_names(df) - rhs_names = other._evaluate_output_names(df) - lhs = self._call(df) - rhs = other._call(df) - return [ - *(x for x, name in zip(lhs, lhs_names) if name not in rhs_names), - *rhs, - ] - - def evaluate_output_names(df: DuckDBLazyFrame) -> list[str]: - lhs_names = self._evaluate_output_names(df) - rhs_names = other._evaluate_output_names(df) - return [*(x for x in lhs_names if x not in rhs_names), *rhs_names] - - return selector(self, call, evaluate_output_names) - else: - return self._to_expr() | other - - def __and__(self: Self, other: DuckDBSelector | Any) -> DuckDBSelector | Any: - if isinstance(other, DuckDBSelector): - - def call(df: DuckDBLazyFrame) -> list[duckdb.Expression]: - lhs_names = self._evaluate_output_names(df) - rhs_names = other._evaluate_output_names(df) - lhs = self._call(df) - return [x for x, name in zip(lhs, lhs_names) if name in rhs_names] - - def evaluate_output_names(df: DuckDBLazyFrame) -> list[str]: - lhs_names = self._evaluate_output_names(df) - rhs_names = other._evaluate_output_names(df) - return [x for x in lhs_names if x in rhs_names] - - return selector(self, call, evaluate_output_names) - else: - return self._to_expr() & other - - def __invert__(self: Self) -> DuckDBSelector: - return DuckDBSelectorNamespace(self).all() - self - - -def selector( - context: _LimitedContext, - call: Callable[[DuckDBLazyFrame], Sequence[duckdb.Expression]], - evaluate_output_names: Callable[[DuckDBLazyFrame], Sequence[str]], - /, -) -> DuckDBSelector: - return DuckDBSelector( - call, - function_name="selector", - evaluate_output_names=evaluate_output_names, - alias_output_names=None, - backend_version=context._backend_version, - version=context._version, - ) diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index 11819495e6..cb3f132ac2 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -52,8 +52,7 @@ def is_expr(obj: Any) -> TypeIs[Expr]: def evaluate_into_expr( - df: CompliantFrameT, - expr: CompliantExpr[CompliantFrameT, CompliantSeriesT_co], + df: CompliantFrameT, expr: CompliantExpr[CompliantFrameT, CompliantSeriesT_co] ) -> Sequence[CompliantSeriesT_co]: """Return list of raw columns. @@ -87,8 +86,7 @@ def evaluate_into_exprs( @overload def maybe_evaluate_expr( - df: CompliantFrameT, - expr: CompliantExpr[CompliantFrameT, CompliantSeriesT_co], + df: CompliantFrameT, expr: CompliantExpr[CompliantFrameT, CompliantSeriesT_co] ) -> CompliantSeriesT_co: ... diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index f2a336146e..e958575461 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -350,6 +350,8 @@ def iter_columns(self) -> Iterator[PandasLikeSeries]: version=self._version, ) + _iter_columns = iter_columns + def iter_rows( self: Self, *, diff --git a/narwhals/_pandas_like/selectors.py b/narwhals/_pandas_like/selectors.py index d6436514d0..bdf5cf33cd 100644 --- a/narwhals/_pandas_like/selectors.py +++ b/narwhals/_pandas_like/selectors.py @@ -1,133 +1,52 @@ from __future__ import annotations -import re from typing import TYPE_CHECKING -from typing import Any -from typing import Callable -from typing import Iterable -from typing import Sequence +from narwhals._pandas_like.dataframe import PandasLikeDataFrame from narwhals._pandas_like.expr import PandasLikeExpr -from narwhals.utils import _parse_time_unit_and_time_zone -from narwhals.utils import dtype_matches_time_unit_and_time_zone -from narwhals.utils import import_dtypes_module +from narwhals._pandas_like.series import PandasLikeSeries +from narwhals._selectors import CompliantSelector +from narwhals._selectors import EagerSelectorNamespace if TYPE_CHECKING: - from datetime import timezone - from typing_extensions import Self from narwhals._pandas_like.dataframe import PandasLikeDataFrame from narwhals._pandas_like.series import PandasLikeSeries - from narwhals.dtypes import DType - from narwhals.typing import TimeUnit + from narwhals._selectors import EvalNames + from narwhals._selectors import EvalSeries from narwhals.utils import _FullContext -class PandasSelectorNamespace: +class PandasSelectorNamespace( + EagerSelectorNamespace["PandasLikeDataFrame", "PandasLikeSeries"] +): + def _selector( + self, + call: EvalSeries[PandasLikeDataFrame, PandasLikeSeries], + evaluate_output_names: EvalNames[PandasLikeDataFrame], + /, + ) -> PandasSelector: + return PandasSelector( + call, + depth=0, + function_name="selector", + evaluate_output_names=evaluate_output_names, + alias_output_names=None, + implementation=self._implementation, + backend_version=self._backend_version, + version=self._version, + ) + def __init__(self: Self, context: _FullContext, /) -> None: self._implementation = context._implementation self._backend_version = context._backend_version self._version = context._version - def by_dtype(self: Self, dtypes: Iterable[DType | type[DType]]) -> PandasSelector: - def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: - return [df[col] for col in df.columns if df.schema[col] in dtypes] - - def evaluate_output_names(df: PandasLikeDataFrame) -> Sequence[str]: - return [col for col in df.columns if df.schema[col] in dtypes] - - return selector(self, func, evaluate_output_names) - - def matches(self: Self, pattern: str) -> PandasSelector: - def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: - return [df[col] for col in df.columns if re.search(pattern, col)] - - def evaluate_output_names(df: PandasLikeDataFrame) -> Sequence[str]: - return [col for col in df.columns if re.search(pattern, col)] - - return selector(self, func, evaluate_output_names) - - def numeric(self: Self) -> PandasSelector: - dtypes = import_dtypes_module(self._version) - return self.by_dtype( - { - dtypes.Int128, - dtypes.Int64, - dtypes.Int32, - dtypes.Int16, - dtypes.Int8, - dtypes.UInt128, - dtypes.UInt64, - dtypes.UInt32, - dtypes.UInt16, - dtypes.UInt8, - dtypes.Float64, - dtypes.Float32, - } - ) - - def categorical(self: Self) -> PandasSelector: - dtypes = import_dtypes_module(self._version) - return self.by_dtype({dtypes.Categorical}) - - def string(self: Self) -> PandasSelector: - dtypes = import_dtypes_module(self._version) - return self.by_dtype({dtypes.String}) - - def boolean(self: Self) -> PandasSelector: - dtypes = import_dtypes_module(self._version) - return self.by_dtype({dtypes.Boolean}) - - def all(self: Self) -> PandasSelector: - def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: - return [df[col] for col in df.columns] - - return selector(self, func, lambda df: df.columns) - - def datetime( - self: Self, - time_unit: TimeUnit | Iterable[TimeUnit] | None, - time_zone: str | timezone | Iterable[str | timezone | None] | None, - ) -> PandasSelector: - dtypes = import_dtypes_module(version=self._version) - time_units, time_zones = _parse_time_unit_and_time_zone( - time_unit=time_unit, time_zone=time_zone - ) - - def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: - return [ - df[col] - for col in df.columns - if dtype_matches_time_unit_and_time_zone( - dtype=df.schema[col], - dtypes=dtypes, - time_units=time_units, - time_zones=time_zones, - ) - ] - - def evaluate_output_names(df: PandasLikeDataFrame) -> Sequence[str]: - return [ - col - for col in df.columns - if dtype_matches_time_unit_and_time_zone( - dtype=df.schema[col], - dtypes=dtypes, - time_units=time_units, - time_zones=time_zones, - ) - ] - - return selector(self, func, evaluate_output_names) - - -class PandasSelector(PandasLikeExpr): - def __repr__(self) -> str: # pragma: no cover - return ( - f"PandasSelector(depth={self._depth}, function_name={self._function_name}, " - ) +class PandasSelector( # type: ignore[misc] + CompliantSelector["PandasLikeDataFrame", "PandasLikeSeries"], PandasLikeExpr +): def _to_expr(self: Self) -> PandasLikeExpr: return PandasLikeExpr( self._call, @@ -138,84 +57,4 @@ def _to_expr(self: Self) -> PandasLikeExpr: implementation=self._implementation, backend_version=self._backend_version, version=self._version, - call_kwargs=self._call_kwargs, ) - - def __sub__(self: Self, other: PandasSelector | Any) -> PandasSelector | Any: - if isinstance(other, PandasSelector): - - def call(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: - lhs_names = self._evaluate_output_names(df) - rhs_names = other._evaluate_output_names(df) - lhs = self._call(df) - return [x for x, name in zip(lhs, lhs_names) if name not in rhs_names] - - def evaluate_output_names(df: PandasLikeDataFrame) -> list[str]: - lhs_names = self._evaluate_output_names(df) - rhs_names = other._evaluate_output_names(df) - return [x for x in lhs_names if x not in rhs_names] - - return selector(self, call, evaluate_output_names) - else: - return self._to_expr() - other - - def __or__(self: Self, other: PandasSelector | Any) -> PandasSelector | Any: - if isinstance(other, PandasSelector): - - def call(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: - lhs_names = self._evaluate_output_names(df) - rhs_names = other._evaluate_output_names(df) - lhs = self._call(df) - rhs = other._call(df) - return [ - *(x for x, name in zip(lhs, lhs_names) if name not in rhs_names), - *rhs, - ] - - def evaluate_output_names(df: PandasLikeDataFrame) -> list[str]: - lhs_names = self._evaluate_output_names(df) - rhs_names = other._evaluate_output_names(df) - return [*(x for x in lhs_names if x not in rhs_names), *rhs_names] - - return selector(self, call, evaluate_output_names) - else: - return self._to_expr() | other - - def __and__(self: Self, other: PandasSelector | Any) -> PandasSelector | Any: - if isinstance(other, PandasSelector): - - def call(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: - lhs_names = self._evaluate_output_names(df) - rhs_names = other._evaluate_output_names(df) - lhs = self._call(df) - return [x for x, name in zip(lhs, lhs_names) if name in rhs_names] - - def evaluate_output_names(df: PandasLikeDataFrame) -> list[str]: - lhs_names = self._evaluate_output_names(df) - rhs_names = other._evaluate_output_names(df) - return [x for x in lhs_names if x in rhs_names] - - return selector(self, call, evaluate_output_names) - else: - return self._to_expr() & other - - def __invert__(self: Self) -> PandasSelector: - return PandasSelectorNamespace(self).all() - self - - -def selector( - context: _FullContext, - call: Callable[[PandasLikeDataFrame], Sequence[PandasLikeSeries]], - evaluate_output_names: Callable[[PandasLikeDataFrame], Sequence[str]], - /, -) -> PandasSelector: - return PandasSelector( - call, - depth=0, - function_name="selector", - evaluate_output_names=evaluate_output_names, - alias_output_names=None, - implementation=context._implementation, - backend_version=context._backend_version, - version=context._version, - ) diff --git a/narwhals/_polars/dataframe.py b/narwhals/_polars/dataframe.py index bc25184e11..4683d59e7a 100644 --- a/narwhals/_polars/dataframe.py +++ b/narwhals/_polars/dataframe.py @@ -438,6 +438,9 @@ def func(*args: Any, **kwargs: Any) -> Any: return func + def _iter_columns(self) -> Iterator[PolarsSeries]: # pragma: no cover + yield from self.collect(self._implementation).iter_columns() + @property def columns(self: Self) -> list[str]: return self._native_frame.columns diff --git a/narwhals/_selectors.py b/narwhals/_selectors.py new file mode 100644 index 0000000000..639b0eb740 --- /dev/null +++ b/narwhals/_selectors.py @@ -0,0 +1,298 @@ +"""Almost entirely complete, generic `selectors` implementation.""" + +from __future__ import annotations + +import re +from functools import partial +from typing import TYPE_CHECKING +from typing import Any +from typing import Callable +from typing import Collection +from typing import Iterable +from typing import Iterator +from typing import Sequence +from typing import TypeVar +from typing import overload + +from narwhals.typing import CompliantExpr +from narwhals.utils import _parse_time_unit_and_time_zone +from narwhals.utils import dtype_matches_time_unit_and_time_zone +from narwhals.utils import get_column_names +from narwhals.utils import import_dtypes_module +from narwhals.utils import is_compliant_dataframe +from narwhals.utils import is_tracks_depth + +if not TYPE_CHECKING: # pragma: no cover + # TODO @dangotbanned: Remove after dropping `3.8` (#2084) + # - https://github.com/narwhals-dev/narwhals/pull/2064#discussion_r1965921386 + import sys + + if sys.version_info >= (3, 9): + from typing import Protocol + else: + from typing import Generic + + Protocol = Generic +else: # pragma: no cover + from typing import Protocol + +if TYPE_CHECKING: + from datetime import timezone + + from typing_extensions import Self + from typing_extensions import TypeAlias + from typing_extensions import TypeIs + + from narwhals.dtypes import DType + from narwhals.typing import CompliantDataFrame + from narwhals.typing import CompliantLazyFrame + from narwhals.typing import CompliantSeries + from narwhals.typing import TimeUnit + from narwhals.utils import Implementation + from narwhals.utils import Version + + +SeriesT = TypeVar("SeriesT", bound="CompliantSeries") +FrameT = TypeVar("FrameT", bound="CompliantDataFrame[Any] | CompliantLazyFrame") +DataFrameT = TypeVar("DataFrameT", bound="CompliantDataFrame[Any]") +LazyFrameT = TypeVar("LazyFrameT", bound="CompliantLazyFrame") +SelectorOrExpr: TypeAlias = ( + "CompliantSelector[FrameT, SeriesT] | CompliantExpr[FrameT, SeriesT]" +) +EvalSeries: TypeAlias = Callable[[FrameT], Sequence[SeriesT]] +EvalNames: TypeAlias = Callable[[FrameT], Sequence[str]] + + +class CompliantSelectorNamespace(Protocol[FrameT, SeriesT]): + _implementation: Implementation + _backend_version: tuple[int, ...] + _version: Version + + def _selector( + self, + call: EvalSeries[FrameT, SeriesT], + evaluate_output_names: EvalNames[FrameT], + /, + ) -> CompliantSelector[FrameT, SeriesT]: ... + + def _iter_columns(self, df: FrameT, /) -> Iterator[SeriesT]: ... + + def _iter_schema(self, df: FrameT, /) -> Iterator[tuple[str, DType]]: + for ser in self._iter_columns(df): + yield ser.name, ser.dtype + + def _iter_columns_dtypes(self, df: FrameT, /) -> Iterator[tuple[SeriesT, DType]]: + # NOTE: Defined to be overridden for lazy + # - Their `SeriesT` is a **native** object + # - `.dtype` won't return a `nw.DType` (or maybe anything) for lazy backends + # - See (https://github.com/narwhals-dev/narwhals/issues/2044) + for ser in self._iter_columns(df): + yield ser, ser.dtype + + def _iter_columns_names(self, df: FrameT, /) -> Iterator[tuple[SeriesT, str]]: + yield from zip(self._iter_columns(df), df.columns) + + def _is_dtype( + self: CompliantSelectorNamespace[FrameT, SeriesT], dtype: type[DType], / + ) -> CompliantSelector[FrameT, SeriesT]: + def series(df: FrameT) -> Sequence[SeriesT]: + return [ + ser for ser, tp in self._iter_columns_dtypes(df) if isinstance(tp, dtype) + ] + + def names(df: FrameT) -> Sequence[str]: + return [name for name, tp in self._iter_schema(df) if isinstance(tp, dtype)] + + return self._selector(series, names) + + def by_dtype( + self: Self, dtypes: Collection[DType | type[DType]] + ) -> CompliantSelector[FrameT, SeriesT]: + def series(df: FrameT) -> Sequence[SeriesT]: + return [ser for ser, tp in self._iter_columns_dtypes(df) if tp in dtypes] + + def names(df: FrameT) -> Sequence[str]: + return [name for name, tp in self._iter_schema(df) if tp in dtypes] + + return self._selector(series, names) + + def matches(self: Self, pattern: str) -> CompliantSelector[FrameT, SeriesT]: + p = re.compile(pattern) + + def series(df: FrameT) -> Sequence[SeriesT]: + if is_compliant_dataframe(df) and not self._implementation.is_duckdb(): + return [df.get_column(col) for col in df.columns if p.search(col)] + + return [ser for ser, name in self._iter_columns_names(df) if p.search(name)] + + def names(df: FrameT) -> Sequence[str]: + return [col for col in df.columns if p.search(col)] + + return self._selector(series, names) + + def numeric(self: Self) -> CompliantSelector[FrameT, SeriesT]: + def series(df: FrameT) -> Sequence[SeriesT]: + return [ser for ser, tp in self._iter_columns_dtypes(df) if tp.is_numeric()] + + def names(df: FrameT) -> Sequence[str]: + return [name for name, tp in self._iter_schema(df) if tp.is_numeric()] + + return self._selector(series, names) + + def categorical(self: Self) -> CompliantSelector[FrameT, SeriesT]: + return self._is_dtype(import_dtypes_module(self._version).Categorical) + + def string(self: Self) -> CompliantSelector[FrameT, SeriesT]: + return self._is_dtype(import_dtypes_module(self._version).String) + + def boolean(self: Self) -> CompliantSelector[FrameT, SeriesT]: + return self._is_dtype(import_dtypes_module(self._version).Boolean) + + def all(self: Self) -> CompliantSelector[FrameT, SeriesT]: + def series(df: FrameT) -> Sequence[SeriesT]: + return list(self._iter_columns(df)) + + return self._selector(series, get_column_names) + + def datetime( + self: Self, + time_unit: TimeUnit | Iterable[TimeUnit] | None, + time_zone: str | timezone | Iterable[str | timezone | None] | None, + ) -> CompliantSelector[FrameT, SeriesT]: + time_units, time_zones = _parse_time_unit_and_time_zone(time_unit, time_zone) + matches = partial( + dtype_matches_time_unit_and_time_zone, + dtypes=import_dtypes_module(version=self._version), + time_units=time_units, + time_zones=time_zones, + ) + + def series(df: FrameT) -> Sequence[SeriesT]: + return [ser for ser, tp in self._iter_columns_dtypes(df) if matches(tp)] + + def names(df: FrameT) -> Sequence[str]: + return [name for name, tp in self._iter_schema(df) if matches(tp)] + + return self._selector(series, names) + + +class EagerSelectorNamespace( + CompliantSelectorNamespace[DataFrameT, SeriesT], Protocol[DataFrameT, SeriesT] +): + def _iter_columns(self, df: DataFrameT, /) -> Iterator[SeriesT]: + yield from df.iter_columns() + + +class LazySelectorNamespace( + CompliantSelectorNamespace[LazyFrameT, SeriesT], Protocol[LazyFrameT, SeriesT] +): + def _iter_schema(self, df: LazyFrameT) -> Iterator[tuple[str, DType]]: + yield from df.schema.items() + + def _iter_columns(self, df: LazyFrameT) -> Iterator[SeriesT]: + yield from df._iter_columns() + + def _iter_columns_dtypes(self, df: LazyFrameT, /) -> Iterator[tuple[SeriesT, DType]]: + yield from zip(self._iter_columns(df), df.schema.values()) + + +class CompliantSelector(CompliantExpr[FrameT, SeriesT], Protocol[FrameT, SeriesT]): + @property + def selectors(self) -> CompliantSelectorNamespace[FrameT, SeriesT]: + return self.__narwhals_namespace__().selectors + + def _to_expr(self: Self) -> CompliantExpr[FrameT, SeriesT]: ... + + def _is_selector( + self: Self, other: Self | CompliantExpr[FrameT, SeriesT] + ) -> TypeIs[CompliantSelector[FrameT, SeriesT]]: + return isinstance(other, type(self)) + + @overload + def __sub__(self: Self, other: Self) -> Self: ... + @overload + def __sub__( + self: Self, other: CompliantExpr[FrameT, SeriesT] + ) -> CompliantExpr[FrameT, SeriesT]: ... + def __sub__( + self: Self, other: SelectorOrExpr[FrameT, SeriesT] + ) -> SelectorOrExpr[FrameT, SeriesT]: + if self._is_selector(other): + + def series(df: FrameT) -> Sequence[SeriesT]: + lhs_names, rhs_names = _eval_lhs_rhs(df, self, other) + return [ + x for x, name in zip(self(df), lhs_names) if name not in rhs_names + ] + + def names(df: FrameT) -> Sequence[str]: + lhs_names, rhs_names = _eval_lhs_rhs(df, self, other) + return [x for x in lhs_names if x not in rhs_names] + + return self.selectors._selector(series, names) + else: + return self._to_expr() - other + + @overload + def __or__(self: Self, other: Self) -> Self: ... + @overload + def __or__( + self: Self, other: CompliantExpr[FrameT, SeriesT] + ) -> CompliantExpr[FrameT, SeriesT]: ... + def __or__( + self: Self, other: SelectorOrExpr[FrameT, SeriesT] + ) -> SelectorOrExpr[FrameT, SeriesT]: + if self._is_selector(other): + + def names(df: FrameT) -> Sequence[SeriesT]: + lhs_names, rhs_names = _eval_lhs_rhs(df, self, other) + return [ + *(x for x, name in zip(self(df), lhs_names) if name not in rhs_names), + *other(df), + ] + + def series(df: FrameT) -> Sequence[str]: + lhs_names, rhs_names = _eval_lhs_rhs(df, self, other) + return [*(x for x in lhs_names if x not in rhs_names), *rhs_names] + + return self.selectors._selector(names, series) + else: + return self._to_expr() | other + + @overload + def __and__(self: Self, other: Self) -> Self: ... + @overload + def __and__( + self: Self, other: CompliantExpr[FrameT, SeriesT] + ) -> CompliantExpr[FrameT, SeriesT]: ... + def __and__( + self: Self, other: SelectorOrExpr[FrameT, SeriesT] + ) -> SelectorOrExpr[FrameT, SeriesT]: + if self._is_selector(other): + + def series(df: FrameT) -> Sequence[SeriesT]: + lhs_names, rhs_names = _eval_lhs_rhs(df, self, other) + return [x for x, name in zip(self(df), lhs_names) if name in rhs_names] + + def names(df: FrameT) -> Sequence[str]: + lhs_names, rhs_names = _eval_lhs_rhs(df, self, other) + return [x for x in lhs_names if x in rhs_names] + + return self.selectors._selector(series, names) + else: + return self._to_expr() & other + + def __invert__(self: Self) -> CompliantSelector[FrameT, SeriesT]: + return self.selectors.all() - self # type: ignore[no-any-return] + + def __repr__(self: Self) -> str: # pragma: no cover + s = f"depth={self._depth}, " if is_tracks_depth(self._implementation) else "" + return f"{type(self).__name__}({s}function_name={self._function_name})" + + +def _eval_lhs_rhs( + df: CompliantDataFrame[Any] | CompliantLazyFrame, + lhs: CompliantExpr[Any, Any], + rhs: CompliantExpr[Any, Any], +) -> tuple[Sequence[str], Sequence[str]]: + return lhs._evaluate_output_names(df), rhs._evaluate_output_names(df) diff --git a/narwhals/_spark_like/dataframe.py b/narwhals/_spark_like/dataframe.py index d4a792e678..ec5d7aafd7 100644 --- a/narwhals/_spark_like/dataframe.py +++ b/narwhals/_spark_like/dataframe.py @@ -4,6 +4,7 @@ from importlib import import_module from typing import TYPE_CHECKING from typing import Any +from typing import Iterator from typing import Literal from typing import Sequence from typing import cast @@ -195,6 +196,10 @@ def _collect_to_arrow(self) -> pa.Table: to_arrow: Incomplete = self._native_frame.toArrow return to_arrow() + def _iter_columns(self) -> Iterator[Column]: + for col in self.columns: + yield self._F.col(col) + @property def columns(self: Self) -> list[str]: return list(self.schema) diff --git a/narwhals/_spark_like/selectors.py b/narwhals/_spark_like/selectors.py index e037e1f8a3..eb7ab72fae 100644 --- a/narwhals/_spark_like/selectors.py +++ b/narwhals/_spark_like/selectors.py @@ -1,131 +1,46 @@ from __future__ import annotations -import re from typing import TYPE_CHECKING -from typing import Any -from typing import Callable -from typing import Iterable -from typing import Sequence +from narwhals._selectors import CompliantSelector +from narwhals._selectors import LazySelectorNamespace from narwhals._spark_like.expr import SparkLikeExpr -from narwhals.utils import _parse_time_unit_and_time_zone -from narwhals.utils import dtype_matches_time_unit_and_time_zone -from narwhals.utils import import_dtypes_module if TYPE_CHECKING: - from datetime import timezone - from pyspark.sql import Column from typing_extensions import Self + from narwhals._selectors import EvalNames + from narwhals._selectors import EvalSeries from narwhals._spark_like.dataframe import SparkLikeLazyFrame - from narwhals.dtypes import DType - from narwhals.typing import TimeUnit from narwhals.utils import _FullContext -class SparkLikeSelectorNamespace: +# NOTE: See issue regarding ignores (#2044) +class SparkLikeSelectorNamespace(LazySelectorNamespace["SparkLikeLazyFrame", "Column"]): # type: ignore[type-var] + def _selector( + self, + call: EvalSeries[SparkLikeLazyFrame, Column], # type: ignore[type-var] + evaluate_output_names: EvalNames[SparkLikeLazyFrame], + /, + ) -> SparkLikeSelector: + return SparkLikeSelector( + call, + function_name="selector", + evaluate_output_names=evaluate_output_names, + alias_output_names=None, + backend_version=self._backend_version, + version=self._version, + implementation=self._implementation, + ) + def __init__(self: Self, context: _FullContext, /) -> None: self._backend_version = context._backend_version self._version = context._version self._implementation = context._implementation - def by_dtype(self: Self, dtypes: Iterable[DType | type[DType]]) -> SparkLikeSelector: - def func(df: SparkLikeLazyFrame) -> list[Column]: - return [df._F.col(col) for col in df.columns if df.schema[col] in dtypes] - - def evaluate_output_names(df: SparkLikeLazyFrame) -> Sequence[str]: - return [col for col in df.columns if df.schema[col] in dtypes] - - return selector(self, func, evaluate_output_names) - - def matches(self: Self, pattern: str) -> SparkLikeSelector: - def func(df: SparkLikeLazyFrame) -> list[Column]: - return [df._F.col(col) for col in df.columns if re.search(pattern, col)] - - def evaluate_output_names(df: SparkLikeLazyFrame) -> Sequence[str]: - return [col for col in df.columns if re.search(pattern, col)] - - return selector(self, func, evaluate_output_names) - - def numeric(self: Self) -> SparkLikeSelector: - dtypes = import_dtypes_module(self._version) - return self.by_dtype( - { - dtypes.Int128, - dtypes.Int64, - dtypes.Int32, - dtypes.Int16, - dtypes.Int8, - dtypes.UInt128, - dtypes.UInt64, - dtypes.UInt32, - dtypes.UInt16, - dtypes.UInt8, - dtypes.Float64, - dtypes.Float32, - }, - ) - - def categorical(self: Self) -> SparkLikeSelector: - dtypes = import_dtypes_module(self._version) - return self.by_dtype({dtypes.Categorical}) - - def string(self: Self) -> SparkLikeSelector: - dtypes = import_dtypes_module(self._version) - return self.by_dtype({dtypes.String}) - - def boolean(self: Self) -> SparkLikeSelector: - dtypes = import_dtypes_module(self._version) - return self.by_dtype({dtypes.Boolean}) - - def all(self: Self) -> SparkLikeSelector: - def func(df: SparkLikeLazyFrame) -> list[Column]: - return [df._F.col(col) for col in df.columns] - - return selector(self, func, lambda df: df.columns) - - def datetime( - self: Self, - time_unit: TimeUnit | Iterable[TimeUnit] | None, - time_zone: str | timezone | Iterable[str | timezone | None] | None, - ) -> SparkLikeSelector: - dtypes = import_dtypes_module(version=self._version) - time_units, time_zones = _parse_time_unit_and_time_zone( - time_unit=time_unit, time_zone=time_zone - ) - - def func(df: SparkLikeLazyFrame) -> list[Column]: - return [ - df._F.col(col) - for col in df.columns - if dtype_matches_time_unit_and_time_zone( - dtype=df.schema[col], - dtypes=dtypes, - time_units=time_units, - time_zones=time_zones, - ) - ] - - def evaluate_output_names(df: SparkLikeLazyFrame) -> Sequence[str]: - return [ - col - for col in df.columns - if dtype_matches_time_unit_and_time_zone( - dtype=df.schema[col], - dtypes=dtypes, - time_units=time_units, - time_zones=time_zones, - ) - ] - - return selector(self, func, evaluate_output_names) - - -class SparkLikeSelector(SparkLikeExpr): - def __repr__(self: Self) -> str: # pragma: no cover - return f"SparkLikeSelector(function_name={self._function_name})" +class SparkLikeSelector(CompliantSelector["SparkLikeLazyFrame", "Column"], SparkLikeExpr): # type: ignore[type-var, misc] def _to_expr(self: Self) -> SparkLikeExpr: return SparkLikeExpr( self._call, @@ -136,81 +51,3 @@ def _to_expr(self: Self) -> SparkLikeExpr: version=self._version, implementation=self._implementation, ) - - def __sub__(self: Self, other: SparkLikeSelector | Any) -> SparkLikeSelector | Any: - if isinstance(other, SparkLikeSelector): - - def call(df: SparkLikeLazyFrame) -> list[Column]: - lhs_names = self._evaluate_output_names(df) - rhs_names = other._evaluate_output_names(df) - lhs = self._call(df) - return [x for x, name in zip(lhs, lhs_names) if name not in rhs_names] - - def evaluate_output_names(df: SparkLikeLazyFrame) -> list[str]: - lhs_names = self._evaluate_output_names(df) - rhs_names = other._evaluate_output_names(df) - return [x for x in lhs_names if x not in rhs_names] - - return selector(self, call, evaluate_output_names) - else: - return self._to_expr() - other - - def __or__(self: Self, other: SparkLikeSelector | Any) -> SparkLikeSelector | Any: - if isinstance(other, SparkLikeSelector): - - def call(df: SparkLikeLazyFrame) -> list[Column]: - lhs_names = self._evaluate_output_names(df) - rhs_names = other._evaluate_output_names(df) - lhs = self._call(df) - rhs = other._call(df) - return [ - *(x for x, name in zip(lhs, lhs_names) if name not in rhs_names), - *rhs, - ] - - def evaluate_output_names(df: SparkLikeLazyFrame) -> list[str]: - lhs_names = self._evaluate_output_names(df) - rhs_names = other._evaluate_output_names(df) - return [*(x for x in lhs_names if x not in rhs_names), *rhs_names] - - return selector(self, call, evaluate_output_names) - else: - return self._to_expr() | other - - def __and__(self: Self, other: SparkLikeSelector | Any) -> SparkLikeSelector | Any: - if isinstance(other, SparkLikeSelector): - - def call(df: SparkLikeLazyFrame) -> list[Column]: - lhs_names = self._evaluate_output_names(df) - rhs_names = other._evaluate_output_names(df) - lhs = self._call(df) - return [x for x, name in zip(lhs, lhs_names) if name in rhs_names] - - def evaluate_output_names(df: SparkLikeLazyFrame) -> list[str]: - lhs_names = self._evaluate_output_names(df) - rhs_names = other._evaluate_output_names(df) - return [x for x in lhs_names if x in rhs_names] - - return selector(self, call, evaluate_output_names) - else: - return self._to_expr() & other - - def __invert__(self: Self) -> SparkLikeSelector: - return SparkLikeSelectorNamespace(self).all() - self - - -def selector( - context: _FullContext, - call: Callable[[SparkLikeLazyFrame], Sequence[Column]], - evaluate_output_names: Callable[[SparkLikeLazyFrame], Sequence[str]], - /, -) -> SparkLikeSelector: - return SparkLikeSelector( - call, - function_name="selector", - evaluate_output_names=evaluate_output_names, - alias_output_names=None, - backend_version=context._backend_version, - version=context._version, - implementation=context._implementation, - ) diff --git a/narwhals/typing.py b/narwhals/typing.py index ceb4677220..d131a32ec5 100644 --- a/narwhals/typing.py +++ b/narwhals/typing.py @@ -3,14 +3,28 @@ from typing import TYPE_CHECKING from typing import Any from typing import Callable +from typing import Iterator from typing import Literal from typing import Protocol from typing import Sequence from typing import TypeVar from typing import Union +if not TYPE_CHECKING: + import sys + + if sys.version_info >= (3, 9): + from typing import Protocol as Protocol38 + else: + from typing import Generic as Protocol38 +else: + # TODO @dangotbanned: Remove after dropping `3.8` (#2084) + # - https://github.com/narwhals-dev/narwhals/pull/2064#discussion_r1965921386 + from typing import Protocol as Protocol38 + if TYPE_CHECKING: from types import ModuleType + from typing import Mapping import numpy as np from typing_extensions import Self @@ -18,6 +32,7 @@ from narwhals import dtypes from narwhals._expression_parsing import ExprKind + from narwhals._selectors import CompliantSelectorNamespace from narwhals.dataframe import DataFrame from narwhals.dataframe import LazyFrame from narwhals.dtypes import DType @@ -68,7 +83,10 @@ def aggregate(self, *exprs: Any) -> Self: @property def columns(self) -> Sequence[str]: ... + @property + def schema(self) -> Mapping[str, DType]: ... def get_column(self, name: str) -> CompliantSeriesT_co: ... + def iter_columns(self) -> Iterator[CompliantSeriesT_co]: ... class CompliantLazyFrame(Protocol): @@ -83,6 +101,9 @@ def aggregate(self, *exprs: Any) -> Self: @property def columns(self) -> Sequence[str]: ... + @property + def schema(self) -> Mapping[str, DType]: ... + def _iter_columns(self) -> Iterator[Any]: ... CompliantFrameT = TypeVar( @@ -90,7 +111,7 @@ def columns(self) -> Sequence[str]: ... ) -class CompliantExpr(Protocol[CompliantFrameT, CompliantSeriesT_co]): +class CompliantExpr(Protocol38[CompliantFrameT, CompliantSeriesT_co]): _implementation: Implementation _backend_version: tuple[int, ...] _version: Version @@ -132,6 +153,8 @@ def col( def lit( self, value: Any, dtype: DType | None ) -> CompliantExpr[CompliantFrameT, CompliantSeriesT_co]: ... + @property + def selectors(self) -> CompliantSelectorNamespace[Any, Any]: ... class SupportsNativeNamespace(Protocol): diff --git a/narwhals/utils.py b/narwhals/utils.py index 6cf39d7ba4..f5959bb550 100644 --- a/narwhals/utils.py +++ b/narwhals/utils.py @@ -10,6 +10,7 @@ from typing import TYPE_CHECKING from typing import Any from typing import Iterable +from typing import Literal from typing import Sequence from typing import TypeVar from typing import Union @@ -46,6 +47,7 @@ import pandas as pd from typing_extensions import Self + from typing_extensions import TypeAlias from typing_extensions import TypeIs from narwhals.dataframe import DataFrame @@ -73,6 +75,8 @@ _T2 = TypeVar("_T2") _T3 = TypeVar("_T3") + _TracksDepth: TypeAlias = "Literal[Implementation.DASK,Implementation.CUDF,Implementation.MODIN,Implementation.PANDAS,Implementation.PYSPARK]" + class _SupportsVersion(Protocol): __version__: str @@ -1383,3 +1387,8 @@ def has_native_namespace(obj: Any) -> TypeIs[SupportsNativeNamespace]: def _supports_dataframe_interchange(obj: Any) -> TypeIs[DataFrameLike]: return hasattr(obj, "__dataframe__") + + +def is_tracks_depth(obj: Implementation, /) -> TypeIs[_TracksDepth]: # pragma: no cover + # Return `True` for implementations that utilize `CompliantExpr._depth`. + return obj.is_pandas_like() or obj in {Implementation.PYARROW, Implementation.DASK}