We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent c2a0cad commit 21de9afCopy full SHA for 21de9af
flashinfer/decode.py
@@ -2216,14 +2216,18 @@ def trtllm_batch_decode_with_kv_cache(
2216
else:
2217
raise ValueError(f"Invalid out_dtype: {out_dtype}")
2218
2219
+ bmm1_scale = (
2220
+ bmm1_scale.item() if isinstance(bmm1_scale, torch.Tensor) else bmm1_scale
2221
+ )
2222
+ bmm2_scale = (
2223
+ bmm2_scale.item() if isinstance(bmm2_scale, torch.Tensor) else bmm2_scale
2224
2225
+
2226
run_func(
2227
out,
2228
out_scale_factor,
2229
query.view(
- query.size(0) // q_len_per_req,
- q_len_per_req,
- query.size(1),
- query.size(2),
2230
+ query.size(0) // q_len_per_req, q_len_per_req, query.size(1), query.size(2)
2231
),
2232
k_cache,
2233
v_cache,
0 commit comments