Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions vllm/model_executor/models/deepseek_mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,9 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
("gate_up_proj", "up_proj", 1),
("fused_qkv_a_proj", "q_a_proj", 0),
("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
# Fused indexer wk + weights_proj
("wk_weights_proj", "wk", 0),
("wk_weights_proj", "weights_proj", 1),
Comment on lines +245 to +246

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the issue in deepseek_v2.py, the wk substring match can corrupt weight names if the checkpoint is already fused. Ensure that wk_weights_proj is included in the guard condition within the load_weights loop (around line 292) to prevent incorrect mapping and potential KeyError when the parameter is missing in certain model configurations.

]

expert_params_mapping = SharedFusedMoE.make_expert_params_mapping(
Expand Down
36 changes: 22 additions & 14 deletions vllm/model_executor/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,21 +639,19 @@ def __init__(
quant_config=quant_config,
prefix=f"{prefix}.wq_b",
)
self.wk = ReplicatedLinear(
# Fused wk + weights_proj: single GEMM producing [head_dim + n_head].
# weights_proj does not get quantized, so we run both with quant_config=None
# wk may be upcasted from the default quant; experiments show fusion is always
# faster unless WK proj is in FP4, which is not the case for all known quants.
self.wk_weights_proj = MergedColumnParallelLinear(
hidden_size,
self.head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.wk",
)
self.k_norm = LayerNorm(self.head_dim, eps=1e-6)
self.weights_proj = ReplicatedLinear(
hidden_size,
self.n_head,
[self.head_dim, self.n_head],
bias=False,
quant_config=None,
prefix=f"{prefix}.weights_proj",
disable_tp=True,
prefix=f"{prefix}.wk_weights_proj",
)
Comment on lines +646 to 653

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Forcing quant_config=None for wk_weights_proj will cause correctness issues when loading from quantized checkpoints (e.g., FP8). Since wk is typically quantized in DeepSeek-V3/V3.2 checkpoints, the weight_loader will attempt to bit-copy quantized weights into a non-quantized parameter without applying the necessary scales. The current weight_loader for MergedColumnParallelLinear does not handle on-the-fly dequantization. If fusion is required for performance, you must implement a custom weight loader that can dequantize wk during the loading process or ensure that the quantization configuration is correctly propagated.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can easily run DSR1 NVFP4 on this branch, which notably does not have the indexer module. I don't expect any DeepSeek checkpoints will have these weights pre-fused. If the bot has a more nuanced point here, I don't see it

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This actually makes sense. One proj with fp8 quant and the other without quant. It is not wise to fuse all of them. We should make a special case for the original checkpoint

self.k_norm = LayerNorm(self.head_dim, eps=1e-6)
self.softmax_scale = self.head_dim**-0.5

self.scale_fmt = "ue8m0"
Expand Down Expand Up @@ -694,7 +692,11 @@ def forward(
q, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1
)

k, _ = self.wk(hidden_states)
# Fused wk + weights_proj: one GEMM, then split
kw, _ = self.wk_weights_proj(hidden_states)
k = kw[:, : self.head_dim]
weights_raw = kw[:, self.head_dim :]
Comment on lines +696 to +698

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The slicing kw[:, : self.head_dim] assumes that kw is a 2D tensor. However, when using torch.compile or in certain prefill paths, hidden_states (and thus kw) can be a 3D tensor. In such cases, this slicing will produce incorrect results. It is safer to flatten hidden_states to a token-based representation before the projection, which also ensures consistency for the subsequent indexer_op call on line 733.

Suggested change
kw, _ = self.wk_weights_proj(hidden_states)
k = kw[:, : self.head_dim]
weights_raw = kw[:, self.head_dim :]
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
kw, _ = self.wk_weights_proj(hidden_states)
k = kw[:, : self.head_dim]
weights_raw = kw[:, self.head_dim :]


k = self.k_norm(k)
k_pe, k_nope = torch.split(
k, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1
Expand Down Expand Up @@ -723,9 +725,8 @@ def forward(
q_fp8 = q_fp8.view(-1, self.n_head, self.head_dim)
q_scale = q_scale.view(-1, self.n_head, 1)

weights, _ = self.weights_proj(hidden_states)
weights = (
weights.unsqueeze(-1) * q_scale * self.softmax_scale * self.n_head**-0.5
weights_raw.unsqueeze(-1) * q_scale * self.softmax_scale * self.n_head**-0.5
)
weights = weights.squeeze(-1)

Expand Down Expand Up @@ -1438,6 +1439,13 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
# Fused indexer wk + weights_proj (shard 0 = wk, shard 1 = weights_proj)
indexer_fused_mapping = [
("wk_weights_proj", "wk", 0),
("wk_weights_proj", "weights_proj", 1),
]
stacked_params_mapping.extend(indexer_fused_mapping)
Comment on lines +1443 to +1447

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The substring replacement logic in load_weights is susceptible to a bug where wk matches wk_weights_proj. If a checkpoint already contains fused weights (e.g., from a previous save), name.replace("wk", "wk_weights_proj") will corrupt the parameter name (e.g., resulting in ...wk_weights_proj_weights_proj). Additionally, this mapping is added unconditionally, which may cause crashes in non-V3.2 models where wk_weights_proj is not defined. Please add wk_weights_proj to the guard condition in the weight loading loop (around line 1503) to ensure it only attempts to map if the parameter exists in params_dict.


if self.use_mha:
stacked_params_mapping.extend(mha_params_mapping)
else:
Expand Down
Loading