Skip to content

[0.13.0][Bugfix] Add synced_cudagraph_mode to limit mixed graph modes in dp ranks#6011

Merged
yiz-liu merged 2 commits intovllm-project:releases/v0.13.0from
slippersss:0.13.0_bugfix_graph
Jan 20, 2026
Merged

[0.13.0][Bugfix] Add synced_cudagraph_mode to limit mixed graph modes in dp ranks#6011
yiz-liu merged 2 commits intovllm-project:releases/v0.13.0from
slippersss:0.13.0_bugfix_graph

Conversation

@slippersss
Copy link
Copy Markdown
Contributor

@slippersss slippersss commented Jan 19, 2026

What this PR does / why we need it?

This PR aims to fix the issue that using A2 + AIV will hang due to the fact that HCCL does not support eager/graph mode communication. To handle it, following vllm-project/vllm#30173, we introduce synced_cudagraph_mode to enable all ranks to know the minimum mode across ranks. Main changes are described below:

  1. execute_model now performs "dispatch -> sync -> re-dispatch" just as _dummy_run
  2. _sync_metadata_across_dp now receives cudagraph_mode from all ranks and returns synced_cudagraph_mode to all ranks
  3. Re-dispatch steps in both execute_model and _dummy_run include disable_full=synced_cudagraph_mode <= CUDAGraphMode.PIECEWISE.value so that when it is true, no FULL will be dispatched

Does this PR introduce any user-facing change?

N/A

How was this patch tested?

by ci

…es in dp ranks

Signed-off-by: Zetong Li <slippersss@126.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a mechanism to synchronize the CUDAGraph mode across data-parallel ranks to prevent hangs when different ranks operate in mixed eager/graph modes. The changes are logical and well-implemented, introducing a synced_cudagraph_mode that is used to control graph dispatching. However, I've identified a critical issue where this synchronization is skipped in a specific optimization path for MoE models, which could lead to the very problem this PR aims to solve. My review includes a suggested fix for this issue.

Comment on lines 441 to +446
if self._skip_all_reduce_across_dp_group():
num_tokens_after_padding = torch.tensor([num_tokens] *
self.dp_size,
device="cpu",
dtype=torch.int32)
return num_tokens, num_tokens_after_padding, with_prefill
return num_tokens, num_tokens_after_padding, with_prefill, cudagraph_mode
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

When _skip_all_reduce_across_dp_group() is true, the all_reduce operation for syncing metadata is skipped. However, this also skips syncing cudagraph_mode, returning the local cudagraph_mode instead. This could lead to different ranks operating in different CUDAGraph modes, which is the exact issue this pull request aims to fix and could cause hangs.

The cudagraph_mode should be synced across all DP ranks regardless of whether other metadata syncing is skipped.

        if self._skip_all_reduce_across_dp_group():
            # Even if we skip syncing num_tokens, we must sync cudagraph_mode.
            mode_tensor = torch.tensor([cudagraph_mode], dtype=torch.int32, device="cpu")
            dist.all_reduce(mode_tensor, op=dist.ReduceOp.MIN, group=get_dp_group().cpu_group)
            synced_cudagraph_mode = mode_tensor.item()

            num_tokens_after_padding = torch.tensor([num_tokens] *
                                                    self.dp_size,
                                                    device="cpu",
                                                    dtype=torch.int32)
            return num_tokens, num_tokens_after_padding, with_prefill, synced_cudagraph_mode

Signed-off-by: Zetong Li <slippersss@126.com>
Comment on lines 443 to +448
if self._skip_all_reduce_across_dp_group():
num_tokens_after_padding = torch.tensor([num_tokens] *
self.dp_size,
device="cpu",
dtype=torch.int32)
return num_tokens, num_tokens_after_padding, with_prefill
return num_tokens, num_tokens_after_padding, with_prefill, cudagraph_mode
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jianzs Hi, as what you mentioned in #5979, when we skip all_reduce, there is still a possibility that different dp ranks may run different graph modes. Since we have met the issue that A2 + AIV will hang, so we have to ensure that there is no prefill when entering this _skip_all_reduce_across_dp_group branch.

@yiz-liu yiz-liu added ready read for review ready-for-test start test by label for PR labels Jan 20, 2026
@yiz-liu yiz-liu merged commit 9f5a033 into vllm-project:releases/v0.13.0 Jan 20, 2026
20 checks passed
845473182 pushed a commit to 845473182/vllm-ascend that referenced this pull request Jan 20, 2026
…lm-ascend into FIA_v0.13.0

* 'releases/v0.13.0' of https://github.com/vllm-project/vllm-ascend:
  [0.13.0][Bugfix] Add `synced_cudagraph_mode` to limit mixed graph modes in dp ranks (vllm-project#6011)
starmountain1997 pushed a commit to starmountain1997/vllm-ascend that referenced this pull request Jan 31, 2026
…es in dp ranks (vllm-project#6011)

### What this PR does / why we need it?
This PR aims to fix the issue that using A2 + AIV will hang due to the
fact that HCCL does not support eager/graph mode communication. To
handle it, following vllm-project/vllm#30173, we
introduce `synced_cudagraph_mode` to enable all ranks to know the
minimum mode across ranks. Main changes are described below:
1. `execute_model` now performs "dispatch -> sync -> re-dispatch" just
as `_dummy_run`
2. `_sync_metadata_across_dp` now receives `cudagraph_mode` from all
ranks and returns `synced_cudagraph_mode` to all ranks
3. Re-dispatch steps in both `execute_model` and `_dummy_run` include
`disable_full=synced_cudagraph_mode <= CUDAGraphMode.PIECEWISE.value` so
that when it is true, no FULL will be dispatched

### Does this PR introduce _any_ user-facing change?
N/A

### How was this patch tested?
by ci

---------

Signed-off-by: Zetong Li <slippersss@126.com>
tangtiangu pushed a commit to tangtiangu/jiusi-vllm-ascend that referenced this pull request Feb 24, 2026
tangtiangu pushed a commit to tangtiangu/jiusi-vllm-ascend that referenced this pull request Feb 24, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready read for review ready-for-test start test by label for PR

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants