Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion python/sglang/srt/layers/quantization/fp8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@

if _use_aiter:
import aiter
from aiter import gemm_a8w8_blockscale, gemm_a8w8_bpreshuffle, get_hip_quant

# from aiter import gemm_a8w8_blockscale, gemm_a8w8_bpreshuffle, get_hip_quant
from aiter import gemm_a8w8_bpreshuffle, get_hip_quant
from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale

aiter_per1x128_quant = get_hip_quant(aiter.QuantType.per_1x128)

Expand Down
71 changes: 57 additions & 14 deletions python/sglang/srt/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,9 @@
_use_aiter_gfx95 = _use_aiter and _is_gfx95_supported

if _use_aiter_gfx95:
from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import (
batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant,
)
from aiter.ops.triton.fused_fp8_quant import fused_rms_fp8_group_quant

from sglang.srt.layers.quantization.quark.utils import quark_post_load_weights
Expand Down Expand Up @@ -1813,10 +1816,25 @@ def forward_absorb_prepare(
q_nope_out,
)
else:
q_nope_out = torch.bmm(
q_nope.to(torch.bfloat16).transpose(0, 1),
self.w_kc.to(torch.bfloat16) * self.w_scale,
)
if _use_aiter_gfx95 and self.w_kc.dtype == torch.float8_e4m3fn:

q_nope_out = batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant(
X=q_nope,
WQ=self.w_kc.transpose(-1, -2),
w_scale=self.w_scale,
group_size=128,
YQ=None, # allocate (B, M, N)
transpose_bm=False, # (B, M, N)
transpose_bm_in=True, # (M, B, K)
dtype=torch.bfloat16,
)

else:
q_nope_out = torch.bmm(
q_nope.to(torch.bfloat16).transpose(0, 1),
self.w_kc.to(torch.bfloat16) * self.w_scale,
)

elif self.w_kc.dtype == torch.float8_e4m3fn:
# fix bmm_fp8 error under cublas12.9 caused by bumpallocator, detail in pr#11612
q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
Expand Down Expand Up @@ -1964,10 +1982,22 @@ def forward_absorb_core(
attn_bmm_output,
)
else:
attn_bmm_output = torch.bmm(
attn_output.to(torch.bfloat16).transpose(0, 1),
self.w_vc.to(torch.bfloat16) * self.w_scale,
)
if _use_aiter_gfx95 and self.w_kc.dtype == torch.float8_e4m3fn:
attn_bmm_output = batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant(
X=attn_output,
WQ=self.w_vc.transpose(-1, -2),
w_scale=self.w_scale,
group_size=128,
YQ=None,
transpose_bm=False,
transpose_bm_in=True,
dtype=torch.bfloat16,
)
else:
attn_bmm_output = torch.bmm(
attn_output.to(torch.bfloat16).transpose(0, 1),
self.w_vc.to(torch.bfloat16) * self.w_scale,
)

if self.o_proj.weight.dtype == torch.uint8:
attn_bmm_output = attn_bmm_output.transpose(0, 1)
Expand Down Expand Up @@ -2162,10 +2192,23 @@ def forward_npu_sparse_prepare(
q_nope_out,
)
else:
q_nope_out = torch.bmm(
q_nope.to(torch.bfloat16).transpose(0, 1),
self.w_kc.to(torch.bfloat16) * self.w_scale,
)
if _use_aiter_gfx95 and self.w_kc.dtype == torch.float8_e4m3fn:

q_nope_out = batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant(
X=q_nope,
WQ=self.w_kc.transpose(-1, -2),
w_scale=self.w_scale, #
group_size=128,
YQ=None, # allocate (B, M, N)
transpose_bm=False, # (B, M, N)
transpose_bm_in=True, # (M, B, K)
dtype=torch.bfloat16,
)
else:
q_nope_out = torch.bmm(
q_nope.to(torch.bfloat16).transpose(0, 1),
self.w_kc.to(torch.bfloat16) * self.w_scale,
)
elif self.w_kc.dtype == torch.float8_e4m3fn:
q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
q_nope.transpose(0, 1),
Expand Down Expand Up @@ -2648,7 +2691,7 @@ def _set_mla_kv_buffer(
k_pe: torch.Tensor,
forward_batch: ForwardBatch,
):
if _is_cuda:
if _is_cuda or _use_aiter_gfx95:
# Save latent cache
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
self.attn_mha, forward_batch.out_cache_loc, kv_a.unsqueeze(1), k_pe
Expand All @@ -2673,7 +2716,7 @@ def _get_mla_kv_buffer(
dst_dtype: torch.dtype,
forward_batch: ForwardBatch,
):
if _is_cuda:
if _is_cuda or _use_aiter_gfx95:
kv_a, k_pe = forward_batch.token_to_kv_pool.get_mla_kv_buffer(
self.attn_mha, kv_indices, dst_dtype
)
Expand Down
Loading