diff --git a/README.md b/README.md index 38859098..f45c60d1 100644 --- a/README.md +++ b/README.md @@ -76,6 +76,7 @@ Finally we provide wrapper presses that can be combined with other presses: - `ComposedPress` ([source](kvpress/presses/composed_press.py)): compose multiple presses together by chaining their forward hooks - `KeyRerotationPress` ([source](kvpress/presses/key_rerotation_press.py)): rerotate pruned keys to have continuous RoPE embeddings - `ChunkPress` ([source](kvpress/presses/chunk_press.py), [paper](https://direct.mit.edu/tacl/article/doi/10.1162/tacl_a_00716/125280)): compress the KV cache on each sequence chunk separately. This can yield to more uniform compression across long sequences +- `CriticalKVPress` and `CriticalAdaKVPress` ([source](kvpress/presses/criticalkv_press.py), [paper](https://arxiv.org/abs/2502.03805)): refine the scores using the L1 norm of Wo @ values, coupled with a two-stage selection. For a detailed list of existing KV cache compression methods, check [Awesome-KV-Cache-Compression](https://github.com/October2001/Awesome-KV-Cache-Compression) or [Awesome-LLM-Compression](https://github.com/HuangOwen/Awesome-LLM-Compression?tab=readme-ov-file#kv-cache-compression) diff --git a/evaluation/evaluate.py b/evaluation/evaluate.py index d165fdab..00fdbad8 100644 --- a/evaluation/evaluate.py +++ b/evaluation/evaluate.py @@ -17,6 +17,8 @@ from zero_scrolls.calculate_metrics import calculate_metrics as zero_scrolls_scorer from kvpress import ( + CriticalKVPress, + CriticalAdaKVPress, AdaKVPress, ExpectedAttentionPress, KnormPress, @@ -45,6 +47,10 @@ } PRESS_DICT = { + "criti_adasnapkv": CriticalAdaKVPress(SnapKVPress()), + "criti_ada_expected_attention": CriticalAdaKVPress(ExpectedAttentionPress(use_vnorm=False)), + "criti_snapkv": CriticalKVPress(SnapKVPress()), + "criti_expected_attention": CriticalKVPress(ExpectedAttentionPress(use_vnorm=False)), "adasnapkv": AdaKVPress(SnapKVPress()), "ada_expected_attention": AdaKVPress(ExpectedAttentionPress()), "expected_attention": ExpectedAttentionPress(), diff --git a/kvpress/__init__.py b/kvpress/__init__.py index c5c4d993..54e419c6 100644 --- a/kvpress/__init__.py +++ b/kvpress/__init__.py @@ -20,11 +20,13 @@ from kvpress.presses.streaming_llm_press import StreamingLLMPress from kvpress.presses.think_press import ThinKPress from kvpress.presses.tova_press import TOVAPress - +from kvpress.presses.criticalkv_press import CriticalKVPress, CriticalAdaKVPress # Patch the attention functions to support head-wise compression patch_attention_functions() __all__ = [ + "CriticalAdaKVPress", + "CriticalKVPress", "AdaKVPress", "BasePress", "ComposedPress", diff --git a/kvpress/presses/criticalkv_press.py b/kvpress/presses/criticalkv_press.py new file mode 100644 index 00000000..9574a980 --- /dev/null +++ b/kvpress/presses/criticalkv_press.py @@ -0,0 +1,161 @@ +# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import logging +from dataclasses import dataclass + +import torch +from transformers.models.llama.modeling_llama import repeat_kv + +from kvpress.presses.base_press import BasePress +from kvpress.presses.scorer_press import ScorerPress +from kvpress.presses.expected_attention_press import ExpectedAttentionPress + +logger = logging.getLogger(__name__) + + +class CriticalKVPress(ScorerPress): + """ + CriticalKV (https://arxiv.org/abs/2502.03805) rescales the scores of a ScorerPress by + the L1 norm of Wo @ values + """ + + def __init__(self, press: ScorerPress, epsilon: float = 1e-4, first_stage_ratio: float = 0.5): + self.press = press + self.epsilon = epsilon + self.first_stage_ratio = first_stage_ratio + + assert isinstance(self.press, ScorerPress), "CriticalAdaKVPress requires a ScorerPress as input" + if isinstance(self.press, ExpectedAttentionPress) and self.press.use_vnorm: + logger.warning("use_vnorm should be disabled for CriticalAdaKVPress") + + @property + def compression_ratio(self): + return self.press.compression_ratio + + @compression_ratio.setter + def compression_ratio(self, value): + self.press.compression_ratio = value + + @staticmethod + def vwl1norm(values, module): + bsz, num_key_value_heads, q_len, _ = values.shape + num_key_value_groups = module.config.num_attention_heads // num_key_value_heads + Wo = module.o_proj.weight.transpose(0, 1) + Wo = Wo.view(module.config.num_attention_heads, module.config.head_dim, module.config.hidden_size) + V = repeat_kv(values, num_key_value_groups) + + # We use head-wise computation instead of direct matmul to reduce the memory usage of WoV. + # Future kernel fusion optimization could eliminate this intermediate variables to enhance performance. + head_WoV_norm_list = [] + for head in range(V.size(1)): + head_WoV = V[: , head, : , ...].matmul(Wo[head, ...].unsqueeze(0)) + head_WoV_norm = torch.norm(head_WoV, p=1, dim=-1) + head_WoV_norm_list.append(head_WoV_norm) + + # b_size, num_heads, q_len , k_len + WoV_norm = torch.stack(head_WoV_norm_list, dim=1) + WoV_norm = WoV_norm.view(bsz, num_key_value_heads, module.num_key_value_groups, q_len).mean(dim=2) + return WoV_norm + + def score(self, module, hidden_states, keys, values, attentions, kwargs): + # Stage 1 + scores = self.press.score(module, hidden_states, keys, values, attentions, kwargs) + q_len = keys.shape[2] + selection_budget = int((1 - self.compression_ratio) * q_len * self.first_stage_ratio) + top_k_index = torch.topk(scores, selection_budget, sorted=True, dim=-1).indices + + # Stage 2 + projected_norm = self.vwl1norm(values, module) + scores = (scores + self.epsilon) * projected_norm + + # Merge the two stages + scores.scatter_(-1, top_k_index, torch.finfo(scores.dtype).max) + + return scores + + +@dataclass +class CriticalAdaKVPress(BasePress): + """ + CriticalAdaKV (https://arxiv.org/abs/2502.03805) rescales the scores of a ScorerPress by + the L1 norm of Wo @ values and combines it with AdaKV (https://arxiv.org/abs/2407.11550). + """ + + press: ScorerPress + alpha_safeguard: float = 0.20 + epsilon: float = 1e-4 + first_stage_ratio: float = 0.5 + + def __post_init__(self): + assert 0 <= self.alpha_safeguard <= 1, "alpha_safeguard should be in 0, 1]" + assert isinstance(self.press, ScorerPress), "CriticalAdaKVPress requires a ScorerPress as input" + if isinstance(self.press, ExpectedAttentionPress) and self.press.use_vnorm: + logger.warning("use_vnorm should be disabled for CriticalAdaKVPress") + + @property + def compression_ratio(self): + return self.press.compression_ratio + + @compression_ratio.setter + def compression_ratio(self, value): + self.press.compression_ratio = value + + def compress(self, module, hidden_states, keys, values, attentions, kwargs): + + if self.compression_ratio == 0: + return keys, values + + assert module.config._attn_implementation != "eager", "eager mode not supported" + + # Compute scores + scores = self.press.score(module, hidden_states, keys, values, attentions, kwargs) + bsz, num_key_value_heads, q_len = scores.shape + + # Make sure to keep at least alpha * (1 - compression_ratio) KV pairs per head + n_kept = int(q_len * (1 - self.compression_ratio)) # ScorerPress definition + n_safe = int(n_kept * self.alpha_safeguard) + top_indices = torch.topk(scores, n_safe, dim=-1).indices + scores.scatter_(-1, top_indices, torch.finfo(scores.dtype).max) + + ############################ + # Start of CriticalKV code # + ############################ + + # Budget allocation + budget_scores = scores.scatter(-1, top_indices, torch.finfo(scores.dtype).max) + budget_scores = budget_scores.reshape(bsz, -1) + top_indices = torch.topk(budget_scores, n_kept * num_key_value_heads, dim=-1).indices + top_indices_head_idx = top_indices // q_len + head_budgets = torch.zeros(num_key_value_heads, device=keys.device, dtype=torch.int64) + head_budgets.scatter_add_(0, top_indices_head_idx.flatten(), torch.ones_like(top_indices_head_idx.flatten())) + + # Stage 1 + head_selection_budget_1st = (head_budgets * self.first_stage_ratio).to(torch.int64).tolist() + top_k_index = torch.topk(scores, max(head_selection_budget_1st), sorted=True, dim=-1).indices + for head_idx in range(num_key_value_heads): + phase1_budget = head_selection_budget_1st[head_idx] + scores[:, head_idx, :].scatter_(-1, top_k_index[:, head_idx, :phase1_budget], torch.finfo(scores.dtype).max) + + # Stage 2 + projected_norm = CriticalKVPress.vwl1norm(values, module) + scores = (scores + self.epsilon) * projected_norm + top_k_index = torch.topk(scores, max(head_budgets), sorted=True, dim=-1).indices + for head_idx in range(num_key_value_heads): + budget = head_budgets[head_idx] + scores[:, head_idx, :].scatter_(-1, top_k_index[:, head_idx, :budget], torch.finfo(scores.dtype).max) + + ########################## + # End of CriticalKV code # + ########################## + + # Compute bottom-k across heads + n_pruned = num_key_value_heads * (q_len - n_kept) + indices = torch.topk(-scores.reshape(bsz, -1), n_pruned, dim=1).indices.flatten() + + # Save indices to mask during the attention mechanism. Please refer to attention_patch.py for more details + batch_indices = torch.arange(bsz).repeat_interleave(n_pruned) + head_indices = indices // q_len + seq_indices = indices % q_len + module.masked_key_indices = (batch_indices, head_indices, seq_indices) + return keys, values diff --git a/tests/presses/test_presses.py b/tests/presses/test_presses.py index ee2e0fa9..27adab64 100644 --- a/tests/presses/test_presses.py +++ b/tests/presses/test_presses.py @@ -8,6 +8,8 @@ from transformers import DynamicCache from kvpress import ( + CriticalKVPress, + CriticalAdaKVPress, AdaKVPress, ChunkPress, ComposedPress, @@ -42,7 +44,8 @@ def test_chunk_press(unit_test_model): # noqa: F811 @pytest.mark.parametrize("press_dict", default_presses) -@pytest.mark.parametrize("wrapper_press", [None, ComposedPress, KeyRerotationPress, AdaKVPress, ChunkPress]) +@pytest.mark.parametrize("wrapper_press", [None, ComposedPress, KeyRerotationPress, AdaKVPress, ChunkPress, + CriticalKVPress, CriticalAdaKVPress]) def test_presses_run(unit_test_model, press_dict, wrapper_press): # noqa: F811 cls = press_dict["cls"] for kwargs in press_dict["kwargs"]: @@ -51,14 +54,13 @@ def test_presses_run(unit_test_model, press_dict, wrapper_press): # noqa: F811 press = ComposedPress(presses=[press]) if isinstance(wrapper_press, KeyRerotationPress): press = KeyRerotationPress(press=press) - if isinstance(wrapper_press, AdaKVPress): - if not isinstance(press, ScorerPress): - return + if isinstance(wrapper_press, (AdaKVPress, CriticalKVPress, CriticalAdaKVPress)): + if isinstance(press, ScorerPress): + press = wrapper_press(press=press) else: - press = AdaKVPress(press=press) + return if isinstance(wrapper_press, ChunkPress): press = ChunkPress(press=press, chunk_length=2) - with press(unit_test_model): input_ids = unit_test_model.dummy_inputs["input_ids"] unit_test_model(input_ids, past_key_values=DynamicCache()).past_key_values