Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/ut/ops/test_moe_comm_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def test_alltoall_comm_impl(self, mock_token_dispatcher,
"vllm_ascend.ops.fused_moe.moe_comm_method.TokenDispatcherWithAllGather"
)
@patch("vllm_ascend.ops.fused_moe.moe_comm_method.unified_apply_mlp")
@patch("torch.npu.current_stream", MagicMock())
def test_fused_experts_method(self, mock_unified_apply_mlp,
mock_token_dispatcher, mock_prepare_finalize,
mock_get_forward_context):
Expand Down
20 changes: 0 additions & 20 deletions tests/ut/ops/test_token_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,26 +116,6 @@ def test_token_permutation_dispatch(self):
mock_dispatch.assert_called_once()
self.assertEqual(output.group_list_type, 0) # group_list_type == 0

def test_token_dispatch_with_shared_experts_and_quant(self):
self.shared_experts = MagicMock()
self.shared_experts.gate_up_proj.return_value = (torch.randn(10, 128),
torch.tensor(1.0))
self.shared_experts.act_fn.return_value = torch.randn(10, 128)
self.dispatcher.with_quant = False
self.dispatcher.shared_act = torch.randn(10, 128)
self.dispatcher.swiglu_out_scale = torch.tensor(1.0)
self.hidden_states = torch.randn(10, 128)
self.topk_weights = torch.randn(10, 1)

with patch("torch_npu.npu_moe_distribute_dispatch_v2",
return_value=(torch.randn(10, 128), ) * 5 + (None, None)):
self.dispatcher.token_dispatch(self.hidden_states,
self.topk_weights,
torch.randint(0, 8, (10, 1)),
torch.tensor(
[0, 1, 2, 3, 4, 5, 6, 7]),
shared_experts=self.shared_experts)

def test_get_combine_mc_kwargs_with_quant(self):
self.dispatcher.with_quant = True
hidden_states = torch.randn(10, 128)
Expand Down
193 changes: 151 additions & 42 deletions vllm_ascend/ops/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Any, Callable, Optional
from dataclasses import dataclass, field
from functools import wraps
from typing import Callable, Optional

import torch
import torch.nn.functional as F
from vllm.config import get_current_vllm_config
from vllm.distributed import (get_dp_group, get_ep_group, get_tp_group,
tensor_model_parallel_all_reduce)
Expand Down Expand Up @@ -47,7 +50,20 @@
from vllm_ascend.utils import (AscendDeviceType, enable_sp,
get_ascend_device_type, maybe_trans_nz,
npu_stream_switch, shared_expert_dp_enabled,
shared_experts_calculation_stream)
shared_experts_calculation_stream, vllm_version_is)

@dataclass
class FusedMoEResult:
routed_out: torch.Tensor
before_dispatch_evt: torch.npu.Event | None = None
before_combine_evt: torch.npu.Event | None = None


@dataclass
class FusedMoEEvents:
before_routed_experts: torch.npu.Event
before_dispatch: torch.npu.Event | None = field(default=None)
before_combine: torch.npu.Event | None = field(default=None)


