diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index 571e061067fd..4dddd407f296 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -45,7 +45,10 @@ if _use_aiter: import aiter - from aiter import gemm_a8w8_blockscale, gemm_a8w8_bpreshuffle, get_hip_quant + + # from aiter import gemm_a8w8_blockscale, gemm_a8w8_bpreshuffle, get_hip_quant + from aiter import gemm_a8w8_bpreshuffle, get_hip_quant + from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale aiter_per1x128_quant = get_hip_quant(aiter.QuantType.per_1x128) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index ca8586f7ca0f..00314bf018db 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -167,6 +167,9 @@ _use_aiter_gfx95 = _use_aiter and _is_gfx95_supported if _use_aiter_gfx95: + from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import ( + batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant, + ) from aiter.ops.triton.fused_fp8_quant import fused_rms_fp8_group_quant from sglang.srt.layers.quantization.quark.utils import quark_post_load_weights @@ -1813,10 +1816,25 @@ def forward_absorb_prepare( q_nope_out, ) else: - q_nope_out = torch.bmm( - q_nope.to(torch.bfloat16).transpose(0, 1), - self.w_kc.to(torch.bfloat16) * self.w_scale, - ) + if _use_aiter_gfx95 and self.w_kc.dtype == torch.float8_e4m3fn: + + q_nope_out = batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant( + X=q_nope, + WQ=self.w_kc.transpose(-1, -2), + w_scale=self.w_scale, + group_size=128, + YQ=None, # allocate (B, M, N) + transpose_bm=False, # (B, M, N) + transpose_bm_in=True, # (M, B, K) + dtype=torch.bfloat16, + ) + + else: + q_nope_out = torch.bmm( + q_nope.to(torch.bfloat16).transpose(0, 1), + self.w_kc.to(torch.bfloat16) * self.w_scale, + ) + elif self.w_kc.dtype == torch.float8_e4m3fn: # fix bmm_fp8 error under cublas12.9 caused by bumpallocator, detail in pr#11612 q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8( @@ -1964,10 +1982,22 @@ def forward_absorb_core( attn_bmm_output, ) else: - attn_bmm_output = torch.bmm( - attn_output.to(torch.bfloat16).transpose(0, 1), - self.w_vc.to(torch.bfloat16) * self.w_scale, - ) + if _use_aiter_gfx95 and self.w_kc.dtype == torch.float8_e4m3fn: + attn_bmm_output = batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant( + X=attn_output, + WQ=self.w_vc.transpose(-1, -2), + w_scale=self.w_scale, + group_size=128, + YQ=None, + transpose_bm=False, + transpose_bm_in=True, + dtype=torch.bfloat16, + ) + else: + attn_bmm_output = torch.bmm( + attn_output.to(torch.bfloat16).transpose(0, 1), + self.w_vc.to(torch.bfloat16) * self.w_scale, + ) if self.o_proj.weight.dtype == torch.uint8: attn_bmm_output = attn_bmm_output.transpose(0, 1) @@ -2162,10 +2192,23 @@ def forward_npu_sparse_prepare( q_nope_out, ) else: - q_nope_out = torch.bmm( - q_nope.to(torch.bfloat16).transpose(0, 1), - self.w_kc.to(torch.bfloat16) * self.w_scale, - ) + if _use_aiter_gfx95 and self.w_kc.dtype == torch.float8_e4m3fn: + + q_nope_out = batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant( + X=q_nope, + WQ=self.w_kc.transpose(-1, -2), + w_scale=self.w_scale, # + group_size=128, + YQ=None, # allocate (B, M, N) + transpose_bm=False, # (B, M, N) + transpose_bm_in=True, # (M, B, K) + dtype=torch.bfloat16, + ) + else: + q_nope_out = torch.bmm( + q_nope.to(torch.bfloat16).transpose(0, 1), + self.w_kc.to(torch.bfloat16) * self.w_scale, + ) elif self.w_kc.dtype == torch.float8_e4m3fn: q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8( q_nope.transpose(0, 1), @@ -2648,7 +2691,7 @@ def _set_mla_kv_buffer( k_pe: torch.Tensor, forward_batch: ForwardBatch, ): - if _is_cuda: + if _is_cuda or _use_aiter_gfx95: # Save latent cache forward_batch.token_to_kv_pool.set_mla_kv_buffer( self.attn_mha, forward_batch.out_cache_loc, kv_a.unsqueeze(1), k_pe @@ -2673,7 +2716,7 @@ def _get_mla_kv_buffer( dst_dtype: torch.dtype, forward_batch: ForwardBatch, ): - if _is_cuda: + if _is_cuda or _use_aiter_gfx95: kv_a, k_pe = forward_batch.token_to_kv_pool.get_mla_kv_buffer( self.attn_mha, kv_indices, dst_dtype )