diff --git a/narwhals/_arrow/namespace.py b/narwhals/_arrow/namespace.py index bf85c75a25..062ab0aaa7 100644 --- a/narwhals/_arrow/namespace.py +++ b/narwhals/_arrow/namespace.py @@ -348,9 +348,7 @@ def concat( @property def selectors(self: Self) -> ArrowSelectorNamespace: - return ArrowSelectorNamespace( - backend_version=self._backend_version, version=self._version - ) + return ArrowSelectorNamespace(self) def when(self: Self, predicate: ArrowExpr) -> ArrowWhen: return ArrowWhen(predicate, self._backend_version, version=self._version) diff --git a/narwhals/_arrow/selectors.py b/narwhals/_arrow/selectors.py index 5065aaf801..f5a6a71a2c 100644 --- a/narwhals/_arrow/selectors.py +++ b/narwhals/_arrow/selectors.py @@ -3,11 +3,11 @@ import re from typing import TYPE_CHECKING from typing import Any +from typing import Callable from typing import Iterable from typing import Sequence from narwhals._arrow.expr import ArrowExpr -from narwhals.utils import Implementation from narwhals.utils import _parse_time_unit_and_time_zone from narwhals.utils import dtype_matches_time_unit_and_time_zone from narwhals.utils import import_dtypes_module @@ -21,16 +21,13 @@ from narwhals._arrow.series import ArrowSeries from narwhals.dtypes import DType from narwhals.typing import TimeUnit - from narwhals.utils import Version + from narwhals.utils import _LimitedContext class ArrowSelectorNamespace: - def __init__( - self: Self, *, backend_version: tuple[int, ...], version: Version - ) -> None: - self._backend_version = backend_version - self._implementation = Implementation.PYARROW - self._version = version + def __init__(self: Self, context: _LimitedContext, /) -> None: + self._backend_version = context._backend_version + self._version = context._version def by_dtype(self: Self, dtypes: Iterable[DType | type[DType]]) -> ArrowSelector: def func(df: ArrowDataFrame) -> list[ArrowSeries]: @@ -39,16 +36,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: def evaluate_output_names(df: ArrowDataFrame) -> Sequence[str]: return [col for col in df.columns if df.schema[col] in dtypes] - return ArrowSelector( - func, - depth=0, - function_name="selector", - evaluate_output_names=evaluate_output_names, - alias_output_names=None, - backend_version=self._backend_version, - version=self._version, - kwargs={"dtypes": dtypes}, - ) + return selector(self, func, evaluate_output_names, {"dtypes": dtypes}) def matches(self: Self, pattern: str) -> ArrowSelector: def func(df: ArrowDataFrame) -> list[ArrowSeries]: @@ -57,16 +45,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: def evaluate_output_names(df: ArrowDataFrame) -> Sequence[str]: return [col for col in df.columns if re.search(pattern, col)] - return ArrowSelector( - func, - depth=0, - function_name="selector", - evaluate_output_names=evaluate_output_names, - alias_output_names=None, - backend_version=self._backend_version, - version=self._version, - kwargs={"pattern": pattern}, - ) + return selector(self, func, evaluate_output_names, {"pattern": pattern}) def numeric(self: Self) -> ArrowSelector: dtypes = import_dtypes_module(self._version) @@ -103,16 +82,7 @@ def all(self: Self) -> ArrowSelector: def func(df: ArrowDataFrame) -> list[ArrowSeries]: return [df[col] for col in df.columns] - return ArrowSelector( - func, - depth=0, - function_name="selector", - evaluate_output_names=lambda df: df.columns, - alias_output_names=None, - backend_version=self._backend_version, - version=self._version, - kwargs={}, - ) + return selector(self, func, lambda df: df.columns, {}) def datetime( self: Self, @@ -148,16 +118,7 @@ def evaluate_output_names(df: ArrowDataFrame) -> Sequence[str]: ) ] - return ArrowSelector( - func, - depth=0, - function_name="selector", - evaluate_output_names=evaluate_output_names, - alias_output_names=None, - backend_version=self._backend_version, - version=self._version, - kwargs={}, - ) + return selector(self, func, evaluate_output_names, {}) class ArrowSelector(ArrowExpr): @@ -190,15 +151,8 @@ def evaluate_output_names(df: ArrowDataFrame) -> list[str]: rhs_names = other._evaluate_output_names(df) return [x for x in lhs_names if x not in rhs_names] - return ArrowSelector( - call, - depth=0, - function_name="selector", - evaluate_output_names=evaluate_output_names, - alias_output_names=None, - backend_version=self._backend_version, - version=self._version, - kwargs={**self._kwargs, "other": other}, + return selector( + self, call, evaluate_output_names, {**self._kwargs, "other": other} ) else: return self._to_expr() - other @@ -221,15 +175,8 @@ def evaluate_output_names(df: ArrowDataFrame) -> list[str]: rhs_names = other._evaluate_output_names(df) return [*(x for x in lhs_names if x not in rhs_names), *rhs_names] - return ArrowSelector( - call, - depth=0, - function_name="selector", - evaluate_output_names=evaluate_output_names, - alias_output_names=None, - backend_version=self._backend_version, - version=self._version, - kwargs={**self._kwargs, "other": other}, + return selector( + self, call, evaluate_output_names, {**self._kwargs, "other": other} ) else: return self._to_expr() | other @@ -248,23 +195,31 @@ def evaluate_output_names(df: ArrowDataFrame) -> list[str]: rhs_names = other._evaluate_output_names(df) return [x for x in lhs_names if x in rhs_names] - return ArrowSelector( - call, - depth=0, - function_name="selector", - evaluate_output_names=evaluate_output_names, - alias_output_names=None, - backend_version=self._backend_version, - version=self._version, - kwargs={**self._kwargs, "other": other}, + return selector( + self, call, evaluate_output_names, {**self._kwargs, "other": other} ) + else: return self._to_expr() & other def __invert__(self: Self) -> ArrowSelector: - return ( - ArrowSelectorNamespace( - backend_version=self._backend_version, version=self._version - ).all() - - self - ) + return ArrowSelectorNamespace(self).all() - self + + +def selector( + context: _LimitedContext, + call: Callable[[ArrowDataFrame], Sequence[ArrowSeries]], + evaluate_output_names: Callable[[ArrowDataFrame], Sequence[str]], + kwargs: dict[str, Any], + /, +) -> ArrowSelector: + return ArrowSelector( + call, + depth=0, + function_name="selector", + evaluate_output_names=evaluate_output_names, + alias_output_names=None, + backend_version=context._backend_version, + version=context._version, + kwargs=kwargs, + ) diff --git a/narwhals/_dask/namespace.py b/narwhals/_dask/namespace.py index cf825d27f7..a8c3e41798 100644 --- a/narwhals/_dask/namespace.py +++ b/narwhals/_dask/namespace.py @@ -41,9 +41,7 @@ class DaskNamespace(CompliantNamespace["dx.Series"]): @property def selectors(self: Self) -> DaskSelectorNamespace: - return DaskSelectorNamespace( - backend_version=self._backend_version, version=self._version - ) + return DaskSelectorNamespace(self) def __init__( self: Self, *, backend_version: tuple[int, ...], version: Version diff --git a/narwhals/_dask/selectors.py b/narwhals/_dask/selectors.py index 62df505289..123da1212d 100644 --- a/narwhals/_dask/selectors.py +++ b/narwhals/_dask/selectors.py @@ -3,6 +3,7 @@ import re from typing import TYPE_CHECKING from typing import Any +from typing import Callable from typing import Iterable from typing import Sequence @@ -24,7 +25,7 @@ from narwhals._dask.dataframe import DaskLazyFrame from narwhals.dtypes import DType from narwhals.typing import TimeUnit - from narwhals.utils import Version + from narwhals.utils import _LimitedContext try: import dask.dataframe.dask_expr as dx @@ -33,11 +34,9 @@ class DaskSelectorNamespace: - def __init__( - self: Self, *, backend_version: tuple[int, ...], version: Version - ) -> None: - self._backend_version = backend_version - self._version = version + def __init__(self: Self, context: _LimitedContext, /) -> None: + self._backend_version = context._backend_version + self._version = context._version def by_dtype(self: Self, dtypes: Iterable[DType | type[DType]]) -> DaskSelector: def func(df: DaskLazyFrame) -> list[dx.Series]: @@ -48,16 +47,7 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: def evaluate_output_names(df: DaskLazyFrame) -> Sequence[str]: return [col for col in df.columns if df.schema[col] in dtypes] - return DaskSelector( - func, - depth=0, - function_name="selector", - evaluate_output_names=evaluate_output_names, - alias_output_names=None, - backend_version=self._backend_version, - version=self._version, - kwargs={}, - ) + return selector(self, func, evaluate_output_names) def matches(self: Self, pattern: str) -> DaskSelector: def func(df: DaskLazyFrame) -> list[dx.Series]: @@ -68,16 +58,7 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: def evaluate_output_names(df: DaskLazyFrame) -> Sequence[str]: return [col for col in df.columns if re.search(pattern, col)] - return DaskSelector( - func, - depth=0, - function_name="selector", - evaluate_output_names=evaluate_output_names, - alias_output_names=None, - backend_version=self._backend_version, - version=self._version, - kwargs={}, - ) + return selector(self, func, evaluate_output_names) def numeric(self: Self) -> DaskSelector: dtypes = import_dtypes_module(self._version) @@ -114,16 +95,7 @@ def all(self: Self) -> DaskSelector: def func(df: DaskLazyFrame) -> list[dx.Series]: return [df._native_frame[col] for col in df.columns] - return DaskSelector( - func, - depth=0, - function_name="selector", - evaluate_output_names=lambda df: df.columns, - alias_output_names=None, - backend_version=self._backend_version, - version=self._version, - kwargs={}, - ) + return selector(self, func, lambda df: df.columns) def datetime( self: Self, @@ -159,16 +131,7 @@ def evaluate_output_names(df: DaskLazyFrame) -> Sequence[str]: ) ] - return DaskSelector( - func, - depth=0, - function_name="selector", - evaluate_output_names=evaluate_output_names, - alias_output_names=None, - backend_version=self._backend_version, - version=self._version, - kwargs={}, - ) + return selector(self, func, evaluate_output_names) class DaskSelector(DaskExpr): @@ -201,16 +164,7 @@ def evaluate_output_names(df: DaskLazyFrame) -> list[str]: rhs_names = other._evaluate_output_names(df) return [x for x in lhs_names if x not in rhs_names] - return DaskSelector( - call, - depth=0, - function_name="selector", - evaluate_output_names=evaluate_output_names, - alias_output_names=None, - backend_version=self._backend_version, - version=self._version, - kwargs={}, - ) + return selector(self, call, evaluate_output_names) else: return self._to_expr() - other @@ -232,16 +186,7 @@ def evaluate_output_names(df: DaskLazyFrame) -> list[str]: rhs_names = other._evaluate_output_names(df) return [*(x for x in lhs_names if x not in rhs_names), *rhs_names] - return DaskSelector( - call, - depth=0, - function_name="selector", - evaluate_output_names=evaluate_output_names, - alias_output_names=None, - backend_version=self._backend_version, - version=self._version, - kwargs={}, - ) + return selector(self, call, evaluate_output_names) else: return self._to_expr() | other @@ -259,23 +204,27 @@ def evaluate_output_names(df: DaskLazyFrame) -> list[str]: rhs_names = other._evaluate_output_names(df) return [x for x in lhs_names if x in rhs_names] - return DaskSelector( - call, - depth=0, - function_name="selector", - evaluate_output_names=evaluate_output_names, - alias_output_names=None, - backend_version=self._backend_version, - version=self._version, - kwargs={}, - ) + return selector(self, call, evaluate_output_names) else: return self._to_expr() & other def __invert__(self: Self) -> DaskSelector: - return ( - DaskSelectorNamespace( - backend_version=self._backend_version, version=self._version - ).all() - - self - ) + return DaskSelectorNamespace(self).all() - self + + +def selector( + context: _LimitedContext, + call: Callable[[DaskLazyFrame], Sequence[dx.Series]], + evaluate_output_names: Callable[[DaskLazyFrame], Sequence[str]], + /, +) -> DaskSelector: + return DaskSelector( + call, + depth=0, + function_name="selector", + evaluate_output_names=evaluate_output_names, + alias_output_names=None, + backend_version=context._backend_version, + version=context._version, + kwargs={}, + ) diff --git a/narwhals/_duckdb/namespace.py b/narwhals/_duckdb/namespace.py index 8c5191a3be..2f50a70724 100644 --- a/narwhals/_duckdb/namespace.py +++ b/narwhals/_duckdb/namespace.py @@ -43,9 +43,7 @@ def __init__( @property def selectors(self: Self) -> DuckDBSelectorNamespace: - return DuckDBSelectorNamespace( - backend_version=self._backend_version, version=self._version - ) + return DuckDBSelectorNamespace(self) def all(self: Self) -> DuckDBExpr: def _all(df: DuckDBLazyFrame) -> list[duckdb.Expression]: diff --git a/narwhals/_duckdb/selectors.py b/narwhals/_duckdb/selectors.py index 8a84450641..a30cec06cb 100644 --- a/narwhals/_duckdb/selectors.py +++ b/narwhals/_duckdb/selectors.py @@ -3,6 +3,7 @@ import re from typing import TYPE_CHECKING from typing import Any +from typing import Callable from typing import Iterable from typing import Sequence @@ -22,15 +23,13 @@ from narwhals._duckdb.dataframe import DuckDBLazyFrame from narwhals.dtypes import DType from narwhals.typing import TimeUnit - from narwhals.utils import Version + from narwhals.utils import _LimitedContext class DuckDBSelectorNamespace: - def __init__( - self: Self, *, backend_version: tuple[int, ...], version: Version - ) -> None: - self._backend_version = backend_version - self._version = version + def __init__(self: Self, context: _LimitedContext, /) -> None: + self._backend_version = context._backend_version + self._version = context._version def by_dtype(self: Self, dtypes: Iterable[DType | type[DType]]) -> DuckDBSelector: def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: @@ -41,14 +40,7 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: def evaluate_output_names(df: DuckDBLazyFrame) -> Sequence[str]: return [col for col in df.columns if df.schema[col] in dtypes] - return DuckDBSelector( - func, - function_name="selector", - evaluate_output_names=evaluate_output_names, - alias_output_names=None, - backend_version=self._backend_version, - version=self._version, - ) + return selector(self, func, evaluate_output_names) def matches(self: Self, pattern: str) -> DuckDBSelector: def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: @@ -59,14 +51,7 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: def evaluate_output_names(df: DuckDBLazyFrame) -> Sequence[str]: return [col for col in df.columns if re.search(pattern, col)] - return DuckDBSelector( - func, - function_name="selector", - evaluate_output_names=evaluate_output_names, - alias_output_names=None, - backend_version=self._backend_version, - version=self._version, - ) + return selector(self, func, evaluate_output_names) def numeric(self: Self) -> DuckDBSelector: dtypes = import_dtypes_module(self._version) @@ -103,14 +88,7 @@ def all(self: Self) -> DuckDBSelector: def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: return [ColumnExpression(col) for col in df.columns] - return DuckDBSelector( - func, - function_name="selector", - evaluate_output_names=lambda df: df.columns, - alias_output_names=None, - backend_version=self._backend_version, - version=self._version, - ) + return selector(self, func, lambda df: df.columns) def datetime( self: Self, @@ -146,14 +124,7 @@ def evaluate_output_names(df: DuckDBLazyFrame) -> Sequence[str]: ) ] - return DuckDBSelector( - func, - function_name="selector", - evaluate_output_names=evaluate_output_names, - alias_output_names=None, - backend_version=self._backend_version, - version=self._version, - ) + return selector(self, func, evaluate_output_names) class DuckDBSelector(DuckDBExpr): @@ -184,14 +155,7 @@ def evaluate_output_names(df: DuckDBLazyFrame) -> list[str]: rhs_names = other._evaluate_output_names(df) return [x for x in lhs_names if x not in rhs_names] - return DuckDBSelector( - call, - function_name="selector", - evaluate_output_names=evaluate_output_names, - alias_output_names=None, - backend_version=self._backend_version, - version=self._version, - ) + return selector(self, call, evaluate_output_names) else: return self._to_expr() - other @@ -213,14 +177,7 @@ def evaluate_output_names(df: DuckDBLazyFrame) -> list[str]: rhs_names = other._evaluate_output_names(df) return [*(x for x in lhs_names if x not in rhs_names), *rhs_names] - return DuckDBSelector( - call, - function_name="selector", - evaluate_output_names=evaluate_output_names, - alias_output_names=None, - backend_version=self._backend_version, - version=self._version, - ) + return selector(self, call, evaluate_output_names) else: return self._to_expr() | other @@ -238,21 +195,25 @@ def evaluate_output_names(df: DuckDBLazyFrame) -> list[str]: rhs_names = other._evaluate_output_names(df) return [x for x in lhs_names if x in rhs_names] - return DuckDBSelector( - call, - function_name="selector", - evaluate_output_names=evaluate_output_names, - alias_output_names=None, - backend_version=self._backend_version, - version=self._version, - ) + return selector(self, call, evaluate_output_names) else: return self._to_expr() & other def __invert__(self: Self) -> DuckDBSelector: - return ( - DuckDBSelectorNamespace( - backend_version=self._backend_version, version=self._version - ).all() - - self - ) + return DuckDBSelectorNamespace(self).all() - self + + +def selector( + context: _LimitedContext, + call: Callable[[DuckDBLazyFrame], Sequence[duckdb.Expression]], + evaluate_output_names: Callable[[DuckDBLazyFrame], Sequence[str]], + /, +) -> DuckDBSelector: + return DuckDBSelector( + call, + function_name="selector", + evaluate_output_names=evaluate_output_names, + alias_output_names=None, + backend_version=context._backend_version, + version=context._version, + ) diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index 67b2314ef7..a0c8a3ac3c 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -36,11 +36,7 @@ class PandasLikeNamespace(CompliantNamespace[PandasLikeSeries]): @property def selectors(self: Self) -> PandasSelectorNamespace: - return PandasSelectorNamespace( - implementation=self._implementation, - backend_version=self._backend_version, - version=self._version, - ) + return PandasSelectorNamespace(self) # --- not in spec --- def __init__( diff --git a/narwhals/_pandas_like/selectors.py b/narwhals/_pandas_like/selectors.py index 471ae6407e..05afa05569 100644 --- a/narwhals/_pandas_like/selectors.py +++ b/narwhals/_pandas_like/selectors.py @@ -3,6 +3,7 @@ import re from typing import TYPE_CHECKING from typing import Any +from typing import Callable from typing import Iterable from typing import Sequence @@ -20,21 +21,14 @@ from narwhals._pandas_like.series import PandasLikeSeries from narwhals.dtypes import DType from narwhals.typing import TimeUnit - from narwhals.utils import Implementation - from narwhals.utils import Version + from narwhals.utils import _FullContext class PandasSelectorNamespace: - def __init__( - self: Self, - *, - implementation: Implementation, - backend_version: tuple[int, ...], - version: Version, - ) -> None: - self._implementation = implementation - self._backend_version = backend_version - self._version = version + def __init__(self: Self, context: _FullContext, /) -> None: + self._implementation = context._implementation + self._backend_version = context._backend_version + self._version = context._version def by_dtype(self: Self, dtypes: Iterable[DType | type[DType]]) -> PandasSelector: def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: @@ -43,17 +37,7 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: def evaluate_output_names(df: PandasLikeDataFrame) -> Sequence[str]: return [col for col in df.columns if df.schema[col] in dtypes] - return PandasSelector( - func, - depth=0, - function_name="selector", - evaluate_output_names=evaluate_output_names, - alias_output_names=None, - implementation=self._implementation, - backend_version=self._backend_version, - version=self._version, - kwargs={"dtypes": dtypes}, - ) + return selector(self, func, evaluate_output_names, {"dtypes": dtypes}) def matches(self: Self, pattern: str) -> PandasSelector: def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: @@ -62,17 +46,7 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: def evaluate_output_names(df: PandasLikeDataFrame) -> Sequence[str]: return [col for col in df.columns if re.search(pattern, col)] - return PandasSelector( - func, - depth=0, - function_name="selector", - evaluate_output_names=evaluate_output_names, - alias_output_names=None, - implementation=self._implementation, - backend_version=self._backend_version, - version=self._version, - kwargs={"pattern": pattern}, - ) + return selector(self, func, evaluate_output_names, {"pattern": pattern}) def numeric(self: Self) -> PandasSelector: dtypes = import_dtypes_module(self._version) @@ -109,17 +83,7 @@ def all(self: Self) -> PandasSelector: def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: return [df[col] for col in df.columns] - return PandasSelector( - func, - depth=0, - function_name="selector", - evaluate_output_names=lambda df: df.columns, - alias_output_names=None, - implementation=self._implementation, - backend_version=self._backend_version, - version=self._version, - kwargs={}, - ) + return selector(self, func, lambda df: df.columns, {}) def datetime( self: Self, @@ -155,17 +119,7 @@ def evaluate_output_names(df: PandasLikeDataFrame) -> Sequence[str]: ) ] - return PandasSelector( - func, - depth=0, - function_name="selector", - evaluate_output_names=evaluate_output_names, - alias_output_names=None, - implementation=self._implementation, - backend_version=self._backend_version, - version=self._version, - kwargs={}, - ) + return selector(self, func, evaluate_output_names, {}) class PandasSelector(PandasLikeExpr): @@ -201,16 +155,8 @@ def evaluate_output_names(df: PandasLikeDataFrame) -> list[str]: rhs_names = other._evaluate_output_names(df) return [x for x in lhs_names if x not in rhs_names] - return PandasSelector( - call, - depth=0, - function_name="selector", - evaluate_output_names=evaluate_output_names, - alias_output_names=None, - implementation=self._implementation, - backend_version=self._backend_version, - version=self._version, - kwargs={**self._kwargs, "other": other}, + return selector( + self, call, evaluate_output_names, {**self._kwargs, "other": other} ) else: return self._to_expr() - other @@ -233,16 +179,8 @@ def evaluate_output_names(df: PandasLikeDataFrame) -> list[str]: rhs_names = other._evaluate_output_names(df) return [*(x for x in lhs_names if x not in rhs_names), *rhs_names] - return PandasSelector( - call, - depth=0, - function_name="selector", - evaluate_output_names=evaluate_output_names, - alias_output_names=None, - implementation=self._implementation, - backend_version=self._backend_version, - version=self._version, - kwargs={**self._kwargs, "other": other}, + return selector( + self, call, evaluate_output_names, {**self._kwargs, "other": other} ) else: return self._to_expr() | other @@ -261,26 +199,31 @@ def evaluate_output_names(df: PandasLikeDataFrame) -> list[str]: rhs_names = other._evaluate_output_names(df) return [x for x in lhs_names if x in rhs_names] - return PandasSelector( - call, - depth=0, - function_name="selector", - evaluate_output_names=evaluate_output_names, - alias_output_names=None, - implementation=self._implementation, - backend_version=self._backend_version, - version=self._version, - kwargs={**self._kwargs, "other": other}, + return selector( + self, call, evaluate_output_names, {**self._kwargs, "other": other} ) else: return self._to_expr() & other def __invert__(self: Self) -> PandasSelector: - return ( - PandasSelectorNamespace( - implementation=self._implementation, - backend_version=self._backend_version, - version=self._version, - ).all() - - self - ) + return PandasSelectorNamespace(self).all() - self + + +def selector( + context: _FullContext, + call: Callable[[PandasLikeDataFrame], Sequence[PandasLikeSeries]], + evaluate_output_names: Callable[[PandasLikeDataFrame], Sequence[str]], + kwargs: dict[str, Any], + /, +) -> PandasSelector: + return PandasSelector( + call, + depth=0, + function_name="selector", + evaluate_output_names=evaluate_output_names, + alias_output_names=None, + implementation=context._implementation, + backend_version=context._backend_version, + version=context._version, + kwargs=kwargs, + ) diff --git a/narwhals/_spark_like/namespace.py b/narwhals/_spark_like/namespace.py index 836c1d6a4d..b6f51a60f8 100644 --- a/narwhals/_spark_like/namespace.py +++ b/narwhals/_spark_like/namespace.py @@ -42,11 +42,7 @@ def __init__( @property def selectors(self: Self) -> SparkLikeSelectorNamespace: - return SparkLikeSelectorNamespace( - backend_version=self._backend_version, - version=self._version, - implementation=self._implementation, - ) + return SparkLikeSelectorNamespace(self) def all(self: Self) -> SparkLikeExpr: def _all(df: SparkLikeLazyFrame) -> list[Column]: diff --git a/narwhals/_spark_like/selectors.py b/narwhals/_spark_like/selectors.py index d24e29d48b..e037e1f8a3 100644 --- a/narwhals/_spark_like/selectors.py +++ b/narwhals/_spark_like/selectors.py @@ -3,11 +3,11 @@ import re from typing import TYPE_CHECKING from typing import Any +from typing import Callable from typing import Iterable from typing import Sequence from narwhals._spark_like.expr import SparkLikeExpr -from narwhals.utils import Implementation from narwhals.utils import _parse_time_unit_and_time_zone from narwhals.utils import dtype_matches_time_unit_and_time_zone from narwhals.utils import import_dtypes_module @@ -21,20 +21,14 @@ from narwhals._spark_like.dataframe import SparkLikeLazyFrame from narwhals.dtypes import DType from narwhals.typing import TimeUnit - from narwhals.utils import Version + from narwhals.utils import _FullContext class SparkLikeSelectorNamespace: - def __init__( - self: Self, - *, - backend_version: tuple[int, ...], - version: Version, - implementation: Implementation, - ) -> None: - self._backend_version = backend_version - self._version = version - self._implementation = implementation + def __init__(self: Self, context: _FullContext, /) -> None: + self._backend_version = context._backend_version + self._version = context._version + self._implementation = context._implementation def by_dtype(self: Self, dtypes: Iterable[DType | type[DType]]) -> SparkLikeSelector: def func(df: SparkLikeLazyFrame) -> list[Column]: @@ -43,15 +37,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: def evaluate_output_names(df: SparkLikeLazyFrame) -> Sequence[str]: return [col for col in df.columns if df.schema[col] in dtypes] - return SparkLikeSelector( - func, - function_name="selector", - evaluate_output_names=evaluate_output_names, - alias_output_names=None, - backend_version=self._backend_version, - version=self._version, - implementation=self._implementation, - ) + return selector(self, func, evaluate_output_names) def matches(self: Self, pattern: str) -> SparkLikeSelector: def func(df: SparkLikeLazyFrame) -> list[Column]: @@ -60,15 +46,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: def evaluate_output_names(df: SparkLikeLazyFrame) -> Sequence[str]: return [col for col in df.columns if re.search(pattern, col)] - return SparkLikeSelector( - func, - function_name="selector", - evaluate_output_names=evaluate_output_names, - alias_output_names=None, - backend_version=self._backend_version, - version=self._version, - implementation=self._implementation, - ) + return selector(self, func, evaluate_output_names) def numeric(self: Self) -> SparkLikeSelector: dtypes = import_dtypes_module(self._version) @@ -105,15 +83,7 @@ def all(self: Self) -> SparkLikeSelector: def func(df: SparkLikeLazyFrame) -> list[Column]: return [df._F.col(col) for col in df.columns] - return SparkLikeSelector( - func, - function_name="selector", - evaluate_output_names=lambda df: df.columns, - alias_output_names=None, - backend_version=self._backend_version, - version=self._version, - implementation=self._implementation, - ) + return selector(self, func, lambda df: df.columns) def datetime( self: Self, @@ -149,15 +119,7 @@ def evaluate_output_names(df: SparkLikeLazyFrame) -> Sequence[str]: ) ] - return SparkLikeSelector( - func, - function_name="selector", - evaluate_output_names=evaluate_output_names, - alias_output_names=None, - backend_version=self._backend_version, - version=self._version, - implementation=self._implementation, - ) + return selector(self, func, evaluate_output_names) class SparkLikeSelector(SparkLikeExpr): @@ -189,15 +151,7 @@ def evaluate_output_names(df: SparkLikeLazyFrame) -> list[str]: rhs_names = other._evaluate_output_names(df) return [x for x in lhs_names if x not in rhs_names] - return SparkLikeSelector( - call, - function_name="selector", - evaluate_output_names=evaluate_output_names, - alias_output_names=None, - backend_version=self._backend_version, - version=self._version, - implementation=self._implementation, - ) + return selector(self, call, evaluate_output_names) else: return self._to_expr() - other @@ -219,15 +173,7 @@ def evaluate_output_names(df: SparkLikeLazyFrame) -> list[str]: rhs_names = other._evaluate_output_names(df) return [*(x for x in lhs_names if x not in rhs_names), *rhs_names] - return SparkLikeSelector( - call, - function_name="selector", - evaluate_output_names=evaluate_output_names, - alias_output_names=None, - backend_version=self._backend_version, - version=self._version, - implementation=self._implementation, - ) + return selector(self, call, evaluate_output_names) else: return self._to_expr() | other @@ -245,24 +191,26 @@ def evaluate_output_names(df: SparkLikeLazyFrame) -> list[str]: rhs_names = other._evaluate_output_names(df) return [x for x in lhs_names if x in rhs_names] - return SparkLikeSelector( - call, - function_name="selector", - evaluate_output_names=evaluate_output_names, - alias_output_names=None, - backend_version=self._backend_version, - version=self._version, - implementation=self._implementation, - ) + return selector(self, call, evaluate_output_names) else: return self._to_expr() & other def __invert__(self: Self) -> SparkLikeSelector: - return ( - SparkLikeSelectorNamespace( - backend_version=self._backend_version, - version=self._version, - implementation=self._implementation, - ).all() - - self - ) + return SparkLikeSelectorNamespace(self).all() - self + + +def selector( + context: _FullContext, + call: Callable[[SparkLikeLazyFrame], Sequence[Column]], + evaluate_output_names: Callable[[SparkLikeLazyFrame], Sequence[str]], + /, +) -> SparkLikeSelector: + return SparkLikeSelector( + call, + function_name="selector", + evaluate_output_names=evaluate_output_names, + alias_output_names=None, + backend_version=context._backend_version, + version=context._version, + implementation=context._implementation, + ) diff --git a/narwhals/utils.py b/narwhals/utils.py index 0d0b50b76b..cb33a603d6 100644 --- a/narwhals/utils.py +++ b/narwhals/utils.py @@ -74,6 +74,33 @@ class _SupportsVersion(Protocol): __version__: str + class _StoresImplementation(Protocol): + _implementation: Implementation + """Implementation of native object (pandas, Polars, PyArrow, ...).""" + + class _StoresBackendVersion(Protocol): + _backend_version: tuple[int, ...] + """Version tuple for a native package.""" + + class _StoresVersion(Protocol): + _version: Version + """Narwhals API version (V1 or MAIN).""" + + class _LimitedContext(_StoresBackendVersion, _StoresVersion, Protocol): + """Provides 2 attributes. + + - `_backend_version` + - `_version` + """ + + class _FullContext(_StoresImplementation, _LimitedContext, Protocol): # noqa: PYI046 + """Provides 3 attributes. + + - `_implementation` + - `_backend_version` + - `_version` + """ + class Version(Enum): V1 = auto()