class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
Expand Down Expand Up @@ -90,7 +106,6 @@ def apply(self,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
enable_force_load_balance: bool = False,
shared_experts: Optional[Any] = None,
**kwargs) -> torch.Tensor:
zero_expert_num = getattr(layer, "zero_expert_num", 0)
zero_expert_type = getattr(layer, "zero_expert_type", None)
Expand Down Expand Up @@ -137,7 +152,6 @@ def apply(self,
topk_ids=topk_ids,
global_num_experts=global_num_experts,
expert_map=expert_map,
shared_experts=shared_experts,
apply_router_weight_on_input=apply_router_weight_on_input,
dynamic_eplb=self.dynamic_eplb,
mc2_mask=kwargs.get("mc2_mask", None))
Expand Down Expand Up @@ -268,13 +282,13 @@ def maybe_all_reduce_tensor_model_parallel(
return torch.ops.vllm.maybe_all_reduce_tensor_model_parallel(
final_hidden_states)

def forward_impl(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
def forward_impl( # type: ignore[override]
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
return_with_event: bool = False) -> torch.Tensor | FusedMoEResult:
assert self.quant_method is not None

# For w8a8 dynamic we can do npu_dynamic_quant and gate in parallel.
quantized_x_for_share, dynamic_scale_for_share = None, None

forward_context = get_forward_context()

# Load balancing for token distribution among experts in dummy_run
Expand Down Expand Up @@ -359,9 +373,6 @@ def forward_impl(self, hidden_states: torch.Tensor,
e_score_correction_bias=self.e_score_correction_bias,
activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input,
quantized_x_for_share=quantized_x_for_share,
dynamic_scale_for_share=dynamic_scale_for_share,
shared_experts=None,
enable_force_load_balance=enable_force_load_balance,
log2phy=self.log2phy,
global_redundant_expert_num=self.global_redundant_expert_num,
Expand All @@ -380,7 +391,14 @@ def forward_impl(self, hidden_states: torch.Tensor,
reduce_results=self.reduce_results,
context_metadata=context_metadata)

return routed_out
if return_with_event:
return FusedMoEResult(
routed_out=routed_out,
before_dispatch_evt=fused_experts_results.before_dispatch_evt,
before_combine_evt=fused_experts_results.before_combine_evt)
else:
# The vLLM FusedMoE forward_impl does not return events.
return routed_out


class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
Expand All @@ -407,6 +425,74 @@ def __init__(

self._gate = gate

# Wrap the quant_method's process_weights_after_loading to validate that
# splitting shared expert computation (gate_up projection + activation,
# then down projection) yields identical results to integrated
# computation after weight loading.
original_process_weights = self.quant_method.process_weights_after_loading

@wraps(original_process_weights)
def wrapped_process_weights(*args, **kwargs):
result = original_process_weights(*args, **kwargs)
self._validate_shared_expert_consistency()
return result

self.quant_method.process_weights_after_loading = wrapped_process_weights # type: ignore

def _shared_experts_part1(self, hidden_states: torch.Tensor):
shared_gate_up, _ = self._shared_experts.gate_up_proj(
hidden_states) # type: ignore
return shared_gate_up

def _shared_experts_part2(self, hidden_states: torch.Tensor,
shared_gate_up: torch.Tensor):
shared_act = self._shared_experts.act_fn(
shared_gate_up) # type: ignore
shared_out, _ = self._shared_experts.down_proj(
shared_act) # type: ignore

# Qwen3-Next specific gating mechanism
if hasattr(self._shared_experts, "expert_gate") and \
self._shared_experts.expert_gate is not None:
if vllm_version_is('0.13.0'):
# TODO(jianzs): remove this branch after vLLM new version is
# released
gate_out = self._shared_experts.expert_gate(hidden_states) # type: ignore
else:
gate_out, _ = self._shared_experts.expert_gate(hidden_states) # type: ignore
shared_out = F.sigmoid(gate_out) * shared_out
return shared_out

def _validate_shared_expert_consistency(self):
"""Validate that split shared expert computation matches integrated
computation."""
test_input = torch.rand(
10, self.hidden_size, device='npu', dtype=self.moe_config.in_dtype
) * 2 - 1 # Random input for testing, scoped to [-1, 1]

integrated_out = self._shared_experts(test_input)
part1_out = self._shared_experts_part1(test_input)
split_out = self._shared_experts_part2(test_input, part1_out)

if not torch.allclose(integrated_out, split_out):
diff = (integrated_out - split_out).abs()
logger.error(
"SharedFusedMoE shared experts split computation does not "
"match the integrated computation.")
logger.error(f"Max absolute difference: {diff.max().item()}")
logger.error("Integrated output - sum: %s, norm: %s",
integrated_out.sum().item(),
integrated_out.norm().item())
logger.error("Split output - sum: %s, norm: %s",
split_out.sum().item(),
split_out.norm().item())
raise ValueError(
"SharedFusedMoE shared experts split computation does not "
"match the integrated computation.")
logger.info_once(
"SharedFusedMoE shared experts split computation matches the "
"integrated computation.")

@property
def gate(self) -> Optional[torch.nn.Module]:
return self._gate if self.use_overlapped else None
Expand Down Expand Up @@ -434,44 +520,67 @@ def forward(
)
return shared_out, fused_out

def forward_impl(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
shared_out = None
if not self.multistream_overlap_gate:
# Make sure the shared experts stream begins after hidden_states are ready.
if self.multistream_overlap_shared_expert:
shared_experts_calculation_stream(
).wait_stream( # type: ignore
torch.npu.current_stream())
with npu_stream_switch(
shared_experts_calculation_stream(),
enabled=self.multistream_overlap_shared_expert):
# Use a separate stream to run shared experts.
shared_out = self._shared_experts(hidden_states)
else:
def _forward_shared_experts(self, hidden_states: torch.Tensor,
fused_moe_evts: FusedMoEEvents):

def maybe_wait_event(evt: torch.npu.Event | None):
if evt is not None:
torch.npu.current_stream().wait_event(evt)

with npu_stream_switch(shared_experts_calculation_stream(),
enabled=self.multistream_overlap_shared_expert):
# Ensure the shared experts wait for hidden_states to be ready.
torch.npu.current_stream().wait_event(
Comment thread
jianzs marked this conversation as resolved.
fused_moe_evts.before_routed_experts)
# Execute the gate projection and activation concurrently with the
# dispatch communication.
maybe_wait_event(fused_moe_evts.before_dispatch)
Comment thread
jianzs marked this conversation as resolved.
part1_out = self._shared_experts_part1(hidden_states)
# Execute the down projection concurrently with the combine
# communication.
maybe_wait_event(fused_moe_evts.before_combine)
shared_out = self._shared_experts_part2(hidden_states, part1_out)

# Make sure the default stream waits for the shared experts stream to
# finish.
if self.multistream_overlap_shared_expert:
torch.npu.current_stream().wait_stream(
Comment thread
jianzs marked this conversation as resolved.
shared_experts_calculation_stream())

# NOTE: This is exactly the opposite of
# `maybe_all_reduce_tensor_model_parallel`
forward_context = get_forward_context()
moe_comm_type = forward_context.moe_comm_type
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2} \
and not shared_expert_dp_enabled():
shared_out = tensor_model_parallel_all_reduce(shared_out)
return shared_out

def forward_impl( # type: ignore[override]
self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
if self.multistream_overlap_gate:
set_flash_common3_context(shared_experts=self._shared_experts)

routed_out = AscendFusedMoE.forward_impl(
before_routed_experts = torch.npu.current_stream().record_event()
fused_moe_results = AscendFusedMoE.forward_impl(
self,
hidden_states=hidden_states,
router_logits=router_logits,
return_with_event=True,
)
routed_out = fused_moe_results.routed_out

if not self.multistream_overlap_gate:
# Make sure the default stream waits for the shared experts stream to finish.
if self.multistream_overlap_shared_expert:
torch.npu.current_stream().wait_stream(
shared_experts_calculation_stream())

# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
forward_context = get_forward_context()
moe_comm_type = forward_context.moe_comm_type
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2} \
and not shared_expert_dp_enabled():
shared_out = tensor_model_parallel_all_reduce(shared_out)
else:
if self.multistream_overlap_gate:
fc3_context = get_flash_common3_context()
assert fc3_context is not None
shared_out = fc3_context.shared_out
else:
shared_out = self._forward_shared_experts(
Comment thread
jianzs marked this conversation as resolved.
hidden_states,
FusedMoEEvents(
before_routed_experts=before_routed_experts,
before_dispatch=fused_moe_results.before_dispatch_evt,
before_combine=fused_moe_results.before_combine_evt,
))

return shared_out, routed_out
22 changes: 10 additions & 12 deletions vllm_ascend/ops/fused_moe/moe_comm_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, Optional
from typing import Dict, Optional

import torch
from vllm.forward_context import get_forward_context
Expand Down Expand Up @@ -51,6 +51,11 @@ def setup_moe_comm_method(moe_config):
@dataclass
class FusedExpertsResult:
routed_out: torch.Tensor
# This field is for shared experts and should be set by the MoE
# communication method that supports shared experts in parallel with routed
# experts.
before_dispatch_evt: torch.npu.Event | None = None
before_combine_evt: torch.npu.Event | None = None
# For dynamic_eplb
group_list_type: int | None = None
expert_tokens: torch.Tensor | None = None
Expand Down Expand Up @@ -108,10 +113,6 @@ def fused_experts(
w2_scale_bias: torch.Tensor = None,
w1_offset: Optional[torch.Tensor] = None,
w2_offset: Optional[torch.Tensor] = None,
# For Cube/Vector parallel
shared_experts: Optional[Any] = None,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
# For load balance
log2phy: torch.Tensor = None,
need_trans: bool = False,
Expand All @@ -126,6 +127,7 @@ def fused_experts(
moe_comm_method = get_forward_context().moe_comm_method
assert moe_comm_method is not None, "Missing communication context"

before_dispatch_evt = torch.npu.current_stream().record_event()
dispatch_results = self.token_dispatcher.token_dispatch(
hidden_states=hidden_states,
topk_weights=topk_weights,
Expand All @@ -134,9 +136,6 @@ def fused_experts(
log2phy=log2phy,
global_redundant_expert_num=self.moe_config.
global_redundant_expert_num,
shared_experts=shared_experts,
quantized_x_for_share=quantized_x_for_share,
dynamic_scale_for_share=dynamic_scale_for_share,
mc2_mask=mc2_mask,
apply_router_weight_on_input=apply_router_weight_on_input,
with_quant=use_int8_w8a8 or use_int4_w4a8,
Expand All @@ -162,12 +161,15 @@ def fused_experts(
need_trans=need_trans,
dynamic_eplb=dynamic_eplb)

before_combine_evt = torch.npu.current_stream().record_event()
combine_results = self.token_dispatcher.token_combine(
hidden_states=mlp_output,
context_metadata=dispatch_results.context_metadata)

return FusedExpertsResult(
routed_out=combine_results.routed_out,
before_dispatch_evt=before_dispatch_evt,
before_combine_evt=before_combine_evt,
group_list_type=dispatch_results.group_list_type,
expert_tokens=dispatch_results.group_list)

Expand Down Expand Up @@ -284,10 +286,6 @@ def fused_experts(
w2_scale_bias: torch.Tensor = None,
w1_offset: Optional[torch.Tensor] = None,
w2_offset: Optional[torch.Tensor] = None,
# For Cube/Vector parallel
shared_experts: Optional[Any] = None,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
# For load balance
log2phy: torch.Tensor = None,
need_trans: bool = False,
Expand Down
Loading