Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 134 additions & 0 deletions examples/sparse_attention/RocketKV.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# RocketKV Sparse Attention

This document details enabling RocketKV sparse attention within TensorRT LLM.

RocketKV is a training-free, two-stage KV cache compression method designed to accelerate long-context LLM inference. It combines permanent KV token eviction (in context phase) with dynamic KV token selection (in generation phase) to significantly reduce memory bandwidth usage and increase throughput while maintaining high accuracy.

For more technical details, please refer to the paper: [RocketKV: Accelerating Long-Context LLM Inference via Two-Stage KV Cache Compression](https://arxiv.org/pdf/2502.14051). Here is an official implement which provides a reference: [RocketKV Repo](https://github.com/NVlabs/RocketKV).

## Overview

In Transformer-based LLM inference, the KV cache grows linearly with sequence length, becoming a major bottleneck. RocketKV mitigates this issue through a two-stage process:

1. **Context Phase (Stage 1):** It performs **permanent KV cache eviction**. Instead of storing the full history, it selects and keeps a `prompt_budget` of the most important tokens based on attention scores.
2. **Generation Phase (Stage 2):** It utilizes a **dynamic Top-K token selection**. It maintains a lightweight, compressed auxiliary cache (KT Cache) to dynamically predict which tokens of the KV cache are relevant for the current token, and loading only those tokens to do the attention computation.

RocketKV is integrated into TensorRT LLM as a specialized attention backend, accessible via the LLM API. Specifically, the core sparse KV prediction kernels are implemented using **Triton** kernels, achieving highly optimized performance on modern NVIDIA GPUs.

## Support Matrix

* GPU Compute Capability >= 10.0 (Blackwell or newer)
* FP16 / BF16 / FP8
* Paged KV Cache
* Tensor Parallel
* Cuda Graph

**Note:**
1. RocketKV currently requires `enable_block_reuse=False` in the KV cache configuration, as the sparse eviction logic is incompatible with standard block reuse mechanisms.
2. RocketKV doesn't support `enable_chunked_prefill=True` for now.
3. RocketKV doesn't support *disagg-serving* as well, since it needs the KV cache transmission from prefill engine to the decode engine. But currently RocketKV uses a python kt cache manager and it cannot support this transmission.

## Usage

To enable RocketKV, configure `RocketSparseAttentionConfig` and pass it to the `LLM` class constructor.

### Python API

Integrate RocketKV into your workflows using the `tensorrt_llm.llmapi` interface.

```python
from tensorrt_llm import LLM, SamplingParams
from tensorrt_llm.llmapi import RocketSparseAttentionConfig, KvCacheConfig

# 1. Define the RocketKV Sparse Attention Configuration
rocket_config = RocketSparseAttentionConfig(
window_size=32, # Size of the recent window to always keep
kernel_size=63, # Pooling kernel size for importance scoring
prompt_budget=2048, # Number of tokens to keep from the prompt (Stage 1)
topk=64, # Number of tokens to retrieve during generation (Stage 2)
topr=128, # Number of query channels to keep for scoring
kt_cache_dtype='float8_e5m2' # Dtype for the auxiliary Key-Token cache
)

# 2. Initialize the LLM with the config and 'pytorch' backend
# Note: Block reuse must be disabled for RocketKV
kv_config = KvCacheConfig(enable_block_reuse=False)

llm = LLM(
model="<path_to_model>",
backend='pytorch', # RocketKV currently requires the PyTorch backend
sparse_attention_config=rocket_config,
kv_cache_config=kv_config,
)

# 3. Generate
prompts = ["To be or not to be, that is the question."]
sampling_params = SamplingParams(max_tokens=128)
outputs = llm.generate(prompts, sampling_params)
```

### Running the Example Script

We provide a reference script `examples/llm-api/llm_sparse_attention.py` to demonstrate RocketKV capabilities.

**Example Command:**

```bash
# Adjust --model_path to your local Llama checkpoint
python3 ../llm-api/llm_sparse_attention.py \
--model_path <path_to_model> \
--algo ROCKETKV \
--attention_backend TRTLLM \
--window_size 32 \
--kernel_size 63 \
--prompt_budget 2048 \
--topk 64 \
--topr 128 \
--kt_cache_dtype float8_e5m2 \
--max_seq_len 10240 \
--max_num_tokens 10240 \
--max_new_tokens 128
```


### Usage with `trtllm-bench` and `trtllm-serve`

Sparse attention options must be specified via `--extra_llm_api_options config.yaml` for both `trtllm-bench` and `trtllm-serve`. All sparse attetnion options can be specified in this YAML file and the argument names/valid values are the same as in their corresponding configuration described in the Configuration Arguments section. For example, a YAML configuration could look like this:

```
backend: pytorch
attn_backend: TRTLLM
sparse_attention_config:
algorithm: rocket
kt_cache_dtype: float8_e5m2
window_size: 32
prompt_budget: 2048
kv_cache_config:
enable_block_reuse: false
enable_chunked_prefill: false
```

Run the command with the config file:
```bash
trtllm-eval/trtllm-bench/trtllm-serve --model <model_path> --extra_llm_api_options extra_config.yaml ...
```

For example, users can evaluate a model with trtllm-eval on LongBenchV2 task like this:

```bash
trtllm-eval --model <path_to_model> --extra_llm_api_options extra_config.yaml longbench_v2 --max_output_length 1024 ...
```

## Configuration Arguments

The `RocketSparseAttentionConfig` allows fine-grained control over compression ratios and performance trade-offs:

* **`prompt_budget`** (int, default=2048): The number of tokens to retain from the input prompt (context). RocketKV compresses the prompt to this size by evicting less important tokens based on importance scores.
* **`topk`** (int, default=64): The number of KT pages to select dynamically during the generation phase. Note that the selection is performed at the granularity of KT cache pages, but the actual attention kernel retrieves data based on the granularity of KV cache page size.
* **`topr`** (int/float, default=128): The number of query feature dimensions to use when computing the relevance score between Query and KT Cache. This acts as a dimensionality reduction to speed up the selection process. However, it's recommended to set it equal to `head_dim` to skip `topr_filter` computations for better performance and accuracy.
* **`window_size`** (int, default=32): The size of the sliding window in RocketKV. In the context phase, RocketKV uses the last `window_size` tokens of the Query and the Key prefix to compute importance scores for eviction. These recent tokens are always retained in the cache, and `prompt_budget-window_size` important tokens in the prefix are retained in the cache also.
* **`kernel_size`** (int, default=63): The size of the 1D max-pooling kernel used in the context phase. It smooths attention scores to better identify locally important regions rather than just isolated high-score tokens.
* **`kt_cache_dtype`** (str, default='float8_e5m2'): The data type for the auxiliary "Key-Token" (KT) cache used for relevance prediction.
* `float8_e5m2`: Recommended. Provides memory savings for the auxiliary structure and speedup for the prediction kernels.
* `bfloat16`: Standard precision.
* **`page_size`** (int, default=4): The granularity of the sparse token selection (KT page). Currently, only **powers of 2** are supported due to Triton kernel limitations. Accuracy is generally maintained well for `page_size <= 4`.
87 changes: 49 additions & 38 deletions tensorrt_llm/evaluate/longbench_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,23 +51,26 @@ class LongBenchV2(Evaluator):
DIFFICULTIES = ['easy', 'hard']
LENGTHS = ['short', 'medium', 'long']

def __init__(self,
dataset_path: str = 'THUDM/LongBench-v2',
prompts_dir: Optional[str] = None,
num_samples: Optional[int] = None,
start_idx: int = 0,
difficulty: Optional[str] = None,
length: str = 'medium',
domain: Optional[str] = None,
cot: bool = False,
no_context: bool = False,
rag: int = 0,
max_len: int = 128000,
output_dir: Optional[str] = None,
random_seed: int = 0,
apply_chat_template: bool = False,
system_prompt: Optional[str] = None,
chat_template_kwargs: Optional[dict[str, Any]] = None):
def __init__(
self,
dataset_path: str = 'THUDM/LongBench-v2',
prompts_dir: Optional[str] = None,
num_samples: Optional[int] = None,
start_idx: int = 0,
difficulty: Optional[str] = None,
length: str = 'medium',
domain: Optional[str] = None,
cot: bool = False,
no_context: bool = False,
rag: int = 0,
max_len: int = 128000,
output_dir: Optional[str] = None,
random_seed: int = 0,
apply_chat_template: bool = False,
system_prompt: Optional[str] = None,
max_output_length: int = 32000,
chat_template_kwargs: Optional[dict[str, Any]] = None,
):
"""Initialize LongBench v2 evaluator.

Args:
Expand All @@ -86,6 +89,7 @@ def __init__(self,
random_seed: Random seed for reproducibility
apply_chat_template: Whether to apply model's chat template
system_prompt: System prompt to prepend
max_output_length: Maximum output length in tokens. Should keep this value as small as possible to avoid unnecessary truncation.
chat_template_kwargs: Chat template kwargs as JSON string
"""
super().__init__(random_seed=random_seed,
Expand All @@ -103,6 +107,7 @@ def __init__(self,
self.no_context = no_context
self.rag = rag
self.max_len = max_len
self.max_output_length = max_output_length
self.output_dir = output_dir

# Will be set during evaluation
Expand Down Expand Up @@ -307,10 +312,11 @@ def _post_process(self, pred: str) -> str:
return pred

def _truncate_prompt(self, prompt: str, tokenizer: Any) -> str:
"""Truncate prompt to max_len tokens using needle-in-haystack strategy.
"""Truncate prompt using needle-in-haystack strategy.

If the prompt exceeds max_len, it takes the first half and last half
If the prompt exceeds (max_len - max_output_length), it takes the first half and last half
to preserve both context beginning and end.
We need to minus max_output_length from max_len to reserve budget for output tokens.

Args:
prompt: The prompt string to truncate
Expand All @@ -325,8 +331,9 @@ def _truncate_prompt(self, prompt: str, tokenizer: Any) -> str:
try:
input_ids = tokenizer.encode(prompt, add_special_tokens=False)

if len(input_ids) > self.max_len:
half = self.max_len // 2
max_input_len = self.max_len - self.max_output_length
if len(input_ids) > max_input_len:
half = max_input_len // 2
truncated_ids = input_ids[:half] + input_ids[-half:]
prompt = tokenizer.decode(truncated_ids,
skip_special_tokens=True)
Expand Down Expand Up @@ -791,7 +798,8 @@ def _save_results(self, results: List[Dict], metrics: Dict[str, float]):
type=int,
default=128000,
help=
"Maximum prompt length in tokens for truncation when building prompts.")
"Maximum input and output length in tokens for truncation when building prompts."
)
@click.option("--output_dir",
type=str,
default=None,
Expand Down Expand Up @@ -843,22 +851,25 @@ def command(ctx, dataset_path: str, prompts_dir: Optional[str],
temperature=0.6,
top_p=0.95)

evaluator = LongBenchV2(dataset_path=dataset_path,
prompts_dir=prompts_dir,
num_samples=num_samples,
start_idx=start_idx,
difficulty=difficulty,
length=length,
domain=domain,
cot=cot,
no_context=no_context,
rag=rag,
max_len=max_len,
output_dir=output_dir,
random_seed=random_seed,
apply_chat_template=apply_chat_template,
system_prompt=system_prompt,
chat_template_kwargs=chat_template_kwargs)
evaluator = LongBenchV2(
dataset_path=dataset_path,
prompts_dir=prompts_dir,
num_samples=num_samples,
start_idx=start_idx,
difficulty=difficulty,
length=length,
domain=domain,
cot=cot,
no_context=no_context,
rag=rag,
max_len=max_len,
output_dir=output_dir,
random_seed=random_seed,
apply_chat_template=apply_chat_template,
system_prompt=system_prompt,
max_output_length=max_output_length,
chat_template_kwargs=chat_template_kwargs,
)

evaluator.evaluate(llm, sampling_params)
llm.shutdown()
3 changes: 3 additions & 0 deletions tests/integration/defs/accuracy/references/longbench_v2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,6 @@ DeepSeek-R1-0528:
kv_cache_quant_algo: FP8
spec_dec_algo: MTP
accuracy: 52.093
meta-llama/Llama-3.1-8B-Instruct:
- accuracy: 26.48
sigma: 25.8
50 changes: 49 additions & 1 deletion tests/integration/defs/accuracy/test_llm_api_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
from tensorrt_llm.llmapi import (AutoDecodingConfig, CudaGraphConfig,
EagleDecodingConfig, KvCacheConfig, MoeConfig,
MTPDecodingConfig, NGramDecodingConfig,
SamplingParams, TorchCompileConfig)
RocketSparseAttentionConfig, SamplingParams,
TorchCompileConfig)
from tensorrt_llm.quantization import QuantAlgo

from ..conftest import (get_device_count, get_device_memory, llm_models_root,
Expand Down Expand Up @@ -4611,3 +4612,50 @@ def test_auto_dtype(self):
max_seq_len=4096) as llm:
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)


