-
Notifications
You must be signed in to change notification settings - Fork 148
AdaKVPress #38
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
AdaKVPress #38
Changes from all commits
Commits
Show all changes
27 commits
Select commit
Hold shift + click to select a range
82cc93e
Handle transformers breaking changes
SimJeg 6d5f34e
Add AdaKVPress (first version)
SimJeg 9a46d7a
Add alpha_safeguard
SimJeg 2693558
Move from least squares to perceptron
SimJeg b16ab6a
Remove GQA
SimJeg d19edc7
Fix attention patch
SimJeg 79156c7
Align with ScorerPress
SimJeg 002ac9d
Update evaluate
SimJeg 5817935
Fix attention patch
SimJeg 31f6b12
Some cleaning
SimJeg f5cb200
Add check
SimJeg 64f0e99
Add first test
SimJeg 7f06fce
Merge branch 'main' into simon/adakv-press-448
SimJeg 6227bcd
Adress PR feedback
SimJeg 7e2ba3b
Update position embeddings
SimJeg 21fb7a6
Fix ThinkPress
SimJeg 64d0260
Fix ExpectedAttentionPress with wrapper
SimJeg bbbb755
Fix head_dim and patch
SimJeg a86534a
Fix flake8
SimJeg 7ea10c7
Update notebook
SimJeg 8edba20
merge
SimJeg 54bb346
Merge branch 'main' into simon/adakv-press-448
SimJeg e78f8fc
Update python version
SimJeg 67654a7
Update patch
SimJeg f804690
Back to 3.10.11
SimJeg 9edb610
Add docs
SimJeg 3ac3df2
Fix
SimJeg File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,58 @@ | ||
| import torch | ||
| from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS | ||
|
|
||
|
|
||
| def search_hyperplane(X, max_iter: int = 1000): | ||
| """ | ||
| Given a tensor X of shape (bsz, seq_len, head_dim), search for an hyperplane Y (bsz, head_dim) | ||
| such that for every i, <X[:, i], Y> <= 0. Returns - 1e5 * Y / ||Y|| ** 2 to ensure exp(<X, Y>) = 0 | ||
| Raises a ValueError if no such hyperplane is found | ||
| """ | ||
| Y = X.mean(1) # this initialization is enough for most cases | ||
| for _ in range(max_iter): | ||
| mask = torch.bmm(X, Y.unsqueeze(-1)) <= 0 | ||
| if not mask.any(): | ||
| return -1e5 * Y / Y.norm(dim=-1, keepdim=True) ** 2 | ||
| Y += (X * mask).sum(1) / mask.sum(1).clamp(min=1) | ||
| raise ValueError("Could not find fake keys such that for every query q, exp(<q, k>) = 0") | ||
|
|
||
|
|
||
| def attention_patch(func): | ||
| """ | ||
| Decorator to udpate the keys before the attention computation at the indices provided in module.masked_key_indices | ||
| The keys are updated with a fake key k such that exp(<q, k>) = 0 to fake head-wise compression | ||
| This solution is not optimal as it does not reduce peak memory and slightly increase runtime | ||
| """ | ||
|
|
||
| def wrapper(module, query, key, value, attention_mask, dropout, **kwargs): | ||
| if query.shape[2] == key.shape[2]: | ||
| # Prefilling | ||
| module.masked_key_indices = None | ||
| elif module.masked_key_indices is not None: | ||
| # Decoding: build fake keys k s.t. exp(<q, k>) = 0 | ||
| bsz, num_heads, seq_len, head_dim = query.shape | ||
| num_key_value_heads = key.shape[1] | ||
| num_groups = num_heads // num_key_value_heads | ||
|
|
||
| # Build a fake key k per key group such that for every query q, exp(<q, k>) = 0 | ||
| q = query.view(bsz, num_key_value_heads, num_groups, seq_len, head_dim) | ||
| q = q.reshape(bsz * num_key_value_heads, num_groups * seq_len, head_dim) | ||
| k = search_hyperplane(q) | ||
| k = k.view(bsz, num_key_value_heads, head_dim) | ||
|
|
||
| # At indices, update the keys to the fake keys | ||
| batch_indices, head_indices, seq_indices = module.masked_key_indices | ||
| key[batch_indices, head_indices, seq_indices] = k[batch_indices, head_indices] | ||
|
|
||
| return func(module, query, key, value, attention_mask, dropout, **kwargs) | ||
|
|
||
| return wrapper | ||
|
|
||
|
|
||
| def patch_attention_functions(): | ||
| """ | ||
| Add the attention_patch decorator to functions in ALL_ATTENTION_FUNCTIONS | ||
| """ | ||
|
|
||
| for name, func in ALL_ATTENTION_FUNCTIONS.items(): | ||
| ALL_ATTENTION_FUNCTIONS[name] = attention_patch(func) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,62 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
|
|
||
| from dataclasses import dataclass | ||
|
|
||
| import torch | ||
|
|
||
| from kvpress.presses.base_press import BasePress | ||
| from kvpress.presses.scorer_press import ScorerPress | ||
|
|
||
|
|
||
| @dataclass | ||
| class AdaKVPress(BasePress): | ||
| """ | ||
| AdaKV (https://arxiv.org/abs/2407.11550) selects the top-k keys and values among all heads in a layer | ||
| based on the scores, achieving head-specific compression. | ||
| A safeguard is applied to ensure a minimum fraction of KV pairs per head (alpha_safeguard parameter) | ||
| This press has been reviewed by Yuan Feng, first author of AdaKV. | ||
| """ | ||
|
|
||
| scorer: ScorerPress | ||
| alpha_safeguard: float = 0.20 | ||
|
|
||
| def __post_init__(self): | ||
| assert isinstance(self.scorer, ScorerPress), "AdaKVPress requires a ScorerPress as input" | ||
| assert 0 <= self.alpha_safeguard <= 1, "alpha_safeguard should be in [0, 1]" | ||
|
|
||
| @property | ||
| def compression_ratio(self): | ||
| return self.scorer.compression_ratio | ||
|
|
||
| @compression_ratio.setter | ||
| def compression_ratio(self, value): | ||
| self.scorer.compression_ratio = value | ||
|
|
||
| def compress(self, module, hidden_states, keys, values, attentions, kwargs): | ||
| if self.compression_ratio == 0: | ||
| return keys, values | ||
|
|
||
| assert module.config._attn_implementation != "eager", "eager mode not supported" | ||
|
|
||
| # Compute scores | ||
| scores = self.scorer.score(module, hidden_states, keys, values, attentions, kwargs) | ||
| bsz, num_key_value_heads, q_len = scores.shape | ||
|
|
||
| # Make sure to keep at least alpha * (1 - compression_ratio) KV pairs per head | ||
| n_kept = int(q_len * (1 - self.compression_ratio)) # ScorerPress definition | ||
| n_safe = int(n_kept * self.alpha_safeguard) | ||
| top_indices = torch.topk(scores, n_safe, dim=-1).indices | ||
| scores.scatter_(-1, top_indices, torch.finfo(scores.dtype).max) | ||
|
|
||
| # Compute bottom-k across heads | ||
| n_pruned = num_key_value_heads * (q_len - n_kept) | ||
| indices = torch.topk(-scores.reshape(bsz, -1), n_pruned, dim=1).indices.flatten() | ||
|
|
||
| # Save indices to mask during the attention mechanism. Please refer to attention_patch.py for more details | ||
| batch_indices = torch.arange(bsz).repeat_interleave(n_pruned) | ||
| head_indices = indices // q_len | ||
| seq_indices = indices % q_len | ||
| module.masked_key_indices = (batch_indices, head_indices, seq_indices) | ||
| return keys, values |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,9 @@ | ||
| import torch | ||
| from kvpress.attention_patch import search_hyperplane | ||
|
|
||
|
|
||
| def test_search_hyperplane(): | ||
| bsz, seq_len, head_dim = 50, 500, 128 | ||
| X = torch.rand(bsz, seq_len, head_dim) | ||
| Y = search_hyperplane(X) | ||
| assert torch.exp(torch.bmm(X, Y.unsqueeze(-1))).max() == 0 |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.