Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions .github/workflows/extremes.yml
Original file line number Diff line number Diff line change
Expand Up @@ -179,9 +179,10 @@ jobs:
run: |
uv pip uninstall dask dask-expr --system
python -m pip install git+https://github.com/dask/distributed git+https://github.com/dask/dask
- name: install duckdb
- name: install duckdb nightly
run: |
python -m pip install -U --pre duckdb
uv pip uninstall duckdb --system
uv pip install -U --pre duckdb --system
- name: show-deps
run: uv pip freeze
- name: Assert nightlies dependencies
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ jobs:
cache-dependency-glob: "pyproject.toml"
- name: install-reqs
run: uv pip install -e ".[modin, dask]" --group core-tests --group extra --system
- name: install duckdb nightly
run: uv pip install -U --pre duckdb --system
Comment on lines +87 to +88
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we wait for the duckdb release before merging this PR? You are ahead of them πŸ˜‚

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nah let's be one step ahead so when they release we're ready πŸ”₯

- name: install pyspark
run: uv pip install -e ".[pyspark]" --system
# PySpark is not yet available on Python3.12+
Expand Down
2 changes: 2 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,7 @@ help: ## Display this help screen

.PHONY: typing
typing: ## Run typing checks
# install duckdb nightly so mypy recognises duckdb.SQLExpression
$(VENV_BIN)/uv pip install -U --pre duckdb
$(VENV_BIN)/uv pip install -e . --group typing
$(VENV_BIN)/mypy
6 changes: 4 additions & 2 deletions narwhals/_duckdb/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from duckdb import FunctionExpression

