Skip to content

[Perf] DSV3.2 Indexer Fused Weights Projection#38684

Merged
robertgshaw2-redhat merged 3 commits intovllm-project:mainfrom
CentML:perf/deepseek-fused-wk
Apr 2, 2026
Merged

[Perf] DSV3.2 Indexer Fused Weights Projection#38684
robertgshaw2-redhat merged 3 commits intovllm-project:mainfrom
CentML:perf/deepseek-fused-wk

Conversation

@benchislett
Copy link
Copy Markdown
Collaborator

Purpose

Fuse the WK and Weights_Proj projections in the DSV3.2 Indexer. This is an alternative optimization to #35968, which overlaps the projections instead of fusing them. Doing the fusion provides a greater speedup:

Benchmark timings for DSV3.2 NVFP4 on 8xB200 (TP8, No Specdec)

BS128 8k/1k

Fused WK+WeightsProj (Decode):  21.6 ms
Baseline             (Decode):  22.3 ms

BS1 8k/1k

Fused WK+WeightsProj (Decode):  9.5 ms
Overlap              (Decode):  9.8  ms
Baseline             (Decode):  10.2  ms

Fused WK+WeightsProj (TTFT):  339 ms
Overlap              (TTFT):  340 ms
Baseline             (TTFT):  341 ms

Testing

GSM8k shows a slight decrease.

PR (2 runs):

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9500|±  | 0.006|
|     |       |strict-match    |     5|exact_match|↑  |0.9492|±  | 0.006|

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9553|±  |0.0057|
|     |       |strict-match    |     5|exact_match|↑  |0.9530|±  |0.0058|

Main (2 runs):

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9538|±  |0.0058|
|     |       |strict-match    |     5|exact_match|↑  |0.9545|±  |0.0057|

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9462|±  |0.0062|
|     |       |strict-match    |     5|exact_match|↑  |0.9462|±  |0.0062|

Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
@benchislett benchislett requested a review from luccafong as a code owner April 1, 2026 03:59
@mergify mergify Bot added the deepseek Related to DeepSeek models label Apr 1, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request fuses the "wk" and "weights_proj" layers into a single "MergedColumnParallelLinear" named "wk_weights_proj" within the DeepSeek-V2 and MTP models to optimize performance. Several critical issues were identified: forcing "quant_config=None" may lead to loading errors with quantized checkpoints, the tensor slicing in the forward pass assumes a 2D input which could fail with 3D tensors (a flattening suggestion was provided), and the weight loading logic is susceptible to name corruption and potential crashes due to substring matching and missing guard conditions.

Comment on lines +646 to 653
self.wk_weights_proj = MergedColumnParallelLinear(
hidden_size,
self.head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.wk",
)
self.k_norm = LayerNorm(self.head_dim, eps=1e-6)
self.weights_proj = ReplicatedLinear(
hidden_size,
self.n_head,
[self.head_dim, self.n_head],
bias=False,
quant_config=None,
prefix=f"{prefix}.weights_proj",
disable_tp=True,
prefix=f"{prefix}.wk_weights_proj",
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Forcing quant_config=None for wk_weights_proj will cause correctness issues when loading from quantized checkpoints (e.g., FP8). Since wk is typically quantized in DeepSeek-V3/V3.2 checkpoints, the weight_loader will attempt to bit-copy quantized weights into a non-quantized parameter without applying the necessary scales. The current weight_loader for MergedColumnParallelLinear does not handle on-the-fly dequantization. If fusion is required for performance, you must implement a custom weight loader that can dequantize wk during the loading process or ensure that the quantization configuration is correctly propagated.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can easily run DSR1 NVFP4 on this branch, which notably does not have the indexer module. I don't expect any DeepSeek checkpoints will have these weights pre-fused. If the bot has a more nuanced point here, I don't see it

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This actually makes sense. One proj with fp8 quant and the other without quant. It is not wise to fuse all of them. We should make a special case for the original checkpoint

Comment on lines +696 to +698
kw, _ = self.wk_weights_proj(hidden_states)
k = kw[:, : self.head_dim]
weights_raw = kw[:, self.head_dim :]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The slicing kw[:, : self.head_dim] assumes that kw is a 2D tensor. However, when using torch.compile or in certain prefill paths, hidden_states (and thus kw) can be a 3D tensor. In such cases, this slicing will produce incorrect results. It is safer to flatten hidden_states to a token-based representation before the projection, which also ensures consistency for the subsequent indexer_op call on line 733.

Suggested change
kw, _ = self.wk_weights_proj(hidden_states)
k = kw[:, : self.head_dim]
weights_raw = kw[:, self.head_dim :]
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
kw, _ = self.wk_weights_proj(hidden_states)
k = kw[:, : self.head_dim]
weights_raw = kw[:, self.head_dim :]

Comment on lines +1443 to +1447
indexer_fused_mapping = [
("wk_weights_proj", "wk", 0),
("wk_weights_proj", "weights_proj", 1),
]
stacked_params_mapping.extend(indexer_fused_mapping)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The substring replacement logic in load_weights is susceptible to a bug where wk matches wk_weights_proj. If a checkpoint already contains fused weights (e.g., from a previous save), name.replace("wk", "wk_weights_proj") will corrupt the parameter name (e.g., resulting in ...wk_weights_proj_weights_proj). Additionally, this mapping is added unconditionally, which may cause crashes in non-V3.2 models where wk_weights_proj is not defined. Please add wk_weights_proj to the guard condition in the weight loading loop (around line 1503) to ensure it only attempts to map if the parameter exists in params_dict.

Comment on lines +245 to +246
("wk_weights_proj", "wk", 0),
("wk_weights_proj", "weights_proj", 1),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the issue in deepseek_v2.py, the wk substring match can corrupt weight names if the checkpoint is already fused. Ensure that wk_weights_proj is included in the guard condition within the load_weights loop (around line 292) to prevent incorrect mapping and potential KeyError when the parameter is missing in certain model configurations.

Copy link
Copy Markdown
Contributor

@LopezCastroRoberto LopezCastroRoberto left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Projections are tiny, bandwidth-bound, and same precision. Fusion wins :)

