diff --git a/vllm/model_executor/models/qwen3_5_mtp.py b/vllm/model_executor/models/qwen3_5_mtp.py index e42403213da7..2e6c7052764b 100644 --- a/vllm/model_executor/models/qwen3_5_mtp.py +++ b/vllm/model_executor/models/qwen3_5_mtp.py @@ -13,7 +13,7 @@ from vllm.distributed.parallel_state import get_pp_group from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import FusedMoE -from vllm.model_executor.layers.linear import ColumnParallelLinear +from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, @@ -59,7 +59,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() model_config = vllm_config.model_config - quant_config = vllm_config.quant_config + quant_config = vllm_config.quant_config # noqa: F841 config: Qwen3_5TextConfig | Qwen3_5MoeTextConfig = model_config.hf_text_config @@ -75,13 +75,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config.hidden_size, ) - self.fc = ColumnParallelLinear( + self.fc = ReplicatedLinear( self.config.hidden_size * 2, self.config.hidden_size, - gather_output=True, bias=False, - return_bias=False, - quant_config=quant_config, + # Not quantizing the MTP layer. + quant_config=None, prefix=f"{prefix}.fc", ) @@ -125,7 +124,7 @@ def forward( inputs_embeds = self.pre_fc_norm_embedding(inputs_embeds) hidden_states = self.pre_fc_norm_hidden(hidden_states) hidden_states = torch.cat([inputs_embeds, hidden_states], dim=-1) - hidden_states = self.fc(hidden_states) + hidden_states, _ = self.fc(hidden_states) residual = None else: assert intermediate_tensors is not None