-
Notifications
You must be signed in to change notification settings - Fork 16
Implements window KV-Cache Compression Strategy #9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
haileyschoelkopf
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@griff4692 left some comments but haven't finished going through everything just yet!
Will also follow up with you to ensure I'm on the same page for certain design decisions
generate.py
Outdated
| parser.add_argument( | ||
| "--max_cache_length", | ||
| type=float, | ||
| default=512, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| default=512, | |
| default=1, |
default to not-windowed? Or I suppose this is ignored unless using "window" for cache_strategy?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not windowed makes sense (1.0)
If cache_strategy == full, max_cache_length has to be 512.
generate.py
Outdated
| # Optional Cache Kwargs depending on cache_strategy | ||
| parser.add_argument( | ||
| "--global_tokens", | ||
| default=128, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
might 4 be a more reasonable default?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes - sorry the defaults right now were somewhat random but will fix to 4. I figured they'd all be adjusted during experimentation!
|
|
||
| self.updates = 0 | ||
|
|
||
| def is_prefill(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I'm confused about the role of is_prefill here. Is this for certain methods which won't be using any fancy approaches during the prefill stage?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't design this very well. is_prefill should probably exist outside of the Cache class.
The only place it's used is in generate.py
...
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
if self.kv_cache is not None:
cache_k, cache_v, cache_mask = self.kv_cache.update(input_pos, k, v)
# If we are in the prefill stage, we use the existing prompt kv-pairs
if not self.kv_cache.is_prefill():
k = cache_k
v = cache_v
mask = cache_mask.to(k.device)
...
The reason I added a switch for prefill is that during the prefill stage, we typically use full self-attention. If there's compression required to initialize the cache (|prompt| > max_cache_length), then this won't be full self-attention. I can explain further if it's unclear!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you think it's ok for the cache to essentially record the generation step (self.updates) or would you put that logic in generate.py?
| # input_pos: [B, S] | ||
| logits = model(x, input_pos) | ||
| # Fix GPU | ||
| causal_mask = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note to self: haven't finished looking at this or terminator_ids yet
Thanks! No rush - I'll update as you make suggestions |
model.py
Outdated
| k = cache_k | ||
| v = cache_v | ||
| # We also need to ask the cache for its dynamic mask which changes based on updates and possibly evictions | ||
| # TODO - why is this not always loaded on GPU? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@haileyschoelkopf - I'm wondering if given that the caches are attached to the model (seesetup_caches), why isn't the mask which is registered as a buffer loaded onto the same device as the model (which is cuda)? I don't know enough about how this works but curious if you do!
6c9dd9f to
538314b
Compare
|
Squashed everything into a single commit to make it easier to follow |
- Creates cache.py - Introduces global_tokens - Formats repo with ruff - Speed parity with full KV-cache
|
@haileyschoelkopf - removed mutable python function args and made a few other minor edits. going to merge now and rebase my heavy hitters code onto the new |
ruffto reformat the Python files