diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index af63117618c1..e5d734e21015 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -227,9 +227,11 @@ def __init__( prefix: str = "", tp_rank: Optional[int] = None, tp_size: Optional[int] = None, + swiglu_limit: Optional[float] = None, ) -> None: super().__init__() self.tp_size = tp_size + self.swiglu_limit = swiglu_limit self.gate_up_proj = MergedColumnParallelLinear( hidden_size, @@ -283,6 +285,12 @@ def forward( x = (x, None, y) gate_up, _ = self.gate_up_proj(x) + if self.swiglu_limit is not None: + _g, _u = gate_up.chunk(2, dim=-1) + _lim = float(self.swiglu_limit) + gate_up = torch.cat( + [_g.clamp(max=_lim), _u.clamp(min=-_lim, max=_lim)], dim=-1 + ) x = self.act_fn(gate_up) x, _ = self.down_proj( x, @@ -533,6 +541,7 @@ def __init__( hidden_act=config.hidden_act, quant_config=quant_config, reduce_results=False, + swiglu_limit=getattr(config, "swiglu_limit", None), prefix=add_prefix("shared_experts", prefix), **( dict(tp_rank=0, tp_size=1) @@ -2594,6 +2603,7 @@ def __init__( prefix=add_prefix("mlp", prefix), tp_rank=mlp_tp_rank, tp_size=mlp_tp_size, + swiglu_limit=getattr(config, "swiglu_limit", None), ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)