diff --git a/docs/api-reference/narwhals.md b/docs/api-reference/narwhals.md index f828a33da2..fced41ef45 100644 --- a/docs/api-reference/narwhals.md +++ b/docs/api-reference/narwhals.md @@ -12,6 +12,7 @@ Here are the top-level functions available in Narwhals. - col - concat - concat_str + - exclude - from_arrow - from_dict - from_native diff --git a/narwhals/__init__.py b/narwhals/__init__.py index c90179b9e1..eea037a692 100644 --- a/narwhals/__init__.py +++ b/narwhals/__init__.py @@ -40,6 +40,7 @@ from narwhals.functions import col from narwhals.functions import concat from narwhals.functions import concat_str +from narwhals.functions import exclude from narwhals.functions import from_arrow from narwhals.functions import from_dict from narwhals.functions import from_numpy @@ -123,6 +124,7 @@ "dependencies", "dtypes", "exceptions", + "exclude", "from_arrow", "from_dict", "from_native", diff --git a/narwhals/_arrow/namespace.py b/narwhals/_arrow/namespace.py index a58c595aa8..e740b85929 100644 --- a/narwhals/_arrow/namespace.py +++ b/narwhals/_arrow/namespace.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING from typing import Any from typing import Callable +from typing import Container from typing import Iterable from typing import Literal from typing import Sequence @@ -120,6 +121,35 @@ def col(self: Self, *column_names: str) -> ArrowExpr: *column_names, backend_version=self._backend_version, version=self._version ) + def exclude(self: Self, excluded_names: Container[str]) -> ArrowExpr: + from narwhals._arrow.series import ArrowSeries + + def evaluate_output_names(df: ArrowDataFrame) -> Sequence[str]: + return [ + column_name + for column_name in df.columns + if column_name not in excluded_names + ] + + def func(df: ArrowDataFrame) -> list[ArrowSeries]: + return [ + ArrowSeries( + df._native_frame[column_name], + name=column_name, + backend_version=df._backend_version, + version=df._version, + ) + for column_name in evaluate_output_names(df) + ] + + return self._create_expr_from_callable( + func, + depth=0, + function_name="exclude", + evaluate_output_names=evaluate_output_names, + alias_output_names=None, + ) + def nth(self: Self, *column_indices: int) -> ArrowExpr: from narwhals._arrow.expr import ArrowExpr diff --git a/narwhals/_dask/namespace.py b/narwhals/_dask/namespace.py index 5f9fa982f6..828ad2afdd 100644 --- a/narwhals/_dask/namespace.py +++ b/narwhals/_dask/namespace.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING from typing import Any from typing import Callable +from typing import Container from typing import Iterable from typing import Literal from typing import Sequence @@ -67,6 +68,29 @@ def col(self: Self, *column_names: str) -> DaskExpr: *column_names, backend_version=self._backend_version, version=self._version ) + def exclude(self: Self, excluded_names: Container[str]) -> DaskExpr: + def evaluate_output_names(df: DaskLazyFrame) -> Sequence[str]: + return [ + column_name + for column_name in df.columns + if column_name not in excluded_names + ] + + def func(df: DaskLazyFrame) -> list[dx.Series]: + return [ + df._native_frame[column_name] for column_name in evaluate_output_names(df) + ] + + return DaskExpr( + func, + depth=0, + function_name="exclude", + evaluate_output_names=evaluate_output_names, + alias_output_names=None, + backend_version=self._backend_version, + version=self._version, + ) + def nth(self: Self, *column_indices: int) -> DaskExpr: return DaskExpr.from_column_indices( *column_indices, backend_version=self._backend_version, version=self._version diff --git a/narwhals/_duckdb/namespace.py b/narwhals/_duckdb/namespace.py index 4ea72fae3b..047355e940 100644 --- a/narwhals/_duckdb/namespace.py +++ b/narwhals/_duckdb/namespace.py @@ -6,6 +6,7 @@ from typing import TYPE_CHECKING from typing import Any from typing import Callable +from typing import Container from typing import Literal from typing import Sequence @@ -237,6 +238,28 @@ def col(self: Self, *column_names: str) -> DuckDBExpr: *column_names, backend_version=self._backend_version, version=self._version ) + def exclude(self: Self, excluded_names: Container[str]) -> DuckDBExpr: + def evaluate_output_names(df: DuckDBLazyFrame) -> Sequence[str]: + return [ + column_name + for column_name in df.columns + if column_name not in excluded_names + ] + + def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: + return [ + ColumnExpression(column_name) for column_name in evaluate_output_names(df) + ] + + return DuckDBExpr( + func, + function_name="exclude", + evaluate_output_names=evaluate_output_names, + alias_output_names=None, + backend_version=self._backend_version, + version=self._version, + ) + def nth(self: Self, *column_indices: int) -> DuckDBExpr: return DuckDBExpr.from_column_indices( *column_indices, backend_version=self._backend_version, version=self._version diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index 4e1b887fc0..75b5cec617 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING from typing import Any from typing import Callable +from typing import Container from typing import Iterable from typing import Literal from typing import Sequence @@ -115,6 +116,33 @@ def col(self: Self, *column_names: str) -> PandasLikeExpr: version=self._version, ) + def exclude(self: Self, excluded_names: Container[str]) -> PandasLikeExpr: + def evaluate_output_names(df: PandasLikeDataFrame) -> Sequence[str]: + return [ + column_name + for column_name in df.columns + if column_name not in excluded_names + ] + + def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: + return [ + PandasLikeSeries( + df._native_frame[column_name], + implementation=df._implementation, + backend_version=df._backend_version, + version=df._version, + ) + for column_name in evaluate_output_names(df) + ] + + return self._create_expr_from_callable( + func, + depth=0, + evaluate_output_names=evaluate_output_names, + function_name="exclude", + alias_output_names=None, + ) + def nth(self: Self, *column_indices: int) -> PandasLikeExpr: return PandasLikeExpr.from_column_indices( *column_indices, diff --git a/narwhals/_spark_like/namespace.py b/narwhals/_spark_like/namespace.py index 07303aee67..91cbd4f30c 100644 --- a/narwhals/_spark_like/namespace.py +++ b/narwhals/_spark_like/namespace.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING from typing import Any from typing import Callable +from typing import Container from typing import Iterable from typing import Literal from typing import Sequence @@ -68,6 +69,27 @@ def col(self: Self, *column_names: str) -> SparkLikeExpr: implementation=self._implementation, ) + def exclude(self: Self, excluded_names: Container[str]) -> SparkLikeExpr: + def evaluate_output_names(df: SparkLikeLazyFrame) -> Sequence[str]: + return [ + column_name + for column_name in df.columns + if column_name not in excluded_names + ] + + def func(df: SparkLikeLazyFrame) -> list[Column]: + return [df._F.col(column_name) for column_name in evaluate_output_names(df)] + + return SparkLikeExpr( + func, + function_name="exclude", + evaluate_output_names=evaluate_output_names, + alias_output_names=None, + backend_version=self._backend_version, + version=self._version, + implementation=self._implementation, + ) + def nth(self: Self, *column_indices: int) -> SparkLikeExpr: return SparkLikeExpr.from_column_indices( *column_indices, diff --git a/narwhals/functions.py b/narwhals/functions.py index 5e80e05c64..e1f85f5d1f 100644 --- a/narwhals/functions.py +++ b/narwhals/functions.py @@ -1040,6 +1040,43 @@ def func(plx: Any) -> Any: return Expr(func, ExprMetadata.selector()) +def exclude(*names: str | Iterable[str]) -> Expr: + """Creates an expression that excludes columns by their name(s). + + Arguments: + names: Name(s) of the columns to exclude. + + Returns: + A new expression. + + Examples: + >>> import polars as pl + >>> import narwhals as nw + >>> + >>> df_native = pl.DataFrame({"a": [1, 2], "b": [3, 4], "c": ["x", "z"]}) + >>> nw.from_native(df_native).select(nw.exclude("c", "a")) + ┌──────────────────┐ + |Narwhals DataFrame| + |------------------| + | shape: (2, 1) | + | ┌─────┐ | + | │ b │ | + | │ --- │ | + | │ i64 │ | + | ╞═════╡ | + | │ 3 │ | + | │ 4 │ | + | └─────┘ | + └──────────────────┘ + """ + exclude_names = frozenset(flatten(names)) + + def func(plx: Any) -> Any: + return plx.exclude(exclude_names) + + return Expr(func, ExprMetadata.selector()) + + def nth(*indices: int | Sequence[int]) -> Expr: """Creates an expression that references one or more columns by their index(es). diff --git a/narwhals/stable/v1/__init__.py b/narwhals/stable/v1/__init__.py index 8a0e2b1335..fbdc935112 100644 --- a/narwhals/stable/v1/__init__.py +++ b/narwhals/stable/v1/__init__.py @@ -1826,6 +1826,18 @@ def col(*names: str | Iterable[str]) -> Expr: return _stableify(nw.col(*names)) +def exclude(*names: str | Iterable[str]) -> Expr: + """Creates an expression that excludes columns by their name(s). + + Arguments: + names: Name(s) of the columns to exclude. + + Returns: + A new expression. + """ + return _stableify(nw.exclude(*names)) + + def nth(*indices: int | Sequence[int]) -> Expr: """Creates an expression that references one or more columns by their index(es). @@ -2417,6 +2429,7 @@ def scan_parquet( "dependencies", "dtypes", "exceptions", + "exclude", "from_arrow", "from_dict", "from_native", diff --git a/tests/expr_and_series/exclude_test.py b/tests/expr_and_series/exclude_test.py new file mode 100644 index 0000000000..a6679faa71 --- /dev/null +++ b/tests/expr_and_series/exclude_test.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +import narwhals.stable.v1 as nw +from tests.utils import assert_equal_data + +if TYPE_CHECKING: + from tests.utils import Constructor + + +@pytest.mark.parametrize( + ("exclude_selector", "expected_cols"), + [ + (nw.exclude("a"), ["b", "z"]), + (nw.exclude("b", "z"), ["a"]), + (nw.exclude(["a"]), ["b", "z"]), + (nw.exclude(["b", "z"]), ["a"]), + ], +) +def test_exclude( + constructor: Constructor, exclude_selector: nw.Expr, expected_cols: list[str] +) -> None: + data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8.0, 9.0]} + + df = nw.from_native(constructor(data)) + result = df.select(exclude_selector) + + expected = {col: data[col] for col in expected_cols} + assert_equal_data(result, expected)