From 12bfa49daaa9416015c1a34775ebb9b4e29b794a Mon Sep 17 00:00:00 2001 From: Yizhou Liu Date: Wed, 27 Aug 2025 15:15:33 +0800 Subject: [PATCH 1/3] refactor(model_runner): Refactor input preparation logic in NPUModelRunner, preparing to remove `get_dp_padding` Moves the determination of attention state, padding, and other forward metadata to an earlier stage within the input preparation method. This improves code clarity by grouping related metadata calculations together before tensor manipulations occur. The variable `padded_num_tokens_across_dp` is also renamed to `maybe_padded_num_tokens` to more accurately reflect that padding is conditional. Signed-off-by: Yizhou Liu --- vllm_ascend/worker/model_runner_v1.py | 89 ++++++++++++++++----------- 1 file changed, 52 insertions(+), 37 deletions(-) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index fe7a9795afc..3ae23c739a2 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1059,24 +1059,10 @@ def _prepare_inputs( torch.Tensor, int, torch.Tensor, SpecDecodeMetadata, Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: - # Check input valid total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 num_reqs = self.input_batch.num_reqs assert num_reqs > 0 - if (self.use_aclgraph and total_num_scheduled_tokens - <= self.aclgraph_batch_sizes[-1]): - # Add padding to the batch size. - num_input_tokens = self.vllm_config.pad_for_cudagraph( - total_num_scheduled_tokens) - else: - # Eager mode. - num_input_tokens = total_num_scheduled_tokens - - # Padding for DP - num_pad, num_tokens_across_dp_native = self.get_dp_padding( - num_input_tokens) - num_input_tokens += num_pad self.attn_metadata_builder.reorder_batch(self.input_batch, scheduler_output) @@ -1097,6 +1083,42 @@ def _prepare_inputs( max_num_scheduled_tokens = max(max_num_scheduled_tokens, num_tokens) + if (self.use_aclgraph and total_num_scheduled_tokens + <= self.aclgraph_batch_sizes[-1]): + # Add padding to the batch size. + num_input_tokens = self.vllm_config.pad_for_cudagraph( + total_num_scheduled_tokens) + else: + # Eager mode. + num_input_tokens = total_num_scheduled_tokens + + # Get the attention state. + attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens, + num_valid_tokens) + self.attn_state = attn_state # type: ignore + + # Determine if it's a splitfuse batch + with_prefill = attn_state not in [ + AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding + ] + + self.query_lens = torch.from_numpy(num_scheduled_tokens) + enable_dbo = self._check_dbo_is_valid(self.query_lens.tolist(), + attn_state, + total_num_scheduled_tokens) + + # Get info across DP ranks. + # NOTE: maybe_padded_num_tokens is only used when using TorchAir with DP, + # Otherwise, it's just total_num_scheduled_tokens + (maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, + enable_dbo) = self._get_forward_metadata_across_dp_and_pad( + num_input_tokens, with_prefill, enable_dbo) + + # # Padding for DP + # num_pad, num_tokens_across_dp_native = self.get_dp_padding( + # num_input_tokens) + # num_input_tokens += num_pad + # Hot-Swap lora model if self.lora_config: self.set_active_loras(self.input_batch, num_scheduled_tokens) @@ -1165,20 +1187,9 @@ def _prepare_inputs( self.seq_lens[num_reqs:].fill_(0) self.query_start_loc[num_reqs + 1:].fill_(-1) - with_prefill = attn_state not in [ - AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding - ] - - enable_dbo = self._check_dbo_is_valid(self.query_lens.tolist(), - attn_state, - total_num_scheduled_tokens) - - (padded_num_tokens_across_dp, num_tokens_across_dp, with_prefill, - enable_dbo) = self._get_forward_metadata_across_dp_and_pad( - total_num_scheduled_tokens, with_prefill, enable_dbo) self.with_prefill = with_prefill self.num_tokens_across_dp = num_tokens_across_dp - self._update_graph_pad_size(with_prefill, padded_num_tokens_across_dp) + self._update_graph_pad_size(with_prefill, maybe_padded_num_tokens) common_attn_metadata = AscendCommonAttentionMetadata( query_start_loc=self.query_start_loc[:num_reqs + 1], query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1], @@ -1246,7 +1257,7 @@ def _prepare_inputs( positions = self.positions[:num_input_tokens] input_ids, positions = self._update_input_ids_and_positions( input_ids, positions, num_input_tokens, with_prefill, - padded_num_tokens_across_dp) + maybe_padded_num_tokens) if get_pp_group().is_first_rank: intermediate_tensors = None @@ -1265,9 +1276,13 @@ def _prepare_inputs( # MC2 may not be available in eager mode # TODO: Unify the padding logic between TorchAir and ACL Graph ASAP if self.use_aclgraph: - num_tokens_across_dp = num_tokens_across_dp_native + # print(f"num_tokens_across_dp: {num_tokens_across_dp}, " + # f"num_tokens_across_dp_native: {num_tokens_across_dp_native}") + # num_tokens_across_dp = num_tokens_across_dp_native + ... else: - num_input_tokens = padded_num_tokens_across_dp + # num_input_tokens = maybe_padded_num_tokens + ... use_spec_decode = len( scheduler_output.scheduled_spec_decode_tokens) > 0 @@ -1296,12 +1311,12 @@ def _prepare_inputs( return (attn_metadata, positions, num_scheduled_tokens, num_input_tokens, num_tokens_across_dp, - padded_num_tokens_across_dp, logits_indices, + maybe_padded_num_tokens, logits_indices, spec_decode_metadata, input_ids, inputs_embeds, intermediate_tensors) def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill, - padded_num_tokens_across_dp, + maybe_padded_num_tokens, input_ids, positions, intermediate_tensors, inputs_embeds): @@ -1344,7 +1359,7 @@ def _update_graph_pad_size(self, with_prefill, graph_pad_size): def _update_input_ids_and_positions(self, input_ids, positions, num_input_tokens, with_prefill, - padded_num_tokens_across_dp): + maybe_padded_num_tokens): if self.uses_mrope: positions = self.mrope_positions[:, :num_input_tokens] return input_ids, positions @@ -1649,7 +1664,7 @@ def execute_model( return self.kv_connector_no_forward(scheduler_output) (attn_metadata, positions, num_scheduled_tokens_np, num_input_tokens, num_tokens_across_dp, - padded_num_tokens_across_dp, logits_indices, spec_decode_metadata, + maybe_padded_num_tokens, logits_indices, spec_decode_metadata, input_ids, inputs_embeds, intermediate_tensors) = (self._prepare_inputs( scheduler_output, intermediate_tensors)) @@ -1680,7 +1695,7 @@ def execute_model( hidden_states = self._generate_process_reqs_hidden_states( attn_metadata, self.with_prefill, - padded_num_tokens_across_dp, input_ids, positions, + maybe_padded_num_tokens, input_ids, positions, intermediate_tensors, inputs_embeds) self.maybe_wait_for_kv_save() @@ -2000,9 +2015,9 @@ def _dummy_run( "Capturing attention in aclgraph is unexpected, because full graph is not supported now" ) - # Padding for DP - num_pad, num_tokens_across_dp_native = self.get_dp_padding(num_tokens) - # num_tokens += num_pad ## Uncomment this after TorchAir is removed + # # Padding for DP + # num_pad, num_tokens_across_dp_native = self.get_dp_padding(num_tokens) + # # num_tokens += num_pad ## Uncomment this after TorchAir is removed # Padding for DP (for TorchAir) (num_tokens, num_tokens_across_dp, with_prefill, From f0a89b5fb2b049a9363314894aed04c72de7a0b5 Mon Sep 17 00:00:00 2001 From: Yizhou Liu Date: Wed, 27 Aug 2025 22:37:23 +0800 Subject: [PATCH 2/3] refactor(model_runner): Refactor DP metadata synchronization and padding logic Unifies and simplifies the logic for synchronizing metadata (number of tokens, prefill status, DBO status) across data parallel (DP) ranks. This change renames `_get_forward_metadata_across_dp_and_pad` to a more descriptive `_sync_metadata_across_dp` and consolidates the padding logic within it. The separate `get_dp_padding` function is removed. The synchronization mechanism is improved by packing all metadata into a single tensor for a more efficient `all_reduce` operation. This refactoring streamlines the code, removes redundancy, and clarifies the data flow for DP padding in both TorchAir and standard execution modes. Signed-off-by: Yizhou Liu --- vllm_ascend/torchair/torchair_model_runner.py | 15 +- vllm_ascend/worker/model_runner_v1.py | 133 +++++++----------- vllm_ascend/worker/mtp_proposer_v1.py | 6 +- 3 files changed, 63 insertions(+), 91 deletions(-) diff --git a/vllm_ascend/torchair/torchair_model_runner.py b/vllm_ascend/torchair/torchair_model_runner.py index fb4f583d11d..24fd33a1ea4 100644 --- a/vllm_ascend/torchair/torchair_model_runner.py +++ b/vllm_ascend/torchair/torchair_model_runner.py @@ -70,7 +70,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): register_torchair_model() torchair_quant_method_register() - def _get_forward_metadata_across_dp_and_pad( + def _sync_metadata_across_dp( self, num_tokens: int, with_prefill: bool, enable_dbo: bool ) -> tuple[int, Optional[torch.Tensor], bool, bool]: """Override from NPUModelRunner to pad num_tokens""" @@ -81,8 +81,17 @@ def _get_forward_metadata_across_dp_and_pad( return maybe_padded_num_tokens, None, with_prefill, enable_dbo return num_tokens, None, with_prefill, enable_dbo - num_tokens_across_dp, with_prefill, enable_dbo = self._get_forward_metadata_across_dp( - num_tokens, with_prefill, enable_dbo) + num_tokens_across_dp = torch.zeros(self.dp_size + 2, + dtype=torch.int32, + device="npu") + num_tokens_across_dp[self.dp_rank] = num_tokens + num_tokens_across_dp[-2] = int(with_prefill) + num_tokens_across_dp[-1] = int(not enable_dbo) + dist.all_reduce(num_tokens_across_dp, + group=get_dp_group().device_group) + with_prefill = bool(num_tokens_across_dp[-2]) + enable_dbo = not bool(num_tokens_across_dp[-1]) + num_tokens_across_dp = num_tokens_across_dp[:-2] if not with_prefill: max_num_token = num_tokens_across_dp.max().item() diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 3ae23c739a2..8ff55742b43 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -43,8 +43,7 @@ from vllm.distributed.parallel_state import (get_dp_group, get_pp_group, get_tp_group, is_global_first_rank) -from vllm.forward_context import (BatchDescriptor, DPMetadata, - get_forward_context) +from vllm.forward_context import BatchDescriptor, get_forward_context from vllm.logger import logger from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding @@ -593,32 +592,43 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Refresh batch metadata with any pending updates. self.input_batch.refresh_metadata() - def _get_forward_metadata_across_dp( - self, num_tokens: int, with_prefill: bool, - enable_dbo: bool) -> tuple[torch.Tensor, bool, bool]: - - # Compose: all_reduce metadata (num_tokens of each rank, with_prefill, enable_dbo) - num_tokens_across_dp = torch.zeros(self.dp_size + 2, - dtype=torch.int32, - device="cpu") - num_tokens_across_dp[self.dp_rank] = num_tokens - num_tokens_across_dp[-2] = int(with_prefill) - num_tokens_across_dp[-1] = int(not enable_dbo) - dist.all_reduce(num_tokens_across_dp, group=get_dp_group().cpu_group) - with_prefill = bool(num_tokens_across_dp[-2]) - enable_dbo = not bool(num_tokens_across_dp[-1]) - num_tokens_across_dp = num_tokens_across_dp[:-2] - return num_tokens_across_dp, with_prefill, enable_dbo - - def _get_forward_metadata_across_dp_and_pad( + def _sync_metadata_across_dp( self, num_tokens: int, with_prefill: bool, enable_dbo: bool ) -> tuple[int, Optional[torch.Tensor], bool, bool]: - if self.dp_size == 1: + if self.dp_size == 1 or self.vllm_config.model_config.enforce_eager: return num_tokens, None, with_prefill, enable_dbo - num_tokens_across_dp, with_prefill, enable_dbo = self._get_forward_metadata_across_dp( - num_tokens, with_prefill, enable_dbo) - return num_tokens, num_tokens_across_dp, with_prefill, enable_dbo + # Sync num_tokens, with_prefill, enable_dbo across dp ranks + num_tokens_tensor = torch.tensor([ + num_tokens if i == self.dp_rank else 0 for i in range(self.dp_size) + ], + dtype=torch.int32, + device="npu") + + flags_tensor = torch.tensor( + [int(with_prefill), int(not enable_dbo)], + dtype=torch.int32, + device="npu") + + packed_tensor = torch.cat([num_tokens_tensor, flags_tensor]) + + dist.all_reduce(packed_tensor, group=get_dp_group().device_group) + + # Unpack the results + num_tokens_across_dp = packed_tensor[:-2] + synced_flags = packed_tensor[-2:] + + max_tokens_across_dp = torch.max(num_tokens_across_dp).item() + global_with_prefill = bool(synced_flags[0]) + global_enable_dbo = not bool(synced_flags[1]) + + # Create a tensor for num_tokens_after_padding + num_tokens_after_padding = torch.tensor([max_tokens_across_dp] * + self.dp_size, + device="npu", + dtype=torch.int32) + + return max_tokens_across_dp, num_tokens_after_padding, global_with_prefill, global_enable_dbo def _check_dbo_is_valid(self, query_lens: torch.Tensor, attn_state: AscendAttentionState, @@ -1024,32 +1034,6 @@ def _gather_mm_embeddings( mm_embeds.append(mm_embeds_item) return mm_embeds - def get_dp_padding(self, - num_tokens: int) -> tuple[int, Optional[torch.Tensor]]: - """This implementation is derived from vLLM's `GPUModelRunner.get_dp_padding`. - Please note that vLLM may refactor or modify this function over time, - at present, we are using the version introduced in PR #18935. - """ - dp_size = self.vllm_config.parallel_config.data_parallel_size - dp_rank = self.vllm_config.parallel_config.data_parallel_rank - - # For DP: Don't pad when setting enforce_eager. - # This lets us set enforce_eager on the prefiller in a P/D setup and - # still use ACL graphs (enabled by this padding) on the decoder. - - if dp_size == 1 or self.vllm_config.model_config.enforce_eager: - # Early exit. - return 0, None - - num_tokens_across_dp = DPMetadata.num_tokens_across_dp( - num_tokens, dp_size, dp_rank) - max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp).item() - num_tokens_after_padding = torch.tensor([max_tokens_across_dp_cpu] * - dp_size, - device="cpu", - dtype=torch.int32) - return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding - def _prepare_inputs( self, scheduler_output: "SchedulerOutput", @@ -1109,15 +1093,14 @@ def _prepare_inputs( # Get info across DP ranks. # NOTE: maybe_padded_num_tokens is only used when using TorchAir with DP, - # Otherwise, it's just total_num_scheduled_tokens + # Otherwise, it's just max_tokens_across_dp_cpu (maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, - enable_dbo) = self._get_forward_metadata_across_dp_and_pad( - num_input_tokens, with_prefill, enable_dbo) + enable_dbo) = self._sync_metadata_across_dp(num_input_tokens, + with_prefill, enable_dbo) - # # Padding for DP - # num_pad, num_tokens_across_dp_native = self.get_dp_padding( - # num_input_tokens) - # num_input_tokens += num_pad + if self.use_aclgraph: + # When using TorchAir with DP, we have other plans for padding + num_input_tokens = maybe_padded_num_tokens # Hot-Swap lora model if self.lora_config: @@ -1272,18 +1255,6 @@ def _prepare_inputs( for k, v in self.intermediate_tensors.items() }) - # NOTE: Currently this padding logic is really messy, - # MC2 may not be available in eager mode - # TODO: Unify the padding logic between TorchAir and ACL Graph ASAP - if self.use_aclgraph: - # print(f"num_tokens_across_dp: {num_tokens_across_dp}, " - # f"num_tokens_across_dp_native: {num_tokens_across_dp_native}") - # num_tokens_across_dp = num_tokens_across_dp_native - ... - else: - # num_input_tokens = maybe_padded_num_tokens - ... - use_spec_decode = len( scheduler_output.scheduled_spec_decode_tokens) > 0 if not use_spec_decode: @@ -1311,9 +1282,8 @@ def _prepare_inputs( return (attn_metadata, positions, num_scheduled_tokens, num_input_tokens, num_tokens_across_dp, - maybe_padded_num_tokens, logits_indices, - spec_decode_metadata, input_ids, inputs_embeds, - intermediate_tensors) + maybe_padded_num_tokens, logits_indices, spec_decode_metadata, + input_ids, inputs_embeds, intermediate_tensors) def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill, maybe_padded_num_tokens, @@ -1663,9 +1633,8 @@ def execute_model( return EMPTY_MODEL_RUNNER_OUTPUT return self.kv_connector_no_forward(scheduler_output) (attn_metadata, positions, num_scheduled_tokens_np, - num_input_tokens, num_tokens_across_dp, - maybe_padded_num_tokens, logits_indices, spec_decode_metadata, - input_ids, inputs_embeds, + num_input_tokens, num_tokens_across_dp, maybe_padded_num_tokens, + logits_indices, spec_decode_metadata, input_ids, inputs_embeds, intermediate_tensors) = (self._prepare_inputs( scheduler_output, intermediate_tensors)) @@ -1694,9 +1663,8 @@ def execute_model( self.maybe_setup_kv_connector(scheduler_output) hidden_states = self._generate_process_reqs_hidden_states( - attn_metadata, self.with_prefill, - maybe_padded_num_tokens, input_ids, positions, - intermediate_tensors, inputs_embeds) + attn_metadata, self.with_prefill, maybe_padded_num_tokens, + input_ids, positions, intermediate_tensors, inputs_embeds) self.maybe_wait_for_kv_save() finished_sending, finished_recving = self.get_finished_kv_transfer( @@ -2015,14 +1983,9 @@ def _dummy_run( "Capturing attention in aclgraph is unexpected, because full graph is not supported now" ) - # # Padding for DP - # num_pad, num_tokens_across_dp_native = self.get_dp_padding(num_tokens) - # # num_tokens += num_pad ## Uncomment this after TorchAir is removed - - # Padding for DP (for TorchAir) + # Padding for DP (num_tokens, num_tokens_across_dp, with_prefill, - _) = self._get_forward_metadata_across_dp_and_pad( - num_tokens, with_prefill, False) + _) = self._sync_metadata_across_dp(num_tokens, with_prefill, False) # If cudagraph_mode.decode_mode() == FULL and # cudagraph_mode.seperate_routine(). This means that we are using diff --git a/vllm_ascend/worker/mtp_proposer_v1.py b/vllm_ascend/worker/mtp_proposer_v1.py index 1ec14363724..120b17a652c 100644 --- a/vllm_ascend/worker/mtp_proposer_v1.py +++ b/vllm_ascend/worker/mtp_proposer_v1.py @@ -194,7 +194,7 @@ def propose( # torch mode need to update num_tokens_across_dp # TODO: adapt enable_dbo later (num_input_tokens, num_tokens_across_dp, with_prefill, - _) = self.runner._get_forward_metadata_across_dp_and_pad( + _) = self.runner._sync_metadata_across_dp( num_tokens, self.runner.with_prefill, False) attn_metadata.slot_mapping = target_slot_mapping else: @@ -281,8 +281,8 @@ def dummy_run(self, if not self.torchair_graph_enabled: # TODO: adapt enable_dbo later (num_tokens, num_tokens_across_dp, with_prefill, - _) = self.runner._get_forward_metadata_across_dp_and_pad( - num_tokens, with_prefill, False) + _) = self.runner._sync_metadata_across_dp(num_tokens, + with_prefill, False) is_running_torchair = self.torchair_graph_enabled and \ not with_prefill From 70818aa411e48d731859e953ebc67a83aafff2d6 Mon Sep 17 00:00:00 2001 From: Yizhou Liu Date: Wed, 27 Aug 2025 23:16:56 +0800 Subject: [PATCH 3/3] refactor(moe): Refactor MoE communication and remove dummy implementation Removes the unused `DummyCommImpl` for Mixture-of-Experts communication. The logic for selecting the MoE communication method is centralized into a new `_select_moe_comm_method` within the model runner. This method dynamically chooses the appropriate communication strategy based on the number of tokens, simplifying the control flow and removing hardcoded defaults from model execution and warmup routines. Signed-off-by: Yizhou Liu --- vllm_ascend/distributed/moe_comm_method.py | 37 ---------------------- vllm_ascend/ops/common_fused_moe.py | 8 +++-- vllm_ascend/worker/model_runner_v1.py | 22 ++++++------- 3 files changed, 15 insertions(+), 52 deletions(-) diff --git a/vllm_ascend/distributed/moe_comm_method.py b/vllm_ascend/distributed/moe_comm_method.py index 02f6d52aff8..ea324958415 100644 --- a/vllm_ascend/distributed/moe_comm_method.py +++ b/vllm_ascend/distributed/moe_comm_method.py @@ -94,43 +94,6 @@ def unpermute(self, mlp_output: torch.Tensor, pass -class DummyCommImpl(MoECommMethod): - - def prepare( - self, hidden_states: torch.Tensor, - router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - """Dummy prepare method that does nothing.""" - return hidden_states, router_logits - - def finalize(self, hidden_states: torch.Tensor, - reduce_results: bool) -> torch.Tensor: - """Dummy finalize method that does nothing.""" - return hidden_states - - def permute( - self, - hidden_states: torch.Tensor, - topk_ids: torch.Tensor, - topk_weights: torch.Tensor, - expert_map: torch.Tensor, - num_experts: int, - ) -> tuple[torch.Tensor, torch.Tensor, int]: - """Dummy implementation, make sure the output shapes are correct.""" - top_k_num = topk_ids.shape[1] - permuted_hidden_states = hidden_states.repeat_interleave(top_k_num, - dim=0) - expert_tokens = torch.zeros((num_experts, ), - dtype=torch.int64, - device=hidden_states.device) - group_list_type = 0 - return permuted_hidden_states, expert_tokens, group_list_type - - def unpermute(self, mlp_output: torch.Tensor, - hidden_states: torch.Tensor) -> None: - """Dummy implementation that does nothing.""" - pass - - class AllGatherCommImpl(MoECommMethod): """This implementation is the same as NativeAllGatherCommImpl, but uses NPU-specific ops for better performance. diff --git a/vllm_ascend/ops/common_fused_moe.py b/vllm_ascend/ops/common_fused_moe.py index ffc1dea87e4..72ee91b31eb 100644 --- a/vllm_ascend/ops/common_fused_moe.py +++ b/vllm_ascend/ops/common_fused_moe.py @@ -26,7 +26,6 @@ from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.distributed.moe_comm_method import (AllGatherCommImpl, - DummyCommImpl, MC2CommImpl, MoECommMethod) from vllm_ascend.distributed.parallel_state import get_mc2_group @@ -230,7 +229,7 @@ def __init__( self.moe_config.ep_group = get_ep_group() self.moe_config.mc2_group = get_mc2_group() - for method in {AllGatherCommImpl, DummyCommImpl, MC2CommImpl}: + for method in {AllGatherCommImpl, MC2CommImpl}: setattr( self, method.__name__.lower(), method(moe_config=self.moe_config)) # type: ignore[abstract] @@ -241,8 +240,11 @@ def forward_impl(self, hidden_states: torch.Tensor, forward_context = get_forward_context() moe_comm_method_name = forward_context.moe_comm_method_name - if not self.moe_config.use_ep and moe_comm_method_name != "dummycommimpl": + + # TODO: Can we refactor this logic to model_runner? + if not self.moe_config.use_ep: moe_comm_method_name = "allgathercommimpl" + forward_context.moe_comm_method = getattr(self, moe_comm_method_name) hidden_states, router_logits = forward_context.moe_comm_method.prepare( diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 8ff55742b43..2c86ec901bb 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -372,10 +372,6 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): device=self.device, ) - self.moe_comm_method = "mc2" - self.fallback_moe_comm_method = "allgather" - self.dummy_moe_comm_method = "dummy" - def _use_aclgraph(self) -> bool: return self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE and self.compilation_config.level == CompilationLevel.PIECEWISE and not self.model_config.enforce_eager @@ -1616,6 +1612,10 @@ def _pool( kv_connector_output=kv_connector_output, ) + def _select_moe_comm_method(self, num_tokens: int) -> str: + return ("mc2" + if num_tokens <= self.mc2_tokens_capacity else "allgather") + @torch.inference_mode() def execute_model( self, @@ -1638,9 +1638,8 @@ def execute_model( intermediate_tensors) = (self._prepare_inputs( scheduler_output, intermediate_tensors)) - moe_comm_method = (self.moe_comm_method - if num_input_tokens <= self.mc2_tokens_capacity else - self.fallback_moe_comm_method) + moe_comm_method = self._select_moe_comm_method(num_input_tokens) + batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens, uniform_decode=False) aclgraph_runtime_mode, batch_descriptor = \ @@ -1969,7 +1968,6 @@ def _dummy_run( num_tokens: int, with_prefill: bool = False, is_torchair_compile: bool = False, - moe_comm_method: str = "dummy", aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, force_attention: bool = False, uniform_decode: bool = False, @@ -1987,6 +1985,8 @@ def _dummy_run( (num_tokens, num_tokens_across_dp, with_prefill, _) = self._sync_metadata_across_dp(num_tokens, with_prefill, False) + moe_comm_method = self._select_moe_comm_method(num_tokens) + # If cudagraph_mode.decode_mode() == FULL and # cudagraph_mode.seperate_routine(). This means that we are using # different graphs and/or modes for mixed prefill-decode batches vs. @@ -2494,12 +2494,10 @@ def _capture_aclgraphs(self, compilation_cases: list[int], self._dummy_run(num_tokens, aclgraph_runtime_mode=CUDAGraphMode.NONE, force_attention=force_attention, - uniform_decode=uniform_decode, - moe_comm_method=self.moe_comm_method) + uniform_decode=uniform_decode) self._dummy_run(num_tokens, aclgraph_runtime_mode=aclgraph_runtime_mode, - uniform_decode=uniform_decode, - moe_comm_method=self.moe_comm_method) + uniform_decode=uniform_decode) def _capture_model(self): if not self.use_aclgraph: