@@ -193,10 +193,10 @@ def __init__(
193193 else 0 )
194194
195195 # Determine if this is VSWA (Variable Sliding Window Attention)
196- is_vswa = len (self .max_attention_window_vec ) > 1
196+ self . is_vswa = len (self .max_attention_window_vec ) > 1
197197
198198 # Calculate blocks per window using appropriate method
199- if is_vswa :
199+ if self . is_vswa :
200200 # VSWA case: use C++ implementation for variable window sizes
201201 # model config check
202202 if model_config is None :
@@ -523,14 +523,29 @@ def get_batch_cache_indices(
523523 return result
524524
525525 def get_num_free_blocks (self ) -> int :
526- return self .impl .get_kv_cache_stats ().free_num_blocks
526+ if self .is_vswa :
527+ logger .info (
528+ f"For VSWA case, we return the minimum of the number of free blocks for each window size: { self .impl .get_kv_cache_stats ().num_free_blocks_per_window_size } "
529+ )
530+ return min (self .impl .get_kv_cache_stats ().
531+ num_free_blocks_per_window_size .values ())
532+ else :
533+ return self .impl .get_kv_cache_stats ().free_num_blocks
527534
528535 def get_num_kv_blocks (self , num_tokens : int ) -> int :
529536 return (num_tokens + self .tokens_per_block - 1 ) // self .tokens_per_block
530537
531538 def get_num_available_tokens (self , max_num_draft_tokens : int = 0 ) -> int :
532- return (self .get_num_free_blocks () * self .tokens_per_block -
533- self .num_extra_kv_tokens - max_num_draft_tokens )
539+ if self .max_attention_window_vec and len (
540+ self .max_attention_window_vec ) > 1 :
541+ # VSWA case, the available tokens should the the minimum of the available tokens for each window size
542+ min_free_blocks = min (self .impl .get_kv_cache_stats ().
543+ num_free_blocks_per_window_size .values ())
544+ res = min_free_blocks * self .tokens_per_block - self .num_extra_kv_tokens - max_num_draft_tokens
545+ else :
546+ res = (self .get_num_free_blocks () * self .tokens_per_block -
547+ self .num_extra_kv_tokens - max_num_draft_tokens )
548+ return res
534549
535550 def get_buffers (self , layer_idx : int ) -> Optional [torch .Tensor ]:
536551 layer_offset = self .layer_offsets [layer_idx ]
0 commit comments