Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ Several presses inherit from `ScorerPress` ([source](kvpress/presses/scorer_pres
- `StreamingLLMPress` ([source](kvpress/presses/streaming_llm_press.py), [paper](https://arxiv.org/abs/2309.17453)): keep only the initial and recent tokens
- `TOVAPress` ([source](kvpress/presses/tova_press.py), [paper](https://arxiv.org/abs/2401.06104)): attention weight of the last query averaged across heads
- `ObservedAttentionPress` ([source](kvpress/presses/observed_attention_press.py), [paper](https://arxiv.org/abs/2306.14048)): average attention weight observed during in pre-filling phase
- `QFilterPress`: project the Key representations on the main SVD component of the Query vectors to approximate the attention scores. ([source](kvpress/presses/qfilter_press.py), [paper](https://arxiv.org/abs/2503.02812))
Comment thread
SimJeg marked this conversation as resolved.
Outdated

Some presses rely on a different logic:
- `ThinKPress` ([source](kvpress/presses/think_press.py), [paper](https://arxiv.org/pdf/2407.21018)): compress the dimensions of the keys based on the channel attention score on the last queries
Expand All @@ -81,6 +82,7 @@ Finally we provide wrapper presses that can be combined with other presses:
- `ChunkPress` ([source](kvpress/presses/chunk_press.py), [paper](https://direct.mit.edu/tacl/article/doi/10.1162/tacl_a_00716/125280)): compress the KV cache on each sequence chunk separately. This can yield to more uniform compression across long sequences
- `CriticalKVPress` and `CriticalAdaKVPress` ([source](kvpress/presses/criticalkv_press.py), [paper](https://arxiv.org/abs/2502.03805)): refine the scores using the L1 norm of Wo @ values, coupled with a two-stage selection.


For a detailed list of existing KV cache compression methods, check [Awesome-KV-Cache-Compression](https://github.com/October2001/Awesome-KV-Cache-Compression) or [Awesome-LLM-Compression](https://github.com/HuangOwen/Awesome-LLM-Compression?tab=readme-ov-file#kv-cache-compression)

## Evaluation
Expand Down
2 changes: 2 additions & 0 deletions kvpress/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from kvpress.presses.streaming_llm_press import StreamingLLMPress
from kvpress.presses.think_press import ThinKPress
from kvpress.presses.tova_press import TOVAPress
from kvpress.presses.qfilter_press import QFilterPress

# Patch the attention functions to support head-wise compression
patch_attention_functions()
Expand All @@ -49,4 +50,5 @@
"ChunkPress",
"DuoAttentionPress",
"ChunkKVPress",
"QFilterPress",
]
1 change: 1 addition & 0 deletions kvpress/presses/duo_attention_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from kvpress.presses.base_press import BasePress


Comment thread
SimJeg marked this conversation as resolved.
Outdated
PATTERNS_DICT = {
"togethercomputer/Llama-2-7B-32K-Instruct": "Llama-2-7B-32K-Instruct/lr%3D0.02-reg%3D0.05-ctx%3D1000_32000-multi_passkey10", # noqa: E501
"gradientai//Llama-3-8B-Instruct-Gradient-1048k": "Llama-3-8B-Instruct-Gradient-1048k/lr%3D0.02-reg%3D0.05-ctx%3D1000_32000-multi_passkey10", # noqa: E501
Expand Down
57 changes: 57 additions & 0 deletions kvpress/presses/qfilter_press.py
Comment thread
SimJeg marked this conversation as resolved.
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Comment thread
NathanGodey marked this conversation as resolved.
Outdated
# SPDX-License-Identifier: Apache-2.0

from contextlib import contextmanager
from dataclasses import dataclass

import torch
from huggingface_hub import PyTorchModelHubMixin, get_collection

from kvpress.presses.scorer_press import ScorerPress


class QFilters(torch.nn.Module, PyTorchModelHubMixin):
def __init__(self, num_layers: int, num_kv_heads: int, kv_head_dim: int):
super().__init__()
self.q_filters = torch.nn.Parameter(torch.randn(num_layers, num_kv_heads, kv_head_dim))

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path):
return super().from_pretrained(pretrained_model_name_or_path)


@dataclass
class QFilterPress(ScorerPress):
"""
Prune KV pairs with Q-filters
"""

def __post_init_from_model__(self, model):
model_name = model.config.name_or_path.split("/")[-1]
self.q_filters = self.load_q_filters(model_name)
self.q_filters.to(model.device, model.dtype)
Comment thread
SimJeg marked this conversation as resolved.
Outdated

@staticmethod
def load_q_filters(model_name):
try:
return QFilters.from_pretrained(f"nthngdy/{model_name}_qfilt").q_filters
except TypeError:
raise ValueError(
f"Could not load Q-filters for {model_name}. Available models: {QFilterPress.available_qfilters()}"
)

@staticmethod
def available_qfilters():
collection = get_collection("nthngdy/q-filters-67a4994dcb302a3d37f3d119", token=False)
return [x.item_id.split("/")[-1][:-6] for x in collection.items]

def score(self, module, hidden_states, keys, values, attentions, kwargs):
q_filter = self.q_filters[module.layer_idx][None, :, None]
scores = -(q_filter * keys).sum(dim=-1)
return scores

@contextmanager
def __call__(self, model):
self.__post_init_from_model__(model)
with super().__call__(model):
yield
2 changes: 2 additions & 0 deletions tests/default_presses.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
StreamingLLMPress,
ThinKPress,
TOVAPress,
QFilterPress,
)


Expand All @@ -31,6 +32,7 @@ def load_attention_pattern(model):
{"cls": ExpectedAttentionPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]},
{"cls": RandomPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]},
{"cls": StreamingLLMPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]},
{"cls": QFilterPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]},
Comment thread
SimJeg marked this conversation as resolved.
{
"cls": SnapKVPress,
"kwargs": [{"compression_ratio": 0.2, "window_size": 2}, {"compression_ratio": 0.8, "window_size": 2}],
Expand Down
8 changes: 8 additions & 0 deletions tests/presses/test_qfilters_press.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from kvpress.presses.qfilter_press import QFilterPress


def test_load_qfilters():
for model_name in QFilterPress.available_qfilters():
QFilterPress.load_q_filters(model_name)