Skip to content

Refactor Dask cluster boostrapping#150

Merged
rapids-bot[bot] merged 19 commits intorapidsai:branch-25.06from
TomAugspurger:tom/dask-bootstrap-cleanup
Mar 25, 2025
Merged

Refactor Dask cluster boostrapping#150
rapids-bot[bot] merged 19 commits intorapidsai:branch-25.06from
TomAugspurger:tom/dask-bootstrap-cleanup

Conversation

@TomAugspurger
Copy link
Contributor

@TomAugspurger TomAugspurger commented Mar 20, 2025

This refactors how we do our dask cluster boostraping. It fixes a few issues:

  1. No need for a new Cluster class. Just use LocalCUDACluster.
  2. Avoid using Client.submit and Futures to bootstrap the UCXX comms, which show up in the user's Dashboard.
  3. Avoid a "two-stage" bootstrapping, where we have an initial stage that sets up the root node followed by bootstrapping the other nodes. Worker plugins are applied uniformly to each worker, so the plugin itself needs to have a way to handle both the root and non-root node cases.

Instead, we'll use scheduler and worker plugins. These are all set up in bootstrap_dask_cluster, which is called via rapidsmp.integrations.dask.get_client (which is called via rapidsmp_shuffle_graph). Users shouldn't need to change anything.

The actual implementation grew a little complicated, but I've left a note explaining things. That might be a good spot to start reviewing.

This refactors how we do our dask cluster boostraping. It fixes a few
issues:

1. No need for a new Cluster class. Just use LocalCUDACluster.
2. Avoid using `submit` and `Futures`, which show up in the user's
   Dashboard
@TomAugspurger TomAugspurger added improvement Improves an existing functionality non-breaking Introduces a non-breaking change labels Mar 20, 2025
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still looking into this a bit, but there are apparently some differences in plugin behavior for sync and async clients. With an async client, we need to await the plugin.register_plugin to ensure it runs to completion before we proceed to the actual shuffle.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is something @rjzamora and I discussed previously and none of us was sure what's the "proper" way to do that. Is there a single way, either sync or async, that works for all cases? This function is inherently blocking because of the barrier we need, so if there's a single implementation that can work everywhere it would probably makes most sense.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there's nothing that prevents it from working with Dask, I would prefer if we could have a single function that can be called in either case, in my mind I would hope for a wrapper function that handles the proper case, similar to this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This kind of wrapper is only an option if the public interface is async (since you're doing an await func() in the iscoroutinefunction case). In our case, I think we need the public API to be sync since we're calling this from rapidsmp_shuffle_graph (which is sync).

I'm kind of down on APIs that magically call async from sync :/ They just always seem to lead to weirdness around coroutines being created / run on the wrong event loop (zarr-developers/zarr-python#2909 is a recent example) that I can never follow. I'd much prefer to duplicate the API definition, and a bit of the implementation if it isn't too complex.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I don't think we can define one function that can be called by both a synchronous and asynchronous client with out doing some wrapper gymnastics that would only make sense if we were duplicating a lot of code. In this case, the duplicated code seems pretty minimal.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something to figure out here: we call this via get_dask_client, which is called via the main rapidsmp_shuffle_graph. If the Client we get happens to be asynchronous, we'd need to call (and await) boostrap_dask_cluster_async. But we're in a sync function.

We can probably make something work using client.sync.

I'm tempted to punt on this because there are other things to change (all the client.runs need to be awaited) async clients are relatively uncommon.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any good reason to support an asynchronous client at all?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question... We do test for it, on main using the test test_dask_ucxx_cluster, which I've spit / duplicated that into a sync and async version in this PR.

That test only runs the bootstrapping. I don't think we could do an actual rapidsmp_shuffle_graph on either main or this PR using that async client.

So if we want, I can delete that test and this function (but keep the sync version and test).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed in a5e4c45

Copy link
Member

@rjzamora rjzamora left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this @TomAugspurger - Just left a few preliminary comments.

Copy link
Member

@pentschev pentschev left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think overall this is going the right direction, I left some comments. There's one case it's not clear to me would work: how do we bootstrap on a non-local cluster, e.g., when the cluster is setup using the Dask CLI? Can we still just have the client call bootstrap_dask_cluster{,_async}? The function we had previously could be called from the client to bootstrap the cluster (once it was already setup), it may be the case still but it's not clear to me if that works when using plugins.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is something @rjzamora and I discussed previously and none of us was sure what's the "proper" way to do that. Is there a single way, either sync or async, that works for all cases? This function is inherently blocking because of the barrier we need, so if there's a single implementation that can work everywhere it would probably makes most sense.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there's nothing that prevents it from working with Dask, I would prefer if we could have a single function that can be called in either case, in my mind I would hope for a wrapper function that handles the proper case, similar to this.

