@@ -2175,9 +2175,10 @@ def trtllm_batch_decode_with_kv_cache(
21752175
21762176 if backend == "xqa" :
21772177 # TODO(Siyuan): support device scale factors, which was removed in #2033
2178- assert isinstance (bmm1_scale , float ) and isinstance (bmm2_scale , float ), (
2179- "XQA MLA only supports float scale factors"
2180- )
2178+ if not isinstance (bmm1_scale , float ):
2179+ bmm1_scale = bmm1_scale .item ()
2180+ if not isinstance (bmm2_scale , float ):
2181+ bmm2_scale = bmm2_scale .item ()
21812182 # xqa backend doesn't support nvfp4 output
21822183 if out_dtype == "nvfp4" or (out_dtype is None and isinstance (out , FP4Tensor )):
21832184 raise ValueError ("xqa backend does not support nvfp4 output" )
@@ -2590,9 +2591,10 @@ def trtllm_batch_decode_with_kv_cache_mla(
25902591 )
25912592 if backend == "xqa" :
25922593 # TODO(Siyuan): support device scale factors, which was removed in #2033
2593- assert isinstance (bmm1_scale , float ) and isinstance (bmm2_scale , float ), (
2594- "XQA MLA only supports float scale factors"
2595- )
2594+ if not isinstance (bmm1_scale , float ):
2595+ bmm1_scale = bmm1_scale .item ()
2596+ if not isinstance (bmm2_scale , float ):
2597+ bmm2_scale = bmm2_scale .item ()
25962598 if (
25972599 get_compute_capability (query .device )[0 ] != 12
25982600 or query .dtype != torch .float8_e4m3fn
0 commit comments