Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
c7f6e2f
ci(typing): add a basic `pyright` config
dangotbanned Feb 17, 2025
0a4c7c4
ci: add `pyright` to `[typing]`
dangotbanned Feb 17, 2025
9c74ceb
ci: match same files as mypy
dangotbanned Feb 17, 2025
68cef04
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 17, 2025
c9e3a31
ci(typing): add "Run pyright" step
dangotbanned Feb 17, 2025
af0677e
try activating first?
dangotbanned Feb 17, 2025
6f2b75e
fix(typing): use class instead of generic alias
dangotbanned Feb 17, 2025
cf1fb46
ci(typing): disable `[reportUnusedExpression]`
dangotbanned Feb 17, 2025
4a8dd5d
fix: Variable not allowed in type expression Pylance [reportInvalidTy…
dangotbanned Feb 17, 2025
8a4db22
ignore incorrect stub
dangotbanned Feb 17, 2025
b3d8d53
fix(typing): resolve some `_dask.group_by`
dangotbanned Feb 17, 2025
65e7ff0
ignore `numpy` test shape
dangotbanned Feb 17, 2025
d87705c
Merge remote-tracking branch 'upstream/main' into dev-pyright
dangotbanned Feb 18, 2025
e182c9c
chore: fix typing, simplify `_polars.group_by`
dangotbanned Feb 18, 2025
01ff424
fix(typing): try public re-export of `_expr`
dangotbanned Feb 18, 2025
b26c94c
revert(typing): reintroduce arrow errors `evaluate_output_names`
dangotbanned Feb 18, 2025
1acc660
fix(typing): resolve dask group by
dangotbanned Feb 18, 2025
bde040e
refactor: move `functools` import
dangotbanned Feb 18, 2025
cd8352f
Merge remote-tracking branch 'upstream/main' into dev-pyright
dangotbanned Feb 18, 2025
6f54657
Merge branch 'main' into dev-pyright
dangotbanned Feb 20, 2025
dc49d5f
chore(typing): ignore all `_evaluate_output_names` errors
dangotbanned Feb 20, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/typing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,7 @@ jobs:
run: |
source .venv/bin/activate
make typing
- name: Run pyright
run: |
source .venv/bin/activate
pyright
2 changes: 1 addition & 1 deletion narwhals/_dask/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ def join(
df = self._native_frame.merge(
other_native,
how="outer",
indicator=indicator_token,
indicator=indicator_token, # pyright: ignore[reportArgumentType]
left_on=left_on,
right_on=left_on,
)
Expand Down
2 changes: 1 addition & 1 deletion narwhals/_dask/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(
self._call = call
self._depth = depth
self._function_name = function_name
self._evaluate_output_names = evaluate_output_names
self._evaluate_output_names = evaluate_output_names # pyright: ignore[reportAttributeAccessIssue]
self._alias_output_names = alias_output_names
self._backend_version = backend_version
self._version = version
Expand Down
52 changes: 24 additions & 28 deletions narwhals/_dask/group_by.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from __future__ import annotations

import re
from functools import partial
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
from typing import Mapping
from typing import Sequence

import dask.dataframe as dd
Expand All @@ -18,48 +20,42 @@

if TYPE_CHECKING:
import pandas as pd
from pandas.core.groupby import SeriesGroupBy as _PandasSeriesGroupBy
from typing_extensions import Self
from typing_extensions import TypeAlias

from narwhals._dask.dataframe import DaskLazyFrame
from narwhals._dask.expr import DaskExpr
from narwhals.typing import CompliantExpr

PandasSeriesGroupBy: TypeAlias = "_PandasSeriesGroupBy[Any, Any]"
_AggFn: TypeAlias = Callable[..., Any]
Aggregation: TypeAlias = "str | _AggFn"

def n_unique() -> dd.Aggregation:
def chunk(s: pd.core.groupby.generic.SeriesGroupBy) -> pd.Series[Any]:
return s.nunique(dropna=False) # type: ignore[no-any-return]
from dask_expr._groupby import GroupBy as _DaskGroupBy
else:
_DaskGroupBy = dx._groupby.GroupBy

def agg(s0: pd.core.groupby.generic.SeriesGroupBy) -> pd.Series[Any]:
return s0.sum() # type: ignore[no-any-return]

return dd.Aggregation(
name="nunique",
chunk=chunk,
agg=agg,
)
def n_unique() -> dd.Aggregation:
def chunk(s: PandasSeriesGroupBy) -> pd.Series[Any]:
return s.nunique(dropna=False)

def agg(s0: PandasSeriesGroupBy) -> pd.Series[Any]:
return s0.sum()

def var(
ddof: int = 1,
) -> Callable[
[pd.core.groupby.generic.SeriesGroupBy], pd.core.groupby.generic.SeriesGroupBy
]:
from functools import partial
return dd.Aggregation(name="nunique", chunk=chunk, agg=agg)

return partial(dx._groupby.GroupBy.var, ddof=ddof)

def var(ddof: int = 1) -> _AggFn:
return partial(_DaskGroupBy.var, ddof=ddof)

def std(
ddof: int = 1,
) -> Callable[
[pd.core.groupby.generic.SeriesGroupBy], pd.core.groupby.generic.SeriesGroupBy
]:
from functools import partial

return partial(dx._groupby.GroupBy.std, ddof=ddof)
def std(ddof: int = 1) -> _AggFn:
return partial(_DaskGroupBy.std, ddof=ddof)


POLARS_TO_DASK_AGGREGATIONS = {
POLARS_TO_DASK_AGGREGATIONS: Mapping[str, Aggregation] = {
"sum": "sum",
"mean": "mean",
"median": "median",
Expand Down Expand Up @@ -101,7 +97,7 @@ def _from_native_frame(self: Self, df: DaskLazyFrame) -> DaskLazyFrame:
from narwhals._dask.dataframe import DaskLazyFrame

return DaskLazyFrame(
df,
df, # pyright: ignore[reportArgumentType]
backend_version=self._df._backend_version,
version=self._df._version,
validate_column_names=True,
Expand Down Expand Up @@ -134,7 +130,7 @@ def agg_dask(
break

if all_simple_aggs:
simple_aggregations: dict[str, tuple[str, str | dd.Aggregation]] = {}
simple_aggregations: dict[str, tuple[str, Aggregation]] = {}
for expr in exprs:
output_names, aliases = evaluate_output_names_and_aliases(expr, df, keys)
if expr._depth == 0:
Expand All @@ -149,7 +145,7 @@ def agg_dask(

# e.g. agg(nw.mean('a')) # noqa: ERA001
function_name = re.sub(r"(\w+->)", "", expr._function_name)
kwargs = (
kwargs: dict[str, Any] = (
{"ddof": expr._kwargs["ddof"]} if function_name in {"std", "var"} else {} # type: ignore[attr-defined]
)

Expand Down
2 changes: 1 addition & 1 deletion narwhals/_dask/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ def __init__(
self._call = call
self._depth = depth
self._function_name = function_name
self._evaluate_output_names = evaluate_output_names
self._evaluate_output_names = evaluate_output_names # pyright: ignore[reportAttributeAccessIssue]
self._alias_output_names = alias_output_names
self._kwargs = kwargs

Expand Down
2 changes: 1 addition & 1 deletion narwhals/_dask/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def validate_comparand(lhs: dx.Series, rhs: dx.Series) -> None:
except ModuleNotFoundError: # pragma: no cover
import dask_expr as dx

if not dx._expr.are_co_aligned(lhs._expr, rhs._expr): # pragma: no cover
if not dx.expr.are_co_aligned(lhs._expr, rhs._expr): # pragma: no cover
# are_co_aligned is a method which cheaply checks if two Dask expressions
# have the same index, and therefore don't require index alignment.
# If someone only operates on a Dask DataFrame via expressions, then this
Expand Down
2 changes: 1 addition & 1 deletion narwhals/_duckdb/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(
) -> None:
self._call = call
self._function_name = function_name
self._evaluate_output_names = evaluate_output_names
self._evaluate_output_names = evaluate_output_names # pyright: ignore[reportAttributeAccessIssue]
self._alias_output_names = alias_output_names
self._backend_version = backend_version
self._version = version
Expand Down
2 changes: 1 addition & 1 deletion narwhals/_duckdb/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def __init__(
self._version = version
self._call = call
self._function_name = function_name
self._evaluate_output_names = evaluate_output_names
self._evaluate_output_names = evaluate_output_names # pyright: ignore[reportAttributeAccessIssue]
self._alias_output_names = alias_output_names

def otherwise(self: Self, value: DuckDBExpr | Any) -> DuckDBExpr:
Expand Down
2 changes: 1 addition & 1 deletion narwhals/_pandas_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(
self._call = call
self._depth = depth
self._function_name = function_name
self._evaluate_output_names = evaluate_output_names
self._evaluate_output_names = evaluate_output_names # pyright: ignore[reportAttributeAccessIssue]
self._alias_output_names = alias_output_names
self._implementation = implementation
self._backend_version = backend_version
Expand Down
2 changes: 1 addition & 1 deletion narwhals/_pandas_like/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,7 @@ def __init__(
self._call = call
self._depth = depth
self._function_name = function_name
self._evaluate_output_names = evaluate_output_names
self._evaluate_output_names = evaluate_output_names # pyright: ignore[reportAttributeAccessIssue]
self._alias_output_names = alias_output_names
self._kwargs = kwargs

Expand Down
35 changes: 17 additions & 18 deletions narwhals/_polars/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@

from typing import TYPE_CHECKING
from typing import Iterator
from typing import cast

from narwhals._polars.utils import extract_args_kwargs
from narwhals._polars.utils import extract_native

if TYPE_CHECKING:
from polars.dataframe.group_by import GroupBy as NativeGroupBy
from polars.lazyframe.group_by import LazyGroupBy as NativeLazyGroupBy
from typing_extensions import Self

from narwhals._polars.dataframe import PolarsDataFrame
Expand All @@ -17,33 +20,29 @@ class PolarsGroupBy:
def __init__(
self: Self, df: PolarsDataFrame, keys: list[str], *, drop_null_keys: bool
) -> None:
self._compliant_frame = df
self.keys = keys
if drop_null_keys:
self._grouped = df.drop_nulls(keys)._native_frame.group_by(keys)
else:
self._grouped = df._native_frame.group_by(keys)
self._compliant_frame: PolarsDataFrame = df
self.keys: list[str] = keys
df = df.drop_nulls(keys) if drop_null_keys else df
self._grouped: NativeGroupBy = df._native_frame.group_by(keys)

def agg(self: Self, *aggs: PolarsExpr) -> PolarsDataFrame:
aggs, _ = extract_args_kwargs(aggs, {}) # type: ignore[assignment]
return self._compliant_frame._from_native_frame(self._grouped.agg(*aggs))
from_native = self._compliant_frame._from_native_frame
return from_native(self._grouped.agg(extract_native(arg) for arg in aggs))

def __iter__(self: Self) -> Iterator[tuple[tuple[str, ...], PolarsDataFrame]]:
for key, df in self._grouped:
yield tuple(key), self._compliant_frame._from_native_frame(df)
yield tuple(cast("str", key)), self._compliant_frame._from_native_frame(df)


class PolarsLazyGroupBy:
def __init__(
self: Self, df: PolarsLazyFrame, keys: list[str], *, drop_null_keys: bool
) -> None:
self._compliant_frame = df
self.keys = keys
if drop_null_keys:
self._grouped = df.drop_nulls(keys)._native_frame.group_by(keys)
else:
self._grouped = df._native_frame.group_by(keys)
self._compliant_frame: PolarsLazyFrame = df
self.keys: list[str] = keys
df = df.drop_nulls(keys) if drop_null_keys else df
self._grouped: NativeLazyGroupBy = df._native_frame.group_by(keys)

def agg(self: Self, *aggs: PolarsExpr) -> PolarsLazyFrame:
aggs, _ = extract_args_kwargs(aggs, {}) # type: ignore[assignment]
return self._compliant_frame._from_native_frame(self._grouped.agg(*aggs))
from_native = self._compliant_frame._from_native_frame
return from_native(self._grouped.agg(extract_native(arg) for arg in aggs))
2 changes: 1 addition & 1 deletion narwhals/_spark_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(
) -> None:
self._call = call
self._function_name = function_name
self._evaluate_output_names = evaluate_output_names
self._evaluate_output_names = evaluate_output_names # pyright: ignore[reportAttributeAccessIssue]
self._alias_output_names = alias_output_names
self._backend_version = backend_version
self._version = version
Expand Down
2 changes: 1 addition & 1 deletion narwhals/_spark_like/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ def __init__(
self._version = version
self._call = call
self._function_name = function_name
self._evaluate_output_names = evaluate_output_names
self._evaluate_output_names = evaluate_output_names # pyright: ignore[reportAttributeAccessIssue]
self._alias_output_names = alias_output_names
self._implementation = implementation

Expand Down
17 changes: 17 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ typing = [
"pandas-stubs",
"typing_extensions",
"mypy~=1.15.0",
"pyright"
]
dev = [
"pre-commit",
Expand Down Expand Up @@ -242,3 +243,19 @@ module = [
# TODO: remove follow_imports
follow_imports = "skip"
ignore_missing_imports = true

[tool.pyright]
pythonPlatform = "All"
# NOTE (stubs do unsafe `TypeAlias` and `TypeVar` imports)
# pythonVersion = "3.8"
reportMissingImports = "none"
reportMissingModuleSource = "none"
reportPrivateImportUsage = "none"
reportUnusedExpression = "none" # can enforce with `ruff`
typeCheckingMode = "basic"
include = ["narwhals", "tests"]
ignore = [
"../.venv/",
"../../../**/Lib", # stdlib
"../../../**/typeshed*" # typeshed-fallback
]
2 changes: 1 addition & 1 deletion tests/frame/with_columns_sequence_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
def test_with_columns(constructor_eager: ConstructorEager) -> None:
result = (
nw.from_native(constructor_eager(data))
.with_columns(d=np.array([4, 5]))
.with_columns(d=np.array([4, 5])) # pyright: ignore[reportArgumentType]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not that happy about this.
Might need to change _(1|2)DArray to just both alias _AnyDArray

narwhals/narwhals/typing.py

Lines 270 to 273 in 9571ad6

_NDArray: TypeAlias = "np.ndarray[_ShapeT, Any]"
_1DArray: TypeAlias = "_NDArray[tuple[int]]" # noqa: PYI042, PYI047
_2DArray: TypeAlias = "_NDArray[tuple[int, int]]" # noqa: PYI042, PYI047
_AnyDArray: TypeAlias = "_NDArray[tuple[int, ...]]" # noqa: PYI047

Maybe with a doc linking back to the numpy issue and a description of what we actually mean?

The problem stems from most functions returning shape: tuple[int, ...] - which fails both _(1|2)DArray.

.with_columns(e=nw.col("d") + 1)
.select("d", "e")
)
Expand Down
10 changes: 5 additions & 5 deletions tests/from_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_from_numpy(constructor: Constructor, request: pytest.FixtureRequest) ->
request.applymarker(pytest.mark.xfail)
df = nw.from_native(constructor(data))
native_namespace = nw.get_native_namespace(df)
result = nw.from_numpy(arr, native_namespace=native_namespace)
result = nw.from_numpy(arr, native_namespace=native_namespace) # pyright: ignore[reportArgumentType]
assert_equal_data(result, expected)
assert isinstance(result, nw.DataFrame)

Expand All @@ -42,7 +42,7 @@ def test_from_numpy_schema_dict(
df = nw_v1.from_native(constructor(data))
native_namespace = nw_v1.get_native_namespace(df)
result = nw_v1.from_numpy(
arr,
arr, # pyright: ignore[reportArgumentType]
native_namespace=native_namespace,
schema=schema, # type: ignore[arg-type]
)
Expand All @@ -58,7 +58,7 @@ def test_from_numpy_schema_list(
df = nw_v1.from_native(constructor(data))
native_namespace = nw_v1.get_native_namespace(df)
result = nw_v1.from_numpy(
arr,
arr, # pyright: ignore[reportArgumentType]
native_namespace=native_namespace,
schema=schema,
)
Expand All @@ -83,7 +83,7 @@ def test_from_numpy_v1(constructor: Constructor, request: pytest.FixtureRequest)
request.applymarker(pytest.mark.xfail)
df = nw_v1.from_native(constructor(data))
native_namespace = nw_v1.get_native_namespace(df)
result = nw_v1.from_numpy(arr, native_namespace=native_namespace)
result = nw_v1.from_numpy(arr, native_namespace=native_namespace) # pyright: ignore[reportArgumentType]
assert_equal_data(result, expected)
assert isinstance(result, nw_v1.DataFrame)

Expand All @@ -92,4 +92,4 @@ def test_from_numpy_not2d(constructor: Constructor) -> None:
df = nw.from_native(constructor(data))
native_namespace = nw_v1.get_native_namespace(df)
with pytest.raises(ValueError, match="`from_numpy` only accepts 2D numpy arrays"):
nw.from_numpy(np.array([0]), native_namespace=native_namespace)
nw.from_numpy(np.array([0]), native_namespace=native_namespace) # pyright: ignore[reportArgumentType]
4 changes: 2 additions & 2 deletions tests/ibis_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
import ibis

from tests.utils import Constructor

ibis = pytest.importorskip("ibis")
else:
ibis = pytest.importorskip("ibis")


@pytest.fixture
Expand Down
2 changes: 1 addition & 1 deletion tests/translate/to_py_scalar_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def test_to_py_scalar(
) -> None:
output = nw.to_py_scalar(input_value)
if expected == 1:
assert not isinstance(output, np.int64)
assert not isinstance(output, np.generic)
assert output == expected


Expand Down
Loading