-
Notifications
You must be signed in to change notification settings - Fork 6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cache the KV projection history when generating #76
Conversation
Thank you for the quality code.
|
I think they are pretty much identical up to renaming (not intentionally, it's just that it's hard to implement it differently). Unless I've done something really stupid and fail to see it. :) My version: def forward(..., past_kv_proj=None):
if past_kv_proj is not None:
past_k_proj, past_v_proj = past_kv_proj
...
k = torch.cat((past_k_proj, k), dim=2)
v = torch.cat((past_v_proj, v), dim=2)
...
present_kv_proj = (k, v) ...
The Huggingface version you linked to: def forward(
...
layer_past: Optional[Tuple[torch.Tensor]] = None,
) ... :
...
if layer_past is not None:
past_key, past_value = layer_past
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)
...
present = (key, value)
So I don't really see where I keep updating previous keys and values. My intention was to take previous KV projections and concatenate them with the current projections, exactly as in the HuggingFace implementation. |
I guess it's best for me to close this PR, given that the upstream version has diverged from mine (after integrating FlashAttention) and that I never intended this to be merged anyway. This PR is still linked to the original minGPT issue, so anyone interested in implementing cached generation can still find it. |
Whoops, sorry, I didn't notice you were referring to Andrej, not me. But if you don't mind me answering anyway: mathematically, the version in the main branch of this repo and the Huggingface version with I guess the easiest way to see this would be to run the version from the main branch and print key and value vectors for every layer after every generated token. You'll notice that, once the vectors are generated for a token at some step More theoretically, you could notice that key and value projections for time step
If all of the dependencies don't change when we compute steps |
Jumping in on this for a moment since I was tempted to implement this same optimization today, then found this PR.
Indeed this would seem to be a problem on one level, since position embeddings are part of the computation for the keys and values... but Huggingface's standard GPT2 also uses absolute position embeddings (which is why they recommend padding on the right of a short sequence rather than left), while still implementing the |
This is a very good question. I thought the answer was "they don't" (i.e., they just throw an exception or something at that point), but it's slightly more complicated than that. This comment explores the possible options. Option 1 is giving up on caching and generating the next tokens "the old way" once the input is too long. This way you only get a speedup for some prefix of the generation. This could work but IMO would be too complex for an educational codebase such as nanoGPT. Option 2 is cutting a prefix of the prompt so that more tokens can be generated using the cache (this is what they provide as Option 3 is the most interesting, and I don't think I fully understand it. Apparently you can keep generating tokens with "wrong" positions, and the generation quality will only degrade slightly? When I tried something like that, the model started outputting pure garbage pretty quickly. |
Excellent find! I was searching for a discussion like that one in their repo. It seems to me that Option 3 is probably not desirable, as you already noted. I feel like the optimum solution might be something like Option 2, but I think their naming ("hole") is a bit confusing. If an option like that was offered as an argument to the But instead of losing all speedup from cached keys/values once you exceed the maximum, there could be an option like "shift_by_tokens" that lets you select how far the model will shift the entire window all at once. So if eg. max context length was 100, then once it has reached Perhaps that makes sense? As you said, the educational nature of this project makes the choice tricky, but exposing the trade-offs of caching seems useful. |
This sounds promising, but I am little unclear on how exactly this will work. As far as I understand, once you compute an embedding for a token by combining a token and a positional embedding, there is no straightforward way to change either without recalculating all layers of the decoder. So, in you hypothetical situation with 100 tokens generated, you have the KV-history that is calculated based on the following (
The already-existing cutting of the context essentially changes the history to look like this (since we don't have
After the cutting, we have to re-calculate everything from scratch, since all token embeddings have changed. If I'm understanding correctly what you're proposing, after clearing the leftmost 20 tokens, we would have this:
Even though you got rid of the leftmost tokens, you still can't change any However, it's been a long time since I last looked at the codebase, so it is entirely possible that I'm missing something and/or misunderstanding what you are proposing. Perhaps it's better to just ignore what I'm saying and create a small proof-of-concept branch to see it your idea works. (and it's ultimately up to Andrej whether he wants to see something like this in his codebase, I don't have any say in this :)) |
I may not have written very clearly—I meant that when the cut happens, we would drop the cached keys/values and intentionally recalculate everything from scratch (for the newly cut range of tokens, eg the last 80 that are now our entire context in the example). But then the cached kvs from that step would still be usable for the next 20 steps (or whatever the shift size is). So when we get to token 101, we'd have something like this
... which means recomputing all the KVs at this step, but then we get another 19 steps in a row with the new cached values--so we end up gaining cache-based speed up on about 95% of our steps past the token limit (in this contrived example of 20 shift / 100 max), while losing some of the context window size to make it work. So the idea isn’t to solve the need to recompute everything when generating longer text, but to let the user set a trade off, so that we only recompute from scratch every But again I might have misinterpreted. |
Oooooh, now I see, thanks! This is really clever and makes perfect sense. I bet there are a lot of real-life scenarios where you have a model with absolute position embeddings, do not care about the reduced context size all that much, but would appreciate getting a performance boost. At this point I would ask @karpathy what he thinks about this idea (probably not here, since this is a closed PR, but in the old minGPT issue or a new one in this repo). If it turns out that Andrej doesn't want introducing additional complexity after all, maybe it's even worth trying to propose this to Huggingface? |
Add scripts compatible with jsbach midi json files
Edit: Previous comment: For each transformer block where K and V are needed to be recalculated, attention is also calculated and attention dot product should dominates the runtime over calculating K and V. Suppose in the best case scenario where the model only consists transformer blocks and attention dot product cost the same amount time as KV calculation, runtime reduction should be only 50% but you are seeing a > 80% runtime reduction when running on CPU. Maybe I am missing something here. |
Add scripts compatible with jsbach midi json files
This PR is a mostly failed attempt to fix issue #95 from the minGPT repo.
The idea is to save the results of key and value projections in each self-attention layer for previously generated tokens. With saved projections, you can essentially convert all matrix-matrix multiplications at every generation step into matrix-vector multiplications, since you now only need to apply linear transformations to the very last token. This is a pretty standard optimization technique for sequential Transformer generation. For example, I think Huggingface calls the cached projections
past_key_values
.The only positive impact of this PR is a tremendous speed-up of CPU generation. E.g., on my MacBook Air with Apple M1:
python3 sample.py [...] --device=cpu --dtype=float32 --num_samples=1
(I used the following hacky patch to generate sentences directly from pretrained GPT models and print generation times)
Unfortunately, with A100, the speed-up is a rounding error even for GPT-XL:
python sample.py [...] --device=cuda:7 --dtype=float32 --num_samples=1
(for some reason, both
bfloat16
andfloat16
are slower on A100 thanfloat32
, even if I don't make any code changes, so I didn't bother measuring them)Even if the speed-up was more pronounced, I don't think cached generation is worth it, for two reasons:
max_new_tokens
exceedsblock_size
, since then the absolute positions of the previous tokens change, and the cached KV history is no longer valid. I guess you could sidestep this with something like rotary positional embeddings, but then you lose the ability to initialize from stock GPT models.So, this PR is more of a proof-of-concept and should not be merged. Although it might be a good idea to add a comment to
GPT.generate()
with an explanation why it recomputes the previous tokens from scratch at every step, to prevent anyone else from going down this particular rabbit hole. :)