diff --git a/narwhals/_arrow/expr.py b/narwhals/_arrow/expr.py index a712b458f3..ab1e0e54b0 100644 --- a/narwhals/_arrow/expr.py +++ b/narwhals/_arrow/expr.py @@ -87,12 +87,13 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: @classmethod def from_column_names( cls: type[Self], - *column_names: str, + evaluate_column_names: Callable[[ArrowDataFrame], Sequence[str]], + /, + *, + function_name: str, backend_version: tuple[int, ...], version: Version, ) -> Self: - from narwhals._arrow.series import ArrowSeries - def func(df: ArrowDataFrame) -> list[ArrowSeries]: try: return [ @@ -102,10 +103,12 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: backend_version=df._backend_version, version=df._version, ) - for column_name in column_names + for column_name in evaluate_column_names(df) ] except KeyError as e: - missing_columns = [x for x in column_names if x not in df.columns] + missing_columns = [ + x for x in evaluate_column_names(df) if x not in df.columns + ] raise ColumnNotFoundError.from_missing_and_available_column_names( missing_columns=missing_columns, available_columns=df.columns ) from e @@ -113,8 +116,8 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: return cls( func, depth=0, - function_name="col", - evaluate_output_names=lambda _df: column_names, + function_name=function_name, + evaluate_output_names=evaluate_column_names, alias_output_names=None, backend_version=backend_version, version=version, diff --git a/narwhals/_arrow/namespace.py b/narwhals/_arrow/namespace.py index e740b85929..ac49998066 100644 --- a/narwhals/_arrow/namespace.py +++ b/narwhals/_arrow/namespace.py @@ -1,6 +1,7 @@ from __future__ import annotations import operator +from functools import partial from functools import reduce from typing import TYPE_CHECKING from typing import Any @@ -27,8 +28,10 @@ from narwhals._expression_parsing import combine_evaluate_output_names from narwhals.typing import CompliantNamespace from narwhals.utils import Implementation +from narwhals.utils import exclude_column_names from narwhals.utils import get_column_names from narwhals.utils import import_dtypes_module +from narwhals.utils import passthrough_column_names if TYPE_CHECKING: from typing import Callable @@ -118,36 +121,18 @@ def col(self: Self, *column_names: str) -> ArrowExpr: from narwhals._arrow.expr import ArrowExpr return ArrowExpr.from_column_names( - *column_names, backend_version=self._backend_version, version=self._version + passthrough_column_names(column_names), + function_name="col", + backend_version=self._backend_version, + version=self._version, ) def exclude(self: Self, excluded_names: Container[str]) -> ArrowExpr: - from narwhals._arrow.series import ArrowSeries - - def evaluate_output_names(df: ArrowDataFrame) -> Sequence[str]: - return [ - column_name - for column_name in df.columns - if column_name not in excluded_names - ] - - def func(df: ArrowDataFrame) -> list[ArrowSeries]: - return [ - ArrowSeries( - df._native_frame[column_name], - name=column_name, - backend_version=df._backend_version, - version=df._version, - ) - for column_name in evaluate_output_names(df) - ] - - return self._create_expr_from_callable( - func, - depth=0, + return ArrowExpr.from_column_names( + partial(exclude_column_names, names=excluded_names), function_name="exclude", - evaluate_output_names=evaluate_output_names, - alias_output_names=None, + backend_version=self._backend_version, + version=self._version, ) def nth(self: Self, *column_indices: int) -> ArrowExpr: @@ -177,23 +162,9 @@ def len(self: Self) -> ArrowExpr: ) def all(self: Self) -> ArrowExpr: - from narwhals._arrow.expr import ArrowExpr - from narwhals._arrow.series import ArrowSeries - - return ArrowExpr( - lambda df: [ - ArrowSeries( - df._native_frame[column_name], - name=column_name, - backend_version=df._backend_version, - version=df._version, - ) - for column_name in df.columns - ], - depth=0, + return ArrowExpr.from_column_names( + get_column_names, function_name="all", - evaluate_output_names=get_column_names, - alias_output_names=None, backend_version=self._backend_version, version=self._version, ) diff --git a/narwhals/_dask/expr.py b/narwhals/_dask/expr.py index c6a861c3d8..b1930e5628 100644 --- a/narwhals/_dask/expr.py +++ b/narwhals/_dask/expr.py @@ -90,15 +90,23 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: @classmethod def from_column_names( cls: type[Self], - *column_names: str, + evaluate_column_names: Callable[[DaskLazyFrame], Sequence[str]], + /, + *, + function_name: str, backend_version: tuple[int, ...], version: Version, ) -> Self: def func(df: DaskLazyFrame) -> list[dx.Series]: try: - return [df._native_frame[column_name] for column_name in column_names] + return [ + df._native_frame[column_name] + for column_name in evaluate_column_names(df) + ] except KeyError as e: - missing_columns = [x for x in column_names if x not in df.columns] + missing_columns = [ + x for x in evaluate_column_names(df) if x not in df.columns + ] raise ColumnNotFoundError.from_missing_and_available_column_names( missing_columns=missing_columns, available_columns=df.columns, @@ -107,8 +115,8 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: return cls( func, depth=0, - function_name="col", - evaluate_output_names=lambda _df: column_names, + function_name=function_name, + evaluate_output_names=evaluate_column_names, alias_output_names=None, backend_version=backend_version, version=version, diff --git a/narwhals/_dask/namespace.py b/narwhals/_dask/namespace.py index 8d64cf9d11..eddbe7925f 100644 --- a/narwhals/_dask/namespace.py +++ b/narwhals/_dask/namespace.py @@ -1,6 +1,7 @@ from __future__ import annotations import operator +from functools import partial from functools import reduce from typing import TYPE_CHECKING from typing import Any @@ -25,7 +26,9 @@ from narwhals._expression_parsing import combine_evaluate_output_names from narwhals.typing import CompliantNamespace from narwhals.utils import Implementation +from narwhals.utils import exclude_column_names from narwhals.utils import get_column_names +from narwhals.utils import passthrough_column_names if TYPE_CHECKING: from typing_extensions import Self @@ -53,43 +56,25 @@ def __init__( self._version = version def all(self: Self) -> DaskExpr: - def func(df: DaskLazyFrame) -> list[dx.Series]: - return [df._native_frame[column_name] for column_name in df.columns] - - return DaskExpr( - func, - depth=0, + return DaskExpr.from_column_names( + get_column_names, function_name="all", - evaluate_output_names=get_column_names, - alias_output_names=None, backend_version=self._backend_version, version=self._version, ) def col(self: Self, *column_names: str) -> DaskExpr: return DaskExpr.from_column_names( - *column_names, backend_version=self._backend_version, version=self._version + passthrough_column_names(column_names), + function_name="col", + backend_version=self._backend_version, + version=self._version, ) def exclude(self: Self, excluded_names: Container[str]) -> DaskExpr: - def evaluate_output_names(df: DaskLazyFrame) -> Sequence[str]: - return [ - column_name - for column_name in df.columns - if column_name not in excluded_names - ] - - def func(df: DaskLazyFrame) -> list[dx.Series]: - return [ - df._native_frame[column_name] for column_name in evaluate_output_names(df) - ] - - return DaskExpr( - func, - depth=0, + return DaskExpr.from_column_names( + partial(exclude_column_names, names=excluded_names), function_name="exclude", - evaluate_output_names=evaluate_output_names, - alias_output_names=None, backend_version=self._backend_version, version=self._version, ) diff --git a/narwhals/_duckdb/expr.py b/narwhals/_duckdb/expr.py index 51f4ac6e47..2f47942b55 100644 --- a/narwhals/_duckdb/expr.py +++ b/narwhals/_duckdb/expr.py @@ -79,17 +79,20 @@ def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Se @classmethod def from_column_names( cls: type[Self], - *column_names: str, + evaluate_column_names: Callable[[DuckDBLazyFrame], Sequence[str]], + /, + *, + function_name: str, backend_version: tuple[int, ...], version: Version, ) -> Self: - def func(_: DuckDBLazyFrame) -> list[duckdb.Expression]: - return [ColumnExpression(col_name) for col_name in column_names] + def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: + return [ColumnExpression(col_name) for col_name in evaluate_column_names(df)] return cls( func, - function_name="col", - evaluate_output_names=lambda _df: column_names, + function_name=function_name, + evaluate_output_names=evaluate_column_names, alias_output_names=None, backend_version=backend_version, version=version, diff --git a/narwhals/_duckdb/namespace.py b/narwhals/_duckdb/namespace.py index 45c877965e..f8780fbc44 100644 --- a/narwhals/_duckdb/namespace.py +++ b/narwhals/_duckdb/namespace.py @@ -1,7 +1,7 @@ from __future__ import annotations -import functools import operator +from functools import partial from functools import reduce from typing import TYPE_CHECKING from typing import Any @@ -12,7 +12,6 @@ from duckdb import CaseExpression from duckdb import CoalesceOperator -from duckdb import ColumnExpression from duckdb import FunctionExpression from duckdb.typing import BIGINT from duckdb.typing import VARCHAR @@ -26,7 +25,9 @@ from narwhals._expression_parsing import combine_evaluate_output_names from narwhals.typing import CompliantNamespace from narwhals.utils import Implementation +from narwhals.utils import exclude_column_names from narwhals.utils import get_column_names +from narwhals.utils import passthrough_column_names if TYPE_CHECKING: import duckdb @@ -51,14 +52,9 @@ def selectors(self: Self) -> DuckDBSelectorNamespace: return DuckDBSelectorNamespace(self) def all(self: Self) -> DuckDBExpr: - def _all(df: DuckDBLazyFrame) -> list[duckdb.Expression]: - return [ColumnExpression(col_name) for col_name in df.columns] - - return DuckDBExpr( - call=_all, + return DuckDBExpr.from_column_names( + get_column_names, function_name="all", - evaluate_output_names=get_column_names, - alias_output_names=None, backend_version=self._backend_version, version=self._version, ) @@ -80,9 +76,7 @@ def concat( if how == "vertical" and not all(x.schema == schema for x in items[1:]): msg = "inputs should all have the same schema" raise TypeError(msg) - res = functools.reduce( - lambda x, y: x.union(y), (item._native_frame for item in items) - ) + res = reduce(lambda x, y: x.union(y), (item._native_frame for item in items)) return first._from_native_frame(res) def concat_str( @@ -238,27 +232,16 @@ def when(self: Self, predicate: DuckDBExpr) -> DuckDBWhen: def col(self: Self, *column_names: str) -> DuckDBExpr: return DuckDBExpr.from_column_names( - *column_names, backend_version=self._backend_version, version=self._version + passthrough_column_names(column_names), + function_name="col", + backend_version=self._backend_version, + version=self._version, ) def exclude(self: Self, excluded_names: Container[str]) -> DuckDBExpr: - def evaluate_output_names(df: DuckDBLazyFrame) -> Sequence[str]: - return [ - column_name - for column_name in df.columns - if column_name not in excluded_names - ] - - def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: - return [ - ColumnExpression(column_name) for column_name in evaluate_output_names(df) - ] - - return DuckDBExpr( - func, + return DuckDBExpr.from_column_names( + partial(exclude_column_names, names=excluded_names), function_name="exclude", - evaluate_output_names=evaluate_output_names, - alias_output_names=None, backend_version=self._backend_version, version=self._version, ) diff --git a/narwhals/_pandas_like/expr.py b/narwhals/_pandas_like/expr.py index 857f1afaad..3caf6c9828 100644 --- a/narwhals/_pandas_like/expr.py +++ b/narwhals/_pandas_like/expr.py @@ -115,7 +115,10 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: @classmethod def from_column_names( cls: type[Self], - *column_names: str, + evaluate_column_names: Callable[[PandasLikeDataFrame], Sequence[str]], + /, + *, + function_name: str, implementation: Implementation, backend_version: tuple[int, ...], version: Version, @@ -129,10 +132,12 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: backend_version=df._backend_version, version=df._version, ) - for column_name in column_names + for column_name in evaluate_column_names(df) ] except KeyError as e: - missing_columns = [x for x in column_names if x not in df.columns] + missing_columns = [ + x for x in evaluate_column_names(df) if x not in df.columns + ] raise ColumnNotFoundError.from_missing_and_available_column_names( missing_columns=missing_columns, available_columns=df.columns, @@ -141,8 +146,8 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: return cls( func, depth=0, - function_name="col", - evaluate_output_names=lambda _df: column_names, + function_name=function_name, + evaluate_output_names=evaluate_column_names, alias_output_names=None, implementation=implementation, backend_version=backend_version, diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index 75b5cec617..4b4994c6d0 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -1,6 +1,7 @@ from __future__ import annotations import operator +from functools import partial from functools import reduce from typing import TYPE_CHECKING from typing import Any @@ -23,8 +24,10 @@ from narwhals._pandas_like.utils import horizontal_concat from narwhals._pandas_like.utils import vertical_concat from narwhals.typing import CompliantNamespace +from narwhals.utils import exclude_column_names from narwhals.utils import get_column_names from narwhals.utils import import_dtypes_module +from narwhals.utils import passthrough_column_names if TYPE_CHECKING: from typing_extensions import Self @@ -110,37 +113,20 @@ def _create_compliant_series(self: Self, value: Any) -> PandasLikeSeries: # --- selection --- def col(self: Self, *column_names: str) -> PandasLikeExpr: return PandasLikeExpr.from_column_names( - *column_names, + passthrough_column_names(column_names), + function_name="col", implementation=self._implementation, backend_version=self._backend_version, version=self._version, ) def exclude(self: Self, excluded_names: Container[str]) -> PandasLikeExpr: - def evaluate_output_names(df: PandasLikeDataFrame) -> Sequence[str]: - return [ - column_name - for column_name in df.columns - if column_name not in excluded_names - ] - - def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: - return [ - PandasLikeSeries( - df._native_frame[column_name], - implementation=df._implementation, - backend_version=df._backend_version, - version=df._version, - ) - for column_name in evaluate_output_names(df) - ] - - return self._create_expr_from_callable( - func, - depth=0, - evaluate_output_names=evaluate_output_names, + return PandasLikeExpr.from_column_names( + partial(exclude_column_names, names=excluded_names), function_name="exclude", - alias_output_names=None, + implementation=self._implementation, + backend_version=self._backend_version, + version=self._version, ) def nth(self: Self, *column_indices: int) -> PandasLikeExpr: @@ -152,20 +138,9 @@ def nth(self: Self, *column_indices: int) -> PandasLikeExpr: ) def all(self: Self) -> PandasLikeExpr: - return PandasLikeExpr( - lambda df: [ - PandasLikeSeries( - df._native_frame[column_name], - implementation=self._implementation, - backend_version=self._backend_version, - version=self._version, - ) - for column_name in df.columns - ], - depth=0, + return PandasLikeExpr.from_column_names( + get_column_names, function_name="all", - evaluate_output_names=get_column_names, - alias_output_names=None, implementation=self._implementation, backend_version=self._backend_version, version=self._version, diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index 41bbe46c5e..f2c09ea8b9 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -131,18 +131,21 @@ def __narwhals_namespace__(self: Self) -> SparkLikeNamespace: # pragma: no cove @classmethod def from_column_names( cls: type[Self], - *column_names: str, + evaluate_column_names: Callable[[SparkLikeLazyFrame], Sequence[str]], + /, + *, + function_name: str, + implementation: Implementation, backend_version: tuple[int, ...], version: Version, - implementation: Implementation, ) -> Self: def func(df: SparkLikeLazyFrame) -> list[Column]: - return [df._F.col(col_name) for col_name in column_names] + return [df._F.col(col_name) for col_name in evaluate_column_names(df)] return cls( func, - function_name="col", - evaluate_output_names=lambda _df: column_names, + function_name=function_name, + evaluate_output_names=evaluate_column_names, alias_output_names=None, backend_version=backend_version, version=version, diff --git a/narwhals/_spark_like/namespace.py b/narwhals/_spark_like/namespace.py index 91cbd4f30c..91cad2bbb6 100644 --- a/narwhals/_spark_like/namespace.py +++ b/narwhals/_spark_like/namespace.py @@ -1,6 +1,7 @@ from __future__ import annotations import operator +from functools import partial from functools import reduce from typing import TYPE_CHECKING from typing import Any @@ -19,7 +20,9 @@ from narwhals._spark_like.utils import maybe_evaluate_expr from narwhals._spark_like.utils import narwhals_to_native_dtype from narwhals.typing import CompliantNamespace +from narwhals.utils import exclude_column_names from narwhals.utils import get_column_names +from narwhals.utils import passthrough_column_names if TYPE_CHECKING: from pyspark.sql import Column @@ -48,46 +51,30 @@ def selectors(self: Self) -> SparkLikeSelectorNamespace: return SparkLikeSelectorNamespace(self) def all(self: Self) -> SparkLikeExpr: - def _all(df: SparkLikeLazyFrame) -> list[Column]: - return [df._F.col(col_name) for col_name in df.columns] - - return SparkLikeExpr( - call=_all, + return SparkLikeExpr.from_column_names( + get_column_names, function_name="all", - evaluate_output_names=get_column_names, - alias_output_names=None, + implementation=self._implementation, backend_version=self._backend_version, version=self._version, - implementation=self._implementation, ) def col(self: Self, *column_names: str) -> SparkLikeExpr: return SparkLikeExpr.from_column_names( - *column_names, + passthrough_column_names(column_names), + function_name="col", + implementation=self._implementation, backend_version=self._backend_version, version=self._version, - implementation=self._implementation, ) def exclude(self: Self, excluded_names: Container[str]) -> SparkLikeExpr: - def evaluate_output_names(df: SparkLikeLazyFrame) -> Sequence[str]: - return [ - column_name - for column_name in df.columns - if column_name not in excluded_names - ] - - def func(df: SparkLikeLazyFrame) -> list[Column]: - return [df._F.col(column_name) for column_name in evaluate_output_names(df)] - - return SparkLikeExpr( - func, + return SparkLikeExpr.from_column_names( + partial(exclude_column_names, names=excluded_names), function_name="exclude", - evaluate_output_names=evaluate_output_names, - alias_output_names=None, + implementation=self._implementation, backend_version=self._backend_version, version=self._version, - implementation=self._implementation, ) def nth(self: Self, *column_indices: int) -> SparkLikeExpr: diff --git a/narwhals/utils.py b/narwhals/utils.py index f5959bb550..a75ead5f3d 100644 --- a/narwhals/utils.py +++ b/narwhals/utils.py @@ -9,6 +9,8 @@ from secrets import token_hex from typing import TYPE_CHECKING from typing import Any +from typing import Callable +from typing import Container from typing import Iterable from typing import Literal from typing import Sequence @@ -1356,6 +1358,17 @@ def get_column_names(frame: _StoresColumns, /) -> Sequence[str]: return frame.columns +def exclude_column_names(frame: _StoresColumns, names: Container[str]) -> Sequence[str]: + return [col_name for col_name in frame.columns if col_name not in names] + + +def passthrough_column_names(names: Sequence[str], /) -> Callable[[Any], Sequence[str]]: + def fn(_frame: Any, /) -> Sequence[str]: + return names + + return fn + + def _hasattr_static(obj: Any, attr: str) -> bool: sentinel = object() return getattr_static(obj, attr, sentinel) is not sentinel