diff --git a/tests/e2e/multicard/moe/test_moe_comm.py b/tests/e2e/multicard/moe/test_moe_comm.py index b1de5e680f9..2b09f57dbcc 100644 --- a/tests/e2e/multicard/moe/test_moe_comm.py +++ b/tests/e2e/multicard/moe/test_moe_comm.py @@ -18,29 +18,30 @@ import pytest import torch -from transformers import PretrainedConfig -from vllm import forward_context -from vllm_ascend.distributed import moe_comm_method -from vllm_ascend.distributed.moe_comm_method import (AllGatherCommImpl, - NativeAllGatherCommImpl) +from vllm.model_executor.layers.fused_moe.config import ( # isort: skip + FusedMoEConfig, FusedMoEParallelConfig) + +from vllm_ascend.distributed.moe_comm_method import ( # isort: skip + AllGatherCommImpl, NativeAllGatherCommImpl) @pytest.mark.parametrize("num_tokens", [16, 128]) @pytest.mark.parametrize("hidden_size", [64, 128]) @pytest.mark.parametrize("global_num_experts", [8, 16]) +@pytest.mark.parametrize("num_local_experts", [4, 8]) @pytest.mark.parametrize("top_k_num", [2, 4]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) -@pytest.mark.parametrize("num_local_experts", [4, 8]) @pytest.mark.parametrize("ep_rank", [0, 1]) def test_all_gather_comm_impl( num_tokens, hidden_size, global_num_experts, + num_local_experts, top_k_num, dtype, - num_local_experts, ep_rank, + mocker, ): """ Tests the AllGatherCommImpl against the NativeAllGatherCommImpl. @@ -56,23 +57,37 @@ def test_all_gather_comm_impl( "num_local_experts cannot be greater than global_num_experts") device = torch.device("npu") - hf_config = PretrainedConfig( - num_experts_per_tok=top_k_num, + + # mock get_tensor_model_parallel_rank to return ep_rank + mocker.patch( + "vllm.model_executor.layers.fused_moe.config.get_tensor_model_parallel_rank", + return_value=ep_rank, + ) + + # make moe config + parallel_config = SimpleNamespace( + enable_expert_parallel=num_local_experts < global_num_experts) + moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make( + tp_size_=max(2, global_num_experts // num_local_experts), + dp_size_=1, + vllm_parallel_config=parallel_config, + ) + + moe_config = FusedMoEConfig( num_experts=global_num_experts, + experts_per_token=top_k_num, + hidden_dim=hidden_size, + num_local_experts=num_local_experts, + moe_parallel_config=moe_parallel_config, + in_dtype=dtype, + quant_config=None, # No quantization in this test + max_num_tokens=num_tokens, ) # Instantiate implementations - native_impl = NativeAllGatherCommImpl(device, dtype, hf_config) - - all_gather_impl = AllGatherCommImpl(device, dtype, hf_config) + native_impl = NativeAllGatherCommImpl(moe_config) - # TODO: Find out if this is the correct way to mock the forward context and ep group - # Mock get_forward_context to return an object with moe_comm_method - forward_context._forward_context = SimpleNamespace( - moe_comm_method=all_gather_impl) - # Mock get_ep_group to return a fake group with the specified ep_rank - fake_ep_group = SimpleNamespace(rank_in_group=ep_rank) - moe_comm_method.get_ep_group = lambda: fake_ep_group + all_gather_impl = AllGatherCommImpl(moe_config) # --- Input Data --- hidden_states = torch.randn(num_tokens, @@ -103,11 +118,11 @@ def test_all_gather_comm_impl( native_permuted_hidden, native_expert_tokens, _, - ) = native_impl._pre_process(hidden_states, topk_ids, topk_weights, - expert_map, num_experts) + ) = native_impl.permute(hidden_states, topk_ids, topk_weights, expert_map, + num_experts) # Simulate MLP output native_mlp_output = torch.randn_like(native_permuted_hidden) - native_impl._post_process(native_mlp_output, native_hidden_states_out) + native_impl.unpermute(native_mlp_output, native_hidden_states_out) # --- Run AllGather Implementation --- all_gather_hidden_states_out = hidden_states.clone() @@ -115,15 +130,14 @@ def test_all_gather_comm_impl( all_gather_permuted_hidden, all_gather_expert_tokens, _, - ) = torch.ops.vllm.moe_comm_pre_process(hidden_states, topk_ids, - topk_weights, expert_map, - num_experts) + ) = all_gather_impl.permute(hidden_states, topk_ids, topk_weights, + expert_map, num_experts) # Use the same simulated MLP output for a fair comparison all_gather_mlp_output = native_mlp_output.clone() - torch.ops.vllm.moe_comm_post_process(all_gather_mlp_output, - all_gather_hidden_states_out) + all_gather_impl.unpermute(all_gather_mlp_output, + all_gather_hidden_states_out) # --- Assertions --- # Define tolerance based on dtype diff --git a/tests/ut/distributed/test_communicator.py b/tests/ut/distributed/test_communicator.py index 880cb246ea7..edaae2a1025 100644 --- a/tests/ut/distributed/test_communicator.py +++ b/tests/ut/distributed/test_communicator.py @@ -1,5 +1,5 @@ import unittest -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import MagicMock, patch import torch import torch.distributed as dist @@ -87,69 +87,3 @@ def patched_all_to_all(output_tensor_list, output = comm.all_to_all(input_, scatter_dim=0, gather_dim=0) assert output.tolist() == [[10, 20], [50, 60]] - - @patch("vllm.config.get_current_vllm_config", return_value=None) - @patch("torch.npu.current_device", return_value=MagicMock()) - @patch("torch.npu.set_device", return_value=MagicMock()) - @patch("torch.distributed.get_process_group_ranks", - return_value={ - 0: 0, - 1: 1 - }) - @patch("torch.distributed.get_group_rank", return_value={0: 0, 1: 1}) - @patch("torch.distributed.is_initialized", return_value=True) - @patch("torch.distributed.get_rank", return_value=1) - @patch("torch.distributed.is_initialized", return_value=True) - @patch("torch.distributed.get_backend", return_value="hccl") - @patch("torch.distributed.get_rank", return_value=1) - @patch("torch.distributed.get_world_size", return_value=2) - @patch("torch.distributed.get_process_group_ranks", return_value=[0, 1]) - @patch("torch.npu.device") - def test_dispatch(self, *_): - comm = NPUCommunicator(cpu_group=dist.group.WORLD) - comm.all2all_manager = Mock() - hidden_states = torch.randn(2, 4, 8) - router_logits = torch.randn(2, 4, 2) - - mock_dispatch_result = (torch.randn(2, 4, 8), torch.randn(2, 4, 2)) - comm.all2all_manager.dispatch.return_value = mock_dispatch_result - - result_hidden, result_logits = comm.dispatch(hidden_states, - router_logits) - - assert torch.allclose(result_hidden, mock_dispatch_result[0]) - assert torch.allclose(result_logits, mock_dispatch_result[1]) - - comm.all2all_manager.dispatch.assert_called_once_with( - hidden_states, router_logits) - - @patch("vllm.config.get_current_vllm_config", return_value=None) - @patch("torch.npu.current_device", return_value=MagicMock()) - @patch("torch.npu.set_device", return_value=MagicMock()) - @patch("torch.distributed.get_process_group_ranks", - return_value={ - 0: 0, - 1: 1 - }) - @patch("torch.distributed.get_group_rank", return_value={0: 0, 1: 1}) - @patch("torch.distributed.is_initialized", return_value=True) - @patch("torch.distributed.get_rank", return_value=1) - @patch("torch.distributed.is_initialized", return_value=True) - @patch("torch.distributed.get_backend", return_value="hccl") - @patch("torch.distributed.get_rank", return_value=1) - @patch("torch.distributed.get_world_size", return_value=2) - @patch("torch.distributed.get_process_group_ranks", return_value=[0, 1]) - @patch("torch.npu.device") - def test_combine(self, *_): - comm = NPUCommunicator(cpu_group=dist.group.WORLD) - comm.all2all_manager = Mock() - hidden_states = torch.randn(2, 4, 8) - - mock_combine_result = torch.randn(2, 4, 8) - comm.all2all_manager.combine.return_value = mock_combine_result - - result = comm.combine(hidden_states) - - assert torch.allclose(result, mock_combine_result) - - comm.all2all_manager.combine.assert_called_once_with(hidden_states) diff --git a/tests/ut/test_utils.py b/tests/ut/test_utils.py index 8166e17a0b8..b1440164b73 100644 --- a/tests/ut/test_utils.py +++ b/tests/ut/test_utils.py @@ -289,13 +289,13 @@ def test_register_ascend_customop(self, mock_ascend_rmsnorm, # ascend custom op is not registered utils.register_ascend_customop() # should call register_oot three - self.assertEqual(mock_customop.register_oot.call_count, 8) + self.assertEqual(mock_customop.register_oot.call_count, 9) self.assertTrue(utils._ASCEND_CUSTOMOP_IS_REIGISTERED) # ascend custom op is already registered utils.register_ascend_customop() # should not register_oot again, thus only called three in this ut - self.assertEqual(mock_customop.register_oot.call_count, 8) + self.assertEqual(mock_customop.register_oot.call_count, 9) class TestProfileExecuteDuration(TestBase): diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index 700fafdca90..def0d353092 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -11,7 +11,6 @@ set_forward_context) import vllm_ascend.envs as envs_ascend -from vllm_ascend.distributed.moe_comm_method import MoECommMethod class FusedMoEState(Enum): @@ -57,7 +56,7 @@ def set_ascend_forward_context( with_prefill: bool = True, in_profile_run: bool = False, reserved_mc2_mask: Optional[torch.Tensor] = None, - moe_comm_method: Optional[MoECommMethod] = None, + moe_comm_method: str = "", num_actual_tokens: Optional[int] = None, aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, batch_descriptor: Optional[BatchDescriptor] = None): @@ -75,7 +74,7 @@ def set_ascend_forward_context( batch_descriptor=batch_descriptor, ): forward_context = get_forward_context() - forward_context.moe_comm_method = moe_comm_method + forward_context.moe_comm_method_name = moe_comm_method + "commimpl" forward_context.with_prefill = with_prefill ep_size = (get_ep_group().world_size if vllm_config.parallel_config.enable_expert_parallel else 1) diff --git a/vllm_ascend/distributed/communicator.py b/vllm_ascend/distributed/communicator.py index 79adc89c798..7c14befa804 100644 --- a/vllm_ascend/distributed/communicator.py +++ b/vllm_ascend/distributed/communicator.py @@ -20,7 +20,6 @@ import torch.distributed as dist from vllm.distributed.device_communicators.base_device_communicator import \ DeviceCommunicatorBase -from vllm.utils import logger class NPUCommunicator(DeviceCommunicatorBase): @@ -35,12 +34,6 @@ def __init__(self, # init device according to rank self.device = torch.npu.current_device() - if self.use_all2all: - from vllm.distributed.device_communicators.all2all import \ - NaiveAll2AllManager - self.all2all_manager = NaiveAll2AllManager(self.cpu_group) - logger.info("Using naive all2all manager.") - def all_to_all(self, input_: torch.Tensor, scatter_dim: int = 0, @@ -80,17 +73,3 @@ def all_to_all(self, dist.all_to_all(output_list, input_list, group=self.device_group) output_tensor = torch.cat(output_list, dim=gather_dim).contiguous() return output_tensor - - # TODO: Add ut for dispatch and combine - def dispatch( - self, hidden_states: torch.Tensor, - router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - assert self.all2all_manager is not None - hidden_states, router_logits = self.all2all_manager.dispatch( - hidden_states, router_logits) - return hidden_states, router_logits - - def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: - assert self.all2all_manager is not None - hidden_states = self.all2all_manager.combine(hidden_states) - return hidden_states diff --git a/vllm_ascend/distributed/moe_comm_method.py b/vllm_ascend/distributed/moe_comm_method.py index f347ab06cb4..02f6d52aff8 100644 --- a/vllm_ascend/distributed/moe_comm_method.py +++ b/vllm_ascend/distributed/moe_comm_method.py @@ -1,12 +1,18 @@ from abc import ABC, abstractmethod +from typing import Optional import torch +import torch.distributed as dist +import torch.nn as nn import torch_npu -from transformers.configuration_utils import PretrainedConfig -from vllm.distributed.parallel_state import get_ep_group, get_tp_group -from vllm.forward_context import ForwardContext, get_forward_context -from vllm.utils import direct_register_custom_op - +from vllm.distributed import tensor_model_parallel_all_reduce +from vllm.distributed.parallel_state import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.fused_moe import FusedMoEConfig + +from vllm_ascend.distributed.communication_op import \ + data_parallel_reduce_scatter from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.utils import AscendSocVersion, get_ascend_soc_version @@ -14,26 +20,34 @@ class MoECommMethod(ABC): """Base class for MoE communication methods.""" - def __init__( - self, - device: torch.device, - dtype: torch.dtype, - hf_config: PretrainedConfig, - ): - self.device = device - self.dtype = dtype - self.top_k_num = getattr(hf_config, "num_experts_per_tok", 0) - # global_num_experts may be called num_experts or n_routed_experts in different models. - possible_keys = ["num_experts", "n_routed_experts"] - for key in possible_keys: - if hasattr(hf_config, key): - self.global_num_experts = getattr(hf_config, key) - break - else: - self.global_num_experts = 0 + def __init__(self, moe_config: FusedMoEConfig): + self.moe_config = moe_config + + @abstractmethod + def prepare( + self, hidden_states: torch.Tensor, + router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Prepare the MoE communication method. + + This method is called before quant_method.apply to prepare the + communication method. It can be used to initialize any necessary + resources or configurations. + """ + pass + + @abstractmethod + def finalize(self, hidden_states: torch.Tensor, + reduce_results: bool) -> torch.Tensor: + """Finalize the MoE communication method. + + This method is called after quant_method.apply to finalize the + communication method. It can be used to clean up any resources or + configurations. + """ + pass @abstractmethod - def _pre_process( + def permute( self, hidden_states: torch.Tensor, topk_ids: torch.Tensor, @@ -67,8 +81,8 @@ def _pre_process( pass @abstractmethod - def _post_process(self, mlp_output: torch.Tensor, - hidden_states: torch.Tensor) -> None: + def unpermute(self, mlp_output: torch.Tensor, + hidden_states: torch.Tensor) -> None: """Post-process after MLP. Args: @@ -82,7 +96,18 @@ def _post_process(self, mlp_output: torch.Tensor, class DummyCommImpl(MoECommMethod): - def _pre_process( + def prepare( + self, hidden_states: torch.Tensor, + router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Dummy prepare method that does nothing.""" + return hidden_states, router_logits + + def finalize(self, hidden_states: torch.Tensor, + reduce_results: bool) -> torch.Tensor: + """Dummy finalize method that does nothing.""" + return hidden_states + + def permute( self, hidden_states: torch.Tensor, topk_ids: torch.Tensor, @@ -90,17 +115,133 @@ def _pre_process( expert_map: torch.Tensor, num_experts: int, ) -> tuple[torch.Tensor, torch.Tensor, int]: - """Dummy implementation, see moe_comm_pre_process_fake for details.""" - return moe_comm_pre_process_fake(hidden_states, topk_ids, topk_weights, - expert_map, num_experts) + """Dummy implementation, make sure the output shapes are correct.""" + top_k_num = topk_ids.shape[1] + permuted_hidden_states = hidden_states.repeat_interleave(top_k_num, + dim=0) + expert_tokens = torch.zeros((num_experts, ), + dtype=torch.int64, + device=hidden_states.device) + group_list_type = 0 + return permuted_hidden_states, expert_tokens, group_list_type - def _post_process(self, mlp_output: torch.Tensor, - hidden_states: torch.Tensor) -> None: + def unpermute(self, mlp_output: torch.Tensor, + hidden_states: torch.Tensor) -> None: """Dummy implementation that does nothing.""" pass -class NativeAllGatherCommImpl(MoECommMethod): +class AllGatherCommImpl(MoECommMethod): + """This implementation is the same as NativeAllGatherCommImpl, + but uses NPU-specific ops for better performance. + + This implementation should be compatible with all scenarios, and + thus it is the default implementation for MoE communication methods. + It uses `torch_npu.npu_moe_init_routing_v2` for pre-processing + and `torch_npu.npu_moe_token_unpermute` for post-processing + to handle the token-to-expert mapping and communication efficiently. + + NOTE(Yizhou): TBH, it is really weird that we were supposed to use + `torch_npu.npu_moe_init_routing_v2` and `torch_npu.npu_moe_finalize_routing` + or `torch_npu.npu_moe_token_permute` and `torch_npu.npu_moe_token_unpermute` + for pre-processing and post-processing, respectively. + But `npu_moe_finalize_routing` will lead to accuracy issues so we have to + use `torch_npu.npu_moe_token_unpermute` instead. + This is a workaround and should be removed after the issue is fixed. + """ + + def prepare( + self, hidden_states: torch.Tensor, + router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """When DP size > 1, pad the hidden states and router logits for communication.""" + if self.moe_config.dp_size > 1: + forward_context = get_forward_context() + max_tokens_across_dp = forward_context.max_tokens_across_dp + + self.num_tokens = hidden_states.shape[0] + pad_size = max_tokens_across_dp - self.num_tokens + if pad_size > 0: + hidden_states = nn.functional.pad(hidden_states, + (0, 0, 0, pad_size)) + router_logits = nn.functional.pad(router_logits, + (0, 0, 0, pad_size)) + + hidden_states = self.moe_config.dp_group.all_gather( + hidden_states, 0) + router_logits = self.moe_config.dp_group.all_gather( + router_logits, 0) + + return hidden_states, router_logits + + def finalize(self, hidden_states: torch.Tensor, + reduce_results: bool) -> torch.Tensor: + """When DP size > 1, reduce-scatter the hidden states to get the final output. + + When TP size > 1, all-reduce the hidden states to get the final output. + """ + if self.moe_config.dp_size > 1: + hidden_states = data_parallel_reduce_scatter(hidden_states, dim=0) + hidden_states = hidden_states[:self.num_tokens] + + if reduce_results and (self.moe_config.tp_size > 1 + or self.moe_config.ep_size > 1): + hidden_states = tensor_model_parallel_all_reduce(hidden_states) + + return hidden_states + + def permute( + self, + hidden_states: torch.Tensor, + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + expert_map: torch.Tensor, # noqa: F841 + num_experts: int, + ) -> tuple[torch.Tensor, torch.Tensor, int]: + num_tokens = hidden_states.shape[0] + + self.topk_weights = topk_weights + self.topk_ids = topk_ids + + first_expert_idx = 0 + if expert_map is not None: + # FIXME: npu_grouped_matmul output random values at [num_valid_tokens:, ...] + # So we need to filter out invalid tokens by zeroing their weights. + # This is a workaround and should be removed after the issue is fixed + mask = expert_map[topk_ids] != -1 + # NOTE: This is equivalent to self.topk_weights[~mask] = 0.0, + # but ~mask will dispatch to aclnnNonzeroV2, which is not supported in ACL Graph + self.topk_weights = torch.where(mask, topk_weights, 0.0) + + first_expert_idx = self.moe_config.ep_rank * num_experts + last_expert_idx = first_expert_idx + num_experts + + permuted_hidden_states, expanded_row_idx, expert_tokens, _ = ( + torch_npu.npu_moe_init_routing_v2( + hidden_states, + topk_ids, + active_num=num_tokens * self.moe_config.experts_per_token, + expert_num=self.moe_config.num_experts, + expert_tokens_num_type=1, # Only support `count` mode now + expert_tokens_num_flag=True, # Output `expert_tokens` + active_expert_range=[first_expert_idx, last_expert_idx], + quant_mode=-1, + )) + self.expanded_row_idx = expanded_row_idx + permuted_hidden_states = permuted_hidden_states + + group_list_type = 1 # `count` mode + + return permuted_hidden_states, expert_tokens, group_list_type + + def unpermute(self, mlp_output: torch.Tensor, + hidden_states: torch.Tensor) -> None: + hidden_states[:] = torch_npu.npu_moe_token_unpermute( + permuted_tokens=mlp_output, + sorted_indices=self.expanded_row_idx, + probs=self.topk_weights) + + +class NativeAllGatherCommImpl(AllGatherCommImpl): """This implementation should be compatible with all scenarios. Note that this implementation purely consists of native PyTorch ops @@ -108,7 +249,7 @@ class NativeAllGatherCommImpl(MoECommMethod): But it is a good fallback for scenarios where NPU-specific ops are not available. """ - def _pre_process( + def permute( self, hidden_states: torch.Tensor, topk_ids: torch.Tensor, @@ -120,10 +261,10 @@ def _pre_process( # Generate token indices and flatten token_indices = torch.arange(num_tokens, - device=self.device, + device=hidden_states.device, dtype=torch.int64) token_indices = (token_indices.unsqueeze(1).expand( - -1, self.top_k_num).reshape(-1)) + -1, self.moe_config.experts_per_token).reshape(-1)) # Flatten token-to-expert mappings and map to local experts weights_flat = topk_weights.view(-1) @@ -138,7 +279,7 @@ def _pre_process( # This is a workaround and should be removed after the issue is fixed filtered_weights = torch.where(mask, weights_flat, torch.zeros_like(weights_flat)).to( - self.dtype) + topk_weights.dtype) filtered_experts = torch.where( mask, local_experts_flat, @@ -154,7 +295,7 @@ def _pre_process( # This is equivalent to but faster than: # >>> token_counts = torch.bincount(filtered_experts, minlength=num_experts)[:-1] token_counts = torch.zeros(num_experts + 1, - device=self.device, + device=hidden_states.device, dtype=torch.int64) ones = torch.ones_like(filtered_experts, dtype=torch.int64) token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones) @@ -167,8 +308,8 @@ def _pre_process( return permuted_hidden_states, expert_tokens, group_list_type - def _post_process(self, mlp_output: torch.Tensor, - hidden_states: torch.Tensor) -> None: + def unpermute(self, mlp_output: torch.Tensor, + hidden_states: torch.Tensor) -> None: mlp_output = mlp_output * self.sorted_weights.unsqueeze(1) final_hidden_states = torch.zeros_like(hidden_states) @@ -178,77 +319,6 @@ def _post_process(self, mlp_output: torch.Tensor, hidden_states[:] = final_hidden_states -class AllGatherCommImpl(MoECommMethod): - """This implementation is the same as NativeAllGatherCommImpl, - but uses NPU-specific ops for better performance. - - This implementation should be compatible with all scenarios, and - thus it is the default implementation for MoE communication methods. - It uses `torch_npu.npu_moe_init_routing_v2` for pre-processing - and `torch_npu.npu_moe_token_unpermute` for post-processing - to handle the token-to-expert mapping and communication efficiently. - - NOTE(Yizhou): TBH, it is really weird that we were supposed to use - `torch_npu.npu_moe_init_routing_v2` and `torch_npu.npu_moe_finalize_routing` - or `torch_npu.npu_moe_token_permute` and `torch_npu.npu_moe_token_unpermute` - for pre-processing and post-processing, respectively. - But `npu_moe_finalize_routing` will lead to accuracy issues so we have to - use `torch_npu.npu_moe_token_unpermute` instead. - This is a workaround and should be removed after the issue is fixed. - """ - - def _pre_process( - self, - hidden_states: torch.Tensor, - topk_ids: torch.Tensor, - topk_weights: torch.Tensor, - expert_map: torch.Tensor, # noqa: F841 - num_experts: int, - ) -> tuple[torch.Tensor, torch.Tensor, int]: - num_tokens = hidden_states.shape[0] - - self.topk_weights = topk_weights - self.topk_ids = topk_ids - - first_expert_idx = 0 - if expert_map is not None: - # FIXME: npu_grouped_matmul output random values at [num_valid_tokens:, ...] - # So we need to filter out invalid tokens by zeroing their weights. - # This is a workaround and should be removed after the issue is fixed - mask = expert_map[topk_ids] != -1 - # NOTE: This is equivalent to self.topk_weights[~mask] = 0.0, - # but ~mask will dispatch to aclnnNonzeroV2, which is not supported in ACL Graph - self.topk_weights = torch.where(mask, topk_weights, 0.0) - - first_expert_idx = get_ep_group().rank_in_group * num_experts - last_expert_idx = first_expert_idx + num_experts - - permuted_hidden_states, expanded_row_idx, expert_tokens, _ = ( - torch_npu.npu_moe_init_routing_v2( - hidden_states, - topk_ids, - active_num=num_tokens * self.top_k_num, - expert_num=self.global_num_experts, - expert_tokens_num_type=1, # Only support `count` mode now - expert_tokens_num_flag=True, # Output `expert_tokens` - active_expert_range=[first_expert_idx, last_expert_idx], - quant_mode=-1, - )) - self.expanded_row_idx = expanded_row_idx - permuted_hidden_states = permuted_hidden_states - - group_list_type = 1 # `count` mode - - return permuted_hidden_states, expert_tokens, group_list_type - - def _post_process(self, mlp_output: torch.Tensor, - hidden_states: torch.Tensor) -> None: - hidden_states[:] = torch_npu.npu_moe_token_unpermute( - permuted_tokens=mlp_output, - sorted_indices=self.expanded_row_idx, - probs=self.topk_weights) - - class MC2CommImpl(MoECommMethod): """This implementation is for the scenarios listed below: 1. `enable_expert_parallel=True`. @@ -259,40 +329,83 @@ class MC2CommImpl(MoECommMethod): Communication and Computation parallelism on Ascend devices. """ - def __init__( - self, - device: torch.device, - dtype: torch.dtype, - hf_config: PretrainedConfig, - ): - super().__init__(device, dtype, hf_config) - - # Shared communication configurations - ep_group = get_mc2_group() - self.ep_rank_id = ep_group.rank_in_group - self.ep_world_size = ep_group.world_size - self.tp_world_size = get_tp_group().world_size - - device_group = ep_group.device_group - local_rank = torch.distributed.get_rank(group=device_group) - backend = device_group._get_backend(torch.device("npu")) - self.moe_all_to_all_group_name = backend.get_hccl_comm_name(local_rank) + def __init__(self, moe_config: Optional[FusedMoEConfig]): + super().__init__(moe_config) + + # NOTE: We do not need to use mc2_group's rank and world size + # because ep_group and mc2_group basically have the same init params. + # We only init another group because of the restriction of MC2: + # "No other groups can be used in the same process as the MC2 group." + self.mc2_comm_name = get_mc2_group().device_group._get_backend( + torch.device("npu")).get_hccl_comm_name(self.moe_config.ep_rank) # Feature flags self.enable_dispatch_v2 = hasattr(torch_npu, "npu_moe_distribute_dispatch_v2") self.is_ascend_a3 = get_ascend_soc_version() == AscendSocVersion.A3 - self.need_extra_args = self.is_ascend_a3 # or is_torchair + self.need_extra_args = self.is_ascend_a3 + self._restore_tp_across_dp() + + def _restore_tp_across_dp(self): + # NOTE: Since vLLM flatten tp across dp, we need to restore the original + # tp_size and tp_rank. + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + + def prepare( + self, hidden_states: torch.Tensor, + router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """The target_pad_length is calculated in forward_context, here we pad the + hidden states and router logits. And if TP size > 1, we also need to split + the tensors accordingly. + """ + self.num_tokens, _ = hidden_states.shape + forward_context = get_forward_context() + self.mc2_mask = forward_context.mc2_mask + target_pad_length = forward_context.padded_num_tokens + pad_size = target_pad_length - self.num_tokens + + if pad_size > 0: + hidden_states = nn.functional.pad(hidden_states, + (0, 0, 0, pad_size)) + router_logits = nn.functional.pad(router_logits, + (0, 0, 0, pad_size)) + + if self.tp_size > 1: + split_hidden_states = torch.tensor_split(hidden_states, + self.tp_size, + dim=0) + split_router_logits = torch.tensor_split(router_logits, + self.tp_size, + dim=0) + split_mc2_mask = torch.tensor_split(self.mc2_mask, + self.tp_size, + dim=0) + self.split_hidden_states = split_hidden_states + + hidden_states = split_hidden_states[self.tp_rank] + router_logits = split_router_logits[self.tp_rank] + self.mc2_mask = split_mc2_mask[self.tp_rank] + + return hidden_states, router_logits + + def finalize(self, hidden_states: torch.Tensor, + reduce_results: bool) -> torch.Tensor: + """If TP size > 1, all-gather the hidden states to get the final output. + + Also, unpad the hidden states if needed. + """ + if self.tp_size > 1: + dist.all_gather(list(self.split_hidden_states), hidden_states, + self.moe_config.tp_group.device_group) + hidden_states = torch.cat(self.split_hidden_states, dim=0) - # Intermediate tensors to be passed from pre_process to post_process - self.topk_ids = None - self.topk_weights = None - self.mc2_mask = None - self.assist_info_for_combine = None - self.ep_recv_counts = None - self.tp_recv_counts = None + if self.num_tokens < hidden_states.shape[0]: + hidden_states = hidden_states[:self.num_tokens] - def _pre_process( + return hidden_states + + def permute( self, hidden_states: torch.Tensor, topk_ids: torch.Tensor, @@ -303,25 +416,24 @@ def _pre_process( # Store tensors needed for post_process self.topk_ids = topk_ids self.topk_weights = topk_weights.to(torch.float32) - self.mc2_mask = get_forward_context().mc2_mask dispatch_kwargs = { "x": hidden_states, "expert_ids": self.topk_ids, "expert_shard_type": 0, "shared_expert_rank_num": 0, - "moe_expert_num": self.global_num_experts, + "moe_expert_num": self.moe_config.num_experts, "global_bs": 0, "scales": None, "quant_mode": 0, - "group_ep": self.moe_all_to_all_group_name, - "ep_world_size": self.ep_world_size, - "ep_rank_id": self.ep_rank_id, + "group_ep": self.mc2_comm_name, + "ep_world_size": self.moe_config.ep_size, + "ep_rank_id": self.moe_config.ep_rank, } if self.need_extra_args: dispatch_kwargs.update({ - "group_tp": self.moe_all_to_all_group_name, + "group_tp": self.mc2_comm_name, "tp_world_size": 1, "tp_rank_id": 0, }) @@ -345,20 +457,20 @@ def _pre_process( return permuted_hidden_states, expert_tokens, group_list_type - def _post_process(self, mlp_output: torch.Tensor, - hidden_states: torch.Tensor) -> None: + def unpermute(self, mlp_output: torch.Tensor, + hidden_states: torch.Tensor) -> None: combine_kwargs = { "expand_x": mlp_output, "expert_ids": self.topk_ids, "expert_scales": self.topk_weights, "expert_shard_type": 0, "shared_expert_rank_num": 0, - "moe_expert_num": self.global_num_experts, + "moe_expert_num": self.moe_config.num_experts, "global_bs": 0, "ep_send_counts": self.ep_recv_counts, - "group_ep": self.moe_all_to_all_group_name, - "ep_world_size": self.ep_world_size, - "ep_rank_id": self.ep_rank_id, + "group_ep": self.mc2_comm_name, + "ep_world_size": self.moe_config.ep_size, + "ep_rank_id": self.moe_config.ep_rank, } if self.enable_dispatch_v2: @@ -370,7 +482,7 @@ def _post_process(self, mlp_output: torch.Tensor, if self.need_extra_args: combine_kwargs.update({ "tp_send_counts": self.tp_recv_counts, - "group_tp": self.moe_all_to_all_group_name, + "group_tp": self.mc2_comm_name, "tp_world_size": 1, "tp_rank_id": 0, }) @@ -382,68 +494,3 @@ def _post_process(self, mlp_output: torch.Tensor, combine = torch_npu.npu_moe_distribute_combine_v2 if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine hidden_states[:] = combine(**combine_kwargs) - - -def moe_comm_pre_process( - hidden_states: torch.Tensor, - topk_ids: torch.Tensor, - topk_weights: torch.Tensor, - expert_map: torch.Tensor, - num_experts: int, -) -> tuple[torch.Tensor, torch.Tensor, int]: - """This function is a wrapper for the pre_process method of the - MoECommMethod instance stored in the ForwardContext. So it can be - used as a custom op in the vllm framework. - """ - forward_context: ForwardContext = get_forward_context() - self = forward_context.moe_comm_method - return self._pre_process(hidden_states, topk_ids, topk_weights, expert_map, - num_experts) - - -def moe_comm_pre_process_fake( - hidden_states: torch.Tensor, - topk_ids: torch.Tensor, - topk_weights: torch.Tensor, - expert_map: torch.Tensor, - num_experts: int, -) -> tuple[torch.Tensor, torch.Tensor, int]: - """This is a fake implementation of the pre_process method. - torch.compile will use this implementation to generate FX graph. - """ - top_k_num = topk_ids.shape[1] - permuted_hidden_states = hidden_states.repeat_interleave(top_k_num, dim=0) - expert_tokens = torch.zeros((num_experts, ), - dtype=torch.int64, - device=hidden_states.device) - group_list_type = 0 - return permuted_hidden_states, expert_tokens, group_list_type - - -def moe_comm_post_process(mlp_output: torch.Tensor, - hidden_states: torch.Tensor) -> None: - """This function is a wrapper for the post_process method of the - MoECommMethod instance stored in the ForwardContext. So it can be - used as a custom op in the vllm framework. - """ - forward_context: ForwardContext = get_forward_context() - self = forward_context.moe_comm_method - self._post_process(mlp_output, hidden_states) - return - - -direct_register_custom_op( - op_name="moe_comm_pre_process", - op_func=moe_comm_pre_process, - mutates_args=[], - fake_impl=moe_comm_pre_process_fake, - dispatch_key="PrivateUse1", -) - -direct_register_custom_op( - op_name="moe_comm_post_process", - op_func=moe_comm_post_process, - mutates_args=["hidden_states"], - fake_impl=lambda x, y: None, # No-op for fake implementation - dispatch_key="PrivateUse1", -) diff --git a/vllm_ascend/models/pangu_moe.py b/vllm_ascend/models/pangu_moe.py index 1dfbcf811e9..c0b2a142bd8 100644 --- a/vllm_ascend/models/pangu_moe.py +++ b/vllm_ascend/models/pangu_moe.py @@ -497,6 +497,10 @@ def forward( router_logits, _ = self.gate(hidden_states) global _ROUTER_SCALE _ROUTER_SCALE = self.router_scale + + # TODO(angazenn): Does not support MC2 currently + get_forward_context().moe_comm_method_name = "allgathercommimpl" + if not use_h2p(): final_hidden_states = self.experts.forward_impl( hidden_states=hidden_states, router_logits=router_logits) diff --git a/vllm_ascend/ops/common_fused_moe.py b/vllm_ascend/ops/common_fused_moe.py index 19a86a7d032..cc0f735e63b 100644 --- a/vllm_ascend/ops/common_fused_moe.py +++ b/vllm_ascend/ops/common_fused_moe.py @@ -15,22 +15,84 @@ # limitations under the License. # -from typing import Callable, Optional +from typing import Any, Callable, Optional import torch from vllm.config import CompilationLevel, get_current_vllm_config +from vllm.distributed import get_dp_group, get_ep_group, get_tp_group from vllm.forward_context import get_forward_context -from vllm.model_executor.layers.fused_moe.layer import \ - UnquantizedFusedMoEMethod +from vllm.model_executor.layers.fused_moe.layer import ( + FusedMoE, UnquantizedFusedMoEMethod) from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.ops.fused_moe import fused_experts_moge, unified_fused_experts +from vllm_ascend.distributed.moe_comm_method import (AllGatherCommImpl, + DummyCommImpl, + MC2CommImpl, + MoECommMethod) +from vllm_ascend.distributed.parallel_state import get_mc2_group +from vllm_ascend.ops.fused_moe import apply_mlp, fused_experts_moge from vllm_ascend.ops.layers.experts_selector import select_experts from vllm_ascend.utils import is_310p original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__ +def fused_experts( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + use_int8_w8a8: bool = False, + use_int4_w4a8: bool = False, + global_num_experts: Optional[int] = None, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_scale_bias: torch.Tensor = None, + w2_scale_bias: torch.Tensor = None, + moe_comm_method: Optional[MoECommMethod] = None, + # For TorchAir graph + is_torchair: bool = False, + # For Cube/Vector parallel + shared_experts: Optional[Any] = None, + quantized_x_for_share: Optional[Any] = None, + dynamic_scale_for_share: Optional[Any] = None, + # For load balance + log2phy: torch.Tensor = None, + global_redundant_expert_num: int = 0, +) -> torch.Tensor: + # Check constraints + assert hidden_states.shape[1] == w1.shape[2], ( + f"Hidden size mismatch {hidden_states.shape[1]} != {w1.shape[2]}") + + assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert w1.stride(-1) == 1, "Stride of last dimension must be 1" + assert w2.stride(-1) == 1, "Stride of last dimension must be 1" + assert hidden_states.dtype in [ + torch.float32, torch.float16, torch.bfloat16 + ] + assert moe_comm_method is not None, "Missing communication context" + + num_experts = w1.shape[0] + + permuted_hidden_states, expert_tokens, group_list_type = moe_comm_method.permute( + hidden_states, topk_ids, topk_weights, expert_map, num_experts) + mlp_output = apply_mlp( + permuted_hidden_states, + w1, + w2, + expert_tokens, + group_list_type=group_list_type, + ) + moe_comm_method.unpermute(mlp_output, hidden_states) + + return hidden_states + + def unquantized_fused_moe_init_func(self, *args, **kwargs): original_unquantized_fused_moe_init_func(self, *args, **kwargs) vllm_config = get_current_vllm_config() @@ -97,7 +159,7 @@ def forward_oot( moe_comm_method = get_forward_context().moe_comm_method - return unified_fused_experts( + return fused_experts( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, @@ -109,5 +171,112 @@ def forward_oot( ) +class AscendFusedMoE(FusedMoE): + + def __init__( + self, + num_experts, + top_k, + hidden_size, + intermediate_size, + params_dtype=None, + reduce_results=False, + renormalize=True, + use_grouped_topk=False, + num_expert_group=None, + topk_group=None, + quant_config=None, + tp_size=None, + ep_size=None, + dp_size=None, + prefix="", + custom_routing_function=None, + scoring_func="softmax", + e_score_correction_bias=None, + apply_router_weight_on_input=False, + activation="silu", + enable_eplb=False, + num_redundant_experts=0, + has_bias=False, + ): + super().__init__( + num_experts, + top_k, + hidden_size, + intermediate_size, + params_dtype, + reduce_results, + renormalize, + use_grouped_topk, + num_expert_group, + topk_group, + quant_config, + tp_size, + ep_size, + dp_size, + prefix, + custom_routing_function, + scoring_func, + e_score_correction_bias, + apply_router_weight_on_input, + activation, + enable_eplb, + num_redundant_experts, + has_bias, + ) + + self.moe_config.tp_group = get_tp_group() + self.moe_config.dp_group = get_dp_group() + self.moe_config.ep_group = get_ep_group() + self.moe_config.mc2_group = get_mc2_group() + + for method in {AllGatherCommImpl, DummyCommImpl, MC2CommImpl}: + setattr( + self, method.__name__.lower(), + method(moe_config=self.moe_config)) # type: ignore[abstract] + + def forward_impl(self, hidden_states: torch.Tensor, + router_logits: torch.Tensor): + assert self.quant_method is not None + + forward_context = get_forward_context() + moe_comm_method_name = forward_context.moe_comm_method_name + if not self.moe_config.use_ep and moe_comm_method_name != "dummycommimpl": + moe_comm_method_name = "allgathercommimpl" + forward_context.moe_comm_method = getattr(self, moe_comm_method_name) + + hidden_states, router_logits = forward_context.moe_comm_method.prepare( + hidden_states=hidden_states, router_logits=router_logits) + + # Matrix multiply. + final_hidden_states = self.quant_method.apply( + layer=self, + x=hidden_states, + router_logits=router_logits, + top_k=self.top_k, + renormalize=self.renormalize, + use_grouped_topk=self.use_grouped_topk, + global_num_experts=self.global_num_experts, + expert_map=self.expert_map, + topk_group=self.topk_group, + num_expert_group=self.num_expert_group, + custom_routing_function=self.custom_routing_function, + scoring_func=self.scoring_func, + e_score_correction_bias=self.e_score_correction_bias, + activation=self.activation, + apply_router_weight_on_input=self.apply_router_weight_on_input, + enable_eplb=self.enable_eplb, + expert_load_view=self.expert_load_view, + logical_to_physical_map=self.logical_to_physical_map, + logical_replica_count=self.logical_replica_count, + ) + + final_hidden_states = forward_context.moe_comm_method.finalize( + hidden_states=final_hidden_states, + reduce_results=self.reduce_results) + + return final_hidden_states + + UnquantizedFusedMoEMethod.__init__ = unquantized_fused_moe_init_func UnquantizedFusedMoEMethod.forward_oot = forward_oot diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 611935cbef6..70e87dc28f0 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -43,7 +43,6 @@ from vllm_ascend.ascend_forward_context import FusedMoEState from vllm_ascend.distributed.communication_op import \ data_parallel_reduce_scatter -from vllm_ascend.distributed.moe_comm_method import MoECommMethod from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer from vllm_ascend.ops.layers.experts_selector import select_experts @@ -58,60 +57,6 @@ MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER -def unified_fused_experts( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - activation: str = "silu", - apply_router_weight_on_input: bool = False, - use_int8_w8a8: bool = False, - use_int4_w4a8: bool = False, - global_num_experts: Optional[int] = None, - expert_map: Optional[torch.Tensor] = None, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - w1_scale_bias: torch.Tensor = None, - w2_scale_bias: torch.Tensor = None, - moe_comm_method: Optional[MoECommMethod] = None, - # For Cube/Vector parallel - shared_experts: Optional[Any] = None, - quantized_x_for_share: Optional[Any] = None, - dynamic_scale_for_share: Optional[Any] = None, - # For load balance - log2phy: torch.Tensor = None, - global_redundant_expert_num: int = 0, -) -> torch.Tensor: - # Check constraints - assert hidden_states.shape[1] == w1.shape[2], ( - f"Hidden size mismatch {hidden_states.shape[1]} != {w1.shape[2]}") - - assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" - assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" - assert w1.stride(-1) == 1, "Stride of last dimension must be 1" - assert w2.stride(-1) == 1, "Stride of last dimension must be 1" - assert hidden_states.dtype in [ - torch.float32, torch.float16, torch.bfloat16 - ] - assert moe_comm_method is not None, "Missing communication context" - - num_experts = w1.shape[0] - - permuted_hidden_states, expert_tokens, group_list_type = torch.ops.vllm.moe_comm_pre_process( - hidden_states, topk_ids, topk_weights, expert_map, num_experts) - mlp_output = apply_mlp( - permuted_hidden_states, - w1, - w2, - expert_tokens, - group_list_type=group_list_type, - ) - torch.ops.vllm.moe_comm_post_process(mlp_output, hidden_states) - - return hidden_states - - def process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int, max_row_per_ep_rank: int, num_tokens: int, top_k: int) -> tuple[torch.Tensor, torch.Tensor]: diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 3d7ed2978b8..1273805fbee 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -509,6 +509,9 @@ def register_ascend_customop(): from vllm_ascend.ops.layernorm import AscendRMSNorm CustomOp.register_oot(_decorated_op_cls=AscendRMSNorm, name="RMSNorm") + from vllm_ascend.ops.common_fused_moe import AscendFusedMoE + CustomOp.register_oot(_decorated_op_cls=AscendFusedMoE, name="FusedMoE") + # NOTE: Keep this at last to ensure all custom actions are registered _ASCEND_CUSTOMOP_IS_REIGISTERED = True diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 03486b05557..aa72eb05f0f 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -24,7 +24,7 @@ import time from contextlib import contextmanager, nullcontext from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, List, Optional, Type, Union, cast +from typing import TYPE_CHECKING, Dict, List, Optional, Union, cast import numpy as np import numpy.typing as npt @@ -85,9 +85,6 @@ from vllm_ascend.attention.mla_v1 import AscendMLAMetadata from vllm_ascend.attention.utils import AscendCommonAttentionMetadata from vllm_ascend.compilation.acl_graph import ACLGraphWrapper -from vllm_ascend.distributed.moe_comm_method import (AllGatherCommImpl, - DummyCommImpl, - MoECommMethod) from vllm_ascend.multistream.ms_split import compute_split_seq_index from vllm_ascend.platform import NPUPlatform from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler @@ -368,13 +365,16 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.is_kv_producer = vllm_config.kv_transfer_config.is_kv_producer self.is_kv_consumer = vllm_config.kv_transfer_config.is_kv_consumer + self.mc2_tokens_capacity = 512 * self.parallel_config.tensor_parallel_size self.reserved_mc2_mask = torch.zeros( - 512, + self.mc2_tokens_capacity, dtype=torch.bool, device=self.device, ) - self.moe_comm_method = AllGatherCommImpl + self.moe_comm_method = "mc2" + self.fallback_moe_comm_method = "allgather" + self.dummy_moe_comm_method = "dummy" def _use_aclgraph(self) -> bool: return self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE and self.compilation_config.level == CompilationLevel.PIECEWISE and not self.model_config.enforce_eager @@ -1622,6 +1622,10 @@ def execute_model( intermediate_tensors) = (self._prepare_inputs( scheduler_output, intermediate_tensors)) + moe_comm_method = (self.moe_comm_method + if num_input_tokens <= self.mc2_tokens_capacity else + self.fallback_moe_comm_method) + # Run forward pass with ProfileExecuteDuration().capture_async("forward"): with set_ascend_forward_context( @@ -1631,8 +1635,7 @@ def execute_model( num_tokens_across_dp=num_tokens_across_dp, with_prefill=self.with_prefill, reserved_mc2_mask=self.reserved_mc2_mask, - moe_comm_method=self.moe_comm_method( - self.device, self.dtype, self.model_config.hf_config), + moe_comm_method=moe_comm_method, num_actual_tokens=scheduler_output. total_num_scheduled_tokens): self.maybe_setup_kv_connector(scheduler_output) @@ -1938,7 +1941,7 @@ def _dummy_run( num_tokens: int, with_prefill: bool = False, is_torchair_compile: bool = False, - moe_comm_method: Type[MoECommMethod] = DummyCommImpl, + moe_comm_method: str = "dummy", aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, force_attention: bool = False, uniform_decode: bool = False, @@ -2061,8 +2064,7 @@ def _dummy_run( with_prefill=with_prefill, in_profile_run=self.in_profile_run, reserved_mc2_mask=self.reserved_mc2_mask, - moe_comm_method=moe_comm_method( - self.device, self.dtype, self.model_config.hf_config), + moe_comm_method=moe_comm_method, num_actual_tokens=0, aclgraph_runtime_mode=aclgraph_runtime_mode, batch_descriptor=batch_descriptor):