Skip to content

Commit 5c6b9d9

Browse files
committed
fix
1 parent 08d088a commit 5c6b9d9

File tree

3 files changed

+16
-4
lines changed

3 files changed

+16
-4
lines changed

flashinfer/decode.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -984,7 +984,6 @@ def plan(
984984
else:
985985
kv_lens_arr_host = seq_lens.cpu()
986986
if self._backend == "trtllm-gen":
987-
assert self._kv_layout == "HND"
988987
assert logits_soft_cap == 0.0
989988
self._max_kv_len = max(kv_lens_arr_host).item()
990989
self._kv_lens_buffer[: len(kv_lens_arr_host)].copy_(
@@ -1227,6 +1226,7 @@ def run(
12271226
if enable_pdl is None:
12281227
enable_pdl = device_support_pdl(q.device)
12291228
k_cache, v_cache = _unpack_paged_kv_cache(paged_kv_cache, self._kv_layout)
1229+
12301230
if self._kv_layout == "NHD":
12311231
page_size = k_cache.shape[1]
12321232
else:
@@ -1235,6 +1235,12 @@ def run(
12351235
q, k_cache, self._cached_q_data_type, self._cached_kv_data_type
12361236
)
12371237

1238+
# Convert HND layout to NHD for trtllm-gen backend
1239+
if self._backend == "trtllm-gen" and self._kv_layout == "HND":
1240+
# For HND: [..., H, N, D] -> NHD: [..., N, H, D]
1241+
k_cache = k_cache.transpose(-3, -2)
1242+
v_cache = v_cache.transpose(-3, -2)
1243+
12381244
pos_encoding_mode = self._pos_encoding_mode
12391245
window_left = self._window_left if window_left is None else window_left
12401246
if self._backend != "trtllm-gen":
@@ -1997,7 +2003,6 @@ def paged_run(
19972003
1.0, # NOTE(Siyuan): update this to expose bmm2 scale
19982004
workspace_size,
19992005
window_left,
2000-
layout,
20012006
enable_pdl,
20022007
out=o,
20032008
sinks=sinks,

flashinfer/prefill.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -582,7 +582,6 @@ def paged_run(
582582
batch_size,
583583
cum_seq_lens_q,
584584
cum_seq_lens_kv,
585-
layout,
586585
enable_pdl,
587586
workspace_size,
588587
window_left,
@@ -2041,6 +2040,7 @@ def run(
20412040
_check_cached_qkv_data_type(
20422041
q, k_cache, self._cached_q_data_type, self._cached_kv_data_type
20432042
)
2043+
20442044
stride_block = k_cache.stride(0)
20452045
if self._kv_layout == "NHD":
20462046
page_size = k_cache.shape[1]
@@ -2088,6 +2088,12 @@ def run(
20882088
out, q.shape[:-1] + v_cache.shape[-1:], q.dtype, q.device, "out"
20892089
)
20902090

2091+
# Convert HND layout to NHD for trtllm-gen backend
2092+
if self._backend == "trtllm-gen" and self._kv_layout == "HND":
2093+
# For HND: [..., H, N, D] -> NHD: [..., N, H, D]
2094+
k_cache = k_cache.transpose(-3, -2)
2095+
v_cache = v_cache.transpose(-3, -2)
2096+
20912097
if self._custom_mask_buf is not None:
20922098
mask_mode = MaskMode.CUSTOM.value
20932099
else:

tests/attention/test_trtllm_gen_attention.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -809,8 +809,9 @@ def test_trtllm_batch_decode(
809809
out_dtype=out_dtype,
810810
o_sf_scale=o_sf_scale,
811811
o_sf_vec_size=o_sf_vec_size,
812-
enable_pdl=enable_pdl,
813812
sinks=(sink if enable_sink else None),
813+
kv_layout=kv_layout,
814+
enable_pdl=enable_pdl,
814815
q_len_per_req=q_len_per_req,
815816
)
816817
# check if the first 8192 * 256 * 4 bytes of workspace_buffer is zero

0 commit comments

Comments
 (0)