@robertgshaw2-redhat robertgshaw2-redhat enabled auto-merge (squash) April 1, 2026 16:22
@robertgshaw2-redhat
Copy link
Copy Markdown
Collaborator

great job ben, love picking the low hanging fruit

@github-actions github-actions Bot added the ready ONLY add when PR is ready to merge/full CI is needed label Apr 1, 2026
@robertgshaw2-redhat robertgshaw2-redhat merged commit 5f96f9a into vllm-project:main Apr 2, 2026
58 checks passed
@benchislett benchislett deleted the perf/deepseek-fused-wk branch April 2, 2026 03:49
vllm-agent pushed a commit to vllm-agent/vllm that referenced this pull request Apr 2, 2026
@zyongye
Copy link
Copy Markdown
Member

zyongye commented Apr 2, 2026

@benchislett This change is breaking the fp8 checkpoint. Can you look into it? Thanks

@zyongye
Copy link
Copy Markdown
Member

zyongye commented Apr 3, 2026

@robertgshaw2-redhat @benchislett can you revert this? This is breaking the fp8 model.

@zyongye zyongye mentioned this pull request Apr 3, 2026
5 tasks
@zyongye
Copy link
Copy Markdown
Member

zyongye commented Apr 3, 2026

fixed by #38870

ChuanLi1101 added a commit to ChuanLi1101/vllm that referenced this pull request Apr 5, 2026
GLM-5-FP8 checkpoints quantize the fused wk_weights_proj tensor with
FP8 block quantization (weight + weight_scale_inv).  Resolve merge
conflict with upstream indexer refactor (vllm-project#38684/vllm-project#38870) by always using
fused MergedColumnParallelLinear with quant_config:
- FP4: quant_config=None (weights_proj should not be quantized)
- Non-FP4: quant_config=quant_config (supports FP8 weight_scale_inv)

Add fallback in load_weights to handle both fused and separate checkpoint
formats gracefully via stacked_params_mapping.

Also reverts glm_moe_dsa from _DEEPSEEK_V3_FAMILY_MODEL_TYPES per
review feedback (will be submitted as a standalone PR).

Co-authored-by: Claude
Signed-off-by: Chuan Li <Chuan.Li2@amd.com>
Made-with: Cursor
HenryTangDev pushed a commit to HenryTangMain/vllm that referenced this pull request Apr 6, 2026
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
puririshi98 pushed a commit to puririshi98/vllm that referenced this pull request Apr 7, 2026
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Rishi Puri <riship@nvidia.com>
mtparet pushed a commit to blackfuel-ai/vllm that referenced this pull request Apr 9, 2026
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
mtparet pushed a commit to blackfuel-ai/vllm that referenced this pull request Apr 9, 2026
GLM-5-FP8 checkpoints quantize the fused wk_weights_proj tensor with
FP8 block quantization (weight + weight_scale_inv).  Resolve merge
conflict with upstream indexer refactor (vllm-project#38684/vllm-project#38870) by always using
fused MergedColumnParallelLinear with quant_config:
- FP4: quant_config=None (weights_proj should not be quantized)
- Non-FP4: quant_config=quant_config (supports FP8 weight_scale_inv)

Add fallback in load_weights to handle both fused and separate checkpoint
formats gracefully via stacked_params_mapping.

Also reverts glm_moe_dsa from _DEEPSEEK_V3_FAMILY_MODEL_TYPES per
review feedback (will be submitted as a standalone PR).

Co-authored-by: Claude
Signed-off-by: Chuan Li <Chuan.Li2@amd.com>
Made-with: Cursor
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

deepseek Related to DeepSeek models ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants