Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions narwhals/_plan/arrow/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from narwhals._arrow.utils import native_to_narwhals_dtype
from narwhals._plan.arrow import functions as fn
from narwhals._plan.arrow.series import ArrowSeries
from narwhals._plan.common import ExprIR
from narwhals._plan.protocols import EagerDataFrame, namespace
from narwhals._plan.common import ExprIR, namespace
from narwhals._plan.protocols import EagerDataFrame
from narwhals._utils import Version

if t.TYPE_CHECKING:
Expand Down
4 changes: 2 additions & 2 deletions narwhals/_plan/arrow/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from narwhals._plan.arrow.functions import lit
from narwhals._plan.arrow.series import ArrowSeries
from narwhals._plan.arrow.typing import NativeScalar, StoresNativeT_co
from narwhals._plan.common import ExprIR, NamedIR, into_dtype
from narwhals._plan.protocols import EagerExpr, EagerScalar, ExprDispatch, namespace
from narwhals._plan.common import ExprIR, NamedIR, into_dtype, namespace
from narwhals._plan.protocols import EagerExpr, EagerScalar, ExprDispatch
from narwhals._utils import (
Implementation,
Version,
Expand Down
146 changes: 116 additions & 30 deletions narwhals/_plan/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import sys
from collections.abc import Iterable
from decimal import Decimal
from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeVar, cast, overload
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, TypeVar, cast, overload

from narwhals._plan.typing import (
Accessor,
Expand All @@ -25,9 +25,9 @@

if TYPE_CHECKING:
from collections.abc import Iterator
from typing import Any, Callable, Literal
from typing import Any, Callable

from typing_extensions import Never, Self, TypeIs, dataclass_transform
from typing_extensions import Never, Self, TypeAlias, TypeIs, dataclass_transform

from narwhals._plan import expr
from narwhals._plan.dummy import Expr, Selector, Series
Expand All @@ -42,7 +42,11 @@
)
from narwhals._plan.meta import IRMetaNamespace
from narwhals._plan.options import FunctionOptions
from narwhals._plan.protocols import CompliantSeries
from narwhals._plan.protocols import (
CompliantSeries,
NamespaceT_co,
SupportsNarwhalsNamespace,
)
from narwhals.typing import NonNestedDType, NonNestedLiteral

else:
Expand Down Expand Up @@ -199,16 +203,123 @@ def _field_str(name: str, value: Any) -> str:
return f"{name}={value}"


def _tp_repr(tp: type[Any], /) -> str:
return _pascal_to_snake_case(tp.__name__)


# TODO @dangotbanned: Add caching strategy?
def _function_repr(tp: type[Function], /) -> str:
name = _tp_repr(tp)
return f"{ns_name}.{name}" if (ns_name := tp._accessor) else name


def _pascal_to_snake_case(s: str) -> str:
"""Convert a PascalCase, camelCase string to snake_case.

Adapted from https://github.com/pydantic/pydantic/blob/f7a9b73517afecf25bf898e3b5f591dffe669778/pydantic/alias_generators.py#L43-L62
"""
# Handle the sequence of uppercase letters followed by a lowercase letter
snake = _PATTERN_UPPER_LOWER.sub(_re_repl_snake, s)
# Insert an underscore between a lowercase letter and an uppercase letter
return _PATTERN_LOWER_UPPER.sub(_re_repl_snake, snake).lower()


_PATTERN_UPPER_LOWER = re.compile(r"([A-Z]+)([A-Z][a-z])")
_PATTERN_LOWER_UPPER = re.compile(r"([a-z])([A-Z])")


def _re_repl_snake(match: re.Match[str], /) -> str:
return f"{match.group(1)}_{match.group(2)}"


DispatchOrigin: TypeAlias = Literal["expr", "expr-accessor", "__narwhals_namespace__"]
Incomplete: TypeAlias = "Any"


def namespace(obj: SupportsNarwhalsNamespace[NamespaceT_co], /) -> NamespaceT_co:
"""Return the compliant namespace."""
return obj.__narwhals_namespace__()


class _ExprIRConfig(Immutable):
__slots__ = ("no_dispatch", "origin", "override_name")
origin: DispatchOrigin
override_name: str
no_dispatch: bool

def __repr__(self) -> str:
return self.__str__()


def dispatch_config(
*, origin: DispatchOrigin = "expr", override_name: str = "", no_dispatch: bool = False
) -> _ExprIRConfig:
return _ExprIRConfig(
origin=origin, override_name=override_name, no_dispatch=no_dispatch
)


def _dispatch_generate(
tp: type[ExprIRT], /
) -> Callable[[Incomplete, ExprIRT, Incomplete, str], Incomplete]:
if tp.__expr_ir_config__.no_dispatch:

def _(self: Any, node: ExprIRT, frame: Any, name: str) -> Any: # noqa: ARG001
tp_name = type(node).__name__
msg = (
f"{tp_name!r} should not appear at the compliant-level.\n\n"
f"Make sure to expand all expressions first, got:\n{self!r}\n{node!r}\n{name!r}"
)
raise TypeError(msg)

return _
method_name = tp.__expr_ir_config__.override_name or _tp_repr(tp)
origin = tp.__expr_ir_config__.origin
if origin == "expr":

def _(self: Any, node: ExprIRT, frame: Any, name: str) -> Any:
return getattr(self, method_name)(node, frame, name)

return _
if origin == "__narwhals_namespace__":

def _(self: Any, node: ExprIRT, frame: Any, name: str) -> Any:
return getattr(namespace(self), method_name)(node, frame, name)

return _
msg = f"`FunctionExpr` can't work this way, the dispatch mostly happens on `.function`, which has the accessor.\n\nGot: {tp.__name__}"
raise NotImplementedError(msg)


class ExprIR(Immutable):
"""Anything that can be a node on a graph of expressions."""

_child: ClassVar[Seq[str]] = ()
"""Nested node names, in iteration order."""

def __init_subclass__(cls, *args: Any, child: Seq[str] = (), **kwds: Any) -> None:
__expr_ir_config__: ClassVar[_ExprIRConfig] = dispatch_config()
__expr_ir_dispatch__: ClassVar[
staticmethod[[Incomplete, Self, Incomplete, str], Incomplete]
]

def __init_subclass__(
cls: type[Self], # `mypy` doesn't understand without
*args: Any,
child: Seq[str] = (),
config: _ExprIRConfig | None = None,
**kwds: Any,
) -> None:
super().__init_subclass__(*args, **kwds)
if child:
cls._child = child
if config:
cls.__expr_ir_config__ = config
cls.__expr_ir_dispatch__ = staticmethod(_dispatch_generate(cls))

def dispatch(self, ctx: Incomplete, frame: Incomplete, name: str, /) -> Incomplete:
"""Evaluate expression in `frame`, using `ctx` for implementation(s)."""
# NOTE: `mypy` would require `Self` on `self` but that conflicts w/ pre-commit
return self.__expr_ir_dispatch__(ctx, cast("Self", self), frame, name)

def to_narwhals(self, version: Version = Version.MAIN) -> Expr:
from narwhals._plan import dummy
Expand Down Expand Up @@ -500,31 +611,6 @@ def __repr__(self) -> str:
return _function_repr(type(self))


# TODO @dangotbanned: Add caching strategy?
def _function_repr(tp: type[Function], /) -> str:
name = _pascal_to_snake_case(tp.__name__)
return f"{ns_name}.{name}" if (ns_name := tp._accessor) else name


def _pascal_to_snake_case(s: str) -> str:
"""Convert a PascalCase, camelCase string to snake_case.

Adapted from https://github.com/pydantic/pydantic/blob/f7a9b73517afecf25bf898e3b5f591dffe669778/pydantic/alias_generators.py#L43-L62
"""
# Handle the sequence of uppercase letters followed by a lowercase letter
snake = _PATTERN_UPPER_LOWER.sub(_re_repl_snake, s)
# Insert an underscore between a lowercase letter and an uppercase letter
return _PATTERN_LOWER_UPPER.sub(_re_repl_snake, snake).lower()


_PATTERN_UPPER_LOWER = re.compile(r"([A-Z]+)([A-Z][a-z])")
_PATTERN_LOWER_UPPER = re.compile(r"([a-z])([A-Z])")


def _re_repl_snake(match: re.Match[str], /) -> str:
return f"{match.group(1)}_{match.group(2)}"


_NON_NESTED_LITERAL_TPS = (
int,
float,
Expand Down
11 changes: 4 additions & 7 deletions narwhals/_plan/protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import TYPE_CHECKING, Any, ClassVar, Literal, Protocol, overload

from narwhals._plan import aggregation as agg, boolean, expr, functions as F, strings
from narwhals._plan.common import ExprIR, Function, NamedIR, flatten_hash_safe
from narwhals._plan.common import ExprIR, Function, NamedIR, flatten_hash_safe, namespace
from narwhals._plan.typing import NativeDataFrameT, NativeFrameT, NativeSeriesT, Seq
from narwhals._typing_compat import TypeVar
from narwhals._utils import Version, _hasattr_static
Expand Down Expand Up @@ -70,11 +70,7 @@
LazyScalarT_co = TypeVar("LazyScalarT_co", bound=LazyScalarAny, covariant=True)


def namespace(obj: _SupportsNarwhalsNamespace[NamespaceT_co], /) -> NamespaceT_co:
return obj.__narwhals_namespace__()


class _SupportsNarwhalsNamespace(Protocol[NamespaceT_co]):
class SupportsNarwhalsNamespace(Protocol[NamespaceT_co]):
def __narwhals_namespace__(self) -> NamespaceT_co: ...


Expand Down Expand Up @@ -159,7 +155,7 @@ class ExprDispatch(StoresVersion, Protocol[FrameT_contra, R_co, NamespaceT_co]):
node, frame, name
),
expr.Len: lambda self, node, frame, name: namespace(self).len(node, frame, name),
expr.Cast: lambda self, node, frame, name: self.cast(node, frame, name),
# expr.Cast: lambda self, node, frame, name: self.cast(node, frame, name),
expr.Sort: lambda self, node, frame, name: self.sort(node, frame, name),
expr.SortBy: lambda self, node, frame, name: self.sort_by(node, frame, name),
expr.Filter: lambda self, node, frame, name: self.filter(node, frame, name),
Expand Down Expand Up @@ -246,6 +242,7 @@ def _dispatch(self, node: ExprIR, frame: FrameT_contra, name: str) -> R_co:
result := method(self, node, frame, name)
):
return result # type: ignore[no-any-return]
return node.dispatch(self, frame, name) # type: ignore[no-any-return]
msg = f"Support for {node.__class__.__name__!r} is not yet implemented, got:\n{node!r}"
raise NotImplementedError(msg)

Expand Down
Loading