From 01dcd7db72e5b49a11db8774b029edf68befa852 Mon Sep 17 00:00:00 2001 From: YuhuiXu Date: Thu, 28 Nov 2024 09:30:36 +0000 Subject: [PATCH 1/3] add think_press --- evaluation/evaluate.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/evaluation/evaluate.py b/evaluation/evaluate.py index b82018b2..bcb7c106 100644 --- a/evaluation/evaluate.py +++ b/evaluation/evaluate.py @@ -23,6 +23,7 @@ RandomPress, SnapKVPress, StreamingLLMPress, + ThinKPress, ) logger = logging.getLogger(__name__) @@ -48,6 +49,7 @@ "random": RandomPress(), "snapkv": SnapKVPress(), "streaming_llm": StreamingLLMPress(), + "think": ThinKPress(), } From f3e09b7a9bbf1288041cd448d60329b934098986 Mon Sep 17 00:00:00 2001 From: YuhuiXu Date: Thu, 28 Nov 2024 09:32:17 +0000 Subject: [PATCH 2/3] add think_press --- kvpress/presses/base_press.py | 28 +++++++---- kvpress/presses/think_press.py | 85 ++++++++++++++++++++++++++++++++++ 2 files changed, 104 insertions(+), 9 deletions(-) create mode 100644 kvpress/presses/think_press.py diff --git a/kvpress/presses/base_press.py b/kvpress/presses/base_press.py index 553187aa..7b50214b 100644 --- a/kvpress/presses/base_press.py +++ b/kvpress/presses/base_press.py @@ -108,15 +108,25 @@ def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dic with torch.no_grad(): scores = self.score(module, hidden_states, keys, values, attentions, kwargs) - - # Prune KV pairs with the lowest scores - n_kept = int(q_len * (1 - self.compression_ratio)) - indices = scores.topk(n_kept, dim=-1).indices - indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim) - - # Update cache - keys = keys.gather(2, indices).contiguous() - values = values.gather(2, indices).contiguous() + # ThinK uses the channel wise pruning + if self.__class__.__name__ == "ThinKPress": + n_prune = int(module.head_dim * self.compression_ratio) + scores = scores.view(scores.shape[0], keys.shape[1], -1, scores.shape[-1]).sum(dim=-2) + _, indices = torch.topk(scores, n_prune, dim=-1, largest=False) + keep_idx = indices.sort().values + mask = torch.zeros(scores.shape, dtype=torch.bool).to(scores.device) + mask_k = mask.scatter(-1, keep_idx, 1) + mask_k = mask_k.unsqueeze(2).expand(-1, -1, q_len - self.window_size, -1) + keys = torch.cat([keys[:, :, :q_len - self.window_size, :].masked_fill(mask_k, 0), keys[:, :, q_len - self.window_size:, :]], dim=-2) + else: + # Prune KV pairs with the lowest scores + n_kept = int(q_len * (1 - self.compression_ratio)) + indices = scores.topk(n_kept, dim=-1).indices + indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim) + + # Update cache + keys = keys.gather(2, indices).contiguous() + values = values.gather(2, indices).contiguous() if isinstance(cache, QuantizedCache): cache._quantized_key_cache[module.layer_idx] = cache._quantize(keys, axis=cache.axis_key) cache._quantized_value_cache[module.layer_idx] = cache._quantize(values, axis=cache.axis_value) diff --git a/kvpress/presses/think_press.py b/kvpress/presses/think_press.py new file mode 100644 index 00000000..e1615d5b --- /dev/null +++ b/kvpress/presses/think_press.py @@ -0,0 +1,85 @@ +# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + + +import inspect +import math +from dataclasses import dataclass + +import torch +import torch.nn.functional as F +from torch import nn +from transformers.models.llama.modeling_llama import repeat_kv, rotate_half + +from kvpress.presses.base_press import BasePress + + +@dataclass +class ThinKPress(BasePress): + """ + SnapKV (https://arxiv.org/abs/2404.14469) use the attention of the latest window_size tokens to estimate the + importance of the previous KV pairs. We use the default settings from: + https://github.com/FasterDecoding/SnapKV/blob/main/snapkv/monkeypatch/snapkv_utils.py#L24 + """ + + compression_ratio: float = 0.0 + window_size: int = 64 + kernel_size: int = 5 + + def compute_window_attention(self, module, hidden_states, keys): + """ + Compute the last window_size queries and associated attention weights for the first q_len - window_size keys. + """ + + bsz, q_len, _ = hidden_states.shape + + # Get last window_size queries + if hasattr(module, "q_proj"): + query_states = module.q_proj(hidden_states[:, -self.window_size :]) + elif hasattr(module, "qkv_proj"): + qkv = module.qkv_proj(hidden_states[:, -self.window_size :]) + query_states = qkv[..., : module.num_heads * module.head_dim] + else: + raise NotImplementedError(f"SnapKV not yet implemented for {module.__class__}.") + + query_states = query_states.view(bsz, self.window_size, module.num_heads, module.head_dim).transpose(1, 2) + + # Apply RoPE + if "position_ids" in inspect.signature(module.rotary_emb.forward).parameters: + position_ids = torch.arange(q_len - self.window_size, q_len).unsqueeze(0).to(query_states.device) + cos, sin = module.rotary_emb(query_states, position_ids) + else: + cos, sin = module.rotary_emb(query_states, q_len) + cos, sin = cos[-self.window_size :].unsqueeze(0), sin[-self.window_size :].unsqueeze(0) + query_states = (query_states * cos) + (rotate_half(query_states) * sin) + + # Compute attention for first q_len - window_size tokens + key_states = repeat_kv(keys, module.num_key_value_groups) + # attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(module.head_dim) + channel_attn = torch.matmul(query_states.permute(0, 1, 3, 2).unsqueeze(-1), key_states.transpose(2, 3).unsqueeze(-2)) + # attention_mask = torch.ones_like(attn_weights) * float("-inf") + # attention_mask = torch.triu(attention_mask, diagonal=q_len - self.window_size + 1) + # attn_weights += attention_mask + # attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + # attn_weights = attn_weights[..., : -self.window_size] + + return channel_attn + + def score( + self, + module: nn.Module, + hidden_states: torch.Tensor, + keys: torch.Tensor, + values: torch.Tensor, + attentions: torch.Tensor, + kwargs, + ) -> torch.Tensor: + + bsz, num_key_value_heads, q_len, _ = keys.shape + + assert q_len > self.window_size, "Query length should be greater than the window size" + + channel_attn = self.compute_window_attention(module, hidden_states, keys) + channel_score = channel_attn.pow_(2).sum(dim=(-1, -2)) + + return channel_score From ccc420b33a307e1be0509ef8a7a1415a825d57c1 Mon Sep 17 00:00:00 2001 From: YuhuiXu Date: Thu, 28 Nov 2024 09:38:24 +0000 Subject: [PATCH 3/3] add think_press --- kvpress/__init__.py | 2 ++ kvpress/presses/base_press.py | 3 ++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/kvpress/__init__.py b/kvpress/__init__.py index 2f1e0409..3b0935e6 100644 --- a/kvpress/__init__.py +++ b/kvpress/__init__.py @@ -12,6 +12,7 @@ from kvpress.presses.snapkv_press import SnapKVPress from kvpress.presses.streaming_llm_press import StreamingLLMPress from kvpress.presses.tova_press import TOVAPress +from kvpress.presses.think_press import ThinKPress __all__ = [ "BasePress", @@ -22,6 +23,7 @@ "SnapKVPress", "StreamingLLMPress", "TOVAPress", + "ThinKPress", "KVPressTextGenerationPipeline", "apply_per_layer_compression", ] diff --git a/kvpress/presses/base_press.py b/kvpress/presses/base_press.py index 7b50214b..efa628d2 100644 --- a/kvpress/presses/base_press.py +++ b/kvpress/presses/base_press.py @@ -111,7 +111,8 @@ def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dic # ThinK uses the channel wise pruning if self.__class__.__name__ == "ThinKPress": n_prune = int(module.head_dim * self.compression_ratio) - scores = scores.view(scores.shape[0], keys.shape[1], -1, scores.shape[-1]).sum(dim=-2) + if keys.shape[1] != scores.shape[1]: + scores = scores.view(scores.shape[0], keys.shape[1], -1, scores.shape[-1]).sum(dim=-2) _, indices = torch.topk(scores, n_prune, dim=-1, largest=False) keep_idx = indices.sort().values mask = torch.zeros(scores.shape, dtype=torch.bool).to(scores.device)