1+ import regex as re
12from abc import ABC , abstractmethod
23from typing import Tuple , Callable
34
@@ -19,10 +20,14 @@ def add_cache_arguments(parser: argparse.ArgumentParser):
1920 help = "Cache size per layer. If len < n layers, the values are tiled. Must have len divisible by n layers. \
2021 If 0 < x <= 1, it is percent of |prompt| + max new tokens. Otherwise, if > 1, its the maximum size." ,
2122 )
23+ strategies = ["full" , "random" , "window" , "scissor" , "l2" , "fastgen" ]
24+ debug_strategies = [f"debug_{ strategy } " for strategy in strategies ]
25+ strategies .extend (debug_strategies )
26+
2227 group .add_argument (
2328 "--cache_strategy" ,
2429 default = "full" ,
25- choices = [ "full" , "random" , "window" , "scissor" , "l2" ] ,
30+ choices = strategies ,
2631 )
2732
2833 # Dealing with Long Prompts
@@ -126,7 +131,7 @@ def create_window_attention_mask(seq_len, window_size, device):
126131class KVCache (ABC , nn .Module ):
127132 # Define which hyperparameters are relevant for the cache.
128133 # Override as needed for sub-classes.
129- relevant_kwargs = ["max_cache_length" , "global_tokens" ]
134+ relevant_kwargs = ["max_cache_length" , "max_seq_length" , " global_tokens" ]
130135
131136 def __init__ (
132137 self ,
@@ -208,6 +213,17 @@ def return_attn(self):
208213 """
209214 return False
210215
216+ def compute_statistics (self , seq_len ):
217+ """
218+ Computes statistics about the cache.
219+
220+ Returns:
221+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: The cache size, the number of tokens inserted, and the compression ratio.
222+ """
223+ return {
224+ "compression_ratio" : self .compression_ratio (seq_len ).item (),
225+ }
226+
211227 def compression_ratio (self , seq_len ):
212228 """
213229 Returns the compression ratio of the cache.
@@ -276,6 +292,24 @@ def compress_prompt(
276292 # Yet we return the un-compressed KV since during pre-fill we compute full causal attention.
277293 return k_val , v_val , mask , new_callback
278294
295+ def attn_history_callback (self ) -> Callable | None :
296+ """
297+ Returns a callback to update the attention history.
298+
299+ Returns None if attention is not needed
300+ """
301+ return (
302+ {
303+ "func" : lambda input_pos ,
304+ input_ids ,
305+ k_val ,
306+ v_val ,
307+ attn : self .update_attn_history (attn )
308+ }
309+ if self .return_attn ()
310+ else None
311+ )
312+
279313 def update (self , input_pos , k_val , v_val , input_ids = None ):
280314 """
281315 Updates the cache with the given input positions, keys, and values.
@@ -424,7 +458,7 @@ def mark_global_tokens(self, num_total_insertions: int) -> bool:
424458 ), "This cache does not have global tokens so we cannot mark them."
425459 # Give self.pos an highest possible position value for global tokens so that they are not replaced
426460 num_to_mark = min (self .global_tokens , num_total_insertions )
427- self .pos [:, :, :num_to_mark ] = self .max_cache_length
461+ self .pos [:, :, :num_to_mark ] = self .max_seq_length
428462 return num_to_mark == self .global_tokens
429463
430464
@@ -448,6 +482,7 @@ def _update(self, input_pos, k_val, v_val, input_ids=None):
448482class KVCacheRandom (KVCache ):
449483 relevant_kwargs = [
450484 "max_cache_length" ,
485+ "max_seq_length" ,
451486 "global_tokens" ,
452487 "prompt_compression_strategy" ,
453488 ]
@@ -475,6 +510,7 @@ def _update(self, input_pos, k_val, v_val, input_ids=None):
475510class KVCacheWindow (KVCache ):
476511 relevant_kwargs = [
477512 "max_cache_length" ,
513+ "max_seq_length" ,
478514 "global_tokens" ,
479515 "prompt_compression_strategy" ,
480516 # NB: "recent_window" is ignored as a relevant kwarg. It is fixed to self.max_cache_length - self.global_tokens.
@@ -520,6 +556,7 @@ def _update(self, input_pos, k_val, v_val, input_ids=None):
520556class KVCacheL2 (KVCacheWindow ):
521557 relevant_kwargs = [
522558 "max_cache_length" ,
559+ "max_seq_length" ,
523560 "global_tokens" ,
524561 "recent_window" ,
525562 "prompt_compression_strategy" ,
@@ -569,6 +606,7 @@ def update_attn_history(self, attn):
569606class KVCacheScissorhands (KVCacheWindow ):
570607 relevant_kwargs = [
571608 "max_cache_length" ,
609+ "max_seq_length" ,
572610 "global_tokens" ,
573611 "history_window_size" ,
574612 "drop_amount" ,
@@ -752,6 +790,7 @@ def _update(self, input_pos, k_val, v_val, input_ids=None):
752790class KVCacheFastGen (KVCacheScissorhands ):
753791 relevant_kwargs = [
754792 "max_cache_length" ,
793+ "max_seq_length" ,
755794 "history_window_size" ,
756795 "recent_window" ,
757796 "attn_thresholding" ,
@@ -1116,18 +1155,147 @@ def profile_and_update(self, input_pos, input_ids, k_val, v_val, attn):
11161155 self .update_attn_history (cum_attn )
11171156
11181157
1158+ class KVCacheAnalysis (KVCache ):
1159+ relevant_kwargs = [
1160+ "max_cache_length" ,
1161+ "history_window_size" ,
1162+ "recent_window" ,
1163+ "attn_thresholding" ,
1164+ "token_ids" ,
1165+ "prompt_compression_strategy" ,
1166+ "min_recovery_frac" ,
1167+ "heavy_hitter_frac" ,
1168+ "global_tokens" ,
1169+ "drop_amount" ,
1170+ "prompt_compression_strategy" ,
1171+ "attn_record_freq" ,
1172+ "max_seq_length" ,
1173+ ]
1174+
1175+ def __init__ (
1176+ self ,
1177+ max_batch_size ,
1178+ n_heads ,
1179+ head_dim ,
1180+ dtype = torch .bfloat16 ,
1181+ cache_strategy = "scissor" ,
1182+ ** kwargs ,
1183+ ):
1184+ # Never any prompt compression for full cache
1185+ full_kwargs = {
1186+ "prompt_compression_strategy" : None ,
1187+ "global_tokens" : 0 ,
1188+ "max_cache_length" : kwargs ["max_seq_length" ],
1189+ "max_seq_length" : kwargs ["max_seq_length" ],
1190+ }
1191+ super ().__init__ (
1192+ max_batch_size , n_heads , head_dim , dtype , head_specific = False , ** full_kwargs
1193+ )
1194+
1195+ # Initialize the compressed cache we want to analyze.
1196+ self .compressed = get_cache_constructor (cache_strategy = cache_strategy )[0 ](
1197+ max_batch_size ,
1198+ n_heads ,
1199+ head_dim ,
1200+ dtype ,
1201+ ** kwargs ,
1202+ )
1203+
1204+ self .register_buffer (
1205+ "attention_losses" ,
1206+ torch .full ((self .max_seq_length ,), fill_value = - 1 , dtype = dtype ),
1207+ )
1208+
1209+ def return_attn (self ):
1210+ return self .compressed .return_attn ()
1211+
1212+ def update (self , input_pos , k_val , v_val , input_ids = None ):
1213+ k , v , mask , _ = super ().update (input_pos , k_val , v_val , input_ids = input_ids )
1214+ _ , _ , _ , attn_callback = self .compressed .update (
1215+ input_pos , k_val , v_val , input_ids = input_ids
1216+ )
1217+
1218+ if attn_callback is not None and input_pos .shape [- 1 ] == 1 :
1219+ # This is ugly but we need to re-write callback to call this class's update_attn_history not the compressed
1220+ # This is because we need to filter the attention weights to only the tokens in the compressed cache first.
1221+ attn_callback = self .attn_history_callback ()
1222+ assert attn_callback is not None
1223+
1224+ return k , v , mask , attn_callback
1225+
1226+ def _update (self , input_pos , k_val , v_val , input_ids = None ):
1227+ # input_pos: [S], k_val: [B, H, S, D]
1228+ self .fill_contiguous (input_pos , k_val , v_val )
1229+ return input_pos .shape [- 1 ]
1230+
1231+ def reset (self ):
1232+ super ().reset ()
1233+ self .compressed .reset ()
1234+ self .attention_losses .fill_ (- 1 )
1235+
1236+ def update_attn_history (self , attn : torch .Tensor ):
1237+ indices = self .compressed .pos .clone ().long ()
1238+
1239+ # Global tokens will have been set to max seq length
1240+ # We need to set them back to actual global tokens
1241+ indices [:, :, : self .compressed .global_tokens ] = (
1242+ torch .arange (self .compressed .global_tokens , device = indices .device )
1243+ .view (1 , 1 , - 1 )
1244+ .expand (1 , indices .shape [1 ], - 1 )
1245+ )
1246+ indices = indices [:, :, : min (indices .shape [- 1 ], attn .shape [- 1 ])]
1247+ attn_compressed = attn .squeeze (2 ).gather (2 , indices ).unsqueeze (2 )
1248+ self .compressed .update_attn_history (attn_compressed )
1249+
1250+ attn_loss = (1 - attn_compressed .sum (dim = - 1 )).mean ()
1251+ insert_idx = torch .where (self .attention_losses == - 1 )[0 ][0 ]
1252+ self .attention_losses [insert_idx ] = attn_loss
1253+
1254+ def compute_statistics (self , seq_len ):
1255+ """
1256+ Computes statistics about the cache.
1257+
1258+ Returns:
1259+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: The cache size, the number of tokens inserted, and the compression ratio.
1260+ """
1261+ stats = super ().compute_statistics (seq_len )
1262+ cutoff = torch .where (self .attention_losses == - 1 )[0 ]
1263+ if len (cutoff ) > 0 :
1264+ cutoff = cutoff [0 ]
1265+ else :
1266+ cutoff = len (self .attention_losses )
1267+ stats ["attention_loss" ] = (self .attention_losses [:cutoff ].sum () / cutoff ).item ()
1268+ return stats
1269+
1270+
11191271def get_cache_constructor (cache_strategy ):
1272+ relevant_kwargs = None
11201273 if cache_strategy == "full" :
1121- return KVCacheFull
1274+ cls = KVCacheFull
11221275 elif cache_strategy == "l2" :
1123- return KVCacheL2
1276+ cls = KVCacheL2
11241277 elif cache_strategy == "random" :
1125- return KVCacheRandom
1278+ cls = KVCacheRandom
11261279 elif cache_strategy == "window" :
1127- return KVCacheWindow
1280+ cls = KVCacheWindow
11281281 elif cache_strategy == "scissor" :
1129- return KVCacheScissorhands
1282+ cls = KVCacheScissorhands
11301283 elif cache_strategy == "fastgen" :
1131- return KVCacheFastGen
1284+ cls = KVCacheFastGen
1285+ elif cache_strategy .startswith ("debug" ):
1286+ cache_strategy = re .sub (r"debug_+" , "" , cache_strategy ).strip ()
1287+ relevant_kwargs = get_cache_constructor (cache_strategy )[1 ]
1288+ cls = (
1289+ lambda max_batch_size , n_heads , head_dim , dtype , ** kwargs : KVCacheAnalysis (
1290+ max_batch_size ,
1291+ n_heads ,
1292+ head_dim ,
1293+ dtype ,
1294+ cache_strategy = cache_strategy ,
1295+ ** kwargs ,
1296+ )
1297+ )
11321298 else :
11331299 raise ValueError (f"Invalid cache strategy: { cache_strategy } " )
1300+
1301+ return cls , relevant_kwargs or cls .relevant_kwargs
0 commit comments