Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
138 commits
Select commit Hold shift + click to select a range
1f3af6a
Fix transformers <4.48
maxjeblick Jan 7, 2025
d78ac5f
Add infinitebench benchmark
maxjeblick Jan 7, 2025
80da31d
Support transformers 4.48 (#39)
SimJeg Jan 13, 2025
319ff44
AdaKVPress (#38)
SimJeg Jan 13, 2025
3e8bac7
Update README (#41)
SimJeg Jan 21, 2025
cc1d61f
Add chunk press
maxjeblick Jan 21, 2025
6398626
Update README (#42)
SimJeg Jan 21, 2025
bf314f9
Fix style errors
maxjeblick Feb 12, 2025
2bf6b2c
Add CriticalKVPress (#46)
FFY0 Feb 12, 2025
6518ed3
Add epsilon to ExpectedAttentionPress (#47)
SimJeg Feb 12, 2025
8cc1e3b
Fix distributed inference (#49)
SimJeg Feb 13, 2025
4221186
Add DuoAttentionPress (#50)
SimJeg Feb 18, 2025
4ca556a
add ChunkKV
Dominic789654 Mar 5, 2025
beaf605
Update copyright date (#60)
SimJeg Mar 13, 2025
c578e4a
Add QFilterPress (#54)
NathanGodey Mar 17, 2025
6c90873
Add longbench benchmark
Xnhyacinth Mar 19, 2025
0dafec8
Add DuoAttention on the fly (#63)
SimJeg Mar 19, 2025
e6c97ea
Add PyramidKVPress
figuremout Apr 16, 2025
48c1a19
Fix style errors (#68)
maxjeblick Apr 16, 2025
462dffc
Add FinchPress (#69)
SimJeg Apr 17, 2025
b95e8da
Update README (#70)
SimJeg Apr 17, 2025
8e2d22b
Align dev dependencies with new poetry format (#71)
emmanuel-ferdman May 5, 2025
adaca4e
Reorganize dependencies (#74)
fanqiNO1 May 6, 2025
bb00088
Add missing SPDX headers
maxjeblick Jun 6, 2025
0e6d429
Add LagKVPress (#77)
JoelSeniorLiang Jun 10, 2025
994d5c4
Support Qwen3 and Gemma3 (#81)
alessiodevoto Jun 16, 2025
f1e7372
Fix FinchPress for Qwen (#82)
alessiodevoto Jun 20, 2025
cbdbafa
Add KeyDiffPress (#86)
figuremout Jul 2, 2025
ed135fc
Fix RoPE with Yarn (#85)
giulio98 Jul 7, 2025
0b6046e
Improve documentation (#90)
maxjeblick Jul 7, 2025
78ce6ca
add Alessio to authors (#92)
maxjeblick Jul 7, 2025
f7df944
Fix failing tests (#94)
maxjeblick Jul 8, 2025
f9fcc5d
Refactor evaluation (#96)
alessiodevoto Jul 9, 2025
f0004ce
Fix QFilters and DuoAttention when used with wrapper presses (#97)
alessiodevoto Jul 14, 2025
18d0a79
Add HuggingFace leaderboard (#98)
alessiodevoto Jul 17, 2025
ee9cf0b
update links (#101)
alessiodevoto Jul 21, 2025
fe74118
Add KVzipPress
Janghyun1230 Jul 25, 2025
fbeef0f
Test head-wise compression (#103)
alessiodevoto Jul 25, 2025
39de87b
add seed for eval
alessiodevoto Jul 28, 2025
f23be36
inference mode
alessiodevoto Jul 28, 2025
7683124
update transformer dependency
alessiodevoto Jul 28, 2025
cc48f13
update pipeline
alessiodevoto Jul 28, 2025
0a90e3a
update seed with try-except
alessiodevoto Jul 28, 2025
5d0ec7e
fix style
alessiodevoto Jul 28, 2025
0fa955a
remove try
alessiodevoto Jul 28, 2025
ffa22d4
remove comments
alessiodevoto Jul 28, 2025
c9c28e3
first refactor-no quantization
alessiodevoto Jul 29, 2025
a5398cd
refactor - quantization not working
alessiodevoto Jul 29, 2025
c1e4f56
update (flash attn 2 and quantized cache not working)
alessiodevoto Jul 30, 2025
a66e33b
wip on fixing fa2 bug
maxjeblick Aug 6, 2025
5dca63f
run backbone model only for prefill
giulio98 Jul 28, 2025
ccc9d96
Fix trandformers<4.54.0
alessiodevoto Jul 28, 2025
70cfcbd
Migration to uv (#108)
alessiodevoto Aug 5, 2025
31c55f3
improve test
maxjeblick Aug 6, 2025
5b7c1c1
formatting
maxjeblick Aug 6, 2025
1a6095e
update readme
maxjeblick Aug 6, 2025
ed8f995
update readme
maxjeblick Aug 6, 2025
8073406
update transformers version
maxjeblick Aug 6, 2025
eb4a9e0
wip on fixing fa2 bug
maxjeblick Aug 6, 2025
366e342
wip on pipeline fixes
maxjeblick Aug 8, 2025
8ae8efd
refactor code
maxjeblick Aug 8, 2025
f2d0f5e
refactor code
maxjeblick Aug 8, 2025
d7c44c7
refactor code
maxjeblick Aug 8, 2025
11a3a93
fix flash attn bug
maxjeblick Aug 8, 2025
7c2255e
add back 8b model for ruler tests
maxjeblick Aug 8, 2025
6528d96
try fixing ruler test
maxjeblick Aug 8, 2025
a9866ba
Optimized covariance transform in ExpectedAttentionPress (#111)
neuralsorcerer Aug 7, 2025
a766682
fix ruler integration tests
maxjeblick Aug 8, 2025
2e73373
adjust to no compression
maxjeblick Aug 8, 2025
9a03f11
fix pos ids
maxjeblick Aug 8, 2025
6539aaf
skip PyramidKVPress, KVzipPress
maxjeblick Aug 8, 2025
b3a43ca
wip on pipeline
maxjeblick Aug 8, 2025
99529c9
remove unneeded parts during prefilling
maxjeblick Aug 8, 2025
fb6dfbf
param compression ratio
maxjeblick Aug 8, 2025
9f5a673
wip on fixing pipeline
maxjeblick Aug 12, 2025
26f10fb
fix kvzip press
maxjeblick Aug 12, 2025
fdbbc6f
fix kvzip press
maxjeblick Aug 12, 2025
a5f4841
fix kvzip press
maxjeblick Aug 12, 2025
b557067
fix kvzip press
maxjeblick Aug 12, 2025
ee6321c
fix kvzip press
maxjeblick Aug 12, 2025
d5cad9e
better docstrings
maxjeblick Aug 12, 2025
0c6201d
better docstrings
maxjeblick Aug 12, 2025
4be229c
fix tests
maxjeblick Aug 12, 2025
8d7b1ad
move back to model.model forward pass
maxjeblick Aug 12, 2025
27e3c13
update (flash attn 2 and quantized cache not working)
alessiodevoto Jul 30, 2025
c0a9df1
wip on fixing fa2 bug
maxjeblick Aug 6, 2025
85b8080
run backbone model only for prefill
giulio98 Jul 28, 2025
d1c8d84
Fix trandformers<4.54.0
alessiodevoto Jul 28, 2025
6bb0403
Migration to uv (#108)
alessiodevoto Aug 5, 2025
61a11df
improve test
maxjeblick Aug 6, 2025
13cf7ee
formatting
maxjeblick Aug 6, 2025
11dc278
update readme
maxjeblick Aug 6, 2025
c7ca196
update readme
maxjeblick Aug 6, 2025
9f9b48b
update transformers version
maxjeblick Aug 6, 2025
65d137a
wip on fixing fa2 bug
maxjeblick Aug 6, 2025
d8e482b
wip on pipeline fixes
maxjeblick Aug 8, 2025
30f37e5
refactor code
maxjeblick Aug 8, 2025
170792c
refactor code
maxjeblick Aug 8, 2025
458625f
refactor code
maxjeblick Aug 8, 2025
343aeca
fix flash attn bug
maxjeblick Aug 8, 2025
3aea31c
add back 8b model for ruler tests
maxjeblick Aug 8, 2025
b86bcd2
try fixing ruler test
maxjeblick Aug 8, 2025
3a765e4
fix ruler integration tests
maxjeblick Aug 8, 2025
7b83594
adjust to no compression
maxjeblick Aug 8, 2025
05cec0a
fix pos ids
maxjeblick Aug 8, 2025
0e3556f
skip PyramidKVPress, KVzipPress
maxjeblick Aug 8, 2025
9da9e79
wip on pipeline
maxjeblick Aug 8, 2025
ad43414
remove unneeded parts during prefilling
maxjeblick Aug 8, 2025
0d7d28f
param compression ratio
maxjeblick Aug 8, 2025
02e13c8
wip on fixing pipeline
maxjeblick Aug 12, 2025
a57e812
fix kvzip press
maxjeblick Aug 12, 2025
95e44d0
fix kvzip press
maxjeblick Aug 12, 2025
4c7a86a
fix kvzip press
maxjeblick Aug 12, 2025
a08299a
fix kvzip press
maxjeblick Aug 12, 2025
71a3a73
better docstrings
maxjeblick Aug 12, 2025
2e58c87
better docstrings
maxjeblick Aug 12, 2025
9535560
fix tests
maxjeblick Aug 12, 2025
51c96eb
move back to model.model forward pass
maxjeblick Aug 12, 2025
2df9919
fix merge conflicts
maxjeblick Aug 12, 2025
e8bf0da
remove dco fix script
maxjeblick Aug 12, 2025
9b96fe4
fix style
maxjeblick Aug 13, 2025
15dfee8
fix version of transformers
maxjeblick Aug 13, 2025
ddb66ed
fix version of transformers
maxjeblick Aug 14, 2025
bd76240
fix style
maxjeblick Aug 14, 2025
5090951
fix style
maxjeblick Aug 14, 2025
cd26de8
fix style
maxjeblick Aug 14, 2025
bc7dc1d
move to transformers 4.56
maxjeblick Aug 14, 2025
0391dc3
fix quantized cache
maxjeblick Aug 14, 2025
d868d7a
fix cumulative len for qunatized cache
maxjeblick Aug 14, 2025
f9465cf
update comment
maxjeblick Aug 14, 2025
ca423ab
update transformers to 4.55.3
maxjeblick Aug 21, 2025
c7afcd1
fix style/refactor
maxjeblick Aug 21, 2025
a50d85e
fix past_key_value
maxjeblick Aug 25, 2025
065dccd
fix past_key_value
maxjeblick Aug 25, 2025
0e539d2
fix past_key_value
maxjeblick Aug 25, 2025
714461e
fix past_key_value
maxjeblick Aug 25, 2025
1bfa1c4
update compat
alessiodevoto Sep 1, 2025
f37b768
merge main into max/transformers_compat
alessiodevoto Sep 1, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 22 additions & 22 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,27 @@ git clone https://github.com/NVIDIA/kvpress.git
cd kvpress
uv sync --all-groups
```
<details><summary>
Advanced installation settings
</summary>

To install optional packages, you can use [uv](https://docs.astral.sh/uv/).
To install with flash attention, just run:

```bash
git clone https://github.com/NVIDIA/kvpress.git
cd kvpress
uv sync --extra flash-attn
```

To install with dependencies for evaluation, run

```bash
git clone https://github.com/NVIDIA/kvpress.git
cd kvpress
uv sync --extra eval
```
</details>

## Usage

Expand Down Expand Up @@ -203,25 +224,4 @@ with press(model):

However, the `generate` method does not allow to exclude the question from the compression, which would artificially favors methods such as SnapKV. Ideally, we want a compression method that works whatever comes after the context (_e.g._ for use cases such as chat or document question answering). Finally the `generate` method does not allow to provide generation for multiple questions at once.

</details>


## Advanced installation settings
To install optional packages, you can use [uv](https://docs.astral.sh/uv/).
To install with flash attention, just run:

```bash
git clone https://github.com/NVIDIA/kvpress.git
cd kvpress
uv sync --extra flash-attn
```

To install with dependencies for evaluation, run

```bash
git clone https://github.com/NVIDIA/kvpress.git
cd kvpress
uv sync --extra eval
```

Notice that optional dependecies can be combined.
</details>
4 changes: 4 additions & 0 deletions kvpress/attention_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ def wrapper(module, query, key, value, attention_mask, dropout, **kwargs):
batch_indices, head_indices, seq_indices = module.masked_key_indices
key[batch_indices, head_indices, seq_indices] = k[batch_indices, head_indices]

# see https://github.com/NVIDIA/kvpress/pull/115#issuecomment-3183785597
# cu_seq_lens_k are only in kwargs if model.generate is used.
if "cu_seq_lens_k" in kwargs:
kwargs["cu_seq_lens_k"][-1] = key.shape[-2]
return func(module, query, key, value, attention_mask, dropout, **kwargs)

return wrapper
Expand Down
60 changes: 25 additions & 35 deletions kvpress/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,14 @@
from typing import Optional

import torch
from transformers import AutoModelForCausalLM, Cache, DynamicCache, Pipeline
from transformers import AutoModelForCausalLM, Cache, DynamicCache, Pipeline, QuantizedCache
from transformers.pipelines import PIPELINE_REGISTRY
from transformers.pipelines.base import GenericTensor

from kvpress.presses.base_press import BasePress
from kvpress.presses.finch_press import FinchPress
from kvpress.presses.key_rerotation_press import KeyRerotationPress
from kvpress.presses.observed_attention_press import ObservedAttentionPress
from kvpress.presses.per_layer_compression_press import PerLayerCompressionPress

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -224,20 +223,6 @@ def _forward(

return answers

def output_attentions(self, press: BasePress):
if isinstance(press, ObservedAttentionPress):
return True
if isinstance(press, (KeyRerotationPress, PerLayerCompressionPress)) and isinstance(
press.press, ObservedAttentionPress
):
return True
return False

def postprocess(self, model_outputs, single_question):
if single_question:
return {"answer": model_outputs[0]}
return {"answers": model_outputs}

def generate_answer(
self, question_ids: torch.Tensor, cache: Cache, context_length: int, max_new_tokens: int
) -> str:
Expand All @@ -260,7 +245,6 @@ def generate_answer(
str
The generated answer.
"""

cache_seq_lengths = [cache.get_seq_length(layer_idx) for layer_idx in range(len(cache))]
position_ids = torch.arange(
context_length, context_length + question_ids.shape[1], device=self.model.device
Expand Down Expand Up @@ -292,28 +276,34 @@ def generate_answer(
if new_id.item() in should_stop_token_ids:
break
answer = self.tokenizer.decode(torch.stack(generated_ids), skip_special_tokens=True)

# Remove the generated tokens from the cache
cache.key_cache = [
cache.key_cache[layer_idx][:, :, :sequence_length]
for layer_idx, sequence_length in enumerate(cache_seq_lengths)
]
cache.value_cache = [
cache.value_cache[layer_idx][:, :, :sequence_length]
for layer_idx, sequence_length in enumerate(cache_seq_lengths)
]
if hasattr(cache, "_quantized_key_cache"):
cache._quantized_key_cache = [
cache._quantized_key_cache[layer_idx][:, :, :sequence_length]
for layer_idx, sequence_length in enumerate(cache_seq_lengths)
]
cache._quantized_value_cache = [
cache._quantized_value_cache[layer_idx][:, :, :sequence_length]
for layer_idx, sequence_length in enumerate(cache_seq_lengths)
]
for layer_idx, sequence_length in enumerate(cache_seq_lengths):
cache.layers[layer_idx].keys = cache.layers[layer_idx].keys[:, :, :sequence_length]
cache.layers[layer_idx].values = cache.layers[layer_idx].values[:, :, :sequence_length]

if isinstance(cache, QuantizedCache):
for layer_idx, sequence_length in enumerate(cache_seq_lengths):
cache.layers[layer_idx]._quantized_keys = cache.layers[layer_idx]._quantized_keys[
:, :, :sequence_length
]
cache.layers[layer_idx]._quantized_values = cache.layers[layer_idx]._quantized_values[
:, :, :sequence_length
]

return answer

def output_attentions(self, press: BasePress):
if isinstance(press, ObservedAttentionPress):
return True
if hasattr(press, "press") and isinstance(press.press, ObservedAttentionPress):
return True
return False

def postprocess(self, model_outputs, single_question):
if single_question:
return {"answer": model_outputs[0]}
return {"answers": model_outputs}


PIPELINE_REGISTRY.register_pipeline(
"kv-press-text-generation",
Expand Down
32 changes: 19 additions & 13 deletions kvpress/presses/base_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dic
kwargs : dict
Keyword arguments passed to the attention layer's forward method, including:
- hidden_states: Input embeddings to the attention layer
- past_key_value: The KV cache object being modified
- past_key_values: The KV cache object being modified
- cache_position: Position indices indicating where we are in the sequence
- position_embeddings: RoPE embeddings if applicable
output : list
Expand All @@ -123,31 +123,37 @@ def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dic
"""

hidden_states = kwargs["hidden_states"]
cache = kwargs["past_key_value"]
cache = kwargs["past_key_values"]
q_len = hidden_states.shape[1]

# Don't compress after pre-filling
if kwargs["cache_position"][-1] > q_len:
return output

cache_layer = cache.layers[module.layer_idx]
if isinstance(cache, QuantizedCache):
keys = cache._dequantize(cache._quantized_key_cache[module.layer_idx])
values = cache._dequantize(cache._quantized_value_cache[module.layer_idx])
keys = cache_layer._dequantize( # type: ignore[index]
cache_layer._quantized_keys # type: ignore[index]
)
values = cache_layer._dequantize( # type: ignore[index]
cache_layer._quantized_values # type: ignore[index]
)

else:
keys = cache.key_cache[module.layer_idx]
values = cache.value_cache[module.layer_idx]
keys = cache_layer.keys
values = cache_layer.values

keys, values = self.compress(module, hidden_states, keys, values, output[1], kwargs)

if isinstance(cache, QuantizedCache):
cache._quantized_key_cache[module.layer_idx] = cache._quantize(keys, axis=cache.axis_key)
cache._quantized_value_cache[module.layer_idx] = cache._quantize(values, axis=cache.axis_value)
cache.key_cache[module.layer_idx] = torch.zeros(0, dtype=keys.dtype, device=keys.device)
cache.value_cache[module.layer_idx] = torch.zeros(0, dtype=keys.dtype, device=keys.device)
cache._seen_tokens = keys.shape[2]
cache_layer._quantized_keys = cache_layer._quantize(keys, axis=cache_layer.axis_key)
cache_layer._quantized_values = cache_layer._quantize(values, axis=cache_layer.axis_value)
cache_layer.keys = torch.zeros(0, dtype=keys.dtype, device=keys.device) # type: ignore[index]
cache_layer.values = torch.zeros(0, dtype=keys.dtype, device=keys.device) # type: ignore[index]
cache_layer.cumulative_length = keys.shape[2]
else:
cache.key_cache[module.layer_idx] = keys
cache.value_cache[module.layer_idx] = values
cache_layer.keys = keys
cache_layer.values = values

return output

Expand Down
38 changes: 23 additions & 15 deletions kvpress/presses/kvzip_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from contextlib import contextmanager
from dataclasses import dataclass
from types import MethodType
from typing import Any, Generator, List, cast
from typing import Generator, List

import torch
from torch import nn
Expand Down Expand Up @@ -116,7 +116,10 @@ def __call__(self, model: PreTrainedModel) -> Generator:

def wrapped_forward(model_self, *args, **kwargs):
self._context_ids = kwargs["input_ids"]
self._cache = kwargs["past_key_values"]
assert (
"past_key_value" in kwargs or "past_key_values" in kwargs
), f"KVzipPress requires 'past_key_value' or 'past_key_values' during prefilling. Got {kwargs.keys()}"
self._cache = kwargs.get("past_key_values", None) or kwargs.get("past_key_value", None)
return original_forward(*args, **kwargs)

model.model.forward = MethodType(wrapped_forward, model.model)
Expand Down Expand Up @@ -149,29 +152,34 @@ def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dic
"""

hidden_states = kwargs["hidden_states"]
cache = kwargs["past_key_value"]
cache = kwargs.get("past_key_values", None) or kwargs.get("past_key_value", None)

cache_layer = cache.layers[module.layer_idx]
if isinstance(cache, QuantizedCache):
keys = cache._dequantize(cache._quantized_key_cache[module.layer_idx]) # type: ignore[attr-defined]
values = cache._dequantize(cache._quantized_value_cache[module.layer_idx]) # type: ignore[attr-defined]
keys = cache_layer._dequantize( # type: ignore[index]
cache_layer._quantized_keys # type: ignore[index]
)
values = cache_layer._dequantize( # type: ignore[index]
cache_layer._quantized_values # type: ignore[index]
)

else:
keys = cache.key_cache[module.layer_idx]
values = cache.value_cache[module.layer_idx]
keys = cache_layer.keys
values = cache_layer.values

# Compute importance scores for KV pairs in the prefilled context,
# retaining only the originally prefilled KV pairs.
keys, values = self.score_kvzip(module, hidden_states, keys, values, output[1], kwargs)

if isinstance(cache, QuantizedCache):
cache = cast(Any, cache) # to ignore attr-defined style errors
cache._quantized_key_cache[module.layer_idx] = cache._quantize(keys, axis=cache.axis_key)
cache._quantized_value_cache[module.layer_idx] = cache._quantize(values, axis=cache.axis_value)
cache.key_cache[module.layer_idx] = torch.zeros(0, dtype=keys.dtype, device=keys.device)
cache.value_cache[module.layer_idx] = torch.zeros(0, dtype=keys.dtype, device=keys.device)
cache._seen_tokens = keys.shape[2]
cache_layer._quantized_keys = cache_layer._quantize(keys, axis=cache_layer.axis_key)
cache_layer._quantized_values = cache_layer._quantize(values, axis=cache_layer.axis_value)
cache_layer.keys = torch.zeros(0, dtype=keys.dtype, device=keys.device) # type: ignore[index]
cache_layer.values = torch.zeros(0, dtype=keys.dtype, device=keys.device) # type: ignore[index]
cache_layer.cumulative_length = keys.shape[2]
else:
cache.key_cache[module.layer_idx] = keys
cache.value_cache[module.layer_idx] = values
cache_layer.keys = keys
cache_layer.values = values

return output

Expand Down
25 changes: 6 additions & 19 deletions kvpress/presses/observed_attention_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,19 @@ class ObservedAttentionPress(ScorerPress):
----------
compression_ratio : float, default=0.0
Fraction of key-value pairs to remove during compression.
output_attentions : bool, default=False
Whether to return attention weights in model output.
Controls whether attention weights are included in output after compression.
Attention weights are always needed internally for scoring but can be removed
from output to save memory.
output_attentions : bool, default=True
Whether to output the attention weights. Must be set True but we keep it for backward compatibility.
"""

compression_ratio: float = 0.0
output_attentions: bool = False
output_attentions: bool = True

def __post_init__(self):
if not self.output_attentions:
logger.warning(
"Model will not return attentions in its output to save memory. "
"Set output_attentions=True if attentions are needed in the output."
# keep for backward compatibility, remove in version 1.0
raise ValueError(
"With transformers >= 4.54, " "ObservedAttentionPress will only work with output_attentions=True"
)
super().__post_init__()

def score(
self,
Expand All @@ -64,12 +60,3 @@ def score(
scores = scores / n_tokens_in_sum
scores = scores.view(bsz, num_key_value_heads, -1, n_tokens).mean(2)
return scores

def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dict, output: list):
output = super().forward_hook(module, input, kwargs, output)
# attentions are needed as input for the hook, but unless the user wants to return them in the output,
# we can remove them to save memory
if not self.output_attentions:
output = (output[0], None)

return output
5 changes: 2 additions & 3 deletions notebooks/speed_and_memory.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,9 @@
" if cache_implementation == \"dynamic\":\n",
" cache = DynamicCache()\n",
" elif cache_implementation == \"quantized\":\n",
" config = QuantizedCacheConfig(nbits=4)\n",
" cache = QuantoQuantizedCache(config)\n",
" cache = QuantoQuantizedCache(config=model.config, nbits=4)\n",
" else:\n",
" raise NotImplementedError(f\"Cache {cache_impl} not yet implemented\")\n",
" raise NotImplementedError(f\"Cache {cache_implementation} not yet implemented\")\n",
"\n",
" start = time()\n",
" model(inputs, num_logits_to_keep=1, past_key_values=cache)\n",
Expand Down
Loading