[Perf] DSV3.2 Indexer Fused Weights Projection#38684
[Perf] DSV3.2 Indexer Fused Weights Projection#38684robertgshaw2-redhat merged 3 commits intovllm-project:mainfrom
Conversation
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
There was a problem hiding this comment.
Code Review
This pull request fuses the "wk" and "weights_proj" layers into a single "MergedColumnParallelLinear" named "wk_weights_proj" within the DeepSeek-V2 and MTP models to optimize performance. Several critical issues were identified: forcing "quant_config=None" may lead to loading errors with quantized checkpoints, the tensor slicing in the forward pass assumes a 2D input which could fail with 3D tensors (a flattening suggestion was provided), and the weight loading logic is susceptible to name corruption and potential crashes due to substring matching and missing guard conditions.
| self.wk_weights_proj = MergedColumnParallelLinear( | ||
| hidden_size, | ||
| self.head_dim, | ||
| bias=False, | ||
| quant_config=quant_config, | ||
| prefix=f"{prefix}.wk", | ||
| ) | ||
| self.k_norm = LayerNorm(self.head_dim, eps=1e-6) | ||
| self.weights_proj = ReplicatedLinear( | ||
| hidden_size, | ||
| self.n_head, | ||
| [self.head_dim, self.n_head], | ||
| bias=False, | ||
| quant_config=None, | ||
| prefix=f"{prefix}.weights_proj", | ||
| disable_tp=True, | ||
| prefix=f"{prefix}.wk_weights_proj", | ||
| ) |
There was a problem hiding this comment.
Forcing quant_config=None for wk_weights_proj will cause correctness issues when loading from quantized checkpoints (e.g., FP8). Since wk is typically quantized in DeepSeek-V3/V3.2 checkpoints, the weight_loader will attempt to bit-copy quantized weights into a non-quantized parameter without applying the necessary scales. The current weight_loader for MergedColumnParallelLinear does not handle on-the-fly dequantization. If fusion is required for performance, you must implement a custom weight loader that can dequantize wk during the loading process or ensure that the quantization configuration is correctly propagated.
There was a problem hiding this comment.
I can easily run DSR1 NVFP4 on this branch, which notably does not have the indexer module. I don't expect any DeepSeek checkpoints will have these weights pre-fused. If the bot has a more nuanced point here, I don't see it
There was a problem hiding this comment.
This actually makes sense. One proj with fp8 quant and the other without quant. It is not wise to fuse all of them. We should make a special case for the original checkpoint
| kw, _ = self.wk_weights_proj(hidden_states) | ||
| k = kw[:, : self.head_dim] | ||
| weights_raw = kw[:, self.head_dim :] |
There was a problem hiding this comment.
The slicing kw[:, : self.head_dim] assumes that kw is a 2D tensor. However, when using torch.compile or in certain prefill paths, hidden_states (and thus kw) can be a 3D tensor. In such cases, this slicing will produce incorrect results. It is safer to flatten hidden_states to a token-based representation before the projection, which also ensures consistency for the subsequent indexer_op call on line 733.
| kw, _ = self.wk_weights_proj(hidden_states) | |
| k = kw[:, : self.head_dim] | |
| weights_raw = kw[:, self.head_dim :] | |
| hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) | |
| kw, _ = self.wk_weights_proj(hidden_states) | |
| k = kw[:, : self.head_dim] | |
| weights_raw = kw[:, self.head_dim :] |
| indexer_fused_mapping = [ | ||
| ("wk_weights_proj", "wk", 0), | ||
| ("wk_weights_proj", "weights_proj", 1), | ||
| ] | ||
| stacked_params_mapping.extend(indexer_fused_mapping) |
There was a problem hiding this comment.
The substring replacement logic in load_weights is susceptible to a bug where wk matches wk_weights_proj. If a checkpoint already contains fused weights (e.g., from a previous save), name.replace("wk", "wk_weights_proj") will corrupt the parameter name (e.g., resulting in ...wk_weights_proj_weights_proj). Additionally, this mapping is added unconditionally, which may cause crashes in non-V3.2 models where wk_weights_proj is not defined. Please add wk_weights_proj to the guard condition in the weight loading loop (around line 1503) to ensure it only attempts to map if the parameter exists in params_dict.
| ("wk_weights_proj", "wk", 0), | ||
| ("wk_weights_proj", "weights_proj", 1), |
There was a problem hiding this comment.
Similar to the issue in deepseek_v2.py, the wk substring match can corrupt weight names if the checkpoint is already fused. Ensure that wk_weights_proj is included in the guard condition within the load_weights loop (around line 292) to prevent incorrect mapping and potential KeyError when the parameter is missing in certain model configurations.
LopezCastroRoberto
left a comment
There was a problem hiding this comment.
LGTM. Projections are tiny, bandwidth-bound, and same precision. Fusion wins :)
|
great job ben, love picking the low hanging fruit |
|
@benchislett This change is breaking the fp8 checkpoint. Can you look into it? Thanks |
|
@robertgshaw2-redhat @benchislett can you revert this? This is breaking the fp8 model. |
|
fixed by #38870 |
GLM-5-FP8 checkpoints quantize the fused wk_weights_proj tensor with FP8 block quantization (weight + weight_scale_inv). Resolve merge conflict with upstream indexer refactor (vllm-project#38684/vllm-project#38870) by always using fused MergedColumnParallelLinear with quant_config: - FP4: quant_config=None (weights_proj should not be quantized) - Non-FP4: quant_config=quant_config (supports FP8 weight_scale_inv) Add fallback in load_weights to handle both fused and separate checkpoint formats gracefully via stacked_params_mapping. Also reverts glm_moe_dsa from _DEEPSEEK_V3_FAMILY_MODEL_TYPES per review feedback (will be submitted as a standalone PR). Co-authored-by: Claude Signed-off-by: Chuan Li <Chuan.Li2@amd.com> Made-with: Cursor
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com> Signed-off-by: Rishi Puri <riship@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
GLM-5-FP8 checkpoints quantize the fused wk_weights_proj tensor with FP8 block quantization (weight + weight_scale_inv). Resolve merge conflict with upstream indexer refactor (vllm-project#38684/vllm-project#38870) by always using fused MergedColumnParallelLinear with quant_config: - FP4: quant_config=None (weights_proj should not be quantized) - Non-FP4: quant_config=quant_config (supports FP8 weight_scale_inv) Add fallback in load_weights to handle both fused and separate checkpoint formats gracefully via stacked_params_mapping. Also reverts glm_moe_dsa from _DEEPSEEK_V3_FAMILY_MODEL_TYPES per review feedback (will be submitted as a standalone PR). Co-authored-by: Claude Signed-off-by: Chuan Li <Chuan.Li2@amd.com> Made-with: Cursor
Purpose
Fuse the WK and Weights_Proj projections in the DSV3.2 Indexer. This is an alternative optimization to #35968, which overlaps the projections instead of fusing them. Doing the fusion provides a greater speedup:
Benchmark timings for DSV3.2 NVFP4 on 8xB200 (TP8, No Specdec)
BS128 8k/1k
BS1 8k/1k
Testing
GSM8k shows a slight decrease.
PR (2 runs):
Main (2 runs):