Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
RandomPress,
SnapKVPress,
StreamingLLMPress,
ThinKPress,
)

logger = logging.getLogger(__name__)
Expand All @@ -48,6 +49,7 @@
"random": RandomPress(),
"snapkv": SnapKVPress(),
"streaming_llm": StreamingLLMPress(),
"think": ThinKPress(),
}


Expand Down
2 changes: 2 additions & 0 deletions kvpress/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from kvpress.presses.snapkv_press import SnapKVPress
from kvpress.presses.streaming_llm_press import StreamingLLMPress
from kvpress.presses.tova_press import TOVAPress
from kvpress.presses.think_press import ThinKPress

__all__ = [
"BasePress",
Expand All @@ -22,6 +23,7 @@
"SnapKVPress",
"StreamingLLMPress",
"TOVAPress",
"ThinKPress",
"KVPressTextGenerationPipeline",
"apply_per_layer_compression",
]
29 changes: 20 additions & 9 deletions kvpress/presses/base_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,15 +108,26 @@ def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dic

with torch.no_grad():
scores = self.score(module, hidden_states, keys, values, attentions, kwargs)

# Prune KV pairs with the lowest scores
n_kept = int(q_len * (1 - self.compression_ratio))
indices = scores.topk(n_kept, dim=-1).indices
indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim)

# Update cache
keys = keys.gather(2, indices).contiguous()
values = values.gather(2, indices).contiguous()
# ThinK uses the channel wise pruning
if self.__class__.__name__ == "ThinKPress":
n_prune = int(module.head_dim * self.compression_ratio)
if keys.shape[1] != scores.shape[1]:
scores = scores.view(scores.shape[0], keys.shape[1], -1, scores.shape[-1]).sum(dim=-2)
_, indices = torch.topk(scores, n_prune, dim=-1, largest=False)
keep_idx = indices.sort().values
mask = torch.zeros(scores.shape, dtype=torch.bool).to(scores.device)
mask_k = mask.scatter(-1, keep_idx, 1)
mask_k = mask_k.unsqueeze(2).expand(-1, -1, q_len - self.window_size, -1)
keys = torch.cat([keys[:, :, :q_len - self.window_size, :].masked_fill(mask_k, 0), keys[:, :, q_len - self.window_size:, :]], dim=-2)
else:
# Prune KV pairs with the lowest scores
n_kept = int(q_len * (1 - self.compression_ratio))
indices = scores.topk(n_kept, dim=-1).indices
indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim)

# Update cache
keys = keys.gather(2, indices).contiguous()
values = values.gather(2, indices).contiguous()
if isinstance(cache, QuantizedCache):
cache._quantized_key_cache[module.layer_idx] = cache._quantize(keys, axis=cache.axis_key)
cache._quantized_value_cache[module.layer_idx] = cache._quantize(values, axis=cache.axis_value)
Expand Down
85 changes: 85 additions & 0 deletions kvpress/presses/think_press.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0


import inspect
import math
from dataclasses import dataclass

import torch
import torch.nn.functional as F
from torch import nn
from transformers.models.llama.modeling_llama import repeat_kv, rotate_half

from kvpress.presses.base_press import BasePress


@dataclass
class ThinKPress(BasePress):
"""
SnapKV (https://arxiv.org/abs/2404.14469) use the attention of the latest window_size tokens to estimate the
importance of the previous KV pairs. We use the default settings from:
https://github.com/FasterDecoding/SnapKV/blob/main/snapkv/monkeypatch/snapkv_utils.py#L24
"""

compression_ratio: float = 0.0
window_size: int = 64
kernel_size: int = 5

def compute_window_attention(self, module, hidden_states, keys):
"""
Compute the last window_size queries and associated attention weights for the first q_len - window_size keys.
"""

bsz, q_len, _ = hidden_states.shape

# Get last window_size queries
if hasattr(module, "q_proj"):
query_states = module.q_proj(hidden_states[:, -self.window_size :])
elif hasattr(module, "qkv_proj"):
qkv = module.qkv_proj(hidden_states[:, -self.window_size :])
query_states = qkv[..., : module.num_heads * module.head_dim]
else:
raise NotImplementedError(f"SnapKV not yet implemented for {module.__class__}.")

query_states = query_states.view(bsz, self.window_size, module.num_heads, module.head_dim).transpose(1, 2)

# Apply RoPE
if "position_ids" in inspect.signature(module.rotary_emb.forward).parameters:
position_ids = torch.arange(q_len - self.window_size, q_len).unsqueeze(0).to(query_states.device)
cos, sin = module.rotary_emb(query_states, position_ids)
else:
cos, sin = module.rotary_emb(query_states, q_len)
cos, sin = cos[-self.window_size :].unsqueeze(0), sin[-self.window_size :].unsqueeze(0)
query_states = (query_states * cos) + (rotate_half(query_states) * sin)

# Compute attention for first q_len - window_size tokens
key_states = repeat_kv(keys, module.num_key_value_groups)
# attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(module.head_dim)
channel_attn = torch.matmul(query_states.permute(0, 1, 3, 2).unsqueeze(-1), key_states.transpose(2, 3).unsqueeze(-2))
# attention_mask = torch.ones_like(attn_weights) * float("-inf")
# attention_mask = torch.triu(attention_mask, diagonal=q_len - self.window_size + 1)
# attn_weights += attention_mask
# attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
# attn_weights = attn_weights[..., : -self.window_size]

return channel_attn

def score(
self,
module: nn.Module,
hidden_states: torch.Tensor,
keys: torch.Tensor,
values: torch.Tensor,
attentions: torch.Tensor,
kwargs,
) -> torch.Tensor:

bsz, num_key_value_heads, q_len, _ = keys.shape

assert q_len > self.window_size, "Query length should be greater than the window size"

channel_attn = self.compute_window_attention(module, hidden_states, keys)
channel_score = channel_attn.pow_(2).sum(dim=(-1, -2))

return channel_score