Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 33 additions & 12 deletions csrc/cache_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,8 @@ __global__ void reshape_and_cache_flash_kernel(
const int64_t block_stride, const int64_t page_stride,
const int64_t head_stride, const int64_t key_stride,
const int64_t value_stride, const int num_heads, const int head_size,
const int block_size, const float* k_scale, const float* v_scale) {
const int block_size, const float* k_scale, const float* v_scale,
const int kv_scale_stride) {
const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx];
// NOTE: slot_idx can be -1 if the token is padded
Expand All @@ -226,21 +227,23 @@ __global__ void reshape_and_cache_flash_kernel(
// this is true for the NHD layout where `head_stride == head_size`
const bool is_contiguous_heads = (head_stride == head_size);

float k_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto) ? 0.f : *k_scale;
float v_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto) ? 0.f : *v_scale;
constexpr int VEC_SIZE = (sizeof(scalar_t) == 2) ? 8 : 4;
CopyWithScaleOp<cache_t, scalar_t, kv_dt> k_op{k_scale_val};
CopyWithScaleOp<cache_t, scalar_t, kv_dt> v_op{v_scale_val};
if (is_contiguous_heads) {
// NHD layout

if (is_contiguous_heads && kv_scale_stride == 0) {
// NHD layout and k/v_scales are [1] (i.e. single scale for all heads)
// kv cache: [num_blocks, block_size, num_heads, head_size]
float k_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto) ? 0.f : *k_scale;
float v_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto) ? 0.f : *v_scale;

CopyWithScaleOp<cache_t, scalar_t, kv_dt> k_op{k_scale_val};
CopyWithScaleOp<cache_t, scalar_t, kv_dt> v_op{v_scale_val};

vectorize_with_alignment<VEC_SIZE>(key_src, key_dst, n_elems, threadIdx.x,
blockDim.x, k_op);

vectorize_with_alignment<VEC_SIZE>(value_src, value_dst, n_elems,
threadIdx.x, blockDim.x, v_op);

} else {
// HND layout OR k/v_scales are [num_heads] (i.e. per-attn-head)
// HND layout: heads are strided, but each head_size segment is contiguous
// kv cache: [num_blocks, num_heads, block_size, head_size]
const int lane = threadIdx.x & 31; // 0..31 within warp
Expand All @@ -256,6 +259,16 @@ __global__ void reshape_and_cache_flash_kernel(
cache_t* __restrict__ v_dst_h =
value_dst + static_cast<int64_t>(head) * head_stride;

float k_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto)
? 0.f
: k_scale[head * kv_scale_stride];
float v_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto)
? 0.f
: v_scale[head * kv_scale_stride];

CopyWithScaleOp<cache_t, scalar_t, kv_dt> k_op{k_scale_val};
CopyWithScaleOp<cache_t, scalar_t, kv_dt> v_op{v_scale_val};

// within each head, let the 32 threads of the warp perform the vector
// copy
vectorize_with_alignment<VEC_SIZE>(k_src_h, k_dst_h, head_size, lane, 32,
Expand Down Expand Up @@ -605,7 +618,8 @@ void reshape_and_cache(
slot_mapping.data_ptr<int64_t>(), block_stride, page_stride, \
head_stride, key_stride, value_stride, num_heads, head_size, \
block_size, reinterpret_cast<const float*>(k_scale.data_ptr()), \
reinterpret_cast<const float*>(v_scale.data_ptr()));
reinterpret_cast<const float*>(v_scale.data_ptr()), \
kv_scale_stride);

