From 70920e83964c10d91467fe8442db8b608a3fd111 Mon Sep 17 00:00:00 2001 From: alessiodevoto Date: Thu, 12 Jun 2025 14:44:12 +0200 Subject: [PATCH 1/4] add qk norm + gemma sliding window Signed-off-by: alessiodevoto --- kvpress/presses/base_press.py | 16 ++++++++++++++-- kvpress/presses/duo_attention_press.py | 6 ++++++ kvpress/presses/expected_attention_press.py | 16 ++++++++++++---- kvpress/presses/snapkv_press.py | 15 ++++++++++++--- kvpress/presses/think_press.py | 19 ++++++++++++++----- 5 files changed, 58 insertions(+), 14 deletions(-) diff --git a/kvpress/presses/base_press.py b/kvpress/presses/base_press.py index 187c8b13..dd6f1394 100644 --- a/kvpress/presses/base_press.py +++ b/kvpress/presses/base_press.py @@ -16,6 +16,8 @@ PreTrainedModel, QuantizedCache, Qwen2ForCausalLM, + Qwen3ForCausalLM, + Gemma3ForCausalLM, ) logger = logging.getLogger(__name__) @@ -127,13 +129,23 @@ def __call__(self, model: PreTrainedModel) -> Generator: model : PreTrainedModel Model to apply the compression method to """ - - if not isinstance(model, (LlamaForCausalLM, MistralForCausalLM, Phi3ForCausalLM, Qwen2ForCausalLM)): + supported_models = ( + LlamaForCausalLM, + MistralForCausalLM, + Phi3ForCausalLM, + Qwen2ForCausalLM, + Qwen3ForCausalLM, + Gemma3ForCausalLM, + ) + if not isinstance(model, supported_models): logger.warning(f"Model {type(model)} not tested") hooks = [] try: for layer in model.model.layers: + if isinstance(model, Gemma3ForCausalLM) and layer.is_sliding: + # Skip layers with sliding window attention, only for Gemma3 + continue layer.self_attn.rotary_emb = model.model.rotary_emb hooks.append(layer.self_attn.register_forward_hook(self.forward_hook, with_kwargs=True)) yield diff --git a/kvpress/presses/duo_attention_press.py b/kvpress/presses/duo_attention_press.py index 40abb3d1..e0982645 100644 --- a/kvpress/presses/duo_attention_press.py +++ b/kvpress/presses/duo_attention_press.py @@ -12,6 +12,8 @@ from datasets import load_dataset from transformers import AutoTokenizer from transformers.models.llama.modeling_llama import apply_rotary_pos_emb +from transformers.models.qwen3.modeling_qwen3 import Qwen3Attention +from transformers.models.gemma3.modeling_gemma3 import Gemma3Attention from kvpress.presses.base_press import BasePress @@ -161,12 +163,16 @@ def duo_attention_on_the_fly(model, num_samples=50, q_len=500): # Mean query q = module.self_attn.q_proj(h) q = q.view(1, q.shape[1], -1, d) + if isinstance(module, (Gemma3Attention, Qwen3Attention)): + q = module.q_norm(q) q = q.mean(dim=1, keepdim=True) q = q.repeat(1, q_len, 1, 1).transpose(1, 2) # Mean key k = module.self_attn.k_proj(h) k = k.view(1, k.shape[1], -1, d) + if isinstance(module, (Gemma3Attention, Qwen3Attention)): + k = module.k_norm(k) k = k.mean(dim=1, keepdim=True) k = k.repeat(1, q_len, 1, 1).transpose(1, 2) diff --git a/kvpress/presses/expected_attention_press.py b/kvpress/presses/expected_attention_press.py index 7dabc910..3b609967 100644 --- a/kvpress/presses/expected_attention_press.py +++ b/kvpress/presses/expected_attention_press.py @@ -12,6 +12,10 @@ from kvpress.presses.scorer_press import ScorerPress +from transformers.models.qwen3.modeling_qwen3 import Qwen3Attention +from transformers.models.gemma3.modeling_gemma3 import Gemma3Attention +from transformers.models.phi3.modeling_phi3 import Phi3Attention + @dataclass class ExpectedAttentionPress(ScorerPress): @@ -44,10 +48,14 @@ def get_query_statistics(self, module: nn.Module, hidden_states: torch.Tensor): # Remove first hidden_states that likely contain outliers h = hidden_states[:, self.n_sink :] - if hasattr(module, "q_proj"): - Wq = module.q_proj.weight - elif hasattr(module, "qkv_proj"): - Wq = module.qkv_proj.weight[: n * d] # type: ignore[index] + if isinstance(module, (Qwen3Attention, Gemma3Attention)): + # Qwen and Gemma use QK norm, which is not compatible with ExpectedAttentionPress (for now) + raise NotImplementedError(f"ExpectedAttentionPress not yet implemented for {module.__class__}.") + elif isinstance(module, Phi3Attention): + Wq = module.qkv_proj.weight[: n * d] + elif hasattr(module, "q_proj"): + # Assume Llama-like attention layer + Wq = module.q_proj.weight # type: ignore[assignment] else: raise NotImplementedError(f"ExpectedAttentionPress not yet implemented for {module.__class__}.") diff --git a/kvpress/presses/snapkv_press.py b/kvpress/presses/snapkv_press.py index c4830032..cee21a40 100644 --- a/kvpress/presses/snapkv_press.py +++ b/kvpress/presses/snapkv_press.py @@ -12,6 +12,10 @@ from kvpress.presses.scorer_press import ScorerPress +from transformers.models.qwen3.modeling_qwen3 import Qwen3Attention +from transformers.models.gemma3.modeling_gemma3 import Gemma3Attention +from transformers.models.phi3.modeling_phi3 import Phi3Attention + @dataclass class SnapKVPress(ScorerPress): @@ -37,16 +41,21 @@ def compute_window_attention(module, hidden_states, keys, window_size, position_ num_key_value_groups = num_heads // module.config.num_key_value_heads # Get last window_size queries - if hasattr(module, "q_proj"): - query_states = module.q_proj(hidden_states[:, -window_size:]) - elif hasattr(module, "qkv_proj"): + if isinstance(module, Phi3Attention): qkv = module.qkv_proj(hidden_states[:, -window_size:]) query_states = qkv[..., : num_heads * head_dim] + elif hasattr(module, "q_proj"): + # Assume Llama-like attention layer + query_states = module.q_proj(hidden_states[:, -window_size:]) else: raise NotImplementedError(f"SnapKV not yet implemented for {module.__class__}.") query_states = query_states.view(bsz, window_size, num_heads, head_dim).transpose(1, 2) + # Support for Qwen3 and Gemma3 QK norm + if isinstance(module, (Qwen3Attention, Gemma3Attention)): + query_states = module.q_norm(query_states) + # Apply RoPE cos, sin = position_embeddings cos, sin = cos[:, -window_size:], sin[:, -window_size:] diff --git a/kvpress/presses/think_press.py b/kvpress/presses/think_press.py index 8e9cbf57..c8be4bb2 100644 --- a/kvpress/presses/think_press.py +++ b/kvpress/presses/think_press.py @@ -10,6 +10,10 @@ from kvpress.presses.base_press import BasePress +from transformers.models.qwen3.modeling_qwen3 import Qwen3Attention +from transformers.models.gemma3.modeling_gemma3 import Gemma3Attention +from transformers.models.phi3.modeling_phi3 import Phi3Attention + @dataclass class ThinKPress(BasePress): @@ -36,17 +40,22 @@ def compute_window_queries(self, module, hidden_states, position_embeddings): num_heads = module.config.num_attention_heads head_dim = module.head_dim - # Get last window_size queries - if hasattr(module, "q_proj"): - query_states = module.q_proj(hidden_states[:, -self.window_size :]) - elif hasattr(module, "qkv_proj"): - qkv = module.qkv_proj(hidden_states[:, -self.window_size :]) + # Get last self.window_size queries + if isinstance(module, Phi3Attention): + qkv = module.qkv_proj(hidden_states[:, -self.window_size:]) query_states = qkv[..., : num_heads * head_dim] + elif hasattr(module, "q_proj"): + # Assume Llama-like attention layer + query_states = module.q_proj(hidden_states[:, -self.window_size:]) else: raise NotImplementedError(f"SnapKV not yet implemented for {module.__class__}.") query_states = query_states.view(bsz, self.window_size, num_heads, head_dim).transpose(1, 2) + # Support for Qwen3 and Gemma3 QK norm + if isinstance(module, (Qwen3Attention, Gemma3Attention)): + query_states = module.q_norm(query_states) + # Apply RoPE cos, sin = position_embeddings cos, sin = cos[:, -self.window_size :], sin[:, -self.window_size :] From ec53981906581c881f6494f0c4de22cef59a0eca Mon Sep 17 00:00:00 2001 From: alessiodevoto Date: Thu, 12 Jun 2025 14:46:00 +0200 Subject: [PATCH 2/4] update version Signed-off-by: alessiodevoto --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index a1b18aa1..d67fbefd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,7 +2,7 @@ name = "kvpress" authors = ["Simon Jegou", "Maximilian Jeblick", "Jiwei Liu", "David Austin"] description = "Efficiently compress the KV cache of any pretrained transformer" -version = "0.2.5" +version = "0.2.6" readme = "README.md" [tool.poetry.dependencies] From b135bbeac6bdd80ba7c0828c42f529eb6f7338cf Mon Sep 17 00:00:00 2001 From: alessiodevoto Date: Sun, 15 Jun 2025 19:26:25 +0200 Subject: [PATCH 3/4] improve style Signed-off-by: alessiodevoto --- kvpress/presses/base_press.py | 24 ++++++++++++--------- kvpress/presses/expected_attention_press.py | 5 ++--- kvpress/presses/snapkv_press.py | 5 ++--- kvpress/presses/think_press.py | 5 ++--- 4 files changed, 20 insertions(+), 19 deletions(-) diff --git a/kvpress/presses/base_press.py b/kvpress/presses/base_press.py index dd6f1394..50e231d3 100644 --- a/kvpress/presses/base_press.py +++ b/kvpress/presses/base_press.py @@ -22,6 +22,15 @@ logger = logging.getLogger(__name__) +SUPPORTED_MODELS = ( + LlamaForCausalLM, + MistralForCausalLM, + Phi3ForCausalLM, + Qwen2ForCausalLM, + Qwen3ForCausalLM, + Gemma3ForCausalLM, +) + @dataclass class BasePress: @@ -129,16 +138,11 @@ def __call__(self, model: PreTrainedModel) -> Generator: model : PreTrainedModel Model to apply the compression method to """ - supported_models = ( - LlamaForCausalLM, - MistralForCausalLM, - Phi3ForCausalLM, - Qwen2ForCausalLM, - Qwen3ForCausalLM, - Gemma3ForCausalLM, - ) - if not isinstance(model, supported_models): - logger.warning(f"Model {type(model)} not tested") + if not isinstance(model, SUPPORTED_MODELS): + logger.warning(f"Model {type(model)} not tested, supported models: {SUPPORTED_MODELS}") + + if isinstance(model, Gemma3ForCausalLM): + logger.warning("Compression in Gemma3 is only applied to layer without sliding window attention") hooks = [] try: diff --git a/kvpress/presses/expected_attention_press.py b/kvpress/presses/expected_attention_press.py index 3b609967..1d8ce6a7 100644 --- a/kvpress/presses/expected_attention_press.py +++ b/kvpress/presses/expected_attention_press.py @@ -9,13 +9,12 @@ from torch import nn from torch.nn import functional as F from transformers.models.llama.modeling_llama import repeat_kv - -from kvpress.presses.scorer_press import ScorerPress - from transformers.models.qwen3.modeling_qwen3 import Qwen3Attention from transformers.models.gemma3.modeling_gemma3 import Gemma3Attention from transformers.models.phi3.modeling_phi3 import Phi3Attention +from kvpress.presses.scorer_press import ScorerPress + @dataclass class ExpectedAttentionPress(ScorerPress): diff --git a/kvpress/presses/snapkv_press.py b/kvpress/presses/snapkv_press.py index cee21a40..078de583 100644 --- a/kvpress/presses/snapkv_press.py +++ b/kvpress/presses/snapkv_press.py @@ -9,13 +9,12 @@ from torch import nn from torch.nn import functional as F from transformers.models.llama.modeling_llama import repeat_kv, rotate_half - -from kvpress.presses.scorer_press import ScorerPress - from transformers.models.qwen3.modeling_qwen3 import Qwen3Attention from transformers.models.gemma3.modeling_gemma3 import Gemma3Attention from transformers.models.phi3.modeling_phi3 import Phi3Attention +from kvpress.presses.scorer_press import ScorerPress + @dataclass class SnapKVPress(ScorerPress): diff --git a/kvpress/presses/think_press.py b/kvpress/presses/think_press.py index c8be4bb2..6ee48280 100644 --- a/kvpress/presses/think_press.py +++ b/kvpress/presses/think_press.py @@ -7,13 +7,12 @@ import torch from torch import nn from transformers.models.llama.modeling_llama import rotate_half - -from kvpress.presses.base_press import BasePress - from transformers.models.qwen3.modeling_qwen3 import Qwen3Attention from transformers.models.gemma3.modeling_gemma3 import Gemma3Attention from transformers.models.phi3.modeling_phi3 import Phi3Attention +from kvpress.presses.base_press import BasePress + @dataclass class ThinKPress(BasePress): From 6bf106ffb3f99688097d0f59041f150fa848a1a9 Mon Sep 17 00:00:00 2001 From: alessiodevoto Date: Mon, 16 Jun 2025 11:33:47 +0200 Subject: [PATCH 4/4] update init file Signed-off-by: alessiodevoto --- kvpress/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/kvpress/__init__.py b/kvpress/__init__.py index 7631557a..6d432e41 100644 --- a/kvpress/__init__.py +++ b/kvpress/__init__.py @@ -27,6 +27,7 @@ from kvpress.presses.pyramidkv_press import PyramidKVPress from kvpress.presses.finch_press import FinchPress from kvpress.presses.lagkv_press import LagKVPress +from kvpress.presses.base_press import SUPPORTED_MODELS # Patch the attention functions to support head-wise compression patch_attention_functions()