From ecc2aca00b2b41f6663156623365304e331da68c Mon Sep 17 00:00:00 2001 From: Mohammad Miadh Angkad <176301910+mmangkad@users.noreply.github.com> Date: Tue, 31 Mar 2026 11:36:40 +0800 Subject: [PATCH] upd Signed-off-by: Mohammad Miadh Angkad <176301910+mmangkad@users.noreply.github.com> --- vllm/model_executor/models/qwen3_5_mtp.py | 34 +++++++++++++++++------ 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/models/qwen3_5_mtp.py b/vllm/model_executor/models/qwen3_5_mtp.py index 0eca47492c91..6f33505e69c4 100644 --- a/vllm/model_executor/models/qwen3_5_mtp.py +++ b/vllm/model_executor/models/qwen3_5_mtp.py @@ -15,6 +15,7 @@ from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, @@ -43,6 +44,15 @@ logger = init_logger(__name__) +def _get_qwen3_5_mtp_quant_config( + quant_config: QuantizationConfig | None, +) -> QuantizationConfig | None: + # Qwen3.5 NVFP4 checkpoints keep the entire MTP branch in bf16 weights. + if quant_config is not None and quant_config.get_name() == "modelopt_fp4": + return None + return quant_config + + @support_torch_compile( dynamic_arg_dims={ "input_ids": 0, @@ -59,7 +69,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() model_config = vllm_config.model_config - quant_config = vllm_config.quant_config + quant_config = _get_qwen3_5_mtp_quant_config(vllm_config.quant_config) config: Qwen3_5TextConfig | Qwen3_5MoeTextConfig = model_config.hf_text_config @@ -85,14 +95,20 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): prefix=f"{prefix}.fc", ) - self.layers = torch.nn.ModuleList( - Qwen3_5DecoderLayer( - vllm_config, - layer_type="full_attention", - prefix=f"{prefix}.layers.{idx}", + # Qwen3_5DecoderLayer reads quantization from vllm_config. + original_quant_config = vllm_config.quant_config + vllm_config.quant_config = quant_config + try: + self.layers = torch.nn.ModuleList( + Qwen3_5DecoderLayer( + vllm_config, + layer_type="full_attention", + prefix=f"{prefix}.layers.{idx}", + ) + for idx in range(self.num_mtp_layers) ) - for idx in range(self.num_mtp_layers) - ) + finally: + vllm_config.quant_config = original_quant_config self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size @@ -352,7 +368,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): "please use '--mamba-cache-mode=align' instead" ) - self.quant_config = vllm_config.quant_config + self.quant_config = _get_qwen3_5_mtp_quant_config(vllm_config.quant_config) super().__init__() self.config = config