Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ Finally we provide wrapper presses that can be combined with other presses:
- `ComposedPress` ([source](kvpress/presses/composed_press.py)): compose multiple presses together by chaining their forward hooks
- `KeyRerotationPress` ([source](kvpress/presses/key_rerotation_press.py)): rerotate pruned keys to have continuous RoPE embeddings
- `ChunkPress` ([source](kvpress/presses/chunk_press.py), [paper](https://direct.mit.edu/tacl/article/doi/10.1162/tacl_a_00716/125280)): compress the KV cache on each sequence chunk separately. This can yield to more uniform compression across long sequences
- `CriticalKVPress` and `CriticalAdaKVPress` ([source](kvpress/presses/criticalkv_press.py), [paper](https://arxiv.org/abs/2502.03805)): refine the scores using the L1 norm of Wo @ values, coupled with a two-stage selection.

For a detailed list of existing KV cache compression methods, check [Awesome-KV-Cache-Compression](https://github.com/October2001/Awesome-KV-Cache-Compression) or [Awesome-LLM-Compression](https://github.com/HuangOwen/Awesome-LLM-Compression?tab=readme-ov-file#kv-cache-compression)

Expand Down
6 changes: 6 additions & 0 deletions evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from zero_scrolls.calculate_metrics import calculate_metrics as zero_scrolls_scorer

from kvpress import (
CriticalKVPress,
CriticalAdaKVPress,
AdaKVPress,
ExpectedAttentionPress,
KnormPress,
Expand Down Expand Up @@ -45,6 +47,10 @@
}

PRESS_DICT = {
"criti_adasnapkv": CriticalAdaKVPress(SnapKVPress()),
"criti_ada_expected_attention": CriticalAdaKVPress(ExpectedAttentionPress(use_vnorm=False)),
"criti_snapkv": CriticalKVPress(SnapKVPress()),
"criti_expected_attention": CriticalKVPress(ExpectedAttentionPress(use_vnorm=False)),
"adasnapkv": AdaKVPress(SnapKVPress()),
"ada_expected_attention": AdaKVPress(ExpectedAttentionPress()),
"expected_attention": ExpectedAttentionPress(),
Expand Down
4 changes: 3 additions & 1 deletion kvpress/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@
from kvpress.presses.streaming_llm_press import StreamingLLMPress
from kvpress.presses.think_press import ThinKPress
from kvpress.presses.tova_press import TOVAPress

from kvpress.presses.criticalkv_press import CriticalKVPress, CriticalAdaKVPress
# Patch the attention functions to support head-wise compression
patch_attention_functions()

__all__ = [
"CriticalAdaKVPress",
"CriticalKVPress",
"AdaKVPress",
"BasePress",
"ComposedPress",
Expand Down
161 changes: 161 additions & 0 deletions kvpress/presses/criticalkv_press.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import logging
from dataclasses import dataclass

import torch
from transformers.models.llama.modeling_llama import repeat_kv

from kvpress.presses.base_press import BasePress
from kvpress.presses.scorer_press import ScorerPress
from kvpress.presses.expected_attention_press import ExpectedAttentionPress

logger = logging.getLogger(__name__)


class CriticalKVPress(ScorerPress):
"""
CriticalKV (https://arxiv.org/abs/2502.03805) rescales the scores of a ScorerPress by
the L1 norm of Wo @ values
"""

def __init__(self, press: ScorerPress, epsilon: float = 1e-4, first_stage_ratio: float = 0.5):
self.press = press
self.epsilon = epsilon
self.first_stage_ratio = first_stage_ratio

assert isinstance(self.press, ScorerPress), "CriticalAdaKVPress requires a ScorerPress as input"
if isinstance(self.press, ExpectedAttentionPress) and self.press.use_vnorm:
logger.warning("use_vnorm should be disabled for CriticalAdaKVPress")

@property
def compression_ratio(self):
return self.press.compression_ratio

@compression_ratio.setter
def compression_ratio(self, value):
self.press.compression_ratio = value

@staticmethod
def vwl1norm(values, module):
bsz, num_key_value_heads, q_len, _ = values.shape
num_key_value_groups = module.config.num_attention_heads // num_key_value_heads
Wo = module.o_proj.weight.transpose(0, 1)
Wo = Wo.view(module.config.num_attention_heads, module.config.head_dim, module.config.hidden_size)
V = repeat_kv(values, num_key_value_groups)

# We use head-wise computation instead of direct matmul to reduce the memory usage of WoV.
# Future kernel fusion optimization could eliminate this intermediate variables to enhance performance.
head_WoV_norm_list = []
for head in range(V.size(1)):
head_WoV = V[: , head, : , ...].matmul(Wo[head, ...].unsqueeze(0))
head_WoV_norm = torch.norm(head_WoV, p=1, dim=-1)
head_WoV_norm_list.append(head_WoV_norm)

# b_size, num_heads, q_len , k_len
WoV_norm = torch.stack(head_WoV_norm_list, dim=1)
WoV_norm = WoV_norm.view(bsz, num_key_value_heads, module.num_key_value_groups, q_len).mean(dim=2)
return WoV_norm

def score(self, module, hidden_states, keys, values, attentions, kwargs):
# Stage 1
scores = self.press.score(module, hidden_states, keys, values, attentions, kwargs)
q_len = keys.shape[2]
selection_budget = int((1 - self.compression_ratio) * q_len * self.first_stage_ratio)
top_k_index = torch.topk(scores, selection_budget, sorted=True, dim=-1).indices

# Stage 2
projected_norm = self.vwl1norm(values, module)
scores = (scores + self.epsilon) * projected_norm

# Merge the two stages
scores.scatter_(-1, top_k_index, torch.finfo(scores.dtype).max)

return scores


@dataclass
class CriticalAdaKVPress(BasePress):
"""
CriticalAdaKV (https://arxiv.org/abs/2502.03805) rescales the scores of a ScorerPress by
the L1 norm of Wo @ values and combines it with AdaKV (https://arxiv.org/abs/2407.11550).
"""

press: ScorerPress
alpha_safeguard: float = 0.20
epsilon: float = 1e-4
first_stage_ratio: float = 0.5

def __post_init__(self):
assert 0 <= self.alpha_safeguard <= 1, "alpha_safeguard should be in 0, 1]"
assert isinstance(self.press, ScorerPress), "CriticalAdaKVPress requires a ScorerPress as input"
if isinstance(self.press, ExpectedAttentionPress) and self.press.use_vnorm:
logger.warning("use_vnorm should be disabled for CriticalAdaKVPress")

@property
def compression_ratio(self):
return self.press.compression_ratio

@compression_ratio.setter
def compression_ratio(self, value):
self.press.compression_ratio = value

def compress(self, module, hidden_states, keys, values, attentions, kwargs):

if self.compression_ratio == 0:
return keys, values

assert module.config._attn_implementation != "eager", "eager mode not supported"

# Compute scores
scores = self.press.score(module, hidden_states, keys, values, attentions, kwargs)
bsz, num_key_value_heads, q_len = scores.shape

# Make sure to keep at least alpha * (1 - compression_ratio) KV pairs per head
n_kept = int(q_len * (1 - self.compression_ratio)) # ScorerPress definition
n_safe = int(n_kept * self.alpha_safeguard)
top_indices = torch.topk(scores, n_safe, dim=-1).indices
scores.scatter_(-1, top_indices, torch.finfo(scores.dtype).max)

############################
# Start of CriticalKV code #
############################

# Budget allocation
budget_scores = scores.scatter(-1, top_indices, torch.finfo(scores.dtype).max)
budget_scores = budget_scores.reshape(bsz, -1)
top_indices = torch.topk(budget_scores, n_kept * num_key_value_heads, dim=-1).indices
top_indices_head_idx = top_indices // q_len
head_budgets = torch.zeros(num_key_value_heads, device=keys.device, dtype=torch.int64)
head_budgets.scatter_add_(0, top_indices_head_idx.flatten(), torch.ones_like(top_indices_head_idx.flatten()))

# Stage 1
head_selection_budget_1st = (head_budgets * self.first_stage_ratio).to(torch.int64).tolist()
top_k_index = torch.topk(scores, max(head_selection_budget_1st), sorted=True, dim=-1).indices
for head_idx in range(num_key_value_heads):
phase1_budget = head_selection_budget_1st[head_idx]
scores[:, head_idx, :].scatter_(-1, top_k_index[:, head_idx, :phase1_budget], torch.finfo(scores.dtype).max)

# Stage 2
projected_norm = CriticalKVPress.vwl1norm(values, module)
scores = (scores + self.epsilon) * projected_norm
top_k_index = torch.topk(scores, max(head_budgets), sorted=True, dim=-1).indices
for head_idx in range(num_key_value_heads):
budget = head_budgets[head_idx]
scores[:, head_idx, :].scatter_(-1, top_k_index[:, head_idx, :budget], torch.finfo(scores.dtype).max)

##########################
# End of CriticalKV code #
##########################

# Compute bottom-k across heads
n_pruned = num_key_value_heads * (q_len - n_kept)
indices = torch.topk(-scores.reshape(bsz, -1), n_pruned, dim=1).indices.flatten()

# Save indices to mask during the attention mechanism. Please refer to attention_patch.py for more details
batch_indices = torch.arange(bsz).repeat_interleave(n_pruned)
head_indices = indices // q_len
seq_indices = indices % q_len
module.masked_key_indices = (batch_indices, head_indices, seq_indices)
return keys, values
14 changes: 8 additions & 6 deletions tests/presses/test_presses.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from transformers import DynamicCache

from kvpress import (
CriticalKVPress,
CriticalAdaKVPress,
AdaKVPress,
ChunkPress,
ComposedPress,
Expand Down Expand Up @@ -42,7 +44,8 @@ def test_chunk_press(unit_test_model): # noqa: F811


@pytest.mark.parametrize("press_dict", default_presses)
@pytest.mark.parametrize("wrapper_press", [None, ComposedPress, KeyRerotationPress, AdaKVPress, ChunkPress])
@pytest.mark.parametrize("wrapper_press", [None, ComposedPress, KeyRerotationPress, AdaKVPress, ChunkPress,
CriticalKVPress, CriticalAdaKVPress])
def test_presses_run(unit_test_model, press_dict, wrapper_press): # noqa: F811
cls = press_dict["cls"]
for kwargs in press_dict["kwargs"]:
Expand All @@ -51,14 +54,13 @@ def test_presses_run(unit_test_model, press_dict, wrapper_press): # noqa: F811
press = ComposedPress(presses=[press])
if isinstance(wrapper_press, KeyRerotationPress):
press = KeyRerotationPress(press=press)
if isinstance(wrapper_press, AdaKVPress):
if not isinstance(press, ScorerPress):
return
if isinstance(wrapper_press, (AdaKVPress, CriticalKVPress, CriticalAdaKVPress)):
if isinstance(press, ScorerPress):
press = wrapper_press(press=press)
else:
press = AdaKVPress(press=press)
return
if isinstance(wrapper_press, ChunkPress):
press = ChunkPress(press=press, chunk_length=2)

with press(unit_test_model):
input_ids = unit_test_model.dummy_inputs["input_ids"]
unit_test_model(input_ids, past_key_values=DynamicCache()).past_key_values
Expand Down