From b30878aa485719c20a488344f0ee0ca6d6afe687 Mon Sep 17 00:00:00 2001 From: Jade Zheng Date: Wed, 2 Jul 2025 19:51:56 +0800 Subject: [PATCH 01/10] feat: optimize forward metadata collection across dp ranks Signed-off-by: Jade Zheng --- vllm_ascend/worker/model_runner_v1.py | 36 ++++++++++++++++----------- vllm_ascend/worker/worker_v1.py | 25 ++++++++++++++----- 2 files changed, 40 insertions(+), 21 deletions(-) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index d8ee57afcdb..546cb11ae2e 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -547,17 +547,15 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: if batch_changed: self.input_batch.refresh_sampling_metadata() - def _get_forward_metadata_across_dp( - self, total_num_scheduled_tokens: int, - with_prefill: bool) -> tuple[int, bool]: - forward_metadata = torch.tensor( - [total_num_scheduled_tokens, with_prefill], - device="cpu", - dtype=torch.int32) - dist.all_reduce(forward_metadata, - op=ReduceOp.MAX, - group=get_dp_group().cpu_group) - return int(forward_metadata[0]), bool(forward_metadata[1] > 0) + def _get_forward_metadata_across_dp(self, num_tokens: int, + with_prefill: bool) -> tuple[int, bool]: + local_forward_metadata = torch.tensor([num_tokens, with_prefill], + device="npu", dtype=torch.int32) + global_forward_metadata = get_dp_group().all_gather( + local_forward_metadata) + num_tokens_across_dp = global_forward_metadata[:, 0].cpu() + with_prefill = bool(global_forward_metadata[:, 1].any()) + return num_tokens_across_dp, with_prefill def get_eagle_atten_dict( self, @@ -1013,9 +1011,12 @@ def _process_reqs( AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding ] + num_tokens_across_dp = None if self.dp_size > 1: - max_num_tokens, with_prefill = self._get_forward_metadata_across_dp( - total_num_scheduled_tokens, with_prefill) + num_tokens_across_dp, with_prefill = \ + self._get_forward_metadata_across_dp(num_input_tokens, + with_prefill) + max_num_tokens = int(num_tokens_across_dp.max().item()) extra_builder_kwargs['max_num_tokens_across_dp'] = max_num_tokens extra_builder_kwargs['with_prefill_across_dp'] = with_prefill @@ -1024,6 +1025,8 @@ def _process_reqs( if self.dp_size > 1: padded_batch_size = self.select_torchair_padded_batch_size( max_num_tokens) + num_tokens_across_dp.masked_fill_(num_tokens_across_dp == -1, + padded_batch_size) else: padded_batch_size = self.select_torchair_padded_batch_size( total_num_scheduled_tokens) @@ -1106,7 +1109,8 @@ def _process_reqs( # Run forward pass with set_forward_context(attn_metadata, self.vllm_config, - num_tokens=num_input_tokens): + num_tokens=num_input_tokens, + num_tokens_across_dp=num_tokens_across_dp): with ProfileExecuteDuration().capture_async("forward"): model_kwargs = {} if self.torchair_graph_enabled: @@ -1585,6 +1589,7 @@ def _dummy_run( num_tokens: int, is_compile: bool = False, with_prefill: bool = True, + num_tokens_across_dp: Optional[int] = None, ) -> torch.Tensor: # Set num_scheduled_tokens based on num_tokens and max_num_seqs # for dummy run with LoRA so that the num_reqs collectively @@ -1628,7 +1633,8 @@ def _dummy_run( with set_forward_context(None, self.vllm_config, - num_tokens=num_tokens): + num_tokens=num_tokens, + num_tokens_across_dp=num_tokens_across_dp): if self.torchair_graph_enabled and not with_prefill: attn_metadata = self.attn_metadata_builder.build_dummy( num_reqs=num_tokens, num_actual_tokens=1) diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index df03d508e43..f4e5e56b2e7 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -294,16 +294,29 @@ def pin_lora(self, lora_id: int) -> bool: def execute_dummy_batch(self) -> None: runner = self.model_runner - max_num_tokens = 1 + + # If torchair graph is enabled, notify the other DP ranks that this is a + # dummy run by using '-1' as a flag for num_tokens. This will be + # replaced with the final determined graph size before the forward pass. + num_tokens = (-1 if runner.torchair_graph_enabled and not with_prefill + else 1) + num_tokens_across_dp = None with_prefill = False + if runner.dp_size > 1: - max_num_tokens, with_prefill = runner._get_forward_metadata_across_dp( - max_num_tokens, with_prefill) + num_tokens_across_dp, with_prefill = \ + runner._get_forward_metadata_across_dp(num_tokens, with_prefill) + num_tokens = int(num_tokens_across_dp.max().item()) + if runner.torchair_graph_enabled and not with_prefill: - max_num_tokens = runner.select_torchair_padded_batch_size( - max_num_tokens) - runner._dummy_run(max_num_tokens, + num_tokens = runner.select_torchair_padded_batch_size(num_tokens) + if num_tokens_across_dp is not None: + num_tokens_across_dp.masked_fill_(num_tokens_across_dp == -1, + num_tokens) + + runner._dummy_run(num_tokens, is_compile=False, + num_tokens_across_dp=num_tokens_across_dp, with_prefill=with_prefill) def _init_worker_distributed_environment(self) -> None: From e160939bb3ee975c2ab494a678c05f7680204760 Mon Sep 17 00:00:00 2001 From: Jade Zheng Date: Wed, 2 Jul 2025 22:20:00 +0800 Subject: [PATCH 02/10] refactor: remove unused imports from model_runner_v1.py Signed-off-by: Jade Zheng --- vllm_ascend/worker/model_runner_v1.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 546cb11ae2e..0d89ea99506 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -30,9 +30,7 @@ import numpy.typing as npt import torch import torch._dynamo.cache_size -import torch.distributed as dist import torch.nn as nn -from torch.distributed import ReduceOp from vllm.attention import AttentionType, get_attn_backend from vllm.attention.layer import Attention from vllm.config import CompilationLevel, VllmConfig From dc9a0de2640c959bb22148b5cc2d6830d469d9a9 Mon Sep 17 00:00:00 2001 From: Jade Zheng Date: Wed, 2 Jul 2025 22:22:55 +0800 Subject: [PATCH 03/10] fix: correct handling the num_tokens for dummy run Signed-off-by: Jade Zheng --- vllm_ascend/worker/worker_v1.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index f4e5e56b2e7..ef4735515af 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -294,26 +294,25 @@ def pin_lora(self, lora_id: int) -> bool: def execute_dummy_batch(self) -> None: runner = self.model_runner + assert runner.dp_size > 1, "Dummy batch execution should only be " \ + "performed with data parallelism enabled, but got " \ + f"dp_size={runner.dp_size}." # If torchair graph is enabled, notify the other DP ranks that this is a # dummy run by using '-1' as a flag for num_tokens. This will be # replaced with the final determined graph size before the forward pass. - num_tokens = (-1 if runner.torchair_graph_enabled and not with_prefill - else 1) - num_tokens_across_dp = None - with_prefill = False - - if runner.dp_size > 1: - num_tokens_across_dp, with_prefill = \ - runner._get_forward_metadata_across_dp(num_tokens, with_prefill) - num_tokens = int(num_tokens_across_dp.max().item()) + num_tokens_across_dp, with_prefill = \ + runner._get_forward_metadata_across_dp(-1, False) if runner.torchair_graph_enabled and not with_prefill: - num_tokens = runner.select_torchair_padded_batch_size(num_tokens) - if num_tokens_across_dp is not None: - num_tokens_across_dp.masked_fill_(num_tokens_across_dp == -1, - num_tokens) + max_num_tokens = int(num_tokens_across_dp.max().item()) + num_tokens = runner.select_torchair_padded_batch_size( + max_num_tokens) + else: + num_tokens = 1 + num_tokens_across_dp.masked_fill_(num_tokens_across_dp == -1, + num_tokens) runner._dummy_run(num_tokens, is_compile=False, num_tokens_across_dp=num_tokens_across_dp, From 91b49e3dcf6a7aa5f6313bcbbd15b9465c0c3c4d Mon Sep 17 00:00:00 2001 From: Jade Zheng Date: Wed, 2 Jul 2025 22:30:51 +0800 Subject: [PATCH 04/10] chore: lint Signed-off-by: Jade Zheng --- vllm_ascend/worker/model_runner_v1.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 0d89ea99506..c99c7ca9272 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -545,10 +545,12 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: if batch_changed: self.input_batch.refresh_sampling_metadata() - def _get_forward_metadata_across_dp(self, num_tokens: int, - with_prefill: bool) -> tuple[int, bool]: + def _get_forward_metadata_across_dp( + self, num_tokens: int, + with_prefill: bool) -> tuple[torch.Tensor, bool]: local_forward_metadata = torch.tensor([num_tokens, with_prefill], - device="npu", dtype=torch.int32) + device="npu", + dtype=torch.int32) global_forward_metadata = get_dp_group().all_gather( local_forward_metadata) num_tokens_across_dp = global_forward_metadata[:, 0].cpu() @@ -1629,10 +1631,11 @@ def _dummy_run( for k, v in self.intermediate_tensors.items() }) - with set_forward_context(None, - self.vllm_config, - num_tokens=num_tokens, - num_tokens_across_dp=num_tokens_across_dp): + with set_forward_context( + None, + self.vllm_config, + num_tokens=num_tokens, + num_tokens_across_dp=num_tokens_across_dp): if self.torchair_graph_enabled and not with_prefill: attn_metadata = self.attn_metadata_builder.build_dummy( num_reqs=num_tokens, num_actual_tokens=1) From ad1e3415bfa52824766e3045bfc241a79c5c137f Mon Sep 17 00:00:00 2001 From: Jade Zheng Date: Wed, 2 Jul 2025 22:41:53 +0800 Subject: [PATCH 05/10] fix: improve handling of max_num_tokens Signed-off-by: Jade Zheng --- vllm_ascend/worker/model_runner_v1.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index c99c7ca9272..6e8407be365 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1022,17 +1022,18 @@ def _process_reqs( # Add graph_pad_size here if self.torchair_graph_enabled and not with_prefill: - if self.dp_size > 1: - padded_batch_size = self.select_torchair_padded_batch_size( - max_num_tokens) - num_tokens_across_dp.masked_fill_(num_tokens_across_dp == -1, - padded_batch_size) - else: - padded_batch_size = self.select_torchair_padded_batch_size( - total_num_scheduled_tokens) + max_num_tokens = (max_num_tokens + if self.dp_size > 1 else num_input_tokens) + padded_batch_size = self.select_torchair_padded_batch_size( + max_num_tokens) + num_tokens_across_dp.masked_fill_(num_tokens_across_dp == -1, + padded_batch_size) graph_pad_size = padded_batch_size - total_num_scheduled_tokens - extra_builder_kwargs['graph_pad_size'] = graph_pad_size + else: + # If torchair graph is not enabled, or if with_prefill is True, the + # dummy run batch size is set to 1. + num_tokens_across_dp.masked_fill_(num_tokens_across_dp == -1, 1) if self.vllm_config.model_config.use_mla: query_start_loc = self.query_start_loc[:num_reqs + 1] From 83b9dca5e09a4419c79c9d72e7053de30fd5ecc6 Mon Sep 17 00:00:00 2001 From: Jade Zheng Date: Wed, 2 Jul 2025 22:56:01 +0800 Subject: [PATCH 06/10] fix: update dummy run batch size handling Signed-off-by: Jade Zheng --- vllm_ascend/worker/model_runner_v1.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 6e8407be365..4f705608cfe 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1026,14 +1026,19 @@ def _process_reqs( if self.dp_size > 1 else num_input_tokens) padded_batch_size = self.select_torchair_padded_batch_size( max_num_tokens) - num_tokens_across_dp.masked_fill_(num_tokens_across_dp == -1, - padded_batch_size) graph_pad_size = padded_batch_size - total_num_scheduled_tokens extra_builder_kwargs['graph_pad_size'] = graph_pad_size + # If torchair graph is enabled and in decode mode, the dummy run + # batch size is set to the selected graph size. + dummy_num_tokens = padded_batch_size else: # If torchair graph is not enabled, or if with_prefill is True, the # dummy run batch size is set to 1. - num_tokens_across_dp.masked_fill_(num_tokens_across_dp == -1, 1) + dummy_num_tokens = 1 + + if self.dp_size > 1: + num_tokens_across_dp.masked_fill_(num_tokens_across_dp == -1, + dummy_num_tokens) if self.vllm_config.model_config.use_mla: query_start_loc = self.query_start_loc[:num_reqs + 1] From 41f905eefaae58137b4ec599c43e2092782a8751 Mon Sep 17 00:00:00 2001 From: Jade Zheng Date: Wed, 2 Jul 2025 23:16:39 +0800 Subject: [PATCH 07/10] fix: add assertion for num_tokens_across_dp in NPUModelRunner Signed-off-by: Jade Zheng --- vllm_ascend/worker/model_runner_v1.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 4f705608cfe..0f737903948 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1037,6 +1037,7 @@ def _process_reqs( dummy_num_tokens = 1 if self.dp_size > 1: + assert num_tokens_across_dp is not None num_tokens_across_dp.masked_fill_(num_tokens_across_dp == -1, dummy_num_tokens) From 82c53b2e4f28c87db12339530fc167fb4be5bb15 Mon Sep 17 00:00:00 2001 From: Jade Zheng Date: Wed, 2 Jul 2025 23:17:18 +0800 Subject: [PATCH 08/10] fix: change assertion to exception for dummy batch execution in NPUWorker Signed-off-by: Jade Zheng --- vllm_ascend/worker/worker_v1.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index ef4735515af..111a0af2778 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -294,9 +294,10 @@ def pin_lora(self, lora_id: int) -> bool: def execute_dummy_batch(self) -> None: runner = self.model_runner - assert runner.dp_size > 1, "Dummy batch execution should only be " \ - "performed with data parallelism enabled, but got " \ - f"dp_size={runner.dp_size}." + if runner.dp_size <= 1: + raise ValueError("Dummy batch execution should only be " + "performed with data parallelism enabled, but got " + f"dp_size={runner.dp_size}.") # If torchair graph is enabled, notify the other DP ranks that this is a # dummy run by using '-1' as a flag for num_tokens. This will be From bc3b360740c802617ecc4a63864129f2c25dd131 Mon Sep 17 00:00:00 2001 From: Jade Zheng Date: Wed, 2 Jul 2025 23:39:18 +0800 Subject: [PATCH 09/10] chore: lint Signed-off-by: Jade Zheng --- vllm_ascend/worker/worker_v1.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index 111a0af2778..4b3c057073c 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -295,9 +295,10 @@ def pin_lora(self, lora_id: int) -> bool: def execute_dummy_batch(self) -> None: runner = self.model_runner if runner.dp_size <= 1: - raise ValueError("Dummy batch execution should only be " - "performed with data parallelism enabled, but got " - f"dp_size={runner.dp_size}.") + raise ValueError( + "Dummy batch execution should only be " + "performed with data parallelism enabled, but got " + f"dp_size={runner.dp_size}.") # If torchair graph is enabled, notify the other DP ranks that this is a # dummy run by using '-1' as a flag for num_tokens. This will be From d35cbf35e44144def7886890b9950d3789887a74 Mon Sep 17 00:00:00 2001 From: Jade Zheng Date: Tue, 8 Jul 2025 14:46:06 +0800 Subject: [PATCH 10/10] Update vllm_ascend/worker/model_runner_v1.py Co-authored-by: Angazenn <92204292+Angazenn@users.noreply.github.com> Signed-off-by: Jade Zheng --- vllm_ascend/worker/model_runner_v1.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 0f737903948..99abdfa60b9 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -550,9 +550,9 @@ def _get_forward_metadata_across_dp( with_prefill: bool) -> tuple[torch.Tensor, bool]: local_forward_metadata = torch.tensor([num_tokens, with_prefill], device="npu", - dtype=torch.int32) + dtype=torch.int32).unsqueeze(0) global_forward_metadata = get_dp_group().all_gather( - local_forward_metadata) + local_forward_metadata, dim=0) num_tokens_across_dp = global_forward_metadata[:, 0].cpu() with_prefill = bool(global_forward_metadata[:, 1].any()) return num_tokens_across_dp, with_prefill