diff --git a/.github/workflows/typing.yml b/.github/workflows/typing.yml index 029ec73fd1..8df0cdf264 100644 --- a/.github/workflows/typing.yml +++ b/.github/workflows/typing.yml @@ -41,3 +41,7 @@ jobs: run: | source .venv/bin/activate make typing + - name: Run pyright + run: | + source .venv/bin/activate + pyright diff --git a/narwhals/_dask/dataframe.py b/narwhals/_dask/dataframe.py index 904b6b6152..935624c4fa 100644 --- a/narwhals/_dask/dataframe.py +++ b/narwhals/_dask/dataframe.py @@ -318,7 +318,7 @@ def join( df = self._native_frame.merge( other_native, how="outer", - indicator=indicator_token, + indicator=indicator_token, # pyright: ignore[reportArgumentType] left_on=left_on, right_on=left_on, ) diff --git a/narwhals/_dask/expr.py b/narwhals/_dask/expr.py index 874c158040..c45acfb2c9 100644 --- a/narwhals/_dask/expr.py +++ b/narwhals/_dask/expr.py @@ -53,7 +53,7 @@ def __init__( self._call = call self._depth = depth self._function_name = function_name - self._evaluate_output_names = evaluate_output_names + self._evaluate_output_names = evaluate_output_names # pyright: ignore[reportAttributeAccessIssue] self._alias_output_names = alias_output_names self._backend_version = backend_version self._version = version diff --git a/narwhals/_dask/group_by.py b/narwhals/_dask/group_by.py index 5fd11225ba..f9cbb2d526 100644 --- a/narwhals/_dask/group_by.py +++ b/narwhals/_dask/group_by.py @@ -1,9 +1,11 @@ from __future__ import annotations import re +from functools import partial from typing import TYPE_CHECKING from typing import Any from typing import Callable +from typing import Mapping from typing import Sequence import dask.dataframe as dd @@ -18,48 +20,42 @@ if TYPE_CHECKING: import pandas as pd + from pandas.core.groupby import SeriesGroupBy as _PandasSeriesGroupBy from typing_extensions import Self + from typing_extensions import TypeAlias from narwhals._dask.dataframe import DaskLazyFrame from narwhals._dask.expr import DaskExpr from narwhals.typing import CompliantExpr + PandasSeriesGroupBy: TypeAlias = "_PandasSeriesGroupBy[Any, Any]" + _AggFn: TypeAlias = Callable[..., Any] + Aggregation: TypeAlias = "str | _AggFn" -def n_unique() -> dd.Aggregation: - def chunk(s: pd.core.groupby.generic.SeriesGroupBy) -> pd.Series[Any]: - return s.nunique(dropna=False) # type: ignore[no-any-return] + from dask_expr._groupby import GroupBy as _DaskGroupBy +else: + _DaskGroupBy = dx._groupby.GroupBy - def agg(s0: pd.core.groupby.generic.SeriesGroupBy) -> pd.Series[Any]: - return s0.sum() # type: ignore[no-any-return] - return dd.Aggregation( - name="nunique", - chunk=chunk, - agg=agg, - ) +def n_unique() -> dd.Aggregation: + def chunk(s: PandasSeriesGroupBy) -> pd.Series[Any]: + return s.nunique(dropna=False) + def agg(s0: PandasSeriesGroupBy) -> pd.Series[Any]: + return s0.sum() -def var( - ddof: int = 1, -) -> Callable[ - [pd.core.groupby.generic.SeriesGroupBy], pd.core.groupby.generic.SeriesGroupBy -]: - from functools import partial + return dd.Aggregation(name="nunique", chunk=chunk, agg=agg) - return partial(dx._groupby.GroupBy.var, ddof=ddof) +def var(ddof: int = 1) -> _AggFn: + return partial(_DaskGroupBy.var, ddof=ddof) -def std( - ddof: int = 1, -) -> Callable[ - [pd.core.groupby.generic.SeriesGroupBy], pd.core.groupby.generic.SeriesGroupBy -]: - from functools import partial - return partial(dx._groupby.GroupBy.std, ddof=ddof) +def std(ddof: int = 1) -> _AggFn: + return partial(_DaskGroupBy.std, ddof=ddof) -POLARS_TO_DASK_AGGREGATIONS = { +POLARS_TO_DASK_AGGREGATIONS: Mapping[str, Aggregation] = { "sum": "sum", "mean": "mean", "median": "median", @@ -101,7 +97,7 @@ def _from_native_frame(self: Self, df: DaskLazyFrame) -> DaskLazyFrame: from narwhals._dask.dataframe import DaskLazyFrame return DaskLazyFrame( - df, + df, # pyright: ignore[reportArgumentType] backend_version=self._df._backend_version, version=self._df._version, validate_column_names=True, @@ -134,7 +130,7 @@ def agg_dask( break if all_simple_aggs: - simple_aggregations: dict[str, tuple[str, str | dd.Aggregation]] = {} + simple_aggregations: dict[str, tuple[str, Aggregation]] = {} for expr in exprs: output_names, aliases = evaluate_output_names_and_aliases(expr, df, keys) if expr._depth == 0: @@ -149,7 +145,7 @@ def agg_dask( # e.g. agg(nw.mean('a')) # noqa: ERA001 function_name = re.sub(r"(\w+->)", "", expr._function_name) - kwargs = ( + kwargs: dict[str, Any] = ( {"ddof": expr._kwargs["ddof"]} if function_name in {"std", "var"} else {} # type: ignore[attr-defined] ) diff --git a/narwhals/_dask/namespace.py b/narwhals/_dask/namespace.py index 97e37567af..cf825d27f7 100644 --- a/narwhals/_dask/namespace.py +++ b/narwhals/_dask/namespace.py @@ -425,7 +425,7 @@ def __init__( self._call = call self._depth = depth self._function_name = function_name - self._evaluate_output_names = evaluate_output_names + self._evaluate_output_names = evaluate_output_names # pyright: ignore[reportAttributeAccessIssue] self._alias_output_names = alias_output_names self._kwargs = kwargs diff --git a/narwhals/_dask/utils.py b/narwhals/_dask/utils.py index c0af257158..5c1c4c0742 100644 --- a/narwhals/_dask/utils.py +++ b/narwhals/_dask/utils.py @@ -87,7 +87,7 @@ def validate_comparand(lhs: dx.Series, rhs: dx.Series) -> None: except ModuleNotFoundError: # pragma: no cover import dask_expr as dx - if not dx._expr.are_co_aligned(lhs._expr, rhs._expr): # pragma: no cover + if not dx.expr.are_co_aligned(lhs._expr, rhs._expr): # pragma: no cover # are_co_aligned is a method which cheaply checks if two Dask expressions # have the same index, and therefore don't require index alignment. # If someone only operates on a Dask DataFrame via expressions, then this diff --git a/narwhals/_duckdb/expr.py b/narwhals/_duckdb/expr.py index 05de163510..0dc9aab63f 100644 --- a/narwhals/_duckdb/expr.py +++ b/narwhals/_duckdb/expr.py @@ -49,7 +49,7 @@ def __init__( ) -> None: self._call = call self._function_name = function_name - self._evaluate_output_names = evaluate_output_names + self._evaluate_output_names = evaluate_output_names # pyright: ignore[reportAttributeAccessIssue] self._alias_output_names = alias_output_names self._backend_version = backend_version self._version = version diff --git a/narwhals/_duckdb/namespace.py b/narwhals/_duckdb/namespace.py index 334e7eb8c2..8c5191a3be 100644 --- a/narwhals/_duckdb/namespace.py +++ b/narwhals/_duckdb/namespace.py @@ -334,7 +334,7 @@ def __init__( self._version = version self._call = call self._function_name = function_name - self._evaluate_output_names = evaluate_output_names + self._evaluate_output_names = evaluate_output_names # pyright: ignore[reportAttributeAccessIssue] self._alias_output_names = alias_output_names def otherwise(self: Self, value: DuckDBExpr | Any) -> DuckDBExpr: diff --git a/narwhals/_pandas_like/expr.py b/narwhals/_pandas_like/expr.py index 79eca8c07a..ab13623e02 100644 --- a/narwhals/_pandas_like/expr.py +++ b/narwhals/_pandas_like/expr.py @@ -63,7 +63,7 @@ def __init__( self._call = call self._depth = depth self._function_name = function_name - self._evaluate_output_names = evaluate_output_names + self._evaluate_output_names = evaluate_output_names # pyright: ignore[reportAttributeAccessIssue] self._alias_output_names = alias_output_names self._implementation = implementation self._backend_version = backend_version diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index 462cc02d6f..67b2314ef7 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -515,7 +515,7 @@ def __init__( self._call = call self._depth = depth self._function_name = function_name - self._evaluate_output_names = evaluate_output_names + self._evaluate_output_names = evaluate_output_names # pyright: ignore[reportAttributeAccessIssue] self._alias_output_names = alias_output_names self._kwargs = kwargs diff --git a/narwhals/_polars/group_by.py b/narwhals/_polars/group_by.py index 26b19cae68..655e0e9bb8 100644 --- a/narwhals/_polars/group_by.py +++ b/narwhals/_polars/group_by.py @@ -2,10 +2,13 @@ from typing import TYPE_CHECKING from typing import Iterator +from typing import cast -from narwhals._polars.utils import extract_args_kwargs +from narwhals._polars.utils import extract_native if TYPE_CHECKING: + from polars.dataframe.group_by import GroupBy as NativeGroupBy + from polars.lazyframe.group_by import LazyGroupBy as NativeLazyGroupBy from typing_extensions import Self from narwhals._polars.dataframe import PolarsDataFrame @@ -17,33 +20,29 @@ class PolarsGroupBy: def __init__( self: Self, df: PolarsDataFrame, keys: list[str], *, drop_null_keys: bool ) -> None: - self._compliant_frame = df - self.keys = keys - if drop_null_keys: - self._grouped = df.drop_nulls(keys)._native_frame.group_by(keys) - else: - self._grouped = df._native_frame.group_by(keys) + self._compliant_frame: PolarsDataFrame = df + self.keys: list[str] = keys + df = df.drop_nulls(keys) if drop_null_keys else df + self._grouped: NativeGroupBy = df._native_frame.group_by(keys) def agg(self: Self, *aggs: PolarsExpr) -> PolarsDataFrame: - aggs, _ = extract_args_kwargs(aggs, {}) # type: ignore[assignment] - return self._compliant_frame._from_native_frame(self._grouped.agg(*aggs)) + from_native = self._compliant_frame._from_native_frame + return from_native(self._grouped.agg(extract_native(arg) for arg in aggs)) def __iter__(self: Self) -> Iterator[tuple[tuple[str, ...], PolarsDataFrame]]: for key, df in self._grouped: - yield tuple(key), self._compliant_frame._from_native_frame(df) + yield tuple(cast("str", key)), self._compliant_frame._from_native_frame(df) class PolarsLazyGroupBy: def __init__( self: Self, df: PolarsLazyFrame, keys: list[str], *, drop_null_keys: bool ) -> None: - self._compliant_frame = df - self.keys = keys - if drop_null_keys: - self._grouped = df.drop_nulls(keys)._native_frame.group_by(keys) - else: - self._grouped = df._native_frame.group_by(keys) + self._compliant_frame: PolarsLazyFrame = df + self.keys: list[str] = keys + df = df.drop_nulls(keys) if drop_null_keys else df + self._grouped: NativeLazyGroupBy = df._native_frame.group_by(keys) def agg(self: Self, *aggs: PolarsExpr) -> PolarsLazyFrame: - aggs, _ = extract_args_kwargs(aggs, {}) # type: ignore[assignment] - return self._compliant_frame._from_native_frame(self._grouped.agg(*aggs)) + from_native = self._compliant_frame._from_native_frame + return from_native(self._grouped.agg(extract_native(arg) for arg in aggs)) diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index 9446c0966a..97c176061f 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -43,7 +43,7 @@ def __init__( ) -> None: self._call = call self._function_name = function_name - self._evaluate_output_names = evaluate_output_names + self._evaluate_output_names = evaluate_output_names # pyright: ignore[reportAttributeAccessIssue] self._alias_output_names = alias_output_names self._backend_version = backend_version self._version = version diff --git a/narwhals/_spark_like/namespace.py b/narwhals/_spark_like/namespace.py index 9f6e213dee..836c1d6a4d 100644 --- a/narwhals/_spark_like/namespace.py +++ b/narwhals/_spark_like/namespace.py @@ -383,7 +383,7 @@ def __init__( self._version = version self._call = call self._function_name = function_name - self._evaluate_output_names = evaluate_output_names + self._evaluate_output_names = evaluate_output_names # pyright: ignore[reportAttributeAccessIssue] self._alias_output_names = alias_output_names self._implementation = implementation diff --git a/pyproject.toml b/pyproject.toml index 311a035397..27ff136f10 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,7 @@ typing = [ "pandas-stubs", "typing_extensions", "mypy~=1.15.0", + "pyright" ] dev = [ "pre-commit", @@ -242,3 +243,19 @@ module = [ # TODO: remove follow_imports follow_imports = "skip" ignore_missing_imports = true + +[tool.pyright] +pythonPlatform = "All" +# NOTE (stubs do unsafe `TypeAlias` and `TypeVar` imports) +# pythonVersion = "3.8" +reportMissingImports = "none" +reportMissingModuleSource = "none" +reportPrivateImportUsage = "none" +reportUnusedExpression = "none" # can enforce with `ruff` +typeCheckingMode = "basic" +include = ["narwhals", "tests"] +ignore = [ + "../.venv/", + "../../../**/Lib", # stdlib + "../../../**/typeshed*" # typeshed-fallback +] diff --git a/tests/frame/with_columns_sequence_test.py b/tests/frame/with_columns_sequence_test.py index 11c6cb4995..846596162d 100644 --- a/tests/frame/with_columns_sequence_test.py +++ b/tests/frame/with_columns_sequence_test.py @@ -15,7 +15,7 @@ def test_with_columns(constructor_eager: ConstructorEager) -> None: result = ( nw.from_native(constructor_eager(data)) - .with_columns(d=np.array([4, 5])) + .with_columns(d=np.array([4, 5])) # pyright: ignore[reportArgumentType] .with_columns(e=nw.col("d") + 1) .select("d", "e") ) diff --git a/tests/from_numpy_test.py b/tests/from_numpy_test.py index 7a40136e71..0d19396993 100644 --- a/tests/from_numpy_test.py +++ b/tests/from_numpy_test.py @@ -23,7 +23,7 @@ def test_from_numpy(constructor: Constructor, request: pytest.FixtureRequest) -> request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) native_namespace = nw.get_native_namespace(df) - result = nw.from_numpy(arr, native_namespace=native_namespace) + result = nw.from_numpy(arr, native_namespace=native_namespace) # pyright: ignore[reportArgumentType] assert_equal_data(result, expected) assert isinstance(result, nw.DataFrame) @@ -42,7 +42,7 @@ def test_from_numpy_schema_dict( df = nw_v1.from_native(constructor(data)) native_namespace = nw_v1.get_native_namespace(df) result = nw_v1.from_numpy( - arr, + arr, # pyright: ignore[reportArgumentType] native_namespace=native_namespace, schema=schema, # type: ignore[arg-type] ) @@ -58,7 +58,7 @@ def test_from_numpy_schema_list( df = nw_v1.from_native(constructor(data)) native_namespace = nw_v1.get_native_namespace(df) result = nw_v1.from_numpy( - arr, + arr, # pyright: ignore[reportArgumentType] native_namespace=native_namespace, schema=schema, ) @@ -83,7 +83,7 @@ def test_from_numpy_v1(constructor: Constructor, request: pytest.FixtureRequest) request.applymarker(pytest.mark.xfail) df = nw_v1.from_native(constructor(data)) native_namespace = nw_v1.get_native_namespace(df) - result = nw_v1.from_numpy(arr, native_namespace=native_namespace) + result = nw_v1.from_numpy(arr, native_namespace=native_namespace) # pyright: ignore[reportArgumentType] assert_equal_data(result, expected) assert isinstance(result, nw_v1.DataFrame) @@ -92,4 +92,4 @@ def test_from_numpy_not2d(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) native_namespace = nw_v1.get_native_namespace(df) with pytest.raises(ValueError, match="`from_numpy` only accepts 2D numpy arrays"): - nw.from_numpy(np.array([0]), native_namespace=native_namespace) + nw.from_numpy(np.array([0]), native_namespace=native_namespace) # pyright: ignore[reportArgumentType] diff --git a/tests/ibis_test.py b/tests/ibis_test.py index d4014f2669..8043b9b0e8 100644 --- a/tests/ibis_test.py +++ b/tests/ibis_test.py @@ -12,8 +12,8 @@ import ibis from tests.utils import Constructor - -ibis = pytest.importorskip("ibis") +else: + ibis = pytest.importorskip("ibis") @pytest.fixture diff --git a/tests/translate/to_py_scalar_test.py b/tests/translate/to_py_scalar_test.py index b859835012..d9b076d135 100644 --- a/tests/translate/to_py_scalar_test.py +++ b/tests/translate/to_py_scalar_test.py @@ -43,7 +43,7 @@ def test_to_py_scalar( ) -> None: output = nw.to_py_scalar(input_value) if expected == 1: - assert not isinstance(output, np.int64) + assert not isinstance(output, np.generic) assert output == expected