Refactor Dask cluster boostrapping#150
Refactor Dask cluster boostrapping#150rapids-bot[bot] merged 19 commits intorapidsai:branch-25.06from
Conversation
This refactors how we do our dask cluster boostraping. It fixes a few issues: 1. No need for a new Cluster class. Just use LocalCUDACluster. 2. Avoid using `submit` and `Futures`, which show up in the user's Dashboard
There was a problem hiding this comment.
Still looking into this a bit, but there are apparently some differences in plugin behavior for sync and async clients. With an async client, we need to await the plugin.register_plugin to ensure it runs to completion before we proceed to the actual shuffle.
There was a problem hiding this comment.
This is something @rjzamora and I discussed previously and none of us was sure what's the "proper" way to do that. Is there a single way, either sync or async, that works for all cases? This function is inherently blocking because of the barrier we need, so if there's a single implementation that can work everywhere it would probably makes most sense.
There was a problem hiding this comment.
If there's nothing that prevents it from working with Dask, I would prefer if we could have a single function that can be called in either case, in my mind I would hope for a wrapper function that handles the proper case, similar to this.
There was a problem hiding this comment.
This kind of wrapper is only an option if the public interface is async (since you're doing an await func() in the iscoroutinefunction case). In our case, I think we need the public API to be sync since we're calling this from rapidsmp_shuffle_graph (which is sync).
I'm kind of down on APIs that magically call async from sync :/ They just always seem to lead to weirdness around coroutines being created / run on the wrong event loop (zarr-developers/zarr-python#2909 is a recent example) that I can never follow. I'd much prefer to duplicate the API definition, and a bit of the implementation if it isn't too complex.
There was a problem hiding this comment.
Yeah, I don't think we can define one function that can be called by both a synchronous and asynchronous client with out doing some wrapper gymnastics that would only make sense if we were duplicating a lot of code. In this case, the duplicated code seems pretty minimal.
There was a problem hiding this comment.
Something to figure out here: we call this via get_dask_client, which is called via the main rapidsmp_shuffle_graph. If the Client we get happens to be asynchronous, we'd need to call (and await) boostrap_dask_cluster_async. But we're in a sync function.
We can probably make something work using client.sync.
I'm tempted to punt on this because there are other things to change (all the client.runs need to be awaited) async clients are relatively uncommon.
There was a problem hiding this comment.
Is there any good reason to support an asynchronous client at all?
There was a problem hiding this comment.
Good question... We do test for it, on main using the test test_dask_ucxx_cluster, which I've spit / duplicated that into a sync and async version in this PR.
That test only runs the bootstrapping. I don't think we could do an actual rapidsmp_shuffle_graph on either main or this PR using that async client.
So if we want, I can delete that test and this function (but keep the sync version and test).
rjzamora
left a comment
There was a problem hiding this comment.
Thanks for working on this @TomAugspurger - Just left a few preliminary comments.
pentschev
left a comment
There was a problem hiding this comment.
I think overall this is going the right direction, I left some comments. There's one case it's not clear to me would work: how do we bootstrap on a non-local cluster, e.g., when the cluster is setup using the Dask CLI? Can we still just have the client call bootstrap_dask_cluster{,_async}? The function we had previously could be called from the client to bootstrap the cluster (once it was already setup), it may be the case still but it's not clear to me if that works when using plugins.
There was a problem hiding this comment.
This is something @rjzamora and I discussed previously and none of us was sure what's the "proper" way to do that. Is there a single way, either sync or async, that works for all cases? This function is inherently blocking because of the barrier we need, so if there's a single implementation that can work everywhere it would probably makes most sense.
There was a problem hiding this comment.
If there's nothing that prevents it from working with Dask, I would prefer if we could have a single function that can be called in either case, in my mind I would hope for a wrapper function that handles the proper case, similar to this.
There was a problem hiding this comment.
I'm not sure I understand this, are you saying the root node should not wait at that barrier? If so, that's very much not what it's intended, all ranks need to wait on the barrier, both root and non-root. It is true that all other ranks need to have made contact with the root though, therefore I don't expect this asyncio.Event to be necessary, not sure if I'm missing anything though.
There was a problem hiding this comment.
At https://github.com/rapidsai/rapids-multi-gpu/pull/150/files#diff-79c8049ca8ea9f7d1c3f247e46853db9eb5e9c68647ac49e7ebb5aca7dc7c312R883 we do wait at the barrier(comm), for both the root and non-root nodes. This Event stuff is all so that the root node is the last to arrive at the barrier. I agree that it shouldn't be necessary :/
This mostly came out of a desire to have the same function be used for bootstrapping both the root and non-root nodes. Previously we did that in two stages (initially for the root node, later for the non-root nodes where we pass in the root node's UCX address). I'd like to avoid that if possible, so I figured why not have the non-root nodes ask the root node what it's UCX address is over the Dask comms.
That's when I discovered that if whatever thread is running this WorkerPlugin.setup(dask_worker) blocks, it also blocks the thread that handles RPC calls, and so we deadlock. I'm not sure why, but offloading the barrier to another thread with
await asyncio.to_thread(barrier, comm)
does not work. I'll look into that a bit more (and we really shouldn't be calling the synchronous barrier function from what's apparently the main worker thread, so to_thread is probably a good idea even if we keep the Event stuff). If we're able to respond to RPC calls while waiting at the barrier then we're good to go. Or we give up on doing this in a single shot, and go back to the two-phase bootstrapping, but I'd like to try to avoid that.
There was a problem hiding this comment.
That's when I discovered that if whatever thread is running this WorkerPlugin.setup(dask_worker) blocks, it also blocks the thread that handles RPC calls
I guess this makes sense, because plugin registration also happens in an RPC hander here. We're probably stuck at https://github.com/dask/distributed/blob/4ef21ad1c562503a5ac4703f312520cc4559823b/distributed/worker.py#L1892 when the root node is waiting at barrier(comm) and unable to respond to other RPCs asking the root node for its address. I still don't see why the asyncio.to_thread(barrier, comm) doesn't work though. I'd expect that to make an new asyncio task and let the event loop progress (and respond to RPCs)
There was a problem hiding this comment.
Ah ok, so the issue is really that by calling barrier we essentially block the async event loop from continuing, that makes sense. What if instead of another thread we just wrap barrier in an await asyncio.create_task(barrier(comm))? I think that might work without requiring any extra threads (which may lead to complication in some cases).
There was a problem hiding this comment.
asyncio.create_task needs a coroutine. barrier is currently synchronous, so AFAIK we need it to be on a separate thread if we want the calling thread to be available to respond to other RPCs.
There was a problem hiding this comment.
I think this is the main outstanding question for this PR.
As a reminder, we need to get the ucx address from the root node to all the other nodes. Previously, we'd create the UCXX comms in two stages: an initial stage run on the worker node (which returned the root nodes ucx address), and a subsequent stage on every other node. We used the Client to receive the root node's ucx address and broadcast it to every other node.
This PR does everything in one shot. The root node is mostly unchanged in how it sets up the UCX comms. The other nodes all ask the root node for its UCX address (using Dask's comms).
The current implementation has a wrinkle: Because Dask happens runs this function on the same (asyncio event loop) thread as comms, we need the root node arrive at the barrier(comm) so that it can respond to RPCs asking the root node for the UCXX address.
If we're on board with the goals and OK with that workaround, then this should be good to go. If either of those is a problem I can figure out some other way to do this. Options might include:
- figure out why
asyncio.to_thread(barrier, comm)apparently doesn't work. - find some other way to broadcast the root node's address to each worker.
- find some way to make
barrierasync (probably the ideal solution?)
There was a problem hiding this comment.
I don't personally think your current approach is a problem as long as the pattern is documented in the code (and the documentation seems pretty clear to me).
There was a problem hiding this comment.
As per the comment above, I don't expect this to be necessary since the C++ implementation is supposed to handle this.
There was a problem hiding this comment.
Just checking: This comment was assuming that all processes should just wait at the barrier, correct? I believe we definitely do need this logic for now (unless we want to go back to using client.submit in two rounds).
Side note: I don't personally mind seeing the bootstrapping in the task stream, but I do think it's nice to have the worker plugins take care of everything.
There was a problem hiding this comment.
Correct. This, or something like it, is necessary with the current implementation (use Dask Comms to share the root node's UCX address) and the current behavior of ucxx's (blocking, synchronous) barrier.
I don't personally mind seeing the bootstrapping in the task stream
I'll admit to being overly sensitive to that :) I'm realizing now that we can probably delete the futures from the client once we've actually bootstrapped the worker nodes..., which would mean they're just hanging around temporarily.
I think even with that, I'm still a bit partial to the WorkerPlugin. But I'm going to spend 15 minutes seeing what it would look like.
There was a problem hiding this comment.
I'm realizing now that we can probably delete the futures from the client once we've actually bootstrapped the worker nodes..., which would mean they're just hanging around temporarily.
I've implemented this, but I think I might have been mistaken about the previous setup. I'm not sure why I thought they hung around in the dashboard permanently, but I don't see where they would have been persisted.
b63fa4d moves away from a WorkerPlugin and back to a set of client.submit and client.run calls like previously. I've (manually) confirmed that while the submit calls show up in the dashboard, there aren't any permanent futures.
There was a problem hiding this comment.
Also here, I wouldn't expect this to be needed and instead handled by the C++ implementation.
Yes, that should still work. I tested it manually for now: |
|
@TomAugspurger I retargeted to the new default |
rjzamora
left a comment
There was a problem hiding this comment.
This is looking really good @TomAugspurger - Just some minor comments.
There was a problem hiding this comment.
Simplified this to
Note that rapidsmp does not currently support adding or removing workers from the cluster.
There was a problem hiding this comment.
Thanks @TomAugspurger - I'm happy with this change. I also confirmed that it works with multi-GPU polars execution.
My favorite part is that LocalRMPCluster goes away.
Note that I don't mind seeing the bootstrapping/setup functions in the task stream if you/others are uncomfortable with the barrier workaround. I'm on the fence about whether the added complexity is worth the improved UX, but I'm okay with either approach (especially for now).
There was a problem hiding this comment.
Is there any good reason to support an asynchronous client at all?
There was a problem hiding this comment.
I don't personally think your current approach is a problem as long as the pattern is documented in the code (and the documentation seems pretty clear to me).
There was a problem hiding this comment.
Just checking: This comment was assuming that all processes should just wait at the barrier, correct? I believe we definitely do need this logic for now (unless we want to go back to using client.submit in two rounds).
Side note: I don't personally mind seeing the bootstrapping in the task stream, but I do think it's nice to have the worker plugins take care of everything.
|
Still looks good to me after the revisions - I also confirmed that cudf-polars integration works smoothly. |
pentschev
left a comment
There was a problem hiding this comment.
Sorry I've not been following this too close, but current changes look good to me. Thanks @TomAugspurger !
|
Thanks for the reviews, and apologies for the diversion through that alternate implementation! |
|
/merge |
This is a follow-up to #150 Closes #157 The goal of this PR is to simplify memory-resource creation by avoiding it in `rapidsmp` altogether. Since the user can just use `LocalCUDACluster` to deploy a Dask cluster, they can also use existing options/utilities to create a memory pool on each worker. When `rapidsmp.integrations.dask.bootstrap_dask_cluster` is called, each worker only needs to wrap the current memory resource in a `StatisticsResourceAdaptor`. This is technically "breaking", because it removes the `pool_size` argument from `bootstrap_dask_cluster`. However, we are only using that option in rapidsai/cudf#18335 (which is still experimental - and can be easily changed). Authors: - Richard (Rick) Zamora (https://github.com/rjzamora) Approvers: - Tom Augspurger (https://github.com/TomAugspurger) - Peter Andreas Entschev (https://github.com/pentschev) URL: #172
This refactors how we do our dask cluster boostraping. It fixes a few issues:
Client.submitandFuturesto bootstrap the UCXX comms, which show up in the user's Dashboard.Instead, we'll use scheduler and worker plugins. These are all set up in
bootstrap_dask_cluster, which is called viarapidsmp.integrations.dask.get_client(which is called viarapidsmp_shuffle_graph). Users shouldn't need to change anything.The actual implementation grew a little complicated, but I've left a note explaining things. That might be a good spot to start reviewing.