@@ -13,7 +13,13 @@ class KVCache(ABC, nn.Module):
1313 relevant_kwargs = ["max_cache_length" ]
1414
1515 def __init__ (
16- self , max_batch_size , n_heads , head_dim , dtype = torch .bfloat16 , head_specific = False , ** kwargs
16+ self ,
17+ max_batch_size ,
18+ n_heads ,
19+ head_dim ,
20+ dtype = torch .bfloat16 ,
21+ head_specific = False ,
22+ ** kwargs ,
1723 ):
1824 super ().__init__ ()
1925
@@ -28,7 +34,15 @@ def __init__(
2834 # We use n_heads as an optional second dimension to allow for head-specific evictions.
2935 self .register_buffer (
3036 "pos" ,
31- torch .full ((max_batch_size , n_heads if head_specific else 1 , self .max_cache_length ), - 1 , dtype = torch .int ),
37+ torch .full (
38+ (
39+ max_batch_size ,
40+ n_heads if head_specific else 1 ,
41+ self .max_cache_length ,
42+ ),
43+ - 1 ,
44+ dtype = torch .int ,
45+ ),
3246 )
3347
3448 self .updates = 0
@@ -49,7 +63,7 @@ def reset(self):
4963 self .pos .fill_ (- 1 )
5064 self .insertions = 0
5165 self .updates = 0
52-
66+
5367 def update (self , input_pos , k_val , v_val ):
5468 """
5569 Updates the cache with the given input positions, keys, and values.
@@ -73,7 +87,9 @@ def update(self, input_pos, k_val, v_val):
7387 # Truncate the unfilled part of the cache
7488 # Since we always fill in-order it will be at the end
7589 truncate_idx = min (self .insertions , self .max_cache_length )
76- return self .k_cache [:, :, :truncate_idx , :], self .v_cache [:, :, :truncate_idx , :]
90+ return self .k_cache [:, :, :truncate_idx , :], self .v_cache [
91+ :, :, :truncate_idx , :
92+ ]
7793
7894 @abstractmethod
7995 def _update (self , input_pos , k_val , v_val ):
@@ -116,19 +132,16 @@ def __init__(
116132 def mark_global_tokens (self ) -> bool :
117133 """
118134 Update POS tensor to give global tokens highest priority.
119-
135+
120136 Return a boolean indicating whether or not all global tokens were filled.
121137
122138 If it returns True, this function won't be called again to save computation.
123139 """
124140 # We put max priority on leading "global" tokens
125- global_mask = torch .logical_and (
126- self .pos < self .global_tokens , self .pos >= 0
127- )
141+ global_mask = torch .logical_and (self .pos < self .global_tokens , self .pos >= 0 )
128142 # Give self.score an arbitrary high value for global tokens so that they are not replaced
129143 self .pos .masked_fill_ (global_mask , LARGE_INTEGER )
130- return global_mask .sum () == self .global_tokens
131-
144+ return (global_mask .sum () == self .global_tokens ).item ()
132145
133146 def _update (self , input_pos , k_val , v_val ):
134147 # Prefill case: If prompt > window, then we need to chop off early positions
@@ -144,19 +157,18 @@ def _update(self, input_pos, k_val, v_val):
144157 input_pos = input_pos [keep_idxs ]
145158 k_val = k_val [:, :, keep_idxs ]
146159 v_val = v_val [:, :, keep_idxs ]
147-
160+
148161 # Identify the lowest positions in the cache that are not filled
149- # For window, all heads are the same so let's just use the first head for "pos"
150162 pos = self .pos [:, 0 , :].squeeze (1 )
151163 _ , min_k_indices = pos .topk (input_pos .shape [0 ], largest = False )
164+ min_k_indices = min_k_indices .squeeze (0 )
152165
153- # Sort the indices in ascending order
154- min_k_indices , _ = min_k_indices .squeeze (0 ).sort ()
155-
156- self .fill (fill_indices = min_k_indices , input_pos = input_pos , k_val = k_val , v_val = v_val )
166+ self .fill (
167+ fill_indices = min_k_indices , input_pos = input_pos , k_val = k_val , v_val = v_val
168+ )
157169
158170 # This is a potentially costly operation which doesn't need to be repeated once we've filled the global tokens
159- self .global_filled |= self .mark_global_tokens ()
171+ self .global_filled = self . global_filled or self .mark_global_tokens ()
160172
161173
162174def get_cache_constructor (cache_strategy ):
0 commit comments