Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion python/cudf_polars/cudf_polars/experimental/benchmarks/pdsh.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ class RunConfig:
default_factory=lambda: datetime.now(timezone.utc).isoformat()
)
hardware: HardwareInfo = dataclasses.field(default_factory=HardwareInfo.collect)
rapidsmpf_spill: bool
spill_device: float

@classmethod
def from_args(cls, args: argparse.Namespace) -> RunConfig:
Expand Down Expand Up @@ -174,6 +176,8 @@ def from_args(cls, args: argparse.Namespace) -> RunConfig:
threads=args.threads,
iterations=args.iterations,
suffix=args.suffix,
spill_device=args.spill_device,
rapidsmpf_spill=args.rapidsmpf_spill,
)

def serialize(self) -> dict:
Expand All @@ -197,6 +201,8 @@ def summarize(self) -> None:
if self.scheduler == "distributed":
print(f"n_workers: {self.n_workers}")
print(f"threads: {self.threads}")
print(f"spill_device: {self.spill_device}")
print(f"rapidsmpf_spill: {self.rapidsmpf_spill}")
print(f"iterations: {self.iterations}")
print("---------------------------------------")
print(f"min time : {min([record.duration for record in records]):0.4f}")
Expand Down Expand Up @@ -1133,6 +1139,18 @@ def _query_type(query: int | str) -> list[int]:
type=float,
help="RMM pool size (fractional).",
)
parser.add_argument(
"--rapidsmpf-spill",
action=argparse.BooleanOptionalAction,
default=False,
help="Use rapidsmpf for general spilling.",
)
parser.add_argument(
"--spill-device",
default=0.5,
type=float,
help="Rapdsimpf device spill threshold.",
)
parser.add_argument(
"-o",
"--output",
Expand Down Expand Up @@ -1185,7 +1203,7 @@ def run(args: argparse.Namespace) -> None:
try:
from rapidsmpf.integrations.dask import bootstrap_dask_cluster

bootstrap_dask_cluster(client, spill_device=0.5)
bootstrap_dask_cluster(client, spill_device=run_config.spill_device)
except ImportError as err:
if run_config.shuffle == "rapidsmpf":
raise ImportError from err
Expand Down Expand Up @@ -1217,6 +1235,8 @@ def run(args: argparse.Namespace) -> None:
"l_orderkey": 1.0, # Q18
},
}
if run_config.rapidsmpf_spill:
executor_options["rapidsmpf_spill"] = run_config.rapidsmpf_spill
if run_config.scheduler == "distributed":
executor_options["scheduler"] = "distributed"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,11 @@ def _(x: Column) -> int:
def _(x: DataFrame) -> int:
"""The total size of the device buffers used by the DataFrame or Column."""
return sum(c.obj.device_buffer_size() for c in x.columns)

# Register rapidsmpf serializer if it's installed.
try:
from rapidsmpf.integrations.dask.spilling import register_dask_serialize

register_dask_serialize() # pragma: no cover; rapidsmpf dependency not included yet
except ImportError:
pass
35 changes: 35 additions & 0 deletions python/cudf_polars/cudf_polars/experimental/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,39 @@ def get_scheduler(config_options: ConfigOptions) -> Any:
raise ValueError(f"{scheduler} not a supported scheduler option.")


def post_process_task_graph(
graph: MutableMapping[Any, Any],
key: str | tuple[str, int],
config_options: ConfigOptions,
) -> MutableMapping[Any, Any]:
"""
Post-process the task graph.

Parameters
----------
graph
Task graph to post-process.
key
Output key for the graph.
config_options
GPUEngine configuration options.

Returns
-------
graph
A Dask-compatible task graph.
"""
assert config_options.executor.name == "streaming", (
"'in-memory' executor not supported in 'post_process_task_graph'"
)

if config_options.executor.rapidsmpf_spill: # pragma: no cover
from cudf_polars.experimental.spilling import wrap_dataframe_in_spillable

return wrap_dataframe_in_spillable(graph, ignore_key=key)
return graph


def evaluate_streaming(ir: IR, config_options: ConfigOptions) -> DataFrame:
"""
Evaluate an IR graph with partitioning.
Expand All @@ -178,6 +211,8 @@ def evaluate_streaming(ir: IR, config_options: ConfigOptions) -> DataFrame:

graph, key = task_graph(ir, partition_info)

graph = post_process_task_graph(graph, key, config_options)

return get_scheduler(config_options)(graph, key)


Expand Down
113 changes: 113 additions & 0 deletions python/cudf_polars/cudf_polars/experimental/spilling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0
"""Spilling in multi-partition Dask execution using RAPIDSMPF."""

from __future__ import annotations

from typing import TYPE_CHECKING, Any

from rapidsmpf.integrations.dask.spilling import SpillableWrapper

from cudf_polars.containers import DataFrame

if TYPE_CHECKING:
from collections.abc import Callable, MutableMapping
from typing import Any


def wrap_arg(obj: Any) -> Any:
"""
Make `obj` spillable if it is a DataFrame.

Parameters
----------
obj
The object to be wrapped (if it is a DataFrame).

