-
-
Notifications
You must be signed in to change notification settings - Fork 18.5k
[ROCm] Cast score correction bias tensor during model construction for DeepSeek/Kimi-K2 #39999
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
fdb7403
34dd717
c5086a0
aac707d
2857bbe
04534fd
cfaa32e
f9945ac
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -351,6 +351,16 @@ def __init__( | |
| else torch.bfloat16 | ||
| ) | ||
|
|
||
| # Pre-cast the bias to match the gate output dtype so the | ||
| # conversion is not repeated on every forward pass. All | ||
| # downstream references (FusedMoE, router) share the same | ||
| # nn.Parameter object, so mutating .data propagates everywhere. | ||
| # Weight loading uses copy_(), which handles the dtype conversion. | ||
| if self.gate.e_score_correction_bias is not None: | ||
| self.gate.e_score_correction_bias.data = ( | ||
| self.gate.e_score_correction_bias.data.to(self.gate.out_dtype) | ||
| ) | ||
|
heachary marked this conversation as resolved.
heachary marked this conversation as resolved.
Outdated
Comment on lines
+351
to
+364
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this block of code could live in
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I mentioned already in my previous comment why thats a harder change that i decided to skip. Let me elaborate with some details here: Moving the bias pre-cast (lines 354-367) into FusedMoE.init() isn't standalone — it depends on gate.set_out_dtype() which is called just above it, and that call relies on self.experts.quant_method.is_monolithic and self.experts.routing_method_type — both only available after FusedMoE.init() completes. So both blocks ( The concern is that this becomes more invasive: every model passing gate= to FusedMoE — including qwen3_moe, qwen3_next, step3p5, and AXK1 — would now have set_out_dtype called automatically in FusedMoE.init(), which changes their gate output dtype behavior even though they don't currently call set_out_dtype at all. If this is not a big concern, I would like to leave this section as is to minimise the impact. |
||
|
|
||
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | ||
| num_tokens, hidden_dim = hidden_states.shape | ||
| hidden_states = hidden_states.view(-1, hidden_dim) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.