Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
8550da7
wip
MarcoGorelli Mar 22, 2025
f7f0a90
lint
MarcoGorelli Mar 22, 2025
2a799c8
remove unnecessary check
MarcoGorelli Mar 22, 2025
40aea92
clean up duckdb
MarcoGorelli Mar 22, 2025
bac71ce
clean up spark-like
MarcoGorelli Mar 22, 2025
c8f483d
snake case
MarcoGorelli Mar 22, 2025
3fd7284
factor out `resolve_expansion_kind`
MarcoGorelli Mar 22, 2025
d06b4e5
completely remove function_name from duckdb and pyspark
MarcoGorelli Mar 22, 2025
291bd10
remove --offline
MarcoGorelli Mar 22, 2025
385fd13
remove outdated arg
MarcoGorelli Mar 22, 2025
2d4dd8a
chore(typing): Ignore issues from #2263
dangotbanned Mar 23, 2025
a58c583
refactor: Add `DepthTrackingExpr`
dangotbanned Mar 23, 2025
8e4bed2
Merge remote-tracking branch 'upstream/main' into pr/MarcoGorelli/2266
dangotbanned Mar 23, 2025
274fd60
revert: Add both classmethods back
dangotbanned Mar 23, 2025
e333fa3
refactor: Add `DepthTrackingNamespace`
dangotbanned Mar 23, 2025
5e66ae5
revert: Undo copy/paste from (https://github.com/narwhals-dev/narwhal…
dangotbanned Mar 23, 2025
94cc379
ci(typing): ignore, unused ignore
dangotbanned Mar 23, 2025
b5b1c92
test: Kinda unbreak `sqlframe` test
dangotbanned Mar 23, 2025
8fa9f1c
refactor: Hide `CompliantExpr` internals from `LazyGroupBy`
dangotbanned Mar 23, 2025
3129332
Remove comment
dangotbanned Mar 23, 2025
8dbeb0e
Merge remote-tracking branch 'upstream/main' into expansion-kind
MarcoGorelli Mar 23, 2025
ba235aa
Merge remote-tracking branch 'upstream/main' into pr/MarcoGorelli/2266
dangotbanned Mar 23, 2025
3cdc967
redo (94cc379edf25f749693306ac0a126517f61bb9ad)
dangotbanned Mar 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion narwhals/_arrow/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ def from_column_names(
evaluate_column_names: Callable[[ArrowDataFrame], Sequence[str]],
/,
*,
function_name: str,
context: _FullContext,
function_name: str = "",
) -> Self:
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
try:
Expand Down
2 changes: 0 additions & 2 deletions narwhals/_arrow/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,6 @@ def __init__(
self._implementation = Implementation.PYARROW
self._version = version

# --- selection ---

def len(self: Self) -> ArrowExpr:
# coverage bug? this is definitely hit
return self._expr( # pragma: no cover
Expand Down
69 changes: 52 additions & 17 deletions narwhals/_compliant/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,6 @@ class CompliantExpr(Protocol38[CompliantFrameT, CompliantSeriesOrNativeExprT_co]
_version: Version
_evaluate_output_names: Callable[[CompliantFrameT], Sequence[str]]
_alias_output_names: Callable[[Sequence[str]], Sequence[str]] | None
_depth: int
_function_name: str
_metadata: ExprMetadata | None

def __call__(
Expand All @@ -101,11 +99,12 @@ def from_column_names(
evaluate_column_names: Callable[[CompliantFrameT], Sequence[str]],
/,
*,
function_name: str,
context: _FullContext,
) -> Self: ...
@classmethod
def from_column_indices(cls, *column_indices: int, context: _FullContext) -> Self: ...
def from_column_indices(
cls: type[Self], *column_indices: int, context: _FullContext
) -> Self: ...

def _with_metadata(self, metadata: ExprMetadata) -> Self: ...

Expand Down Expand Up @@ -272,25 +271,64 @@ def __invert__(self) -> Self: ...
def broadcast(
self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]
) -> Self: ...
def _is_multi_output_agg(self) -> bool:
"""Return `True` for multi-output aggregations.
def _is_multi_output_unnamed(self) -> bool:
"""Return `True` for multi-output aggregations without names.

For example, column `'a'` only appears in the output as a grouping key:

Here we skip the keys, else they would appear duplicated in the output:
df.group_by('a').agg(nw.all().sum())

df.group_by("a").agg(nw.all().mean())
It does not get included in:

nw.all().sum().
"""
return self._function_name.split("->", maxsplit=1)[0] in {"all", "selector"}
assert self._metadata is not None # noqa: S101
return self._metadata.expansion_kind.is_multi_unnamed()
Copy link
Member

@dangotbanned dangotbanned Mar 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(#2266 (comment)) can be taken a step further by adding:

ExprMetadata.is_multi_unnamed

So this part wouldn't need to even know what an ExpansionKind is, and having the line here be:

return self._metadata.is_multi_unnamed()

But this works as well



class DepthTrackingExpr(
CompliantExpr[CompliantFrameT, CompliantSeriesOrNativeExprT_co],
Protocol38[CompliantFrameT, CompliantSeriesOrNativeExprT_co],
):
_depth: int
_function_name: str

@classmethod
def from_column_names(
cls: type[Self],
evaluate_column_names: Callable[[CompliantFrameT], Sequence[str]],
/,
*,
context: _FullContext,
function_name: str = "",
) -> Self: ...

def _is_elementary(self) -> bool:
"""Check if expr is elementary.

Examples:
- nw.col('a').mean() # depth 1
- nw.mean('a') # depth 1
- nw.len() # depth 0

as opposed to, say

- nw.col('a').filter(nw.col('b')>nw.col('c')).max()

Elementary expressions are the only ones supported properly in
pandas, PyArrow, and Dask.
Comment on lines +318 to +319
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I copied directly from is_elementary_expression - but this line might be redundant now?

Maybe a description in the class doc could have some stuff from https://narwhals-dev.github.io/narwhals/how_it_works

"""
return self._depth < 2

def __repr__(self) -> str: # pragma: no cover
return f"{type(self).__name__}(depth={self._depth}, function_name={self._function_name})"


class EagerExpr(
CompliantExpr[EagerDataFrameT, EagerSeriesT],
DepthTrackingExpr[EagerDataFrameT, EagerSeriesT],
Protocol38[EagerDataFrameT, EagerSeriesT],
):
_call: Callable[[EagerDataFrameT], Sequence[EagerSeriesT]]
_depth: int
_function_name: str
_evaluate_output_names: Any
_alias_output_names: Any
_call_kwargs: dict[str, Any]

def __init__(
Expand All @@ -310,9 +348,6 @@ def __init__(
def __call__(self, df: EagerDataFrameT) -> Sequence[EagerSeriesT]:
return self._call(df)

def __repr__(self) -> str: # pragma: no cover
return f"{type(self).__name__}(depth={self._depth}, function_name={self._function_name})"

def __narwhals_namespace__(
self,
) -> EagerNamespace[EagerDataFrameT, EagerSeriesT, Self]: ...
Expand Down
27 changes: 16 additions & 11 deletions narwhals/_compliant/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,18 @@
from typing import TypeVar

from narwhals._compliant.typing import CompliantDataFrameT_co
from narwhals._compliant.typing import CompliantExprAny
from narwhals._compliant.typing import CompliantExprT_contra
from narwhals._compliant.typing import CompliantFrameT_co
from narwhals._compliant.typing import CompliantLazyFrameT_co
from narwhals._compliant.typing import EagerExprT_contra
from narwhals._compliant.typing import LazyExprT_contra
from narwhals._compliant.typing import NativeExprT_co
from narwhals._expression_parsing import is_elementary_expression

if TYPE_CHECKING:
from typing_extensions import TypeAlias

from narwhals._compliant.expr import DepthTrackingExpr

if not TYPE_CHECKING: # pragma: no cover
if sys.version_info >= (3, 9):
from typing import Protocol as Protocol38
Expand All @@ -49,6 +50,10 @@
NarwhalsAggregation: TypeAlias = Literal[
"sum", "mean", "median", "max", "min", "std", "var", "len", "n_unique", "count"
]
DepthTrackingExprAny: TypeAlias = "DepthTrackingExpr[Any, Any]"
DepthTrackingExprT_contra = TypeVar(
"DepthTrackingExprT_contra", bound=DepthTrackingExprAny, contravariant=True
)


_RE_LEAF_NAME: re.Pattern[str] = re.compile(r"(\w+->)")
Expand All @@ -75,8 +80,8 @@ def agg(self, *exprs: CompliantExprT_contra) -> CompliantFrameT_co: ...


class DepthTrackingGroupBy(
CompliantGroupBy[CompliantFrameT_co, CompliantExprT_contra],
Protocol38[CompliantFrameT_co, CompliantExprT_contra, NativeAggregationT_co],
CompliantGroupBy[CompliantFrameT_co, DepthTrackingExprT_contra],
Protocol38[CompliantFrameT_co, DepthTrackingExprT_contra, NativeAggregationT_co],
):
"""`CompliantGroupBy` variant, deals with `Eager` and other backends that utilize `CompliantExpr._depth`."""

Expand All @@ -87,7 +92,7 @@ class DepthTrackingGroupBy(
- `Dask` *may* return a `Callable` instead of a `str` referring to one.
"""

def _ensure_all_simple(self, exprs: Sequence[CompliantExprT_contra]) -> None:
def _ensure_all_simple(self, exprs: Sequence[DepthTrackingExprT_contra]) -> None:
for expr in exprs:
if not self._is_simple(expr):
name = self.compliant._implementation.name.lower()
Expand All @@ -104,9 +109,9 @@ def _ensure_all_simple(self, exprs: Sequence[CompliantExprT_contra]) -> None:
raise ValueError(msg)

@classmethod
def _is_simple(cls, expr: CompliantExprAny, /) -> bool:
def _is_simple(cls, expr: DepthTrackingExprAny, /) -> bool:
"""Return `True` is we can efficiently use `expr` in a native `group_by` context."""
return is_elementary_expression(expr) and cls._leaf_name(expr) in cls._REMAP_AGGS
return expr._is_elementary() and cls._leaf_name(expr) in cls._REMAP_AGGS

@classmethod
def _remap_expr_name(
Expand All @@ -123,14 +128,14 @@ def _remap_expr_name(
return cls._REMAP_AGGS.get(name, name)

@classmethod
def _leaf_name(cls, expr: CompliantExprAny, /) -> NarwhalsAggregation | Any:
def _leaf_name(cls, expr: DepthTrackingExprAny, /) -> NarwhalsAggregation | Any:
"""Return the last function name in the chain defined by `expr`."""
return _RE_LEAF_NAME.sub("", expr._function_name)


class EagerGroupBy(
DepthTrackingGroupBy[CompliantDataFrameT_co, CompliantExprT_contra, str],
Protocol38[CompliantDataFrameT_co, CompliantExprT_contra],
DepthTrackingGroupBy[CompliantDataFrameT_co, EagerExprT_contra, str],
Protocol38[CompliantDataFrameT_co, EagerExprT_contra],
):
def __iter__(self) -> Iterator[tuple[Any, CompliantDataFrameT_co]]: ...

Expand All @@ -147,7 +152,7 @@ def _evaluate_expr(self, expr: LazyExprT_contra, /) -> Iterator[NativeExprT_co]:
else output_names
)
native_exprs = expr(self.compliant)
if expr._is_multi_output_agg():
if expr._is_multi_output_unnamed():
for native_expr, name, alias in zip(native_exprs, output_names, aliases):
if name not in self._keys:
yield native_expr.alias(alias)
Expand Down
39 changes: 31 additions & 8 deletions narwhals/_compliant/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Iterable
from typing import Literal
from typing import Protocol
from typing import TypeVar

from narwhals._compliant.typing import CompliantExprT
from narwhals._compliant.typing import CompliantFrameT
Expand All @@ -20,6 +21,7 @@
if TYPE_CHECKING:
from typing_extensions import TypeAlias

from narwhals._compliant.expr import DepthTrackingExpr
from narwhals._compliant.selectors import CompliantSelectorNamespace
from narwhals._compliant.when_then import CompliantWhen
from narwhals._compliant.when_then import EagerWhen
Expand All @@ -31,27 +33,26 @@

__all__ = ["CompliantNamespace", "EagerNamespace"]

DepthTrackingExprAny: TypeAlias = "DepthTrackingExpr[Any, Any]"
DepthTrackingExprT = TypeVar("DepthTrackingExprT", bound=DepthTrackingExprAny)


class CompliantNamespace(Protocol[CompliantFrameT, CompliantExprT]):
_implementation: Implementation
_backend_version: tuple[int, ...]
_version: Version

def all(self) -> CompliantExprT:
return self._expr.from_column_names(
get_column_names, function_name="all", context=self
)
return self._expr.from_column_names(get_column_names, context=self)

def col(self, *column_names: str) -> CompliantExprT:
return self._expr.from_column_names(
passthrough_column_names(column_names), function_name="col", context=self
passthrough_column_names(column_names), context=self
)

def exclude(self, excluded_names: Container[str]) -> CompliantExprT:
return self._expr.from_column_names(
partial(exclude_column_names, names=excluded_names),
function_name="exclude",
context=self,
partial(exclude_column_names, names=excluded_names), context=self
)

def nth(self, *column_indices: int) -> CompliantExprT:
Expand Down Expand Up @@ -86,8 +87,30 @@ def selectors(self) -> CompliantSelectorNamespace[Any, Any]: ...
def _expr(self) -> type[CompliantExprT]: ...


class DepthTrackingNamespace(
CompliantNamespace[CompliantFrameT, DepthTrackingExprT],
Protocol[CompliantFrameT, DepthTrackingExprT],
):
def all(self) -> DepthTrackingExprT:
return self._expr.from_column_names(
get_column_names, function_name="all", context=self
)

def col(self, *column_names: str) -> DepthTrackingExprT:
return self._expr.from_column_names(
passthrough_column_names(column_names), function_name="col", context=self
)

def exclude(self, excluded_names: Container[str]) -> DepthTrackingExprT:
return self._expr.from_column_names(
partial(exclude_column_names, names=excluded_names),
function_name="exclude",
context=self,
)


class EagerNamespace(
CompliantNamespace[EagerDataFrameT, EagerExprT],
DepthTrackingNamespace[EagerDataFrameT, EagerExprT],
Protocol[EagerDataFrameT, EagerSeriesT, EagerExprT],
):
@property
Expand Down
5 changes: 0 additions & 5 deletions narwhals/_compliant/selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from narwhals.utils import get_column_names
from narwhals.utils import import_dtypes_module
from narwhals.utils import is_compliant_dataframe
from narwhals.utils import is_tracks_depth

if not TYPE_CHECKING: # pragma: no cover
# TODO @dangotbanned: Remove after dropping `3.8` (#2084)
Expand Down Expand Up @@ -302,10 +301,6 @@ def names(df: FrameT) -> Sequence[str]:
def __invert__(self: Self) -> CompliantSelector[FrameT, SeriesOrExprT]:
return self.selectors.all() - self # type: ignore[no-any-return]

def __repr__(self: Self) -> str: # pragma: no cover
s = f"depth={self._depth}, " if is_tracks_depth(self._implementation) else ""
return f"{type(self).__name__}({s}function_name={self._function_name})"


def _eval_lhs_rhs(
df: CompliantDataFrame[Any, Any, Any] | CompliantLazyFrame[Any, Any],
Expand Down
1 change: 1 addition & 0 deletions narwhals/_compliant/when_then.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ class CompliantThen(CompliantExpr[FrameT, SeriesT], Protocol38[FrameT, SeriesT,
_call: Callable[[FrameT], Sequence[SeriesT]]
_when_value: CompliantWhen[FrameT, SeriesT, ExprT]
_function_name: str
_depth: int
Copy link
Member

@dangotbanned dangotbanned Mar 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor nuisance, can fix properly in a follow-up and align DaskWhen with the current EagerWhen as DepthTrackingWhen

_implementation: Implementation
_backend_version: tuple[int, ...]
_version: Version
Expand Down
11 changes: 7 additions & 4 deletions narwhals/_dask/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Sequence

from narwhals._compliant import LazyExpr
from narwhals._compliant.expr import DepthTrackingExpr
from narwhals._dask.expr_dt import DaskExprDateTimeNamespace
from narwhals._dask.expr_name import DaskExprNameNamespace
from narwhals._dask.expr_str import DaskExprStringNamespace
Expand All @@ -16,7 +17,6 @@
from narwhals._dask.utils import narwhals_to_native_dtype
from narwhals._expression_parsing import ExprKind
from narwhals._expression_parsing import evaluate_output_names_and_aliases
from narwhals._expression_parsing import is_elementary_expression
from narwhals._pandas_like.utils import native_to_narwhals_dtype
from narwhals.exceptions import ColumnNotFoundError
from narwhals.exceptions import InvalidOperationError
Expand All @@ -42,7 +42,10 @@
from narwhals.utils import _FullContext


class DaskExpr(LazyExpr["DaskLazyFrame", "dx.Series"]):
class DaskExpr(
LazyExpr["DaskLazyFrame", "dx.Series"],
DepthTrackingExpr["DaskLazyFrame", "dx.Series"],
):
_implementation: Implementation = Implementation.DASK

def __init__(
Expand Down Expand Up @@ -115,8 +118,8 @@ def from_column_names(
evaluate_column_names: Callable[[DaskLazyFrame], Sequence[str]],
/,
*,
function_name: str,
context: _FullContext,
function_name: str = "",
) -> Self:
def func(df: DaskLazyFrame) -> list[dx.Series]:
try:
Expand Down Expand Up @@ -573,7 +576,7 @@ def over(
# which we can always easily support, as it doesn't require grouping.
def func(df: DaskLazyFrame) -> Sequence[dx.Series]:
return self(df.sort(*order_by, descending=False, nulls_last=False))
elif not is_elementary_expression(self): # pragma: no cover
elif not self._is_elementary(): # pragma: no cover
msg = (
"Only elementary expressions are supported for `.over` in dask.\n\n"
"Please see: "
Expand Down
4 changes: 2 additions & 2 deletions narwhals/_dask/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
import dask.dataframe as dd
import pandas as pd

from narwhals._compliant import CompliantNamespace
from narwhals._compliant import CompliantThen
from narwhals._compliant import CompliantWhen
from narwhals._compliant.namespace import DepthTrackingNamespace
from narwhals._dask.dataframe import DaskLazyFrame
from narwhals._dask.expr import DaskExpr
from narwhals._dask.selectors import DaskSelectorNamespace
Expand All @@ -38,7 +38,7 @@
import dask_expr as dx


class DaskNamespace(CompliantNamespace[DaskLazyFrame, DaskExpr]):
class DaskNamespace(DepthTrackingNamespace[DaskLazyFrame, "DaskExpr"]):
_implementation: Implementation = Implementation.DASK

@property
Expand Down
Loading
Loading