diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index fe7a9795afc..84fa25be38b 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -29,7 +29,6 @@ import numpy.typing as npt import torch import torch._dynamo.cache_size -import torch.distributed as dist import torch.nn as nn from tqdm import tqdm # type: ignore from vllm.attention import AttentionType, get_attn_backend @@ -596,18 +595,15 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: def _get_forward_metadata_across_dp( self, num_tokens: int, with_prefill: bool, enable_dbo: bool) -> tuple[torch.Tensor, bool, bool]: - - # Compose: all_reduce metadata (num_tokens of each rank, with_prefill, enable_dbo) - num_tokens_across_dp = torch.zeros(self.dp_size + 2, - dtype=torch.int32, - device="cpu") - num_tokens_across_dp[self.dp_rank] = num_tokens - num_tokens_across_dp[-2] = int(with_prefill) - num_tokens_across_dp[-1] = int(not enable_dbo) - dist.all_reduce(num_tokens_across_dp, group=get_dp_group().cpu_group) - with_prefill = bool(num_tokens_across_dp[-2]) - enable_dbo = not bool(num_tokens_across_dp[-1]) - num_tokens_across_dp = num_tokens_across_dp[:-2] + local_forward_metadata = torch.tensor( + [[num_tokens, with_prefill, enable_dbo]], + device="npu", + dtype=torch.int32) + global_forward_metadata = get_dp_group().all_gather( + local_forward_metadata, dim=0) + num_tokens_across_dp = global_forward_metadata[:, 0].cpu() + with_prefill = bool(global_forward_metadata[:, 1].any()) + enable_dbo = bool(global_forward_metadata[:, 2].any()) return num_tokens_across_dp, with_prefill, enable_dbo def _get_forward_metadata_across_dp_and_pad(