@skip_pre_blackwell
class TestLlama3_1_8B_Instruct_LongBenchV2(LlmapiAccuracyTestHarness):
MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
MODEL_PATH = f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct/"

def test_auto_dtype(self):
model_dir = f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct/"
if not os.path.exists(model_dir):
pytest.skip(f"Model directory {model_dir} does not exist")

# Configure model settings
kv_cache_config = KvCacheConfig(enable_block_reuse=False)

cuda_graph_config = CudaGraphConfig(enable_padding=True,
max_batch_size=64)

sparse_attention_config = RocketSparseAttentionConfig(
kt_cache_dtype="float8_e5m2", )

pytorch_config = dict(cuda_graph_config=cuda_graph_config,
kv_cache_config=kv_cache_config,
sparse_attention_config=sparse_attention_config,
enable_chunked_prefill=False)

MAX_LEN = 128000
MAX_NEW_TOKENS = 1024

with LLM(model_dir,
max_seq_len=MAX_LEN,
max_num_tokens=128000,
max_batch_size=64,
**pytorch_config) as llm:
task = LongBenchV2(self.MODEL_NAME)

sampling_params = SamplingParams(
max_tokens=MAX_NEW_TOKENS,
temperature=0.8,
top_p=0.95,
)

extra_evaluator_kwargs = dict(max_len=MAX_LEN,
max_output_length=MAX_NEW_TOKENS)
task.evaluate(llm,
sampling_params=sampling_params,
extra_evaluator_kwargs=extra_evaluator_kwargs)
1 change: 1 addition & 0 deletions tests/integration/test_lists/test-db/l0_b200.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ l0_b200:
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_w4a8_mxfp4[mxfp8-latency-CUTLASS]
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_w4a16_mxfp4[latency-TRTLLM]
- accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[tp1-cutlass]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8B_Instruct_LongBenchV2::test_auto_dtype
- disaggregated/test_workers.py::test_workers_kv_cache_aware_router_eviction[TinyLlama-1.1B-Chat-v1.0] # nvbugs 5300551
- test_e2e.py::test_ptp_quickstart_advanced[Llama3.1-8B-NVFP4-nvfp4-quantized/Meta-Llama-3.1-8B]
- test_e2e.py::test_ptp_quickstart_advanced[Llama3.1-8B-FP8-llama-3.1-model/Llama-3.1-8B-Instruct-FP8]
Expand Down
Loading