Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ repos:
rev: 'v1.13.0'
hooks:
- id: mypy
additional_dependencies: [types-cachetools, pyarrow-stubs, numpy, pytest]
additional_dependencies: [numpy, polars, pyarrow-stubs, pytest, types-cachetools]
args: ["--config-file=pyproject.toml",
"python/cudf/cudf",
"python/pylibcudf/pylibcudf",
Expand Down
59 changes: 46 additions & 13 deletions python/cudf_polars/cudf_polars/containers/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,20 @@
from __future__ import annotations

from functools import cache
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Literal, cast

from typing_extensions import assert_never

import polars as pl

import pylibcudf as plc

from cudf_polars.utils.versions import POLARS_VERSION_LT_136

if TYPE_CHECKING:
from cudf_polars.typing import (
DataTypeHeader,
PolarsDataType,
)

__all__ = ["DataType"]
Expand Down Expand Up @@ -46,7 +49,18 @@ def _dtype_to_header(dtype: pl.DataType) -> DataTypeHeader:
if name in SCALAR_NAME_TO_POLARS_TYPE_MAP:
return {"kind": "scalar", "name": name}
if isinstance(dtype, pl.Decimal):
return {"kind": "decimal", "precision": dtype.precision, "scale": dtype.scale}
# Workaround for incorrect polars stubs where precision is typed as int | None
# Fixed upstream: https://github.com/pola-rs/polars/pull/25227
# TODO: Remove this workaround when polars >= 1.36
if POLARS_VERSION_LT_136:
assert (
dtype.precision is not None
) # Decimal always has precision at runtime
return {
"kind": "decimal",
"precision": cast(int, dtype.precision),
"scale": dtype.scale,
}
Comment thread
Matt711 marked this conversation as resolved.
if isinstance(dtype, pl.Datetime):
return {
"kind": "datetime",
Expand All @@ -56,12 +70,17 @@ def _dtype_to_header(dtype: pl.DataType) -> DataTypeHeader:
if isinstance(dtype, pl.Duration):
return {"kind": "duration", "time_unit": dtype.time_unit}
if isinstance(dtype, pl.List):
return {"kind": "list", "inner": _dtype_to_header(dtype.inner)}
# isinstance narrows dtype to pl.List, but .inner returns DataTypeClass | DataType
return {
"kind": "list",
"inner": _dtype_to_header(cast(pl.DataType, dtype.inner)),
}
if isinstance(dtype, pl.Struct):
# isinstance narrows dtype to pl.Struct, but field.dtype returns DataTypeClass | DataType
return {
"kind": "struct",
"fields": [
{"name": f.name, "dtype": _dtype_to_header(f.dtype)}
{"name": f.name, "dtype": _dtype_to_header(cast(pl.DataType, f.dtype))}
for f in dtype.fields
],
}
Expand All @@ -78,9 +97,14 @@ def _dtype_from_header(header: DataTypeHeader) -> pl.DataType:
if header["kind"] == "decimal":
return pl.Decimal(header["precision"], header["scale"])
if header["kind"] == "datetime":
return pl.Datetime(time_unit=header["time_unit"], time_zone=header["time_zone"])
return pl.Datetime(
time_unit=cast(Literal["ns", "us", "ms"], header["time_unit"]),
time_zone=header["time_zone"],
)
if header["kind"] == "duration":
return pl.Duration(time_unit=header["time_unit"])
return pl.Duration(
time_unit=cast(Literal["ns", "us", "ms"], header["time_unit"])
)
if header["kind"] == "list":
return pl.List(_dtype_from_header(header["inner"]))
if header["kind"] == "struct":
Expand Down Expand Up @@ -182,9 +206,14 @@ class DataType:
polars_type: pl.datatypes.DataType
plc_type: plc.DataType

def __init__(self, polars_dtype: pl.DataType) -> None:
self.polars_type = polars_dtype
self.plc_type = _from_polars(polars_dtype)
def __init__(self, polars_dtype: PolarsDataType) -> None:
# Convert DataTypeClass to DataType instance if needed
# polars allows both pl.Int64 (class) and pl.Int64() (instance)
if isinstance(polars_dtype, type):
polars_dtype = polars_dtype()
# After conversion, it's guaranteed to be a DataType instance
self.polars_type = cast(pl.DataType, polars_dtype)
self.plc_type = _from_polars(self.polars_type)

def id(self) -> plc.TypeId:
"""The pylibcudf.TypeId of this DataType."""
Expand All @@ -193,12 +222,16 @@ def id(self) -> plc.TypeId:
@property
def children(self) -> list[DataType]:
"""The children types of this DataType."""
# these type ignores are needed because the type checker doesn't
# see that these equality checks passing imply a specific type for each child field.
# Type checker doesn't narrow polars_type through plc_type.id() checks
if self.plc_type.id() == plc.TypeId.STRUCT:
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could consider implementing some TypeGuard functions for our DataType classes so that we can get better narrowing if others prefer that approach to casting like this.

return [DataType(field.dtype) for field in self.polars_type.fields]
# field.dtype returns DataTypeClass | DataType, need to cast to DataType
return [
DataType(cast(pl.DataType, field.dtype))
for field in cast(pl.Struct, self.polars_type).fields
]
elif self.plc_type.id() == plc.TypeId.LIST:
return [DataType(self.polars_type.inner)]
# .inner returns DataTypeClass | DataType, need to cast to DataType
return [DataType(cast(pl.DataType, cast(pl.List, self.polars_type).inner))]
return []

def scale(self) -> int:
Expand Down
11 changes: 9 additions & 2 deletions python/cudf_polars/cudf_polars/dsl/expressions/boolean.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@

from enum import IntEnum, auto
from functools import partial, reduce
from typing import TYPE_CHECKING, Any, ClassVar
from typing import TYPE_CHECKING, Any, ClassVar, cast

import polars as pl

import pylibcudf as plc

Expand Down Expand Up @@ -350,9 +352,14 @@ def do_evaluate(
needles, haystack = columns
if haystack.obj.type().id() == plc.TypeId.LIST:
# Unwrap values from the list column
# .inner returns DataTypeClass | DataType, need to cast to DataType
haystack = Column(
haystack.obj.children()[1],
dtype=DataType(haystack.dtype.polars_type.inner),
dtype=DataType(
cast(
pl.DataType, cast(pl.List, haystack.dtype.polars_type).inner
)
),
).astype(needles.dtype, stream=df.stream)
if haystack.size:
return Column(
Expand Down
7 changes: 4 additions & 3 deletions python/cudf_polars/cudf_polars/dsl/expressions/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
import re
from datetime import datetime
from enum import IntEnum, auto
from typing import TYPE_CHECKING, Any, ClassVar
from typing import TYPE_CHECKING, Any, ClassVar, cast

import polars as pl
from polars.exceptions import InvalidOperationError
from polars.polars import dtype_str_repr

Expand All @@ -37,12 +38,12 @@

def _dtypes_for_json_decode(dtype: DataType) -> JsonDecodeType:
"""Get the dtypes for json decode."""
# the type checker doesn't know that this equality check implies a struct dtype.
# Type checker doesn't narrow polars_type through dtype.id() check
if dtype.id() == plc.TypeId.STRUCT:
return [
(field.name, child.plc_type, _dtypes_for_json_decode(child))
for field, child in zip(
dtype.polars_type.fields,
cast(pl.Struct, dtype.polars_type).fields,
dtype.children,
strict=True,
)
Expand Down
15 changes: 10 additions & 5 deletions python/cudf_polars/cudf_polars/dsl/expressions/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@

from enum import IntEnum, auto
from io import StringIO
from typing import TYPE_CHECKING, Any, ClassVar
from typing import TYPE_CHECKING, Any, ClassVar, cast

import polars as pl

import pylibcudf as plc

Expand Down Expand Up @@ -87,13 +89,14 @@ def do_evaluate(
"""Evaluate this expression given a dataframe for context."""
columns = [child.evaluate(df, context=context) for child in self.children]
(column,) = columns
# these type ignores are needed because the type checker doesn't
# know that polars only calls StructFunction with struct types.
# Type checker doesn't know polars only calls StructFunction with struct types
if self.name == StructFunction.Name.FieldByName:
field_index = next(
(
i
for i, field in enumerate(self.children[0].dtype.polars_type.fields)
for i, field in enumerate(
cast(pl.Struct, self.children[0].dtype.polars_type).fields
)
if field.name == self.options[0]
),
None,
Expand All @@ -113,7 +116,9 @@ def do_evaluate(
table,
[
(field.name, [])
for field in self.children[0].dtype.polars_type.fields
for field in cast(
pl.Struct, self.children[0].dtype.polars_type
).fields
],
)
options = (
Expand Down
3 changes: 2 additions & 1 deletion python/cudf_polars/cudf_polars/dsl/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,8 @@ def add_file_paths(
plc.Table(
[
plc.Column.from_arrow(
pl.Series(values=map(str, paths)), stream=df.stream
pl.Series(values=map(str, paths)),
stream=df.stream,
)
]
),
Expand Down
12 changes: 7 additions & 5 deletions python/cudf_polars/cudf_polars/dsl/to_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
from __future__ import annotations

from functools import partial, reduce, singledispatch
from typing import TYPE_CHECKING, TypeAlias, TypedDict
from typing import TYPE_CHECKING, TypeAlias, TypedDict, cast

import polars as pl

import pylibcudf as plc
from pylibcudf import expressions as plc_expr
Expand Down Expand Up @@ -226,10 +228,10 @@ def _(node: expr.BooleanFunction, self: Transformer) -> plc_expr.Expression:
if haystack.dtype.id() == plc.TypeId.LIST:
# Because we originally translated pl_expr.Literal with a list scalar
# to a expr.LiteralColumn, so the actual type is in the inner type
#
# the type-ignore is safe because the for plc.TypeID.LIST, we know
# we have a polars.List type, which has an inner attribute.
plc_dtype = DataType(haystack.dtype.polars_type.inner).plc_type
# .inner returns DataTypeClass | DataType, need to cast to DataType
plc_dtype = DataType(
cast(pl.DataType, cast(pl.List, haystack.dtype.polars_type).inner)
).plc_type
else:
plc_dtype = haystack.dtype.plc_type # pragma: no cover
values = (
Expand Down
4 changes: 4 additions & 0 deletions python/cudf_polars/cudf_polars/experimental/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ def find_sort_splits(
stream=stream,
)
# And convert to list for final processing
# The type ignores are for cross-library boundaries: plc.Column -> pl.Series
# These work at runtime via the Arrow C Data Interface protocol
# TODO: Find a way for pylibcudf types to show they export the Arrow protocol
# (mypy wasn't happy with a custom protocol)
split_first_list = pl.Series(split_first_col).to_list()
split_last_list = pl.Series(split_last_col).to_list()
split_part_id_list = pl.Series(split_part_id).to_list()
Expand Down
10 changes: 5 additions & 5 deletions python/cudf_polars/cudf_polars/testing/asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@ def assert_gpu_result_equal(

# These keywords are correct, but mypy doesn't see that.
# the 'misc' is for 'error: Keywords must be strings'
expect = lazydf.collect(**final_polars_collect_kwargs)
got = lazydf.collect(**final_cudf_collect_kwargs, engine=engine)
expect = lazydf.collect(**final_polars_collect_kwargs) # type: ignore[misc, call-overload]
got = lazydf.collect(**final_cudf_collect_kwargs, engine=engine) # type: ignore[misc, call-overload]

assert_kwargs_bool: dict[str, bool] = {
"check_row_order": check_row_order,
Expand All @@ -136,7 +136,7 @@ def assert_gpu_result_equal(
expect,
got,
**assert_kwargs_bool,
**tol_kwargs,
**tol_kwargs, # type: ignore[arg-type]
)


Expand Down Expand Up @@ -294,7 +294,7 @@ def assert_collect_raises(
)

try:
lazydf.collect(**final_polars_collect_kwargs)
lazydf.collect(**final_polars_collect_kwargs) # type: ignore[misc, call-overload]
except polars_except:
pass
except Exception as e:
Expand All @@ -307,7 +307,7 @@ def assert_collect_raises(

engine = GPUEngine(raise_on_fail=True)
try:
lazydf.collect(**final_cudf_collect_kwargs, engine=engine)
lazydf.collect(**final_cudf_collect_kwargs, engine=engine) # type: ignore[misc, call-overload]
except cudf_except:
pass
except Exception as e:
Expand Down
6 changes: 3 additions & 3 deletions python/cudf_polars/cudf_polars/testing/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,11 @@ def pytest_configure(config: pytest.Config) -> None:
collect = polars.LazyFrame.collect
engine = polars.GPUEngine(raise_on_fail=no_fallback)
# https://github.com/python/mypy/issues/2427
polars.LazyFrame.collect = partialmethod(collect, engine=engine)
polars.LazyFrame.collect = partialmethod(collect, engine=engine) # type: ignore[method-assign, assignment]
elif executor == "in-memory":
collect = polars.LazyFrame.collect
engine = polars.GPUEngine(executor=executor)
polars.LazyFrame.collect = partialmethod(collect, engine=engine)
polars.LazyFrame.collect = partialmethod(collect, engine=engine) # type: ignore[method-assign, assignment]
elif executor == "streaming" and blocksize_mode == "small":
executor_options: dict[str, Any] = {}
executor_options["max_rows_per_partition"] = 4
Expand All @@ -70,7 +70,7 @@ def pytest_configure(config: pytest.Config) -> None:
executor_options["fallback_mode"] = StreamingFallbackMode.SILENT
collect = polars.LazyFrame.collect
engine = polars.GPUEngine(executor=executor, executor_options=executor_options)
polars.LazyFrame.collect = partialmethod(collect, engine=engine)
polars.LazyFrame.collect = partialmethod(collect, engine=engine) # type: ignore[method-assign, assignment]
else:
# run with streaming executor and default blocksize
polars.Config.set_engine_affinity("gpu")
Expand Down
1 change: 1 addition & 0 deletions python/cudf_polars/cudf_polars/utils/versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
POLARS_VERSION_LT_1321 = POLARS_VERSION < parse("1.32.1")
POLARS_VERSION_LT_1323 = POLARS_VERSION < parse("1.32.3")
POLARS_VERSION_LT_133 = POLARS_VERSION < parse("1.33.0")
POLARS_VERSION_LT_136 = POLARS_VERSION < parse("1.36")


def _ensure_polars_version() -> None:
Expand Down
7 changes: 6 additions & 1 deletion python/cudf_polars/tests/containers/test_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from __future__ import annotations

from typing import TYPE_CHECKING

import pytest

import polars as pl
Expand All @@ -14,8 +16,11 @@
from cudf_polars.containers import Column, DataType
from cudf_polars.utils.cuda_stream import get_cuda_stream

if TYPE_CHECKING:
from cudf_polars.typing import PolarsDataType


def _as_instance(dtype: pl.DataType) -> pl.DataType:
def _as_instance(dtype: PolarsDataType) -> pl.DataType:
if isinstance(dtype, type):
return dtype()
if isinstance(dtype, pl.List):
Expand Down
8 changes: 6 additions & 2 deletions python/cudf_polars/tests/experimental/test_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from __future__ import annotations

from typing import Literal, cast

import pytest

import polars as pl
Expand Down Expand Up @@ -75,7 +77,9 @@ def test_hash_shuffle(df: pl.LazyFrame, engine: pl.GPUEngine) -> None:
qir3,
options,
)
# ignore is for polars' EngineType, which isn't publicly exported.
# Cast needed because polars' EngineType "cpu" isn't publicly exported.
# https://github.com/pola-rs/polars/issues/17420
expect = df.collect(engine="cpu")
expect = df.collect(
engine=cast(Literal["auto", "in-memory", "streaming", "gpu"], "cpu")
)
assert_frame_equal(result, expect, check_row_order=False)
Loading