From 15b115b06f4838ecca741a65c8d6111700c2bf07 Mon Sep 17 00:00:00 2001 From: Angazenn Date: Fri, 22 Aug 2025 16:19:34 +0800 Subject: [PATCH 1/3] optimize dp allreduce Signed-off-by: Angazenn --- vllm_ascend/worker/model_runner_v1.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 859f21ece31..715ad764ea3 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -587,17 +587,19 @@ def _get_forward_metadata_across_dp( self, num_tokens: int, with_prefill: bool, enable_dbo: bool) -> tuple[torch.Tensor, bool, bool]: - # Compose: all_reduce metadata (num_tokens of each rank, with_prefill, enable_dbo) - num_tokens_across_dp = torch.zeros(self.dp_size + 2, - dtype=torch.int32, - device="cpu") - num_tokens_across_dp[self.dp_rank] = num_tokens - num_tokens_across_dp[-2] = int(with_prefill) - num_tokens_across_dp[-1] = int(not enable_dbo) - dist.all_reduce(num_tokens_across_dp, group=get_dp_group().cpu_group) - with_prefill = bool(num_tokens_across_dp[-2]) - enable_dbo = not bool(num_tokens_across_dp[-1]) - num_tokens_across_dp = num_tokens_across_dp[:-2] + local_forward_metadata = torch.tensor([[num_tokens, with_prefill, enable_dbo]], + device="npu", + dtype=torch.int32) + global_forward_metadata = get_dp_group().all_gather( + local_forward_metadata, dim=0) + maybe_padded_num_tokens = global_forward_metadata[:, 0].cpu().max() + num_tokens_across_dp = torch.tensor([maybe_padded_num_tokens] * + self.dp_size, + device="cpu", + dtype=torch.int32) + with_prefill = bool(global_forward_metadata[:, 1].any()) + enable_dbo = bool(global_forward_metadata[:, 2].any()) + return num_tokens_across_dp, with_prefill, enable_dbo def _get_forward_metadata_across_dp_and_pad( From 03d790344b28f16e34166ff6fa5f68f4ec8270c2 Mon Sep 17 00:00:00 2001 From: Angazenn Date: Fri, 22 Aug 2025 17:36:47 +0800 Subject: [PATCH 2/3] fix lint Signed-off-by: Angazenn --- vllm_ascend/worker/model_runner_v1.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 715ad764ea3..ef6e38a33b9 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -30,7 +30,6 @@ import numpy.typing as npt import torch import torch._dynamo.cache_size -import torch.distributed as dist import torch.nn as nn from tqdm import tqdm # type: ignore from vllm.attention import AttentionType, get_attn_backend @@ -587,19 +586,20 @@ def _get_forward_metadata_across_dp( self, num_tokens: int, with_prefill: bool, enable_dbo: bool) -> tuple[torch.Tensor, bool, bool]: - local_forward_metadata = torch.tensor([[num_tokens, with_prefill, enable_dbo]], - device="npu", - dtype=torch.int32) + local_forward_metadata = torch.tensor( + [[num_tokens, with_prefill, enable_dbo]], + device="npu", + dtype=torch.int32) global_forward_metadata = get_dp_group().all_gather( local_forward_metadata, dim=0) - maybe_padded_num_tokens = global_forward_metadata[:, 0].cpu().max() + maybe_padded_num_tokens = global_forward_metadata[:, 0].max().item() num_tokens_across_dp = torch.tensor([maybe_padded_num_tokens] * self.dp_size, device="cpu", dtype=torch.int32) with_prefill = bool(global_forward_metadata[:, 1].any()) enable_dbo = bool(global_forward_metadata[:, 2].any()) - + return num_tokens_across_dp, with_prefill, enable_dbo def _get_forward_metadata_across_dp_and_pad( From d9e687b72819f840cf3780d3c540c6e1cc8b11e1 Mon Sep 17 00:00:00 2001 From: Angazenn Date: Mon, 25 Aug 2025 15:10:25 +0800 Subject: [PATCH 3/3] comment Signed-off-by: Angazenn --- vllm_ascend/worker/model_runner_v1.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index ef6e38a33b9..c9c94dcc5a0 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -590,6 +590,7 @@ def _get_forward_metadata_across_dp( [[num_tokens, with_prefill, enable_dbo]], device="npu", dtype=torch.int32) + # use device all_gather global_forward_metadata = get_dp_group().all_gather( local_forward_metadata, dim=0) maybe_padded_num_tokens = global_forward_metadata[:, 0].max().item()