Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions vllm_ascend/ascend_forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,12 +253,24 @@ def select_moe_comm_method(num_tokens: int,
ascend_config = get_ascend_config()
dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path
# TODO: drop the EP-size guard when dispatch_ffn_combine supports larger EP sizes
fused_mc2_enable = envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 and quant_type == "w8a8_dynamic" and get_ep_group(
).world_size <= 16 and (not dynamic_eplb) and (not is_mtp_model)
# TODO: drop dynamic_eplb guard when dispatch_gmm_combine_decode supports tensor list inputs
# TODO: add guard for dispatch_gmm_combine_decode when mtp uses float while moe uses w8a8
fused_mc2_enable = envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 and quant_type == "w8a8_dynamic" and (
not dynamic_eplb)
if num_tokens <= mc2_tokens_capacity:
moe_comm_type = MoECommType.FUSED_MC2 if fused_mc2_enable else MoECommType.MC2
fused_decode_enable = fused_mc2_enable
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
fused_decode_enable = fused_mc2_enable and get_ep_group(
).world_size <= 16 and (not is_mtp_model)
moe_comm_type = MoECommType.FUSED_MC2 if fused_decode_enable else MoECommType.MC2
else:
moe_comm_type = MoECommType.FUSED_MC2 if fused_mc2_enable else MoECommType.ALLTOALL
fused_prefill_enable = fused_mc2_enable
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
fused_prefill_enable = fused_mc2_enable and get_ep_group(
).world_size <= 16 and (not is_mtp_model)
elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2:
fused_prefill_enable = False
moe_comm_type = MoECommType.FUSED_MC2 if fused_prefill_enable else MoECommType.ALLTOALL

else:
raise ValueError(f"Unsupported soc_version: {soc_version}")
Expand Down
8 changes: 7 additions & 1 deletion vllm_ascend/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,13 @@
# Whether to anbale dynamic EPLB
"DYNAMIC_EPLB":
lambda: os.getenv("DYNAMIC_EPLB", "false").lower(),
# Whether to anbale fused mc2(dispatch_gmm_combine_decode/dispatch_ffn_combine operator)
# Whether to enable fused mc2(`dispatch_gmm_combine_decode`/`dispatch_ffn_combine` operator)
# 0, or not set: default ALLTOALL and MC2 will be used.
# 1: ALLTOALL and MC2 might be replaced by `dispatch_ffn_combine` operator.
# `dispatch_ffn_combine` can be used only for moe layer with W8A8, EP<=16, non-mtp, non-dynamic-eplb.
# 2: MC2 might be replaced by `dispatch_gmm_combine_decode` operator.
# `dispatch_gmm_combine_decode` can be used only for **decode node** moe layer
# with W8A8, non-dynamic-eplb. And MTP layer must be W8A8.
"VLLM_ASCEND_ENABLE_FUSED_MC2":
lambda: int(os.getenv("VLLM_ASCEND_ENABLE_FUSED_MC2", '0')),
}
Expand Down
2 changes: 1 addition & 1 deletion vllm_ascend/ops/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def forward_impl(self, hidden_states: torch.Tensor,
shared_out = fc3_context.shared_experts(hidden_states)
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
moe_comm_type = forward_context.moe_comm_type
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2} \
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2} \
and not shared_expert_dp_enabled():
shared_out = tensor_model_parallel_all_reduce(shared_out)
set_flash_common3_context(shared_out=shared_out)
Expand Down
18 changes: 16 additions & 2 deletions vllm_ascend/ops/fused_moe/moe_comm_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,9 +291,9 @@ def fused_experts(
assert not (
w1_scale is None or w2_scale is None
), "w1_scale and w2_scale cannot be None for FusedMC2CommImpl."
out = torch.empty_like(hidden_states)

if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
out = torch.empty_like(hidden_states)
torch.ops._C_ascend.dispatch_ffn_combine(
x=hidden_states,
weight1=w1[0],
Expand All @@ -307,7 +307,21 @@ def fused_experts(
out=out,
)
elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2:
raise NotImplementedError()
assert expert_map is not None, "expert_map cannot be None."
out, _ = torch.ops._C_ascend.dispatch_gmm_combine_decode(
x=hidden_states,
expert_ids=topk_ids,
gmm1_permuted_weight=w1[0],
gmm1_permuted_weight_scale=w1_scale[0],
gmm2_weight=w2[0],
gmm2_weight_scale=w2_scale[0],
expert_smooth_scales=None,
expert_scales=topk_weights.to(torch.float32),
group_ep=self.token_dispatcher.moe_all_to_all_group_name,
ep_rank_size=self.token_dispatcher.ep_world_size,
ep_rank_id=self.token_dispatcher.ep_rank_id,
moe_expert_num=len(expert_map),
global_bs=self.token_dispatcher.fused_global_bs)
else:
raise ValueError(
f"Wrong value of {envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2=}")
Expand Down
1 change: 1 addition & 0 deletions vllm_ascend/ops/fused_moe/token_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def __init__(self, **kwargs):
max_num_tokens = min(max_num_reqs * uniform_decode_query_len, 512)
num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1) // tp_size
self.global_bs = num_tokens_per_tp_rank * self.ep_world_size
self.fused_global_bs = max_num_tokens * self.ep_world_size

def get_dispatch_mc2_kwargs(
self,
Expand Down
11 changes: 10 additions & 1 deletion vllm_ascend/quantization/w8a8_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,10 @@ def apply(
topk_weights = topk_weights.to(self.in_dtype)

moe_comm_method = get_forward_context().moe_comm_method
# When VLLM_ASCEND_ENABLE_FUSED_MC2 == 2, use dispatch_gmm_combine_decode, need fp32 scale
w2_weight_scale_fp32_flag = (
get_forward_context().moe_comm_type == MoECommType.FUSED_MC2
and envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2)
if self.dynamic_eplb:
w1 = layer.w13_weight_list
w1_scale = layer.w13_weight_scale_fp32_list
Expand All @@ -240,7 +244,10 @@ def apply(
w1 = [layer.w13_weight]
w1_scale = [layer.w13_weight_scale_fp32]
w2 = [layer.w2_weight]
w2_scale = [layer.w2_weight_scale]
w2_scale = [
layer.w2_weight_scale_fp32
if w2_weight_scale_fp32_flag else layer.w2_weight_scale
]

fused_scale_flag = (get_forward_context().moe_comm_type
== MoECommType.FUSED_MC2
Expand Down Expand Up @@ -279,6 +286,8 @@ def process_weights_after_loading(self, layer):
layer.w13_weight_offset.data.shape[0], -1)
layer.w2_weight_scale.data = layer.w2_weight_scale.data.view(
layer.w2_weight_scale.data.shape[0], -1)
layer.w2_weight_scale_fp32 = layer.w2_weight_scale.data.to(
torch.float32)
layer.w2_weight_offset.data = layer.w2_weight_offset.data.view(
layer.w2_weight_offset.data.shape[0], -1)

Expand Down
Loading