Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions narwhals/_plan/_guards.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
"""Common type guards, mostly with inline imports."""

from __future__ import annotations

import datetime as dt
from decimal import Decimal
from typing import TYPE_CHECKING, Any, TypeVar

from narwhals._utils import _hasattr_static

if TYPE_CHECKING:
from typing_extensions import TypeIs

from narwhals._plan import expr
from narwhals._plan.dummy import Expr, Series
from narwhals._plan.protocols import CompliantSeries
from narwhals._plan.typing import NativeSeriesT, Seq
from narwhals.typing import NonNestedLiteral

T = TypeVar("T")

_NON_NESTED_LITERAL_TPS = (
int,
float,
str,
dt.date,
dt.time,
dt.timedelta,
bytes,
Decimal,
)


def _dummy(*_: Any): # type: ignore[no-untyped-def] # noqa: ANN202
from narwhals._plan import dummy

return dummy


def _expr(*_: Any): # type: ignore[no-untyped-def] # noqa: ANN202
from narwhals._plan import expr

return expr


def is_non_nested_literal(obj: Any) -> TypeIs[NonNestedLiteral]:
return obj is None or isinstance(obj, _NON_NESTED_LITERAL_TPS)


def is_expr(obj: Any) -> TypeIs[Expr]:
return isinstance(obj, _dummy().Expr)


def is_column(obj: Any) -> TypeIs[Expr]:
"""Indicate if the given object is a basic/unaliased column."""
return is_expr(obj) and obj.meta.is_column()


def is_series(obj: Series[NativeSeriesT] | Any) -> TypeIs[Series[NativeSeriesT]]:
return isinstance(obj, _dummy().Series)


def is_compliant_series(
obj: CompliantSeries[NativeSeriesT] | Any,
) -> TypeIs[CompliantSeries[NativeSeriesT]]:
return _hasattr_static(obj, "__narwhals_series__")


def is_iterable_reject(obj: Any) -> TypeIs[str | bytes | Series | CompliantSeries]:
return isinstance(obj, (str, bytes, _dummy().Series)) or is_compliant_series(obj)


def is_window_expr(obj: Any) -> TypeIs[expr.WindowExpr]:
return isinstance(obj, _expr().WindowExpr)


def is_function_expr(obj: Any) -> TypeIs[expr.FunctionExpr[Any]]:
return isinstance(obj, _expr().FunctionExpr)


def is_binary_expr(obj: Any) -> TypeIs[expr.BinaryExpr]:
return isinstance(obj, _expr().BinaryExpr)


def is_agg_expr(obj: Any) -> TypeIs[expr.AggExpr]:
return isinstance(obj, _expr().AggExpr)


def is_aggregation(obj: Any) -> TypeIs[expr.AggExpr | expr.FunctionExpr[Any]]:
"""Superset of `ExprIR.is_scalar`, excludes literals & len."""
return is_agg_expr(obj) or (is_function_expr(obj) and obj.is_scalar)


def is_literal(obj: Any) -> TypeIs[expr.Literal[Any]]:
return isinstance(obj, _expr().Literal)


def is_horizontal_reduction(obj: Any) -> TypeIs[expr.FunctionExpr[Any]]:
return is_function_expr(obj) and obj.options.is_input_wildcard_expansion()


def is_tuple_of(obj: Any, tp: type[T]) -> TypeIs[Seq[T]]:
return bool(isinstance(obj, tuple) and obj and isinstance(obj[0], tp))
151 changes: 151 additions & 0 deletions narwhals/_plan/_immutable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Literal, TypeVar

if TYPE_CHECKING:
from collections.abc import Iterator
from typing import Any, Callable

from typing_extensions import Never, Self, dataclass_transform

else:
# https://docs.python.org/3/library/typing.html#typing.dataclass_transform
def dataclass_transform(
*,
eq_default: bool = True,
order_default: bool = False,
kw_only_default: bool = False,
frozen_default: bool = False,
field_specifiers: tuple[type[Any] | Callable[..., Any], ...] = (),
**kwargs: Any,
) -> Callable[[T], T]:
def decorator(cls_or_fn: T) -> T:
cls_or_fn.__dataclass_transform__ = {
"eq_default": eq_default,
"order_default": order_default,
"kw_only_default": kw_only_default,
"frozen_default": frozen_default,
"field_specifiers": field_specifiers,
"kwargs": kwargs,
}
return cls_or_fn

return decorator


T = TypeVar("T")
_IMMUTABLE_HASH_NAME: Literal["__immutable_hash_value__"] = "__immutable_hash_value__"


@dataclass_transform(kw_only_default=True, frozen_default=True)
class Immutable:
"""A poor man's frozen dataclass.

- Keyword-only constructor (IDE supported)
- Manual `__slots__` required
- Compatible with [`copy.replace`]
- No handling for default arguments

[`copy.replace`]: https://docs.python.org/3.13/library/copy.html#copy.replace
"""

__slots__ = (_IMMUTABLE_HASH_NAME,)
__immutable_hash_value__: int

@property
def __immutable_keys__(self) -> Iterator[str]:
slots: tuple[str, ...] = self.__slots__
for name in slots:
if name != _IMMUTABLE_HASH_NAME:
yield name

@property
def __immutable_values__(self) -> Iterator[Any]:
for name in self.__immutable_keys__:
yield getattr(self, name)

@property
def __immutable_items__(self) -> Iterator[tuple[str, Any]]:
for name in self.__immutable_keys__:
yield name, getattr(self, name)

