Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions kvpress/presses/base_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
PreTrainedModel,
QuantizedCache,
Qwen2ForCausalLM,
Qwen3ForCausalLM,
Gemma3ForCausalLM,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -127,13 +129,23 @@ def __call__(self, model: PreTrainedModel) -> Generator:
model : PreTrainedModel
Model to apply the compression method to
"""

if not isinstance(model, (LlamaForCausalLM, MistralForCausalLM, Phi3ForCausalLM, Qwen2ForCausalLM)):
supported_models = (
Comment thread
SimJeg marked this conversation as resolved.
Outdated
LlamaForCausalLM,
MistralForCausalLM,
Phi3ForCausalLM,
Qwen2ForCausalLM,
Qwen3ForCausalLM,
Gemma3ForCausalLM,
)
if not isinstance(model, supported_models):
logger.warning(f"Model {type(model)} not tested")

hooks = []
try:
for layer in model.model.layers:
if isinstance(model, Gemma3ForCausalLM) and layer.is_sliding:
# Skip layers with sliding window attention, only for Gemma3
Comment thread
SimJeg marked this conversation as resolved.
continue
layer.self_attn.rotary_emb = model.model.rotary_emb
hooks.append(layer.self_attn.register_forward_hook(self.forward_hook, with_kwargs=True))
yield
Expand Down
6 changes: 6 additions & 0 deletions kvpress/presses/duo_attention_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from datasets import load_dataset
from transformers import AutoTokenizer
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
from transformers.models.qwen3.modeling_qwen3 import Qwen3Attention
from transformers.models.gemma3.modeling_gemma3 import Gemma3Attention

from kvpress.presses.base_press import BasePress

Expand Down Expand Up @@ -161,12 +163,16 @@ def duo_attention_on_the_fly(model, num_samples=50, q_len=500):
# Mean query
q = module.self_attn.q_proj(h)
q = q.view(1, q.shape[1], -1, d)
if isinstance(module, (Gemma3Attention, Qwen3Attention)):
Comment thread
SimJeg marked this conversation as resolved.
q = module.q_norm(q)
q = q.mean(dim=1, keepdim=True)
q = q.repeat(1, q_len, 1, 1).transpose(1, 2)

# Mean key
k = module.self_attn.k_proj(h)
k = k.view(1, k.shape[1], -1, d)
if isinstance(module, (Gemma3Attention, Qwen3Attention)):
k = module.k_norm(k)
k = k.mean(dim=1, keepdim=True)
k = k.repeat(1, q_len, 1, 1).transpose(1, 2)

Expand Down
16 changes: 12 additions & 4 deletions kvpress/presses/expected_attention_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@

from kvpress.presses.scorer_press import ScorerPress
Comment thread
SimJeg marked this conversation as resolved.
Outdated

from transformers.models.qwen3.modeling_qwen3 import Qwen3Attention
from transformers.models.gemma3.modeling_gemma3 import Gemma3Attention
from transformers.models.phi3.modeling_phi3 import Phi3Attention


@dataclass
class ExpectedAttentionPress(ScorerPress):
Expand Down Expand Up @@ -44,10 +48,14 @@ def get_query_statistics(self, module: nn.Module, hidden_states: torch.Tensor):
# Remove first hidden_states that likely contain outliers
h = hidden_states[:, self.n_sink :]

if hasattr(module, "q_proj"):
Wq = module.q_proj.weight
elif hasattr(module, "qkv_proj"):
Wq = module.qkv_proj.weight[: n * d] # type: ignore[index]
if isinstance(module, (Qwen3Attention, Gemma3Attention)):
# Qwen and Gemma use QK norm, which is not compatible with ExpectedAttentionPress (for now)
raise NotImplementedError(f"ExpectedAttentionPress not yet implemented for {module.__class__}.")
elif isinstance(module, Phi3Attention):
Wq = module.qkv_proj.weight[: n * d]
elif hasattr(module, "q_proj"):
# Assume Llama-like attention layer
Wq = module.q_proj.weight # type: ignore[assignment]
else:
raise NotImplementedError(f"ExpectedAttentionPress not yet implemented for {module.__class__}.")

Expand Down
15 changes: 12 additions & 3 deletions kvpress/presses/snapkv_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@

Comment thread
SimJeg marked this conversation as resolved.
Outdated
from kvpress.presses.scorer_press import ScorerPress

from transformers.models.qwen3.modeling_qwen3 import Qwen3Attention
from transformers.models.gemma3.modeling_gemma3 import Gemma3Attention
from transformers.models.phi3.modeling_phi3 import Phi3Attention


@dataclass
class SnapKVPress(ScorerPress):
Expand All @@ -37,16 +41,21 @@ def compute_window_attention(module, hidden_states, keys, window_size, position_
num_key_value_groups = num_heads // module.config.num_key_value_heads

# Get last window_size queries
if hasattr(module, "q_proj"):
query_states = module.q_proj(hidden_states[:, -window_size:])
elif hasattr(module, "qkv_proj"):
if isinstance(module, Phi3Attention):
qkv = module.qkv_proj(hidden_states[:, -window_size:])
query_states = qkv[..., : num_heads * head_dim]
elif hasattr(module, "q_proj"):
# Assume Llama-like attention layer
query_states = module.q_proj(hidden_states[:, -window_size:])
else:
raise NotImplementedError(f"SnapKV not yet implemented for {module.__class__}.")

query_states = query_states.view(bsz, window_size, num_heads, head_dim).transpose(1, 2)

# Support for Qwen3 and Gemma3 QK norm
if isinstance(module, (Qwen3Attention, Gemma3Attention)):
query_states = module.q_norm(query_states)

# Apply RoPE
cos, sin = position_embeddings
cos, sin = cos[:, -window_size:], sin[:, -window_size:]
Expand Down
19 changes: 14 additions & 5 deletions kvpress/presses/think_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@

from kvpress.presses.base_press import BasePress

from transformers.models.qwen3.modeling_qwen3 import Qwen3Attention
from transformers.models.gemma3.modeling_gemma3 import Gemma3Attention
from transformers.models.phi3.modeling_phi3 import Phi3Attention


@dataclass
class ThinKPress(BasePress):
Expand All @@ -36,17 +40,22 @@ def compute_window_queries(self, module, hidden_states, position_embeddings):
num_heads = module.config.num_attention_heads
head_dim = module.head_dim

# Get last window_size queries
if hasattr(module, "q_proj"):
query_states = module.q_proj(hidden_states[:, -self.window_size :])
elif hasattr(module, "qkv_proj"):
qkv = module.qkv_proj(hidden_states[:, -self.window_size :])
# Get last self.window_size queries
if isinstance(module, Phi3Attention):
qkv = module.qkv_proj(hidden_states[:, -self.window_size:])
query_states = qkv[..., : num_heads * head_dim]
elif hasattr(module, "q_proj"):
# Assume Llama-like attention layer
query_states = module.q_proj(hidden_states[:, -self.window_size:])
else:
raise NotImplementedError(f"SnapKV not yet implemented for {module.__class__}.")

query_states = query_states.view(bsz, self.window_size, num_heads, head_dim).transpose(1, 2)

# Support for Qwen3 and Gemma3 QK norm
if isinstance(module, (Qwen3Attention, Gemma3Attention)):
query_states = module.q_norm(query_states)

# Apply RoPE
cos, sin = position_embeddings
cos, sin = cos[:, -self.window_size :], sin[:, -self.window_size :]
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
name = "kvpress"
authors = ["Simon Jegou", "Maximilian Jeblick", "Jiwei Liu", "David Austin"]
description = "Efficiently compress the KV cache of any pretrained transformer"
version = "0.2.5"
version = "0.2.6"
readme = "README.md"

[tool.poetry.dependencies]
Expand Down