-
Notifications
You must be signed in to change notification settings - Fork 16
Implement Scissorhands KV-cache compression & SnapKV prompt compression #11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
cache.py
Outdated
| attn = attn.squeeze() | ||
| keys = attn.shape[1] | ||
| attn_is_low = (attn < 1 / keys).int() | ||
| self.attn_history[:, :, :keys, self.attn_counter % self.history_window_size] = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Circular queue % self.history_window_size ensures we are always inserting the latest attention value into the most stale / old slot.
| ) | ||
| self.attn_counter += 1 | ||
|
|
||
| def refill_eviction_queue(self, input_pos: int): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not in the Scissorhands paper but I explain it in the main PR comment
|
|
||
| k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1) | ||
| v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1) | ||
| k_rep = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change to k_rep instead of k since we don't want the repeated k, v passed to the attention callback
cache.py
Outdated
| """ | ||
| attn = attn.squeeze() | ||
| keys = attn.shape[1] | ||
| attn_is_low = (attn < 1 / keys).int() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
treats all tokens as being unimportant if its attention score is lower than uniform attn (1 / keys)
cache.py
Outdated
| num_insertions = k_val.shape[2] | ||
| # Update global tokens to the prompt size if set to -1 | ||
| if self.insertions == 0 and self.global_tokens == -1: | ||
| self.global_tokens = num_insertions |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added this so that we have a setting where we always keep the prompt tokens -- by setting global_tokens=-1
| if prompt_overflow: | ||
| return k_val, v_val, self.compress_prompt | ||
|
|
||
| # If the cache requires attention weights to manage evictions, we need to pass self.update_attn_history as a callback |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
todo (me): check if we have a way to only update attn history every m steps, as done by Scissorhands
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think I saw this. I suspect we'll want to add it? as otherwise the overhead is quite high for this method
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am recording the attention history every step but only aggregating and computing which tokens to evict every drop_amount (m) steps.
I feel like we have to store attention probs at each step (otherwise we'd have to recompute them when we need them) -- is there a better way?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Specifically scissorhands seems to say they only check and record attn probs every M steps for some hparam M:
So you'd only, every M steps, check the attention prob and appropriately increment the numerator and denominator based on just that step. Probably more brittle due to fewer observations over which to average / for which future-important tokens have the opportunity to reach higher-than-uniform attn probs, but would reduce computational overheads a lot.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's keep discussing over discord! closing for now
| ) | ||
|
|
||
| if attn_callback: | ||
| # Mean pool over the grouped queries (average over self.n_head // self.n_local_heads) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this choice documented in Scissorhands / SnapKV?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I couldn't find any mention of GQA... probably best to just check which models they test on ... they might not be trained with GQA
Scissorhands records the number of times a token in the KV Cache had low attention score (< uniform probability) over a history window (defaulted to 400 tokens). It evicts the tokens with the highest fraction of unimportant attentions per attention head. To avoid aggregating attention unimportances every step, they perform bulk evictions. This leads to empty slots in the KV Cache. I modified it slightly so that instead of periodically evicting in bulk, I periodically update an "eviction queue" which is then used to perform evictions at each step. This change allows us to avoid expensive re-calculation without having to perform bulk evictions.
SnapKV compresses long prompts by separating a prompt into a "prefix" and an "observation window". The method keeps every token in the observation window and compresses the "prefix" based on the attention scores from the observation window. This method is only called when
prompt length > max cache length. In this case, we can't just insert into the KV cache in theupdatemethod. In order to compress, we need the attention scores which we only get after runningspda. So, inline 160 of cache.py, we pass a "compress_prompt" callback which returns the attention scores to the cache afterspdais performed. In turn this method first compresses the prompt and then calls the standardupdatemethod to insert it into the cache.