Skip to content

Commit 39e36dc

Browse files
committed
upd
1 parent e535e80 commit 39e36dc

File tree

3 files changed

+31
-25
lines changed

3 files changed

+31
-25
lines changed

csrc/trtllm_fmha_kernel_launcher.cu

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -232,11 +232,11 @@ void trtllm_paged_attention_decode(TensorView out, Optional<TensorView> out_scal
232232
bool is_shared_kv = key_cache.data_ptr() == value_cache.data_ptr();
233233
int num_pages_in_mem_pool = is_shared_kv ? key_cache.size(0) : key_cache.size(0) * 2;
234234

235-
// Assume NHD layout: [..., N, H, D]
236-
int page_size = key_cache.size(-3);
237-
int num_kv_heads = key_cache.size(-2);
238-
int kv_stride_keys_values = key_cache.stride(-3); // key/values
239-
int kv_stride_heads = key_cache.stride(-2); // head
235+
// Assume NHD layout: [..., H, N, D]
236+
int page_size = key_cache.size(-2);
237+
int num_kv_heads = key_cache.size(-3);
238+
int kv_stride_keys_values = key_cache.stride(-2); // key/values
239+
int kv_stride_heads = key_cache.stride(-3); // head
240240

241241
int kv_stride_batch = key_cache.stride(0); // batch
242242

@@ -294,11 +294,11 @@ void trtllm_paged_attention_context(TensorView out, Optional<TensorView> out_sca
294294
bool is_shared_kv = key_cache.data_ptr() == value_cache.data_ptr();
295295
int num_pages_in_mem_pool = is_shared_kv ? key_cache.size(0) : key_cache.size(0) * 2;
296296

297-
// Assume NHD layout: [..., N, H, D]
298-
int page_size = key_cache.size(-3);
299-
int num_kv_heads = key_cache.size(-2);
300-
int kv_stride_keys_values = key_cache.stride(-3); // key/values
301-
int kv_stride_heads = key_cache.stride(-2); // head
297+
// Assume NHD layout: [..., H, N, D]
298+
int page_size = key_cache.size(-2);
299+
int num_kv_heads = key_cache.size(-3);
300+
int kv_stride_keys_values = key_cache.stride(-2); // key/values
301+
int kv_stride_heads = key_cache.stride(-3); // head
302302
int kv_stride_batch = key_cache.stride(0); // batch
303303

304304
const auto stream = get_stream(query.device());

flashinfer/decode.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1235,9 +1235,9 @@ def run(
12351235
q, k_cache, self._cached_q_data_type, self._cached_kv_data_type
12361236
)
12371237

1238-
# Convert HND layout to NHD for trtllm-gen backend
1239-
if self._backend == "trtllm-gen" and self._kv_layout == "HND":
1240-
# For HND: [..., H, N, D] -> NHD: [..., N, H, D]
1238+
# Convert NHD layout to HND for trtllm-gen backend
1239+
if self._backend == "trtllm-gen" and self._kv_layout == "NHD":
1240+
# For NHD: [..., N, H, D] -> HND: [..., H, N, D]
12411241
k_cache = k_cache.transpose(-3, -2)
12421242
v_cache = v_cache.transpose(-3, -2)
12431243

@@ -2198,9 +2198,9 @@ def trtllm_batch_decode_with_kv_cache(
21982198
q_len_per_req=q_len_per_req,
21992199
)
22002200
elif backend == "trtllm-gen":
2201-
# Convert HND layout to NHD if necessary (transpose only changes stride, not data)
2202-
if kv_layout == "HND":
2203-
# For HND: [..., H, N, D] -> NHD: [..., N, H, D]
2201+
# Convert NHD layout to HND if necessary (transpose only changes stride, not data)
2202+
if kv_layout == "NHD":
2203+
# For NHD: [..., N, H, D] -> HND: [..., H, N, D]
22042204
k_cache = k_cache.transpose(-3, -2)
22052205
v_cache = v_cache.transpose(-3, -2)
22062206

@@ -2431,7 +2431,9 @@ def xqa_batch_decode_with_kv_cache(
24312431
page_size = k_cache.shape[2]
24322432
head_dim = k_cache.shape[3]
24332433

2434-
workspace_0, workspace_1 = torch.chunk(workspace_buffer, 2, dim=0)
2434+
workspace_u8 = workspace_buffer.view(torch.uint8)
2435+
semaphore = workspace_u8[: round_up(4 * sm_count, 16)]
2436+
scratch = workspace_u8[round_up(4 * sm_count, 16) :]
24352437
kv_scale_value = bmm2_scale
24362438
q_scale_value = bmm1_scale / kv_scale_value * (head_dim**0.5)
24372439

@@ -2448,8 +2450,8 @@ def xqa_batch_decode_with_kv_cache(
24482450
block_tables,
24492451
seq_lens_new,
24502452
out,
2451-
workspace_0,
2452-
workspace_1,
2453+
scratch,
2454+
semaphore,
24532455
num_kv_heads,
24542456
page_size,
24552457
sinks=sinks_new,
@@ -2571,6 +2573,10 @@ def trtllm_batch_decode_with_kv_cache_mla(
25712573
): # todo(Yingyi): add support for more block sizes?
25722574
raise ValueError(f"Supported block_size are 32 and 64, got {block_size}")
25732575

2576+
print(
2577+
f"Running TRTLLM batch decode with KV cache: {query.shape}, {kv_cache.shape}, {workspace_buffer.shape}"
2578+
)
2579+
25742580
_check_trtllm_gen_mla_shape(
25752581
query,
25762582
kv_cache,

flashinfer/prefill.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2088,9 +2088,9 @@ def run(
20882088
out, q.shape[:-1] + v_cache.shape[-1:], q.dtype, q.device, "out"
20892089
)
20902090

2091-
# Convert HND layout to NHD for trtllm-gen backend
2092-
if self._backend == "trtllm-gen" and self._kv_layout == "HND":
2093-
# For HND: [..., H, N, D] -> NHD: [..., N, H, D]
2091+
# Convert NHD layout to HND for trtllm-gen backend
2092+
if self._backend == "trtllm-gen" and self._kv_layout == "NHD":
2093+
# For NHD: [..., N, H, D] -> HND: [..., H, N, D]
20942094
k_cache = k_cache.transpose(-3, -2)
20952095
v_cache = v_cache.transpose(-3, -2)
20962096

@@ -3411,9 +3411,9 @@ def trtllm_batch_context_with_kv_cache(
34113411
# it doesn't change underlying storage
34123412
k_cache, v_cache = kv_cache.unbind(dim=1)
34133413

3414-
# Convert HND layout to NHD if necessary (transpose only changes stride, not data)
3415-
if kv_layout == "HND":
3416-
# For HND: [..., H, N, D] -> NHD: [..., N, H, D]
3414+
# Convert NHD layout to HND if necessary (transpose only changes stride, not data)
3415+
if kv_layout == "NHD":
3416+
# For NHD: [..., N, H, D] -> HND: [..., H, N, D]
34173417
k_cache = k_cache.transpose(-3, -2)
34183418
v_cache = v_cache.transpose(-3, -2)
34193419

0 commit comments

Comments
 (0)