diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 489994c18cd..0f115da8970 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -856,6 +856,7 @@ def is_drafter_moe_model(vllm_config: VllmConfig): def speculative_enable_dispatch_gmm_combine_decode( vllm_config: VllmConfig) -> bool: + """When draft contains MOE Arch and non-w8a8, disable dispatch_gmm_combine_decode.""" if vllm_config.speculative_config is None: return True speculative_method = getattr(vllm_config.speculative_config, "method", @@ -863,7 +864,15 @@ def speculative_enable_dispatch_gmm_combine_decode( if speculative_method in [None, "ngram", "suffix"]: return True if speculative_method in ["eagle", "eagle3"]: - return False + if is_drafter_moe_model(vllm_config): + draft_model_config = vllm_config.speculative_config.draft_model_config + hf_text_config = draft_model_config.hf_text_config + quant_type = getattr(hf_text_config, "moe_quantize", None) + if quant_type is None: + quant_type = getattr(hf_text_config, "quantize", None) + return quant_type == "w8a8_dynamic" + else: + return True if speculative_method == "mtp": mtp_quant_type = getattr(vllm_config.model_config.hf_text_config, "mtp_quantize", None)