from narwhals._duckdb.utils import evaluate_exprs
from narwhals._duckdb.utils import generate_partition_by_sql
from narwhals._duckdb.utils import lit
from narwhals._duckdb.utils import native_to_narwhals_dtype
from narwhals.dependencies import get_duckdb
Expand Down Expand Up @@ -363,11 +364,12 @@ def unique(
keep_condition = f"where {count_name}=1"
else:
keep_condition = f"where {idx_name}=1"
partition_by_sql = generate_partition_by_sql(*subset)
query = f"""
with cte as (
select *,
row_number() over (partition by {",".join(subset)}) as {idx_name},
count(*) over (partition by {",".join(subset)}) as {count_name}
row_number() over ({partition_by_sql}) as {idx_name},
count(*) over ({partition_by_sql}) as {count_name}
from rel
)
select * exclude ({idx_name}, {count_name}) from cte {keep_condition}
Expand Down
120 changes: 113 additions & 7 deletions narwhals/_duckdb/expr.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import contextlib
import operator
from typing import TYPE_CHECKING
from typing import Any
Expand All @@ -20,6 +21,8 @@
from narwhals._duckdb.expr_name import DuckDBExprNameNamespace
from narwhals._duckdb.expr_str import DuckDBExprStringNamespace
from narwhals._duckdb.expr_struct import DuckDBExprStructNamespace
from narwhals._duckdb.utils import generate_order_by_sql
from narwhals._duckdb.utils import generate_partition_by_sql
from narwhals._duckdb.utils import lit
from narwhals._duckdb.utils import maybe_evaluate_expr
from narwhals._duckdb.utils import narwhals_to_native_dtype
Expand All @@ -33,11 +36,15 @@

from narwhals._duckdb.dataframe import DuckDBLazyFrame
from narwhals._duckdb.namespace import DuckDBNamespace
from narwhals._duckdb.typing import WindowFunction
from narwhals._expression_parsing import ExprMetadata
from narwhals.dtypes import DType
from narwhals.utils import Version
from narwhals.utils import _FullContext

with contextlib.suppress(ImportError): # requires duckdb>=1.3.0
from duckdb import SQLExpression

Comment on lines +45 to +47
Copy link
Member

@dangotbanned dangotbanned Mar 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MarcoGorelli this is still causing typing issues outside of CI on a fresh install.

Fix 1

I've been using this locally, but is more of a workaround:

diff --git a/narwhals/_duckdb/expr.py b/narwhals/_duckdb/expr.py
index fd371b73..0ec4681d 100644
--- a/narwhals/_duckdb/expr.py
+++ b/narwhals/_duckdb/expr.py
@@ -41,8 +41,11 @@ if TYPE_CHECKING:
     from narwhals.utils import Version
     from narwhals.utils import _FullContext
 
-with contextlib.suppress(ImportError):  # requires duckdb>=1.3.0
-    from duckdb import SQLExpression
+if not TYPE_CHECKING:
+    with contextlib.suppress(ImportError):  # requires duckdb>=1.3.0
+        from duckdb import SQLExpression
+else:
+    from duckdb import Expression as SQLExpression
 
 
 class DuckDBExpr(LazyExpr["DuckDBLazyFrame", "duckdb.Expression"]):

Fix 2

Specifying this requirement in pyproject.toml somewhere.
Not sure on exactly how though

narwhals/Makefile

Lines 22 to 24 in de9f375

typing: ## Run typing checks
# install duckdb nightly so mypy recognises duckdb.SQLExpression
$(VENV_BIN)/uv pip install -U --pre duckdb

uv pip install -U --pre duckdb --system


class DuckDBExpr(LazyExpr["DuckDBLazyFrame", "duckdb.Expression"]):
_implementation = Implementation.DUCKDB
Expand All @@ -59,6 +66,7 @@ def __init__(
self._alias_output_names = alias_output_names
self._backend_version = backend_version
self._version = version
self._window_function: WindowFunction | None = None
self._metadata: ExprMetadata | None = None

def __call__(self: Self, df: DuckDBLazyFrame) -> Sequence[duckdb.Expression]:
Expand All @@ -83,15 +91,31 @@ def _with_metadata(self, metadata: ExprMetadata) -> Self:
backend_version=self._backend_version,
version=self._version,
)
if func := self._window_function:
expr = expr._with_window_function(func)
expr._metadata = metadata
return expr

def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Self:
if kind is ExprKind.AGGREGATION:
msg = "Broadcasting aggregations is not yet supported for DuckDB."
if kind is ExprKind.LITERAL:
return self
if self._backend_version < (1, 3):
msg = "At least version 1.3 of DuckDB is required for binary operations between aggregates and columns."
raise NotImplementedError(msg)
# For literals, DuckDB does its own broadcasting.
return self

template = "{expr} over ()"

def func(df: DuckDBLazyFrame) -> Sequence[duckdb.Expression]:
return [SQLExpression(template.format(expr=expr)) for expr in self(df)]

return self.__class__(
func,
function_name=self._function_name,
evaluate_output_names=self._evaluate_output_names,
alias_output_names=self._alias_output_names,
backend_version=self._backend_version,
version=self._version,
)

@classmethod
def from_column_names(
Expand Down Expand Up @@ -167,6 +191,21 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]:
version=self._version,
)

def _with_window_function(
self: Self,
window_function: WindowFunction,
) -> Self:
result = self.__class__(
self._call,
function_name=self._function_name,
evaluate_output_names=self._evaluate_output_names,
alias_output_names=self._alias_output_names,
backend_version=self._backend_version,
version=self._version,
)
result._window_function = window_function
return result

def __and__(self: Self, other: DuckDBExpr) -> Self:
return self._from_call(
lambda _input, other: _input & other,
Expand Down Expand Up @@ -438,6 +477,40 @@ def null_count(self: Self) -> Self:
"null_count",
)

def over(
self: Self,
partition_by: Sequence[str],
order_by: Sequence[str] | None,
) -> Self:
if self._backend_version < (1, 3):
msg = "At least version 1.3 of DuckDB is required for `over` operation."
raise NotImplementedError(msg)
if (window_function := self._window_function) is not None:
Copy link
Member

@dangotbanned dangotbanned Mar 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIK, you only need the is not None when the type overrides __bool__.

Going to guess this is from experience battling pandas, polars, and (maybe) numpy who all tell you off for using __bool__

Suggested change
if (window_function := self._window_function) is not None:
if (window_function := self._window_function):

assert order_by is not None # noqa: S101

def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]:
return [
window_function(expr, partition_by, order_by)
for expr in self._call(df)
]
else:
partition_by_sql = generate_partition_by_sql(*partition_by)
template = f"{{expr}} over ({partition_by_sql})"

def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]:
return [
SQLExpression(template.format(expr=expr)) for expr in self._call(df)
]

return self.__class__(
func,
function_name=self._function_name + "->over",
evaluate_output_names=self._evaluate_output_names,
alias_output_names=self._alias_output_names,
backend_version=self._backend_version,
version=self._version,
)

def is_null(self: Self) -> Self:
return self._from_call(lambda _input: _input.isnull(), "is_null")

Expand All @@ -461,6 +534,42 @@ def round(self: Self, decimals: int) -> Self:
lambda _input: FunctionExpression("round", _input, lit(decimals)), "round"
)

def cum_sum(self, *, reverse: bool) -> Self:
def func(
_input: duckdb.Expression,
partition_by: Sequence[str],
order_by: Sequence[str],
) -> duckdb.Expression:
order_by_sql = generate_order_by_sql(*order_by, ascending=not reverse)
partition_by_sql = generate_partition_by_sql(*partition_by)
sql = f"sum ({_input}) over ({partition_by_sql} {order_by_sql} rows between unbounded preceding and current row)"
return SQLExpression(sql)

return self._with_window_function(func)

def rolling_sum(self, window_size: int, *, min_samples: int, center: bool) -> Self:
if center:
half = (window_size - 1) // 2
remainder = (window_size - 1) % 2
start = f"{half + remainder} preceding"
end = f"{half} following"
else:
start = f"{window_size - 1} preceding"
end = "current row"

def func(
_input: duckdb.Expression,
partition_by: Sequence[str],
order_by: Sequence[str],
) -> duckdb.Expression:
order_by_sql = generate_order_by_sql(*order_by, ascending=True)
partition_by_sql = generate_partition_by_sql(*partition_by)
window = f"({partition_by_sql} {order_by_sql} rows between {start} and {end})"
sql = f"case when count({_input}) over {window} >= {min_samples} then sum({_input}) over {window} else null end"
return SQLExpression(sql)

return self._with_window_function(func)

def fill_null(
self: Self, value: Self | Any, strategy: Any, limit: int | None
) -> Self:
Expand Down Expand Up @@ -507,10 +616,7 @@ def struct(self: Self) -> DuckDBExprStructNamespace:
is_unique = not_implemented()
is_first_distinct = not_implemented()
is_last_distinct = not_implemented()
cum_sum = not_implemented()
cum_count = not_implemented()
cum_min = not_implemented()
cum_max = not_implemented()
cum_prod = not_implemented()
over = not_implemented()
rolling_sum = not_implemented()
16 changes: 16 additions & 0 deletions narwhals/_duckdb/typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import Protocol
from typing import Sequence

if TYPE_CHECKING:
import duckdb

class WindowFunction(Protocol):
def __call__(
self,
_input: duckdb.Expression,
partition_by: Sequence[str],
order_by: Sequence[str],
) -> duckdb.Expression: ...
Comment on lines +10 to +16
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've been meaning to ask about this for a while now.

Is that naming convention intended to signal positional-only for _input?

It seems similar to a convention that was common before PEP 570 – Python Positional-Only Parameters.

Current

    class WindowFunction(Protocol):
        def __call__(
            self,
            _input: duckdb.Expression,
            partition_by: Sequence[str],
            order_by: Sequence[str],
        ) -> duckdb.Expression: ...

Before PEP 570

    class WindowFunction(Protocol):
        def __call__(
            self,
            __input: duckdb.Expression,
            partition_by: Sequence[str],
            order_by: Sequence[str],
        ) -> duckdb.Expression: ...

After PEP 570

    class WindowFunction(Protocol):
        def __call__(
            self,
            input: duckdb.Expression,
            /,
            partition_by: Sequence[str],
            order_by: Sequence[str],
        ) -> duckdb.Expression: ...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah _input wasn't a very good name, i don't think i'd put much thought into it at the time and now it's all over

we could use native_expr in its place?

Copy link
Member

@dangotbanned dangotbanned Mar 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could use native_expr in its place?

Whatever you feel works πŸ‘

Side note

One of the cool things about positional-only args is that you can use different names - and it still works the same at runtime and to a type checker.
Was a bit of a 🀯 when I learned that.

Here, that might mean you could use expr when it is unambiguous; but native_expr when either compliant_expr or native_expr could be confused

15 changes: 15 additions & 0 deletions narwhals/_duckdb/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,3 +189,18 @@ def narwhals_to_native_dtype(dtype: DType | type[DType], version: Version) -> st
return f"{duckdb_inner}{duckdb_shape_fmt}"
msg = f"Unknown dtype: {dtype}" # pragma: no cover
raise AssertionError(msg)


def generate_partition_by_sql(*partition_by: str) -> str:
if not partition_by:
return ""
by_sql = ", ".join([f'"{x}"' for x in partition_by])
return f"partition by {by_sql}"


def generate_order_by_sql(*order_by: str, ascending: bool) -> str:
if ascending:
by_sql = ", ".join([f'"{x}" asc nulls first' for x in order_by])
else:
by_sql = ", ".join([f'"{x}" desc nulls last' for x in order_by])
return f"order by {by_sql}"
13 changes: 6 additions & 7 deletions narwhals/_spark_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,13 @@ def __call__(self: Self, df: SparkLikeLazyFrame) -> Sequence[Column]:
return self._call(df)

def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Self:
if kind is ExprKind.LITERAL:
return self

def func(df: SparkLikeLazyFrame) -> Sequence[Column]:
if kind is ExprKind.AGGREGATION:
return [
result.over(df._Window().partitionBy(df._F.lit(1)))
for result in self(df)
]
# Let PySpark do its own broadcasting for literals.
return self(df)
return [
result.over(df._Window().partitionBy(df._F.lit(1))) for result in self(df)
]

return self.__class__(
func,
Expand Down
9 changes: 5 additions & 4 deletions narwhals/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,10 +233,11 @@ def is_pyspark_dataframe(df: Any) -> TypeIs[pyspark_sql.DataFrame]:

def is_sqlframe_dataframe(df: Any) -> TypeIs[SQLFrameDataFrame]:
"""Check whether `df` is a SQLFrame DataFrame without importing SQLFrame."""
return bool(
(sqlframe := get_sqlframe()) is not None
and isinstance(df, sqlframe.base.dataframe.BaseDataFrame)
)
if get_sqlframe() is not None:
from sqlframe.base.dataframe import BaseDataFrame

return isinstance(df, BaseDataFrame)
return False # pragma: no cover


def is_numpy_array(arr: Any | _NDArray[_ShapeT]) -> TypeIs[_NDArray[_ShapeT]]:
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ omit = [
'narwhals/stable/v1/typing.py',
'narwhals/this.py',
'narwhals/_arrow/typing.py',
'narwhals/_duckdb/typing.py',
'narwhals/_spark_like/typing.py',
# we can't run this in every environment that we measure coverage on due to upper-bound constraits
'narwhals/_ibis/*',
Expand Down
Loading
Loading