Skip to content
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

![kvpress](kvpress.jpg)

Deploying long-context LLMs is costly due to the linear growth of the key-value (KV) cache in transformer models. For example, handling 1M tokens with Llama 3.1-70B in float16 requires up to 330GB of memory. This repository implements multiple KV cache pruning methods and benchmarks using [🤗 transformers](https://huggingface.co/docs/transformers/en/index), aiming to simplify the development of new methods for researchers and developers in this field.
Deploying long-context LLMs is costly due to the linear growth of the key-value (KV) cache in transformer models. For example, handling 1M tokens with Llama 3.1-70B in float16 requires up to 330GB of memory. This repository implements multiple KV cache compression methods and benchmarks using [🤗 transformers](https://huggingface.co/docs/transformers/en/index), aiming to simplify the development of new methods for researchers and developers in this field.

## Installation

Expand Down Expand Up @@ -60,11 +60,12 @@ All current presses are training free. Several of them inherit from `ScorerPress
- `TOVAPress`: attention weight of the last query averaged across heads ([paper](https://arxiv.org/abs/2401.06104))
- `ObservedAttentionPress`: average attention weight observed during in pre-filling phase (similar to [H2O](https://arxiv.org/abs/2306.14048))

Some presses relying on a different logic:
Some presses rely on a different logic:
- `ThinKPress`: compress the dimensions of the keys based on the channel attention score on the last queries ([paper](https://arxiv.org/pdf/2407.21018))
- `SimLayerKVPress`: identify "lazy" layers, and apply the StreamingLLM approach to them ([paper](https://arxiv.org/abs/2410.13846))

Finally we provide special presses:
- `AdaKVPress`: prune bottom scores of any `ScorerPress` but across all heads, achieving head-wise compressions (see [paper](https://arxiv.org/abs/2407.11550))
- `PerLayerCompressionPress`: compress each layer with a different compression ratio (experimental). This press can be used with any other press that allows to set a compression_ratio
- `ComposedPress`: compose multiple presses together by chaining their forward hooks
- `KeyRerotationPress`: rerotate pruned keys to have continuous RoPE embeddings. This press can be used with any other press that inherits from `ScorerPress`.
Expand Down Expand Up @@ -101,7 +102,7 @@ pipe(..., cache=cache)
By default, the `DynamicCache` is used (no quantization).

> [!IMPORTANT]
> To use the `QuantizedCache`, you need to install additional dependencies (e.g. `pip install optimum-quanto`, see also [this issue](https://github.com/huggingface/transformers/issues/34848)).
> To use the `QuantizedCache`, you need to install additional dependencies (e.g. `pip install optimum-quanto`).


## FAQ
Expand Down
23 changes: 14 additions & 9 deletions evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@
from zero_scrolls.calculate_metrics import calculate_metrics as zero_scrolls_scorer

from kvpress import (
AdaKVPress,
ExpectedAttentionPress,
KnormPress,
ObservedAttentionPress,
RandomPress,
SnapKVPress,
StreamingLLMPress,
ThinKPress,
TOVAPress,
)

logger = logging.getLogger(__name__)
Expand All @@ -42,12 +45,16 @@
}

PRESS_DICT = {
"adasnapkv": AdaKVPress(SnapKVPress()),
"ada_expected_attention": AdaKVPress(ExpectedAttentionPress()),
"expected_attention": ExpectedAttentionPress(),
"knorm": KnormPress(),
"observed_attention": ObservedAttentionPress(),
"random": RandomPress(),
"snapkv": SnapKVPress(),
"streaming_llm": StreamingLLMPress(),
"think": ThinKPress(),
"tova": TOVAPress(),
}


Expand Down Expand Up @@ -110,6 +117,7 @@ def evaluate(
df = load_dataset(DATASET_DICT[dataset], data_dir=data_dir, split="test").to_pandas()
if fraction < 1.0:
df = df.sample(frac=fraction, random_state=42)
save_filename = save_filename.with_name(save_filename.stem + f"__fraction{fraction:.2f}" + save_filename.suffix)

if compress_questions:
df["context"] = df["context"] + df["question"]
Expand All @@ -122,24 +130,21 @@ def evaluate(
press.compression_ratio = compression_ratio

# Initialize pipeline with the correct attention implementation
model_kwargs = {"torch_dtype": "auto"}
if isinstance(press, ObservedAttentionPress):
model_kwargs = {"attn_implementation": "eager"}
model_kwargs["attn_implementation"] = "eager"
else:
try:
import flash_attn # noqa: F401

model_kwargs = {"attn_implementation": "flash_attention_2"}
model_kwargs["attn_implementation"] = "flash_attention_2"
except ImportError:
model_kwargs = {}
pass

if device == "auto":
pipe = pipeline(
"kv-press-text-generation", model=model, device_map="auto", torch_dtype="auto", model_kwargs=model_kwargs
)
pipe = pipeline("kv-press-text-generation", model=model, device_map="auto", model_kwargs=model_kwargs)
else:
pipe = pipeline(
"kv-press-text-generation", model=model, device=device, torch_dtype="auto", model_kwargs=model_kwargs
)
pipe = pipeline("kv-press-text-generation", model=model, device=device, model_kwargs=model_kwargs)

# Run pipeline on each context
df["predicted_answer"] = None
Expand Down
5 changes: 5 additions & 0 deletions kvpress/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@


from kvpress.pipeline import KVPressTextGenerationPipeline
from kvpress.presses.adakv_press import AdaKVPress
from kvpress.presses.base_press import BasePress
from kvpress.presses.composed_press import ComposedPress
from kvpress.presses.expected_attention_press import ExpectedAttentionPress
Expand All @@ -18,7 +19,11 @@
from kvpress.presses.think_press import ThinKPress
from kvpress.presses.tova_press import TOVAPress

from kvpress.attention_patch import patch_attention_functions
patch_attention_functions()

__all__ = [
"AdaKVPress",
"BasePress",
"ComposedPress",
"ScorerPress",
Expand Down
57 changes: 57 additions & 0 deletions kvpress/attention_patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import torch
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS


def search_hyperplane(X, max_iter: int = 1000):
"""
Given a tensor X of shape (bsz, seq_len, head_dim), search for an hyperplane Y (bsz, head_dim)
such that for every i, <X[:, i], Y> <= 0. Returns - 1e5 * Y / ||Y|| ** 2 to ensure exp(<X, Y>) = 0
Raises a ValueError if no such hyperplane is found
"""
Y = X.mean(1) # this initialization is enough for most cases
for _ in range(max_iter):
mask = torch.bmm(X, Y.unsqueeze(-1)) <= 0
if not mask.any():
return -1e5 * Y / Y.norm(dim=-1, keepdim=True) ** 2
Y += (X * mask).sum(1) / mask.sum(1).clamp(min=1)
raise ValueError("Could not find fake keys such that for every query q, exp(<q, k>) = 0")


def attention_patch(func):
"""
Decorator to udpate the keys before the attention computation at the indices provided in module.masked_key_indices
The keys are updated with a fake key k such that exp(<q, k>) = 0 to fake head-wise compression
This solution is not optimal as it does not reduce peak memory and slightly increase runtime
"""

def wrapper(module, query, key, value, attention_mask, dropout, scaling=None, is_causal=None, **kwargs):
if query.shape[2] == key.shape[2]:
# Prefilling
module.masked_key_indices = None
elif module.masked_key_indices is not None:
# Decoding: build fake keys k s.t. exp(<q, k>) = 0
bsz, num_heads, seq_len, head_dim = query.shape
num_key_value_heads = key.shape[1]
num_groups = num_heads // num_key_value_heads

# Build a fake key k per key group such that for every query q, exp(<q, k>) = 0
q = query.view(bsz, num_key_value_heads, num_groups, seq_len, head_dim)
q = q.reshape(bsz * num_key_value_heads, num_groups * seq_len, head_dim)
k = search_hyperplane(q)
k = k.view(bsz, num_key_value_heads, head_dim)

# At indices, update the keys to the fake keys and the values to 0
key[*module.masked_key_indices] = k[*module.masked_key_indices[:2]]
Comment thread
SimJeg marked this conversation as resolved.
Outdated

return func(module, query, key, value, attention_mask, dropout, scaling, is_causal, **kwargs)

return wrapper


def patch_attention_functions():
"""
Add the attention_patch decorator to functions in ALL_ATTENTION_FUNCTIONS
"""

for name, func in ALL_ATTENTION_FUNCTIONS.items():
ALL_ATTENTION_FUNCTIONS[name] = attention_patch(func)
Comment thread
SimJeg marked this conversation as resolved.
55 changes: 55 additions & 0 deletions kvpress/presses/adakv_press.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0


from dataclasses import dataclass

import torch

from kvpress.presses.base_press import BasePress
from kvpress.presses.scorer_press import ScorerPress


@dataclass
class AdaKVPress(BasePress):
"""
AdaKV (https://arxiv.org/abs/2407.11550) selects the top-k keys and values among all heads in a layer
based on the scores, achieving head-specific compression.
A safeguard is applied to ensure a minimum fraction of KV pairs per head (alpha_safeguard parameter)
"""

scorer: ScorerPress
alpha_safeguard: float = 0.20

@property
def compression_ratio(self):
return self.scorer.compression_ratio

@compression_ratio.setter
def compression_ratio(self, value):
self.scorer.compression_ratio = value

def compress(self, module, hidden_states, keys, values, attentions, kwargs):
if self.compression_ratio == 0:
return keys, values

assert module.config._attn_implementation != "eager", "eager mode not supported"
assert isinstance(self.scorer, ScorerPress), "AdaKVPress requires a ScorerPress as input"

# Compute scores
scores = self.scorer.score(module, hidden_states, keys, values, attentions, kwargs)
bsz, num_key_value_heads, q_len = scores.shape

# Make sure to keep at least alpha * (1 - compression_ratio) KV pairs per head
n_kept = int(q_len * (1 - self.compression_ratio)) # ScorerPress definition
n_safe = int(n_kept * self.alpha_safeguard)
top_indices = torch.topk(scores, n_safe, dim=-1).indices
scores.scatter_(-1, top_indices, torch.finfo(scores.dtype).max)

# Compute bottom-k across heads
n_pruned = num_key_value_heads * (q_len - n_kept)
indices = torch.topk(-scores.reshape(bsz, -1), n_pruned, dim=1).indices.flatten()

# Save indices for attention patching in the module
module.masked_key_indices = (torch.arange(bsz).repeat_interleave(n_pruned), indices // q_len, indices % q_len)
return keys, values
12 changes: 4 additions & 8 deletions kvpress/presses/base_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,17 +86,13 @@ def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dic
Modified output of the forward pass of the layer.

"""
# See e.g. LlamaDecoderLayer.forward for the output structure
if len(output) == 3:
_, attentions, cache = output
else:
attentions, cache = None, output[-1]

hidden_states = kwargs["hidden_states"]
cache = kwargs["past_key_value"]
q_len = hidden_states.shape[1]

# Don't compress after pre-filling
if cache.seen_tokens > q_len:
if kwargs["cache_position"][-1] > q_len:
return output

if isinstance(cache, QuantizedCache):
Expand All @@ -106,7 +102,7 @@ def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dic
keys = cache.key_cache[module.layer_idx]
values = cache.value_cache[module.layer_idx]

keys, values = self.compress(module, hidden_states, keys, values, attentions, kwargs)
keys, values = self.compress(module, hidden_states, keys, values, output[1], kwargs)

if isinstance(cache, QuantizedCache):
cache._quantized_key_cache[module.layer_idx] = cache._quantize(keys, axis=cache.axis_key)
Expand Down Expand Up @@ -138,8 +134,8 @@ def __call__(self, model: PreTrainedModel) -> Generator:
hooks = []
try:
for layer in model.model.layers:
layer.self_attn.rotary_emb = model.model.rotary_emb
Comment thread
maxjeblick marked this conversation as resolved.
hooks.append(layer.self_attn.register_forward_hook(self.forward_hook, with_kwargs=True))

yield
finally:
for forward_hook in hooks:
Expand Down
5 changes: 3 additions & 2 deletions kvpress/presses/composed_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from kvpress.presses.base_press import BasePress
from kvpress.presses.observed_attention_press import ObservedAttentionPress
from kvpress.presses.adakv_press import AdaKVPress


@dataclass
Expand All @@ -15,8 +16,8 @@ class ComposedPress(BasePress):
def __post_init__(self):
self.compression_ratio = None
assert not any(
isinstance(press, ObservedAttentionPress) for press in self.presses
), "ComposedPress cannot contains ObservedAttentionPress because attentions pruning is not handled"
isinstance(press, (ObservedAttentionPress, AdaKVPress)) for press in self.presses
), "ComposedPress cannot contains ObservedAttentionPress or AdaKVPress"

def forward_hook(self, module, input, kwargs, output):
self.compression_ratio = 1.0
Expand Down
19 changes: 8 additions & 11 deletions kvpress/presses/expected_attention_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# SPDX-License-Identifier: Apache-2.0


import inspect
import math
from dataclasses import dataclass

Expand Down Expand Up @@ -39,7 +38,7 @@ def get_query_statistics(self, module: nn.Module, hidden_states: torch.Tensor):
"""

bsz, q_len, _ = hidden_states.shape
n, d = module.num_heads, module.head_dim
n, d = module.config.num_attention_heads, module.config.head_dim

# Remove first hidden_states that likely contain outliers
h = hidden_states[:, self.n_sink :]
Expand All @@ -66,13 +65,9 @@ def get_query_statistics(self, module: nn.Module, hidden_states: torch.Tensor):
cov = cov.permute(0, 3, 1, 2)

# RoPE rotation matrix on next n_future_positions
if "position_ids" in inspect.signature(module.rotary_emb.forward).parameters:
position_ids = torch.arange(q_len, q_len + self.n_future_positions).unsqueeze(0).to(mu.device)
cos, sin = module.rotary_emb(mu, position_ids)
cos, sin = cos[0], sin[0]
else:
cos, sin = module.rotary_emb(mu, q_len + self.n_future_positions)
cos, sin = cos[q_len:], sin[q_len:]
position_ids = torch.arange(q_len, q_len + self.n_future_positions).unsqueeze(0).to(mu.device)
cos, sin = module.rotary_emb(mu, position_ids)
cos, sin = cos[0], sin[0]

Id = torch.eye(d, device=cos.device, dtype=cos.dtype)
P = torch.zeros((d, d), device=cos.device, dtype=cos.dtype)
Expand Down Expand Up @@ -117,14 +112,16 @@ def score(

# Compute scores
bsz, num_key_value_heads, q_len, d = keys.shape
keys = repeat_kv(keys, module.num_key_value_groups).transpose(2, 3)
num_key_value_groups = module.config.num_attention_heads // num_key_value_heads

keys = repeat_kv(keys, num_key_value_groups).transpose(2, 3)
scores = torch.matmul(mean_query.unsqueeze(2), keys).squeeze(2) / math.sqrt(d)
if self.use_covariance:
scores += torch.einsum("bhin, bhij, bhjn->bhn", keys, cov_query, keys) / d / 2
scores = F.softmax(scores, dim=-1)

# Average scores across groups
scores = scores.view(bsz, num_key_value_heads, module.num_key_value_groups, q_len)
scores = scores.view(bsz, num_key_value_heads, num_key_value_groups, q_len)
scores = scores.mean(dim=2)

# Rescale scores by the norm of the values
Expand Down
20 changes: 5 additions & 15 deletions kvpress/presses/key_rerotation_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# SPDX-License-Identifier: Apache-2.0


import inspect
from dataclasses import dataclass

import torch
Expand Down Expand Up @@ -40,16 +39,18 @@ def compress(
if self.press.compression_ratio == 0:
return keys, values

assert isinstance(self.press, ScorerPress)

# Compute scores from base press
scores = self.press.score(module, hidden_states, keys, values, attentions, kwargs)

# Get indices of KV pairs with the lowest scores
q_len = hidden_states.shape[1]
n_kept = int(q_len * (1 - self.press.compression_ratio))
indices = scores.topk(n_kept, dim=-1).indices
indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim)
indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.config.head_dim)

cos, sin = get_rope_embeddings(module, keys)
cos, sin = kwargs["position_embeddings"]
# Rerotate as follows
# 1. keys = RoPE(W_k * hidden_states)
# 2. keys_unrotated = RoPE^-1(keys)
Expand All @@ -61,19 +62,8 @@ def compress(
# 3. Prune keys
keys = keys.gather(2, indices).contiguous()
# 4. Apply RoPE
cos, sin = get_rope_embeddings(module, keys)
cos, sin = cos[:, :n_kept], sin[:, :n_kept]
keys = (keys * cos.unsqueeze(1)) + (rotate_half(keys) * sin.unsqueeze(1))

values = values.gather(2, indices).contiguous()
return keys, values


def get_rope_embeddings(module, x):
length = x.shape[2]
# rotary_emb function only needs .device and .dtype, so we can plug in any tensor regardless of shape
if "position_ids" in inspect.signature(module.rotary_emb.forward).parameters:
position_ids = torch.arange(length).unsqueeze(0).to(x.device)
cos, sin = module.rotary_emb(x, position_ids)
else:
cos, sin = module.rotary_emb(x, length)
return cos, sin
Loading