Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ repos:
"python/custreamz/custreamz",
"python/cudf_kafka/cudf_kafka",
"python/cudf_polars/cudf_polars",
"python/cudf_polars/tests",
"python/dask_cudf/dask_cudf"]
pass_filenames: false
- repo: https://github.com/pre-commit/mirrors-clang-format
Expand Down
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ exclude = [
module = ["cudf_polars.*"]
disallow_untyped_defs = true

[[tool.mypy.overrides]]
module = ["polars.*"]
# override the global "skip" above so that types from polars
# aren't skipped and assumed to be `Any`.
follow_imports = "normal"

[tool.codespell]
# note: pre-commit passes explicit lists of files here, which this skip file list doesn't override -
# this is only to allow you to run codespell interactively
Expand Down
4 changes: 3 additions & 1 deletion python/cudf_polars/cudf_polars/containers/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ def to_polars(self) -> pl.DataFrame:
self.table,
[plc.interop.ColumnMetadata(name=name) for name in name_map],
)
df: pl.DataFrame = pl.from_arrow(table)
# from_arrow returns DataFrame | Series, but we know that it's a
# DataFrame since the input was a Table.
df = cast(pl.DataFrame, pl.from_arrow(table))
return df.rename(name_map).with_columns(
pl.col(c.name).set_sorted(descending=c.order == plc.types.Order.DESCENDING)
if c.is_sorted
Expand Down
10 changes: 6 additions & 4 deletions python/cudf_polars/cudf_polars/testing/asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,10 @@ def assert_gpu_result_equal(
collect_kwargs, polars_collect_kwargs, cudf_collect_kwargs
)

expect = lazydf.collect(**final_polars_collect_kwargs)
got = lazydf.collect(**final_cudf_collect_kwargs, engine=engine)
# These keywords are correct, but mypy doesn't see that.
# the 'misc' is for 'error: Keywords must be strings'
expect = lazydf.collect(**final_polars_collect_kwargs) # type: ignore[call-overload,misc]
got = lazydf.collect(**final_cudf_collect_kwargs, engine=engine) # type: ignore[call-overload,misc]
assert_frame_equal(
expect,
got,
Expand Down Expand Up @@ -227,7 +229,7 @@ def assert_collect_raises(
)

try:
lazydf.collect(**final_polars_collect_kwargs)
lazydf.collect(**final_polars_collect_kwargs) # type: ignore[call-overload,misc]
except polars_except:
pass
except Exception as e:
Expand All @@ -240,7 +242,7 @@ def assert_collect_raises(

engine = GPUEngine(raise_on_fail=True)
try:
lazydf.collect(**final_cudf_collect_kwargs, engine=engine)
lazydf.collect(**final_cudf_collect_kwargs, engine=engine) # type: ignore[call-overload,misc]
except cudf_except:
pass
except Exception as e:
Expand Down
3 changes: 2 additions & 1 deletion python/cudf_polars/cudf_polars/testing/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ def pytest_configure(config: pytest.Config) -> None:
if no_fallback:
collect = polars.LazyFrame.collect
engine = polars.GPUEngine(raise_on_fail=no_fallback)
polars.LazyFrame.collect = partialmethod(collect, engine=engine)
# https://github.com/python/mypy/issues/2427
polars.LazyFrame.collect = partialmethod(collect, engine=engine) # type: ignore[method-assign,assignment]
else:
polars.Config.set_engine_affinity("gpu")
config.addinivalue_line(
Expand Down
6 changes: 3 additions & 3 deletions python/cudf_polars/cudf_polars/typing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,13 @@
Union,
)

import polars as pl
from polars.polars import _expr_nodes as pl_expr, _ir_nodes as pl_ir

if TYPE_CHECKING:
from collections.abc import Callable, Mapping
from typing import TypeAlias

import polars as pl

import pylibcudf as plc

from cudf_polars.containers import DataFrame, DataType
Expand Down Expand Up @@ -82,6 +81,7 @@
pl_expr.PyExprIR,
]

PolarsSchema: TypeAlias = dict[str, pl.DataType]
Schema: TypeAlias = dict[str, "DataType"]

Slice: TypeAlias = tuple[int, int | None]
Expand All @@ -108,7 +108,7 @@ def view_current_node(self) -> PolarsIR:
"""Convert current plan node to python rep."""
...

def get_schema(self) -> Schema:
def get_schema(self) -> PolarsSchema:
Comment thread
TomAugspurger marked this conversation as resolved.
"""Get the schema of the current plan node."""
...

Expand Down
4 changes: 3 additions & 1 deletion python/cudf_polars/tests/experimental/test_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,5 +68,7 @@ def test_hash_shuffle(df: pl.LazyFrame, engine: pl.GPUEngine) -> None:

# Check that streaming evaluation works
result = evaluate_streaming(qir3, options).to_polars()
expect = df.collect(engine="cpu")
# ignore is for polars' EngineType, which isn't publicly exported.
# https://github.com/pola-rs/polars/issues/17420
expect = df.collect(engine="cpu") # type: ignore[call-overload]
assert_frame_equal(result, expect, check_row_order=False)
15 changes: 11 additions & 4 deletions python/cudf_polars/tests/expressions/test_booleanfunction.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

from typing import TYPE_CHECKING

import pytest

import polars as pl
Expand All @@ -12,18 +14,21 @@
)
from cudf_polars.utils.versions import POLARS_VERSION_LT_128, POLARS_VERSION_LT_130

if TYPE_CHECKING:
from collections.abc import Callable


@pytest.fixture(params=[False, True], ids=["no_nulls", "nulls"])
def has_nulls(request):
def has_nulls(request: pytest.FixtureRequest) -> bool:
return request.param


@pytest.fixture(params=[False, True], ids=["include_nulls", "ignore_nulls"])
def ignore_nulls(request):
def ignore_nulls(request: pytest.FixtureRequest) -> bool:
return request.param


def test_booleanfunction_reduction(ignore_nulls):
def test_booleanfunction_reduction(*, ignore_nulls: bool) -> None:
ldf = pl.LazyFrame(
{
"a": pl.Series([1, 2, 3.0, 2, 5], dtype=pl.Float64()),
Expand Down Expand Up @@ -70,7 +75,9 @@ def test_booleanfunction_all_any_kleene(expr, ignore_nulls):
ids=lambda f: f"{f.__name__}()",
)
@pytest.mark.parametrize("has_nans", [False, True], ids=["no_nans", "nans"])
def test_boolean_function_unary(expr, has_nans, has_nulls):
def test_boolean_function_unary(
expr: Callable[[pl.Expr], pl.Expr], *, has_nans: bool, has_nulls: bool
) -> None:
values: list[float | None] = [1, 2, 3, 4, 5]
if has_nans:
values[3] = float("nan")
Expand Down
4 changes: 2 additions & 2 deletions python/cudf_polars/tests/expressions/test_distinct.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

Expand All @@ -17,7 +17,7 @@ def op(request):


@pytest.fixture
def df(with_nulls):
def df(*, with_nulls: bool) -> pl.LazyFrame:
values: list[int | None] = [1, 2, 3, 1, 1, 7, 3, 2, 7, 8, 1]
if with_nulls:
values[1] = None
Expand Down
6 changes: 5 additions & 1 deletion python/cudf_polars/tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def test_validate_streaming_executor_shuffle_method() -> None:
executor_options={"shuffle_method": "tasks"},
)
)
assert config.executor.name == "streaming"
assert config.executor.shuffle_method == "tasks"

config = ConfigOptions.from_polars_engine(
Expand All @@ -151,6 +152,7 @@ def test_validate_streaming_executor_shuffle_method() -> None:
},
)
)
assert config.executor.name == "streaming"
assert config.executor.shuffle_method == "rapidsmpf"

# rapidsmpf with sync is not allowed
Expand Down Expand Up @@ -201,6 +203,7 @@ def test_validate_scheduler() -> None:
executor="streaming",
)
)
assert config.executor.name == "streaming"
assert config.executor.scheduler == "synchronous"

with pytest.raises(ValueError, match="'foo' is not a valid Scheduler"):
Expand All @@ -218,6 +221,7 @@ def test_validate_shuffle_method() -> None:
executor="streaming",
)
)
assert config.executor.name == "streaming"
assert config.executor.shuffle_method is None

with pytest.raises(ValueError, match="'foo' is not a valid ShuffleMethod"):
Expand Down Expand Up @@ -264,7 +268,7 @@ def test_validate_parquet_options(option: str) -> None:
def test_validate_raise_on_fail() -> None:
with pytest.raises(TypeError, match="'raise_on_fail' must be"):
ConfigOptions.from_polars_engine(
pl.GPUEngine(executor="streaming", raise_on_fail=object())
pl.GPUEngine(executor="streaming", raise_on_fail=object()) # type: ignore[arg-type]
)


Expand Down
8 changes: 6 additions & 2 deletions python/cudf_polars/tests/test_window_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,10 @@ def test_over_with_sort(df: pl.LazyFrame):
@pytest.mark.parametrize("mapping_strategy", ["group_to_rows", "explode", "join"])
def test_over_mapping_strategy(df: pl.LazyFrame, mapping_strategy: str):
"""Test window functions with different mapping strategies."""
# ignore is for polars' WindowMappingStrategy, which isn't publicly exported.
# https://github.com/pola-rs/polars/issues/17420
query = df.with_columns(
[pl.col("b").rank().over(pl.col("a"), mapping_strategy=mapping_strategy)]
[pl.col("b").rank().over(pl.col("a"), mapping_strategy=mapping_strategy)] # type: ignore[arg-type]
)
assert_ir_translation_raises(query, NotImplementedError)

Expand Down Expand Up @@ -130,7 +132,9 @@ def test_rolling_unsupported(df: pl.LazyFrame, unsupported_agg_expr):
@pytest.mark.parametrize("closed", ["left", "right", "both", "none"])
def test_rolling_closed(df: pl.LazyFrame, closed: str):
"""Test rolling window functions with different closed parameters."""
# ignore is for polars' ClosedInterval, which isn't publicly exported.
# https://github.com/pola-rs/polars/issues/17420
query = df.with_columns(
[pl.col("b").sum().rolling(period="2d", index_column="date", closed=closed)]
[pl.col("b").sum().rolling(period="2d", index_column="date", closed=closed)] # type: ignore[arg-type]
)
assert_gpu_result_equal(query)