@property
def __immutable_hash__(self) -> int:
if hasattr(self, _IMMUTABLE_HASH_NAME):
return self.__immutable_hash_value__
hash_value = hash((self.__class__, *self.__immutable_values__))
object.__setattr__(self, _IMMUTABLE_HASH_NAME, hash_value)
return self.__immutable_hash_value__

def __setattr__(self, name: str, value: Never) -> Never:
msg = f"{type(self).__name__!r} is immutable, {name!r} cannot be set."
raise AttributeError(msg)

def __replace__(self, **changes: Any) -> Self:
"""https://docs.python.org/3.13/library/copy.html#copy.replace""" # noqa: D415
if len(changes) == 1:
# The most common case is a single field replacement.
# Iff that field happens to be equal, we can noop, preserving the current object's hash.
name, value_changed = next(iter(changes.items()))
if getattr(self, name) == value_changed:
return self
changes = dict(self.__immutable_items__, **changes)
else:
for name, value_current in self.__immutable_items__:
if name not in changes or value_current == changes[name]:
changes[name] = value_current
return type(self)(**changes)

def __init_subclass__(cls, *args: Any, **kwds: Any) -> None:
super().__init_subclass__(*args, **kwds)
if cls.__slots__:
...
else:
cls.__slots__ = ()

def __hash__(self) -> int:
return self.__immutable_hash__

def __eq__(self, other: object) -> bool:
if self is other:
return True
if type(self) is not type(other):
return False
return all(
getattr(self, key) == getattr(other, key) for key in self.__immutable_keys__
)

def __str__(self) -> str:
fields = ", ".join(f"{_field_str(k, v)}" for k, v in self.__immutable_items__)
return f"{type(self).__name__}({fields})"

def __init__(self, **kwds: Any) -> None:
required: set[str] = set(self.__immutable_keys__)
if not required and not kwds:
# NOTE: Fastpath for empty slots
...
elif required == set(kwds):
for name, value in kwds.items():
object.__setattr__(self, name, value)
elif missing := required.difference(kwds):
msg = (
f"{type(self).__name__!r} requires attributes {sorted(required)!r}, \n"
f"but missing values for {sorted(missing)!r}"
)
raise TypeError(msg)
else:
extra = set(kwds).difference(required)
msg = (
f"{type(self).__name__!r} only supports attributes {sorted(required)!r}, \n"
f"but got unknown arguments {sorted(extra)!r}"
)
raise TypeError(msg)


def _field_str(name: str, value: Any) -> str:
if isinstance(value, tuple):
inner = ", ".join(f"{v}" for v in value)
return f"{name}=[{inner}]"
if isinstance(value, str):
return f"{name}={value!r}"
return f"{name}={value}"
60 changes: 10 additions & 50 deletions narwhals/_plan/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,16 @@

from typing import TYPE_CHECKING, Any

from narwhals._plan.common import ExprIR, _pascal_to_snake_case, replace
from narwhals._plan.common import ExprIR, _pascal_to_snake_case
from narwhals._plan.exceptions import agg_scalar_error

if TYPE_CHECKING:
from collections.abc import Iterator

from typing_extensions import Self

from narwhals._plan.typing import MapIR
from narwhals.typing import RollingInterpolationMethod


class AggExpr(ExprIR):
class AggExpr(ExprIR, child=("expr",)):
__slots__ = ("expr",)
expr: ExprIR

Expand All @@ -25,50 +22,31 @@ def is_scalar(self) -> bool:
def __repr__(self) -> str:
return f"{self.expr!r}.{_pascal_to_snake_case(type(self).__name__)}()"

def iter_left(self) -> Iterator[ExprIR]:
yield from self.expr.iter_left()
yield self

def iter_right(self) -> Iterator[ExprIR]:
yield self
yield from self.expr.iter_right()

def iter_output_name(self) -> Iterator[ExprIR]:
yield from self.expr.iter_output_name()

def map_ir(self, function: MapIR, /) -> ExprIR:
return function(self.with_expr(self.expr.map_ir(function)))

def with_expr(self, expr: ExprIR, /) -> Self:
return replace(self, expr=expr)

def __init__(self, *, expr: ExprIR, **kwds: Any) -> None:
if expr.is_scalar:
raise agg_scalar_error(self, expr)
super().__init__(expr=expr, **kwds) # pyright: ignore[reportCallIssue]


# fmt: off
class Count(AggExpr): ...


class Max(AggExpr): ...


class Mean(AggExpr): ...


class Median(AggExpr): ...


class Min(AggExpr): ...


class NUnique(AggExpr): ...


class Sum(AggExpr): ...
class OrderableAggExpr(AggExpr): ...
class First(OrderableAggExpr): ...
class Last(OrderableAggExpr): ...
class ArgMin(OrderableAggExpr): ...
class ArgMax(OrderableAggExpr): ...
# fmt: on
class Quantile(AggExpr):
__slots__ = (*AggExpr.__slots__, "interpolation", "quantile")

quantile: float
interpolation: RollingInterpolationMethod

Expand All @@ -78,24 +56,6 @@ class Std(AggExpr):
ddof: int


class Sum(AggExpr): ...


class Var(AggExpr):
__slots__ = (*AggExpr.__slots__, "ddof")
ddof: int


class OrderableAggExpr(AggExpr): ...


class First(OrderableAggExpr): ...


class Last(OrderableAggExpr): ...


class ArgMin(OrderableAggExpr): ...


class ArgMax(OrderableAggExpr): ...
Loading
Loading