diff --git a/tests/ut/ops/test_moe_comm_method.py b/tests/ut/ops/test_moe_comm_method.py index 5ff39c4ee6f..c8e0eef46a4 100644 --- a/tests/ut/ops/test_moe_comm_method.py +++ b/tests/ut/ops/test_moe_comm_method.py @@ -163,6 +163,7 @@ def test_alltoall_comm_impl(self, mock_token_dispatcher, "vllm_ascend.ops.fused_moe.moe_comm_method.TokenDispatcherWithAllGather" ) @patch("vllm_ascend.ops.fused_moe.moe_comm_method.unified_apply_mlp") + @patch("torch.npu.current_stream", MagicMock()) def test_fused_experts_method(self, mock_unified_apply_mlp, mock_token_dispatcher, mock_prepare_finalize, mock_get_forward_context): diff --git a/tests/ut/ops/test_token_dispatcher.py b/tests/ut/ops/test_token_dispatcher.py index a1919b6b00d..9609b97d990 100644 --- a/tests/ut/ops/test_token_dispatcher.py +++ b/tests/ut/ops/test_token_dispatcher.py @@ -116,26 +116,6 @@ def test_token_permutation_dispatch(self): mock_dispatch.assert_called_once() self.assertEqual(output.group_list_type, 0) # group_list_type == 0 - def test_token_dispatch_with_shared_experts_and_quant(self): - self.shared_experts = MagicMock() - self.shared_experts.gate_up_proj.return_value = (torch.randn(10, 128), - torch.tensor(1.0)) - self.shared_experts.act_fn.return_value = torch.randn(10, 128) - self.dispatcher.with_quant = False - self.dispatcher.shared_act = torch.randn(10, 128) - self.dispatcher.swiglu_out_scale = torch.tensor(1.0) - self.hidden_states = torch.randn(10, 128) - self.topk_weights = torch.randn(10, 1) - - with patch("torch_npu.npu_moe_distribute_dispatch_v2", - return_value=(torch.randn(10, 128), ) * 5 + (None, None)): - self.dispatcher.token_dispatch(self.hidden_states, - self.topk_weights, - torch.randint(0, 8, (10, 1)), - torch.tensor( - [0, 1, 2, 3, 4, 5, 6, 7]), - shared_experts=self.shared_experts) - def test_get_combine_mc_kwargs_with_quant(self): self.dispatcher.with_quant = True hidden_states = torch.randn(10, 128) diff --git a/vllm_ascend/ops/fused_moe/fused_moe.py b/vllm_ascend/ops/fused_moe/fused_moe.py index 8d7628360d2..81bf0796dba 100644 --- a/vllm_ascend/ops/fused_moe/fused_moe.py +++ b/vllm_ascend/ops/fused_moe/fused_moe.py @@ -14,9 +14,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import Any, Callable, Optional +from dataclasses import dataclass, field +from functools import wraps +from typing import Callable, Optional import torch +import torch.nn.functional as F from vllm.config import get_current_vllm_config from vllm.distributed import (get_dp_group, get_ep_group, get_tp_group, tensor_model_parallel_all_reduce) @@ -47,7 +50,20 @@ from vllm_ascend.utils import (AscendDeviceType, enable_sp, get_ascend_device_type, maybe_trans_nz, npu_stream_switch, shared_expert_dp_enabled, - shared_experts_calculation_stream) + shared_experts_calculation_stream, vllm_version_is) + +@dataclass +class FusedMoEResult: + routed_out: torch.Tensor + before_dispatch_evt: torch.npu.Event | None = None + before_combine_evt: torch.npu.Event | None = None + + +@dataclass +class FusedMoEEvents: + before_routed_experts: torch.npu.Event + before_dispatch: torch.npu.Event | None = field(default=None) + before_combine: torch.npu.Event | None = field(default=None) class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): @@ -90,7 +106,6 @@ def apply(self, expert_map: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, enable_force_load_balance: bool = False, - shared_experts: Optional[Any] = None, **kwargs) -> torch.Tensor: zero_expert_num = getattr(layer, "zero_expert_num", 0) zero_expert_type = getattr(layer, "zero_expert_type", None) @@ -137,7 +152,6 @@ def apply(self, topk_ids=topk_ids, global_num_experts=global_num_experts, expert_map=expert_map, - shared_experts=shared_experts, apply_router_weight_on_input=apply_router_weight_on_input, dynamic_eplb=self.dynamic_eplb, mc2_mask=kwargs.get("mc2_mask", None)) @@ -268,13 +282,13 @@ def maybe_all_reduce_tensor_model_parallel( return torch.ops.vllm.maybe_all_reduce_tensor_model_parallel( final_hidden_states) - def forward_impl(self, hidden_states: torch.Tensor, - router_logits: torch.Tensor): + def forward_impl( # type: ignore[override] + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + return_with_event: bool = False) -> torch.Tensor | FusedMoEResult: assert self.quant_method is not None - # For w8a8 dynamic we can do npu_dynamic_quant and gate in parallel. - quantized_x_for_share, dynamic_scale_for_share = None, None - forward_context = get_forward_context() # Load balancing for token distribution among experts in dummy_run @@ -359,9 +373,6 @@ def forward_impl(self, hidden_states: torch.Tensor, e_score_correction_bias=self.e_score_correction_bias, activation=self.activation, apply_router_weight_on_input=self.apply_router_weight_on_input, - quantized_x_for_share=quantized_x_for_share, - dynamic_scale_for_share=dynamic_scale_for_share, - shared_experts=None, enable_force_load_balance=enable_force_load_balance, log2phy=self.log2phy, global_redundant_expert_num=self.global_redundant_expert_num, @@ -380,7 +391,14 @@ def forward_impl(self, hidden_states: torch.Tensor, reduce_results=self.reduce_results, context_metadata=context_metadata) - return routed_out + if return_with_event: + return FusedMoEResult( + routed_out=routed_out, + before_dispatch_evt=fused_experts_results.before_dispatch_evt, + before_combine_evt=fused_experts_results.before_combine_evt) + else: + # The vLLM FusedMoE forward_impl does not return events. + return routed_out class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE): @@ -407,6 +425,74 @@ def __init__( self._gate = gate + # Wrap the quant_method's process_weights_after_loading to validate that + # splitting shared expert computation (gate_up projection + activation, + # then down projection) yields identical results to integrated + # computation after weight loading. + original_process_weights = self.quant_method.process_weights_after_loading + + @wraps(original_process_weights) + def wrapped_process_weights(*args, **kwargs): + result = original_process_weights(*args, **kwargs) + self._validate_shared_expert_consistency() + return result + + self.quant_method.process_weights_after_loading = wrapped_process_weights # type: ignore + + def _shared_experts_part1(self, hidden_states: torch.Tensor): + shared_gate_up, _ = self._shared_experts.gate_up_proj( + hidden_states) # type: ignore + return shared_gate_up + + def _shared_experts_part2(self, hidden_states: torch.Tensor, + shared_gate_up: torch.Tensor): + shared_act = self._shared_experts.act_fn( + shared_gate_up) # type: ignore + shared_out, _ = self._shared_experts.down_proj( + shared_act) # type: ignore + + # Qwen3-Next specific gating mechanism + if hasattr(self._shared_experts, "expert_gate") and \ + self._shared_experts.expert_gate is not None: + if vllm_version_is('0.13.0'): + # TODO(jianzs): remove this branch after vLLM new version is + # released + gate_out = self._shared_experts.expert_gate(hidden_states) # type: ignore + else: + gate_out, _ = self._shared_experts.expert_gate(hidden_states) # type: ignore + shared_out = F.sigmoid(gate_out) * shared_out + return shared_out + + def _validate_shared_expert_consistency(self): + """Validate that split shared expert computation matches integrated + computation.""" + test_input = torch.rand( + 10, self.hidden_size, device='npu', dtype=self.moe_config.in_dtype + ) * 2 - 1 # Random input for testing, scoped to [-1, 1] + + integrated_out = self._shared_experts(test_input) + part1_out = self._shared_experts_part1(test_input) + split_out = self._shared_experts_part2(test_input, part1_out) + + if not torch.allclose(integrated_out, split_out): + diff = (integrated_out - split_out).abs() + logger.error( + "SharedFusedMoE shared experts split computation does not " + "match the integrated computation.") + logger.error(f"Max absolute difference: {diff.max().item()}") + logger.error("Integrated output - sum: %s, norm: %s", + integrated_out.sum().item(), + integrated_out.norm().item()) + logger.error("Split output - sum: %s, norm: %s", + split_out.sum().item(), + split_out.norm().item()) + raise ValueError( + "SharedFusedMoE shared experts split computation does not " + "match the integrated computation.") + logger.info_once( + "SharedFusedMoE shared experts split computation matches the " + "integrated computation.") + @property def gate(self) -> Optional[torch.nn.Module]: return self._gate if self.use_overlapped else None @@ -434,44 +520,67 @@ def forward( ) return shared_out, fused_out - def forward_impl(self, hidden_states: torch.Tensor, - router_logits: torch.Tensor): - shared_out = None - if not self.multistream_overlap_gate: - # Make sure the shared experts stream begins after hidden_states are ready. - if self.multistream_overlap_shared_expert: - shared_experts_calculation_stream( - ).wait_stream( # type: ignore - torch.npu.current_stream()) - with npu_stream_switch( - shared_experts_calculation_stream(), - enabled=self.multistream_overlap_shared_expert): - # Use a separate stream to run shared experts. - shared_out = self._shared_experts(hidden_states) - else: + def _forward_shared_experts(self, hidden_states: torch.Tensor, + fused_moe_evts: FusedMoEEvents): + + def maybe_wait_event(evt: torch.npu.Event | None): + if evt is not None: + torch.npu.current_stream().wait_event(evt) + + with npu_stream_switch(shared_experts_calculation_stream(), + enabled=self.multistream_overlap_shared_expert): + # Ensure the shared experts wait for hidden_states to be ready. + torch.npu.current_stream().wait_event( + fused_moe_evts.before_routed_experts) + # Execute the gate projection and activation concurrently with the + # dispatch communication. + maybe_wait_event(fused_moe_evts.before_dispatch) + part1_out = self._shared_experts_part1(hidden_states) + # Execute the down projection concurrently with the combine + # communication. + maybe_wait_event(fused_moe_evts.before_combine) + shared_out = self._shared_experts_part2(hidden_states, part1_out) + + # Make sure the default stream waits for the shared experts stream to + # finish. + if self.multistream_overlap_shared_expert: + torch.npu.current_stream().wait_stream( + shared_experts_calculation_stream()) + + # NOTE: This is exactly the opposite of + # `maybe_all_reduce_tensor_model_parallel` + forward_context = get_forward_context() + moe_comm_type = forward_context.moe_comm_type + if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2} \ + and not shared_expert_dp_enabled(): + shared_out = tensor_model_parallel_all_reduce(shared_out) + return shared_out + + def forward_impl( # type: ignore[override] + self, hidden_states: torch.Tensor, router_logits: torch.Tensor): + if self.multistream_overlap_gate: set_flash_common3_context(shared_experts=self._shared_experts) - routed_out = AscendFusedMoE.forward_impl( + before_routed_experts = torch.npu.current_stream().record_event() + fused_moe_results = AscendFusedMoE.forward_impl( self, hidden_states=hidden_states, router_logits=router_logits, + return_with_event=True, ) + routed_out = fused_moe_results.routed_out - if not self.multistream_overlap_gate: - # Make sure the default stream waits for the shared experts stream to finish. - if self.multistream_overlap_shared_expert: - torch.npu.current_stream().wait_stream( - shared_experts_calculation_stream()) - - # NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel` - forward_context = get_forward_context() - moe_comm_type = forward_context.moe_comm_type - if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2} \ - and not shared_expert_dp_enabled(): - shared_out = tensor_model_parallel_all_reduce(shared_out) - else: + if self.multistream_overlap_gate: fc3_context = get_flash_common3_context() assert fc3_context is not None shared_out = fc3_context.shared_out + else: + shared_out = self._forward_shared_experts( + hidden_states, + FusedMoEEvents( + before_routed_experts=before_routed_experts, + before_dispatch=fused_moe_results.before_dispatch_evt, + before_combine=fused_moe_results.before_combine_evt, + )) return shared_out, routed_out diff --git a/vllm_ascend/ops/fused_moe/moe_comm_method.py b/vllm_ascend/ops/fused_moe/moe_comm_method.py index 1692f1453e1..b5427fa753e 100644 --- a/vllm_ascend/ops/fused_moe/moe_comm_method.py +++ b/vllm_ascend/ops/fused_moe/moe_comm_method.py @@ -17,7 +17,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any, Dict, Optional +from typing import Dict, Optional import torch from vllm.forward_context import get_forward_context @@ -51,6 +51,11 @@ def setup_moe_comm_method(moe_config): @dataclass class FusedExpertsResult: routed_out: torch.Tensor + # This field is for shared experts and should be set by the MoE + # communication method that supports shared experts in parallel with routed + # experts. + before_dispatch_evt: torch.npu.Event | None = None + before_combine_evt: torch.npu.Event | None = None # For dynamic_eplb group_list_type: int | None = None expert_tokens: torch.Tensor | None = None @@ -108,10 +113,6 @@ def fused_experts( w2_scale_bias: torch.Tensor = None, w1_offset: Optional[torch.Tensor] = None, w2_offset: Optional[torch.Tensor] = None, - # For Cube/Vector parallel - shared_experts: Optional[Any] = None, - quantized_x_for_share: Optional[Any] = None, - dynamic_scale_for_share: Optional[Any] = None, # For load balance log2phy: torch.Tensor = None, need_trans: bool = False, @@ -126,6 +127,7 @@ def fused_experts( moe_comm_method = get_forward_context().moe_comm_method assert moe_comm_method is not None, "Missing communication context" + before_dispatch_evt = torch.npu.current_stream().record_event() dispatch_results = self.token_dispatcher.token_dispatch( hidden_states=hidden_states, topk_weights=topk_weights, @@ -134,9 +136,6 @@ def fused_experts( log2phy=log2phy, global_redundant_expert_num=self.moe_config. global_redundant_expert_num, - shared_experts=shared_experts, - quantized_x_for_share=quantized_x_for_share, - dynamic_scale_for_share=dynamic_scale_for_share, mc2_mask=mc2_mask, apply_router_weight_on_input=apply_router_weight_on_input, with_quant=use_int8_w8a8 or use_int4_w4a8, @@ -162,12 +161,15 @@ def fused_experts( need_trans=need_trans, dynamic_eplb=dynamic_eplb) + before_combine_evt = torch.npu.current_stream().record_event() combine_results = self.token_dispatcher.token_combine( hidden_states=mlp_output, context_metadata=dispatch_results.context_metadata) return FusedExpertsResult( routed_out=combine_results.routed_out, + before_dispatch_evt=before_dispatch_evt, + before_combine_evt=before_combine_evt, group_list_type=dispatch_results.group_list_type, expert_tokens=dispatch_results.group_list) @@ -284,10 +286,6 @@ def fused_experts( w2_scale_bias: torch.Tensor = None, w1_offset: Optional[torch.Tensor] = None, w2_offset: Optional[torch.Tensor] = None, - # For Cube/Vector parallel - shared_experts: Optional[Any] = None, - quantized_x_for_share: Optional[Any] = None, - dynamic_scale_for_share: Optional[Any] = None, # For load balance log2phy: torch.Tensor = None, need_trans: bool = False, diff --git a/vllm_ascend/ops/fused_moe/prepare_finalize.py b/vllm_ascend/ops/fused_moe/prepare_finalize.py index 05f439129f4..df8e4cf420f 100644 --- a/vllm_ascend/ops/fused_moe/prepare_finalize.py +++ b/vllm_ascend/ops/fused_moe/prepare_finalize.py @@ -151,8 +151,8 @@ def prepare( """ self.replace_allreduce = replace_allreduce self.enable_shared_expert_dp = enable_shared_expert_dp - split_hidden_states = None + padded_hidden_states_shape = hidden_states.shape if not (self.replace_allreduce or self.enable_shared_expert_dp): self.num_tokens, _ = hidden_states.shape pad_size = self.tp_size - self.num_tokens # Pad to TP size (cyclic) @@ -162,6 +162,7 @@ def prepare( (0, 0, 0, pad_size)) router_logits = nn.functional.pad(router_logits, (0, 0, 0, pad_size)) + padded_hidden_states_shape = hidden_states.shape if self.tp_size > 1: split_hidden_states = torch.tensor_split(hidden_states, @@ -174,7 +175,9 @@ def prepare( hidden_states = split_hidden_states[self.tp_rank] router_logits = split_router_logits[self.tp_rank] - context_metadata = {"split_hidden_states": split_hidden_states} + context_metadata = { + "padded_hidden_states_shape": padded_hidden_states_shape + } return hidden_states, router_logits, None, context_metadata @@ -190,14 +193,25 @@ def finalize(self, Skips if `enable_shared_expert_dp` or `replace_allreduce` is True. """ - assert context_metadata is not None - split_hidden_states = context_metadata["split_hidden_states"] if not (self.enable_shared_expert_dp or self.replace_allreduce): if self.tp_size > 1: + assert context_metadata is not None + # Cannot reuse `split_hidden_states` from prepare phase as it + # may share memory with original hidden_states. Since shared + # experts may use the original tensor, reusing it would cause + # in-place modification during all_gather, corrupting the data. + padded_hidden_states_shape = context_metadata[ + "padded_hidden_states_shape"] + gathered_hidden_states = torch.empty( + padded_hidden_states_shape, + device=hidden_states.device, + dtype=hidden_states.dtype) + split_hidden_states = torch.tensor_split( + gathered_hidden_states, self.tp_size, dim=0) dist.all_gather(list(split_hidden_states), hidden_states, self.moe_config.tp_group.device_group) - hidden_states = torch.cat(split_hidden_states, dim=0) + hidden_states = gathered_hidden_states if self.num_tokens < hidden_states.shape[0]: hidden_states = hidden_states[:self.num_tokens] @@ -249,7 +263,6 @@ def prepare( """ self.replace_allreduce = replace_allreduce self.enable_shared_expert_dp = enable_shared_expert_dp - split_hidden_states = None forward_context = get_forward_context() mc2_mask = forward_context.mc2_mask if self.tp_size > 1: @@ -257,6 +270,7 @@ def prepare( split_mc2_mask = torch.tensor_split(mc2_mask, self.tp_size, dim=0) mc2_mask = split_mc2_mask[self.tp_rank] + padded_hidden_states_shape = hidden_states.shape if not self.replace_allreduce: self.num_tokens, _ = hidden_states.shape target_pad_length = forward_context.padded_num_tokens @@ -268,6 +282,7 @@ def prepare( (0, 0, 0, pad_size)) router_logits = nn.functional.pad(router_logits, (0, 0, 0, pad_size)) + padded_hidden_states_shape = hidden_states.shape # Slice across TP ranks if self.tp_size > 1 and not self.enable_shared_expert_dp: @@ -280,7 +295,9 @@ def prepare( hidden_states = split_hidden_states[self.tp_rank] router_logits = split_router_logits[self.tp_rank] - context_metadata = {"split_hidden_states": split_hidden_states} + context_metadata = { + "padded_hidden_states_shape": padded_hidden_states_shape, + } return hidden_states, router_logits, mc2_mask, context_metadata diff --git a/vllm_ascend/ops/fused_moe/token_dispatcher.py b/vllm_ascend/ops/fused_moe/token_dispatcher.py index d90f4c71f05..4b699e4298c 100644 --- a/vllm_ascend/ops/fused_moe/token_dispatcher.py +++ b/vllm_ascend/ops/fused_moe/token_dispatcher.py @@ -22,7 +22,7 @@ # limitations under the License. from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import Any, Optional +from typing import Optional import torch import torch_npu @@ -82,9 +82,6 @@ def token_dispatch( expert_map: Optional[torch.Tensor] = None, log2phy: Optional[torch.Tensor] = None, global_redundant_expert_num: int = 0, - shared_experts: Optional[Any] = None, - quantized_x_for_share: Optional[Any] = None, - dynamic_scale_for_share: Optional[Any] = None, mc2_mask: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, with_quant: bool = False, @@ -193,9 +190,6 @@ def token_dispatch(self, expert_map: Optional[torch.Tensor] = None, log2phy: Optional[torch.Tensor] = None, global_redundant_expert_num: int = 0, - shared_experts: Optional[Any] = None, - quantized_x_for_share: Optional[Any] = None, - dynamic_scale_for_share: Optional[Any] = None, mc2_mask: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, with_quant: bool = False, @@ -226,12 +220,10 @@ def token_dispatch(self, "ep_recv_counts": ep_recv_counts, "tp_recv_counts": tp_recv_counts, "assist_info_for_combine": assist_info_for_combine, - "shared_experts": shared_experts, "expand_scales": expand_scales } group_list_type = 0 - return TokenDispatchResult(hidden_states=expand_x, dynamic_scale=dynamic_scale, group_list=expert_token_nums, @@ -297,7 +289,7 @@ def token_combine(self, hidden_states, context_metadata, bias=None): combined_output = torch_npu.npu_moe_distribute_combine_v2(**kwargs_mc2) \ if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine(**kwargs_mc2) - return TokenCombineResult(routed_out=combined_output) + return TokenCombineResult(routed_out=combined_output, ) class TokenDispatcherWithAllGather(MoETokenDispatcher): @@ -319,9 +311,6 @@ def token_dispatch(self, expert_map: Optional[torch.Tensor] = None, log2phy: Optional[torch.Tensor] = None, global_redundant_expert_num: int = 0, - shared_experts: Optional[Any] = None, - quantized_x_for_share: Optional[Any] = None, - dynamic_scale_for_share: Optional[Any] = None, mc2_mask: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, with_quant: bool = False, @@ -442,9 +431,6 @@ def token_dispatch(self, expert_map: Optional[torch.Tensor] = None, log2phy: Optional[torch.Tensor] = None, global_redundant_expert_num: int = 0, - shared_experts: Optional[Any] = None, - quantized_x_for_share: Optional[Any] = None, - dynamic_scale_for_share: Optional[Any] = None, mc2_mask: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, with_quant: bool = False, diff --git a/vllm_ascend/quantization/w4a16.py b/vllm_ascend/quantization/w4a16.py index 3767593fa14..45fb9bd8910 100644 --- a/vllm_ascend/quantization/w4a16.py +++ b/vllm_ascend/quantization/w4a16.py @@ -204,9 +204,6 @@ def apply( enable_force_load_balance: bool = True, log2phy: torch.Tensor = None, global_redundant_expert_num: int = 0, - shared_experts: Optional[Any] = None, - quantized_x_for_share: Optional[Any] = None, - dynamic_scale_for_share: Optional[Any] = None, **kwargs, ) -> torch.Tensor: assert router_logits.shape[ @@ -229,24 +226,21 @@ def apply( topk_weights = topk_weights.to(x.dtype) moe_comm_method = get_forward_context().moe_comm_method - return moe_comm_method.fused_experts( - hidden_states=x, - w1=layer.w13_weight_packed, - w2=layer.w2_weight_packed, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - w1_offset=layer.w13_weight_offset, - w2_offset=layer.w2_weight_offset, - topk_weights=topk_weights, - topk_ids=topk_ids, - use_int4_w4a16=True, - expert_map=expert_map, - log2phy=log2phy, - shared_experts=shared_experts, - quantized_x_for_share=quantized_x_for_share, - dynamic_scale_for_share=dynamic_scale_for_share, - dynamic_eplb=self.dynamic_eplb, - mc2_mask=kwargs.get("mc2_mask", None)) + return moe_comm_method.fused_experts(hidden_states=x, + w1=layer.w13_weight_packed, + w2=layer.w2_weight_packed, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + w1_offset=layer.w13_weight_offset, + w2_offset=layer.w2_weight_offset, + topk_weights=topk_weights, + topk_ids=topk_ids, + use_int4_w4a16=True, + expert_map=expert_map, + log2phy=log2phy, + dynamic_eplb=self.dynamic_eplb, + mc2_mask=kwargs.get( + "mc2_mask", None)) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if self.transpose_weight: diff --git a/vllm_ascend/quantization/w4a8_dynamic.py b/vllm_ascend/quantization/w4a8_dynamic.py index 21a4302891a..e46ac1cb57d 100644 --- a/vllm_ascend/quantization/w4a8_dynamic.py +++ b/vllm_ascend/quantization/w4a8_dynamic.py @@ -341,9 +341,6 @@ def apply( enable_force_load_balance: bool = False, log2phy: torch.Tensor = None, global_redundant_expert_num: int = 0, - shared_experts: Optional[Any] = None, - quantized_x_for_share: Optional[Any] = None, - dynamic_scale_for_share: Optional[Any] = None, **kwargs, ) -> torch.Tensor: assert router_logits.shape[ @@ -390,9 +387,6 @@ def apply( use_int4_w4a8=True, expert_map=expert_map, log2phy=log2phy, - shared_experts=shared_experts, - quantized_x_for_share=quantized_x_for_share, - dynamic_scale_for_share=dynamic_scale_for_share, dynamic_eplb=self.dynamic_eplb, mc2_mask=kwargs.get("mc2_mask", None)) diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 103b6c10c7c..0c6def3371c 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -190,9 +190,6 @@ def apply( enable_force_load_balance: bool = False, log2phy: torch.Tensor = None, global_redundant_expert_num: int = 0, - shared_experts: Optional[Any] = None, - quantized_x_for_share: Optional[Any] = None, - dynamic_scale_for_share: Optional[Any] = None, pertoken_scale: Optional[Any] = None, **kwargs, ) -> torch.Tensor: @@ -280,9 +277,6 @@ def apply( use_int8_w8a8=True, expert_map=expert_map, log2phy=log2phy, - shared_experts=shared_experts, - quantized_x_for_share=quantized_x_for_share, - dynamic_scale_for_share=dynamic_scale_for_share, dynamic_eplb=self.dynamic_eplb, mc2_mask=kwargs.get("mc2_mask", None)) if zero_expert_num > 0 and zero_expert_type is not None: