Skip to content

Commit 9ff7384

Browse files
committed
fix
1 parent 6405c37 commit 9ff7384

File tree

4 files changed

+7
-11
lines changed

4 files changed

+7
-11
lines changed

vllm_ascend/ascend_forward_context.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,9 @@ def set_ascend_forward_context(
8181
batch_descriptor=batch_descriptor,
8282
):
8383
forward_context = get_forward_context()
84+
if moe_comm_method == "allgather" and with_prefill:
85+
moe_comm_method = "naivemulticast"
86+
8487
forward_context.moe_comm_method_name = moe_comm_method + "commimpl"
8588
forward_context.with_prefill = with_prefill
8689
ep_size = (get_ep_group().world_size if

vllm_ascend/ops/moe/fused_moe_prepare_and_finalize.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,8 @@ def prepare(self,
264264
rm_router_logits: bool = False,
265265
replace_allreduce: bool = False,
266266
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
267+
self.enable_shared_expert_dp = enable_shared_expert_dp
268+
267269
if self.moe_config.dp_size > 1:
268270
self.cu_tokens_across_dp_cpu = get_forward_context(
269271
).dp_metadata.cu_tokens_across_dp_cpu

vllm_ascend/ops/moe/moe_comm_method.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -90,16 +90,8 @@ def fused_experts(
9090
# For load balance
9191
log2phy: torch.Tensor = None,
9292
global_redundant_expert_num: int = 0,
93-
fusion_mlp: bool = False,
9493
need_trans: bool = False) -> torch.Tensor:
9594
# Check constraints
96-
assert hidden_states.shape[1] == w1.shape[1], (
97-
f"Hidden size mismatch {hidden_states.shape[1]} != {w1.shape[1]}")
98-
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
99-
assert hidden_states.is_contiguous(
100-
), "Hidden_states must be contiguous"
101-
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
102-
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
10395
assert hidden_states.dtype in [
10496
torch.float32, torch.float16, torch.bfloat16
10597
]
@@ -137,7 +129,7 @@ def fused_experts(
137129
w2_scale_bias=w2_scale_bias,
138130
with_quant=use_int8_w8a8
139131
or use_int4_w4a8,
140-
fusion=fusion_mlp,
132+
fusion=use_int8_w8a8,
141133
need_trans=need_trans)
142134

143135
hidden_states[:] = self.token_dispatcher.token_combine(

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -265,8 +265,7 @@ def apply(
265265
global_redundant_expert_num=global_redundant_expert_num,
266266
shared_experts=shared_experts,
267267
shared_gate_up=shared_gate_up,
268-
shared_dequant_scale=shared_dequant_scale,
269-
fusion_mlp=True
268+
shared_dequant_scale=shared_dequant_scale
270269
)
271270

272271
# return unified_fused_experts_eager(

0 commit comments

Comments
 (0)