-
-
Notifications
You must be signed in to change notification settings - Fork 18.8k
[Perf] DSV3.2 Indexer Fused Weights Projection #38684
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -639,21 +639,19 @@ def __init__( | |||||||||||||||
| quant_config=quant_config, | ||||||||||||||||
| prefix=f"{prefix}.wq_b", | ||||||||||||||||
| ) | ||||||||||||||||
| self.wk = ReplicatedLinear( | ||||||||||||||||
| # Fused wk + weights_proj: single GEMM producing [head_dim + n_head]. | ||||||||||||||||
| # weights_proj does not get quantized, so we run both with quant_config=None | ||||||||||||||||
| # wk may be upcasted from the default quant; experiments show fusion is always | ||||||||||||||||
| # faster unless WK proj is in FP4, which is not the case for all known quants. | ||||||||||||||||
| self.wk_weights_proj = MergedColumnParallelLinear( | ||||||||||||||||
| hidden_size, | ||||||||||||||||
| self.head_dim, | ||||||||||||||||
| bias=False, | ||||||||||||||||
| quant_config=quant_config, | ||||||||||||||||
| prefix=f"{prefix}.wk", | ||||||||||||||||
| ) | ||||||||||||||||
| self.k_norm = LayerNorm(self.head_dim, eps=1e-6) | ||||||||||||||||
| self.weights_proj = ReplicatedLinear( | ||||||||||||||||
| hidden_size, | ||||||||||||||||
| self.n_head, | ||||||||||||||||
| [self.head_dim, self.n_head], | ||||||||||||||||
| bias=False, | ||||||||||||||||
| quant_config=None, | ||||||||||||||||
| prefix=f"{prefix}.weights_proj", | ||||||||||||||||
| disable_tp=True, | ||||||||||||||||
| prefix=f"{prefix}.wk_weights_proj", | ||||||||||||||||
| ) | ||||||||||||||||
|
Comment on lines
+646
to
653
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Forcing
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I can easily run DSR1 NVFP4 on this branch, which notably does not have the indexer module. I don't expect any DeepSeek checkpoints will have these weights pre-fused. If the bot has a more nuanced point here, I don't see it
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This actually makes sense. One proj with fp8 quant and the other without quant. It is not wise to fuse all of them. We should make a special case for the original checkpoint |
||||||||||||||||
| self.k_norm = LayerNorm(self.head_dim, eps=1e-6) | ||||||||||||||||
| self.softmax_scale = self.head_dim**-0.5 | ||||||||||||||||
|
|
||||||||||||||||
| self.scale_fmt = "ue8m0" | ||||||||||||||||
|
|
@@ -694,7 +692,11 @@ def forward( | |||||||||||||||
| q, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1 | ||||||||||||||||
| ) | ||||||||||||||||
|
|
||||||||||||||||
| k, _ = self.wk(hidden_states) | ||||||||||||||||
| # Fused wk + weights_proj: one GEMM, then split | ||||||||||||||||
| kw, _ = self.wk_weights_proj(hidden_states) | ||||||||||||||||
| k = kw[:, : self.head_dim] | ||||||||||||||||
| weights_raw = kw[:, self.head_dim :] | ||||||||||||||||
|
Comment on lines
+696
to
+698
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The slicing
Suggested change
|
||||||||||||||||
|
|
||||||||||||||||
| k = self.k_norm(k) | ||||||||||||||||
| k_pe, k_nope = torch.split( | ||||||||||||||||
| k, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1 | ||||||||||||||||
|
|
@@ -723,9 +725,8 @@ def forward( | |||||||||||||||
| q_fp8 = q_fp8.view(-1, self.n_head, self.head_dim) | ||||||||||||||||
| q_scale = q_scale.view(-1, self.n_head, 1) | ||||||||||||||||
|
|
||||||||||||||||
| weights, _ = self.weights_proj(hidden_states) | ||||||||||||||||
| weights = ( | ||||||||||||||||
| weights.unsqueeze(-1) * q_scale * self.softmax_scale * self.n_head**-0.5 | ||||||||||||||||
| weights_raw.unsqueeze(-1) * q_scale * self.softmax_scale * self.n_head**-0.5 | ||||||||||||||||
| ) | ||||||||||||||||
| weights = weights.squeeze(-1) | ||||||||||||||||
|
|
||||||||||||||||
|
|
@@ -1438,6 +1439,13 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: | |||||||||||||||
| ("qkv_proj", "k_proj", "k"), | ||||||||||||||||
| ("qkv_proj", "v_proj", "v"), | ||||||||||||||||
| ] | ||||||||||||||||
| # Fused indexer wk + weights_proj (shard 0 = wk, shard 1 = weights_proj) | ||||||||||||||||
| indexer_fused_mapping = [ | ||||||||||||||||
| ("wk_weights_proj", "wk", 0), | ||||||||||||||||
| ("wk_weights_proj", "weights_proj", 1), | ||||||||||||||||
| ] | ||||||||||||||||
| stacked_params_mapping.extend(indexer_fused_mapping) | ||||||||||||||||
|
Comment on lines
+1443
to
+1447
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The substring replacement logic in |
||||||||||||||||
|
|
||||||||||||||||
| if self.use_mha: | ||||||||||||||||
| stacked_params_mapping.extend(mha_params_mapping) | ||||||||||||||||
| else: | ||||||||||||||||
|
|
||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to the issue in
deepseek_v2.py, thewksubstring match can corrupt weight names if the checkpoint is already fused. Ensure thatwk_weights_projis included in the guard condition within theload_weightsloop (around line 292) to prevent incorrect mapping and potentialKeyErrorwhen the parameter is missing in certain model configurations.