Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions kvpress/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from kvpress.presses.pyramidkv_press import PyramidKVPress
from kvpress.presses.finch_press import FinchPress
from kvpress.presses.lagkv_press import LagKVPress
from kvpress.presses.base_press import SUPPORTED_MODELS

# Patch the attention functions to support head-wise compression
patch_attention_functions()
Expand Down
20 changes: 18 additions & 2 deletions kvpress/presses/base_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,21 @@
PreTrainedModel,
QuantizedCache,
Qwen2ForCausalLM,
Qwen3ForCausalLM,
Gemma3ForCausalLM,
)

logger = logging.getLogger(__name__)

SUPPORTED_MODELS = (
LlamaForCausalLM,
MistralForCausalLM,
Phi3ForCausalLM,
Qwen2ForCausalLM,
Qwen3ForCausalLM,
Gemma3ForCausalLM,
)


@dataclass
class BasePress:
Expand Down Expand Up @@ -127,13 +138,18 @@ def __call__(self, model: PreTrainedModel) -> Generator:
model : PreTrainedModel
Model to apply the compression method to
"""
if not isinstance(model, SUPPORTED_MODELS):
logger.warning(f"Model {type(model)} not tested, supported models: {SUPPORTED_MODELS}")

if not isinstance(model, (LlamaForCausalLM, MistralForCausalLM, Phi3ForCausalLM, Qwen2ForCausalLM)):
logger.warning(f"Model {type(model)} not tested")
if isinstance(model, Gemma3ForCausalLM):
logger.warning("Compression in Gemma3 is only applied to layer without sliding window attention")

hooks = []
try:
for layer in model.model.layers:
if isinstance(model, Gemma3ForCausalLM) and layer.is_sliding:
# Skip layers with sliding window attention, only for Gemma3
Comment thread
SimJeg marked this conversation as resolved.
continue
layer.self_attn.rotary_emb = model.model.rotary_emb
hooks.append(layer.self_attn.register_forward_hook(self.forward_hook, with_kwargs=True))
yield
Expand Down
6 changes: 6 additions & 0 deletions kvpress/presses/duo_attention_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from datasets import load_dataset
from transformers import AutoTokenizer
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
from transformers.models.qwen3.modeling_qwen3 import Qwen3Attention
from transformers.models.gemma3.modeling_gemma3 import Gemma3Attention

from kvpress.presses.base_press import BasePress

Expand Down Expand Up @@ -161,12 +163,16 @@ def duo_attention_on_the_fly(model, num_samples=50, q_len=500):
# Mean query
q = module.self_attn.q_proj(h)
q = q.view(1, q.shape[1], -1, d)
if isinstance(module, (Gemma3Attention, Qwen3Attention)):
Comment thread
SimJeg marked this conversation as resolved.
q = module.q_norm(q)
q = q.mean(dim=1, keepdim=True)
q = q.repeat(1, q_len, 1, 1).transpose(1, 2)

# Mean key
k = module.self_attn.k_proj(h)
k = k.view(1, k.shape[1], -1, d)
if isinstance(module, (Gemma3Attention, Qwen3Attention)):
k = module.k_norm(k)
k = k.mean(dim=1, keepdim=True)
k = k.repeat(1, q_len, 1, 1).transpose(1, 2)

Expand Down
15 changes: 11 additions & 4 deletions kvpress/presses/expected_attention_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from torch import nn
from torch.nn import functional as F
from transformers.models.llama.modeling_llama import repeat_kv
from transformers.models.qwen3.modeling_qwen3 import Qwen3Attention
from transformers.models.gemma3.modeling_gemma3 import Gemma3Attention
from transformers.models.phi3.modeling_phi3 import Phi3Attention

from kvpress.presses.scorer_press import ScorerPress

Expand Down Expand Up @@ -44,10 +47,14 @@ def get_query_statistics(self, module: nn.Module, hidden_states: torch.Tensor):
# Remove first hidden_states that likely contain outliers
h = hidden_states[:, self.n_sink :]

if hasattr(module, "q_proj"):
Wq = module.q_proj.weight
elif hasattr(module, "qkv_proj"):
Wq = module.qkv_proj.weight[: n * d] # type: ignore[index]
if isinstance(module, (Qwen3Attention, Gemma3Attention)):
# Qwen and Gemma use QK norm, which is not compatible with ExpectedAttentionPress (for now)
raise NotImplementedError(f"ExpectedAttentionPress not yet implemented for {module.__class__}.")
elif isinstance(module, Phi3Attention):
Wq = module.qkv_proj.weight[: n * d]
elif hasattr(module, "q_proj"):
# Assume Llama-like attention layer
Wq = module.q_proj.weight # type: ignore[assignment]
else:
raise NotImplementedError(f"ExpectedAttentionPress not yet implemented for {module.__class__}.")

Expand Down
14 changes: 11 additions & 3 deletions kvpress/presses/snapkv_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from torch import nn
from torch.nn import functional as F
from transformers.models.llama.modeling_llama import repeat_kv, rotate_half
from transformers.models.qwen3.modeling_qwen3 import Qwen3Attention
from transformers.models.gemma3.modeling_gemma3 import Gemma3Attention
from transformers.models.phi3.modeling_phi3 import Phi3Attention

from kvpress.presses.scorer_press import ScorerPress

Expand Down Expand Up @@ -37,16 +40,21 @@ def compute_window_attention(module, hidden_states, keys, window_size, position_
num_key_value_groups = num_heads // module.config.num_key_value_heads

# Get last window_size queries
if hasattr(module, "q_proj"):
query_states = module.q_proj(hidden_states[:, -window_size:])
elif hasattr(module, "qkv_proj"):
if isinstance(module, Phi3Attention):
qkv = module.qkv_proj(hidden_states[:, -window_size:])
query_states = qkv[..., : num_heads * head_dim]
elif hasattr(module, "q_proj"):
# Assume Llama-like attention layer
query_states = module.q_proj(hidden_states[:, -window_size:])
else:
raise NotImplementedError(f"SnapKV not yet implemented for {module.__class__}.")

query_states = query_states.view(bsz, window_size, num_heads, head_dim).transpose(1, 2)

# Support for Qwen3 and Gemma3 QK norm
if isinstance(module, (Qwen3Attention, Gemma3Attention)):
query_states = module.q_norm(query_states)

# Apply RoPE
cos, sin = position_embeddings
cos, sin = cos[:, -window_size:], sin[:, -window_size:]
Expand Down
18 changes: 13 additions & 5 deletions kvpress/presses/think_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
import torch
from torch import nn
from transformers.models.llama.modeling_llama import rotate_half
from transformers.models.qwen3.modeling_qwen3 import Qwen3Attention
from transformers.models.gemma3.modeling_gemma3 import Gemma3Attention
from transformers.models.phi3.modeling_phi3 import Phi3Attention

from kvpress.presses.base_press import BasePress

Expand Down Expand Up @@ -36,17 +39,22 @@ def compute_window_queries(self, module, hidden_states, position_embeddings):
num_heads = module.config.num_attention_heads
head_dim = module.head_dim

# Get last window_size queries
if hasattr(module, "q_proj"):
query_states = module.q_proj(hidden_states[:, -self.window_size :])
elif hasattr(module, "qkv_proj"):
qkv = module.qkv_proj(hidden_states[:, -self.window_size :])
# Get last self.window_size queries
if isinstance(module, Phi3Attention):
qkv = module.qkv_proj(hidden_states[:, -self.window_size:])
query_states = qkv[..., : num_heads * head_dim]
elif hasattr(module, "q_proj"):
# Assume Llama-like attention layer
query_states = module.q_proj(hidden_states[:, -self.window_size:])
else:
raise NotImplementedError(f"SnapKV not yet implemented for {module.__class__}.")

query_states = query_states.view(bsz, self.window_size, num_heads, head_dim).transpose(1, 2)

# Support for Qwen3 and Gemma3 QK norm
if isinstance(module, (Qwen3Attention, Gemma3Attention)):
query_states = module.q_norm(query_states)

# Apply RoPE
cos, sin = position_embeddings
cos, sin = cos[:, -self.window_size :], sin[:, -self.window_size :]
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
name = "kvpress"
authors = ["Simon Jegou", "Maximilian Jeblick", "Jiwei Liu", "David Austin"]
description = "Efficiently compress the KV cache of any pretrained transformer"
version = "0.2.5"
version = "0.2.6"
readme = "README.md"

[tool.poetry.dependencies]
Expand Down