Comment on lines 785 to 790
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I understand this, are you saying the root node should not wait at that barrier? If so, that's very much not what it's intended, all ranks need to wait on the barrier, both root and non-root. It is true that all other ranks need to have made contact with the root though, therefore I don't expect this asyncio.Event to be necessary, not sure if I'm missing anything though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At https://github.com/rapidsai/rapids-multi-gpu/pull/150/files#diff-79c8049ca8ea9f7d1c3f247e46853db9eb5e9c68647ac49e7ebb5aca7dc7c312R883 we do wait at the barrier(comm), for both the root and non-root nodes. This Event stuff is all so that the root node is the last to arrive at the barrier. I agree that it shouldn't be necessary :/

This mostly came out of a desire to have the same function be used for bootstrapping both the root and non-root nodes. Previously we did that in two stages (initially for the root node, later for the non-root nodes where we pass in the root node's UCX address). I'd like to avoid that if possible, so I figured why not have the non-root nodes ask the root node what it's UCX address is over the Dask comms.

That's when I discovered that if whatever thread is running this WorkerPlugin.setup(dask_worker) blocks, it also blocks the thread that handles RPC calls, and so we deadlock. I'm not sure why, but offloading the barrier to another thread with

        await asyncio.to_thread(barrier, comm)

does not work. I'll look into that a bit more (and we really shouldn't be calling the synchronous barrier function from what's apparently the main worker thread, so to_thread is probably a good idea even if we keep the Event stuff). If we're able to respond to RPC calls while waiting at the barrier then we're good to go. Or we give up on doing this in a single shot, and go back to the two-phase bootstrapping, but I'd like to try to avoid that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's when I discovered that if whatever thread is running this WorkerPlugin.setup(dask_worker) blocks, it also blocks the thread that handles RPC calls

I guess this makes sense, because plugin registration also happens in an RPC hander here. We're probably stuck at https://github.com/dask/distributed/blob/4ef21ad1c562503a5ac4703f312520cc4559823b/distributed/worker.py#L1892 when the root node is waiting at barrier(comm) and unable to respond to other RPCs asking the root node for its address. I still don't see why the asyncio.to_thread(barrier, comm) doesn't work though. I'd expect that to make an new asyncio task and let the event loop progress (and respond to RPCs)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah ok, so the issue is really that by calling barrier we essentially block the async event loop from continuing, that makes sense. What if instead of another thread we just wrap barrier in an await asyncio.create_task(barrier(comm))? I think that might work without requiring any extra threads (which may lead to complication in some cases).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

asyncio.create_task needs a coroutine. barrier is currently synchronous, so AFAIK we need it to be on a separate thread if we want the calling thread to be available to respond to other RPCs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is the main outstanding question for this PR.

As a reminder, we need to get the ucx address from the root node to all the other nodes. Previously, we'd create the UCXX comms in two stages: an initial stage run on the worker node (which returned the root nodes ucx address), and a subsequent stage on every other node. We used the Client to receive the root node's ucx address and broadcast it to every other node.

This PR does everything in one shot. The root node is mostly unchanged in how it sets up the UCX comms. The other nodes all ask the root node for its UCX address (using Dask's comms).

The current implementation has a wrinkle: Because Dask happens runs this function on the same (asyncio event loop) thread as comms, we need the root node arrive at the barrier(comm) so that it can respond to RPCs asking the root node for the UCXX address.

If we're on board with the goals and OK with that workaround, then this should be good to go. If either of those is a problem I can figure out some other way to do this. Options might include:

  1. figure out why asyncio.to_thread(barrier, comm) apparently doesn't work.
  2. find some other way to broadcast the root node's address to each worker.
  3. find some way to make barrier async (probably the ideal solution?)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't personally think your current approach is a problem as long as the pattern is documented in the code (and the documentation seems pretty clear to me).

Comment on lines 819 to 827
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As per the comment above, I don't expect this to be necessary since the C++ implementation is supposed to handle this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just checking: This comment was assuming that all processes should just wait at the barrier, correct? I believe we definitely do need this logic for now (unless we want to go back to using client.submit in two rounds).

Side note: I don't personally mind seeing the bootstrapping in the task stream, but I do think it's nice to have the worker plugins take care of everything.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct. This, or something like it, is necessary with the current implementation (use Dask Comms to share the root node's UCX address) and the current behavior of ucxx's (blocking, synchronous) barrier.

I don't personally mind seeing the bootstrapping in the task stream

I'll admit to being overly sensitive to that :) I'm realizing now that we can probably delete the futures from the client once we've actually bootstrapped the worker nodes..., which would mean they're just hanging around temporarily.

I think even with that, I'm still a bit partial to the WorkerPlugin. But I'm going to spend 15 minutes seeing what it would look like.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm realizing now that we can probably delete the futures from the client once we've actually bootstrapped the worker nodes..., which would mean they're just hanging around temporarily.

I've implemented this, but I think I might have been mistaken about the previous setup. I'm not sure why I thought they hung around in the dashboard permanently, but I don't see where they would have been persisted.

b63fa4d moves away from a WorkerPlugin and back to a set of client.submit and client.run calls like previously. I've (manually) confirmed that while the submit calls show up in the dashboard, there aren't any permanent futures.

Comment on lines 931 to 934
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also here, I wouldn't expect this to be needed and instead handled by the C++ implementation.

@TomAugspurger TomAugspurger mentioned this pull request Mar 20, 2025
@TomAugspurger
Copy link
Contributor Author

how do we bootstrap on a non-local cluster, e.g., when the cluster is setup using the Dask CLI? Can we still just have the client call bootstrap_dask_cluster{,_async}?

Yes, that should still work. I tested it manually for now:

$ dask-scheduler  # terminal 1
$ CUDA_VISIBLE_DEVICES=0 dask cuda worker <address>  # terminal 2
$ CUDA_VISIBLE_DEVICES=1 dask cuda worker <address>  # terminal 3
$ ipython  # terminal 4
In [1]: import dask.datasets
   ...: from distributed import Client
   ...: from rapidsmp.integrations.dask import bootstrap_dask_cluster
   ...: from rapidsmp.examples.dask import dask_cudf_shuffle
   ...:
   ...: client = Client("tcp://10.33.227.162:8786")
   ...:
   ...: df = dask.datasets.timeseries().reset_index(drop=True).to_backend("cudf")
   ...: bootstrap_dask_cluster(client, pool_size=0.25, spill_device=0.1)
   ...: shuffled = dask_cudf_shuffle(df, shuffle_on=["name"])
   ...: result = shuffled.compute()

@pentschev pentschev changed the base branch from main to branch-25.06 March 21, 2025 17:23
@pentschev
Copy link
Member

@TomAugspurger I retargeted to the new default branch-25.06, could you make sure you pull latest changes and resolve conflicts here?

@TomAugspurger TomAugspurger marked this pull request as ready for review March 24, 2025 15:21
@TomAugspurger TomAugspurger requested a review from a team as a code owner March 24, 2025 15:21
Copy link
Member

@rjzamora rjzamora left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is looking really good @TomAugspurger - Just some minor comments.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sentence sounds funny.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simplified this to

Note that rapidsmp does not currently support adding or removing workers from the cluster.

Copy link
Member

@rjzamora rjzamora left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @TomAugspurger - I'm happy with this change. I also confirmed that it works with multi-GPU polars execution.

My favorite part is that LocalRMPCluster goes away.

Note that I don't mind seeing the bootstrapping/setup functions in the task stream if you/others are uncomfortable with the barrier workaround. I'm on the fence about whether the added complexity is worth the improved UX, but I'm okay with either approach (especially for now).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any good reason to support an asynchronous client at all?

Comment on lines 785 to 790
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't personally think your current approach is a problem as long as the pattern is documented in the code (and the documentation seems pretty clear to me).

Comment on lines 819 to 827
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just checking: This comment was assuming that all processes should just wait at the barrier, correct? I believe we definitely do need this logic for now (unless we want to go back to using client.submit in two rounds).

Side note: I don't personally mind seeing the bootstrapping in the task stream, but I do think it's nice to have the worker plugins take care of everything.

@rjzamora
Copy link
Member

Still looks good to me after the revisions - I also confirmed that cudf-polars integration works smoothly.

Copy link
Member

@pentschev pentschev left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I've not been following this too close, but current changes look good to me. Thanks @TomAugspurger !

@TomAugspurger
Copy link
Contributor Author

Thanks for the reviews, and apologies for the diversion through that alternate implementation!

@TomAugspurger
Copy link
Contributor Author

/merge

@rapids-bot rapids-bot bot merged commit 2241c0c into rapidsai:branch-25.06 Mar 25, 2025
21 checks passed
@TomAugspurger TomAugspurger deleted the tom/dask-bootstrap-cleanup branch March 25, 2025 23:47
rapids-bot bot pushed a commit that referenced this pull request Mar 26, 2025
This is a follow-up to #150

Closes #157

The goal of this PR is to simplify memory-resource creation by avoiding it in `rapidsmp` altogether. Since the user can just use `LocalCUDACluster` to deploy a Dask cluster, they can also use existing options/utilities to create a memory pool on each worker. When `rapidsmp.integrations.dask.bootstrap_dask_cluster` is called, each worker only needs to wrap the current memory resource in a `StatisticsResourceAdaptor`.

This is technically "breaking", because it removes the `pool_size` argument from `bootstrap_dask_cluster`. However, we are only using that option in rapidsai/cudf#18335 (which is still experimental - and can be easily changed).

Authors:
  - Richard (Rick) Zamora (https://github.com/rjzamora)

Approvers:
  - Tom Augspurger (https://github.com/TomAugspurger)
  - Peter Andreas Entschev (https://github.com/pentschev)

URL: #172
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

improvement Improves an existing functionality non-breaking Introduces a non-breaking change

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants