diff --git a/tests/e2e/nightly/single_node/ops/singlecard_ops/test_fused_moe.py b/tests/e2e/nightly/single_node/ops/singlecard_ops/test_fused_moe.py index 971d931039a..8a162b91461 100644 --- a/tests/e2e/nightly/single_node/ops/singlecard_ops/test_fused_moe.py +++ b/tests/e2e/nightly/single_node/ops/singlecard_ops/test_fused_moe.py @@ -136,10 +136,10 @@ def test_token_dispatcher_with_all_gather( expert_map=expert_map, apply_router_weight_on_input=apply_router_weight_on_input) - sorted_hidden_states = dispatch_output["hidden_states"] - group_list = dispatch_output["group_list"] - group_list_type = dispatch_output.get("group_list_type", 1) - context_metadata = dispatch_output["context_metadata"] + sorted_hidden_states = dispatch_output.hidden_states + group_list = dispatch_output.group_list + group_list_type = dispatch_output.group_list_type + context_metadata = dispatch_output.context_metadata expert_output = apply_mlp(hidden_states=sorted_hidden_states, w1=w1_local, @@ -155,7 +155,7 @@ def test_token_dispatcher_with_all_gather( torch_output = torch_moe(a, w1, w2, topk_weights, topk_ids, topk, expert_map) - torch.testing.assert_close(combined_output, + torch.testing.assert_close(combined_output.routed_out, torch_output, atol=4e-2, rtol=1) @@ -216,11 +216,11 @@ def test_token_dispatcher_with_all_gather_quant( apply_router_weight_on_input=apply_router_weight_on_input, with_quant=True) - sorted_hidden_states = dispatch_output["hidden_states"] - group_list = dispatch_output["group_list"] - group_list_type = dispatch_output.get("group_list_type", 1) - dynamic_scale = dispatch_output["dynamic_scale"] - context_metadata = dispatch_output["context_metadata"] + sorted_hidden_states = dispatch_output.hidden_states + group_list = dispatch_output.group_list + group_list_type = dispatch_output.group_list_type + dynamic_scale = dispatch_output.dynamic_scale + context_metadata = dispatch_output.context_metadata expert_output = unified_apply_mlp(hidden_states=sorted_hidden_states, w1=w1, @@ -235,7 +235,7 @@ def test_token_dispatcher_with_all_gather_quant( hidden_states=expert_output, context_metadata=context_metadata, bias=None) - assert combined_output.shape == (m, k) + assert combined_output.routed_out.shape == (m, k) gc.collect() torch.npu.empty_cache() torch.npu.reset_peak_memory_stats() diff --git a/tests/ut/ops/test_moe_comm_method.py b/tests/ut/ops/test_moe_comm_method.py index 7620999a159..e40f67084bd 100644 --- a/tests/ut/ops/test_moe_comm_method.py +++ b/tests/ut/ops/test_moe_comm_method.py @@ -8,6 +8,8 @@ AlltoAllCommImpl, MC2CommImpl) from vllm_ascend.ops.fused_moe.prepare_finalize import QuantType +from vllm_ascend.ops.fused_moe.token_dispatcher import (TokenCombineResult, + TokenDispatchResult) class TestMoECommMethod(TestBase): @@ -178,12 +180,12 @@ def test_fused_experts_method(self, mock_unified_apply_mlp, # Mock token dispatcher mock_td_instance = MagicMock() - mock_td_instance.token_dispatch.return_value = { - "hidden_states": torch.randn(6, 8), - "group_list": torch.tensor([2, 2, 2]), - "group_list_type": 1 - } - mock_td_instance.token_combine.return_value = torch.randn(4, 8) + mock_td_instance.token_dispatch.return_value = TokenDispatchResult( + hidden_states=torch.randn(6, 8), + group_list=torch.tensor([2, 2, 2]), + group_list_type=1) + mock_td_instance.token_combine.return_value = TokenCombineResult( + routed_out=torch.randn(4, 8)) mock_token_dispatcher.return_value = mock_td_instance # Mock unified_apply_mlp @@ -213,7 +215,7 @@ def test_fused_experts_method(self, mock_unified_apply_mlp, activation="silu") # Verify result shape - self.assertEqual(result.shape, (4, 8)) + self.assertEqual(result.routed_out.shape, (4, 8)) # Verify token_dispatch was called mock_td_instance.token_dispatch.assert_called_once() diff --git a/tests/ut/ops/test_token_dispatcher.py b/tests/ut/ops/test_token_dispatcher.py index 140bae5cd20..027815ba0c8 100644 --- a/tests/ut/ops/test_token_dispatcher.py +++ b/tests/ut/ops/test_token_dispatcher.py @@ -97,8 +97,7 @@ def test_token_permutation_dispatch(self): topk_weights, topk_ids, expert_map) mock_dispatch.assert_called_once() - self.assertEqual(output["group_list_type"], - 0) # group_list_type == 0 + self.assertEqual(output.group_list_type, 0) # group_list_type == 0 def test_token_dispatch_with_shared_experts_and_quant(self): self.shared_experts = MagicMock() @@ -149,43 +148,6 @@ def test_get_combine_mc_kwargs_with_quant(self): context_metadata) self.assertIn("tp_send_counts", kwargs) - def test_token_combine_with_shared_experts(self): - shared_experts = MagicMock() - shared_experts.down_proj.return_value = (torch.randn(10, 128), - torch.tensor(1.0)) - - topk_ids = torch.randint(0, 8, (10, 1)) - topk_weights = torch.randn(10, 1) - expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) - ep_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) - assist_info_for_combine = torch.arange(10) - tp_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) - - context_metadata = { - "topk_ids": topk_ids, - "topk_weights": topk_weights, - "expert_map": expert_map, - "ep_recv_counts": ep_recv_counts, - "mc2_mask": None, - "assist_info_for_combine": assist_info_for_combine, - "expand_scales": None, - "shared_experts": shared_experts, - "shared_act": torch.randn(10, 128), - "swiglu_out_scale": torch.randn(10, 1), - "tp_recv_counts": tp_recv_counts - } - - self.dispatcher.with_quant = True - self.dispatcher.need_extra_args = True - self.dispatcher.enable_dispatch_v2 = True - - hidden_states = torch.randn(10, 128) - with patch("torch_npu.npu_moe_distribute_combine_v2", - return_value=torch.randn(10, 128)): - result = self.dispatcher.token_combine(hidden_states, - context_metadata) - self.assertIsInstance(result, tuple) - class TestTokenDispatcherWithAllGather(TestBase): @@ -233,7 +195,7 @@ def test_token_dispatch_without_expert_map(self): self.mock_npu_moe_init_routing_v2.assert_called_once() args, kwargs = self.mock_npu_moe_init_routing_v2.call_args - self.assertEqual(results["group_list_type"], 1) + self.assertEqual(results.group_list_type, 1) def test_token_dispatch_with_expert_map(self): self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3]) @@ -248,7 +210,7 @@ def test_token_dispatch_with_expert_map(self): self.mock_npu_moe_init_routing_v2.assert_called_once() args, kwargs = self.mock_npu_moe_init_routing_v2.call_args - self.assertEqual(results["group_list_type"], 1) + self.assertEqual(results.group_list_type, 1) def test_token_dispatch_without_quant(self): kwargs = { @@ -268,7 +230,7 @@ def test_token_dispatch_without_quant(self): topk_weights, topk_ids, None) - self.assertEqual(results["group_list_type"], 1) + self.assertEqual(results.group_list_type, 1) def test_token_dispatch_with_quant(self): kwargs = { @@ -290,10 +252,10 @@ def test_token_dispatch_with_quant(self): None, with_quant=True) - self.assertIsNotNone(results["hidden_states"]) - self.assertIsNotNone(results["group_list"]) - self.assertIsNotNone(results["dynamic_scale"]) - self.assertEqual(results["group_list_type"], 1) + self.assertIsNotNone(results.hidden_states) + self.assertIsNotNone(results.group_list) + self.assertIsNotNone(results.dynamic_scale) + self.assertEqual(results.group_list_type, 1) def test_token_combine_with_expert_map(self): hidden_states = torch.randn(6, 128) @@ -303,7 +265,7 @@ def test_token_combine_with_expert_map(self): } self.dispatcher.original_shape = (6, 128) final_hidden_states = self.dispatcher.token_combine( - hidden_states, context_metadata) + hidden_states, context_metadata).routed_out self.assertEqual(final_hidden_states.shape, (6, 128)) def test_token_combine_without_expert_map(self): @@ -314,7 +276,7 @@ def test_token_combine_without_expert_map(self): } self.dispatcher.original_shape = (6, 128) final_hidden_states = self.dispatcher.token_combine( - hidden_states, context_metadata) + hidden_states, context_metadata).routed_out self.mock_npu_moe_token_unpermute.assert_called_once() self.assertEqual(final_hidden_states.shape, (6, 128)) @@ -326,7 +288,7 @@ def test_token_dispatch_with_router_weight(self): results = self.dispatcher.token_dispatch(hidden_states, topk_weights, topk_ids, None) - self.assertEqual(results["hidden_states"].shape, (6, 128)) + self.assertEqual(results.hidden_states.shape, (6, 128)) class TestTokenDispatcherWithAll2AllV(TestBase): @@ -437,9 +399,9 @@ def test_token_dispatch(self): topk_ids=topk_ids, expert_map=expert_map) - self.assertIsNotNone(result["hidden_states"]) - self.assertIsNotNone(result["group_list"]) - self.assertEqual(result["group_list_type"], 1) + self.assertIsNotNone(result.hidden_states) + self.assertIsNotNone(result.group_list) + self.assertEqual(result.group_list_type, 1) def test_token_combine(self): hidden_states = torch.randn(16, 16) @@ -458,7 +420,7 @@ def test_token_combine(self): output = self.dispatcher.token_combine(hidden_states, context_metadata) self.assertIsNotNone(output) - self.assertEqual(output.shape, (8, 16)) + self.assertEqual(output.routed_out.shape, (8, 16)) def test_token_dispatch_with_quant(self): self.dispatcher = TokenDispatcherWithAll2AllV(top_k=2, @@ -480,10 +442,10 @@ def test_token_dispatch_with_quant(self): expert_map=expert_map, with_quant=True) - self.assertIsNotNone(result["hidden_states"]) - self.assertIsNotNone(result["group_list"]) - self.assertIsNotNone(result["dynamic_scale"]) - self.assertEqual(result["group_list_type"], 1) + self.assertIsNotNone(result.hidden_states) + self.assertIsNotNone(result.group_list) + self.assertIsNotNone(result.dynamic_scale) + self.assertEqual(result.group_list_type, 1) def test_token_dispatch_with_quant_no_active_tokens(self): self.dispatcher = TokenDispatcherWithAll2AllV(top_k=2, @@ -508,10 +470,10 @@ def test_token_dispatch_with_quant_no_active_tokens(self): expert_map=expert_map, with_quant=True) - self.assertIsNotNone(result["hidden_states"]) - self.assertIsNotNone(result["group_list"]) - self.assertIsNotNone(result["dynamic_scale"]) - self.assertEqual(result["group_list_type"], 1) + self.assertIsNotNone(result.hidden_states) + self.assertIsNotNone(result.group_list) + self.assertIsNotNone(result.dynamic_scale) + self.assertEqual(result.group_list_type, 1) def test_token_dispatch_with_log2phy(self): hidden_states = torch.randn(8, 16) @@ -530,6 +492,6 @@ def test_token_dispatch_with_log2phy(self): expert_map=expert_map, log2phy=log2phy) - self.assertIsNotNone(result["hidden_states"]) - self.assertIsNotNone(result["group_list"]) - self.assertEqual(result["group_list_type"], 1) + self.assertIsNotNone(result.hidden_states) + self.assertIsNotNone(result.group_list) + self.assertEqual(result.group_list_type, 1) diff --git a/vllm_ascend/ops/fused_moe/fused_moe.py b/vllm_ascend/ops/fused_moe/fused_moe.py index daaca8b95d7..5e818fd0452 100644 --- a/vllm_ascend/ops/fused_moe/fused_moe.py +++ b/vllm_ascend/ops/fused_moe/fused_moe.py @@ -37,6 +37,7 @@ set_flash_common3_context) from vllm_ascend.ops.fused_moe.experts_selector import select_experts from vllm_ascend.ops.fused_moe.moe_comm_method import (AllGatherCommImpl, + FusedExpertsResult, setup_moe_comm_method) from vllm_ascend.ops.fused_moe.prepare_finalize import QuantType from vllm_ascend.quantization.w4a8_dynamic import \ @@ -325,7 +326,7 @@ def forward_impl(self, hidden_states: torch.Tensor, pertoken_scale = None # Matrix multiply. - final_hidden_states = self.quant_method.apply( + fused_experts_results: FusedExpertsResult = self.quant_method.apply( layer=self, x=hidden_states, router_logits=router_logits, @@ -350,25 +351,25 @@ def forward_impl(self, hidden_states: torch.Tensor, global_redundant_expert_num=self.global_redundant_expert_num, mc2_mask=mc2_mask) - if isinstance(final_hidden_states, tuple): - final_hidden_states, group_list_type, expert_tokens = final_hidden_states - if self.dynamic_eplb: - - moe_load_stream = moe_load_async_stream() - cur_stream = torch.npu.current_stream() - - moe_load_stream.wait_stream(cur_stream) - with npu_stream_switch(moe_load_stream): - self.moe_load += expert_tokens if group_list_type == 1 else \ - torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]]) - cur_stream.wait_stream(moe_load_stream) - - final_hidden_states = forward_context.moe_comm_method.finalize( - hidden_states=final_hidden_states, + if self.dynamic_eplb: + expert_tokens = fused_experts_results.expert_tokens + group_list_type = fused_experts_results.group_list_type + assert expert_tokens is not None and group_list_type is not None, \ + "expert_tokens and group_list_type should not be None when dynamic_eplb is enabled." + moe_load_stream = moe_load_async_stream() + cur_stream = torch.npu.current_stream() + moe_load_stream.wait_stream(cur_stream) + with npu_stream_switch(moe_load_stream): + self.moe_load += expert_tokens if group_list_type == 1 else \ + torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]]) + cur_stream.wait_stream(moe_load_stream) + + routed_out = forward_context.moe_comm_method.finalize( + hidden_states=fused_experts_results.routed_out, reduce_results=self.reduce_results, context_metadata=context_metadata) - return final_hidden_states + return routed_out class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE): @@ -439,7 +440,7 @@ def forward_impl(self, hidden_states: torch.Tensor, else: set_flash_common3_context(shared_experts=self._shared_experts) - fused_output = AscendFusedMoE.forward_impl( + routed_out = AscendFusedMoE.forward_impl( self, hidden_states=hidden_states, router_logits=router_logits, @@ -462,4 +463,4 @@ def forward_impl(self, hidden_states: torch.Tensor, assert fc3_context is not None shared_out = fc3_context.shared_out - return shared_out, fused_output + return shared_out, routed_out diff --git a/vllm_ascend/ops/fused_moe/moe_comm_method.py b/vllm_ascend/ops/fused_moe/moe_comm_method.py index 30d1e5c1376..06fd2fe4415 100644 --- a/vllm_ascend/ops/fused_moe/moe_comm_method.py +++ b/vllm_ascend/ops/fused_moe/moe_comm_method.py @@ -16,6 +16,7 @@ from __future__ import annotations from abc import ABC, abstractmethod +from dataclasses import dataclass from typing import Any, Dict, Optional import torch @@ -26,11 +27,11 @@ from vllm_ascend.ascend_forward_context import MoECommType from vllm_ascend.ops.fused_moe.moe_mlp import unified_apply_mlp from vllm_ascend.ops.fused_moe.prepare_finalize import ( - PrepareAndFinalizeWithAll2All, PrepareAndFinalizeWithAllGather, - PrepareAndFinalizeWithMC2, QuantType) + PrepareAndFinalize, PrepareAndFinalizeWithAll2All, + PrepareAndFinalizeWithAllGather, PrepareAndFinalizeWithMC2, QuantType) from vllm_ascend.ops.fused_moe.token_dispatcher import ( - TokenDispatcherWithAll2AllV, TokenDispatcherWithAllGather, - TokenDispatcherWithMC2) + MoETokenDispatcher, TokenDispatcherWithAll2AllV, + TokenDispatcherWithAllGather, TokenDispatcherWithMC2) _MoECommMethods: Dict[Optional[MoECommType], MoECommMethod] = {} @@ -47,6 +48,14 @@ def setup_moe_comm_method(moe_config): _MoECommMethods[MoECommType.FUSED_MC2] = FusedMC2CommImpl(moe_config) +@dataclass +class FusedExpertsResult: + routed_out: torch.Tensor + # For dynamic_eplb + group_list_type: int | None = None + expert_tokens: torch.Tensor | None = None + + class MoECommMethod(ABC): """Base class for MoE communication methods.""" @@ -118,7 +127,7 @@ def fused_experts( moe_comm_method = get_forward_context().moe_comm_method assert moe_comm_method is not None, "Missing communication context" - results = self.token_dispatcher.token_dispatch( + dispatch_results = self.token_dispatcher.token_dispatch( hidden_states=hidden_states, topk_weights=topk_weights, topk_ids=topk_ids, @@ -134,43 +143,41 @@ def fused_experts( dynamic_eplb=dynamic_eplb, pertoken_scale=pertoken_scale) - permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type, topk_scales, context_metadata = \ - results["hidden_states"], results["group_list"], results.get("dynamic_scale"), results["group_list_type"], results.get("topk_scales"), results.get("context_metadata") - - mlp_output = unified_apply_mlp(hidden_states=permuted_hidden_states, - w1=w1, - w1_scale=w1_scale, - w2=w2, - w2_scale=w2_scale, - group_list=expert_tokens, - dynamic_scale=dynamic_scale, - group_list_type=group_list_type, - w1_scale_bias=w1_scale_bias, - w2_scale_bias=w2_scale_bias, - w1_offset=w1_offset, - w2_offset=w2_offset, - topk_scales=topk_scales, - with_quant=use_int8_w8a8 - or use_int4_w4a8 or use_int4_w4a16, - fusion=use_int8_w8a8, - need_trans=need_trans, - dynamic_eplb=dynamic_eplb) - - final_hidden_states = self.token_dispatcher.token_combine( - hidden_states=mlp_output, context_metadata=context_metadata) - - if dynamic_eplb: - return (final_hidden_states, group_list_type, expert_tokens) - - return final_hidden_states + mlp_output = unified_apply_mlp( + hidden_states=dispatch_results.hidden_states, + w1=w1, + w1_scale=w1_scale, + w2=w2, + w2_scale=w2_scale, + group_list=dispatch_results.group_list, + dynamic_scale=dispatch_results.dynamic_scale, + group_list_type=dispatch_results.group_list_type, + w1_scale_bias=w1_scale_bias, + w2_scale_bias=w2_scale_bias, + w1_offset=w1_offset, + w2_offset=w2_offset, + topk_scales=dispatch_results.topk_scales, + with_quant=use_int8_w8a8 or use_int4_w4a8 or use_int4_w4a16, + fusion=use_int8_w8a8, + need_trans=need_trans, + dynamic_eplb=dynamic_eplb) + + combine_results = self.token_dispatcher.token_combine( + hidden_states=mlp_output, + context_metadata=dispatch_results.context_metadata) + + return FusedExpertsResult( + routed_out=combine_results.routed_out, + group_list_type=dispatch_results.group_list_type, + expert_tokens=dispatch_results.group_list) @abstractmethod - def _get_token_dispatcher(self): + def _get_token_dispatcher(self) -> MoETokenDispatcher: raise NotImplementedError( "_get_token_dispatcher function not implemented.") @abstractmethod - def _get_prepare_finalize(self): + def _get_prepare_finalize(self) -> PrepareAndFinalize: raise NotImplementedError( "_get_prepare_finalize function not implemented.") @@ -292,9 +299,11 @@ def fused_experts( w1_scale is None or w2_scale is None ), "w1_scale and w2_scale cannot be None for FusedMC2CommImpl." + assert isinstance(self.token_dispatcher, TokenDispatcherWithMC2), \ + "token_dispatcher must be an instance of TokenDispatcherWithMC2." if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1: out = torch.empty_like(hidden_states) - torch.ops._C_ascend.dispatch_ffn_combine( + torch.ops._C_ascend.dispatch_ffn_combine( # type: ignore x=hidden_states, weight1=w1[0], weight2=w2[0], @@ -308,7 +317,7 @@ def fused_experts( ) elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2: assert expert_map is not None, "expert_map cannot be None." - out, _ = torch.ops._C_ascend.dispatch_gmm_combine_decode( + out, _ = torch.ops._C_ascend.dispatch_gmm_combine_decode( # type: ignore x=hidden_states, expert_ids=topk_ids, gmm1_permuted_weight=w1[0], @@ -325,4 +334,4 @@ def fused_experts( else: raise ValueError( f"Wrong value of {envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2=}") - return out + return FusedExpertsResult(routed_out=out) diff --git a/vllm_ascend/ops/fused_moe/token_dispatcher.py b/vllm_ascend/ops/fused_moe/token_dispatcher.py index aeb751d0d8d..8df56ff3531 100644 --- a/vllm_ascend/ops/fused_moe/token_dispatcher.py +++ b/vllm_ascend/ops/fused_moe/token_dispatcher.py @@ -21,6 +21,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod +from dataclasses import dataclass, field from typing import Any, Optional import torch @@ -35,6 +36,21 @@ is_hierarchical_communication_enabled) +@dataclass +class TokenDispatchResult: + hidden_states: torch.Tensor + group_list: torch.Tensor + group_list_type: int + dynamic_scale: torch.Tensor | None = field(default=None) + topk_scales: torch.Tensor | None = field(default=None) + context_metadata: dict = field(default_factory=dict) + + +@dataclass +class TokenCombineResult: + routed_out: torch.Tensor + + class MoETokenDispatcher(ABC): def __init__(self, **kwargs) -> None: @@ -74,14 +90,14 @@ def token_dispatch( with_quant: bool = False, dynamic_eplb: bool = False, pertoken_scale: Optional[torch.Tensor] = None, - ): + ) -> TokenDispatchResult: raise NotImplementedError("Dispatch function not implemented.") @abstractmethod def token_combine(self, hidden_states: torch.Tensor, context_metadata: dict, - bias: torch.Tensor = None): + bias: torch.Tensor | None = None) -> TokenCombineResult: raise NotImplementedError("Combine function not implemented.") @@ -207,24 +223,6 @@ def token_dispatch(self, expand_x, dynamic_scale, assist_info_for_combine, expert_token_nums, \ ep_recv_counts, tp_recv_counts, expand_scales = output[0:7] - # Handle shared experts (store intermediate results in local vars, not self) - shared_act = None - swiglu_out_scale = None - if with_quant: - if shared_experts is not None: - share_up_out, _ = shared_experts.gate_up_proj( - (quantized_x_for_share, dynamic_scale_for_share)) - shared_gate_up, shared_dequant_scale = share_up_out[ - 0], share_up_out[1] - shared_act_out = shared_experts.act_fn( - (shared_gate_up, shared_dequant_scale)) - shared_act, swiglu_out_scale = shared_act_out[ - 0], shared_act_out[1] - else: - if shared_experts is not None: - shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states) - shared_act = shared_experts.act_fn(shared_gate_up) - context_metadata = { "topk_ids": topk_ids, "topk_weights": topk_weights, @@ -233,20 +231,16 @@ def token_dispatch(self, "tp_recv_counts": tp_recv_counts, "assist_info_for_combine": assist_info_for_combine, "shared_experts": shared_experts, - "shared_act": shared_act, - "swiglu_out_scale": swiglu_out_scale, "expand_scales": expand_scales } group_list_type = 0 - return { - "group_list_type": group_list_type, - "hidden_states": expand_x, - "group_list": expert_token_nums, - "dynamic_scale": dynamic_scale, - "context_metadata": context_metadata - } + return TokenDispatchResult(hidden_states=expand_x, + dynamic_scale=dynamic_scale, + group_list=expert_token_nums, + group_list_type=group_list_type, + context_metadata=context_metadata) def get_combine_mc_kwargs(self, hidden_states: torch.Tensor, context_metadata: dict): @@ -300,12 +294,7 @@ def get_combine_mc_kwargs(self, hidden_states: torch.Tensor, kwargs_mc2.update(stage3_kwargs) return kwargs_mc2 - def token_combine( - self, - hidden_states: torch.Tensor, - context_metadata: dict, - bias: torch.Tensor = None, - ): + def token_combine(self, hidden_states, context_metadata, bias=None): assert bias is None, "Bias is not supported in MoEAlltoAllvTokenDispatcher." kwargs_mc2 = self.get_combine_mc_kwargs(hidden_states, @@ -313,20 +302,7 @@ def token_combine( combined_output = torch_npu.npu_moe_distribute_combine_v2(**kwargs_mc2) \ if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine(**kwargs_mc2) - # Handle shared experts from metadata - shared_experts = context_metadata["shared_experts"] - if shared_experts is None: - return combined_output - - shared_act = context_metadata["shared_act"] - if self.with_quant: - swiglu_out_scale = context_metadata["swiglu_out_scale"] - shared_hidden_states, _ = shared_experts.down_proj( - (shared_act, swiglu_out_scale)) - else: - shared_hidden_states, _ = shared_experts.down_proj(shared_act) - - return combined_output, shared_hidden_states + return TokenCombineResult(routed_out=combined_output) class TokenDispatcherWithAllGather(MoETokenDispatcher): @@ -399,18 +375,16 @@ def token_dispatch(self, "topk_weights": topk_weights, "expanded_row_idx": expanded_row_idx } - return { - "group_list_type": group_list_type, - "hidden_states": sorted_hidden_states, - "group_list": expert_tokens, - "dynamic_scale": pertoken_scale if self.with_quant else None, - "context_metadata": context_metadata - } - def token_combine(self, - hidden_states: torch.Tensor, - context_metadata: dict, - bias: torch.Tensor = None): + return TokenDispatchResult( + hidden_states=sorted_hidden_states, + dynamic_scale=pertoken_scale if self.with_quant else None, + group_list=expert_tokens, + group_list_type=group_list_type, + context_metadata=context_metadata, + ) + + def token_combine(self, hidden_states, context_metadata, bias=None): assert self.original_shape is not None final_hidden_states = torch_npu.npu_moe_token_unpermute( permuted_tokens=hidden_states, @@ -420,7 +394,7 @@ def token_combine(self, final_hidden_states = final_hidden_states.view(self.original_shape) # these values are no longer used, so they need to be set to None for memory release. - return final_hidden_states + return TokenCombineResult(routed_out=final_hidden_states) class TokenDispatcherWithAll2AllV(MoETokenDispatcher): @@ -528,20 +502,15 @@ def token_dispatch(self, reversed_global_input_permutation_mapping } - return { - "hidden_states": global_input_tokens, - "group_list": tokens_per_expert, - "group_list_type": 1, - "dynamic_scale": dynamic_scale_final, - "context_metadata": context_metadata, - } + return TokenDispatchResult( + hidden_states=global_input_tokens, + dynamic_scale=dynamic_scale_final, + group_list=tokens_per_expert, + group_list_type=1, + context_metadata=context_metadata, + ) - def token_combine( - self, - hidden_states: torch.Tensor, - context_metadata: dict, - bias: torch.Tensor = None, - ): + def token_combine(self, hidden_states, context_metadata, bias=None): assert bias is None, "Bias is not supported in MoEAlltoAllvTokenDispatcher." # 1. Preprocess using metadata @@ -562,7 +531,7 @@ def token_combine( output = self._combine_postprocess(permutated_local_input_tokens, context_metadata) - return output + return TokenCombineResult(routed_out=output) def _dispatch_preprocess(self, hidden_states, topk_ids): assert self.hidden_shape is not None