Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 25 additions & 9 deletions vllm/model_executor/models/qwen3_5_mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from vllm.model_executor.layers.fused_moe import FusedMoE
Comment thread
mmangkad marked this conversation as resolved.
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
Expand Down Expand Up @@ -43,6 +44,15 @@
logger = init_logger(__name__)


def _get_qwen3_5_mtp_quant_config(
quant_config: QuantizationConfig | None,
) -> QuantizationConfig | None:
# Qwen3.5 NVFP4 checkpoints keep the entire MTP branch in bf16 weights.
if quant_config is not None and quant_config.get_name() == "modelopt_fp4":
return None
return quant_config


@support_torch_compile(
dynamic_arg_dims={
"input_ids": 0,
Expand All @@ -59,7 +69,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()

model_config = vllm_config.model_config
quant_config = vllm_config.quant_config
quant_config = _get_qwen3_5_mtp_quant_config(vllm_config.quant_config)

config: Qwen3_5TextConfig | Qwen3_5MoeTextConfig = model_config.hf_text_config

Expand All @@ -85,14 +95,20 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
prefix=f"{prefix}.fc",
)

self.layers = torch.nn.ModuleList(
Qwen3_5DecoderLayer(
vllm_config,
layer_type="full_attention",
prefix=f"{prefix}.layers.{idx}",
# Qwen3_5DecoderLayer reads quantization from vllm_config.
original_quant_config = vllm_config.quant_config
vllm_config.quant_config = quant_config
try:
self.layers = torch.nn.ModuleList(
Qwen3_5DecoderLayer(
vllm_config,
layer_type="full_attention",
prefix=f"{prefix}.layers.{idx}",
)
for idx in range(self.num_mtp_layers)
)
for idx in range(self.num_mtp_layers)
)
finally:
vllm_config.quant_config = original_quant_config
Comment thread
mmangkad marked this conversation as resolved.

self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size
Expand Down Expand Up @@ -352,7 +368,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
"please use '--mamba-cache-mode=align' instead"
)

self.quant_config = vllm_config.quant_config
self.quant_config = _get_qwen3_5_mtp_quant_config(vllm_config.quant_config)

super().__init__()
self.config = config
Expand Down
Loading