-
-
Notifications
You must be signed in to change notification settings - Fork 14.8k
Add llmcompressor fp8 kv-cache quant (per-tensor and per-attn_head) #30141
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
LucasWilkinson
merged 25 commits into
vllm-project:main
from
eldarkurtic:expand-static-scaled-fp8-quant
Jan 22, 2026
Merged
Changes from all commits
Commits
Show all changes
25 commits
Select commit
Hold shift + click to select a range
5a8b680
Add support for llmcompressor fp8 kv-cache quant (per-tensor and per-…
eldarkurtic db2b3d2
satisfy mypy checks until next compressed-tensors release
eldarkurtic 1bf17b8
revert scales for cascade attn
eldarkurtic d6a6fac
fix markdown lint
eldarkurtic d68d150
fix scales for cascade attn
eldarkurtic 03109d5
address PR reviews
eldarkurtic ce0e1bd
fix mock attention
eldarkurtic ff7f9d7
address PR reviews and fix tests
eldarkurtic c42f4b5
address PR reviews
eldarkurtic 63806e2
refactoring
eldarkurtic eff6864
lint
eldarkurtic e961b66
minor fixes
eldarkurtic 8fecf3b
fix
eldarkurtic 773ce01
fix compressed-tensors test for kv-cache quant
eldarkurtic 08fa0fc
prepare scales for group broadcast
eldarkurtic 309142f
fix ruff
eldarkurtic 6d83b18
fix
eldarkurtic af994b1
fix tests for attn_head
eldarkurtic a0b0d8a
fix compilation
eldarkurtic ea368be
tmp
eldarkurtic dad0ab2
fix per-tensor case and reshaping
eldarkurtic 307d0c2
fix format
eldarkurtic efb7f8e
remove usage of env vars
eldarkurtic 7234561
simplify test
eldarkurtic cf5c848
simplify convert_fp8_local
eldarkurtic File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,162 +1,187 @@ | ||
| # Quantized KV Cache | ||
|
|
||
| ## FP8 KV Cache | ||
| ## FP8 KV Cache Overview | ||
|
|
||
| Quantizing the KV cache to FP8 reduces its memory footprint. This increases the number of tokens that can be stored in the cache, improving throughput. | ||
| Efficient memory usage is crucial for working with large language models. Quantizing the KV (Key-Value) cache to FP8 format can significantly reduce its memory footprint. This optimization enables you to store more tokens in memory, leading to improved throughput and support for longer context windows. | ||
|
|
||
| ### FP8 Formats | ||
| > **Note:** When using the Flash Attention 3 backend with FP8 KV cache, attention operations are also performed in the quantized (FP8) domain. In this configuration, queries are quantized to FP8 in addition to keys and values. | ||
|
|
||
| [OCP (Open Compute Project)](https://www.opencompute.org) specifies two common 8-bit floating point data formats: | ||
| ### Supported FP8 KV-Cache Quantization Schemes | ||
|
|
||
| - E5M2 (5 exponent bits and 2 mantissa bits) | ||
| - E4M3FN (4 exponent bits and 3 mantissa bits, often shortened as E4M3) | ||
| vLLM supports two main quantization strategies for the FP8 KV-cache: | ||
|
|
||
| The E4M3 format offers higher precision compared to E5M2. However, due to its small dynamic range (±240.0), E4M3 typically requires a higher-precision (FP32) scaling factor alongside each quantized tensor. | ||
| - **Per-tensor quantization:** | ||
| A single scale is applied for each Q, K, and V tensor individually. (`q/k/v_scale = [1]`) | ||
| - **Per-attention-head quantization:** | ||
| Each scale corresponds to an attention head: `q_scale = [num_heads]`, `k/v_scale = [num_kv_heads]`. | ||
|
|
||
| ### Current Limitations | ||
| > **Note:** | ||
| > Per-attention-head quantization is currently available **only with the Flash Attention backend** and requires the calibration pathway provided by **llm-compressor**. | ||
|
|
||
| For now, only per-tensor (scalar) scaling factors are supported. Development is ongoing to support scaling factors of a finer granularity (e.g. per-channel). | ||
| ### Scale Calibration Approaches | ||
|
|
||
| ### How FP8 KV Cache Works | ||
| You can configure how the quantization scales are computed in vLLM using three different approaches: | ||
|
|
||
| The FP8 KV cache implementation follows this workflow: | ||
| 1. **No calibration (default scales):** | ||
| All quantization scales are set to `1.0`. | ||
| _Configure with:_ | ||
| ```python | ||
| kv_cache_dtype="fp8" | ||
| calculate_kv_scales=False | ||
| ``` | ||
|
|
||
| 1. **Storage**: Key and Value tensors are quantized to FP8 format using scaling factors before being stored in the KV cache | ||
| 2. **Retrieval**: When needed for attention computation, cached KV tensors are dequantized back to higher precision (FP16/BF16) | ||
| 3. **Attention**: The attention-value multiplication (softmax output × V) is performed using the dequantized higher-precision V tensor | ||
| 2. **Random token calibration (on-the-fly):** | ||
| Scales are automatically estimated from a single batch of random tokens during warmup and then fixed. | ||
| _Configure with:_ | ||
| ```python | ||
| kv_cache_dtype="fp8" | ||
| calculate_kv_scales=True | ||
| ``` | ||
|
|
||
| This means the final attention computation operates on dequantized values, not FP8 tensors. The quantization reduces memory usage during storage but maintains computation accuracy by using higher precision during the actual attention operations. | ||
| 3. **[Recommended] Calibration with a dataset (via llm-compressor):** | ||
| Scales are estimated using a curated calibration dataset for maximum accuracy. | ||
| This requires the [llm-compressor](https://github.com/vllm-project/llm-compressor) library. | ||
| _See example below!_ | ||
|
|
||
| ### Performance Impact | ||
| #### Additional `kv_cache_dtype` Options | ||
|
|
||
| The current FP8 KV cache implementation primarily benefits throughput by allowing approximately double the amount of space for KV cache allocation. This enables either: | ||
| - `kv_cache_dtype="auto"`: Use the model's default data type | ||
| - `kv_cache_dtype="fp8_e4m3"`: Supported on CUDA 11.8+ and ROCm (AMD GPUs) | ||
| - `kv_cache_dtype="fp8_e5m2"`: Supported on CUDA 11.8+ | ||
|
|
||
| - Processing longer context lengths for individual requests, or | ||
| - Handling more concurrent request batches | ||
| --- | ||
|
|
||
| However, there are currently no latency improvements as the implementation does not yet include fused dequantization and attention operations. Future releases will support quantized attention with hardware acceleration, which should provide additional performance benefits. While the most recent silicon offerings (e.g. AMD MI300, NVIDIA Hopper or later) support native hardware conversion between FP8 and other formats (fp32, fp16, bf16), this benefit is not yet fully realized. | ||
| ## Examples | ||
|
|
||
| Studies have shown that FP8 E4M3 quantization typically only minimally degrades inference accuracy, making it a practical choice for throughput optimization. | ||
| ### 1. No Calibration (`kv_cache_dtype="fp8"`, `calculate_kv_scales=False`) | ||
|
|
||
| ## Usage Example | ||
| All quantization scales are set to 1.0. | ||
|
|
||
| Here is an example of how to enable FP8 quantization: | ||
| ```python | ||
| from vllm import LLM, SamplingParams | ||
|
|
||
| ??? code | ||
| sampling_params = SamplingParams(temperature=0.7, top_p=0.8) | ||
| llm = LLM( | ||
| model="meta-llama/Llama-2-7b-chat-hf", | ||
| kv_cache_dtype="fp8", | ||
| calculate_kv_scales=False, | ||
| ) | ||
| prompt = "London is the capital of" | ||
| out = llm.generate(prompt, sampling_params)[0].outputs[0].text | ||
| print(out) | ||
| ``` | ||
|
|
||
| ```python | ||
| # To calculate kv cache scales on the fly enable the calculate_kv_scales | ||
| # parameter | ||
| --- | ||
|
|
||
| from vllm import LLM, SamplingParams | ||
| ### 2. Random Token Calibration (`kv_cache_dtype="fp8"`, `calculate_kv_scales=True`) | ||
|
|
||
| sampling_params = SamplingParams(temperature=0.7, top_p=0.8) | ||
| llm = LLM( | ||
| model="meta-llama/Llama-2-7b-chat-hf", | ||
| kv_cache_dtype="fp8", | ||
| calculate_kv_scales=True, | ||
| ) | ||
| prompt = "London is the capital of" | ||
| out = llm.generate(prompt, sampling_params)[0].outputs[0].text | ||
| print(out) | ||
| ``` | ||
| Scales are automatically estimated from a single batch of tokens during warmup. | ||
|
|
||
| The `kv_cache_dtype` argument specifies the data type for KV cache storage: | ||
| ```python | ||
| from vllm import LLM, SamplingParams | ||
|
|
||
| - `"auto"`: Uses the model's default "unquantized" data type | ||
| - `"fp8"` or `"fp8_e4m3"`: Supported on CUDA 11.8+ and ROCm (AMD GPU) | ||
| - `"fp8_e5m2"`: Supported on CUDA 11.8+ | ||
| sampling_params = SamplingParams(temperature=0.7, top_p=0.8) | ||
| llm = LLM( | ||
| model="meta-llama/Llama-2-7b-chat-hf", | ||
| kv_cache_dtype="fp8", | ||
| calculate_kv_scales=True, | ||
| ) | ||
| prompt = "London is the capital of" | ||
| out = llm.generate(prompt, sampling_params)[0].outputs[0].text | ||
| print(out) | ||
| ``` | ||
|
|
||
| ## Calibrated Scales for Better Accuracy | ||
| --- | ||
|
|
||
| For optimal model quality when using FP8 KV Cache, we recommend using calibrated scales tuned to representative inference data. [LLM Compressor](https://github.com/vllm-project/llm-compressor/) is the recommended tool for this process. | ||
| ### 3. **[Recommended] Calibration Using a Dataset (with `llm-compressor`)** | ||
|
|
||
| ### Installation | ||
| For the highest-quality quantization, we recommend calibrating against a dataset using `llm-compressor`. This enables advanced strategies such as per-attention-head quantization. | ||
|
|
||
| First, install the required dependencies: | ||
| #### Install the required package | ||
|
|
||
| ```bash | ||
| pip install llmcompressor | ||
| ``` | ||
|
|
||
| ### Example Usage | ||
| #### Example: Quantize Llama Attention & KV Cache to FP8 | ||
|
|
||
| Here's a complete example using `meta-llama/Llama-3.1-8B-Instruct` (most models can use this same pattern): | ||
|
|
||
| ??? code | ||
| ```python | ||
| """ | ||
| Quantize Llama attention + KV cache to FP8 (choose either 'tensor' or 'attn_head' strategy) | ||
| using llm-compressor one-shot calibration. | ||
| """ | ||
|
|
||
| from datasets import load_dataset | ||
| from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
|
||
| from llmcompressor import oneshot | ||
| from llmcompressor.modifiers.quantization import QuantizationModifier | ||
| from compressed_tensors.quantization import QuantizationScheme, QuantizationArgs | ||
|
|
||
| # ----------------------------- | ||
| # Config | ||
| # ----------------------------- | ||
| MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct" | ||
| DATASET_ID = "HuggingFaceH4/ultrachat_200k" | ||
| DATASET_SPLIT = "train_sft" | ||
| STRATEGY = "tensor" # or "attn_head" | ||
| NUM_CALIB_SAMPLES = 512 # Good starting value | ||
| MAX_SEQ_LEN = 2048 | ||
|
|
||
| # ----------------------------- | ||
| # Helpers | ||
| # ----------------------------- | ||
| def process_and_tokenize(example, tokenizer: AutoTokenizer): | ||
| """Convert chat messages to tokens.""" | ||
| text = tokenizer.apply_chat_template(example["messages"], tokenize=False) | ||
| return tokenizer( | ||
| text, | ||
| padding=False, | ||
| max_length=MAX_SEQ_LEN, | ||
| truncation=True, | ||
| add_special_tokens=False, | ||
| ) | ||
|
|
||
| ```python | ||
| from datasets import load_dataset | ||
| from transformers import AutoModelForCausalLM, AutoTokenizer | ||
| from llmcompressor import oneshot | ||
| def build_recipe(strategy: str) -> QuantizationModifier: | ||
| fp8_args = QuantizationArgs(num_bits=8, type="float", strategy=strategy) | ||
| return QuantizationModifier( | ||
| config_groups={ | ||
| "attention": QuantizationScheme( | ||
| targets=["LlamaAttention"], # Quantize queries: q_scale | ||
| input_activations=fp8_args, | ||
| ) | ||
| }, | ||
| kv_cache_scheme=fp8_args, # Quantize KV cache: k/v_scale | ||
| ) | ||
|
|
||
| # Select model and load it | ||
| MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct" | ||
| model = AutoModelForCausalLM.from_pretrained(MODEL_ID, device_map="auto", dtype="auto") | ||
| # ----------------------------- | ||
| # Main | ||
| # ----------------------------- | ||
| def main(): | ||
| model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto") | ||
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | ||
| ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIB_SAMPLES}]") | ||
| ds = ds.shuffle(seed=42) | ||
| ds = ds.map( | ||
| lambda ex: process_and_tokenize(ex, tokenizer), | ||
| remove_columns=ds.column_names, | ||
| ) | ||
|
|
||
| # Select calibration dataset | ||
| DATASET_ID = "HuggingFaceH4/ultrachat_200k" | ||
| DATASET_SPLIT = "train_sft" | ||
|
|
||
| # Configure calibration parameters | ||
| NUM_CALIBRATION_SAMPLES = 512 # 512 samples is a good starting point | ||
| MAX_SEQUENCE_LENGTH = 2048 | ||
|
|
||
| # Load and preprocess dataset | ||
| ds = load_dataset(DATASET_ID, split=DATASET_SPLIT) | ||
| ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES)) | ||
|
|
||
| def process_and_tokenize(example): | ||
| text = tokenizer.apply_chat_template(example["messages"], tokenize=False) | ||
| return tokenizer( | ||
| text, | ||
| padding=False, | ||
| max_length=MAX_SEQUENCE_LENGTH, | ||
| truncation=True, | ||
| add_special_tokens=False, | ||
| ) | ||
|
|
||
| ds = ds.map(process_and_tokenize, remove_columns=ds.column_names) | ||
|
|
||
| # Configure quantization settings | ||
| recipe = """ | ||
| quant_stage: | ||
| quant_modifiers: | ||
| QuantizationModifier: | ||
| kv_cache_scheme: | ||
| num_bits: 8 | ||
| type: float | ||
| strategy: tensor | ||
| dynamic: false | ||
| symmetric: true | ||
| """ | ||
|
|
||
| # Apply quantization | ||
| recipe = build_recipe(STRATEGY) | ||
| oneshot( | ||
| model=model, | ||
| dataset=ds, | ||
| recipe=recipe, | ||
| max_seq_length=MAX_SEQUENCE_LENGTH, | ||
| num_calibration_samples=NUM_CALIBRATION_SAMPLES, | ||
| max_seq_length=MAX_SEQ_LEN, | ||
| num_calibration_samples=NUM_CALIB_SAMPLES, | ||
| ) | ||
|
|
||
| # Save quantized model: Llama-3.1-8B-Instruct-FP8-KV | ||
| SAVE_DIR = MODEL_ID.split("/")[1] + "-FP8-KV" | ||
| model.save_pretrained(SAVE_DIR, save_compressed=True) | ||
| tokenizer.save_pretrained(SAVE_DIR) | ||
| ``` | ||
|
|
||
| The above script will create a folder in your current directory containing your quantized model (e.g., `Llama-3.1-8B-Instruct-FP8-KV`) with calibrated scales. | ||
| save_dir = f"{MODEL_ID.rstrip('/').split('/')[-1]}-kvattn-fp8-{STRATEGY}" | ||
| model.save_pretrained(save_dir, save_compressed=True) | ||
| tokenizer.save_pretrained(save_dir) | ||
|
|
||
| When running the model you must specify `kv_cache_dtype="fp8"` in order to enable the kv cache quantization and use the scales. | ||
|
|
||
| ```python | ||
| from vllm import LLM, SamplingParams | ||
|
|
||
| sampling_params = SamplingParams(temperature=0.7, top_p=0.8) | ||
| llm = LLM(model="Llama-3.1-8B-Instruct-FP8-KV", kv_cache_dtype="fp8") | ||
| prompt = "London is the capital of" | ||
| out = llm.generate(prompt, sampling_params)[0].outputs[0].text | ||
| print(out) | ||
| if __name__ == "__main__": | ||
| main() | ||
| ``` | ||
|
|
||
| For more detailed and up-to-date examples, see the [`llm-compressor` official examples](https://github.com/vllm-project/llm-compressor/tree/main/examples/quantization_kv_cache). | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The support matrix for kv cache dtypes is a bit more complicated than this. I'd recommend leaving out any comment about support for now, and when #32477 lands we can point to that. This section is also a bit misleading and makes it seem like
fp8is an option which is distinct fromfp8_e4m3andfp8_e5m2(when it actually acts as effectively an alias forfp8_e4m3). We should make that behavior clear. Also, can we usecalculate_kv_scaleswithfp8_e5m2?