diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 327e04c657d..e87eb21f1b2 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -92,7 +92,6 @@ block_quant_dequant, block_quant_to_tensor_quant, channel_quant_to_tensor_quant, - input_to_float8, normalize_e4m3fn_to_e4m3fnuz, requant_weight_ue8m0_inplace, ) @@ -1619,15 +1618,15 @@ def forward_absorb_prepare( self.w_kc.to(torch.bfloat16) * self.w_scale, ) elif self.w_kc.dtype == torch.float8_e4m3fn: - # TODO fix the per_tensor_quant_mla_fp8 for cublas 12.9 - if _is_cublas_ge_129: - q_nope_val, q_nope_scale = input_to_float8( - q_nope.transpose(0, 1), torch.float8_e4m3fn - ) - else: - q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8( - q_nope.transpose(0, 1), zero_allocator.allocate(1) - ) + # fix bmm_fp8 error under cublas12.9 caused by bumpallocator, detail in pr#11612 + q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8( + q_nope.transpose(0, 1), + ( + torch.zeros((1,), dtype=torch.float32, device=q_nope.device) + if _is_cublas_ge_129 + else zero_allocator.allocate(1) + ), + ) q_nope_out = bmm_fp8( q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16 ) @@ -1768,14 +1767,14 @@ def forward_absorb_core( attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2) elif self.w_vc.dtype == torch.float8_e4m3fn: - if _is_cublas_ge_129: - attn_output_val, attn_output_scale = input_to_float8( - attn_output.transpose(0, 1), torch.float8_e4m3fn - ) - else: - attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8( - attn_output.transpose(0, 1), zero_allocator.allocate(1) - ) + attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8( + attn_output.transpose(0, 1), + ( + torch.zeros((1,), dtype=torch.float32, device=attn_output.device) + if _is_cublas_ge_129 + else zero_allocator.allocate(1) + ), + ) attn_bmm_output = bmm_fp8( attn_output_val, self.w_vc,