Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions vllm/model_executor/models/qwen3_5_mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from vllm.distributed.parallel_state import get_pp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
Expand Down Expand Up @@ -59,7 +59,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()

model_config = vllm_config.model_config
quant_config = vllm_config.quant_config
quant_config = vllm_config.quant_config # noqa: F841

config: Qwen3_5TextConfig | Qwen3_5MoeTextConfig = model_config.hf_text_config

Expand All @@ -75,13 +75,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config.hidden_size,
)

self.fc = ColumnParallelLinear(
self.fc = ReplicatedLinear(
self.config.hidden_size * 2,
self.config.hidden_size,
gather_output=True,
bias=False,
return_bias=False,
quant_config=quant_config,
# Not quantizing the MTP layer.
quant_config=None,
prefix=f"{prefix}.fc",
)

Expand Down Expand Up @@ -125,7 +124,7 @@ def forward(
inputs_embeds = self.pre_fc_norm_embedding(inputs_embeds)
hidden_states = self.pre_fc_norm_hidden(hidden_states)
hidden_states = torch.cat([inputs_embeds, hidden_states], dim=-1)
hidden_states = self.fc(hidden_states)
hidden_states, _ = self.fc(hidden_states)
residual = None
else:
assert intermediate_tensors is not None
Expand Down