-
Notifications
You must be signed in to change notification settings - Fork 5.7k
[DeepSeek V4] Fix meaningless numbers in chat output by adding swiglu_limit clamp to DeepseekV2MLP #23776
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DeepSeek V4] Fix meaningless numbers in chat output by adding swiglu_limit clamp to DeepseekV2MLP #23776
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -227,9 +227,11 @@ def __init__( | |
| prefix: str = "", | ||
| tp_rank: Optional[int] = None, | ||
| tp_size: Optional[int] = None, | ||
| swiglu_limit: Optional[float] = None, | ||
| ) -> None: | ||
| super().__init__() | ||
| self.tp_size = tp_size | ||
| self.swiglu_limit = swiglu_limit | ||
|
|
||
| self.gate_up_proj = MergedColumnParallelLinear( | ||
| hidden_size, | ||
|
|
@@ -283,6 +285,12 @@ def forward( | |
| x = (x, None, y) | ||
|
|
||
| gate_up, _ = self.gate_up_proj(x) | ||
| if self.swiglu_limit is not None: | ||
| _g, _u = gate_up.chunk(2, dim=-1) | ||
| _lim = float(self.swiglu_limit) | ||
| gate_up = torch.cat( | ||
| [_g.clamp(max=_lim), _u.clamp(min=-_lim, max=_lim)], dim=-1 | ||
| ) | ||
|
Comment on lines
+288
to
+293
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using if self.swiglu_limit is not None:
_lim = self.swiglu_limit
_g, _u = gate_up.chunk(2, dim=-1)
_g.clamp_(max=_lim)
_u.clamp_(min=-_lim, max=_lim) |
||
| x = self.act_fn(gate_up) | ||
| x, _ = self.down_proj( | ||
| x, | ||
|
|
@@ -533,6 +541,7 @@ def __init__( | |
| hidden_act=config.hidden_act, | ||
| quant_config=quant_config, | ||
| reduce_results=False, | ||
| swiglu_limit=getattr(config, "swiglu_limit", None), | ||
| prefix=add_prefix("shared_experts", prefix), | ||
| **( | ||
| dict(tp_rank=0, tp_size=1) | ||
|
|
@@ -2594,6 +2603,7 @@ def __init__( | |
| prefix=add_prefix("mlp", prefix), | ||
| tp_rank=mlp_tp_rank, | ||
| tp_size=mlp_tp_size, | ||
| swiglu_limit=getattr(config, "swiglu_limit", None), | ||
| ) | ||
|
|
||
| self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is better to cast
swiglu_limitto a float once during initialization to avoid repeated casting in the forward pass, which is on the hot path.