Returns
-------
A SpillableWrapper if obj is a DataFrame, otherwise the original object.
"""
if isinstance(obj, DataFrame):
return SpillableWrapper(on_device=obj)
return obj


def unwrap_arg(obj: Any) -> Any:
"""
Unwraps a SpillableWrapper to retrieve the original object.

Parameters
----------
obj
The object to be unwrapped.

Returns
-------
The unwrapped obj is a SpillableWrapper, otherwise the original object.
"""
if isinstance(obj, SpillableWrapper):
return obj.unspill()
return obj


def wrap_func_spillable(
func: Callable,
*,
make_func_output_spillable: bool,
) -> Callable:
"""
Wraps a function to handle spillable DataFrames.

Parameters
----------
func
The function to be wrapped.
make_func_output_spillable
Whether to wrap the function's output in a SpillableWrapper.

Returns
-------
A wrapped function that processes spillable DataFrames.
"""

def wrapper(*args: Any) -> Any:
ret: Any = func(*(unwrap_arg(arg) for arg in args))
if make_func_output_spillable:
ret = wrap_arg(ret)
return ret

return wrapper


def wrap_dataframe_in_spillable(
graph: MutableMapping[Any, Any], ignore_key: str | tuple[str, int]
) -> MutableMapping[Any, Any]:
"""
Wraps functions within a task graph to handle spillable DataFrames.

Only supports flat task graphs where each DataFrame can be found in the
outermost level. Currently, this is true for all cudf-polars task graphs.

Parameters
----------
graph
Task graph.
ignore_key
The key to ignore when wrapping function, typically the key of the
output node.

Returns
-------
A new task graph with wrapped functions.
"""
ret = {}
for key, task in graph.items():
assert isinstance(task, tuple)
ret[key] = tuple(
wrap_func_spillable(a, make_func_output_spillable=key != ignore_key)
if callable(a)
else a
for a in task
)
return ret
6 changes: 6 additions & 0 deletions python/cudf_polars/cudf_polars/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,9 @@ class StreamingExecutor:
The method to use for shuffling data between workers. ``None``
by default, which will use 'rapidsmpf' if installed and fall back to
'tasks' if not.
rapidsmpf_spill
Whether to wrap task arguments and output in objects that are
spillable by 'rapidsmpf'.
"""

name: Literal["streaming"] = dataclasses.field(default="streaming", init=False)
Expand All @@ -153,6 +156,7 @@ class StreamingExecutor:
groupby_n_ary: int = 32
broadcast_join_limit: int = 4
shuffle_method: ShuffleMethod | None = None
rapidsmpf_spill: bool = False

def __post_init__(self) -> None:
if self.scheduler == "synchronous" and self.shuffle_method == "rapidsmpf":
Expand Down Expand Up @@ -181,6 +185,8 @@ def __post_init__(self) -> None:
raise TypeError("groupby_n_ary must be an int")
if not isinstance(self.broadcast_join_limit, int):
raise TypeError("broadcast_join_limit must be an int")
if not isinstance(self.rapidsmpf_spill, bool):
raise TypeError("rapidsmpf_spill must be bool")

def __hash__(self) -> int:
# cardinality factory, a dict, isn't natively hashable. We'll dump it
Expand Down
6 changes: 5 additions & 1 deletion python/cudf_polars/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,11 @@ exclude_also = [
]
# The cudf_polars test suite doesn't exercise the plugin, so we omit
# it from coverage checks.
omit = ["cudf_polars/testing/plugin.py", "cudf_polars/experimental/benchmarks/pdsh.py"]
omit = [
"cudf_polars/testing/plugin.py",
"cudf_polars/experimental/benchmarks/pdsh.py",
"cudf_polars/experimental/spilling.py",
]

[tool.ruff]
line-length = 88
Expand Down
7 changes: 6 additions & 1 deletion python/cudf_polars/tests/experimental/test_rapidsmpf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,12 @@
from cudf_polars.testing.asserts import assert_gpu_result_equal


@pytest.mark.parametrize("rapidsmpf_spill", [False, True])
@pytest.mark.parametrize("max_rows_per_partition", [1, 5])
def test_join_rapidsmpf(max_rows_per_partition: int) -> None:
def test_join_rapidsmpf(
max_rows_per_partition: int,
rapidsmpf_spill: bool, # noqa: FBT001
) -> None:
# Check that we have a distributed cluster running.
# This tests must be run with:
# --executor='streaming' --scheduler='distributed'
Expand Down Expand Up @@ -43,6 +47,7 @@ def test_join_rapidsmpf(max_rows_per_partition: int) -> None:
"broadcast_join_limit": 2,
"shuffle_method": "rapidsmpf",
"scheduler": "distributed",
"rapidsmpf_spill": rapidsmpf_spill,
},
)

Expand Down
1 change: 1 addition & 0 deletions python/cudf_polars/tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ def test_validate_shuffle_method() -> None:
"parquet_blocksize",
"groupby_n_ary",
"broadcast_join_limit",
"rapidsmpf_spill",
],
)
def test_validate_max_rows_per_partition(option: str) -> None:
Expand Down
Loading