void reshape_and_cache_flash(
torch::Tensor& key, // [num_tokens, num_heads, head_size]
Expand All @@ -614,8 +628,9 @@ void reshape_and_cache_flash(
torch::Tensor&
value_cache, // [num_blocks, block_size, num_heads, head_size]
torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens]
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale) {
const std::string& kv_cache_dtype,
torch::Tensor& k_scale, // [1] or [num_heads]
torch::Tensor& v_scale) { // [1] or [num_heads]
// NOTE(woosuk): In vLLM V1, key.size(0) can be different from
// slot_mapping.size(0) because of padding for CUDA graphs.
// In vLLM V0, key.size(0) is always equal to slot_mapping.size(0) because
Expand All @@ -638,6 +653,12 @@ void reshape_and_cache_flash(
int64_t head_stride = key_cache.stride(2);
TORCH_CHECK(key_cache.stride(0) == value_cache.stride(0));

TORCH_CHECK(k_scale.sizes() == v_scale.sizes(),
"k_scale and v_scale must have the same shape");
TORCH_CHECK(k_scale.numel() == 1 || k_scale.numel() == num_heads,
"k_scale and v_scale must be of shape [1] or [num_heads]");
int kv_scale_stride = (k_scale.numel() > 1) ? 1 : 0;

dim3 grid(num_tokens);
dim3 block(std::min(num_heads * head_size, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
Expand Down
253 changes: 139 additions & 114 deletions docs/features/quantization/quantized_kvcache.md
Original file line number Diff line number Diff line change
@@ -1,162 +1,187 @@
# Quantized KV Cache

## FP8 KV Cache
## FP8 KV Cache Overview

Quantizing the KV cache to FP8 reduces its memory footprint. This increases the number of tokens that can be stored in the cache, improving throughput.
Efficient memory usage is crucial for working with large language models. Quantizing the KV (Key-Value) cache to FP8 format can significantly reduce its memory footprint. This optimization enables you to store more tokens in memory, leading to improved throughput and support for longer context windows.

### FP8 Formats
> **Note:** When using the Flash Attention 3 backend with FP8 KV cache, attention operations are also performed in the quantized (FP8) domain. In this configuration, queries are quantized to FP8 in addition to keys and values.

[OCP (Open Compute Project)](https://www.opencompute.org) specifies two common 8-bit floating point data formats:
### Supported FP8 KV-Cache Quantization Schemes

- E5M2 (5 exponent bits and 2 mantissa bits)
- E4M3FN (4 exponent bits and 3 mantissa bits, often shortened as E4M3)
vLLM supports two main quantization strategies for the FP8 KV-cache:

The E4M3 format offers higher precision compared to E5M2. However, due to its small dynamic range (±240.0), E4M3 typically requires a higher-precision (FP32) scaling factor alongside each quantized tensor.
- **Per-tensor quantization:**
A single scale is applied for each Q, K, and V tensor individually. (`q/k/v_scale = [1]`)
- **Per-attention-head quantization:**
Each scale corresponds to an attention head: `q_scale = [num_heads]`, `k/v_scale = [num_kv_heads]`.

### Current Limitations
> **Note:**
> Per-attention-head quantization is currently available **only with the Flash Attention backend** and requires the calibration pathway provided by **llm-compressor**.

For now, only per-tensor (scalar) scaling factors are supported. Development is ongoing to support scaling factors of a finer granularity (e.g. per-channel).
### Scale Calibration Approaches

### How FP8 KV Cache Works
You can configure how the quantization scales are computed in vLLM using three different approaches:

The FP8 KV cache implementation follows this workflow:
1. **No calibration (default scales):**
All quantization scales are set to `1.0`.
_Configure with:_
```python
kv_cache_dtype="fp8"
calculate_kv_scales=False
```

1. **Storage**: Key and Value tensors are quantized to FP8 format using scaling factors before being stored in the KV cache
2. **Retrieval**: When needed for attention computation, cached KV tensors are dequantized back to higher precision (FP16/BF16)
3. **Attention**: The attention-value multiplication (softmax output × V) is performed using the dequantized higher-precision V tensor
2. **Random token calibration (on-the-fly):**
Scales are automatically estimated from a single batch of random tokens during warmup and then fixed.
_Configure with:_
```python
kv_cache_dtype="fp8"
calculate_kv_scales=True
```

This means the final attention computation operates on dequantized values, not FP8 tensors. The quantization reduces memory usage during storage but maintains computation accuracy by using higher precision during the actual attention operations.
3. **[Recommended] Calibration with a dataset (via llm-compressor):**
Scales are estimated using a curated calibration dataset for maximum accuracy.
This requires the [llm-compressor](https://github.com/vllm-project/llm-compressor) library.
_See example below!_

### Performance Impact
#### Additional `kv_cache_dtype` Options

The current FP8 KV cache implementation primarily benefits throughput by allowing approximately double the amount of space for KV cache allocation. This enables either:
- `kv_cache_dtype="auto"`: Use the model's default data type
- `kv_cache_dtype="fp8_e4m3"`: Supported on CUDA 11.8+ and ROCm (AMD GPUs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The support matrix for kv cache dtypes is a bit more complicated than this. I'd recommend leaving out any comment about support for now, and when #32477 lands we can point to that. This section is also a bit misleading and makes it seem like fp8 is an option which is distinct from fp8_e4m3 and fp8_e5m2 (when it actually acts as effectively an alias for fp8_e4m3). We should make that behavior clear. Also, can we use calculate_kv_scales with fp8_e5m2?

- `kv_cache_dtype="fp8_e5m2"`: Supported on CUDA 11.8+

- Processing longer context lengths for individual requests, or
- Handling more concurrent request batches
---

However, there are currently no latency improvements as the implementation does not yet include fused dequantization and attention operations. Future releases will support quantized attention with hardware acceleration, which should provide additional performance benefits. While the most recent silicon offerings (e.g. AMD MI300, NVIDIA Hopper or later) support native hardware conversion between FP8 and other formats (fp32, fp16, bf16), this benefit is not yet fully realized.
## Examples

Studies have shown that FP8 E4M3 quantization typically only minimally degrades inference accuracy, making it a practical choice for throughput optimization.
### 1. No Calibration (`kv_cache_dtype="fp8"`, `calculate_kv_scales=False`)

## Usage Example
All quantization scales are set to 1.0.

Here is an example of how to enable FP8 quantization:
```python
from vllm import LLM, SamplingParams

??? code
sampling_params = SamplingParams(temperature=0.7, top_p=0.8)
llm = LLM(
model="meta-llama/Llama-2-7b-chat-hf",
kv_cache_dtype="fp8",
calculate_kv_scales=False,
)
prompt = "London is the capital of"
out = llm.generate(prompt, sampling_params)[0].outputs[0].text
print(out)
```

```python
# To calculate kv cache scales on the fly enable the calculate_kv_scales
# parameter
---

from vllm import LLM, SamplingParams
### 2. Random Token Calibration (`kv_cache_dtype="fp8"`, `calculate_kv_scales=True`)

sampling_params = SamplingParams(temperature=0.7, top_p=0.8)
llm = LLM(
model="meta-llama/Llama-2-7b-chat-hf",
kv_cache_dtype="fp8",
calculate_kv_scales=True,
)
prompt = "London is the capital of"
out = llm.generate(prompt, sampling_params)[0].outputs[0].text
print(out)
```
Scales are automatically estimated from a single batch of tokens during warmup.

The `kv_cache_dtype` argument specifies the data type for KV cache storage:
```python
from vllm import LLM, SamplingParams

- `"auto"`: Uses the model's default "unquantized" data type
- `"fp8"` or `"fp8_e4m3"`: Supported on CUDA 11.8+ and ROCm (AMD GPU)
- `"fp8_e5m2"`: Supported on CUDA 11.8+
sampling_params = SamplingParams(temperature=0.7, top_p=0.8)
llm = LLM(
model="meta-llama/Llama-2-7b-chat-hf",
kv_cache_dtype="fp8",
calculate_kv_scales=True,
)
prompt = "London is the capital of"
out = llm.generate(prompt, sampling_params)[0].outputs[0].text
print(out)
```

## Calibrated Scales for Better Accuracy
---

For optimal model quality when using FP8 KV Cache, we recommend using calibrated scales tuned to representative inference data. [LLM Compressor](https://github.com/vllm-project/llm-compressor/) is the recommended tool for this process.
### 3. **[Recommended] Calibration Using a Dataset (with `llm-compressor`)**

### Installation
For the highest-quality quantization, we recommend calibrating against a dataset using `llm-compressor`. This enables advanced strategies such as per-attention-head quantization.

First, install the required dependencies:
#### Install the required package

```bash
pip install llmcompressor
```

### Example Usage
#### Example: Quantize Llama Attention & KV Cache to FP8

Here's a complete example using `meta-llama/Llama-3.1-8B-Instruct` (most models can use this same pattern):

??? code
```python
"""
Quantize Llama attention + KV cache to FP8 (choose either 'tensor' or 'attn_head' strategy)
using llm-compressor one-shot calibration.
"""

from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import QuantizationModifier
from compressed_tensors.quantization import QuantizationScheme, QuantizationArgs

# -----------------------------
# Config
# -----------------------------
MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct"
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
DATASET_SPLIT = "train_sft"
STRATEGY = "tensor" # or "attn_head"
NUM_CALIB_SAMPLES = 512 # Good starting value
MAX_SEQ_LEN = 2048

# -----------------------------
# Helpers
# -----------------------------
def process_and_tokenize(example, tokenizer: AutoTokenizer):
"""Convert chat messages to tokens."""
text = tokenizer.apply_chat_template(example["messages"], tokenize=False)
return tokenizer(
text,
padding=False,
max_length=MAX_SEQ_LEN,
truncation=True,
add_special_tokens=False,
)

```python
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from llmcompressor import oneshot
def build_recipe(strategy: str) -> QuantizationModifier:
fp8_args = QuantizationArgs(num_bits=8, type="float", strategy=strategy)
return QuantizationModifier(
config_groups={
"attention": QuantizationScheme(
targets=["LlamaAttention"], # Quantize queries: q_scale
input_activations=fp8_args,
)
},
kv_cache_scheme=fp8_args, # Quantize KV cache: k/v_scale
)

# Select model and load it
MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct"
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, device_map="auto", dtype="auto")
# -----------------------------
# Main
# -----------------------------
def main():
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIB_SAMPLES}]")
ds = ds.shuffle(seed=42)
ds = ds.map(
lambda ex: process_and_tokenize(ex, tokenizer),
remove_columns=ds.column_names,
)

# Select calibration dataset
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
DATASET_SPLIT = "train_sft"

# Configure calibration parameters
NUM_CALIBRATION_SAMPLES = 512 # 512 samples is a good starting point
MAX_SEQUENCE_LENGTH = 2048

# Load and preprocess dataset
ds = load_dataset(DATASET_ID, split=DATASET_SPLIT)
ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES))

def process_and_tokenize(example):
text = tokenizer.apply_chat_template(example["messages"], tokenize=False)
return tokenizer(
text,
padding=False,
max_length=MAX_SEQUENCE_LENGTH,
truncation=True,
add_special_tokens=False,
)

ds = ds.map(process_and_tokenize, remove_columns=ds.column_names)

# Configure quantization settings
recipe = """
quant_stage:
quant_modifiers:
QuantizationModifier:
kv_cache_scheme:
num_bits: 8
type: float
strategy: tensor
dynamic: false
symmetric: true
"""

# Apply quantization
recipe = build_recipe(STRATEGY)
oneshot(
model=model,
dataset=ds,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
max_seq_length=MAX_SEQ_LEN,
num_calibration_samples=NUM_CALIB_SAMPLES,
)

# Save quantized model: Llama-3.1-8B-Instruct-FP8-KV
SAVE_DIR = MODEL_ID.split("/")[1] + "-FP8-KV"
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)
```

The above script will create a folder in your current directory containing your quantized model (e.g., `Llama-3.1-8B-Instruct-FP8-KV`) with calibrated scales.
save_dir = f"{MODEL_ID.rstrip('/').split('/')[-1]}-kvattn-fp8-{STRATEGY}"
model.save_pretrained(save_dir, save_compressed=True)
tokenizer.save_pretrained(save_dir)

When running the model you must specify `kv_cache_dtype="fp8"` in order to enable the kv cache quantization and use the scales.

```python
from vllm import LLM, SamplingParams

sampling_params = SamplingParams(temperature=0.7, top_p=0.8)
llm = LLM(model="Llama-3.1-8B-Instruct-FP8-KV", kv_cache_dtype="fp8")
prompt = "London is the capital of"
out = llm.generate(prompt, sampling_params)[0].outputs[0].text
print(out)
if __name__ == "__main__":
main()
```

For more detailed and up-to-date examples, see the [`llm-compressor` official examples](https://github.com/vllm-project/llm-compressor/tree/main/examples/quantization_kv_cache).
Loading
Loading