@@ -48,17 +48,19 @@ def _vllm_layout_trans_kernel(
4848 ):
4949 batch_idx = tl .program_id (0 )
5050 block_idx = tl .program_id (1 )
51- batch_token_indexes = tl .load (b_seq_lens_loc + batch_idx +
52- tl .arange (0 , 2 ))
53- batch_token_start , batch_token_end = tl .split (batch_token_indexes )
54- seq_len = batch_token_end - batch_token_start
5551
5652 batch_query_indexes = tl .load (b_query_lens_loc + batch_idx +
5753 tl .arange (0 , 2 ))
5854 batch_query_start , batch_query_end = tl .split (batch_query_indexes )
5955 query_len = batch_query_end - batch_query_start
6056 if query_len <= 1 :
6157 return
58+
59+ batch_token_indexes = tl .load (b_seq_lens_loc + batch_idx +
60+ tl .arange (0 , 2 ))
61+ batch_token_start , batch_token_end = tl .split (batch_token_indexes )
62+ seq_len = batch_token_end - batch_token_start
63+
6264 if block_idx * BLOCK_SIZE < seq_len :
6365 block_mask = (block_idx * BLOCK_SIZE +
6466 tl .arange (0 , BLOCK_SIZE )[:, None ]) < seq_len
@@ -269,12 +271,13 @@ def build(self, common_prefix_len: int,
269271 max_query_len = common_attn_metadata .max_query_len
270272
271273 max_seq_len = int (self .runner .seq_lens_np [:num_reqs ].max ())
272- total_tokens = int (self .runner .seq_lens_np [:num_reqs ].sum ())
273274 query_start_loc = common_attn_metadata .query_start_loc
274275 seq_lens = common_attn_metadata .seq_lens
275276 block_table = self .block_table
276277 block_table_tensor = block_table .get_device_tensor ()[:num_reqs ]
277-
278+ query_lens = query_start_loc [1 :] - query_start_loc [:- 1 ]
279+ masked_seq_lens = torch .where (query_lens > 1 , seq_lens ,
280+ torch .zeros_like (seq_lens ))
278281 block_table .slot_mapping [:num_actual_tokens ].copy_ (
279282 block_table .slot_mapping_cpu [:num_actual_tokens ],
280283 non_blocking = True )
@@ -284,10 +287,10 @@ def build(self, common_prefix_len: int,
284287
285288 slot_mapping = block_table .slot_mapping [:num_actual_tokens ]
286289
287- cu_seq_lens = torch .zeros (seq_lens .shape [0 ] + 1 ,
290+ cu_seq_lens = torch .zeros (masked_seq_lens .shape [0 ] + 1 ,
288291 dtype = torch .int32 ,
289292 device = "cuda" )
290- torch .cumsum (seq_lens ,
293+ torch .cumsum (masked_seq_lens ,
291294 dim = 0 ,
292295 dtype = cu_seq_lens .dtype ,
293296 out = cu_seq_lens [1 :])
@@ -356,14 +359,14 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
356359 dtype = torch .uint8 ,
357360 device = self .runner .device ,
358361 )
359-
362+ masked_total_tokens = cu_seq_lens [ - 1 ]. item ()
360363 k_buffer = torch .empty (
361- (total_tokens , self .num_heads_kv , self .headdim ),
364+ (masked_total_tokens , self .num_heads_kv , self .headdim ),
362365 dtype = self .runner .dtype ,
363366 device = self .runner .device ,
364367 )
365368 v_buffer = torch .empty (
366- (total_tokens , self .num_heads_kv , self .headdim ),
369+ (masked_total_tokens , self .num_heads_kv , self .headdim ),
367370 dtype = self .runner .dtype ,
368371 device = self .runner .device ,
369372 )
0 commit comments