Skip to content

Commit e032b4f

Browse files
Liu-congolpc0220
authored andcommitted
[BugFix] test_mla_fp8.py fails on Cublas 12.9 (sgl-project#11360)
Signed-off-by: Liu-congo <[email protected]>
1 parent d40bdca commit e032b4f

File tree

2 files changed

+31
-8
lines changed

2 files changed

+31
-8
lines changed

python/sglang/srt/models/deepseek_v2.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@
9494
block_quant_dequant,
9595
block_quant_to_tensor_quant,
9696
channel_quant_to_tensor_quant,
97+
input_to_float8,
9798
normalize_e4m3fn_to_e4m3fnuz,
9899
requant_weight_ue8m0_inplace,
99100
)
@@ -131,6 +132,7 @@
131132
is_hip,
132133
is_non_idle_and_non_empty,
133134
is_npu,
135+
is_nvidia_cublas_cu12_version_ge_12_9,
134136
is_sm100_supported,
135137
log_info_on_rank0,
136138
make_layers,
@@ -189,6 +191,7 @@
189191

190192
_is_flashinfer_available = is_flashinfer_available()
191193
_is_sm100_supported = is_cuda() and is_sm100_supported()
194+
_is_cublas_ge_129 = is_nvidia_cublas_cu12_version_ge_12_9()
192195

193196

194197
logger = logging.getLogger(__name__)
@@ -1572,10 +1575,15 @@ def forward_absorb_prepare(
15721575
self.w_kc.to(torch.bfloat16) * self.w_scale,
15731576
)
15741577
elif self.w_kc.dtype == torch.float8_e4m3fn:
1575-
q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
1576-
q_nope.transpose(0, 1),
1577-
zero_allocator.allocate(1),
1578-
)
1578+
# TODO fix the per_tensor_quant_mla_fp8 for cublas 12.9
1579+
if _is_cublas_ge_129:
1580+
q_nope_val, q_nope_scale = input_to_float8(
1581+
q_nope.transpose(0, 1), torch.float8_e4m3fn
1582+
)
1583+
else:
1584+
q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
1585+
q_nope.transpose(0, 1), zero_allocator.allocate(1)
1586+
)
15791587
q_nope_out = bmm_fp8(
15801588
q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
15811589
)
@@ -1716,10 +1724,14 @@ def forward_absorb_core(
17161724
attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
17171725

17181726
elif self.w_vc.dtype == torch.float8_e4m3fn:
1719-
attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
1720-
attn_output.transpose(0, 1),
1721-
zero_allocator.allocate(1),
1722-
)
1727+
if _is_cublas_ge_129:
1728+
attn_output_val, attn_output_scale = input_to_float8(
1729+
attn_output.transpose(0, 1), torch.float8_e4m3fn
1730+
)
1731+
else:
1732+
attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
1733+
attn_output.transpose(0, 1), zero_allocator.allocate(1)
1734+
)
17231735
attn_bmm_output = bmm_fp8(
17241736
attn_output_val,
17251737
self.w_vc,

python/sglang/srt/utils/common.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,17 @@ def is_flashinfer_available():
263263
return importlib.util.find_spec("flashinfer") is not None and is_cuda()
264264

265265

266+
def is_nvidia_cublas_cu12_version_ge_12_9():
267+
"""
268+
temporary fix for issue #11272
269+
"""
270+
try:
271+
installed_version = version("nvidia-cublas-cu12")
272+
except PackageNotFoundError:
273+
return False
274+
return pkg_version.parse(installed_version) >= pkg_version.parse("12.9")
275+
276+
266277
def random_uuid() -> str:
267278
return str(uuid.uuid4().hex)
268279

0 commit comments

Comments
 (0)