@@ -243,8 +243,8 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
243243 self .runner .device , non_blocking = True )
244244 local_seqused_k = torch .from_numpy (virt_k_seqlens_np ).to (
245245 self .runner .device , non_blocking = True )
246- local_max_query_len = seqlens_q_local_np .max ()
247- local_max_seq_len = virt_k_seqlens_np .max ()
246+ local_max_query_len = int ( seqlens_q_local_np .max () )
247+ local_max_seq_len = int ( virt_k_seqlens_np .max () )
248248 local_scheduler_metadata = schedule (
249249 batch_size = local_query_start_loc .shape [0 ] - 1 ,
250250 cu_query_lens = local_query_start_loc ,
@@ -253,13 +253,25 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
253253 max_seq_len = local_max_seq_len ,
254254 causal = True )
255255
256+ local_cu_seq_lens = torch .zeros (virt_k_seqlens_np .shape [0 ] + 1 ,
257+ dtype = torch .int32 ,
258+ device = self .runner .device )
259+ local_cu_seq_lens [1 :] = torch .cumsum (
260+ torch .from_numpy (virt_k_seqlens_np ).to (
261+ device = self .runner .device ,
262+ dtype = torch .int32 ,
263+ non_blocking = True ),
264+ dim = 0 )
265+
266+
256267 local_attn_metadata = \
257268 AiterFlashAttentionMetadata .LocalAttentionMetadata (
258269 local_query_start_loc = local_query_start_loc ,
259270 local_seqused_k = local_seqused_k ,
260271 local_block_table = virt_block_table_tensor ,
261272 local_max_query_len = local_max_query_len ,
262273 local_max_seq_len = local_max_seq_len ,
274+ local_cu_seq_lens = local_cu_seq_lens ,
263275 local_scheduler_metadata = local_scheduler_metadata ,
264276 )
265277
@@ -368,6 +380,7 @@ class LocalAttentionMetadata:
368380 local_block_table : torch .Tensor
369381 local_max_query_len : int
370382 local_max_seq_len : int
383+ local_cu_seq_lens : torch .Tensor
371384 local_scheduler_metadata : Optional [torch .Tensor ]
372385
373386 local_attn_metadata : Optional [LocalAttentionMetadata ] = None
@@ -387,6 +400,7 @@ def __init__(
387400 blocksparse_params : Optional [dict [str , Any ]] = None ,
388401 logits_soft_cap : Optional [float ] = None ,
389402 attn_type : AttentionType = AttentionType .DECODER ,
403+ kv_sharing_target_layer_name : Optional [int ] = None ,
390404 use_irope : bool = False ,
391405 ) -> None :
392406 if blocksparse_params is not None :
@@ -408,6 +422,7 @@ def __init__(
408422 # In flash-attn, setting logits_soft_cap as 0 means no soft cap.
409423 logits_soft_cap = 0.
410424 self .logits_soft_cap = logits_soft_cap
425+ self .kv_sharing_target_layer_name = kv_sharing_target_layer_name
411426
412427 assert self .num_heads % self .num_kv_heads == 0
413428 self .num_queries_per_kv = self .num_heads // self .num_kv_heads
@@ -478,22 +493,25 @@ def forward(
478493 # performance to make sure it does not introduce any overhead.
479494
480495 num_actual_tokens = attn_metadata .num_actual_tokens
481- # Reshape the input keys and values and store them in the cache.
482- # NOTE(woosuk): Here, key and value are padded while slot_mapping is
483- # not padded. However, we don't need to do key[:num_actual_tokens] and
484- # value[:num_actual_tokens] because the reshape_and_cache_flash op uses
485- # the slot_mapping's shape to determine the number of actual tokens.
486496 key_cache , value_cache = kv_cache .unbind (0 )
487- torch .ops ._C_cache_ops .reshape_and_cache_flash (
488- key ,
489- value ,
490- key_cache ,
491- value_cache ,
492- attn_metadata .slot_mapping ,
493- self .kv_cache_dtype ,
494- layer ._k_scale ,
495- layer ._v_scale ,
496- )
497+ if self .kv_sharing_target_layer_name is None :
498+ # Reshape the input keys and values and store them in the cache.
499+ # Skip this if sharing KV cache with an earlier attention layer.
500+ # NOTE(woosuk): Here, key and value are padded while slot_mapping is
501+ # not padded. However, we don't need to do key[:num_actual_tokens]
502+ # and value[:num_actual_tokens] because the reshape_and_cache_flash
503+ # op uses the slot_mapping's shape to determine the number of
504+ # actual tokens.
505+ torch .ops ._C_cache_ops .reshape_and_cache_flash (
506+ key ,
507+ value ,
508+ key_cache ,
509+ value_cache ,
510+ attn_metadata .slot_mapping ,
511+ self .kv_cache_dtype ,
512+ layer ._k_scale ,
513+ layer ._v_scale ,
514+ )
497515
498516 if self .kv_cache_dtype .startswith ("fp8" ):
499517 key_cache = key_cache .view (torch .float8_e4m3fnuz )
@@ -541,7 +559,8 @@ def forward(
541559 alibi_slopes = self .alibi_slopes ,
542560 window_size = self .sliding_window ,
543561 block_table = block_table ,
544- cu_seqlens_k = cu_seq_lens ,
562+ cu_seqlens_k = (cu_seq_lens if not use_local_attn else
563+ local_metadata .local_cu_seq_lens ),
545564 )
546565
547566 _ , num_heads , head_size = query .shape
0 commit comments