diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 859f21ece31..c9c94dcc5a0 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -30,7 +30,6 @@ import numpy.typing as npt import torch import torch._dynamo.cache_size -import torch.distributed as dist import torch.nn as nn from tqdm import tqdm # type: ignore from vllm.attention import AttentionType, get_attn_backend @@ -587,17 +586,21 @@ def _get_forward_metadata_across_dp( self, num_tokens: int, with_prefill: bool, enable_dbo: bool) -> tuple[torch.Tensor, bool, bool]: - # Compose: all_reduce metadata (num_tokens of each rank, with_prefill, enable_dbo) - num_tokens_across_dp = torch.zeros(self.dp_size + 2, - dtype=torch.int32, - device="cpu") - num_tokens_across_dp[self.dp_rank] = num_tokens - num_tokens_across_dp[-2] = int(with_prefill) - num_tokens_across_dp[-1] = int(not enable_dbo) - dist.all_reduce(num_tokens_across_dp, group=get_dp_group().cpu_group) - with_prefill = bool(num_tokens_across_dp[-2]) - enable_dbo = not bool(num_tokens_across_dp[-1]) - num_tokens_across_dp = num_tokens_across_dp[:-2] + local_forward_metadata = torch.tensor( + [[num_tokens, with_prefill, enable_dbo]], + device="npu", + dtype=torch.int32) + # use device all_gather + global_forward_metadata = get_dp_group().all_gather( + local_forward_metadata, dim=0) + maybe_padded_num_tokens = global_forward_metadata[:, 0].max().item() + num_tokens_across_dp = torch.tensor([maybe_padded_num_tokens] * + self.dp_size, + device="cpu", + dtype=torch.int32) + with_prefill = bool(global_forward_metadata[:, 1].any()) + enable_dbo = bool(global_forward_metadata[:, 2].any()) + return num_tokens_across_dp, with_prefill, enable_dbo def _get_forward_metadata_across_dp_and_pad(