Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ All current presses are training free. We provide the following presses associat
- `ExpectedAttentionPress` (ours): expected attention weight during the generation phase (see [this notebook](notebooks/expected_attention.ipynb))
- `StreamingLLMPress`: keep only the first and last tokens ([paper](https://arxiv.org/abs/2309.17453))
- `TOVAPress`: attention weight of the last query averaged across heads ([paper](https://arxiv.org/abs/2401.06104))
- `ThinKPress`: compress the dimension of the keys based on the channel attention score on the last 64 queries ([paper](https://arxiv.org/pdf/2407.21018)). Can be combined with any of the presses above.

For a detailed list of existing KV cache compression methods, check [Awesome-KV-Cache-Compression](https://github.com/October2001/Awesome-KV-Cache-Compression) or [Awesome-LLM-Compression](https://github.com/HuangOwen/Awesome-LLM-Compression?tab=readme-ov-file#kv-cache-compression)

Expand Down
2 changes: 2 additions & 0 deletions kvpress/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from kvpress.presses.snapkv_press import SnapKVPress
from kvpress.presses.streaming_llm_press import StreamingLLMPress
from kvpress.presses.tova_press import TOVAPress
from kvpress.presses.think_press import ThinKPress

__all__ = [
"BasePress",
Expand All @@ -21,6 +22,7 @@
"RandomPress",
"SnapKVPress",
"StreamingLLMPress",
"ThinKPress",
"TOVAPress",
"KVPressTextGenerationPipeline",
"apply_per_layer_compression",
Expand Down
111 changes: 111 additions & 0 deletions kvpress/presses/think_press.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0


import inspect
from dataclasses import dataclass
from typing import Optional

import torch
from torch import nn
from transformers.cache_utils import QuantizedCache
from transformers.models.llama.modeling_llama import rotate_half

from kvpress.presses.base_press import BasePress


@dataclass
class ThinKPress(BasePress):
"""
ThinK (https://arxiv.org/pdf/2407.21018) compresses the dimensions of the keys, and not the sequence length.
Hence it can be combined with any other press that compresses the sequence length, e.g.
press = ThinKPress(compression_ratio=0.5, inner_press=SnapKVPress(compression_ratio=0.5))

Here, we zero out the pruned dimensions resulting in no memory gain (the shape of the keys remains the same).
To achieve memory savings, several options can be considered (see https://github.com/NVIDIA/kvpress/pull/18/),
we might implement them in the future, especially if other similar presses are requested.

This press has been reviewed by Yuhui Xu, first author of the ThinK paper.
"""

compression_ratio: float = 0.0
inner_press: Optional[BasePress] = None
window_size: int = 64

def compute_window_queries(self, module, hidden_states):
Comment thread
SimJeg marked this conversation as resolved.
"""
Re-compute the last window_size query states
"""

bsz, q_len, _ = hidden_states.shape

# Get last window_size queries
if hasattr(module, "q_proj"):
query_states = module.q_proj(hidden_states[:, -self.window_size :])
elif hasattr(module, "qkv_proj"):
qkv = module.qkv_proj(hidden_states[:, -self.window_size :])
query_states = qkv[..., : module.num_heads * module.head_dim]
else:
raise NotImplementedError(f"SnapKV not yet implemented for {module.__class__}.")

query_states = query_states.view(bsz, self.window_size, module.num_heads, module.head_dim).transpose(1, 2)

# Apply RoPE
if "position_ids" in inspect.signature(module.rotary_emb.forward).parameters:
position_ids = torch.arange(q_len - self.window_size, q_len).unsqueeze(0).to(query_states.device)
cos, sin = module.rotary_emb(query_states, position_ids)
else:
cos, sin = module.rotary_emb(query_states, q_len)
cos, sin = cos[-self.window_size :].unsqueeze(0), sin[-self.window_size :].unsqueeze(0)
Comment thread
maxjeblick marked this conversation as resolved.
Outdated
query_states = (query_states * cos) + (rotate_half(query_states) * sin)

return query_states

def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dict, output: list):
"""
We first apply the inner press, then we prune the key dimensions. If other similar presses are requested,
we will create a dedicated DimensionBasePress class to avoid code duplication.
"""

# Apply the forward hook of the inner press
if self.inner_press is not None:
output = self.inner_press.forward_hook(module, input, kwargs, output)

# Don't compress if the compression ratio is 0 or this is not pre-filling
cache = output[-1]
hidden_states = kwargs["hidden_states"]
q_len = hidden_states.shape[1]
assert q_len > self.window_size, "Query length should be greater than the window size"

if (self.compression_ratio == 0) or (cache.seen_tokens > q_len):
return output

# Get keys
if isinstance(cache, QuantizedCache):
keys = cache._dequantize(cache._quantized_key_cache[module.layer_idx])
else:
keys = cache.key_cache[module.layer_idx]
bsz, num_key_value_heads, q_len, head_dim = keys.shape

# ThinK specific code
queries = self.compute_window_queries(module, kwargs["hidden_states"])

# Compute scores per dimension
queries_norm = torch.pow(queries, 2).mean(dim=2) # (bsz, num_heads, head_dim)
queries_norm = queries_norm.view(bsz, num_key_value_heads, module.num_key_value_groups, module.head_dim).mean(2)
keys_norm = torch.pow(keys, 2).mean(dim=2)
key_scores = queries_norm * keys_norm # (bsz, num_key_value_heads, head_dim)

# Prune dimensions with the lowest scores by setting them to 0
n_pruned = int(head_dim * self.compression_ratio)
indices = key_scores.topk(n_pruned, dim=-1, largest=False).indices
indices = indices.unsqueeze(2).expand(-1, -1, q_len, -1)
keys = keys.scatter_(-1, indices, 0)

# Update cache
if isinstance(cache, QuantizedCache):
cache._quantized_key_cache[module.layer_idx] = cache._quantize(keys, axis=cache.axis_key)
else:
cache.key_cache[module.layer_idx] = keys

return output
6 changes: 3 additions & 3 deletions notebooks/per_layer_compression_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -216,9 +216,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "kvpress_2",
"display_name": ".venv",
"language": "python",
"name": "kvpress_2"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -230,7 +230,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
"version": "3.10.12"
}
},
"nbformat": 4,
Expand Down
6 changes: 4 additions & 2 deletions tests/presses/test_presses.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,17 @@
SnapKVPress,
StreamingLLMPress,
TOVAPress,
ThinKPress,
)

from tests.fixtures import unit_test_model, unit_test_model_output_attention # noqa: F401


def test_presses_run(unit_test_model): # noqa: F811
for cls in [KnormPress, ExpectedAttentionPress, RandomPress, StreamingLLMPress, SnapKVPress, TOVAPress]:
for cls in [KnormPress, ExpectedAttentionPress, RandomPress, StreamingLLMPress, SnapKVPress, TOVAPress, ThinKPress]:
for compression_ratio in [0.2, 0.4, 0.6, 0.8]:
press = cls(compression_ratio=compression_ratio)
if cls == SnapKVPress:
if cls in [SnapKVPress, ThinKPress]:
press.window_size = 2
with press(unit_test_model):
input_ids = unit_test_model.dummy_inputs["input_ids"]
Expand Down