diff --git a/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsh.py b/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsh.py index 1f528d36a28..29e3553c490 100644 --- a/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsh.py +++ b/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsh.py @@ -147,6 +147,8 @@ class RunConfig: default_factory=lambda: datetime.now(timezone.utc).isoformat() ) hardware: HardwareInfo = dataclasses.field(default_factory=HardwareInfo.collect) + rapidsmpf_spill: bool + spill_device: float @classmethod def from_args(cls, args: argparse.Namespace) -> RunConfig: @@ -174,6 +176,8 @@ def from_args(cls, args: argparse.Namespace) -> RunConfig: threads=args.threads, iterations=args.iterations, suffix=args.suffix, + spill_device=args.spill_device, + rapidsmpf_spill=args.rapidsmpf_spill, ) def serialize(self) -> dict: @@ -197,6 +201,8 @@ def summarize(self) -> None: if self.scheduler == "distributed": print(f"n_workers: {self.n_workers}") print(f"threads: {self.threads}") + print(f"spill_device: {self.spill_device}") + print(f"rapidsmpf_spill: {self.rapidsmpf_spill}") print(f"iterations: {self.iterations}") print("---------------------------------------") print(f"min time : {min([record.duration for record in records]):0.4f}") @@ -1133,6 +1139,18 @@ def _query_type(query: int | str) -> list[int]: type=float, help="RMM pool size (fractional).", ) +parser.add_argument( + "--rapidsmpf-spill", + action=argparse.BooleanOptionalAction, + default=False, + help="Use rapidsmpf for general spilling.", +) +parser.add_argument( + "--spill-device", + default=0.5, + type=float, + help="Rapdsimpf device spill threshold.", +) parser.add_argument( "-o", "--output", @@ -1185,7 +1203,7 @@ def run(args: argparse.Namespace) -> None: try: from rapidsmpf.integrations.dask import bootstrap_dask_cluster - bootstrap_dask_cluster(client, spill_device=0.5) + bootstrap_dask_cluster(client, spill_device=run_config.spill_device) except ImportError as err: if run_config.shuffle == "rapidsmpf": raise ImportError from err @@ -1217,6 +1235,8 @@ def run(args: argparse.Namespace) -> None: "l_orderkey": 1.0, # Q18 }, } + if run_config.rapidsmpf_spill: + executor_options["rapidsmpf_spill"] = run_config.rapidsmpf_spill if run_config.scheduler == "distributed": executor_options["scheduler"] = "distributed" diff --git a/python/cudf_polars/cudf_polars/experimental/dask_registers.py b/python/cudf_polars/cudf_polars/experimental/dask_registers.py index c9f2b0be72b..39260d88f63 100644 --- a/python/cudf_polars/cudf_polars/experimental/dask_registers.py +++ b/python/cudf_polars/cudf_polars/experimental/dask_registers.py @@ -139,3 +139,11 @@ def _(x: Column) -> int: def _(x: DataFrame) -> int: """The total size of the device buffers used by the DataFrame or Column.""" return sum(c.obj.device_buffer_size() for c in x.columns) + + # Register rapidsmpf serializer if it's installed. + try: + from rapidsmpf.integrations.dask.spilling import register_dask_serialize + + register_dask_serialize() # pragma: no cover; rapidsmpf dependency not included yet + except ImportError: + pass diff --git a/python/cudf_polars/cudf_polars/experimental/parallel.py b/python/cudf_polars/cudf_polars/experimental/parallel.py index aee7590c4b2..87bc5fa5e9e 100644 --- a/python/cudf_polars/cudf_polars/experimental/parallel.py +++ b/python/cudf_polars/cudf_polars/experimental/parallel.py @@ -159,6 +159,39 @@ def get_scheduler(config_options: ConfigOptions) -> Any: raise ValueError(f"{scheduler} not a supported scheduler option.") +def post_process_task_graph( + graph: MutableMapping[Any, Any], + key: str | tuple[str, int], + config_options: ConfigOptions, +) -> MutableMapping[Any, Any]: + """ + Post-process the task graph. + + Parameters + ---------- + graph + Task graph to post-process. + key + Output key for the graph. + config_options + GPUEngine configuration options. + + Returns + ------- + graph + A Dask-compatible task graph. + """ + assert config_options.executor.name == "streaming", ( + "'in-memory' executor not supported in 'post_process_task_graph'" + ) + + if config_options.executor.rapidsmpf_spill: # pragma: no cover + from cudf_polars.experimental.spilling import wrap_dataframe_in_spillable + + return wrap_dataframe_in_spillable(graph, ignore_key=key) + return graph + + def evaluate_streaming(ir: IR, config_options: ConfigOptions) -> DataFrame: """ Evaluate an IR graph with partitioning. @@ -178,6 +211,8 @@ def evaluate_streaming(ir: IR, config_options: ConfigOptions) -> DataFrame: graph, key = task_graph(ir, partition_info) + graph = post_process_task_graph(graph, key, config_options) + return get_scheduler(config_options)(graph, key) diff --git a/python/cudf_polars/cudf_polars/experimental/spilling.py b/python/cudf_polars/cudf_polars/experimental/spilling.py new file mode 100644 index 00000000000..62433c8e7ce --- /dev/null +++ b/python/cudf_polars/cudf_polars/experimental/spilling.py @@ -0,0 +1,113 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 +"""Spilling in multi-partition Dask execution using RAPIDSMPF.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from rapidsmpf.integrations.dask.spilling import SpillableWrapper + +from cudf_polars.containers import DataFrame + +if TYPE_CHECKING: + from collections.abc import Callable, MutableMapping + from typing import Any + + +def wrap_arg(obj: Any) -> Any: + """ + Make `obj` spillable if it is a DataFrame. + + Parameters + ---------- + obj + The object to be wrapped (if it is a DataFrame). + + Returns + ------- + A SpillableWrapper if obj is a DataFrame, otherwise the original object. + """ + if isinstance(obj, DataFrame): + return SpillableWrapper(on_device=obj) + return obj + + +def unwrap_arg(obj: Any) -> Any: + """ + Unwraps a SpillableWrapper to retrieve the original object. + + Parameters + ---------- + obj + The object to be unwrapped. + + Returns + ------- + The unwrapped obj is a SpillableWrapper, otherwise the original object. + """ + if isinstance(obj, SpillableWrapper): + return obj.unspill() + return obj + + +def wrap_func_spillable( + func: Callable, + *, + make_func_output_spillable: bool, +) -> Callable: + """ + Wraps a function to handle spillable DataFrames. + + Parameters + ---------- + func + The function to be wrapped. + make_func_output_spillable + Whether to wrap the function's output in a SpillableWrapper. + + Returns + ------- + A wrapped function that processes spillable DataFrames. + """ + + def wrapper(*args: Any) -> Any: + ret: Any = func(*(unwrap_arg(arg) for arg in args)) + if make_func_output_spillable: + ret = wrap_arg(ret) + return ret + + return wrapper + + +def wrap_dataframe_in_spillable( + graph: MutableMapping[Any, Any], ignore_key: str | tuple[str, int] +) -> MutableMapping[Any, Any]: + """ + Wraps functions within a task graph to handle spillable DataFrames. + + Only supports flat task graphs where each DataFrame can be found in the + outermost level. Currently, this is true for all cudf-polars task graphs. + + Parameters + ---------- + graph + Task graph. + ignore_key + The key to ignore when wrapping function, typically the key of the + output node. + + Returns + ------- + A new task graph with wrapped functions. + """ + ret = {} + for key, task in graph.items(): + assert isinstance(task, tuple) + ret[key] = tuple( + wrap_func_spillable(a, make_func_output_spillable=key != ignore_key) + if callable(a) + else a + for a in task + ) + return ret diff --git a/python/cudf_polars/cudf_polars/utils/config.py b/python/cudf_polars/cudf_polars/utils/config.py index 1814649cf65..b73ec466a68 100644 --- a/python/cudf_polars/cudf_polars/utils/config.py +++ b/python/cudf_polars/cudf_polars/utils/config.py @@ -142,6 +142,9 @@ class StreamingExecutor: The method to use for shuffling data between workers. ``None`` by default, which will use 'rapidsmpf' if installed and fall back to 'tasks' if not. + rapidsmpf_spill + Whether to wrap task arguments and output in objects that are + spillable by 'rapidsmpf'. """ name: Literal["streaming"] = dataclasses.field(default="streaming", init=False) @@ -153,6 +156,7 @@ class StreamingExecutor: groupby_n_ary: int = 32 broadcast_join_limit: int = 4 shuffle_method: ShuffleMethod | None = None + rapidsmpf_spill: bool = False def __post_init__(self) -> None: if self.scheduler == "synchronous" and self.shuffle_method == "rapidsmpf": @@ -181,6 +185,8 @@ def __post_init__(self) -> None: raise TypeError("groupby_n_ary must be an int") if not isinstance(self.broadcast_join_limit, int): raise TypeError("broadcast_join_limit must be an int") + if not isinstance(self.rapidsmpf_spill, bool): + raise TypeError("rapidsmpf_spill must be bool") def __hash__(self) -> int: # cardinality factory, a dict, isn't natively hashable. We'll dump it diff --git a/python/cudf_polars/pyproject.toml b/python/cudf_polars/pyproject.toml index 3d7a942c8ee..cf6511958f0 100644 --- a/python/cudf_polars/pyproject.toml +++ b/python/cudf_polars/pyproject.toml @@ -78,7 +78,11 @@ exclude_also = [ ] # The cudf_polars test suite doesn't exercise the plugin, so we omit # it from coverage checks. -omit = ["cudf_polars/testing/plugin.py", "cudf_polars/experimental/benchmarks/pdsh.py"] +omit = [ + "cudf_polars/testing/plugin.py", + "cudf_polars/experimental/benchmarks/pdsh.py", + "cudf_polars/experimental/spilling.py", + ] [tool.ruff] line-length = 88 diff --git a/python/cudf_polars/tests/experimental/test_rapidsmpf.py b/python/cudf_polars/tests/experimental/test_rapidsmpf.py index a2c1ac62aaf..4b611355b12 100644 --- a/python/cudf_polars/tests/experimental/test_rapidsmpf.py +++ b/python/cudf_polars/tests/experimental/test_rapidsmpf.py @@ -10,8 +10,12 @@ from cudf_polars.testing.asserts import assert_gpu_result_equal +@pytest.mark.parametrize("rapidsmpf_spill", [False, True]) @pytest.mark.parametrize("max_rows_per_partition", [1, 5]) -def test_join_rapidsmpf(max_rows_per_partition: int) -> None: +def test_join_rapidsmpf( + max_rows_per_partition: int, + rapidsmpf_spill: bool, # noqa: FBT001 +) -> None: # Check that we have a distributed cluster running. # This tests must be run with: # --executor='streaming' --scheduler='distributed' @@ -43,6 +47,7 @@ def test_join_rapidsmpf(max_rows_per_partition: int) -> None: "broadcast_join_limit": 2, "shuffle_method": "rapidsmpf", "scheduler": "distributed", + "rapidsmpf_spill": rapidsmpf_spill, }, ) diff --git a/python/cudf_polars/tests/test_config.py b/python/cudf_polars/tests/test_config.py index 155d004785f..e43ad1e784b 100644 --- a/python/cudf_polars/tests/test_config.py +++ b/python/cudf_polars/tests/test_config.py @@ -225,6 +225,7 @@ def test_validate_shuffle_method() -> None: "parquet_blocksize", "groupby_n_ary", "broadcast_join_limit", + "rapidsmpf_spill", ], ) def test_validate_max_rows_per_partition(option: str) -> None: