Skip to content

Conversation

@griff4692
Copy link
Contributor

  • Window keeps evicts earliest tokens first.
  • Used ruff to reformat the Python files

Copy link
Collaborator

@haileyschoelkopf haileyschoelkopf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@griff4692 left some comments but haven't finished going through everything just yet!

Will also follow up with you to ensure I'm on the same page for certain design decisions

generate.py Outdated
parser.add_argument(
"--max_cache_length",
type=float,
default=512,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
default=512,
default=1,

default to not-windowed? Or I suppose this is ignored unless using "window" for cache_strategy?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not windowed makes sense (1.0)

If cache_strategy == full, max_cache_length has to be 512.

generate.py Outdated
# Optional Cache Kwargs depending on cache_strategy
parser.add_argument(
"--global_tokens",
default=128,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might 4 be a more reasonable default?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes - sorry the defaults right now were somewhat random but will fix to 4. I figured they'd all be adjusted during experimentation!


self.updates = 0

def is_prefill(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I'm confused about the role of is_prefill here. Is this for certain methods which won't be using any fancy approaches during the prefill stage?

Copy link
Contributor Author

@griff4692 griff4692 May 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't design this very well. is_prefill should probably exist outside of the Cache class.

The only place it's used is in generate.py

...
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))

if self.kv_cache is not None:
    cache_k, cache_v, cache_mask = self.kv_cache.update(input_pos, k, v)

# If we are in the prefill stage, we use the existing prompt kv-pairs
if not self.kv_cache.is_prefill():
    k = cache_k
    v = cache_v
    mask = cache_mask.to(k.device)
...

The reason I added a switch for prefill is that during the prefill stage, we typically use full self-attention. If there's compression required to initialize the cache (|prompt| > max_cache_length), then this won't be full self-attention. I can explain further if it's unclear!

Copy link
Contributor Author

@griff4692 griff4692 May 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think it's ok for the cache to essentially record the generation step (self.updates) or would you put that logic in generate.py?

# input_pos: [B, S]
logits = model(x, input_pos)
# Fix GPU
causal_mask = (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note to self: haven't finished looking at this or terminator_ids yet

@griff4692
Copy link
Contributor Author

@griff4692 left some comments but haven't finished going through everything just yet!

Will also follow up with you to ensure I'm on the same page for certain design decisions

Thanks! No rush - I'll update as you make suggestions

model.py Outdated
k = cache_k
v = cache_v
# We also need to ask the cache for its dynamic mask which changes based on updates and possibly evictions
# TODO - why is this not always loaded on GPU?
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@haileyschoelkopf - I'm wondering if given that the caches are attached to the model (seesetup_caches), why isn't the mask which is registered as a buffer loaded onto the same device as the model (which is cuda)? I don't know enough about how this works but curious if you do!

@griff4692 griff4692 force-pushed the window branch 2 times, most recently from 6c9dd9f to 538314b Compare June 3, 2024 22:36
@griff4692
Copy link
Contributor Author

Squashed everything into a single commit to make it easier to follow

- Creates cache.py
- Introduces global_tokens
- Formats repo with ruff
- Speed parity with full KV-cache
@griff4692
Copy link
Contributor Author

@haileyschoelkopf - removed mutable python function args and made a few other minor edits. going to merge now and rebase my heavy hitters code onto the new main.

@griff4692 griff4692 merged commit a4dd428 into main Jun 4, 2024
@griff4692 griff4692 deleted the window branch June 4, 2024 11:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants