Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/source/quickstart.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ from rapidsmp.examples.dask import dask_cudf_shuffle

df = dask.datasets.timeseries().reset_index(drop=True).to_backend("cudf")

with LocalCUDACluster() as cluster:
# Rapidsmp is compatible with `dask_cuda` workers.
# Use an rmm pool for optimal performance.
with LocalCUDACluster(rmm_pool_size=0.8) as cluster:
with dask.distributed.Client(cluster) as client:
shuffled = dask_cudf_shuffle(df, shuffle_on=["name"])

Expand Down
30 changes: 5 additions & 25 deletions python/rapidsmp/rapidsmp/integrations/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,6 @@ async def rapidsmp_ucxx_rank_setup_node(
def rmp_worker_setup(
dask_worker: Worker,
*,
pool_size: float = 0.75,
spill_device: float = 0.50,
enable_statistics: bool = True,
) -> None:
Expand All @@ -489,8 +488,6 @@ def rmp_worker_setup(
----------
dask_worker
The current Dask worker.
pool_size
The desired RMM pool size.
spill_device
GPU memory limit for shuffling.
enable_statistics
Expand Down Expand Up @@ -529,22 +526,14 @@ def rmp_worker_setup(
else:
dask_worker._rmp_statistics = None

# Setup a buffer_resource
# Create a RMM stack with both a device pool and statistics.
available_memory = rmm.mr.available_device_memory()[1]
rmm_pool_size = int(available_memory * pool_size)
rmm_pool_size = (rmm_pool_size // 256) * 256
mr = rmm.mr.StatisticsResourceAdaptor(
rmm.mr.PoolMemoryResource(
rmm.mr.CudaMemoryResource(),
initial_pool_size=rmm_pool_size,
maximum_pool_size=rmm_pool_size,
)
)
# Setup a buffer_resource.
# Wrap the current RMM resource in statistics adaptor.
mr = rmm.mr.StatisticsResourceAdaptor(rmm.mr.get_current_device_resource())
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: The current memory resource doesn't need to be a pool for spilling to work.

rmm.mr.set_current_device_resource(mr)
total_memory = rmm.mr.available_device_memory()[1]
memory_available = {
MemoryType.DEVICE: LimitAvailableMemory(
mr, limit=int(available_memory * spill_device)
mr, limit=int(total_memory * spill_device)
)
}
dask_worker._rmp_buffer_resource = BufferResource(mr, memory_available)
Expand All @@ -553,7 +542,6 @@ def rmp_worker_setup(
def bootstrap_dask_cluster(
client: Client,
*,
pool_size: float = 0.75,
spill_device: float = 0.50,
enable_statistics: bool = True,
) -> None:
Expand All @@ -564,18 +552,11 @@ def bootstrap_dask_cluster(
----------
client
The current Dask client.
pool_size
The desired RMM pool size.
spill_device
GPU memory limit for shuffling.
enable_statistics
Whether to track shuffler statistics.

See Also
--------
bootstrap_dask_cluster_async
Setup an asynchronous Dask cluster for rapidsmp shuffling.

Notes
-----
This utility must be executed before rapidsmp shuffling can be used within a
Expand Down Expand Up @@ -627,7 +608,6 @@ def bootstrap_dask_cluster(
# Finally, prepare the rapidsmp resources on top of the UCXX comms
client.run(
rmp_worker_setup,
pool_size=pool_size,
spill_device=spill_device,
enable_statistics=enable_statistics,
)
Expand Down
10 changes: 5 additions & 5 deletions python/rapidsmp/rapidsmp/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ async def test_dask_ucxx_cluster_sync() -> None:
Client(cluster) as client,
):
assert len(cluster.workers) == get_n_gpus()
bootstrap_dask_cluster(client, pool_size=0.25, spill_device=0.1)
bootstrap_dask_cluster(client, spill_device=0.1)

def get_rank(dask_worker: Worker) -> int:
# TODO: maybe move the cast into rapidsmp_comm?
Expand All @@ -69,7 +69,7 @@ def test_dask_cudf_integration(loop: pytest.FixtureDef, partition_count: int) ->

with LocalCUDACluster(loop=loop) as cluster: # noqa: SIM117
with Client(cluster) as client:
bootstrap_dask_cluster(client, pool_size=0.25, spill_device=0.1)
bootstrap_dask_cluster(client, spill_device=0.1)
df = (
dask.datasets.timeseries(
freq="3600s",
Expand All @@ -93,14 +93,14 @@ def test_dask_cudf_integration(loop: pytest.FixtureDef, partition_count: int) ->

def test_bootstrap_dask_cluster_idempotent() -> None:
with LocalCUDACluster() as cluster, Client(cluster) as client:
bootstrap_dask_cluster(client, pool_size=0.25, spill_device=0.1)
bootstrap_dask_cluster(client, spill_device=0.1)
before = client.run(lambda dask_worker: id(dask_worker._rapidsmp_comm))

bootstrap_dask_cluster(client, pool_size=0.25, spill_device=0.1)
bootstrap_dask_cluster(client, spill_device=0.1)
after = client.run(lambda dask_worker: id(dask_worker._rapidsmp_comm))
assert before == after


def test_boostrap_single_node_cluster_no_deadlock() -> None:
with LocalCUDACluster(n_workers=1) as cluster, Client(cluster) as client:
bootstrap_dask_cluster(client, pool_size=0.25, spill_device=0.1)
bootstrap_dask_cluster(client, spill_device=0.1)