-
Notifications
You must be signed in to change notification settings - Fork 180
feat: support window operations for DuckDB #2263
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
d6fd44a
1655fbc
8c36db7
ba5b646
8235348
f2cc84d
06de928
35e603c
38a4710
07872ca
1c08990
eaf0172
6485303
5e50682
65ae7cb
bc515af
eecff03
f1a50e8
231f1f3
6fb1b76
de9f375
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,5 +1,6 @@ | ||||||||||
| from __future__ import annotations | ||||||||||
|
|
||||||||||
| import contextlib | ||||||||||
| import operator | ||||||||||
| from typing import TYPE_CHECKING | ||||||||||
| from typing import Any | ||||||||||
|
|
@@ -20,6 +21,8 @@ | |||||||||
| from narwhals._duckdb.expr_name import DuckDBExprNameNamespace | ||||||||||
| from narwhals._duckdb.expr_str import DuckDBExprStringNamespace | ||||||||||
| from narwhals._duckdb.expr_struct import DuckDBExprStructNamespace | ||||||||||
| from narwhals._duckdb.utils import generate_order_by_sql | ||||||||||
| from narwhals._duckdb.utils import generate_partition_by_sql | ||||||||||
| from narwhals._duckdb.utils import lit | ||||||||||
| from narwhals._duckdb.utils import maybe_evaluate_expr | ||||||||||
| from narwhals._duckdb.utils import narwhals_to_native_dtype | ||||||||||
|
|
@@ -33,11 +36,15 @@ | |||||||||
|
|
||||||||||
| from narwhals._duckdb.dataframe import DuckDBLazyFrame | ||||||||||
| from narwhals._duckdb.namespace import DuckDBNamespace | ||||||||||
| from narwhals._duckdb.typing import WindowFunction | ||||||||||
| from narwhals._expression_parsing import ExprMetadata | ||||||||||
| from narwhals.dtypes import DType | ||||||||||
| from narwhals.utils import Version | ||||||||||
| from narwhals.utils import _FullContext | ||||||||||
|
|
||||||||||
| with contextlib.suppress(ImportError): # requires duckdb>=1.3.0 | ||||||||||
| from duckdb import SQLExpression | ||||||||||
|
|
||||||||||
|
Comment on lines
+45
to
+47
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @MarcoGorelli this is still causing typing issues outside of CI on a fresh install. Fix 1I've been using this locally, but is more of a workaround: diff --git a/narwhals/_duckdb/expr.py b/narwhals/_duckdb/expr.py
index fd371b73..0ec4681d 100644
--- a/narwhals/_duckdb/expr.py
+++ b/narwhals/_duckdb/expr.py
@@ -41,8 +41,11 @@ if TYPE_CHECKING:
from narwhals.utils import Version
from narwhals.utils import _FullContext
-with contextlib.suppress(ImportError): # requires duckdb>=1.3.0
- from duckdb import SQLExpression
+if not TYPE_CHECKING:
+ with contextlib.suppress(ImportError): # requires duckdb>=1.3.0
+ from duckdb import SQLExpression
+else:
+ from duckdb import Expression as SQLExpression
class DuckDBExpr(LazyExpr["DuckDBLazyFrame", "duckdb.Expression"]):Fix 2Specifying this requirement in Lines 22 to 24 in de9f375
narwhals/.github/workflows/extremes.yml Line 185 in de9f375
|
||||||||||
|
|
||||||||||
| class DuckDBExpr(LazyExpr["DuckDBLazyFrame", "duckdb.Expression"]): | ||||||||||
| _implementation = Implementation.DUCKDB | ||||||||||
|
|
@@ -59,6 +66,7 @@ def __init__( | |||||||||
| self._alias_output_names = alias_output_names | ||||||||||
| self._backend_version = backend_version | ||||||||||
| self._version = version | ||||||||||
| self._window_function: WindowFunction | None = None | ||||||||||
| self._metadata: ExprMetadata | None = None | ||||||||||
|
|
||||||||||
| def __call__(self: Self, df: DuckDBLazyFrame) -> Sequence[duckdb.Expression]: | ||||||||||
|
|
@@ -83,15 +91,31 @@ def _with_metadata(self, metadata: ExprMetadata) -> Self: | |||||||||
| backend_version=self._backend_version, | ||||||||||
| version=self._version, | ||||||||||
| ) | ||||||||||
| if func := self._window_function: | ||||||||||
| expr = expr._with_window_function(func) | ||||||||||
| expr._metadata = metadata | ||||||||||
| return expr | ||||||||||
|
|
||||||||||
| def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Self: | ||||||||||
| if kind is ExprKind.AGGREGATION: | ||||||||||
| msg = "Broadcasting aggregations is not yet supported for DuckDB." | ||||||||||
| if kind is ExprKind.LITERAL: | ||||||||||
| return self | ||||||||||
| if self._backend_version < (1, 3): | ||||||||||
| msg = "At least version 1.3 of DuckDB is required for binary operations between aggregates and columns." | ||||||||||
| raise NotImplementedError(msg) | ||||||||||
| # For literals, DuckDB does its own broadcasting. | ||||||||||
| return self | ||||||||||
|
|
||||||||||
| template = "{expr} over ()" | ||||||||||
|
|
||||||||||
| def func(df: DuckDBLazyFrame) -> Sequence[duckdb.Expression]: | ||||||||||
| return [SQLExpression(template.format(expr=expr)) for expr in self(df)] | ||||||||||
|
|
||||||||||
| return self.__class__( | ||||||||||
| func, | ||||||||||
| function_name=self._function_name, | ||||||||||
| evaluate_output_names=self._evaluate_output_names, | ||||||||||
| alias_output_names=self._alias_output_names, | ||||||||||
| backend_version=self._backend_version, | ||||||||||
| version=self._version, | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| @classmethod | ||||||||||
| def from_column_names( | ||||||||||
|
|
@@ -167,6 +191,21 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: | |||||||||
| version=self._version, | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| def _with_window_function( | ||||||||||
| self: Self, | ||||||||||
| window_function: WindowFunction, | ||||||||||
| ) -> Self: | ||||||||||
| result = self.__class__( | ||||||||||
| self._call, | ||||||||||
| function_name=self._function_name, | ||||||||||
| evaluate_output_names=self._evaluate_output_names, | ||||||||||
| alias_output_names=self._alias_output_names, | ||||||||||
| backend_version=self._backend_version, | ||||||||||
| version=self._version, | ||||||||||
| ) | ||||||||||
| result._window_function = window_function | ||||||||||
| return result | ||||||||||
|
|
||||||||||
| def __and__(self: Self, other: DuckDBExpr) -> Self: | ||||||||||
| return self._from_call( | ||||||||||
| lambda _input, other: _input & other, | ||||||||||
|
|
@@ -438,6 +477,40 @@ def null_count(self: Self) -> Self: | |||||||||
| "null_count", | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| def over( | ||||||||||
| self: Self, | ||||||||||
| partition_by: Sequence[str], | ||||||||||
| order_by: Sequence[str] | None, | ||||||||||
| ) -> Self: | ||||||||||
| if self._backend_version < (1, 3): | ||||||||||
| msg = "At least version 1.3 of DuckDB is required for `over` operation." | ||||||||||
| raise NotImplementedError(msg) | ||||||||||
| if (window_function := self._window_function) is not None: | ||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. AFAIK, you only need the Going to guess this is from experience battling
Suggested change
|
||||||||||
| assert order_by is not None # noqa: S101 | ||||||||||
|
|
||||||||||
| def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: | ||||||||||
| return [ | ||||||||||
| window_function(expr, partition_by, order_by) | ||||||||||
| for expr in self._call(df) | ||||||||||
| ] | ||||||||||
| else: | ||||||||||
| partition_by_sql = generate_partition_by_sql(*partition_by) | ||||||||||
| template = f"{{expr}} over ({partition_by_sql})" | ||||||||||
|
|
||||||||||
| def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: | ||||||||||
| return [ | ||||||||||
| SQLExpression(template.format(expr=expr)) for expr in self._call(df) | ||||||||||
| ] | ||||||||||
|
|
||||||||||
| return self.__class__( | ||||||||||
| func, | ||||||||||
| function_name=self._function_name + "->over", | ||||||||||
| evaluate_output_names=self._evaluate_output_names, | ||||||||||
| alias_output_names=self._alias_output_names, | ||||||||||
| backend_version=self._backend_version, | ||||||||||
| version=self._version, | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| def is_null(self: Self) -> Self: | ||||||||||
| return self._from_call(lambda _input: _input.isnull(), "is_null") | ||||||||||
|
|
||||||||||
|
|
@@ -461,6 +534,42 @@ def round(self: Self, decimals: int) -> Self: | |||||||||
| lambda _input: FunctionExpression("round", _input, lit(decimals)), "round" | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| def cum_sum(self, *, reverse: bool) -> Self: | ||||||||||
| def func( | ||||||||||
| _input: duckdb.Expression, | ||||||||||
| partition_by: Sequence[str], | ||||||||||
| order_by: Sequence[str], | ||||||||||
| ) -> duckdb.Expression: | ||||||||||
| order_by_sql = generate_order_by_sql(*order_by, ascending=not reverse) | ||||||||||
| partition_by_sql = generate_partition_by_sql(*partition_by) | ||||||||||
| sql = f"sum ({_input}) over ({partition_by_sql} {order_by_sql} rows between unbounded preceding and current row)" | ||||||||||
| return SQLExpression(sql) | ||||||||||
|
|
||||||||||
| return self._with_window_function(func) | ||||||||||
|
|
||||||||||
| def rolling_sum(self, window_size: int, *, min_samples: int, center: bool) -> Self: | ||||||||||
| if center: | ||||||||||
| half = (window_size - 1) // 2 | ||||||||||
| remainder = (window_size - 1) % 2 | ||||||||||
| start = f"{half + remainder} preceding" | ||||||||||
| end = f"{half} following" | ||||||||||
| else: | ||||||||||
| start = f"{window_size - 1} preceding" | ||||||||||
| end = "current row" | ||||||||||
|
|
||||||||||
| def func( | ||||||||||
| _input: duckdb.Expression, | ||||||||||
| partition_by: Sequence[str], | ||||||||||
| order_by: Sequence[str], | ||||||||||
| ) -> duckdb.Expression: | ||||||||||
| order_by_sql = generate_order_by_sql(*order_by, ascending=True) | ||||||||||
| partition_by_sql = generate_partition_by_sql(*partition_by) | ||||||||||
| window = f"({partition_by_sql} {order_by_sql} rows between {start} and {end})" | ||||||||||
| sql = f"case when count({_input}) over {window} >= {min_samples} then sum({_input}) over {window} else null end" | ||||||||||
| return SQLExpression(sql) | ||||||||||
|
|
||||||||||
| return self._with_window_function(func) | ||||||||||
|
|
||||||||||
| def fill_null( | ||||||||||
| self: Self, value: Self | Any, strategy: Any, limit: int | None | ||||||||||
| ) -> Self: | ||||||||||
|
|
@@ -507,10 +616,7 @@ def struct(self: Self) -> DuckDBExprStructNamespace: | |||||||||
| is_unique = not_implemented() | ||||||||||
| is_first_distinct = not_implemented() | ||||||||||
| is_last_distinct = not_implemented() | ||||||||||
| cum_sum = not_implemented() | ||||||||||
| cum_count = not_implemented() | ||||||||||
| cum_min = not_implemented() | ||||||||||
| cum_max = not_implemented() | ||||||||||
| cum_prod = not_implemented() | ||||||||||
| over = not_implemented() | ||||||||||
| rolling_sum = not_implemented() | ||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,16 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from typing import TYPE_CHECKING | ||
| from typing import Protocol | ||
| from typing import Sequence | ||
|
|
||
| if TYPE_CHECKING: | ||
| import duckdb | ||
|
|
||
| class WindowFunction(Protocol): | ||
| def __call__( | ||
| self, | ||
| _input: duckdb.Expression, | ||
| partition_by: Sequence[str], | ||
| order_by: Sequence[str], | ||
| ) -> duckdb.Expression: ... | ||
|
Comment on lines
+10
to
+16
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've been meaning to ask about this for a while now. Is that naming convention intended to signal positional-only for It seems similar to a convention that was common before PEP 570 β Python Positional-Only Parameters. Current class WindowFunction(Protocol):
def __call__(
self,
_input: duckdb.Expression,
partition_by: Sequence[str],
order_by: Sequence[str],
) -> duckdb.Expression: ...Before PEP 570 class WindowFunction(Protocol):
def __call__(
self,
__input: duckdb.Expression,
partition_by: Sequence[str],
order_by: Sequence[str],
) -> duckdb.Expression: ...After PEP 570 class WindowFunction(Protocol):
def __call__(
self,
input: duckdb.Expression,
/,
partition_by: Sequence[str],
order_by: Sequence[str],
) -> duckdb.Expression: ...
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah we could use
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Whatever you feel works π Side noteOne of the cool things about positional-only args is that you can use different names - and it still works the same at runtime and to a type checker. Here, that might mean you could use |
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we wait for the duckdb release before merging this PR? You are ahead of them π
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nah let's be one step ahead so when they release we're ready π₯