diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4b26a2aac2d..f80774f5fad 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -34,7 +34,7 @@ repos: rev: 'v1.13.0' hooks: - id: mypy - additional_dependencies: [types-cachetools, pyarrow-stubs, numpy, pytest] + additional_dependencies: [numpy, polars, pyarrow-stubs, pytest, types-cachetools] args: ["--config-file=pyproject.toml", "python/cudf/cudf", "python/pylibcudf/pylibcudf", diff --git a/python/cudf_polars/cudf_polars/containers/datatype.py b/python/cudf_polars/cudf_polars/containers/datatype.py index 5b23788212c..833549d9ee4 100644 --- a/python/cudf_polars/cudf_polars/containers/datatype.py +++ b/python/cudf_polars/cudf_polars/containers/datatype.py @@ -6,7 +6,7 @@ from __future__ import annotations from functools import cache -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal, cast from typing_extensions import assert_never @@ -14,9 +14,12 @@ import pylibcudf as plc +from cudf_polars.utils.versions import POLARS_VERSION_LT_136 + if TYPE_CHECKING: from cudf_polars.typing import ( DataTypeHeader, + PolarsDataType, ) __all__ = ["DataType"] @@ -46,7 +49,18 @@ def _dtype_to_header(dtype: pl.DataType) -> DataTypeHeader: if name in SCALAR_NAME_TO_POLARS_TYPE_MAP: return {"kind": "scalar", "name": name} if isinstance(dtype, pl.Decimal): - return {"kind": "decimal", "precision": dtype.precision, "scale": dtype.scale} + # Workaround for incorrect polars stubs where precision is typed as int | None + # Fixed upstream: https://github.com/pola-rs/polars/pull/25227 + # TODO: Remove this workaround when polars >= 1.36 + if POLARS_VERSION_LT_136: + assert ( + dtype.precision is not None + ) # Decimal always has precision at runtime + return { + "kind": "decimal", + "precision": cast(int, dtype.precision), + "scale": dtype.scale, + } if isinstance(dtype, pl.Datetime): return { "kind": "datetime", @@ -56,12 +70,17 @@ def _dtype_to_header(dtype: pl.DataType) -> DataTypeHeader: if isinstance(dtype, pl.Duration): return {"kind": "duration", "time_unit": dtype.time_unit} if isinstance(dtype, pl.List): - return {"kind": "list", "inner": _dtype_to_header(dtype.inner)} + # isinstance narrows dtype to pl.List, but .inner returns DataTypeClass | DataType + return { + "kind": "list", + "inner": _dtype_to_header(cast(pl.DataType, dtype.inner)), + } if isinstance(dtype, pl.Struct): + # isinstance narrows dtype to pl.Struct, but field.dtype returns DataTypeClass | DataType return { "kind": "struct", "fields": [ - {"name": f.name, "dtype": _dtype_to_header(f.dtype)} + {"name": f.name, "dtype": _dtype_to_header(cast(pl.DataType, f.dtype))} for f in dtype.fields ], } @@ -78,9 +97,14 @@ def _dtype_from_header(header: DataTypeHeader) -> pl.DataType: if header["kind"] == "decimal": return pl.Decimal(header["precision"], header["scale"]) if header["kind"] == "datetime": - return pl.Datetime(time_unit=header["time_unit"], time_zone=header["time_zone"]) + return pl.Datetime( + time_unit=cast(Literal["ns", "us", "ms"], header["time_unit"]), + time_zone=header["time_zone"], + ) if header["kind"] == "duration": - return pl.Duration(time_unit=header["time_unit"]) + return pl.Duration( + time_unit=cast(Literal["ns", "us", "ms"], header["time_unit"]) + ) if header["kind"] == "list": return pl.List(_dtype_from_header(header["inner"])) if header["kind"] == "struct": @@ -182,9 +206,14 @@ class DataType: polars_type: pl.datatypes.DataType plc_type: plc.DataType - def __init__(self, polars_dtype: pl.DataType) -> None: - self.polars_type = polars_dtype - self.plc_type = _from_polars(polars_dtype) + def __init__(self, polars_dtype: PolarsDataType) -> None: + # Convert DataTypeClass to DataType instance if needed + # polars allows both pl.Int64 (class) and pl.Int64() (instance) + if isinstance(polars_dtype, type): + polars_dtype = polars_dtype() + # After conversion, it's guaranteed to be a DataType instance + self.polars_type = cast(pl.DataType, polars_dtype) + self.plc_type = _from_polars(self.polars_type) def id(self) -> plc.TypeId: """The pylibcudf.TypeId of this DataType.""" @@ -193,12 +222,16 @@ def id(self) -> plc.TypeId: @property def children(self) -> list[DataType]: """The children types of this DataType.""" - # these type ignores are needed because the type checker doesn't - # see that these equality checks passing imply a specific type for each child field. + # Type checker doesn't narrow polars_type through plc_type.id() checks if self.plc_type.id() == plc.TypeId.STRUCT: - return [DataType(field.dtype) for field in self.polars_type.fields] + # field.dtype returns DataTypeClass | DataType, need to cast to DataType + return [ + DataType(cast(pl.DataType, field.dtype)) + for field in cast(pl.Struct, self.polars_type).fields + ] elif self.plc_type.id() == plc.TypeId.LIST: - return [DataType(self.polars_type.inner)] + # .inner returns DataTypeClass | DataType, need to cast to DataType + return [DataType(cast(pl.DataType, cast(pl.List, self.polars_type).inner))] return [] def scale(self) -> int: diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/boolean.py b/python/cudf_polars/cudf_polars/dsl/expressions/boolean.py index c309c6f9db3..b5ef0000fd6 100644 --- a/python/cudf_polars/cudf_polars/dsl/expressions/boolean.py +++ b/python/cudf_polars/cudf_polars/dsl/expressions/boolean.py @@ -8,7 +8,9 @@ from enum import IntEnum, auto from functools import partial, reduce -from typing import TYPE_CHECKING, Any, ClassVar +from typing import TYPE_CHECKING, Any, ClassVar, cast + +import polars as pl import pylibcudf as plc @@ -350,9 +352,14 @@ def do_evaluate( needles, haystack = columns if haystack.obj.type().id() == plc.TypeId.LIST: # Unwrap values from the list column + # .inner returns DataTypeClass | DataType, need to cast to DataType haystack = Column( haystack.obj.children()[1], - dtype=DataType(haystack.dtype.polars_type.inner), + dtype=DataType( + cast( + pl.DataType, cast(pl.List, haystack.dtype.polars_type).inner + ) + ), ).astype(needles.dtype, stream=df.stream) if haystack.size: return Column( diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/string.py b/python/cudf_polars/cudf_polars/dsl/expressions/string.py index 687ccd44762..f9e8c0a4341 100644 --- a/python/cudf_polars/cudf_polars/dsl/expressions/string.py +++ b/python/cudf_polars/cudf_polars/dsl/expressions/string.py @@ -10,8 +10,9 @@ import re from datetime import datetime from enum import IntEnum, auto -from typing import TYPE_CHECKING, Any, ClassVar +from typing import TYPE_CHECKING, Any, ClassVar, cast +import polars as pl from polars.exceptions import InvalidOperationError from polars.polars import dtype_str_repr @@ -37,12 +38,12 @@ def _dtypes_for_json_decode(dtype: DataType) -> JsonDecodeType: """Get the dtypes for json decode.""" - # the type checker doesn't know that this equality check implies a struct dtype. + # Type checker doesn't narrow polars_type through dtype.id() check if dtype.id() == plc.TypeId.STRUCT: return [ (field.name, child.plc_type, _dtypes_for_json_decode(child)) for field, child in zip( - dtype.polars_type.fields, + cast(pl.Struct, dtype.polars_type).fields, dtype.children, strict=True, ) diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/struct.py b/python/cudf_polars/cudf_polars/dsl/expressions/struct.py index daa28deb624..9825ac49f58 100644 --- a/python/cudf_polars/cudf_polars/dsl/expressions/struct.py +++ b/python/cudf_polars/cudf_polars/dsl/expressions/struct.py @@ -8,7 +8,9 @@ from enum import IntEnum, auto from io import StringIO -from typing import TYPE_CHECKING, Any, ClassVar +from typing import TYPE_CHECKING, Any, ClassVar, cast + +import polars as pl import pylibcudf as plc @@ -87,13 +89,14 @@ def do_evaluate( """Evaluate this expression given a dataframe for context.""" columns = [child.evaluate(df, context=context) for child in self.children] (column,) = columns - # these type ignores are needed because the type checker doesn't - # know that polars only calls StructFunction with struct types. + # Type checker doesn't know polars only calls StructFunction with struct types if self.name == StructFunction.Name.FieldByName: field_index = next( ( i - for i, field in enumerate(self.children[0].dtype.polars_type.fields) + for i, field in enumerate( + cast(pl.Struct, self.children[0].dtype.polars_type).fields + ) if field.name == self.options[0] ), None, @@ -113,7 +116,9 @@ def do_evaluate( table, [ (field.name, []) - for field in self.children[0].dtype.polars_type.fields + for field in cast( + pl.Struct, self.children[0].dtype.polars_type + ).fields ], ) options = ( diff --git a/python/cudf_polars/cudf_polars/dsl/ir.py b/python/cudf_polars/cudf_polars/dsl/ir.py index 8d9755470f7..288bf6d549e 100644 --- a/python/cudf_polars/cudf_polars/dsl/ir.py +++ b/python/cudf_polars/cudf_polars/dsl/ir.py @@ -598,7 +598,8 @@ def add_file_paths( plc.Table( [ plc.Column.from_arrow( - pl.Series(values=map(str, paths)), stream=df.stream + pl.Series(values=map(str, paths)), + stream=df.stream, ) ] ), diff --git a/python/cudf_polars/cudf_polars/dsl/to_ast.py b/python/cudf_polars/cudf_polars/dsl/to_ast.py index add5891444b..95fd044aa71 100644 --- a/python/cudf_polars/cudf_polars/dsl/to_ast.py +++ b/python/cudf_polars/cudf_polars/dsl/to_ast.py @@ -6,7 +6,9 @@ from __future__ import annotations from functools import partial, reduce, singledispatch -from typing import TYPE_CHECKING, TypeAlias, TypedDict +from typing import TYPE_CHECKING, TypeAlias, TypedDict, cast + +import polars as pl import pylibcudf as plc from pylibcudf import expressions as plc_expr @@ -226,10 +228,10 @@ def _(node: expr.BooleanFunction, self: Transformer) -> plc_expr.Expression: if haystack.dtype.id() == plc.TypeId.LIST: # Because we originally translated pl_expr.Literal with a list scalar # to a expr.LiteralColumn, so the actual type is in the inner type - # - # the type-ignore is safe because the for plc.TypeID.LIST, we know - # we have a polars.List type, which has an inner attribute. - plc_dtype = DataType(haystack.dtype.polars_type.inner).plc_type + # .inner returns DataTypeClass | DataType, need to cast to DataType + plc_dtype = DataType( + cast(pl.DataType, cast(pl.List, haystack.dtype.polars_type).inner) + ).plc_type else: plc_dtype = haystack.dtype.plc_type # pragma: no cover values = ( diff --git a/python/cudf_polars/cudf_polars/experimental/sort.py b/python/cudf_polars/cudf_polars/experimental/sort.py index 71334507da1..9ea5678aad1 100644 --- a/python/cudf_polars/cudf_polars/experimental/sort.py +++ b/python/cudf_polars/cudf_polars/experimental/sort.py @@ -98,6 +98,10 @@ def find_sort_splits( stream=stream, ) # And convert to list for final processing + # The type ignores are for cross-library boundaries: plc.Column -> pl.Series + # These work at runtime via the Arrow C Data Interface protocol + # TODO: Find a way for pylibcudf types to show they export the Arrow protocol + # (mypy wasn't happy with a custom protocol) split_first_list = pl.Series(split_first_col).to_list() split_last_list = pl.Series(split_last_col).to_list() split_part_id_list = pl.Series(split_part_id).to_list() diff --git a/python/cudf_polars/cudf_polars/testing/asserts.py b/python/cudf_polars/cudf_polars/testing/asserts.py index 385f65f39b2..95ed7244961 100644 --- a/python/cudf_polars/cudf_polars/testing/asserts.py +++ b/python/cudf_polars/cudf_polars/testing/asserts.py @@ -112,8 +112,8 @@ def assert_gpu_result_equal( # These keywords are correct, but mypy doesn't see that. # the 'misc' is for 'error: Keywords must be strings' - expect = lazydf.collect(**final_polars_collect_kwargs) - got = lazydf.collect(**final_cudf_collect_kwargs, engine=engine) + expect = lazydf.collect(**final_polars_collect_kwargs) # type: ignore[misc, call-overload] + got = lazydf.collect(**final_cudf_collect_kwargs, engine=engine) # type: ignore[misc, call-overload] assert_kwargs_bool: dict[str, bool] = { "check_row_order": check_row_order, @@ -136,7 +136,7 @@ def assert_gpu_result_equal( expect, got, **assert_kwargs_bool, - **tol_kwargs, + **tol_kwargs, # type: ignore[arg-type] ) @@ -294,7 +294,7 @@ def assert_collect_raises( ) try: - lazydf.collect(**final_polars_collect_kwargs) + lazydf.collect(**final_polars_collect_kwargs) # type: ignore[misc, call-overload] except polars_except: pass except Exception as e: @@ -307,7 +307,7 @@ def assert_collect_raises( engine = GPUEngine(raise_on_fail=True) try: - lazydf.collect(**final_cudf_collect_kwargs, engine=engine) + lazydf.collect(**final_cudf_collect_kwargs, engine=engine) # type: ignore[misc, call-overload] except cudf_except: pass except Exception as e: diff --git a/python/cudf_polars/cudf_polars/testing/plugin.py b/python/cudf_polars/cudf_polars/testing/plugin.py index 28ab3d91b16..49d6ef6e01d 100644 --- a/python/cudf_polars/cudf_polars/testing/plugin.py +++ b/python/cudf_polars/cudf_polars/testing/plugin.py @@ -57,11 +57,11 @@ def pytest_configure(config: pytest.Config) -> None: collect = polars.LazyFrame.collect engine = polars.GPUEngine(raise_on_fail=no_fallback) # https://github.com/python/mypy/issues/2427 - polars.LazyFrame.collect = partialmethod(collect, engine=engine) + polars.LazyFrame.collect = partialmethod(collect, engine=engine) # type: ignore[method-assign, assignment] elif executor == "in-memory": collect = polars.LazyFrame.collect engine = polars.GPUEngine(executor=executor) - polars.LazyFrame.collect = partialmethod(collect, engine=engine) + polars.LazyFrame.collect = partialmethod(collect, engine=engine) # type: ignore[method-assign, assignment] elif executor == "streaming" and blocksize_mode == "small": executor_options: dict[str, Any] = {} executor_options["max_rows_per_partition"] = 4 @@ -70,7 +70,7 @@ def pytest_configure(config: pytest.Config) -> None: executor_options["fallback_mode"] = StreamingFallbackMode.SILENT collect = polars.LazyFrame.collect engine = polars.GPUEngine(executor=executor, executor_options=executor_options) - polars.LazyFrame.collect = partialmethod(collect, engine=engine) + polars.LazyFrame.collect = partialmethod(collect, engine=engine) # type: ignore[method-assign, assignment] else: # run with streaming executor and default blocksize polars.Config.set_engine_affinity("gpu") diff --git a/python/cudf_polars/cudf_polars/utils/versions.py b/python/cudf_polars/cudf_polars/utils/versions.py index cf1556551e7..6299a2a7b46 100644 --- a/python/cudf_polars/cudf_polars/utils/versions.py +++ b/python/cudf_polars/cudf_polars/utils/versions.py @@ -18,6 +18,7 @@ POLARS_VERSION_LT_1321 = POLARS_VERSION < parse("1.32.1") POLARS_VERSION_LT_1323 = POLARS_VERSION < parse("1.32.3") POLARS_VERSION_LT_133 = POLARS_VERSION < parse("1.33.0") +POLARS_VERSION_LT_136 = POLARS_VERSION < parse("1.36") def _ensure_polars_version() -> None: diff --git a/python/cudf_polars/tests/containers/test_column.py b/python/cudf_polars/tests/containers/test_column.py index a344123d2cb..053f85d9f35 100644 --- a/python/cudf_polars/tests/containers/test_column.py +++ b/python/cudf_polars/tests/containers/test_column.py @@ -3,6 +3,8 @@ from __future__ import annotations +from typing import TYPE_CHECKING + import pytest import polars as pl @@ -14,8 +16,11 @@ from cudf_polars.containers import Column, DataType from cudf_polars.utils.cuda_stream import get_cuda_stream +if TYPE_CHECKING: + from cudf_polars.typing import PolarsDataType + -def _as_instance(dtype: pl.DataType) -> pl.DataType: +def _as_instance(dtype: PolarsDataType) -> pl.DataType: if isinstance(dtype, type): return dtype() if isinstance(dtype, pl.List): diff --git a/python/cudf_polars/tests/experimental/test_shuffle.py b/python/cudf_polars/tests/experimental/test_shuffle.py index 4a5c0bd97ad..9452815232d 100644 --- a/python/cudf_polars/tests/experimental/test_shuffle.py +++ b/python/cudf_polars/tests/experimental/test_shuffle.py @@ -3,6 +3,8 @@ from __future__ import annotations +from typing import Literal, cast + import pytest import polars as pl @@ -75,7 +77,9 @@ def test_hash_shuffle(df: pl.LazyFrame, engine: pl.GPUEngine) -> None: qir3, options, ) - # ignore is for polars' EngineType, which isn't publicly exported. + # Cast needed because polars' EngineType "cpu" isn't publicly exported. # https://github.com/pola-rs/polars/issues/17420 - expect = df.collect(engine="cpu") + expect = df.collect( + engine=cast(Literal["auto", "in-memory", "streaming", "gpu"], "cpu") + ) assert_frame_equal(result, expect, check_row_order=False) diff --git a/python/cudf_polars/tests/expressions/test_rolling.py b/python/cudf_polars/tests/expressions/test_rolling.py index e2a02816438..af13fcaad43 100644 --- a/python/cudf_polars/tests/expressions/test_rolling.py +++ b/python/cudf_polars/tests/expressions/test_rolling.py @@ -3,7 +3,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal, cast import pytest @@ -388,7 +388,11 @@ def test_fill_over( group_key: str, expr: pl.Expr, ) -> None: - q = df.select(expr.fill_null(strategy=strategy).over(group_key, order_by=order_by)) + q = df.select( + expr.fill_null(strategy=cast(Literal["forward", "backward"], strategy)).over( + group_key, order_by=order_by + ) + ) if POLARS_VERSION_LT_132: assert_ir_translation_raises(q, NotImplementedError) else: diff --git a/python/cudf_polars/tests/test_config.py b/python/cudf_polars/tests/test_config.py index c0952d7a8f4..bee10f3e257 100644 --- a/python/cudf_polars/tests/test_config.py +++ b/python/cudf_polars/tests/test_config.py @@ -4,7 +4,7 @@ from __future__ import annotations import sys -from typing import Any +from typing import Any, cast import pytest @@ -537,7 +537,7 @@ def test_validate_parquet_options(option: str) -> None: def test_validate_raise_on_fail() -> None: with pytest.raises(TypeError, match="'raise_on_fail' must be"): ConfigOptions.from_polars_engine( - pl.GPUEngine(executor="streaming", raise_on_fail=object()) + pl.GPUEngine(executor="streaming", raise_on_fail=cast(bool, object())) ) diff --git a/python/cudf_polars/tests/test_window_functions.py b/python/cudf_polars/tests/test_window_functions.py index 66cedb13488..3d50aa70b55 100644 --- a/python/cudf_polars/tests/test_window_functions.py +++ b/python/cudf_polars/tests/test_window_functions.py @@ -3,6 +3,7 @@ from __future__ import annotations from datetime import datetime, timedelta +from typing import Literal, cast import pytest @@ -106,7 +107,16 @@ def test_over_mapping_strategy(df: pl.LazyFrame, mapping_strategy: str): # ignore is for polars' WindowMappingStrategy, which isn't publicly exported. # https://github.com/pola-rs/polars/issues/17420 q = df.with_columns( - [pl.col("b").rank().over(pl.col("a"), mapping_strategy=mapping_strategy)] + [ + pl.col("b") + .rank() + .over( + pl.col("a"), + mapping_strategy=cast( + Literal["group_to_rows", "join", "explode"], mapping_strategy + ), + ) + ] ) if not POLARS_VERSION_LT_132 and mapping_strategy == "group_to_rows": assert_gpu_result_equal(q) @@ -143,6 +153,14 @@ def test_rolling_closed(df: pl.LazyFrame, closed: str): # ignore is for polars' ClosedInterval, which isn't publicly exported. # https://github.com/pola-rs/polars/issues/17420 query = df.with_columns( - [pl.col("b").sum().rolling(period="2d", index_column="date", closed=closed)] + [ + pl.col("b") + .sum() + .rolling( + period="2d", + index_column="date", + closed=cast(Literal["left", "right", "both", "none"], closed), + ) + ] ) assert_gpu_result_equal(query)