diff --git a/python/sglang/srt/models/glm4_moe.py b/python/sglang/srt/models/glm4_moe.py index 121a31630ae3..fcf62efc8727 100644 --- a/python/sglang/srt/models/glm4_moe.py +++ b/python/sglang/srt/models/glm4_moe.py @@ -327,9 +327,14 @@ def __init__( self.e_score_correction_bias = nn.Parameter( torch.empty((config.n_routed_experts), dtype=torch.float32) ) + # GLM requires FP32 gate projection; cache to avoid per-forward cast. + # FIXME: if gate weight is updated at runtime (e.g. expert rebalancing), _weight_fp32 must be invalidated. + self.register_buffer("_weight_fp32", None, persistent=False) def forward(self, hidden_states): - logits = F.linear(hidden_states, self.weight, None) + if self._weight_fp32 is None: + self._weight_fp32 = self.weight.data.to(torch.float32) + logits = F.linear(hidden_states.to(torch.float32), self._weight_fp32, None) return logits