Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
c9c7360
refactor(expr-ir): Add `namespace` helper
dangotbanned Aug 31, 2025
275cfba
Merge branch 'expr-ir/shrink-main' into expr-ir/nw-namespace
dangotbanned Sep 1, 2025
0cffdc0
Merge branch 'expr-ir/shrink-main' into expr-ir/nw-namespace
dangotbanned Sep 1, 2025
92666dd
refactor(expr-ir): Generate dispatch methods (wip) (#3073)
dangotbanned Sep 1, 2025
b9611ab
refactor(expr-ir): Remove all default redefs
dangotbanned Sep 1, 2025
76d5b04
refactor: encode most existing dispatch stuff
dangotbanned Sep 1, 2025
c2b7d98
refactor: Add some alt `ExprIRConfig` constructors
dangotbanned Sep 1, 2025
ecaa33f
refactor: Rename `Ternary` -> `TernaryExpr`
dangotbanned Sep 1, 2025
8b2d3f4
refactor: Split out `Immutable`
dangotbanned Sep 1, 2025
75a4a67
refactor: Fix import cycle properly 😭
dangotbanned Sep 1, 2025
c939848
Start prepping `Function` version
dangotbanned Sep 1, 2025
e14f40b
feat(expr-ir): Fill out `Function` version
dangotbanned Sep 2, 2025
9e5b037
refactor: Replace `int_range` special casing
dangotbanned Sep 2, 2025
d867abf
refactor: Rename `ConcatHorizontal` -> `ConcatStr`
dangotbanned Sep 2, 2025
978272e
update `Not`
dangotbanned Sep 2, 2025
142e0c8
refactor: Add `HorizontalFunction`
dangotbanned Sep 2, 2025
8c8923a
refactor: Remove all default dispatch overrides
dangotbanned Sep 2, 2025
a1a313d
refactor: trim some fat
dangotbanned Sep 2, 2025
ab2fc94
refactor: Replace `with_accessor`
dangotbanned Sep 2, 2025
7122afc
refactor: Align *most* of the dispatch functions
dangotbanned Sep 2, 2025
bf8bc89
revert: Move `namespace` back
dangotbanned Sep 2, 2025
1fa42aa
add missing rename
dangotbanned Sep 2, 2025
011c70c
refactor: Simplify `FunctionOptions` usage
dangotbanned Sep 2, 2025
24dd22a
simplify, document dispatch overrides
dangotbanned Sep 2, 2025
8d64b4c
refactor: Remove unused dispatch override feature
dangotbanned Sep 2, 2025
2a6a7d6
refactor: Start removing need for `ExprDispatch._dispatch`
dangotbanned Sep 2, 2025
9c089ab
refactor: Remove `ExprDispatch._dispatch`
dangotbanned Sep 2, 2025
5569326
refactor: Reuse `HorizontalFunction`
dangotbanned Sep 2, 2025
f843716
remove some notes
dangotbanned Sep 2, 2025
78ef884
refactor: rename `*Config` classes
dangotbanned Sep 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions narwhals/_plan/arrow/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from narwhals._plan.arrow import functions as fn
from narwhals._plan.arrow.series import ArrowSeries
from narwhals._plan.common import ExprIR
from narwhals._plan.protocols import EagerDataFrame
from narwhals._plan.protocols import EagerDataFrame, namespace
from narwhals._utils import Version

if t.TYPE_CHECKING:
Expand Down Expand Up @@ -89,7 +89,7 @@ def to_dict(
return {ser.name: ser.to_list() for ser in it}

def _evaluate_irs(self, nodes: Iterable[NamedIR[ExprIR]], /) -> Iterator[ArrowSeries]:
ns = self.__narwhals_namespace__()
ns = namespace(self)
from_named_ir = ns._expr.from_named_ir
yield from ns._expr.align(from_named_ir(e, self) for e in nodes)

Expand Down
6 changes: 3 additions & 3 deletions narwhals/_plan/arrow/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from narwhals._plan.arrow.series import ArrowSeries
from narwhals._plan.arrow.typing import NativeScalar, StoresNativeT_co
from narwhals._plan.common import ExprIR, NamedIR, into_dtype
from narwhals._plan.protocols import EagerExpr, EagerScalar, ExprDispatch
from narwhals._plan.protocols import EagerExpr, EagerScalar, ExprDispatch, namespace
from narwhals._utils import (
Implementation,
Version,
Expand Down Expand Up @@ -245,8 +245,8 @@ def sort_by(self, node: expr.SortBy, frame: ArrowDataFrame, name: str) -> ArrowE
self._dispatch_expr(e, frame, f"<TEMP>_{idx}")
for idx, e in enumerate(node.by)
)
ns = self.__narwhals_namespace__()
df = ns._concat_horizontal((series, *by))

df = namespace(self)._concat_horizontal((series, *by))
names = df.columns[1:]
indices = pc.sort_indices(df.native, options=node.options.to_arrow(names))
result: ChunkedArrayAny = df.native.column(0).take(indices)
Expand Down
78 changes: 37 additions & 41 deletions narwhals/_plan/protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,14 @@
LazyScalarT_co = TypeVar("LazyScalarT_co", bound=LazyScalarAny, covariant=True)


def namespace(obj: _SupportsNarwhalsNamespace[NamespaceT_co], /) -> NamespaceT_co:
return obj.__narwhals_namespace__()


class _SupportsNarwhalsNamespace(Protocol[NamespaceT_co]):
def __narwhals_namespace__(self) -> NamespaceT_co: ...


# NOTE: Unlike the version in `nw._utils`, here `.version` it is public
class StoresVersion(Protocol):
_version: Version
Expand Down Expand Up @@ -144,15 +152,13 @@ def _length_required(

class ExprDispatch(StoresVersion, Protocol[FrameT_contra, R_co, NamespaceT_co]):
_DISPATCH: ClassVar[Mapping[type[ExprIR], Callable[[Any, ExprIR, Any, str], Any]]] = {
expr.Column: lambda self, node, frame, name: self.__narwhals_namespace__().col(
expr.Column: lambda self, node, frame, name: namespace(self).col(
node, frame, name
),
expr.Literal: lambda self, node, frame, name: self.__narwhals_namespace__().lit(
node, frame, name
),
expr.Len: lambda self, node, frame, name: self.__narwhals_namespace__().len(
expr.Literal: lambda self, node, frame, name: namespace(self).lit(
node, frame, name
),
expr.Len: lambda self, node, frame, name: namespace(self).len(node, frame, name),
expr.Cast: lambda self, node, frame, name: self.cast(node, frame, name),
expr.Sort: lambda self, node, frame, name: self.sort(node, frame, name),
expr.SortBy: lambda self, node, frame, name: self.sort_by(node, frame, name),
Expand Down Expand Up @@ -185,10 +191,9 @@ class ExprDispatch(StoresVersion, Protocol[FrameT_contra, R_co, NamespaceT_co]):
),
# NOTE: Keeping it simple for now
# When adding other `*_range` functions, this should instead map to `range_expr`
expr.RangeExpr: lambda self,
node,
frame,
name: self.__narwhals_namespace__().int_range(node, frame, name),
expr.RangeExpr: lambda self, node, frame, name: namespace(self).int_range(
node, frame, name
),
expr.OrderedWindowExpr: lambda self, node, frame, name: self.over_ordered(
node, frame, name
),
Expand All @@ -200,34 +205,27 @@ class ExprDispatch(StoresVersion, Protocol[FrameT_contra, R_co, NamespaceT_co]):
_DISPATCH_FUNCTION: ClassVar[
Mapping[type[Function], Callable[[Any, FunctionExpr, Any, str], Any]]
] = {
boolean.AnyHorizontal: lambda self,
node,
frame,
name: self.__narwhals_namespace__().any_horizontal(node, frame, name),
boolean.AllHorizontal: lambda self,
node,
frame,
name: self.__narwhals_namespace__().all_horizontal(node, frame, name),
F.SumHorizontal: lambda self,
node,
frame,
name: self.__narwhals_namespace__().sum_horizontal(node, frame, name),
F.MinHorizontal: lambda self,
node,
frame,
name: self.__narwhals_namespace__().min_horizontal(node, frame, name),
F.MaxHorizontal: lambda self,
node,
frame,
name: self.__narwhals_namespace__().max_horizontal(node, frame, name),
F.MeanHorizontal: lambda self,
node,
frame,
name: self.__narwhals_namespace__().mean_horizontal(node, frame, name),
strings.ConcatHorizontal: lambda self,
node,
frame,
name: self.__narwhals_namespace__().concat_str(node, frame, name),
boolean.AnyHorizontal: lambda self, node, frame, name: namespace(
self
).any_horizontal(node, frame, name),
boolean.AllHorizontal: lambda self, node, frame, name: namespace(
self
).all_horizontal(node, frame, name),
F.SumHorizontal: lambda self, node, frame, name: namespace(self).sum_horizontal(
node, frame, name
),
F.MinHorizontal: lambda self, node, frame, name: namespace(self).min_horizontal(
node, frame, name
),
F.MaxHorizontal: lambda self, node, frame, name: namespace(self).max_horizontal(
node, frame, name
),
F.MeanHorizontal: lambda self, node, frame, name: namespace(self).mean_horizontal(
node, frame, name
),
strings.ConcatHorizontal: lambda self, node, frame, name: namespace(
self
).concat_str(node, frame, name),
F.Pow: lambda self, node, frame, name: self.pow(node, frame, name),
F.FillNull: lambda self, node, frame, name: self.fill_null(node, frame, name),
boolean.IsBetween: lambda self, node, frame, name: self.is_between(
Expand Down Expand Up @@ -704,12 +702,10 @@ class EagerDataFrame(
):
def __narwhals_namespace__(self) -> EagerNamespace[Self, SeriesT, Any, Any]: ...
def select(self, irs: Seq[NamedIR]) -> Self:
ns = self.__narwhals_namespace__()
return ns._concat_horizontal(self._evaluate_irs(irs))
return self.__narwhals_namespace__()._concat_horizontal(self._evaluate_irs(irs))

def with_columns(self, irs: Seq[NamedIR]) -> Self:
ns = self.__narwhals_namespace__()
return ns._concat_horizontal(self._evaluate_irs(irs))
return self.__narwhals_namespace__()._concat_horizontal(self._evaluate_irs(irs))


class CompliantSeries(StoresVersion, Protocol[NativeSeriesT]):
Expand Down
Loading