Skip to content

Commit d8f6387

Browse files
committed
relax the assersions for xqa
Signed-off-by: Siyuan Fu <[email protected]>
1 parent 2f39e1f commit d8f6387

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

flashinfer/decode.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2175,9 +2175,10 @@ def trtllm_batch_decode_with_kv_cache(
21752175

21762176
if backend == "xqa":
21772177
# TODO(Siyuan): support device scale factors, which was removed in #2033
2178-
assert isinstance(bmm1_scale, float) and isinstance(bmm2_scale, float), (
2179-
"XQA MLA only supports float scale factors"
2180-
)
2178+
if not isinstance(bmm1_scale, float):
2179+
bmm1_scale = bmm1_scale.item()
2180+
if not isinstance(bmm2_scale, float):
2181+
bmm2_scale = bmm2_scale.item()
21812182
# xqa backend doesn't support nvfp4 output
21822183
if out_dtype == "nvfp4" or (out_dtype is None and isinstance(out, FP4Tensor)):
21832184
raise ValueError("xqa backend does not support nvfp4 output")
@@ -2590,9 +2591,10 @@ def trtllm_batch_decode_with_kv_cache_mla(
25902591
)
25912592
if backend == "xqa":
25922593
# TODO(Siyuan): support device scale factors, which was removed in #2033
2593-
assert isinstance(bmm1_scale, float) and isinstance(bmm2_scale, float), (
2594-
"XQA MLA only supports float scale factors"
2595-
)
2594+
if not isinstance(bmm1_scale, float):
2595+
bmm1_scale = bmm1_scale.item()
2596+
if not isinstance(bmm2_scale, float):
2597+
bmm2_scale = bmm2_scale.item()
25962598
if (
25972599
get_compute_capability(query.device)[0] != 12
25982600
or query.dtype != torch.float8_e4m3fn

0 commit comments

Comments
 (0)