From f692f2bfb26a052dce979ca831993ba560917c05 Mon Sep 17 00:00:00 2001 From: yuhangh <58161490+heyuhhh@users.noreply.github.com> Date: Thu, 27 Nov 2025 10:05:12 +0000 Subject: [PATCH 1/6] Add RocketKV doc Signed-off-by: yuhangh <58161490+heyuhhh@users.noreply.github.com> --- examples/sparse_attention/RocketKV.md | 103 ++++++++++++++++++++++++++ 1 file changed, 103 insertions(+) create mode 100644 examples/sparse_attention/RocketKV.md diff --git a/examples/sparse_attention/RocketKV.md b/examples/sparse_attention/RocketKV.md new file mode 100644 index 00000000000..dd7e64ddec0 --- /dev/null +++ b/examples/sparse_attention/RocketKV.md @@ -0,0 +1,103 @@ +# RocketKV Sparse Attention + +This document details how to run RocketKV sparse attention with TensorRT-LLM. + +RocketKV is a training-free, two-stage KV cache compression method designed to accelerate long-context LLM inference. It combines permanent KV token eviction (coarse-grain) with dynamic KV token selection (fine-grain) to significantly reduce memory bandwidth usage and increase throughput while maintaining high accuracy. + +For more technical details, please refer to the paper: [RocketKV: Accelerating Long-Context LLM Inference via Two-Stage KV Cache Compression](https://arxiv.org/pdf/2502.14051). + +## Overview + +In Transformer-based LLM inference, the KV cache grows linearly with sequence length, becoming a major bottleneck. RocketKV addresses this via two stages: + +1. **Context Phase (Stage 1):** It performs **permanent KV cache eviction**. Instead of storing the full history, it selects and keeps a `prompt_budget` of the most important tokens based on attention scores. +2. **Generation Phase (Stage 2):** It utilizes a **fine-grain Top-K sparse attention**. It maintains a lightweight, compressed auxiliary cache (KT Cache) to dynamically predict which blocks of the KV cache are relevant for the current token, and loading only those blocks to do the attention computation. + +In TensorRT-LLM, RocketKV is implemented as a specialized attention backend within the PyTorch-based high-level API workflow. Specifically, the core sparse KV prediction kernels are implemented using **Triton** kernels, achieving highly optimized performance on modern NVIDIA GPUs. + +## Support Matrix + + * GPU Compute Capability >= 10.0 (Blackwell or newer) + * FP16 / BF16 / FP8 + * Paged KV Cache + * Tensor Parallel + * Cuda Graph + +**Note:** RocketKV currently requires `enable_block_reuse=False` in the KV cache configuration, as the sparse eviction logic is incompatible with standard block reuse mechanisms. Also, RocketKV doesn't support `enable_chunked_prefill=True` for now. + +## Usage + +To use RocketKV, you need to configure the `RocketSparseAttentionConfig` and pass it to the `LLM` class. + +### Python API + +You can integrate RocketKV into your own scripts using the `tensorrt_llm.llmapi` interface. + + +```python +from tensorrt_llm import LLM, SamplingParams +from tensorrt_llm.llmapi import RocketSparseAttentionConfig, KvCacheConfig + +# 1. Define the RocketKV Sparse Attention Configuration +rocket_config = RocketSparseAttentionConfig( + window_size=32, # Size of the recent window to always keep + kernel_size=63, # Pooling kernel size for importance scoring + prompt_budget=2048, # Number of tokens to keep from the prompt (Stage 1) + topk=64, # Number of tokens to retrieve during generation (Stage 2) + topr=128, # Number of query channels to keep for scoring + kt_cache_dtype='float8_e5m2' # Dtype for the auxiliary Key-Token cache +) + +# 2. Initialize the LLM with the config and 'pytorch' backend +# Note: Block reuse must be disabled for RocketKV +kv_config = KvCacheConfig(enable_block_reuse=False) + +llm = LLM( + model="", + backend='pytorch', # RocketKV currently requires the PyTorch backend + sparse_attention_config=rocket_config, + kv_cache_config=kv_config, +) + +# 3. Generate +prompts = ["To be or not to be, that is the question."] +sampling_params = SamplingParams(max_tokens=128) +outputs = llm.generate(prompts, sampling_params) +``` + +### Running the Example Script + +We provide a reference script `examples/llm-api/llm_sparse_attention.py` to demonstrate RocketKV capabilities. + +**Example Command:** + +```bash +# Adjust --model_path to your local Llama checkpoint +python3 ../llm-api/llm_sparse_attention.py \ + --model_path \ + --algo ROCKETKV \ + --attention_backend TRTLLM \ + --window_size 32 \ + --kernel_size 63 \ + --prompt_budget 2048 \ + --topk 64 \ + --topr 128 \ + --kt_cache_dtype float8_e5m2 \ + --max_seq_len 10240 \ + --max_num_tokens 10240 \ + --max_new_tokens 128 +``` + +### Configuration Arguments + +The `RocketSparseAttentionConfig` allows fine-grained control over compression ratios and performance trade-offs: + +* **`prompt_budget`** (int, default=2048): The number of tokens to retain from the input prompt (context). RocketKV compresses the prompt to this size by evicting less important tokens based on importance scores. +* **`topk`** (int, default=64): The number of KT pages to select dynamically during the generation phase. Note that the selection is performed at the granularity of KT cache pages, but the actual attention kernel retrieves data based on the granularity of KV cache page size. +* **`topr`** (int/float, default=128): The number of query feature dimensions to use when computing the relevance score between Query and KT Cache. This acts as a dimensionality reduction to speed up the selection process. However, it's recommended to set it equal to `head_dim` to skip `topr_filter` computations for better performance and accuracy. +* **`window_size`** (int, default=32): The size of the sliding window in RocketKV. In the context phase, RocketKV uses the last `window_size` tokens of the Query and the Key prefix to compute importance scores for eviction. These recent tokens are always retained in the cache, and `prompt_budget-window_size` important tokens in the prefix are retained in the cache also. +* **`kernel_size`** (int, default=63): The size of the 1D max-pooling kernel used in the context phase. It smooths attention scores to better identify locally important regions rather than just isolated high-score tokens. +* **`kt_cache_dtype`** (str, default='float8_e5m2'): The data type for the auxiliary "Key-Token" (KT) cache used for relevance prediction. + * `float8_e5m2`: Recommended. Provides memory savings for the auxiliary structure and speedup for the prediction kernels. + * `bfloat16`: Standard precision. +* **`page_size`** (int, default=4): The granularity of the sparse selection (KT page). Currently, only **powers of 2** are supported due to Triton kernel limitations. Accuracy is generally maintained well for `page_size <= 4`. From 30b66b44db105aea31fdc11aa53d0044b91af4dd Mon Sep 17 00:00:00 2001 From: yuhangh <58161490+heyuhhh@users.noreply.github.com> Date: Fri, 28 Nov 2025 08:54:16 +0000 Subject: [PATCH 2/6] Add e2e LongBenchV2 tests for RocketKV Signed-off-by: yuhangh <58161490+heyuhhh@users.noreply.github.com> --- tensorrt_llm/evaluate/longbench_v2.py | 21 +++++--- .../accuracy/references/longbench_v2.yaml | 3 ++ .../defs/accuracy/test_llm_api_pytorch.py | 50 ++++++++++++++++++- .../test_lists/test-db/l0_b200.yml | 1 + 4 files changed, 67 insertions(+), 8 deletions(-) diff --git a/tensorrt_llm/evaluate/longbench_v2.py b/tensorrt_llm/evaluate/longbench_v2.py index 503e1bac7d0..caeac271f4f 100644 --- a/tensorrt_llm/evaluate/longbench_v2.py +++ b/tensorrt_llm/evaluate/longbench_v2.py @@ -67,7 +67,8 @@ def __init__(self, random_seed: int = 0, apply_chat_template: bool = False, system_prompt: Optional[str] = None, - chat_template_kwargs: Optional[dict[str, Any]] = None): + max_output_length: int = 32000, + chat_template_kwargs: Optional[dict[str, Any]] = None,): """Initialize LongBench v2 evaluator. Args: @@ -86,6 +87,7 @@ def __init__(self, random_seed: Random seed for reproducibility apply_chat_template: Whether to apply model's chat template system_prompt: System prompt to prepend + max_output_length: Maximum output length in tokens. Should keep this value as small as possible to avoid unnecessary truncation. chat_template_kwargs: Chat template kwargs as JSON string """ super().__init__(random_seed=random_seed, @@ -103,6 +105,7 @@ def __init__(self, self.no_context = no_context self.rag = rag self.max_len = max_len + self.max_output_length = max_output_length self.output_dir = output_dir # Will be set during evaluation @@ -307,10 +310,11 @@ def _post_process(self, pred: str) -> str: return pred def _truncate_prompt(self, prompt: str, tokenizer: Any) -> str: - """Truncate prompt to max_len tokens using needle-in-haystack strategy. + """Truncate prompt using needle-in-haystack strategy. - If the prompt exceeds max_len, it takes the first half and last half + If the prompt exceeds (max_len - max_output_length), it takes the first half and last half to preserve both context beginning and end. + We need to minus max_output_length from max_len to reserve budget for output tokens. Args: prompt: The prompt string to truncate @@ -325,8 +329,9 @@ def _truncate_prompt(self, prompt: str, tokenizer: Any) -> str: try: input_ids = tokenizer.encode(prompt, add_special_tokens=False) - if len(input_ids) > self.max_len: - half = self.max_len // 2 + max_input_len = self.max_len - self.max_output_length + if len(input_ids) > max_input_len: + half = max_input_len // 2 truncated_ids = input_ids[:half] + input_ids[-half:] prompt = tokenizer.decode(truncated_ids, skip_special_tokens=True) @@ -791,7 +796,8 @@ def _save_results(self, results: List[Dict], metrics: Dict[str, float]): type=int, default=128000, help= - "Maximum prompt length in tokens for truncation when building prompts.") + "Maximum input and output length in tokens for truncation when building prompts." + ) @click.option("--output_dir", type=str, default=None, @@ -858,7 +864,8 @@ def command(ctx, dataset_path: str, prompts_dir: Optional[str], random_seed=random_seed, apply_chat_template=apply_chat_template, system_prompt=system_prompt, - chat_template_kwargs=chat_template_kwargs) + max_output_length=max_output_length, + chat_template_kwargs=chat_template_kwargs,) evaluator.evaluate(llm, sampling_params) llm.shutdown() diff --git a/tests/integration/defs/accuracy/references/longbench_v2.yaml b/tests/integration/defs/accuracy/references/longbench_v2.yaml index 357dc405097..eae407f35a3 100644 --- a/tests/integration/defs/accuracy/references/longbench_v2.yaml +++ b/tests/integration/defs/accuracy/references/longbench_v2.yaml @@ -7,3 +7,6 @@ DeepSeek-R1-0528: kv_cache_quant_algo: FP8 spec_dec_algo: MTP accuracy: 52.093 +meta-llama/Llama-3.1-8B-Instruct: + - accuracy: 26.48 + sigma: 25.8 diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index ed04151db2e..07e42c25458 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -25,7 +25,8 @@ from tensorrt_llm.llmapi import (AutoDecodingConfig, CudaGraphConfig, EagleDecodingConfig, KvCacheConfig, MoeConfig, MTPDecodingConfig, NGramDecodingConfig, - SamplingParams, TorchCompileConfig) + RocketSparseAttentionConfig, SamplingParams, + TorchCompileConfig) from tensorrt_llm.quantization import QuantAlgo from ..conftest import (get_device_count, get_device_memory, llm_models_root, @@ -4611,3 +4612,50 @@ def test_auto_dtype(self): max_seq_len=4096) as llm: task = GSM8K(self.MODEL_NAME) task.evaluate(llm) + + +@skip_pre_blackwell +class TestLlama3_1_8B_Instruct_LongBenchV2(LlmapiAccuracyTestHarness): + MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct" + MODEL_PATH = f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct/" + + def test_auto_dtype(self): + model_dir = f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct/" + if not os.path.exists(model_dir): + pytest.skip(f"Model directory {model_dir} does not exist") + + # Configure model settings + kv_cache_config = KvCacheConfig(enable_block_reuse=False) + + cuda_graph_config = CudaGraphConfig(enable_padding=True, + max_batch_size=64) + + sparse_attention_config = RocketSparseAttentionConfig( + kt_cache_dtype="float8_e5m2", ) + + pytorch_config = dict(cuda_graph_config=cuda_graph_config, + kv_cache_config=kv_cache_config, + sparse_attention_config=sparse_attention_config, + enable_chunked_prefill=False) + + MAX_LEN = 128000 + MAX_NEW_TOKENS = 1024 + + with LLM(model_dir, + max_seq_len=MAX_LEN, + max_num_tokens=128000, + max_batch_size=64, + **pytorch_config) as llm: + task = LongBenchV2(self.MODEL_NAME) + + sampling_params = SamplingParams( + max_tokens=MAX_NEW_TOKENS, + temperature=0.8, + top_p=0.95, + ) + + extra_evaluator_kwargs = dict(max_len=MAX_LEN, + max_output_length=MAX_NEW_TOKENS) + task.evaluate(llm, + sampling_params=sampling_params, + extra_evaluator_kwargs=extra_evaluator_kwargs) diff --git a/tests/integration/test_lists/test-db/l0_b200.yml b/tests/integration/test_lists/test-db/l0_b200.yml index 58e341ec339..22923e42120 100644 --- a/tests/integration/test_lists/test-db/l0_b200.yml +++ b/tests/integration/test_lists/test-db/l0_b200.yml @@ -55,6 +55,7 @@ l0_b200: - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_w4a8_mxfp4[mxfp8-latency-CUTLASS] - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_w4a16_mxfp4[latency-TRTLLM] - accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[tp1-cutlass] + - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8B_Instruct_LongBenchV2::test_auto_dtype - disaggregated/test_workers.py::test_workers_kv_cache_aware_router_eviction[TinyLlama-1.1B-Chat-v1.0] # nvbugs 5300551 - test_e2e.py::test_ptp_quickstart_advanced[Llama3.1-8B-NVFP4-nvfp4-quantized/Meta-Llama-3.1-8B] - test_e2e.py::test_ptp_quickstart_advanced[Llama3.1-8B-FP8-llama-3.1-model/Llama-3.1-8B-Instruct-FP8] From a5bb15a48715e87df5c8392bd0e4e6d04f821e76 Mon Sep 17 00:00:00 2001 From: yuhangh <58161490+heyuhhh@users.noreply.github.com> Date: Mon, 1 Dec 2025 09:09:00 +0000 Subject: [PATCH 3/6] RocketKV doc changes Signed-off-by: yuhangh <58161490+heyuhhh@users.noreply.github.com> --- examples/sparse_attention/RocketKV.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/sparse_attention/RocketKV.md b/examples/sparse_attention/RocketKV.md index dd7e64ddec0..90474f9aeb4 100644 --- a/examples/sparse_attention/RocketKV.md +++ b/examples/sparse_attention/RocketKV.md @@ -1,6 +1,6 @@ # RocketKV Sparse Attention -This document details how to run RocketKV sparse attention with TensorRT-LLM. +This document details enabling RocketKV sparse attention within TensorRT LLM. RocketKV is a training-free, two-stage KV cache compression method designed to accelerate long-context LLM inference. It combines permanent KV token eviction (coarse-grain) with dynamic KV token selection (fine-grain) to significantly reduce memory bandwidth usage and increase throughput while maintaining high accuracy. @@ -8,12 +8,12 @@ For more technical details, please refer to the paper: [RocketKV: Accelerating L ## Overview -In Transformer-based LLM inference, the KV cache grows linearly with sequence length, becoming a major bottleneck. RocketKV addresses this via two stages: +In Transformer-based LLM inference, the KV cache grows linearly with sequence length, becoming a major bottleneck. RocketKV mitigates this issue through a two-stage process: 1. **Context Phase (Stage 1):** It performs **permanent KV cache eviction**. Instead of storing the full history, it selects and keeps a `prompt_budget` of the most important tokens based on attention scores. 2. **Generation Phase (Stage 2):** It utilizes a **fine-grain Top-K sparse attention**. It maintains a lightweight, compressed auxiliary cache (KT Cache) to dynamically predict which blocks of the KV cache are relevant for the current token, and loading only those blocks to do the attention computation. -In TensorRT-LLM, RocketKV is implemented as a specialized attention backend within the PyTorch-based high-level API workflow. Specifically, the core sparse KV prediction kernels are implemented using **Triton** kernels, achieving highly optimized performance on modern NVIDIA GPUs. +RocketKV is integrated into TensorRT LLM as a specialized attention backend, accessible via the LLM API. Specifically, the core sparse KV prediction kernels are implemented using **Triton** kernels, achieving highly optimized performance on modern NVIDIA GPUs. ## Support Matrix @@ -27,11 +27,11 @@ In TensorRT-LLM, RocketKV is implemented as a specialized attention backend with ## Usage -To use RocketKV, you need to configure the `RocketSparseAttentionConfig` and pass it to the `LLM` class. +To enable RocketKV, configure `RocketSparseAttentionConfig` and pass it to the `LLM` class constructor. ### Python API -You can integrate RocketKV into your own scripts using the `tensorrt_llm.llmapi` interface. +Integrate RocketKV into your workflows using the `tensorrt_llm.llmapi` interface. ```python From e17313399d87e8567605482351b5a654bf1651fb Mon Sep 17 00:00:00 2001 From: yuhangh <58161490+heyuhhh@users.noreply.github.com> Date: Tue, 2 Dec 2025 03:27:45 +0000 Subject: [PATCH 4/6] RocketKV doc changes Signed-off-by: yuhangh <58161490+heyuhhh@users.noreply.github.com> --- examples/sparse_attention/RocketKV.md | 30 +++++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/examples/sparse_attention/RocketKV.md b/examples/sparse_attention/RocketKV.md index 90474f9aeb4..28983a42bdb 100644 --- a/examples/sparse_attention/RocketKV.md +++ b/examples/sparse_attention/RocketKV.md @@ -2,7 +2,7 @@ This document details enabling RocketKV sparse attention within TensorRT LLM. -RocketKV is a training-free, two-stage KV cache compression method designed to accelerate long-context LLM inference. It combines permanent KV token eviction (coarse-grain) with dynamic KV token selection (fine-grain) to significantly reduce memory bandwidth usage and increase throughput while maintaining high accuracy. +RocketKV is a training-free, two-stage KV cache compression method designed to accelerate long-context LLM inference. It combines permanent KV token eviction (in context phase) with dynamic KV token selection (in generation phase) to significantly reduce memory bandwidth usage and increase throughput while maintaining high accuracy. For more technical details, please refer to the paper: [RocketKV: Accelerating Long-Context LLM Inference via Two-Stage KV Cache Compression](https://arxiv.org/pdf/2502.14051). @@ -11,7 +11,7 @@ For more technical details, please refer to the paper: [RocketKV: Accelerating L In Transformer-based LLM inference, the KV cache grows linearly with sequence length, becoming a major bottleneck. RocketKV mitigates this issue through a two-stage process: 1. **Context Phase (Stage 1):** It performs **permanent KV cache eviction**. Instead of storing the full history, it selects and keeps a `prompt_budget` of the most important tokens based on attention scores. -2. **Generation Phase (Stage 2):** It utilizes a **fine-grain Top-K sparse attention**. It maintains a lightweight, compressed auxiliary cache (KT Cache) to dynamically predict which blocks of the KV cache are relevant for the current token, and loading only those blocks to do the attention computation. +2. **Generation Phase (Stage 2):** It utilizes a **blocked Top-K sparse attention**. It maintains a lightweight, compressed auxiliary cache (KT Cache) to dynamically predict which blocks of the KV cache are relevant for the current token, and loading only those blocks to do the attention computation. RocketKV is integrated into TensorRT LLM as a specialized attention backend, accessible via the LLM API. Specifically, the core sparse KV prediction kernels are implemented using **Triton** kernels, achieving highly optimized performance on modern NVIDIA GPUs. @@ -23,7 +23,10 @@ RocketKV is integrated into TensorRT LLM as a specialized attention backend, acc * Tensor Parallel * Cuda Graph -**Note:** RocketKV currently requires `enable_block_reuse=False` in the KV cache configuration, as the sparse eviction logic is incompatible with standard block reuse mechanisms. Also, RocketKV doesn't support `enable_chunked_prefill=True` for now. +**Note:** +1. RocketKV currently requires `enable_block_reuse=False` in the KV cache configuration, as the sparse eviction logic is incompatible with standard block reuse mechanisms. +2. RocketKV doesn't support `enable_chunked_prefill=True` for now. +3. RocketKV doesn't support *disagg-serving* as well, since it needs the KV cache transmission from prefill engine to the decode engine. But currently RocketKV uses a python kt cache manager and it cannot support this transmission. ## Usage @@ -88,7 +91,26 @@ python3 ../llm-api/llm_sparse_attention.py \ --max_new_tokens 128 ``` -### Configuration Arguments + +### Usage with `trtllm-bench` and `trtllm-serve` + +Sparse attention options must be specified via `--extra_llm_api_options config.yaml` for both `trtllm-bench` and `trtllm-serve`. All sparse attetnion options can be specified in this YAML file and the argument names/valid values are the same as in their corresponding configuration described in the Configuration Arguments section. For example, a YAML configuration could look like this: + +``` +backend: pytorch +attn_backend: TRTLLM +sparse_attention_config: + algorithm: rocket + kt_cache_dtype: float8_e5m2 + window_size: 32 + prompt_budget: 2048 +kv_cache_config: + enable_block_reuse: false +enable_chunked_prefill: false +``` + + +## Configuration Arguments The `RocketSparseAttentionConfig` allows fine-grained control over compression ratios and performance trade-offs: From 1bfc3023017109701edc76a76c3b2433d78dbdbe Mon Sep 17 00:00:00 2001 From: yuhangh <58161490+heyuhhh@users.noreply.github.com> Date: Tue, 2 Dec 2025 05:52:25 +0000 Subject: [PATCH 5/6] pre-commit fix Signed-off-by: yuhangh <58161490+heyuhhh@users.noreply.github.com> --- tensorrt_llm/evaluate/longbench_v2.py | 74 ++++++++++++++------------- 1 file changed, 39 insertions(+), 35 deletions(-) diff --git a/tensorrt_llm/evaluate/longbench_v2.py b/tensorrt_llm/evaluate/longbench_v2.py index caeac271f4f..fce496be322 100644 --- a/tensorrt_llm/evaluate/longbench_v2.py +++ b/tensorrt_llm/evaluate/longbench_v2.py @@ -51,24 +51,26 @@ class LongBenchV2(Evaluator): DIFFICULTIES = ['easy', 'hard'] LENGTHS = ['short', 'medium', 'long'] - def __init__(self, - dataset_path: str = 'THUDM/LongBench-v2', - prompts_dir: Optional[str] = None, - num_samples: Optional[int] = None, - start_idx: int = 0, - difficulty: Optional[str] = None, - length: str = 'medium', - domain: Optional[str] = None, - cot: bool = False, - no_context: bool = False, - rag: int = 0, - max_len: int = 128000, - output_dir: Optional[str] = None, - random_seed: int = 0, - apply_chat_template: bool = False, - system_prompt: Optional[str] = None, - max_output_length: int = 32000, - chat_template_kwargs: Optional[dict[str, Any]] = None,): + def __init__( + self, + dataset_path: str = 'THUDM/LongBench-v2', + prompts_dir: Optional[str] = None, + num_samples: Optional[int] = None, + start_idx: int = 0, + difficulty: Optional[str] = None, + length: str = 'medium', + domain: Optional[str] = None, + cot: bool = False, + no_context: bool = False, + rag: int = 0, + max_len: int = 128000, + output_dir: Optional[str] = None, + random_seed: int = 0, + apply_chat_template: bool = False, + system_prompt: Optional[str] = None, + max_output_length: int = 32000, + chat_template_kwargs: Optional[dict[str, Any]] = None, + ): """Initialize LongBench v2 evaluator. Args: @@ -849,23 +851,25 @@ def command(ctx, dataset_path: str, prompts_dir: Optional[str], temperature=0.6, top_p=0.95) - evaluator = LongBenchV2(dataset_path=dataset_path, - prompts_dir=prompts_dir, - num_samples=num_samples, - start_idx=start_idx, - difficulty=difficulty, - length=length, - domain=domain, - cot=cot, - no_context=no_context, - rag=rag, - max_len=max_len, - output_dir=output_dir, - random_seed=random_seed, - apply_chat_template=apply_chat_template, - system_prompt=system_prompt, - max_output_length=max_output_length, - chat_template_kwargs=chat_template_kwargs,) + evaluator = LongBenchV2( + dataset_path=dataset_path, + prompts_dir=prompts_dir, + num_samples=num_samples, + start_idx=start_idx, + difficulty=difficulty, + length=length, + domain=domain, + cot=cot, + no_context=no_context, + rag=rag, + max_len=max_len, + output_dir=output_dir, + random_seed=random_seed, + apply_chat_template=apply_chat_template, + system_prompt=system_prompt, + max_output_length=max_output_length, + chat_template_kwargs=chat_template_kwargs, + ) evaluator.evaluate(llm, sampling_params) llm.shutdown() From 685207004ae0d2b0dd81b6ab6dab72695b8f2470 Mon Sep 17 00:00:00 2001 From: yuhangh <58161490+heyuhhh@users.noreply.github.com> Date: Wed, 3 Dec 2025 02:36:38 +0000 Subject: [PATCH 6/6] Doc changes Signed-off-by: yuhangh <58161490+heyuhhh@users.noreply.github.com> --- examples/sparse_attention/RocketKV.md | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/examples/sparse_attention/RocketKV.md b/examples/sparse_attention/RocketKV.md index 28983a42bdb..c320350c9f7 100644 --- a/examples/sparse_attention/RocketKV.md +++ b/examples/sparse_attention/RocketKV.md @@ -4,14 +4,14 @@ This document details enabling RocketKV sparse attention within TensorRT LLM. RocketKV is a training-free, two-stage KV cache compression method designed to accelerate long-context LLM inference. It combines permanent KV token eviction (in context phase) with dynamic KV token selection (in generation phase) to significantly reduce memory bandwidth usage and increase throughput while maintaining high accuracy. -For more technical details, please refer to the paper: [RocketKV: Accelerating Long-Context LLM Inference via Two-Stage KV Cache Compression](https://arxiv.org/pdf/2502.14051). +For more technical details, please refer to the paper: [RocketKV: Accelerating Long-Context LLM Inference via Two-Stage KV Cache Compression](https://arxiv.org/pdf/2502.14051). Here is an official implement which provides a reference: [RocketKV Repo](https://github.com/NVlabs/RocketKV). ## Overview In Transformer-based LLM inference, the KV cache grows linearly with sequence length, becoming a major bottleneck. RocketKV mitigates this issue through a two-stage process: 1. **Context Phase (Stage 1):** It performs **permanent KV cache eviction**. Instead of storing the full history, it selects and keeps a `prompt_budget` of the most important tokens based on attention scores. -2. **Generation Phase (Stage 2):** It utilizes a **blocked Top-K sparse attention**. It maintains a lightweight, compressed auxiliary cache (KT Cache) to dynamically predict which blocks of the KV cache are relevant for the current token, and loading only those blocks to do the attention computation. +2. **Generation Phase (Stage 2):** It utilizes a **dynamic Top-K token selection**. It maintains a lightweight, compressed auxiliary cache (KT Cache) to dynamically predict which tokens of the KV cache are relevant for the current token, and loading only those tokens to do the attention computation. RocketKV is integrated into TensorRT LLM as a specialized attention backend, accessible via the LLM API. Specifically, the core sparse KV prediction kernels are implemented using **Triton** kernels, achieving highly optimized performance on modern NVIDIA GPUs. @@ -36,7 +36,6 @@ To enable RocketKV, configure `RocketSparseAttentionConfig` and pass it to the ` Integrate RocketKV into your workflows using the `tensorrt_llm.llmapi` interface. - ```python from tensorrt_llm import LLM, SamplingParams from tensorrt_llm.llmapi import RocketSparseAttentionConfig, KvCacheConfig @@ -109,6 +108,16 @@ kv_cache_config: enable_chunked_prefill: false ``` +Run the command with the config file: +```bash +trtllm-eval/trtllm-bench/trtllm-serve --model --extra_llm_api_options extra_config.yaml ... +``` + +For example, users can evaluate a model with trtllm-eval on LongBenchV2 task like this: + +```bash +trtllm-eval --model --extra_llm_api_options extra_config.yaml longbench_v2 --max_output_length 1024 ... +``` ## Configuration Arguments @@ -122,4 +131,4 @@ The `RocketSparseAttentionConfig` allows fine-grained control over compression r * **`kt_cache_dtype`** (str, default='float8_e5m2'): The data type for the auxiliary "Key-Token" (KT) cache used for relevance prediction. * `float8_e5m2`: Recommended. Provides memory savings for the auxiliary structure and speedup for the prediction kernels. * `bfloat16`: Standard precision. -* **`page_size`** (int, default=4): The granularity of the sparse selection (KT page). Currently, only **powers of 2** are supported due to Triton kernel limitations. Accuracy is generally maintained well for `page_size <= 4`. +* **`page_size`** (int, default=4): The granularity of the sparse token selection (KT page). Currently, only **powers of 2** are supported due to Triton kernel limitations. Accuracy is generally maintained well for `page_size <= 4`.