diff --git a/README.md b/README.md index 6ada295d..7039274b 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,13 @@ [![PyPI version](https://badge.fury.io/py/kvpress.svg)](https://badge.fury.io/py/kvpress) [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) [![Colab example notebook](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1JNvaTKuuAHrl49dYB9-mdEH_y52Ib-NP?usp=drive_link) +[![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-blue)](https://huggingface.co/spaces/nvidia/kvpress) + ![kvpress](kvpress.jpg) -Deploying long-context LLMs is costly due to the linear growth of the key-value (KV) cache in transformer models. For example, handling 1M tokens with Llama 3.1-70B in float16 requires up to 330GB of memory. This repository implements multiple KV cache compression methods and benchmarks using [🤗 transformers](https://huggingface.co/docs/transformers/en/index), aiming to simplify the development of new methods for researchers and developers in this field. + +Deploying long-context LLMs is costly due to the linear growth of the key-value (KV) cache in transformer models. For example, handling 1M tokens with Llama 3.1-70B in float16 requires up to 330GB of memory. kvpress implements multiple KV cache compression methods and benchmarks using 🤗 transformers, aiming to simplify the development of new methods for researchers and developers in this field. ## Installation @@ -12,41 +15,42 @@ Deploying long-context LLMs is costly due to the linear growth of the key-value pip install kvpress ``` -We recommend using [flash attention](https://github.com/Dao-AILab/flash-attention/) if possible: +If possible, install flash attention: ```bash pip install flash-attn --no-build-isolation ``` ## Usage -This repository provides a set of "presses" that compress the KV cache. A press is only applied during the pre-filling phase and is associated with a `compression_ratio` attribute that measures the compression of the cache. The easiest way to use a press is through our custom `KVPressTextGenerationPipeline` that is automatically registered as a transformers pipeline with the name "kv-press-text-generation" when kvpress is imported. It handles chat templates and tokenization for you: +kvpress provides a set of "presses" that compress the KV cache during the prefilling-phase. Each press is associated with a `compression_ratio` attribute that measures the compression of the cache. The easiest way to use a press is through our custom `KVPressTextGenerationPipeline`. It is automatically registered as a transformers pipeline with the name "kv-press-text-generation" when kvpress is imported and handles chat templates and tokenization for you: ```python -from kvpress import ExpectedAttentionPress from transformers import pipeline +from kvpress import ExpectedAttentionPress device = "cuda:0" -model= "microsoft/Phi-3.5-mini-instruct" -pipe = pipeline("kv-press-text-generation", model=model, device=device, torch_dtype="auto", model_kwargs={"attn_implementation":"flash_attention_2"}) +model = "meta-llama/Llama-3.1-8B-Instruct" +model_kwargs = {"attn_implementation": "flash_attention_2"} +pipe = pipeline("kv-press-text-generation", model=model, device=device, model_kwargs=model_kwargs) context = "A very long text you want to compress once and for all" -question = "\nA question about the compressed context" # optional - -press = ExpectedAttentionPress(compression_ratio=0.4) +question = "\nA question about the compressed context" # optional + +press = ExpectedAttentionPress(compression_ratio=0.5) answer = pipe(context, question=question, press=press)["answer"] ``` -In the snippet above, the compression is only applied on the context tokens so that you can evaluate the compression for different questions. Check the [Wikipedia notebook demo](notebooks/wikipedia_demo.ipynb) for a more detailed example. +In the snippet above, the compression is only applied on the context tokens so that you can evaluate the compression for different questions. Check the [Wikipedia notebook demo](notebooks/wikipedia_demo.ipynb) for a more detailed example (also available on Colab [here](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1JNvaTKuuAHrl49dYB9-mdEH_y52Ib-NP?usp=drive_link)). > [!IMPORTANT] > We focus on compression during the pre-filling phase as the KV cache becomes a bottleneck for long-context sequence (100k - 1M tokens) which are essentially long context prompts. This would typically apply to improving prompt caching systems. > [!NOTE] -> To use the `ObservedAttentionPress`, use `model_kwargs={"attn_implementation":"eager"}` in order to materialize the attention weights (this method is not compatible with flash attention). +> Use `model_kwargs={"attn_implementation":"flash_attention_2"}` to enable flash attention. To use the press `ObservedAttentionPress`, you need to specify `model_kwargs={"attn_implementation":"eager"}` as this press requires to materialize the attention weights -## Contributing with a new press +## Contributing -We welcome contributions! If you want to implement a new press, open an issue or a pull request. Refer to the [new_press.ipynb](notebooks/new_press.ipynb) notebook for a step-by-step guide to understand how presses work and what should be done to create a new one. +We welcome contributions! To add a new press, simply open an issue or submit a pull request. Check the [new_press.ipynb](notebooks/new_press.ipynb) notebook for a step-by-step guide. ## Available presses @@ -65,7 +69,7 @@ Some presses rely on a different logic: - `SimLayerKVPress`: identify "lazy" layers, and apply the StreamingLLM approach to them ([paper](https://arxiv.org/abs/2410.13846)) Finally we provide special presses: -- `AdaKVPress`: prune bottom scores of any `ScorerPress` but across all heads, achieving head-wise compressions (see [paper](https://arxiv.org/abs/2407.11550)) +- `AdaKVPress`: prune bottom scores of any `ScorerPress` but across all heads, achieving head-wise compressions ([paper](https://arxiv.org/abs/2407.11550)) - `PerLayerCompressionPress`: compress each layer with a different compression ratio (experimental). This press can be used with any other press that allows to set a compression_ratio - `ComposedPress`: compose multiple presses together by chaining their forward hooks - `KeyRerotationPress`: rerotate pruned keys to have continuous RoPE embeddings. This press can be used with any other press that inherits from `ScorerPress`. @@ -74,19 +78,17 @@ For a detailed list of existing KV cache compression methods, check [Awesome-KV- ## Evaluation -See the [speed_and_memory.ipynb](notebooks/speed_and_memory.ipynb) notebook on how to measure peak memory usage and total time gain. -drawing +The [speed_and_memory.ipynb](notebooks/speed_and_memory.ipynb) notebook can help you to measure peak memory usage and total time gain. +![memory](evaluation/assets/peak_memory_consumption_xkcd.png) -We provide a simple CLI to evaluate the performance of the different presses on several long-context datasets. +We provide a simple CLI to evaluate the performance of the different presses on several long-context datasets. Below we report the average performance on the RULER dataset with 4k context length for different presses. -_Average performance on the RULER dataset with 4k context length and Loogle Short Dependency QA task for 3 models and 7 presses_ -![RULER](evaluation/assets/ruler_4096_average%20score.png) -![Loogle](evaluation/assets/loogle_shortdep_qa.png) +![RULER](evaluation/assets/ruler_llama_xkcd.png) Please refer to the [evaluation](evaluation/README.md) directory for more details and results. -## KV cache quantization +## Quantization We support KV cache quantization through the transformers `QuantizedCache` class (see [HF blog post](https://huggingface.co/blog/kv-cache-quantization#how-to-use-quantized-kv-cache-in-%F0%9F%A4%97-transformers)). To use it, simply pass a cache object to your pipeline: @@ -102,7 +104,7 @@ pipe(..., cache=cache) By default, the `DynamicCache` is used (no quantization). > [!IMPORTANT] -> To use the `QuantizedCache`, you need to install additional dependencies (e.g. `pip install optimum-quanto`). +> To use the `QuantizedCache`, you need to install additional dependencies (_e.g._ `pip install optimum-quanto`). ## FAQ @@ -112,7 +114,7 @@ By default, the `DynamicCache` is used (no quantization). ### Which models are supported ? -Some presses depend on the model architecture (_e.g._ `ExpectedAttentionPress` and `SnapKVPress`) hence they might not work with all models. We tested support for `LlamaForCausalLM`, `MistralForCausalLM`, `Phi3ForCausalLM` and `Qwen2ForCausalLM` but many other models might be supported out of the box because their implementation is often similar in transformers. +Some presses depend on the model architecture (_e.g._ `ExpectedAttentionPress` or `SnapKVPress`) hence they might not work with all models. We tested support for `LlamaForCausalLM`, `MistralForCausalLM`, `Phi3ForCausalLM` and `Qwen2ForCausalLM` but many other models might be supported out of the box because their implementation is often similar in transformers. diff --git a/evaluation/assets/peak_memory_consumption_xkcd.png b/evaluation/assets/peak_memory_consumption_xkcd.png new file mode 100644 index 00000000..993318c3 Binary files /dev/null and b/evaluation/assets/peak_memory_consumption_xkcd.png differ diff --git a/evaluation/assets/ruler_llama_xkcd.png b/evaluation/assets/ruler_llama_xkcd.png new file mode 100644 index 00000000..900cb3fb Binary files /dev/null and b/evaluation/assets/ruler_llama_xkcd.png differ