diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index 8ca7b255075..6baa199bf43 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -14,7 +14,8 @@ from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.utils import (AscendDeviceType, enable_sp, flashcomm2_enable, get_ascend_device_type, has_layer_idx, - is_moe_model) + is_moe_model, + speculative_enable_dispatch_gmm_combine_decode) class MoECommType(Enum): @@ -242,7 +243,7 @@ def select_moe_comm_method(num_tokens: int, dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path # TODO: drop the EP-size guard when dispatch_ffn_combine supports larger EP sizes # TODO: drop dynamic_eplb guard when dispatch_gmm_combine_decode supports tensor list inputs - # TODO: add guard for dispatch_gmm_combine_decode when mtp uses float while moe uses w8a8 + # TODO: drop speculative method guard when dispatch_gmm_combine_decode supports w16a16 fused_mc2_enable = envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 and quant_type == "w8a8_dynamic" and ( not dynamic_eplb) if num_tokens <= mc2_tokens_capacity: @@ -250,6 +251,9 @@ def select_moe_comm_method(num_tokens: int, if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1: fused_decode_enable = fused_mc2_enable and get_ep_group( ).world_size <= 16 and (not is_draft_model) + elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2: + fused_decode_enable = fused_mc2_enable and \ + speculative_enable_dispatch_gmm_combine_decode(vllm_config) moe_comm_type = MoECommType.FUSED_MC2 if fused_decode_enable else MoECommType.MC2 else: fused_prefill_enable = fused_mc2_enable diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 97f8e2b66cc..9a1e53e3713 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -823,6 +823,23 @@ def is_moe_model(vllm_config: VllmConfig): return _IS_MOE_MODEL +def speculative_enable_dispatch_gmm_combine_decode( + vllm_config: VllmConfig) -> bool: + if vllm_config.speculative_config is None: + return True + speculative_method = getattr(vllm_config.speculative_config, "method", + None) + if speculative_method in [None, "ngram", "suffix"]: + return True + if speculative_method in ["eagle", "eagle3"]: + return False + if speculative_method == "mtp": + mtp_quant_type = getattr(vllm_config.model_config.hf_config, + "mtp_quantize", None) + return mtp_quant_type == "w8a8_dynamic" + return False + + def _is_contain_expert(config: Any): if isinstance(config, dict): for k, v in config.items():