Skip to content

Commit 21de9af

Browse files
committed
upd
Signed-off-by: Qidi Sang <[email protected]>
1 parent c2a0cad commit 21de9af

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

flashinfer/decode.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2216,14 +2216,18 @@ def trtllm_batch_decode_with_kv_cache(
22162216
else:
22172217
raise ValueError(f"Invalid out_dtype: {out_dtype}")
22182218

2219+
bmm1_scale = (
2220+
bmm1_scale.item() if isinstance(bmm1_scale, torch.Tensor) else bmm1_scale
2221+
)
2222+
bmm2_scale = (
2223+
bmm2_scale.item() if isinstance(bmm2_scale, torch.Tensor) else bmm2_scale
2224+
)
2225+
22192226
run_func(
22202227
out,
22212228
out_scale_factor,
22222229
query.view(
2223-
query.size(0) // q_len_per_req,
2224-
q_len_per_req,
2225-
query.size(1),
2226-
query.size(2),
2230+
query.size(0) // q_len_per_req, q_len_per_req, query.size(1), query.size(2)
22272231
),
22282232
k_cache,
22292233
v_cache,

0 commit comments

Comments
 (0)