diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 8e9c4c496df..95f5e7c6783 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -233,7 +233,8 @@ def forward(self, hidden_states): ) if ( - hidden_states.shape[0] < 4 + _is_cuda + and hidden_states.shape[0] < 4 and hidden_states.shape[1] == 7168 and self.weight.shape[0] == 256 and _device_sm >= 90