From 82cc93ee3bd6b097cb840e9743d69e99b17fd3a5 Mon Sep 17 00:00:00 2001 From: SimJeg Date: Mon, 6 Jan 2025 12:44:16 +0000 Subject: [PATCH 01/25] Handle transformers breaking changes --- README.md | 6 ++-- kvpress/presses/base_press.py | 13 ++++----- kvpress/presses/expected_attention_press.py | 8 +++-- kvpress/presses/key_rerotation_press.py | 2 +- kvpress/presses/observed_attention_press.py | 4 +-- kvpress/presses/scorer_press.py | 2 +- kvpress/presses/snapkv_press.py | 14 +++++---- kvpress/presses/think_press.py | 10 +++++-- notebooks/new_press.ipynb | 2 +- pyproject.toml | 4 +-- .../presses/test_key_rerotation_press_rope.py | 29 ++++++++++--------- tests/test_pipeline.py | 7 +++-- 12 files changed, 55 insertions(+), 46 deletions(-) diff --git a/README.md b/README.md index 95f0e57b..2b9e59cf 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ ![kvpress](kvpress.jpg) -Deploying long-context LLMs is costly due to the linear growth of the key-value (KV) cache in transformer models. For example, handling 1M tokens with Llama 3.1-70B in float16 requires up to 330GB of memory. This repository implements multiple KV cache pruning methods and benchmarks using [🤗 transformers](https://huggingface.co/docs/transformers/en/index), aiming to simplify the development of new methods for researchers and developers in this field. +Deploying long-context LLMs is costly due to the linear growth of the key-value (KV) cache in transformer models. For example, handling 1M tokens with Llama 3.1-70B in float16 requires up to 330GB of memory. This repository implements multiple KV cache compression methods and benchmarks using [🤗 transformers](https://huggingface.co/docs/transformers/en/index), aiming to simplify the development of new methods for researchers and developers in this field. ## Installation @@ -60,7 +60,7 @@ All current presses are training free. Several of them inherit from `ScorerPress - `TOVAPress`: attention weight of the last query averaged across heads ([paper](https://arxiv.org/abs/2401.06104)) - `ObservedAttentionPress`: average attention weight observed during in pre-filling phase (similar to [H2O](https://arxiv.org/abs/2306.14048)) -Some presses relying on a different logic: +Some presses rely on a different logic: - `ThinKPress`: compress the dimensions of the keys based on the channel attention score on the last queries ([paper](https://arxiv.org/pdf/2407.21018)) - `SimLayerKVPress`: identify "lazy" layers, and apply the StreamingLLM approach to them ([paper](https://arxiv.org/abs/2410.13846)) @@ -101,7 +101,7 @@ pipe(..., cache=cache) By default, the `DynamicCache` is used (no quantization). > [!IMPORTANT] -> To use the `QuantizedCache`, you need to install additional dependencies (e.g. `pip install optimum-quanto`, see also [this issue](https://github.com/huggingface/transformers/issues/34848)). +> To use the `QuantizedCache`, you need to install additional dependencies (e.g. `pip install optimum-quanto`). ## FAQ diff --git a/kvpress/presses/base_press.py b/kvpress/presses/base_press.py index 8a1bea2f..e5589e11 100644 --- a/kvpress/presses/base_press.py +++ b/kvpress/presses/base_press.py @@ -86,17 +86,13 @@ def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dic Modified output of the forward pass of the layer. """ - # See e.g. LlamaDecoderLayer.forward for the output structure - if len(output) == 3: - _, attentions, cache = output - else: - attentions, cache = None, output[-1] hidden_states = kwargs["hidden_states"] + cache = kwargs["past_key_value"] q_len = hidden_states.shape[1] # Don't compress after pre-filling - if cache.seen_tokens > q_len: + if kwargs["cache_position"][-1] > q_len: return output if isinstance(cache, QuantizedCache): @@ -106,7 +102,7 @@ def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dic keys = cache.key_cache[module.layer_idx] values = cache.value_cache[module.layer_idx] - keys, values = self.compress(module, hidden_states, keys, values, attentions, kwargs) + keys, values = self.compress(module, hidden_states, keys, values, output[1], kwargs) if isinstance(cache, QuantizedCache): cache._quantized_key_cache[module.layer_idx] = cache._quantize(keys, axis=cache.axis_key) @@ -138,6 +134,9 @@ def __call__(self, model: PreTrainedModel) -> Generator: hooks = [] try: for layer in model.model.layers: + if hasattr(model.model, "rotary_emb"): + # Make rotary embeddings available to the forward hook + layer.self_attn.rotary_emb = model.model.rotary_emb hooks.append(layer.self_attn.register_forward_hook(self.forward_hook, with_kwargs=True)) yield diff --git a/kvpress/presses/expected_attention_press.py b/kvpress/presses/expected_attention_press.py index 747a9597..49cb311a 100644 --- a/kvpress/presses/expected_attention_press.py +++ b/kvpress/presses/expected_attention_press.py @@ -39,7 +39,7 @@ def get_query_statistics(self, module: nn.Module, hidden_states: torch.Tensor): """ bsz, q_len, _ = hidden_states.shape - n, d = module.num_heads, module.head_dim + n, d = module.config.num_attention_heads, module.config.head_dim # Remove first hidden_states that likely contain outliers h = hidden_states[:, self.n_sink :] @@ -117,14 +117,16 @@ def score( # Compute scores bsz, num_key_value_heads, q_len, d = keys.shape - keys = repeat_kv(keys, module.num_key_value_groups).transpose(2, 3) + num_key_value_groups = module.config.num_attention_heads // num_key_value_heads + + keys = repeat_kv(keys, num_key_value_groups).transpose(2, 3) scores = torch.matmul(mean_query.unsqueeze(2), keys).squeeze(2) / math.sqrt(d) if self.use_covariance: scores += torch.einsum("bhin, bhij, bhjn->bhn", keys, cov_query, keys) / d / 2 scores = F.softmax(scores, dim=-1) # Average scores across groups - scores = scores.view(bsz, num_key_value_heads, module.num_key_value_groups, q_len) + scores = scores.view(bsz, num_key_value_heads, num_key_value_groups, q_len) scores = scores.mean(dim=2) # Rescale scores by the norm of the values diff --git a/kvpress/presses/key_rerotation_press.py b/kvpress/presses/key_rerotation_press.py index 514178a5..6c5b7356 100644 --- a/kvpress/presses/key_rerotation_press.py +++ b/kvpress/presses/key_rerotation_press.py @@ -47,7 +47,7 @@ def compress( q_len = hidden_states.shape[1] n_kept = int(q_len * (1 - self.press.compression_ratio)) indices = scores.topk(n_kept, dim=-1).indices - indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim) + indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.config.head_dim) cos, sin = get_rope_embeddings(module, keys) # Rerotate as follows diff --git a/kvpress/presses/observed_attention_press.py b/kvpress/presses/observed_attention_press.py index d11d2f01..4e9e78c9 100644 --- a/kvpress/presses/observed_attention_press.py +++ b/kvpress/presses/observed_attention_press.py @@ -54,8 +54,6 @@ def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dic # attentions are needed as input for the hook, but unless the user wants to return them in the output, # we can remove them to save memory if not self.output_attentions: - output = list(output) - output[-2] = None - output = tuple(output) + output = (output[0], None) return output diff --git a/kvpress/presses/scorer_press.py b/kvpress/presses/scorer_press.py index ea97eab4..8bbdea61 100644 --- a/kvpress/presses/scorer_press.py +++ b/kvpress/presses/scorer_press.py @@ -62,7 +62,7 @@ def compress( q_len = hidden_states.shape[1] n_kept = int(q_len * (1 - self.compression_ratio)) indices = scores.topk(n_kept, dim=-1).indices - indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim) + indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.config.head_dim) # Prune keys and values keys = keys.gather(2, indices).contiguous() diff --git a/kvpress/presses/snapkv_press.py b/kvpress/presses/snapkv_press.py index 9265862e..05ee6f9f 100644 --- a/kvpress/presses/snapkv_press.py +++ b/kvpress/presses/snapkv_press.py @@ -34,17 +34,20 @@ def compute_window_attention( """ bsz, q_len, _ = hidden_states.shape + num_heads = module.config.num_attention_heads + head_dim = module.config.head_dim + num_key_value_groups = num_heads // module.config.num_key_value_heads # Get last window_size queries if hasattr(module, "q_proj"): query_states = module.q_proj(hidden_states[:, -window_size:]) elif hasattr(module, "qkv_proj"): qkv = module.qkv_proj(hidden_states[:, -window_size:]) - query_states = qkv[..., : module.num_heads * module.head_dim] + query_states = qkv[..., : num_heads * head_dim] else: raise NotImplementedError(f"SnapKV not yet implemented for {module.__class__}.") - query_states = query_states.view(bsz, window_size, module.num_heads, module.head_dim).transpose(1, 2) + query_states = query_states.view(bsz, window_size, num_heads, head_dim).transpose(1, 2) # Apply RoPE position_ids = torch.arange(q_len - window_size, q_len).unsqueeze(0).to(query_states.device) @@ -52,8 +55,8 @@ def compute_window_attention( query_states = (query_states * cos.unsqueeze(1)) + (rotate_half(query_states) * sin.unsqueeze(1)) # Compute attention for first q_len - window_size tokens - key_states = repeat_kv(keys, module.num_key_value_groups) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(module.head_dim) + key_states = repeat_kv(keys, num_key_value_groups) + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(head_dim) attention_mask = torch.ones_like(attn_weights) * float("-inf") attention_mask = torch.triu(attention_mask, diagonal=q_len - window_size + 1) attn_weights += attention_mask @@ -73,6 +76,7 @@ def score( ) -> torch.Tensor: bsz, num_key_value_heads, q_len, _ = keys.shape + num_key_value_groups = module.config.num_attention_heads // num_key_value_heads assert q_len > self.window_size, "Query length should be greater than the window size" @@ -85,7 +89,7 @@ def score( scores = F.avg_pool1d(scores, kernel_size=self.kernel_size, padding=self.kernel_size // 2, stride=1) # Average per grioup (https://github.com/FasterDecoding/SnapKV/issues/22) - scores = scores.view(bsz, num_key_value_heads, module.num_key_value_groups, q_len - self.window_size) + scores = scores.view(bsz, num_key_value_heads, num_key_value_groups, q_len - self.window_size) scores = scores.mean(2) # Add back the observation window. Use max score to make sure the window is not pruned. diff --git a/kvpress/presses/think_press.py b/kvpress/presses/think_press.py index 2f882e56..5309532e 100644 --- a/kvpress/presses/think_press.py +++ b/kvpress/presses/think_press.py @@ -33,17 +33,19 @@ def compute_window_queries(self, module, hidden_states): Re-compute the last window_size query states """ bsz, q_len, _ = hidden_states.shape + num_heads = module.config.num_attention_heads + head_dim = module.config.head_dim # Get last window_size queries if hasattr(module, "q_proj"): query_states = module.q_proj(hidden_states[:, -self.window_size :]) elif hasattr(module, "qkv_proj"): qkv = module.qkv_proj(hidden_states[:, -self.window_size :]) - query_states = qkv[..., : module.num_heads * module.head_dim] + query_states = qkv[..., : num_heads * head_dim] else: raise NotImplementedError(f"SnapKV not yet implemented for {module.__class__}.") - query_states = query_states.view(bsz, self.window_size, module.num_heads, module.head_dim).transpose(1, 2) + query_states = query_states.view(bsz, self.window_size, num_heads, head_dim).transpose(1, 2) # Apply RoPE position_ids = torch.arange(q_len - self.window_size, q_len).unsqueeze(0).to(query_states.device) @@ -71,9 +73,11 @@ def compress( # Compute scores per dimension bsz, num_key_value_heads, q_len, head_dim = keys.shape + num_key_value_groups = module.config.num_attention_heads // num_key_value_heads + queries = self.compute_window_queries(module, kwargs["hidden_states"]) queries_norm = torch.pow(queries, 2).mean(dim=2) # (bsz, num_heads, head_dim) - queries_norm = queries_norm.view(bsz, num_key_value_heads, module.num_key_value_groups, module.head_dim).mean(2) + queries_norm = queries_norm.view(bsz, num_key_value_heads, num_key_value_groups, module.config.head_dim).mean(2) keys_norm = torch.pow(keys, 2).mean(dim=2) key_scores = queries_norm * keys_norm # (bsz, num_key_value_heads, head_dim) diff --git a/notebooks/new_press.ipynb b/notebooks/new_press.ipynb index c9ed6769..3ffb2799 100644 --- a/notebooks/new_press.ipynb +++ b/notebooks/new_press.ipynb @@ -157,7 +157,7 @@ " # For demonstration, we show some details on the shape for the first layer\n", " if module.layer_idx == 0:\n", " print(f\"module: {module}\")\n", - " print(f\"Number of key value heads: {module.num_key_value_heads}\")\n", + " print(f\"Number of key value heads: {module.config.num_key_value_heads}\")\n", " print(f\"Sequence length: {hidden_states.shape[1]}\")\n", " print()\n", " print(f\"hidden_states shape: {hidden_states.shape}\")\n", diff --git a/pyproject.toml b/pyproject.toml index cb0a686f..85d8d6e2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,7 +2,7 @@ name = "kvpress" authors = ["Simon Jegou", "Maximilian Jeblick", "Jiwei Liu", "David Austin"] description = "Efficiently compress the KV cache of any pretrained transformer" -version = "0.1.0" +version = "0.1.1" readme = "README.md" [tool.poetry.dependencies] @@ -14,7 +14,7 @@ scipy = "^1.13.1" matplotlib = "^3.9.0" bs4 = "^0.0.2" torch = "^2.3.1" -transformers = "^4.45.1" +transformers = "^4.48.0" nvitop = "^1.3.2" sentencepiece = "^0.2.0" protobuf = "^5.27.2" diff --git a/tests/presses/test_key_rerotation_press_rope.py b/tests/presses/test_key_rerotation_press_rope.py index b612a6b8..f7082d10 100644 --- a/tests/presses/test_key_rerotation_press_rope.py +++ b/tests/presses/test_key_rerotation_press_rope.py @@ -31,21 +31,22 @@ def test_rerotate_keys_is_matches_reference_implementation(unit_test_model: Llam original_press = RandomPressStoreIndices(compression_ratio=0.5) key_rerotation_press = KeyRerotationPress(press=original_press) - module = unit_test_model.model.layers[0].self_attn - hidden_states = torch.randn( - 8, 64, module.config.hidden_size, device=unit_test_model.device, dtype=unit_test_model.dtype - ) + with key_rerotation_press(unit_test_model): + module = unit_test_model.model.layers[0].self_attn + hidden_states = torch.randn( + 8, 64, module.config.hidden_size, device=unit_test_model.device, dtype=unit_test_model.dtype + ) - keys = get_keys_with_rope(module, hidden_states) + keys = get_keys_with_rope(module, hidden_states) - values = torch.randn_like(keys) - # Press result - keys_compressed, _ = key_rerotation_press.compress( - module, hidden_states, keys, values, attentions=None, kwargs=dict() - ) + values = torch.randn_like(keys) + # Press result + keys_compressed, _ = key_rerotation_press.compress( + module, hidden_states, keys, values, attentions=None, kwargs=dict() + ) - indices = original_press.indices - keys_compressed_ref = compute_rerotated_keys_comparison_implementation(module, hidden_states, indices) + indices = original_press.indices + keys_compressed_ref = compute_rerotated_keys_comparison_implementation(module, hidden_states, indices) assert torch.allclose(keys_compressed, keys_compressed_ref, atol=1e-6 if precision == "full" else 1e-3) @@ -82,7 +83,7 @@ def score( q_len = hidden_states.shape[1] n_kept = int(q_len * (1 - self.compression_ratio)) indices = scores.topk(n_kept, dim=-1).indices - indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim) + indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.config.head_dim) self.indices = indices return scores @@ -108,6 +109,6 @@ def compute_rerotated_keys_comparison_implementation(module: LlamaAttention, hid def get_keys_without_pos_embedding(module, hidden_states): key_states = module.k_proj(hidden_states) key_states = key_states.view( - key_states.shape[0], key_states.shape[1], module.num_key_value_heads, module.head_dim + key_states.shape[0], key_states.shape[1], module.config.num_key_value_heads, module.config.head_dim ).transpose(1, 2) return key_states diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 96993492..5e371db5 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -114,13 +114,14 @@ def test_pipeline_context_cache_is_invariant(unit_test_model): # noqa: F811 model = unit_test_model questions = ["When was this article written?"] tokenizer = AutoTokenizer.from_pretrained(model.config.name_or_path) + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") - compression_pipeline = KVPressTextGenerationPipeline(model=model, tokenizer=tokenizer, device=torch.device("cpu")) - input_ids_question = tokenizer(questions[0], return_tensors="pt", add_special_tokens=False)["input_ids"] + compression_pipeline = KVPressTextGenerationPipeline(model=model, tokenizer=tokenizer, device=device) + input_ids_question = tokenizer(questions[0], return_tensors="pt", add_special_tokens=False)["input_ids"].to(device) seq_len = 256 past_key_values: DynamicCache = model( - input_ids=torch.randint(0, 1000, (1, seq_len)), past_key_values=DynamicCache() + input_ids=torch.randint(0, 1000, (1, seq_len), device=device), past_key_values=DynamicCache() ).past_key_values assert past_key_values.get_seq_length() == seq_len From 6d5f34ee7ffea93c74e953aad556b997116962bb Mon Sep 17 00:00:00 2001 From: SimJeg Date: Tue, 7 Jan 2025 10:47:15 +0000 Subject: [PATCH 02/25] Add AdaKVPress (first version) --- kvpress/__init__.py | 5 ++++ kvpress/attention_patch.py | 48 ++++++++++++++++++++++++++++++++ kvpress/presses/adakv_press.py | 50 ++++++++++++++++++++++++++++++++++ 3 files changed, 103 insertions(+) create mode 100644 kvpress/attention_patch.py create mode 100644 kvpress/presses/adakv_press.py diff --git a/kvpress/__init__.py b/kvpress/__init__.py index e2693aaf..774eaa20 100644 --- a/kvpress/__init__.py +++ b/kvpress/__init__.py @@ -3,6 +3,7 @@ from kvpress.pipeline import KVPressTextGenerationPipeline +from kvpress.presses.adakv_press import AdaKVPress from kvpress.presses.base_press import BasePress from kvpress.presses.composed_press import ComposedPress from kvpress.presses.expected_attention_press import ExpectedAttentionPress @@ -18,7 +19,11 @@ from kvpress.presses.think_press import ThinKPress from kvpress.presses.tova_press import TOVAPress +from kvpress.attention_patch import patch_attention_functions +patch_attention_functions() + __all__ = [ + "AdaKVPress", "BasePress", "ComposedPress", "ScorerPress", diff --git a/kvpress/attention_patch.py b/kvpress/attention_patch.py new file mode 100644 index 00000000..705f168a --- /dev/null +++ b/kvpress/attention_patch.py @@ -0,0 +1,48 @@ +import torch +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + +LARGE_NEGATIVE_FLOAT = -float(1e5) + + +def attention_patch(func): + """ + Decorator to udpate the keys and values before the attention computation at the indices provided in module.indices + The keys are updated to a fake key k such that for the input queries q, exp() = 0. The values are set to 0. + This is used to fake head-wise compression. A more optimal solution would be to create a new kernel. + """ + + def wrapper(module, query, key, value, attention_mask, dropout, scaling=None, is_causal=None, **kwargs): + if query.shape[2] == key.shape[2]: + # Prefilling phase + module.indices = None + elif module.indices is not None: + # Decoding phase + bsz, num_heads, seq_len, head_dim = query.shape + num_key_value_heads = key.shape[1] + num_groups = num_heads // num_key_value_heads + + # Build a fake key k per key group such that for every query q, exp() = 0 + # To do so, use the least square method to find k such that q @ k ~ LARGE_NEGATIVE_FLOAT + q = query.view(bsz, num_groups, num_key_value_heads, seq_len, head_dim) + q = q.transpose(1, 2).reshape(bsz * num_key_value_heads, num_groups * seq_len, head_dim) + targets = LARGE_NEGATIVE_FLOAT * torch.ones(q.shape[:2]).to(q.device) + k = torch.linalg.lstsq(q.float(), targets)[0].to(q.dtype) + assert torch.exp(torch.einsum("hnd,hd->hn", q, k).max()) == 0, "Could not find fake keys" + k = k.view(bsz, num_key_value_heads, head_dim) + + # At indices, update the keys to the fake keys and the values to 0 + key[*module.indices] = k[*module.indices[:2]] + value[*module.indices] = 0 # TODO: do this only once in the forward_hook ? + + return func(module, query, key, value, attention_mask, dropout, scaling, is_causal, **kwargs) + + return wrapper + + +def patch_attention_functions(): + """ + Add the update_keys_before_attention decorator to all attention functions in ALL_ATTENTION_FUNCTIONS + """ + + for name, func in ALL_ATTENTION_FUNCTIONS.items(): + ALL_ATTENTION_FUNCTIONS[name] = attention_patch(func) diff --git a/kvpress/presses/adakv_press.py b/kvpress/presses/adakv_press.py new file mode 100644 index 00000000..8a38afe1 --- /dev/null +++ b/kvpress/presses/adakv_press.py @@ -0,0 +1,50 @@ +# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + + +from dataclasses import dataclass + +import torch + +from kvpress.presses.base_press import BasePress +from kvpress.presses.scorer_press import ScorerPress + + +@dataclass +class AdaKVPress(BasePress): + """ + AdaKV (https://arxiv.org/abs/2407.11550) selects the top-k keys and values among all heads in a layer + based on the scores, achieving head-specific compression. + """ + + scorer: ScorerPress + + @property + def compression_ratio(self): + return self.scorer.compression_ratio + + @compression_ratio.setter + def compression_ratio(self, value): + self.scorer.compression_ratio = value + + def compress(self, module, hidden_states, keys, values, attentions, kwargs): + if self.compression_ratio == 0: + return keys, values + + assert module.config._attn_implementation != "eager", "eager mode not supported" + + # Compute scores + scores = self.scorer.score(module, hidden_states, keys, values, attentions, kwargs) + bsz, num_key_value_heads, q_len = scores.shape + + # TODO: understand how floor_alpha works + + # Compute bottom-k across heads + n_pruned = int(num_key_value_heads * q_len * self.compression_ratio) + indices = torch.topk(-scores.view(bsz, -1), n_pruned, dim=1).indices.flatten() + + # Save indices for attention patching in the module + module.indices = (torch.arange(bsz).repeat_interleave(n_pruned), indices // q_len, indices % q_len) + + # Return keys and values without compression (achieved with the attention patch) + return keys, values From 9a46d7a7868f622b626b859e9a7cf2d5174d23d4 Mon Sep 17 00:00:00 2001 From: SimJeg Date: Tue, 7 Jan 2025 12:47:51 +0000 Subject: [PATCH 03/25] Add alpha_safeguard --- kvpress/presses/adakv_press.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/kvpress/presses/adakv_press.py b/kvpress/presses/adakv_press.py index 8a38afe1..8469efba 100644 --- a/kvpress/presses/adakv_press.py +++ b/kvpress/presses/adakv_press.py @@ -15,9 +15,11 @@ class AdaKVPress(BasePress): """ AdaKV (https://arxiv.org/abs/2407.11550) selects the top-k keys and values among all heads in a layer based on the scores, achieving head-specific compression. + A safeguard is applied to ensure a minimum fraction of KV pairs per head (alpha_safeguard parameter) """ scorer: ScorerPress + alpha_safeguard: float = 0.20 @property def compression_ratio(self): @@ -37,7 +39,10 @@ def compress(self, module, hidden_states, keys, values, attentions, kwargs): scores = self.scorer.score(module, hidden_states, keys, values, attentions, kwargs) bsz, num_key_value_heads, q_len = scores.shape - # TODO: understand how floor_alpha works + # Make sure to keep at least alpha * (1 - compression_ratio) KV pairs per head + n_safe = int(q_len * (1 - self.compression_ratio) * self.alpha_safeguard) + top_indices = torch.topk(scores, n_safe, dim=-1).indices + scores.scatter_(-1, top_indices, torch.finfo(scores.dtype).max) # Compute bottom-k across heads n_pruned = int(num_key_value_heads * q_len * self.compression_ratio) From 26935585baa6cd9b760dd61e7751ded6a9613bf2 Mon Sep 17 00:00:00 2001 From: SimJeg Date: Wed, 8 Jan 2025 11:39:09 +0000 Subject: [PATCH 04/25] Move from least squares to perceptron --- kvpress/attention_patch.py | 23 ++++++++++++++++------- kvpress/presses/adakv_press.py | 1 + 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/kvpress/attention_patch.py b/kvpress/attention_patch.py index 705f168a..e69a0185 100644 --- a/kvpress/attention_patch.py +++ b/kvpress/attention_patch.py @@ -1,7 +1,20 @@ -import torch from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS -LARGE_NEGATIVE_FLOAT = -float(1e5) +LARGE_NEGATIVE_FLOAT = -1e5 + + +def search_hyperplane(X, max_iter=1000): + """ + Search for an hyperplane Y such that for every Xi, <= 1 (simple perceptron) + Returns LARGE_NEGATIVE_FLOAT * Y to ensure exp() = 0 + """ + Y = X.mean(1) + for _ in range(max_iter): + mask = (X * Y.unsqueeze(1)).sum(dim=2, keepdim=True) <= 1 + if not mask.any(): + return LARGE_NEGATIVE_FLOAT * Y + Y += (X * mask).sum(1) / mask.sum(1).clamp(min=1) + raise ValueError("Could not find fake keys such that for every query q, exp() = 0") def attention_patch(func): @@ -22,17 +35,13 @@ def wrapper(module, query, key, value, attention_mask, dropout, scaling=None, is num_groups = num_heads // num_key_value_heads # Build a fake key k per key group such that for every query q, exp() = 0 - # To do so, use the least square method to find k such that q @ k ~ LARGE_NEGATIVE_FLOAT q = query.view(bsz, num_groups, num_key_value_heads, seq_len, head_dim) q = q.transpose(1, 2).reshape(bsz * num_key_value_heads, num_groups * seq_len, head_dim) - targets = LARGE_NEGATIVE_FLOAT * torch.ones(q.shape[:2]).to(q.device) - k = torch.linalg.lstsq(q.float(), targets)[0].to(q.dtype) - assert torch.exp(torch.einsum("hnd,hd->hn", q, k).max()) == 0, "Could not find fake keys" + k = search_hyperplane(q) k = k.view(bsz, num_key_value_heads, head_dim) # At indices, update the keys to the fake keys and the values to 0 key[*module.indices] = k[*module.indices[:2]] - value[*module.indices] = 0 # TODO: do this only once in the forward_hook ? return func(module, query, key, value, attention_mask, dropout, scaling, is_causal, **kwargs) diff --git a/kvpress/presses/adakv_press.py b/kvpress/presses/adakv_press.py index 8469efba..b92853ff 100644 --- a/kvpress/presses/adakv_press.py +++ b/kvpress/presses/adakv_press.py @@ -52,4 +52,5 @@ def compress(self, module, hidden_states, keys, values, attentions, kwargs): module.indices = (torch.arange(bsz).repeat_interleave(n_pruned), indices // q_len, indices % q_len) # Return keys and values without compression (achieved with the attention patch) + values[*module.indices] = 0 return keys, values From b16ab6a3a9435fcb237a2a15383c06c92bf56bdc Mon Sep 17 00:00:00 2001 From: SimJeg Date: Wed, 8 Jan 2025 15:57:42 +0000 Subject: [PATCH 05/25] Remove GQA --- kvpress/attention_patch.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/kvpress/attention_patch.py b/kvpress/attention_patch.py index e69a0185..9a0dabc2 100644 --- a/kvpress/attention_patch.py +++ b/kvpress/attention_patch.py @@ -19,8 +19,8 @@ def search_hyperplane(X, max_iter=1000): def attention_patch(func): """ - Decorator to udpate the keys and values before the attention computation at the indices provided in module.indices - The keys are updated to a fake key k such that for the input queries q, exp() = 0. The values are set to 0. + Decorator to udpate the keys before the attention computation at the indices provided in module.indices + The keys are updated to a fake key k such that for the input queries q, exp() = 0 This is used to fake head-wise compression. A more optimal solution would be to create a new kernel. """ @@ -31,16 +31,13 @@ def wrapper(module, query, key, value, attention_mask, dropout, scaling=None, is elif module.indices is not None: # Decoding phase bsz, num_heads, seq_len, head_dim = query.shape - num_key_value_heads = key.shape[1] - num_groups = num_heads // num_key_value_heads # Build a fake key k per key group such that for every query q, exp() = 0 - q = query.view(bsz, num_groups, num_key_value_heads, seq_len, head_dim) - q = q.transpose(1, 2).reshape(bsz * num_key_value_heads, num_groups * seq_len, head_dim) + q = query.reshape(bsz * num_heads, seq_len, head_dim) k = search_hyperplane(q) - k = k.view(bsz, num_key_value_heads, head_dim) + k = k.view(bsz, num_heads, head_dim) - # At indices, update the keys to the fake keys and the values to 0 + # At indices, update the keys to the fake keys key[*module.indices] = k[*module.indices[:2]] return func(module, query, key, value, attention_mask, dropout, scaling, is_causal, **kwargs) From d19edc7d3b986fb4f33185dca4f6a410074b056c Mon Sep 17 00:00:00 2001 From: SimJeg Date: Thu, 9 Jan 2025 08:20:07 +0000 Subject: [PATCH 06/25] Fix attention patch --- kvpress/attention_patch.py | 27 ++++++++++++++------------- kvpress/presses/adakv_press.py | 3 --- 2 files changed, 14 insertions(+), 16 deletions(-) diff --git a/kvpress/attention_patch.py b/kvpress/attention_patch.py index 9a0dabc2..fcafe8ef 100644 --- a/kvpress/attention_patch.py +++ b/kvpress/attention_patch.py @@ -1,18 +1,16 @@ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS -LARGE_NEGATIVE_FLOAT = -1e5 - -def search_hyperplane(X, max_iter=1000): +def search_hyperplane(X, max_iter: int = 1000, epsilon: float = 1e-5): """ - Search for an hyperplane Y such that for every Xi, <= 1 (simple perceptron) - Returns LARGE_NEGATIVE_FLOAT * Y to ensure exp() = 0 + Search for an hyperplane Y such that for every Xi, <= epsilon (simple perceptron) + Returns - Y / espilon ** 2 to ensure exp() = 0 """ Y = X.mean(1) for _ in range(max_iter): - mask = (X * Y.unsqueeze(1)).sum(dim=2, keepdim=True) <= 1 + mask = (X * Y.unsqueeze(1)).sum(dim=2, keepdim=True) <= epsilon if not mask.any(): - return LARGE_NEGATIVE_FLOAT * Y + return -Y / epsilon**2 Y += (X * mask).sum(1) / mask.sum(1).clamp(min=1) raise ValueError("Could not find fake keys such that for every query q, exp() = 0") @@ -26,18 +24,21 @@ def attention_patch(func): def wrapper(module, query, key, value, attention_mask, dropout, scaling=None, is_causal=None, **kwargs): if query.shape[2] == key.shape[2]: - # Prefilling phase + # Prefilling module.indices = None elif module.indices is not None: - # Decoding phase + # Decoding: build fake keys k s.t. exp() = 0 bsz, num_heads, seq_len, head_dim = query.shape + num_key_value_heads = key.shape[1] + num_groups = num_heads // num_key_value_heads # Build a fake key k per key group such that for every query q, exp() = 0 - q = query.reshape(bsz * num_heads, seq_len, head_dim) + q = query.view(bsz, num_key_value_heads, num_groups, seq_len, head_dim) + q = q.reshape(bsz * num_key_value_heads, num_groups * seq_len, head_dim) k = search_hyperplane(q) - k = k.view(bsz, num_heads, head_dim) + k = k.view(bsz, num_key_value_heads, head_dim) - # At indices, update the keys to the fake keys + # At indices, update the keys to the fake keys and the values to 0 key[*module.indices] = k[*module.indices[:2]] return func(module, query, key, value, attention_mask, dropout, scaling, is_causal, **kwargs) @@ -47,7 +48,7 @@ def wrapper(module, query, key, value, attention_mask, dropout, scaling=None, is def patch_attention_functions(): """ - Add the update_keys_before_attention decorator to all attention functions in ALL_ATTENTION_FUNCTIONS + Add the attention_patch decorator to functions in ALL_ATTENTION_FUNCTIONS """ for name, func in ALL_ATTENTION_FUNCTIONS.items(): diff --git a/kvpress/presses/adakv_press.py b/kvpress/presses/adakv_press.py index b92853ff..2cb8c08a 100644 --- a/kvpress/presses/adakv_press.py +++ b/kvpress/presses/adakv_press.py @@ -50,7 +50,4 @@ def compress(self, module, hidden_states, keys, values, attentions, kwargs): # Save indices for attention patching in the module module.indices = (torch.arange(bsz).repeat_interleave(n_pruned), indices // q_len, indices % q_len) - - # Return keys and values without compression (achieved with the attention patch) - values[*module.indices] = 0 return keys, values From 79156c76c9e76df8edb8f60d387ea7f1e4d8e5b5 Mon Sep 17 00:00:00 2001 From: SimJeg Date: Thu, 9 Jan 2025 09:07:36 +0000 Subject: [PATCH 07/25] Align with ScorerPress --- kvpress/attention_patch.py | 3 ++- kvpress/presses/adakv_press.py | 5 +++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/kvpress/attention_patch.py b/kvpress/attention_patch.py index fcafe8ef..0ac665b8 100644 --- a/kvpress/attention_patch.py +++ b/kvpress/attention_patch.py @@ -1,3 +1,4 @@ +import torch from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS @@ -8,7 +9,7 @@ def search_hyperplane(X, max_iter: int = 1000, epsilon: float = 1e-5): """ Y = X.mean(1) for _ in range(max_iter): - mask = (X * Y.unsqueeze(1)).sum(dim=2, keepdim=True) <= epsilon + mask = torch.bmm(X, Y.unsqueeze(-1)) <= epsilon if not mask.any(): return -Y / epsilon**2 Y += (X * mask).sum(1) / mask.sum(1).clamp(min=1) diff --git a/kvpress/presses/adakv_press.py b/kvpress/presses/adakv_press.py index 2cb8c08a..ab2130ed 100644 --- a/kvpress/presses/adakv_press.py +++ b/kvpress/presses/adakv_press.py @@ -40,12 +40,13 @@ def compress(self, module, hidden_states, keys, values, attentions, kwargs): bsz, num_key_value_heads, q_len = scores.shape # Make sure to keep at least alpha * (1 - compression_ratio) KV pairs per head - n_safe = int(q_len * (1 - self.compression_ratio) * self.alpha_safeguard) + n_kept = int(q_len * (1 - self.compression_ratio)) # ScorerPress definition + n_safe = int(n_kept * self.alpha_safeguard) top_indices = torch.topk(scores, n_safe, dim=-1).indices scores.scatter_(-1, top_indices, torch.finfo(scores.dtype).max) # Compute bottom-k across heads - n_pruned = int(num_key_value_heads * q_len * self.compression_ratio) + n_pruned = num_key_value_heads * (q_len - n_kept) indices = torch.topk(-scores.view(bsz, -1), n_pruned, dim=1).indices.flatten() # Save indices for attention patching in the module From 002ac9db5d97c91088f7517bdcab184ac0fd0e22 Mon Sep 17 00:00:00 2001 From: SimJeg Date: Thu, 9 Jan 2025 10:39:45 +0000 Subject: [PATCH 08/25] Update evaluate --- evaluation/evaluate.py | 18 +++++++++--------- kvpress/presses/adakv_press.py | 3 ++- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/evaluation/evaluate.py b/evaluation/evaluate.py index b82018b2..05c7abb2 100644 --- a/evaluation/evaluate.py +++ b/evaluation/evaluate.py @@ -17,6 +17,7 @@ from zero_scrolls.calculate_metrics import calculate_metrics as zero_scrolls_scorer from kvpress import ( + AdaKVPress, ExpectedAttentionPress, KnormPress, ObservedAttentionPress, @@ -47,6 +48,7 @@ "observed_attention": ObservedAttentionPress(), "random": RandomPress(), "snapkv": SnapKVPress(), + "adasnapkv": AdaKVPress(SnapKVPress()), "streaming_llm": StreamingLLMPress(), } @@ -110,6 +112,7 @@ def evaluate( df = load_dataset(DATASET_DICT[dataset], data_dir=data_dir, split="test").to_pandas() if fraction < 1.0: df = df.sample(frac=fraction, random_state=42) + save_filename = save_filename.with_name(save_filename.stem + f"__fraction{fraction:.2f}" + save_filename.suffix) if compress_questions: df["context"] = df["context"] + df["question"] @@ -122,24 +125,21 @@ def evaluate( press.compression_ratio = compression_ratio # Initialize pipeline with the correct attention implementation + model_kwargs = {"torch_dtype": "auto"} if isinstance(press, ObservedAttentionPress): - model_kwargs = {"attn_implementation": "eager"} + model_kwargs["attn_implementation"] = "eager" else: try: import flash_attn # noqa: F401 - model_kwargs = {"attn_implementation": "flash_attention_2"} + model_kwargs["attn_implementation"] = "flash_attention_2" except ImportError: - model_kwargs = {} + pass if device == "auto": - pipe = pipeline( - "kv-press-text-generation", model=model, device_map="auto", torch_dtype="auto", model_kwargs=model_kwargs - ) + pipe = pipeline("kv-press-text-generation", model=model, device_map="auto", model_kwargs=model_kwargs) else: - pipe = pipeline( - "kv-press-text-generation", model=model, device=device, torch_dtype="auto", model_kwargs=model_kwargs - ) + pipe = pipeline("kv-press-text-generation", model=model, device=device, model_kwargs=model_kwargs) # Run pipeline on each context df["predicted_answer"] = None diff --git a/kvpress/presses/adakv_press.py b/kvpress/presses/adakv_press.py index ab2130ed..587b16d2 100644 --- a/kvpress/presses/adakv_press.py +++ b/kvpress/presses/adakv_press.py @@ -34,6 +34,7 @@ def compress(self, module, hidden_states, keys, values, attentions, kwargs): return keys, values assert module.config._attn_implementation != "eager", "eager mode not supported" + assert isinstance(self.scorer, ScorerPress), "AdaKVPress requires a ScorerPress as input" # Compute scores scores = self.scorer.score(module, hidden_states, keys, values, attentions, kwargs) @@ -47,7 +48,7 @@ def compress(self, module, hidden_states, keys, values, attentions, kwargs): # Compute bottom-k across heads n_pruned = num_key_value_heads * (q_len - n_kept) - indices = torch.topk(-scores.view(bsz, -1), n_pruned, dim=1).indices.flatten() + indices = torch.topk(-scores.reshape(bsz, -1), n_pruned, dim=1).indices.flatten() # Save indices for attention patching in the module module.indices = (torch.arange(bsz).repeat_interleave(n_pruned), indices // q_len, indices % q_len) From 58179356e719996b0e86ed772503837fe5c4803c Mon Sep 17 00:00:00 2001 From: SimJeg Date: Thu, 9 Jan 2025 11:02:58 +0000 Subject: [PATCH 09/25] Fix attention patch --- evaluation/evaluate.py | 7 ++++++- kvpress/attention_patch.py | 10 +++++----- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/evaluation/evaluate.py b/evaluation/evaluate.py index 05c7abb2..2a1fd296 100644 --- a/evaluation/evaluate.py +++ b/evaluation/evaluate.py @@ -24,6 +24,8 @@ RandomPress, SnapKVPress, StreamingLLMPress, + ThinKPress, + TOVAPress, ) logger = logging.getLogger(__name__) @@ -43,13 +45,16 @@ } PRESS_DICT = { + "adasnapkv": AdaKVPress(SnapKVPress()), + "ada_expected_attention": AdaKVPress(ExpectedAttentionPress()), "expected_attention": ExpectedAttentionPress(), "knorm": KnormPress(), "observed_attention": ObservedAttentionPress(), "random": RandomPress(), "snapkv": SnapKVPress(), - "adasnapkv": AdaKVPress(SnapKVPress()), "streaming_llm": StreamingLLMPress(), + "think": ThinKPress(), + "tova": TOVAPress(), } diff --git a/kvpress/attention_patch.py b/kvpress/attention_patch.py index 0ac665b8..57778fe0 100644 --- a/kvpress/attention_patch.py +++ b/kvpress/attention_patch.py @@ -2,16 +2,16 @@ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS -def search_hyperplane(X, max_iter: int = 1000, epsilon: float = 1e-5): +def search_hyperplane(X, max_iter: int = 1000): """ - Search for an hyperplane Y such that for every Xi, <= epsilon (simple perceptron) - Returns - Y / espilon ** 2 to ensure exp() = 0 + Search for an hyperplane Y such that for every Xi, <= 0 (simple perceptron) + Returns - 1e5 * Y / ||Y|| ** 2 to ensure exp() = 0 """ Y = X.mean(1) for _ in range(max_iter): - mask = torch.bmm(X, Y.unsqueeze(-1)) <= epsilon + mask = torch.bmm(X, Y.unsqueeze(-1)) <= 0 if not mask.any(): - return -Y / epsilon**2 + return -1e5 * Y / Y.norm(dim=-1, keepdim=True) ** 2 Y += (X * mask).sum(1) / mask.sum(1).clamp(min=1) raise ValueError("Could not find fake keys such that for every query q, exp() = 0") From 31f6b12cb9003c17818a4225d293c7d38526e714 Mon Sep 17 00:00:00 2001 From: SimJeg Date: Thu, 9 Jan 2025 11:20:29 +0000 Subject: [PATCH 10/25] Some cleaning --- README.md | 1 + kvpress/attention_patch.py | 17 +++++++++-------- kvpress/presses/adakv_press.py | 2 +- kvpress/presses/composed_press.py | 5 +++-- 4 files changed, 14 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index 2b9e59cf..6ada295d 100644 --- a/README.md +++ b/README.md @@ -65,6 +65,7 @@ Some presses rely on a different logic: - `SimLayerKVPress`: identify "lazy" layers, and apply the StreamingLLM approach to them ([paper](https://arxiv.org/abs/2410.13846)) Finally we provide special presses: +- `AdaKVPress`: prune bottom scores of any `ScorerPress` but across all heads, achieving head-wise compressions (see [paper](https://arxiv.org/abs/2407.11550)) - `PerLayerCompressionPress`: compress each layer with a different compression ratio (experimental). This press can be used with any other press that allows to set a compression_ratio - `ComposedPress`: compose multiple presses together by chaining their forward hooks - `KeyRerotationPress`: rerotate pruned keys to have continuous RoPE embeddings. This press can be used with any other press that inherits from `ScorerPress`. diff --git a/kvpress/attention_patch.py b/kvpress/attention_patch.py index 57778fe0..26b63f4e 100644 --- a/kvpress/attention_patch.py +++ b/kvpress/attention_patch.py @@ -4,10 +4,11 @@ def search_hyperplane(X, max_iter: int = 1000): """ - Search for an hyperplane Y such that for every Xi, <= 0 (simple perceptron) + Search for an hyperplane Y such that for every Xi, <= 0 Returns - 1e5 * Y / ||Y|| ** 2 to ensure exp() = 0 + Raises a ValueError if no such hyperplane is found """ - Y = X.mean(1) + Y = X.mean(1) # this initialization is enough for most cases for _ in range(max_iter): mask = torch.bmm(X, Y.unsqueeze(-1)) <= 0 if not mask.any(): @@ -18,16 +19,16 @@ def search_hyperplane(X, max_iter: int = 1000): def attention_patch(func): """ - Decorator to udpate the keys before the attention computation at the indices provided in module.indices - The keys are updated to a fake key k such that for the input queries q, exp() = 0 - This is used to fake head-wise compression. A more optimal solution would be to create a new kernel. + Decorator to udpate the keys before the attention computation at the indices provided in module.masked_key_indices + The keys are updated with a fake key k such that exp() = 0 to fake head-wise compression + This solution is not optimal as it does not reduce peak memory and slightly increase runtime """ def wrapper(module, query, key, value, attention_mask, dropout, scaling=None, is_causal=None, **kwargs): if query.shape[2] == key.shape[2]: # Prefilling - module.indices = None - elif module.indices is not None: + module.masked_key_indices = None + elif module.masked_key_indices is not None: # Decoding: build fake keys k s.t. exp() = 0 bsz, num_heads, seq_len, head_dim = query.shape num_key_value_heads = key.shape[1] @@ -40,7 +41,7 @@ def wrapper(module, query, key, value, attention_mask, dropout, scaling=None, is k = k.view(bsz, num_key_value_heads, head_dim) # At indices, update the keys to the fake keys and the values to 0 - key[*module.indices] = k[*module.indices[:2]] + key[*module.masked_key_indices] = k[*module.masked_key_indices[:2]] return func(module, query, key, value, attention_mask, dropout, scaling, is_causal, **kwargs) diff --git a/kvpress/presses/adakv_press.py b/kvpress/presses/adakv_press.py index 587b16d2..e4a2a1e1 100644 --- a/kvpress/presses/adakv_press.py +++ b/kvpress/presses/adakv_press.py @@ -51,5 +51,5 @@ def compress(self, module, hidden_states, keys, values, attentions, kwargs): indices = torch.topk(-scores.reshape(bsz, -1), n_pruned, dim=1).indices.flatten() # Save indices for attention patching in the module - module.indices = (torch.arange(bsz).repeat_interleave(n_pruned), indices // q_len, indices % q_len) + module.masked_key_indices = (torch.arange(bsz).repeat_interleave(n_pruned), indices // q_len, indices % q_len) return keys, values diff --git a/kvpress/presses/composed_press.py b/kvpress/presses/composed_press.py index 7fdf8fba..cb2b4aab 100644 --- a/kvpress/presses/composed_press.py +++ b/kvpress/presses/composed_press.py @@ -2,6 +2,7 @@ from kvpress.presses.base_press import BasePress from kvpress.presses.observed_attention_press import ObservedAttentionPress +from kvpress.presses.adakv_press import AdaKVPress @dataclass @@ -15,8 +16,8 @@ class ComposedPress(BasePress): def __post_init__(self): self.compression_ratio = None assert not any( - isinstance(press, ObservedAttentionPress) for press in self.presses - ), "ComposedPress cannot contains ObservedAttentionPress because attentions pruning is not handled" + isinstance(press, (ObservedAttentionPress, AdaKVPress)) for press in self.presses + ), "ComposedPress cannot contains ObservedAttentionPress or AdaKVPress" def forward_hook(self, module, input, kwargs, output): self.compression_ratio = 1.0 From f5cb200977a1e3d4c2d2dcae14f41e20413d7f4d Mon Sep 17 00:00:00 2001 From: SimJeg Date: Thu, 9 Jan 2025 11:36:11 +0000 Subject: [PATCH 11/25] Add check --- kvpress/presses/per_layer_compression_press.py | 1 + 1 file changed, 1 insertion(+) diff --git a/kvpress/presses/per_layer_compression_press.py b/kvpress/presses/per_layer_compression_press.py index 0e497375..80c6db77 100644 --- a/kvpress/presses/per_layer_compression_press.py +++ b/kvpress/presses/per_layer_compression_press.py @@ -32,6 +32,7 @@ def __post_init__(self): self.press.__init__ # type:ignore[misc] ).parameters ), f"compression_ratio can't be set in the provided press: {self.press.__class__}" + assert isinstance(self.press, ScorerPress), "PerLayerCompressionPress requires a ScorerPress as input" def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dict, output: list): original_compression_ratio = self.press.compression_ratio # type:ignore[attr-defined] From 64f0e995930dc58ff91d5b6c58c0ef8f5df0266f Mon Sep 17 00:00:00 2001 From: SimJeg Date: Thu, 9 Jan 2025 13:00:49 +0000 Subject: [PATCH 12/25] Add first test --- tests/presses/test_presses.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/presses/test_presses.py b/tests/presses/test_presses.py index 50056b9a..4bb3a89e 100644 --- a/tests/presses/test_presses.py +++ b/tests/presses/test_presses.py @@ -7,9 +7,7 @@ from torch import nn from transformers import DynamicCache -from kvpress import ComposedPress, KeyRerotationPress, KnormPress, ObservedAttentionPress -from kvpress.presses.scorer_press import ScorerPress -from kvpress.presses.think_press import ThinKPress +from kvpress import ComposedPress, KeyRerotationPress, KnormPress, ObservedAttentionPress, AdaKVPress, ThinKPress, ScorerPress from tests.default_presses import default_presses from tests.fixtures import unit_test_model, unit_test_model_output_attention # noqa: F401 @@ -24,7 +22,7 @@ def test_composed_press(unit_test_model): # noqa: F811 @pytest.mark.parametrize("press_dict", default_presses) -@pytest.mark.parametrize("wrapper_press", [None, ComposedPress, KeyRerotationPress]) +@pytest.mark.parametrize("wrapper_press", [None, ComposedPress, KeyRerotationPress, AdaKVPress]) def test_presses_run(unit_test_model, press_dict, wrapper_press): # noqa: F811 cls = press_dict["cls"] for kwargs in press_dict["kwargs"]: @@ -33,6 +31,11 @@ def test_presses_run(unit_test_model, press_dict, wrapper_press): # noqa: F811 press = ComposedPress(presses=[press]) if isinstance(wrapper_press, KeyRerotationPress): press = KeyRerotationPress(press=press) + if isinstance(wrapper_press, AdaKVPress): + if not isinstance(press, ScorerPress): + return + else: + press = AdaKVPress(press=press) with press(unit_test_model): input_ids = unit_test_model.dummy_inputs["input_ids"] From 6227bcd5a852b339071aa3bd0584bf1d5f5449e6 Mon Sep 17 00:00:00 2001 From: SimJeg Date: Thu, 9 Jan 2025 14:20:26 +0000 Subject: [PATCH 13/25] Adress PR feedback --- kvpress/attention_patch.py | 4 ++-- pyproject.toml | 2 +- tests/test_attention_patch.py | 9 +++++++++ 3 files changed, 12 insertions(+), 3 deletions(-) create mode 100644 tests/test_attention_patch.py diff --git a/kvpress/attention_patch.py b/kvpress/attention_patch.py index 26b63f4e..6e4541db 100644 --- a/kvpress/attention_patch.py +++ b/kvpress/attention_patch.py @@ -4,8 +4,8 @@ def search_hyperplane(X, max_iter: int = 1000): """ - Search for an hyperplane Y such that for every Xi, <= 0 - Returns - 1e5 * Y / ||Y|| ** 2 to ensure exp() = 0 + Given a tensor X of shape (bsz, seq_len, head_dim), search for an hyperplane Y (bsz, head_dim) + such that for every i, <= 0. Returns - 1e5 * Y / ||Y|| ** 2 to ensure exp() = 0 Raises a ValueError if no such hyperplane is found """ Y = X.mean(1) # this initialization is enough for most cases diff --git a/pyproject.toml b/pyproject.toml index 7f60fbec..4ded782b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,7 +2,7 @@ name = "kvpress" authors = ["Simon Jegou", "Maximilian Jeblick", "Jiwei Liu", "David Austin"] description = "Efficiently compress the KV cache of any pretrained transformer" -version = "0.1.1" +version = "0.2.0" readme = "README.md" [tool.poetry.dependencies] diff --git a/tests/test_attention_patch.py b/tests/test_attention_patch.py new file mode 100644 index 00000000..0a12a935 --- /dev/null +++ b/tests/test_attention_patch.py @@ -0,0 +1,9 @@ +import torch +from kvpress.attention_patch import search_hyperplane + + +def test_search_hyperplane(): + bsz, seq_len, head_dim = 50, 500, 128 + X = torch.rand(bsz, seq_len, head_dim) + Y = search_hyperplane(X) + assert torch.exp(torch.bmm(X, Y.unsqueeze(-1))).max() == 0 \ No newline at end of file From 7e2ba3b8909d2f8c3e90d4e789bcb9373481ed99 Mon Sep 17 00:00:00 2001 From: SimJeg Date: Thu, 9 Jan 2025 14:53:05 +0000 Subject: [PATCH 14/25] Update position embeddings --- kvpress/presses/base_press.py | 3 --- kvpress/presses/expected_attention_press.py | 15 +++++++-------- kvpress/presses/key_rerotation_press.py | 16 ++-------------- kvpress/presses/simlayerkv_press.py | 7 +++++-- kvpress/presses/snapkv_press.py | 12 ++++++------ kvpress/presses/think_press.py | 7 +++---- kvpress/presses/tova_press.py | 4 +++- 7 files changed, 26 insertions(+), 38 deletions(-) diff --git a/kvpress/presses/base_press.py b/kvpress/presses/base_press.py index e5589e11..564a9123 100644 --- a/kvpress/presses/base_press.py +++ b/kvpress/presses/base_press.py @@ -134,9 +134,6 @@ def __call__(self, model: PreTrainedModel) -> Generator: hooks = [] try: for layer in model.model.layers: - if hasattr(model.model, "rotary_emb"): - # Make rotary embeddings available to the forward hook - layer.self_attn.rotary_emb = model.model.rotary_emb hooks.append(layer.self_attn.register_forward_hook(self.forward_hook, with_kwargs=True)) yield diff --git a/kvpress/presses/expected_attention_press.py b/kvpress/presses/expected_attention_press.py index 49cb311a..fb79c520 100644 --- a/kvpress/presses/expected_attention_press.py +++ b/kvpress/presses/expected_attention_press.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 -import inspect import math from dataclasses import dataclass @@ -66,13 +65,9 @@ def get_query_statistics(self, module: nn.Module, hidden_states: torch.Tensor): cov = cov.permute(0, 3, 1, 2) # RoPE rotation matrix on next n_future_positions - if "position_ids" in inspect.signature(module.rotary_emb.forward).parameters: - position_ids = torch.arange(q_len, q_len + self.n_future_positions).unsqueeze(0).to(mu.device) - cos, sin = module.rotary_emb(mu, position_ids) - cos, sin = cos[0], sin[0] - else: - cos, sin = module.rotary_emb(mu, q_len + self.n_future_positions) - cos, sin = cos[q_len:], sin[q_len:] + position_ids = torch.arange(q_len, q_len + self.n_future_positions).unsqueeze(0).to(mu.device) + cos, sin = self.rotary_emb(mu, position_ids) + cos, sin = cos[0], sin[0] Id = torch.eye(d, device=cos.device, dtype=cos.dtype) P = torch.zeros((d, d), device=cos.device, dtype=cos.dtype) @@ -137,3 +132,7 @@ def score( scores = F.pad(scores, (self.n_sink, 0), value=scores.max().item()) return scores + + def __call__(self, model): + self.rotary_emb = model.model.rotary_emb + return super().__call__(model) diff --git a/kvpress/presses/key_rerotation_press.py b/kvpress/presses/key_rerotation_press.py index 6c5b7356..da6dd226 100644 --- a/kvpress/presses/key_rerotation_press.py +++ b/kvpress/presses/key_rerotation_press.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 -import inspect from dataclasses import dataclass import torch @@ -49,7 +48,7 @@ def compress( indices = scores.topk(n_kept, dim=-1).indices indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.config.head_dim) - cos, sin = get_rope_embeddings(module, keys) + cos, sin = kwargs["position_embeddings"] # Rerotate as follows # 1. keys = RoPE(W_k * hidden_states) # 2. keys_unrotated = RoPE^-1(keys) @@ -61,19 +60,8 @@ def compress( # 3. Prune keys keys = keys.gather(2, indices).contiguous() # 4. Apply RoPE - cos, sin = get_rope_embeddings(module, keys) + cos, sin = cos[:, :n_kept], sin[:, :n_kept] keys = (keys * cos.unsqueeze(1)) + (rotate_half(keys) * sin.unsqueeze(1)) values = values.gather(2, indices).contiguous() return keys, values - - -def get_rope_embeddings(module, x): - length = x.shape[2] - # rotary_emb function only needs .device and .dtype, so we can plug in any tensor regardless of shape - if "position_ids" in inspect.signature(module.rotary_emb.forward).parameters: - position_ids = torch.arange(length).unsqueeze(0).to(x.device) - cos, sin = module.rotary_emb(x, position_ids) - else: - cos, sin = module.rotary_emb(x, length) - return cos, sin diff --git a/kvpress/presses/simlayerkv_press.py b/kvpress/presses/simlayerkv_press.py index 8693015c..6113899e 100644 --- a/kvpress/presses/simlayerkv_press.py +++ b/kvpress/presses/simlayerkv_press.py @@ -43,13 +43,16 @@ def is_lazy( module: nn.Module, hidden_states: torch.Tensor, keys: torch.Tensor, + position_embeddings: torch.Tensor, ) -> bool: """ Compute the attention weights of the last tokens over the initial and recent tokens. The layer is considered lazy if the sum of these attention weights is above the lazy_threshold. """ - attn_weights = SnapKVPress.compute_window_attention(module, hidden_states, keys, self.n_last) + attn_weights = SnapKVPress.compute_window_attention( + module, hidden_states, keys, self.n_last, position_embeddings + ) attn_weights = attn_weights.mean((0, 1, 2)) # mean over bsz, heads and window size score = attn_weights[: self.n_initial].sum() + attn_weights[-self.n_recent :].sum() return score.item() > self.lazy_threshold @@ -91,7 +94,7 @@ def compress( return keys, values # Compression - if self.is_lazy(module, hidden_states, keys): + if self.is_lazy(module, hidden_states, keys, kwargs["position_embeddings"]): # If layer is lazy, only keep the initial and recent KV pairs keys = torch.cat([keys[:, :, : self.n_initial], keys[:, :, -self.n_recent + self.n_last :]], dim=2) values = torch.cat([values[:, :, : self.n_initial], values[:, :, -self.n_recent + self.n_last :]], dim=2) diff --git a/kvpress/presses/snapkv_press.py b/kvpress/presses/snapkv_press.py index 05ee6f9f..852c92f2 100644 --- a/kvpress/presses/snapkv_press.py +++ b/kvpress/presses/snapkv_press.py @@ -26,9 +26,7 @@ class SnapKVPress(ScorerPress): kernel_size: int = 5 @staticmethod - def compute_window_attention( - module: nn.Module, hidden_states: torch.Tensor, keys: torch.Tensor, window_size: int - ) -> torch.Tensor: + def compute_window_attention(module, hidden_states, keys, window_size, position_embeddings): """ Compute the last window_size queries and associated attention weights for the first q_len - window_size keys. """ @@ -50,8 +48,8 @@ def compute_window_attention( query_states = query_states.view(bsz, window_size, num_heads, head_dim).transpose(1, 2) # Apply RoPE - position_ids = torch.arange(q_len - window_size, q_len).unsqueeze(0).to(query_states.device) - cos, sin = module.rotary_emb(query_states, position_ids) + cos, sin = position_embeddings + cos, sin = cos[:, -window_size:], sin[:, -window_size:] query_states = (query_states * cos.unsqueeze(1)) + (rotate_half(query_states) * sin.unsqueeze(1)) # Compute attention for first q_len - window_size tokens @@ -83,7 +81,9 @@ def score( if attentions is not None: attn_weights = attentions[..., -self.window_size :, : -self.window_size] else: - attn_weights = self.compute_window_attention(module, hidden_states, keys, self.window_size) + attn_weights = self.compute_window_attention( + module, hidden_states, keys, self.window_size, kwargs["position_embeddings"] + ) scores = attn_weights.mean(dim=-2) scores = F.avg_pool1d(scores, kernel_size=self.kernel_size, padding=self.kernel_size // 2, stride=1) diff --git a/kvpress/presses/think_press.py b/kvpress/presses/think_press.py index 5309532e..fd84eb52 100644 --- a/kvpress/presses/think_press.py +++ b/kvpress/presses/think_press.py @@ -28,7 +28,7 @@ class ThinKPress(BasePress): key_channel_compression_ratio: float = 0.0 window_size: int = 32 - def compute_window_queries(self, module, hidden_states): + def compute_window_queries(self, module, hidden_states, position_embeddings): """ Re-compute the last window_size query states """ @@ -48,8 +48,7 @@ def compute_window_queries(self, module, hidden_states): query_states = query_states.view(bsz, self.window_size, num_heads, head_dim).transpose(1, 2) # Apply RoPE - position_ids = torch.arange(q_len - self.window_size, q_len).unsqueeze(0).to(query_states.device) - cos, sin = module.rotary_emb(query_states, position_ids) + cos, sin = position_embeddings query_states = (query_states * cos.unsqueeze(1)) + (rotate_half(query_states) * sin.unsqueeze(1)) return query_states @@ -75,7 +74,7 @@ def compress( bsz, num_key_value_heads, q_len, head_dim = keys.shape num_key_value_groups = module.config.num_attention_heads // num_key_value_heads - queries = self.compute_window_queries(module, kwargs["hidden_states"]) + queries = self.compute_window_queries(module, kwargs["hidden_states"], kwargs["position_embeddings"]) queries_norm = torch.pow(queries, 2).mean(dim=2) # (bsz, num_heads, head_dim) queries_norm = queries_norm.view(bsz, num_key_value_heads, num_key_value_groups, module.config.head_dim).mean(2) keys_norm = torch.pow(keys, 2).mean(dim=2) diff --git a/kvpress/presses/tova_press.py b/kvpress/presses/tova_press.py index 6a9eb8e0..5ca96cd1 100644 --- a/kvpress/presses/tova_press.py +++ b/kvpress/presses/tova_press.py @@ -37,7 +37,9 @@ def score( if attentions is not None: attn_weights = attentions[..., -1:, :-1] else: - attn_weights = SnapKVPress.compute_window_attention(module, hidden_states, keys, 1) + attn_weights = SnapKVPress.compute_window_attention( + module, hidden_states, keys, 1, kwargs["position_embeddings"] + ) # Average across heads and repeat num_key_value_head times scores = attn_weights.mean(1) From 21fb7a6c6ca9e5c07e070776b8c63871425df92b Mon Sep 17 00:00:00 2001 From: SimJeg Date: Thu, 9 Jan 2025 15:05:18 +0000 Subject: [PATCH 15/25] Fix ThinkPress --- kvpress/presses/key_rerotation_press.py | 2 ++ kvpress/presses/think_press.py | 1 + 2 files changed, 3 insertions(+) diff --git a/kvpress/presses/key_rerotation_press.py b/kvpress/presses/key_rerotation_press.py index da6dd226..561a3f43 100644 --- a/kvpress/presses/key_rerotation_press.py +++ b/kvpress/presses/key_rerotation_press.py @@ -39,6 +39,8 @@ def compress( if self.press.compression_ratio == 0: return keys, values + assert isinstance(self.press, ScorerPress) + # Compute scores from base press scores = self.press.score(module, hidden_states, keys, values, attentions, kwargs) diff --git a/kvpress/presses/think_press.py b/kvpress/presses/think_press.py index fd84eb52..2b608feb 100644 --- a/kvpress/presses/think_press.py +++ b/kvpress/presses/think_press.py @@ -49,6 +49,7 @@ def compute_window_queries(self, module, hidden_states, position_embeddings): # Apply RoPE cos, sin = position_embeddings + cos, sin = cos[:, -self.window_size:], sin[:, -self.window_size:] query_states = (query_states * cos.unsqueeze(1)) + (rotate_half(query_states) * sin.unsqueeze(1)) return query_states From 64d026002c81cb3d6c3e97d14df293312b20d11c Mon Sep 17 00:00:00 2001 From: SimJeg Date: Mon, 13 Jan 2025 09:42:15 +0000 Subject: [PATCH 16/25] Fix ExpectedAttentionPress with wrapper --- kvpress/presses/base_press.py | 2 +- kvpress/presses/expected_attention_press.py | 6 +----- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/kvpress/presses/base_press.py b/kvpress/presses/base_press.py index 564a9123..88a2d025 100644 --- a/kvpress/presses/base_press.py +++ b/kvpress/presses/base_press.py @@ -134,8 +134,8 @@ def __call__(self, model: PreTrainedModel) -> Generator: hooks = [] try: for layer in model.model.layers: + layer.self_attn.rotary_emb = model.model.rotary_emb hooks.append(layer.self_attn.register_forward_hook(self.forward_hook, with_kwargs=True)) - yield finally: for forward_hook in hooks: diff --git a/kvpress/presses/expected_attention_press.py b/kvpress/presses/expected_attention_press.py index fb79c520..fbca07d4 100644 --- a/kvpress/presses/expected_attention_press.py +++ b/kvpress/presses/expected_attention_press.py @@ -66,7 +66,7 @@ def get_query_statistics(self, module: nn.Module, hidden_states: torch.Tensor): # RoPE rotation matrix on next n_future_positions position_ids = torch.arange(q_len, q_len + self.n_future_positions).unsqueeze(0).to(mu.device) - cos, sin = self.rotary_emb(mu, position_ids) + cos, sin = module.rotary_emb(mu, position_ids) cos, sin = cos[0], sin[0] Id = torch.eye(d, device=cos.device, dtype=cos.dtype) @@ -132,7 +132,3 @@ def score( scores = F.pad(scores, (self.n_sink, 0), value=scores.max().item()) return scores - - def __call__(self, model): - self.rotary_emb = model.model.rotary_emb - return super().__call__(model) From bbbb75528e6084463223f81a8694f59ed1afd3a2 Mon Sep 17 00:00:00 2001 From: SimJeg Date: Mon, 13 Jan 2025 10:35:11 +0000 Subject: [PATCH 17/25] Fix head_dim and patch --- kvpress/attention_patch.py | 4 ++-- kvpress/presses/expected_attention_press.py | 2 +- kvpress/presses/key_rerotation_press.py | 2 +- kvpress/presses/scorer_press.py | 2 +- kvpress/presses/snapkv_press.py | 2 +- kvpress/presses/think_press.py | 4 ++-- .../presses/test_key_rerotation_press_rope.py | 24 +++++++++++++++---- tests/test_attention_patch.py | 2 +- 8 files changed, 29 insertions(+), 13 deletions(-) diff --git a/kvpress/attention_patch.py b/kvpress/attention_patch.py index 6e4541db..012d2b9d 100644 --- a/kvpress/attention_patch.py +++ b/kvpress/attention_patch.py @@ -24,7 +24,7 @@ def attention_patch(func): This solution is not optimal as it does not reduce peak memory and slightly increase runtime """ - def wrapper(module, query, key, value, attention_mask, dropout, scaling=None, is_causal=None, **kwargs): + def wrapper(module, query, key, value, attention_mask, dropout, **kwargs): if query.shape[2] == key.shape[2]: # Prefilling module.masked_key_indices = None @@ -43,7 +43,7 @@ def wrapper(module, query, key, value, attention_mask, dropout, scaling=None, is # At indices, update the keys to the fake keys and the values to 0 key[*module.masked_key_indices] = k[*module.masked_key_indices[:2]] - return func(module, query, key, value, attention_mask, dropout, scaling, is_causal, **kwargs) + return func(module, query, key, value, attention_mask, dropout, **kwargs) return wrapper diff --git a/kvpress/presses/expected_attention_press.py b/kvpress/presses/expected_attention_press.py index fbca07d4..3b1695e7 100644 --- a/kvpress/presses/expected_attention_press.py +++ b/kvpress/presses/expected_attention_press.py @@ -38,7 +38,7 @@ def get_query_statistics(self, module: nn.Module, hidden_states: torch.Tensor): """ bsz, q_len, _ = hidden_states.shape - n, d = module.config.num_attention_heads, module.config.head_dim + n, d = module.config.num_attention_heads, module.head_dim # Remove first hidden_states that likely contain outliers h = hidden_states[:, self.n_sink :] diff --git a/kvpress/presses/key_rerotation_press.py b/kvpress/presses/key_rerotation_press.py index 561a3f43..17abe556 100644 --- a/kvpress/presses/key_rerotation_press.py +++ b/kvpress/presses/key_rerotation_press.py @@ -48,7 +48,7 @@ def compress( q_len = hidden_states.shape[1] n_kept = int(q_len * (1 - self.press.compression_ratio)) indices = scores.topk(n_kept, dim=-1).indices - indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.config.head_dim) + indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim) cos, sin = kwargs["position_embeddings"] # Rerotate as follows diff --git a/kvpress/presses/scorer_press.py b/kvpress/presses/scorer_press.py index 8bbdea61..ea97eab4 100644 --- a/kvpress/presses/scorer_press.py +++ b/kvpress/presses/scorer_press.py @@ -62,7 +62,7 @@ def compress( q_len = hidden_states.shape[1] n_kept = int(q_len * (1 - self.compression_ratio)) indices = scores.topk(n_kept, dim=-1).indices - indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.config.head_dim) + indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim) # Prune keys and values keys = keys.gather(2, indices).contiguous() diff --git a/kvpress/presses/snapkv_press.py b/kvpress/presses/snapkv_press.py index 852c92f2..371e9368 100644 --- a/kvpress/presses/snapkv_press.py +++ b/kvpress/presses/snapkv_press.py @@ -33,7 +33,7 @@ def compute_window_attention(module, hidden_states, keys, window_size, position_ bsz, q_len, _ = hidden_states.shape num_heads = module.config.num_attention_heads - head_dim = module.config.head_dim + head_dim = module.head_dim num_key_value_groups = num_heads // module.config.num_key_value_heads # Get last window_size queries diff --git a/kvpress/presses/think_press.py b/kvpress/presses/think_press.py index 2b608feb..6f1b829b 100644 --- a/kvpress/presses/think_press.py +++ b/kvpress/presses/think_press.py @@ -34,7 +34,7 @@ def compute_window_queries(self, module, hidden_states, position_embeddings): """ bsz, q_len, _ = hidden_states.shape num_heads = module.config.num_attention_heads - head_dim = module.config.head_dim + head_dim = module.head_dim # Get last window_size queries if hasattr(module, "q_proj"): @@ -77,7 +77,7 @@ def compress( queries = self.compute_window_queries(module, kwargs["hidden_states"], kwargs["position_embeddings"]) queries_norm = torch.pow(queries, 2).mean(dim=2) # (bsz, num_heads, head_dim) - queries_norm = queries_norm.view(bsz, num_key_value_heads, num_key_value_groups, module.config.head_dim).mean(2) + queries_norm = queries_norm.view(bsz, num_key_value_heads, num_key_value_groups, module.head_dim).mean(2) keys_norm = torch.pow(keys, 2).mean(dim=2) key_scores = queries_norm * keys_norm # (bsz, num_key_value_heads, head_dim) diff --git a/tests/presses/test_key_rerotation_press_rope.py b/tests/presses/test_key_rerotation_press_rope.py index f7082d10..f890dc6f 100644 --- a/tests/presses/test_key_rerotation_press_rope.py +++ b/tests/presses/test_key_rerotation_press_rope.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 +import inspect from dataclasses import dataclass import pytest @@ -10,7 +11,6 @@ from transformers.models.llama.modeling_llama import LlamaAttention, LlamaForCausalLM, rotate_half from kvpress import KeyRerotationPress, ScorerPress -from kvpress.presses.key_rerotation_press import get_rope_embeddings from tests.fixtures import unit_test_model # noqa: F401 @@ -42,7 +42,12 @@ def test_rerotate_keys_is_matches_reference_implementation(unit_test_model: Llam values = torch.randn_like(keys) # Press result keys_compressed, _ = key_rerotation_press.compress( - module, hidden_states, keys, values, attentions=None, kwargs=dict() + module, + hidden_states, + keys, + values, + attentions=None, + kwargs={"position_embeddings": get_rope_embeddings(module, keys)}, ) indices = original_press.indices @@ -83,7 +88,7 @@ def score( q_len = hidden_states.shape[1] n_kept = int(q_len * (1 - self.compression_ratio)) indices = scores.topk(n_kept, dim=-1).indices - indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.config.head_dim) + indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim) self.indices = indices return scores @@ -109,6 +114,17 @@ def compute_rerotated_keys_comparison_implementation(module: LlamaAttention, hid def get_keys_without_pos_embedding(module, hidden_states): key_states = module.k_proj(hidden_states) key_states = key_states.view( - key_states.shape[0], key_states.shape[1], module.config.num_key_value_heads, module.config.head_dim + key_states.shape[0], key_states.shape[1], module.config.num_key_value_heads, module.head_dim ).transpose(1, 2) return key_states + + +def get_rope_embeddings(module, x): + length = x.shape[2] + # rotary_emb function only needs .device and .dtype, so we can plug in any tensor regardless of shape + if "position_ids" in inspect.signature(module.rotary_emb.forward).parameters: + position_ids = torch.arange(length).unsqueeze(0).to(x.device) + cos, sin = module.rotary_emb(x, position_ids) + else: + cos, sin = module.rotary_emb(x, length) + return cos, sin diff --git a/tests/test_attention_patch.py b/tests/test_attention_patch.py index 0a12a935..9333609d 100644 --- a/tests/test_attention_patch.py +++ b/tests/test_attention_patch.py @@ -6,4 +6,4 @@ def test_search_hyperplane(): bsz, seq_len, head_dim = 50, 500, 128 X = torch.rand(bsz, seq_len, head_dim) Y = search_hyperplane(X) - assert torch.exp(torch.bmm(X, Y.unsqueeze(-1))).max() == 0 \ No newline at end of file + assert torch.exp(torch.bmm(X, Y.unsqueeze(-1))).max() == 0 From a86534a00731010cdf6feaa8900d663299e6b8cd Mon Sep 17 00:00:00 2001 From: SimJeg Date: Mon, 13 Jan 2025 10:36:30 +0000 Subject: [PATCH 18/25] Fix flake8 --- tests/presses/test_presses.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/presses/test_presses.py b/tests/presses/test_presses.py index 4bb3a89e..d76192a6 100644 --- a/tests/presses/test_presses.py +++ b/tests/presses/test_presses.py @@ -7,7 +7,15 @@ from torch import nn from transformers import DynamicCache -from kvpress import ComposedPress, KeyRerotationPress, KnormPress, ObservedAttentionPress, AdaKVPress, ThinKPress, ScorerPress +from kvpress import ( + ComposedPress, + KeyRerotationPress, + KnormPress, + ObservedAttentionPress, + AdaKVPress, + ThinKPress, + ScorerPress, +) from tests.default_presses import default_presses from tests.fixtures import unit_test_model, unit_test_model_output_attention # noqa: F401 From 7ea10c740da7534ec911f6d5448e3ce41430ad45 Mon Sep 17 00:00:00 2001 From: SimJeg Date: Mon, 13 Jan 2025 13:17:10 +0000 Subject: [PATCH 19/25] Update notebook --- kvpress/attention_patch.py | 2 +- notebooks/new_press.ipynb | 167 ++++++++++++++++++++++++++++++++++++- 2 files changed, 164 insertions(+), 5 deletions(-) diff --git a/kvpress/attention_patch.py b/kvpress/attention_patch.py index 012d2b9d..4c66b36f 100644 --- a/kvpress/attention_patch.py +++ b/kvpress/attention_patch.py @@ -40,7 +40,7 @@ def wrapper(module, query, key, value, attention_mask, dropout, **kwargs): k = search_hyperplane(q) k = k.view(bsz, num_key_value_heads, head_dim) - # At indices, update the keys to the fake keys and the values to 0 + # At indices, update the keys to the fake keys key[*module.masked_key_indices] = k[*module.masked_key_indices[:2]] return func(module, query, key, value, attention_mask, dropout, **kwargs) diff --git a/notebooks/new_press.ipynb b/notebooks/new_press.ipynb index 3ffb2799..59f8cc4e 100644 --- a/notebooks/new_press.ipynb +++ b/notebooks/new_press.ipynb @@ -11,7 +11,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -26,9 +26,108 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "e6758cc9db344df3840d72945b23f5d2", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "model.safetensors: 59%|#####8 | 1.81G/3.09G [00:00 Date: Mon, 13 Jan 2025 13:21:56 +0000 Subject: [PATCH 20/25] merge --- evaluation/evaluate.py | 7 ++++++- kvpress/presses/key_rerotation_press.py | 3 +++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/evaluation/evaluate.py b/evaluation/evaluate.py index 2a1fd296..d165fdab 100644 --- a/evaluation/evaluate.py +++ b/evaluation/evaluate.py @@ -119,6 +119,11 @@ def evaluate( df = df.sample(frac=fraction, random_state=42) save_filename = save_filename.with_name(save_filename.stem + f"__fraction{fraction:.2f}" + save_filename.suffix) + if max_context_length is not None: + save_filename = save_filename.with_name( + save_filename.stem + f"__max_context{max_context_length}" + save_filename.suffix + ) + if compress_questions: df["context"] = df["context"] + df["question"] df["question"] = "" @@ -127,7 +132,7 @@ def evaluate( # Load press assert press_name in PRESS_DICT press = PRESS_DICT[press_name] - press.compression_ratio = compression_ratio + press.compression_ratio = compression_ratio # type:ignore[attr-defined] # Initialize pipeline with the correct attention implementation model_kwargs = {"torch_dtype": "auto"} diff --git a/kvpress/presses/key_rerotation_press.py b/kvpress/presses/key_rerotation_press.py index 17abe556..521f122c 100644 --- a/kvpress/presses/key_rerotation_press.py +++ b/kvpress/presses/key_rerotation_press.py @@ -27,6 +27,9 @@ class KeyRerotationPress(BasePress): press: ScorerPress + def __post_init__(self): + assert isinstance(self.press, ScorerPress) + def compress( self, module: nn.Module, From e78f8fc6c954cea9ff78b46d2398b181058637c5 Mon Sep 17 00:00:00 2001 From: SimJeg Date: Mon, 13 Jan 2025 14:06:24 +0000 Subject: [PATCH 21/25] Update python version --- .github/workflows/python-publish.yml | 2 +- .github/workflows/style.yml | 2 +- .github/workflows/test.yml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index 6a67207a..a27515ff 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -17,7 +17,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v3 with: - python-version: 3.10.11 + python-version: 3.12.0 - name: Install dependencies run: | python -m pip install --upgrade pip diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index 19f6b008..238cea29 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -13,7 +13,7 @@ jobs: - name: Setup Python uses: actions/setup-python@v4 with: - python-version: 3.10.11 + python-version: 3.12.0 - name: Install Poetry run: | diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 07eafff1..400e929d 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -13,7 +13,7 @@ jobs: - name: Setup Python uses: actions/setup-python@v4 with: - python-version: 3.10.11 + python-version: 3.12.0 - name: Install Poetry run: | From 67654a7b9adee7de0c4e893ccd6895a46d24f0f2 Mon Sep 17 00:00:00 2001 From: SimJeg Date: Mon, 13 Jan 2025 14:16:38 +0000 Subject: [PATCH 22/25] Update patch --- kvpress/attention_patch.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/kvpress/attention_patch.py b/kvpress/attention_patch.py index 4c66b36f..c40e1e41 100644 --- a/kvpress/attention_patch.py +++ b/kvpress/attention_patch.py @@ -41,7 +41,8 @@ def wrapper(module, query, key, value, attention_mask, dropout, **kwargs): k = k.view(bsz, num_key_value_heads, head_dim) # At indices, update the keys to the fake keys - key[*module.masked_key_indices] = k[*module.masked_key_indices[:2]] + batch_indices, head_indices, seq_indices = module.masked_key_indices + key[batch_indices, head_indices, seq_indices] = k[batch_indices, head_indices] return func(module, query, key, value, attention_mask, dropout, **kwargs) From f804690cddec66dc420e4d857bf37fde6b0d7634 Mon Sep 17 00:00:00 2001 From: SimJeg Date: Mon, 13 Jan 2025 14:18:49 +0000 Subject: [PATCH 23/25] Back to 3.10.11 --- .github/workflows/python-publish.yml | 2 +- .github/workflows/style.yml | 2 +- .github/workflows/test.yml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index a27515ff..6a67207a 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -17,7 +17,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v3 with: - python-version: 3.12.0 + python-version: 3.10.11 - name: Install dependencies run: | python -m pip install --upgrade pip diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index 238cea29..19f6b008 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -13,7 +13,7 @@ jobs: - name: Setup Python uses: actions/setup-python@v4 with: - python-version: 3.12.0 + python-version: 3.10.11 - name: Install Poetry run: | diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 400e929d..07eafff1 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -13,7 +13,7 @@ jobs: - name: Setup Python uses: actions/setup-python@v4 with: - python-version: 3.12.0 + python-version: 3.10.11 - name: Install Poetry run: | From 9edb610a54178a1572e73755f7e187656fb3f00f Mon Sep 17 00:00:00 2001 From: SimJeg Date: Mon, 13 Jan 2025 15:04:47 +0000 Subject: [PATCH 24/25] Add docs --- kvpress/__init__.py | 1 + kvpress/presses/adakv_press.py | 13 ++++++++++--- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/kvpress/__init__.py b/kvpress/__init__.py index 774eaa20..86748862 100644 --- a/kvpress/__init__.py +++ b/kvpress/__init__.py @@ -20,6 +20,7 @@ from kvpress.presses.tova_press import TOVAPress from kvpress.attention_patch import patch_attention_functions +# Patch the attention functions to support head-wise compression patch_attention_functions() __all__ = [ diff --git a/kvpress/presses/adakv_press.py b/kvpress/presses/adakv_press.py index e4a2a1e1..a4ea1c33 100644 --- a/kvpress/presses/adakv_press.py +++ b/kvpress/presses/adakv_press.py @@ -16,11 +16,16 @@ class AdaKVPress(BasePress): AdaKV (https://arxiv.org/abs/2407.11550) selects the top-k keys and values among all heads in a layer based on the scores, achieving head-specific compression. A safeguard is applied to ensure a minimum fraction of KV pairs per head (alpha_safeguard parameter) + This press has been reviewed by Yuan Feng, first author of AdaKV. """ scorer: ScorerPress alpha_safeguard: float = 0.20 + def __post_init__(self): + assert isinstance(self.scorer, ScorerPress), "AdaKVPress requires a ScorerPress as input" + assert 0 <= self.alpha_safeguard <= 1, "alpha_safeguard should be in [0, 1]" + @property def compression_ratio(self): return self.scorer.compression_ratio @@ -34,7 +39,6 @@ def compress(self, module, hidden_states, keys, values, attentions, kwargs): return keys, values assert module.config._attn_implementation != "eager", "eager mode not supported" - assert isinstance(self.scorer, ScorerPress), "AdaKVPress requires a ScorerPress as input" # Compute scores scores = self.scorer.score(module, hidden_states, keys, values, attentions, kwargs) @@ -50,6 +54,9 @@ def compress(self, module, hidden_states, keys, values, attentions, kwargs): n_pruned = num_key_value_heads * (q_len - n_kept) indices = torch.topk(-scores.reshape(bsz, -1), n_pruned, dim=1).indices.flatten() - # Save indices for attention patching in the module - module.masked_key_indices = (torch.arange(bsz).repeat_interleave(n_pruned), indices // q_len, indices % q_len) + # Save indices to mask during the attention mechanism. Please refer to attention_patch.py for more details + batch_indices = torch.arange(bsz).repeat_interleave(n_pruned) + head_indices = indices // q_len + seq_indices = indices % q_len + module.masked_key_indices = (batch_indices, head_indices, seq_indices) return keys, values From 3ac3df29a10c236e487899692ef0c29c5eafd163 Mon Sep 17 00:00:00 2001 From: SimJeg Date: Mon, 13 Jan 2025 17:27:33 +0000 Subject: [PATCH 25/25] Fix --- kvpress/presses/key_rerotation_press.py | 2 - notebooks/new_press.ipynb | 109 +++--------------------- 2 files changed, 10 insertions(+), 101 deletions(-) diff --git a/kvpress/presses/key_rerotation_press.py b/kvpress/presses/key_rerotation_press.py index 521f122c..0675fbcb 100644 --- a/kvpress/presses/key_rerotation_press.py +++ b/kvpress/presses/key_rerotation_press.py @@ -42,8 +42,6 @@ def compress( if self.press.compression_ratio == 0: return keys, values - assert isinstance(self.press, ScorerPress) - # Compute scores from base press scores = self.press.score(module, hidden_states, keys, values, attentions, kwargs) diff --git a/notebooks/new_press.ipynb b/notebooks/new_press.ipynb index 59f8cc4e..a64ede71 100644 --- a/notebooks/new_press.ipynb +++ b/notebooks/new_press.ipynb @@ -16,6 +16,7 @@ "outputs": [], "source": [ "from dataclasses import dataclass\n", + "from contextlib import contextmanager\n", "\n", "import torch\n", "from torch import nn\n", @@ -29,101 +30,11 @@ "execution_count": 2, "metadata": {}, "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "e6758cc9db344df3840d72945b23f5d2", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "model.safetensors: 59%|#####8 | 1.81G/3.09G [00:00