feat: introduce async rebalance mode for expert load balancer#8529
feat: introduce async rebalance mode for expert load balancer#8529TheBasy wants to merge 1 commit intosgl-project:mainfrom
Conversation
There was a problem hiding this comment.
Summary of Changes
Hello @TheBasy, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request significantly enhances the Expert Load Balancer (EPLB) by introducing an asynchronous rebalance mode. This new mode offloads the rebalance plan computation to a background thread and employs non-blocking communication primitives for distributed synchronization, aiming to reduce the performance impact of rebalancing on the main training loop.
Highlights
- Asynchronous Rebalance Mode: Introduced a new asynchronous rebalance mode for the Expert Load Balancer (EPLB), enabled via the
--eplb-rebalance-async-modeCLI argument. This mode allows the computationally intensive rebalance plan calculation to occur in a background thread, preventing it from blocking the main training loop. - Non-blocking Communication Primitives: Added new
send_signal_cpuandrecv_signal_cpumethods to theProcessGroupclass. These methods utilizetorch.distributed.isendandirecvfor non-blocking, CPU-based integer signal communication, crucial for coordinating the asynchronous rebalance across distributed ranks. - TP-wide Synchronization Barrier: Implemented a new synchronization mechanism (
tp_sync) using the non-blocking CPU signals. This ensures that all ranks in the Tensor Parallel (TP) group agree on a specific iteration (_begin_transfer) before the rebalance plan is applied, guaranteeing atomic plan swapping across the entire group. - Yield-based Generator for Training Loop Integration: The rebalance process is now integrated into the main training loop via a yield-based generator. This design allows the rebalance computation and synchronization steps to interleave with the training iterations, keeping the training loop non-blocking while the rebalance is prepared in the background.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Code Review
This pull request introduces an asynchronous rebalancing mode for the expert load balancer to improve performance by reducing blocking time in the main loop. The changes involve adding new asynchronous signaling functions, using background threads for computation and synchronization, and updating the server arguments.
The implementation is well-structured, but there are a few areas for improvement:
- There are a couple of incorrect type hints in the new signaling functions.
- A hardcoded value for the number of GPUs per node could cause issues on different hardware setups.
- There is some dead code (an unused variable).
- The waiting logic for asynchronous operations can be made more efficient.
There was a problem hiding this comment.
The line torch.cuda.set_device(local_rank%8) uses a hardcoded value 8 for the number of GPUs per node. This assumption may not hold true for all hardware configurations and could lead to a CUDA error: invalid device ordinal if a node has a different number of GPUs. It's better to dynamically determine the number of devices using torch.cuda.device_count().
| torch.cuda.set_device(local_rank%8) | |
| torch.cuda.set_device(local_rank % torch.cuda.device_count()) |
There was a problem hiding this comment.
The function is type-hinted to return None, but it actually returns a tuple (w, size_tensor). The type hint should be updated to reflect the actual return value for better type safety and code clarity.
| def send_signal_cpu(self, dst_rank_global: int, value: int, group: Optional[ProcessGroup] = None,) -> None: | |
| def send_signal_cpu(self, dst_rank_global: int, value: int, group: Optional[ProcessGroup] = None,) -> Tuple[torch.distributed.Work, torch.Tensor]: |
There was a problem hiding this comment.
The function is type-hinted to return Any, but it returns a tuple (w, size_tensor). For better code clarity and type safety, the type hint should be more specific.
| def recv_signal_cpu(self, src_rank_global: int, group: Optional[ProcessGroup] = None, ) -> Any: | |
| def recv_signal_cpu(self, src_rank_global: int, group: Optional[ProcessGroup] = None, ) -> Tuple[torch.distributed.Work, torch.Tensor]: |
There was a problem hiding this comment.
The while True loop with a nested for w in works: w.wait() results in waiting for the asynchronous operations sequentially. This negates the benefit of using non-blocking isend/irecv and is inefficient.
A better approach is to wait for all operations to complete in parallel by checking their completion status in a loop. Also, the while True loop in the original implementation is redundant as it would only ever execute once.
while not all(w.is_completed() for w in works):
time.sleep(0.001)
self._tp_sync_ongoing = False|
We conducted the bench_serving test with a batch size of 4096 and a request rate of 4. The asynchronous EPLB mode results are as follows: With the same load (batch 4096, request rate 4 req/s), switching from synchronous to asynchronous EPLB yields a small but consistent throughput gain (~0.3 % more tokens and requests per second) while simultaneously cutting latency: mean end-to-end time drops ~2-3 %, P99 TTFT and max inter-token latency fall 9 % and 32 %, respectively, and the average in-flight concurrency shrinks by ~2 %. These improvements confirm that moving expert-parameter transfers to a background thread effectively overlaps communication with computation, giving higher hardware utilization and noticeably better tail-latency for large-batch, high-concurrency inference workloads. |
fzyzcjy
left a comment
There was a problem hiding this comment.
oh I realize the code is not ready, thus ignore the code cleanup suggestions below for now - as long as the code is clean and neat when it is ready then it is good!
d2173a4 to
3fd46b3
Compare
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces an asynchronous mode for expert rebalancing in EPLB, which is a great feature for improving performance by overlapping computation and communication with the main model execution loop. The use of background threads and a dedicated CPU process group for synchronization is well-thought-out.
My review focuses on ensuring the correctness of the new asynchronous logic. I've identified a couple of race conditions related to thread synchronization flags that could lead to incorrect behavior. I've also suggested some minor renamings for clarity.
Overall, the changes are a significant improvement. Addressing the identified issues will make the implementation more robust.
There was a problem hiding this comment.
As mentioned in the comment on _compute_expert_metadata, to fix the race condition with self._compute_ongoing, it should be set to True here before starting the background thread.
| self._compute_thread = threading.Thread( | |
| target=self._compute_expert_metadata, | |
| args=(logical_count_sum,), | |
| daemon=True | |
| ) | |
| self._compute_thread.start() | |
| self._compute_ongoing = True | |
| self._compute_thread = threading.Thread( | |
| target=self._compute_expert_metadata, | |
| args=(logical_count_sum,), | |
| daemon=True | |
| ) | |
| self._compute_thread.start() |
There was a problem hiding this comment.
There is a race condition in how self._compute_ongoing is managed. It's set to True here in the background thread, but the main thread in compute() starts checking it in a loop immediately after starting this thread. The main thread might read the old False value before this line is executed, causing it to skip waiting for the computation to finish.
To fix this, self._compute_ongoing should be set to True in the compute() method before this thread is started. This line should be removed. I'll add another comment in compute() with the corresponding change.
There was a problem hiding this comment.
As mentioned in the comment on _wait_compute_sig, to fix the race condition with self._tp_sync_ongoing, it should be set to True here before starting the background thread.
| self._tp_sync_thread = threading.Thread( | |
| target=self._wait_compute_sig, | |
| args=(send_works, recv_works,), | |
| daemon=True | |
| ) | |
| self._tp_sync_thread.start() | |
| self._tp_sync_ongoing = True | |
| self._tp_sync_thread = threading.Thread( | |
| target=self._wait_compute_sig, | |
| args=(send_works, recv_works,), | |
| daemon=True | |
| ) | |
| self._tp_sync_thread.start() |
There was a problem hiding this comment.
Similar to _compute_ongoing, there is a race condition in how self._tp_sync_ongoing is managed. It's set to True here in the background thread, but the main thread in tp_sync() starts checking it in a loop after starting this thread. The main thread might read the old False value before this line is executed, causing it to skip waiting for the TP sync to finish.
To fix this, self._tp_sync_ongoing should be set to True in the tp_sync() method before this thread is started. This line should be removed. I'll add another comment in tp_sync() with the corresponding change.
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces an asynchronous rebalancing mode for the expert load balancer, which is a great feature for improving performance by overlapping computation and communication. The implementation uses background threads and generators to keep the main loop non-blocking, which is a solid approach.
My main concerns are around the robustness of the background threading. Specifically, exceptions in the background threads are not propagated and will cause the main loop to hang. I've added critical comments with suggestions to fix this using try...finally blocks.
I've also included a few medium-severity comments to improve code clarity and maintainability, such as removing redundant yield statements, clarifying a magic number with a comment, and renaming a function to better reflect its purpose.
Overall, this is a well-structured feature, and addressing the exception handling will make it much more robust.
There was a problem hiding this comment.
If an exception occurs within ExpertLocationMetadata.init_by_eplb, the self._compute_ongoing = False line will not be executed. This will cause the main generator loop in compute() to hang indefinitely in the while self._compute_ongoing: spin-wait loop.
To prevent this, you should wrap the logic in a try...finally block to ensure self._compute_ongoing is always set to False, even if an error occurs.
For even greater robustness, consider capturing the exception and propagating it to the main thread to be handled, rather than just preventing the hang. A simple way to do this is to store the exception in an instance variable (e.g., self._compute_exception) and check for it in the main loop after the thread completes.
| def _compute_expert_metadata(self, logical_count): | |
| local_rank = self._model_runner.tp_rank | |
| num_gpu_per_node = self._num_gpu_per_node | |
| torch.cuda.set_device(local_rank % num_gpu_per_node) | |
| expert_location_metadata = ExpertLocationMetadata.init_by_eplb( | |
| self._server_args, self._model_runner.model_config, logical_count | |
| ) | |
| self._rebalance_result = expert_location_metadata | |
| self._compute_ongoing = False | |
| def _compute_expert_metadata(self, logical_count): | |
| try: | |
| local_rank = self._model_runner.tp_rank | |
| num_gpu_per_node = self._num_gpu_per_node | |
| torch.cuda.set_device(local_rank % num_gpu_per_node) | |
| expert_location_metadata = ExpertLocationMetadata.init_by_eplb( | |
| self._server_args, self._model_runner.model_config, logical_count | |
| ) | |
| self._rebalance_result = expert_location_metadata | |
| finally: | |
| self._compute_ongoing = False |
There was a problem hiding this comment.
If any of the w.wait() calls raise an exception (e.g., due to a timeout or a distributed communication error), the self._tp_sync_ongoing = False line will be skipped. This will cause the main generator loop in tp_sync() to hang in the while self._tp_sync_ongoing: spin-wait loop.
To ensure the application remains responsive, you should wrap the waiting logic in a try...finally block. This guarantees that _tp_sync_ongoing is set to False regardless of whether the operations succeed or fail.
| def _wait_compute_sig(self, send_works,recv_works): | |
| works = send_works + recv_works | |
| for w in works: | |
| w.wait() | |
| self._tp_sync_ongoing = False | |
| def _wait_compute_sig(self, send_works,recv_works): | |
| try: | |
| works = send_works + recv_works | |
| for w in works: | |
| w.wait() | |
| finally: | |
| self._tp_sync_ongoing = False |
There was a problem hiding this comment.
This yield statement appears to be redundant. The while self._compute_ongoing: loop on the following lines will yield on its first iteration if the condition is met. If the condition is not met (i.e., the computation is already complete), the loop is skipped, and yielding here is unnecessary. Removing this line would make the code slightly cleaner without changing the logic.
There was a problem hiding this comment.
The + 2 here is a magic number. For better maintainability and readability, it would be helpful to add a comment explaining why a buffer of 2 steps is needed. This will help future developers understand the synchronization logic.
For example, a comment could clarify that this buffer ensures all ranks have processed the synchronization signal before proceeding with the transfer.
| self._begin_transfer_step = max(max_tensor.item(),self._send_signal_step) + 2 | |
| # Add a buffer of 2 steps to ensure all ranks see the new step before acting on it. | |
| self._begin_transfer_step = max(max_tensor.item(),self._send_signal_step) + 2 |
There was a problem hiding this comment.
The function name _wait_compute_sig is a bit misleading, as this function waits for communication signals (isend/irecv works) from the tensor parallel group to complete, not for a computation signal. A more descriptive name like _wait_for_tp_sync_signals would better reflect its purpose and improve code clarity.
Remember to update the call site in tp_sync (line 145) if you rename this function.
| def _wait_compute_sig(self, send_works,recv_works): | |
| def _wait_for_tp_sync_signals(self, send_works,recv_works): |
Thanks for the heads-up! |
- Add CLI arg `--enable-eplb-rebalance-async` - Background thread: broadcast `logical_count` → compute `ExpertLocationMetadata` → store in `_rebalance_result` - TP barrier: use gloo `cpu_group` signals (`send_single_signal` / `recv_single_signal`) ensure all ranks swap plan atomically - Yield-based generator keeps decoding loop non-blocking; transfer starts after TP-wide agreement via `_begin_transfer` - Sync mode (`async=False`) unchanged: blocking single-thread rebalance
| self._rebalance_result = expert_location_metadata | ||
| self._compute_ongoing = False | ||
|
|
||
| def tp_sync(self): |
There was a problem hiding this comment.
hmm not sure whether this is optimal or not (do not have time to dig deeper). but FYI there is a PollBasedBarrier that may look somehow related
There was a problem hiding this comment.
Thanks for the note!
In our tp_sync function, we're aiming to implement a non-blocking all-reduce max operation to synchronize the step_counter across TP ranks. The goal is to ensure that all TP ranks proceed to the data transfer phase only after the maximum step_counter among them is reached, while allowing each rank to continue decoding independently during the synchronization window.
If TP ranks enter the transfer phase at different decoding steps, the earlier ones would have to transfer one layer and then wait for others to finish their current decode — this creates global blocking and hurts throughput.
We did try using torch.distributed.all_reduce directly, but it blocks the main inference thread due to the collective communication being synchronous, which disrupts the continuous decoding flow. That’s why we’re exploring alternative designs to achieve asynchronous coordination without stalling inference.
|
what moe model you are test? can you should give your model server command, that is good for test reproduction. |
I test Deepseek-R1. You only need to add one line, --enable-eplb-rebalance-async, to the original launch arguments to enable asynchronous computation during rebalancing. It is also recommended to set --eplb-rebalance-layers-per-chunk 1, which allows layerwise data migration during the rebalancing phase, further reducing the user's perception of interruption latency. for example, my original launch arguments for the decode stage is: python3 -m sglang.launch_server \
--disaggregation-mode decode \
--disaggregation-transfer-backend mooncake \
--attention-backend flashmla \
--trust-remote-code \
--enable-dp-attention \
--mem-fraction-static 0.9 \
--context-length 65535 \
--decode-log-interval 50 \
--enable-cache-report \
--moe-dense-tp-size 1 \
--enable-dp-lm-head \
--enable-deepep-moe \
--deepep-mode low_latency \
--disaggregation-ib-device mlx5_bond_0,mlx5_bond_1,mlx5_bond_2,mlx5_bond_3 \
--model-path /home/models/deepseek-ai__DeepSeek-R1 \
--host 0.0.0.0 \
--port ${DEFAULT_PORT} \
--dist-init-addr ${DECODE_MASTER_IP}:${MASTER_PORT} \
--tp-size $((8 * NUM_DECODE)) \
--dp-size $((DP_SIZE_PER_DECODE_NODE * NUM_DECODE)) \
--chunked-prefill-size $((DP_SIZE_PER_DECODE_NODE * NUM_DECODE * CHUNKED_PREFILL_SIZE_PER_DP_RANK)) \
--max-running-requests $((DP_SIZE_PER_DECODE_NODE * NUM_DECODE * DECODE_MAX_RUNNING_REQUEST_PER_DP_RANK)) \
--nnodes ${NUM_DECODE} \
--node-rank $node_rank \
--enable-dp-attention \
--enable-expert-distribution-metrics \
--enable-eplb \
--eplb-rebalance-num-iterations 1000 \
--cuda-graph-max-bs ${DECODE_MAX_RUNNING_REQUEST_PER_DP_RANK} \
> ${WORKDIR}/decode.log 2>&1 &After using async rebalance mode, the launch arguments will be python3 -m sglang.launch_server \
--disaggregation-mode decode \
--disaggregation-transfer-backend mooncake \
--attention-backend flashmla \
--trust-remote-code \
--enable-dp-attention \
--mem-fraction-static 0.9 \
--context-length 65535 \
--decode-log-interval 50 \
--enable-cache-report \
--moe-dense-tp-size 1 \
--enable-dp-lm-head \
--enable-deepep-moe \
--deepep-mode low_latency \
--disaggregation-ib-device mlx5_bond_0,mlx5_bond_1,mlx5_bond_2,mlx5_bond_3 \
--model-path /home/models/deepseek-ai__DeepSeek-R1 \
--host 0.0.0.0 \
--port ${DEFAULT_PORT} \
--dist-init-addr ${DECODE_MASTER_IP}:${MASTER_PORT} \
--tp-size $((8 * NUM_DECODE)) \
--dp-size $((DP_SIZE_PER_DECODE_NODE * NUM_DECODE)) \
--chunked-prefill-size $((DP_SIZE_PER_DECODE_NODE * NUM_DECODE * CHUNKED_PREFILL_SIZE_PER_DP_RANK)) \
--max-running-requests $((DP_SIZE_PER_DECODE_NODE * NUM_DECODE * DECODE_MAX_RUNNING_REQUEST_PER_DP_RANK)) \
--nnodes ${NUM_DECODE} \
--node-rank $node_rank \
--enable-dp-attention \
--enable-expert-distribution-metrics \
--enable-eplb \
--eplb-rebalance-layers-per-chunk 1 \
--enable-eplb-rebalance-async \
--eplb-rebalance-num-iterations 1000 \
--cuda-graph-max-bs ${DECODE_MAX_RUNNING_REQUEST_PER_DP_RANK} \
> ${WORKDIR}/decode.log 2>&1 &
|
|
Do you have the related benchmark results only on decode instance? |
|
@TheBasy Could you merge the latest master to resolve the conflicts? |
You can check this PR for the merge: antgroup#3. It includes the latest changes from master and resolves the conflicts. |
--eplb-rebalance-async-modelogical_count→ computeExpertLocationMetadata→ store in_rebalance_resultcpu_groupsignals (send_signal_cpu/recv_signal_cpu) ensure all ranks swap plan atomically_begin_transferasync=False) unchanged: blocking single-thread rebalanceMotivation
Modifications
Accuracy Test
Benchmark & Profiling
Checklist