Skip to content
17 changes: 10 additions & 7 deletions narwhals/_arrow/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,13 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
@classmethod
def from_column_names(
cls: type[Self],
*column_names: str,
evaluate_column_names: Callable[[ArrowDataFrame], Sequence[str]],
/,
*,
function_name: str,
backend_version: tuple[int, ...],
version: Version,
) -> Self:
from narwhals._arrow.series import ArrowSeries

def func(df: ArrowDataFrame) -> list[ArrowSeries]:
try:
return [
Expand All @@ -102,19 +103,21 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
backend_version=df._backend_version,
version=df._version,
)
for column_name in column_names
for column_name in evaluate_column_names(df)
]
except KeyError as e:
missing_columns = [x for x in column_names if x not in df.columns]
missing_columns = [
x for x in evaluate_column_names(df) if x not in df.columns
]
raise ColumnNotFoundError.from_missing_and_available_column_names(
missing_columns=missing_columns, available_columns=df.columns
) from e

return cls(
func,
depth=0,
function_name="col",
evaluate_output_names=lambda _df: column_names,
function_name=function_name,
evaluate_output_names=evaluate_column_names,
alias_output_names=None,
backend_version=backend_version,
version=version,
Expand Down
55 changes: 13 additions & 42 deletions narwhals/_arrow/namespace.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import operator
from functools import partial
from functools import reduce
from typing import TYPE_CHECKING
from typing import Any
Expand All @@ -27,8 +28,10 @@
from narwhals._expression_parsing import combine_evaluate_output_names
from narwhals.typing import CompliantNamespace
from narwhals.utils import Implementation
from narwhals.utils import exclude_column_names
from narwhals.utils import get_column_names
from narwhals.utils import import_dtypes_module
from narwhals.utils import passthrough_column_names

if TYPE_CHECKING:
from typing import Callable
Expand Down Expand Up @@ -118,36 +121,18 @@ def col(self: Self, *column_names: str) -> ArrowExpr:
from narwhals._arrow.expr import ArrowExpr

return ArrowExpr.from_column_names(
*column_names, backend_version=self._backend_version, version=self._version
passthrough_column_names(column_names),
function_name="col",
backend_version=self._backend_version,
version=self._version,
)

def exclude(self: Self, excluded_names: Container[str]) -> ArrowExpr:
from narwhals._arrow.series import ArrowSeries

def evaluate_output_names(df: ArrowDataFrame) -> Sequence[str]:
return [
column_name
for column_name in df.columns
if column_name not in excluded_names
]

def func(df: ArrowDataFrame) -> list[ArrowSeries]:
return [
ArrowSeries(
df._native_frame[column_name],
name=column_name,
backend_version=df._backend_version,
version=df._version,
)
for column_name in evaluate_output_names(df)
]

return self._create_expr_from_callable(
func,
depth=0,
return ArrowExpr.from_column_names(
partial(exclude_column_names, names=excluded_names),
function_name="exclude",
evaluate_output_names=evaluate_output_names,
alias_output_names=None,
backend_version=self._backend_version,
version=self._version,
)

def nth(self: Self, *column_indices: int) -> ArrowExpr:
Expand Down Expand Up @@ -177,23 +162,9 @@ def len(self: Self) -> ArrowExpr:
)

def all(self: Self) -> ArrowExpr:
from narwhals._arrow.expr import ArrowExpr
from narwhals._arrow.series import ArrowSeries

return ArrowExpr(
lambda df: [
ArrowSeries(
df._native_frame[column_name],
name=column_name,
backend_version=df._backend_version,
version=df._version,
)
for column_name in df.columns
],
depth=0,
return ArrowExpr.from_column_names(
get_column_names,
function_name="all",
evaluate_output_names=get_column_names,
alias_output_names=None,
backend_version=self._backend_version,
version=self._version,
)
Expand Down
18 changes: 13 additions & 5 deletions narwhals/_dask/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,23 @@ def func(df: DaskLazyFrame) -> list[dx.Series]:
@classmethod
def from_column_names(
cls: type[Self],
*column_names: str,
evaluate_column_names: Callable[[DaskLazyFrame], Sequence[str]],
/,
*,
function_name: str,
backend_version: tuple[int, ...],
version: Version,
) -> Self:
def func(df: DaskLazyFrame) -> list[dx.Series]:
try:
return [df._native_frame[column_name] for column_name in column_names]
return [
df._native_frame[column_name]
for column_name in evaluate_column_names(df)
]
except KeyError as e:
missing_columns = [x for x in column_names if x not in df.columns]
missing_columns = [
x for x in evaluate_column_names(df) if x not in df.columns
]
raise ColumnNotFoundError.from_missing_and_available_column_names(
missing_columns=missing_columns,
available_columns=df.columns,
Expand All @@ -107,8 +115,8 @@ def func(df: DaskLazyFrame) -> list[dx.Series]:
return cls(
func,
depth=0,
function_name="col",
evaluate_output_names=lambda _df: column_names,
function_name=function_name,
evaluate_output_names=evaluate_column_names,
alias_output_names=None,
backend_version=backend_version,
version=version,
Expand Down
37 changes: 11 additions & 26 deletions narwhals/_dask/namespace.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import operator
from functools import partial
from functools import reduce
from typing import TYPE_CHECKING
from typing import Any
Expand All @@ -25,7 +26,9 @@
from narwhals._expression_parsing import combine_evaluate_output_names
from narwhals.typing import CompliantNamespace
from narwhals.utils import Implementation
from narwhals.utils import exclude_column_names
from narwhals.utils import get_column_names
from narwhals.utils import passthrough_column_names

if TYPE_CHECKING:
from typing_extensions import Self
Expand Down Expand Up @@ -53,43 +56,25 @@ def __init__(
self._version = version

def all(self: Self) -> DaskExpr:
def func(df: DaskLazyFrame) -> list[dx.Series]:
return [df._native_frame[column_name] for column_name in df.columns]

return DaskExpr(
func,
depth=0,
return DaskExpr.from_column_names(
get_column_names,
function_name="all",
evaluate_output_names=get_column_names,
alias_output_names=None,
backend_version=self._backend_version,
version=self._version,
)

def col(self: Self, *column_names: str) -> DaskExpr:
return DaskExpr.from_column_names(
*column_names, backend_version=self._backend_version, version=self._version
passthrough_column_names(column_names),
function_name="col",
backend_version=self._backend_version,
version=self._version,
)

def exclude(self: Self, excluded_names: Container[str]) -> DaskExpr:
def evaluate_output_names(df: DaskLazyFrame) -> Sequence[str]:
return [
column_name
for column_name in df.columns
if column_name not in excluded_names
]

def func(df: DaskLazyFrame) -> list[dx.Series]:
return [
df._native_frame[column_name] for column_name in evaluate_output_names(df)
]

return DaskExpr(
func,
depth=0,
return DaskExpr.from_column_names(
partial(exclude_column_names, names=excluded_names),
function_name="exclude",
evaluate_output_names=evaluate_output_names,
alias_output_names=None,
backend_version=self._backend_version,
version=self._version,
)
Expand Down
13 changes: 8 additions & 5 deletions narwhals/_duckdb/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,17 +79,20 @@ def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Se
@classmethod
def from_column_names(
cls: type[Self],
*column_names: str,
evaluate_column_names: Callable[[DuckDBLazyFrame], Sequence[str]],
/,
*,
function_name: str,
backend_version: tuple[int, ...],
version: Version,
) -> Self:
def func(_: DuckDBLazyFrame) -> list[duckdb.Expression]:
return [ColumnExpression(col_name) for col_name in column_names]
def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]:
return [ColumnExpression(col_name) for col_name in evaluate_column_names(df)]

return cls(
func,
function_name="col",
evaluate_output_names=lambda _df: column_names,
function_name=function_name,
evaluate_output_names=evaluate_column_names,
alias_output_names=None,
backend_version=backend_version,
version=version,
Expand Down
41 changes: 12 additions & 29 deletions narwhals/_duckdb/namespace.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import functools
import operator
from functools import partial
from functools import reduce
from typing import TYPE_CHECKING
from typing import Any
Expand All @@ -12,7 +12,6 @@

from duckdb import CaseExpression
from duckdb import CoalesceOperator
from duckdb import ColumnExpression
from duckdb import FunctionExpression
from duckdb.typing import BIGINT
from duckdb.typing import VARCHAR
Expand All @@ -26,7 +25,9 @@
from narwhals._expression_parsing import combine_evaluate_output_names
from narwhals.typing import CompliantNamespace
from narwhals.utils import Implementation
from narwhals.utils import exclude_column_names
from narwhals.utils import get_column_names
from narwhals.utils import passthrough_column_names

if TYPE_CHECKING:
import duckdb
Expand All @@ -51,14 +52,9 @@ def selectors(self: Self) -> DuckDBSelectorNamespace:
return DuckDBSelectorNamespace(self)

def all(self: Self) -> DuckDBExpr:
def _all(df: DuckDBLazyFrame) -> list[duckdb.Expression]:
return [ColumnExpression(col_name) for col_name in df.columns]

return DuckDBExpr(
call=_all,
return DuckDBExpr.from_column_names(
get_column_names,
function_name="all",
evaluate_output_names=get_column_names,
alias_output_names=None,
backend_version=self._backend_version,
version=self._version,
)
Expand All @@ -80,9 +76,7 @@ def concat(
if how == "vertical" and not all(x.schema == schema for x in items[1:]):
msg = "inputs should all have the same schema"
raise TypeError(msg)
res = functools.reduce(
lambda x, y: x.union(y), (item._native_frame for item in items)
)
res = reduce(lambda x, y: x.union(y), (item._native_frame for item in items))
return first._from_native_frame(res)

def concat_str(
Expand Down Expand Up @@ -238,27 +232,16 @@ def when(self: Self, predicate: DuckDBExpr) -> DuckDBWhen:

def col(self: Self, *column_names: str) -> DuckDBExpr:
return DuckDBExpr.from_column_names(
*column_names, backend_version=self._backend_version, version=self._version
passthrough_column_names(column_names),
function_name="col",
backend_version=self._backend_version,
version=self._version,
)

def exclude(self: Self, excluded_names: Container[str]) -> DuckDBExpr:
def evaluate_output_names(df: DuckDBLazyFrame) -> Sequence[str]:
return [
column_name
for column_name in df.columns
if column_name not in excluded_names
]

def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]:
return [
ColumnExpression(column_name) for column_name in evaluate_output_names(df)
]

return DuckDBExpr(
func,
return DuckDBExpr.from_column_names(
partial(exclude_column_names, names=excluded_names),
function_name="exclude",
evaluate_output_names=evaluate_output_names,
alias_output_names=None,
backend_version=self._backend_version,
version=self._version,
)
Expand Down
15 changes: 10 additions & 5 deletions narwhals/_pandas_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,10 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
@classmethod
def from_column_names(
cls: type[Self],
*column_names: str,
evaluate_column_names: Callable[[PandasLikeDataFrame], Sequence[str]],
/,
*,
function_name: str,
implementation: Implementation,
backend_version: tuple[int, ...],
version: Version,
Expand All @@ -129,10 +132,12 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
backend_version=df._backend_version,
version=df._version,
)
for column_name in column_names
for column_name in evaluate_column_names(df)
]
except KeyError as e:
missing_columns = [x for x in column_names if x not in df.columns]
missing_columns = [
x for x in evaluate_column_names(df) if x not in df.columns
]
raise ColumnNotFoundError.from_missing_and_available_column_names(
missing_columns=missing_columns,
available_columns=df.columns,
Expand All @@ -141,8 +146,8 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
return cls(
func,
depth=0,
function_name="col",
evaluate_output_names=lambda _df: column_names,
function_name=function_name,
evaluate_output_names=evaluate_column_names,
alias_output_names=None,
implementation=implementation,
backend_version=backend_version,
Expand Down
Loading
Loading