Skip to content

Commit dfa8bc0

Browse files
authored
Merge branch 'NVIDIA:main' into qgai/deepseekeagle2models
2 parents e1f5058 + 523a17d commit dfa8bc0

File tree

13 files changed

+224
-74
lines changed

13 files changed

+224
-74
lines changed

docs/source/features/kvcache.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ Before a block is evicted from GPU memory, it can optionally be offloaded to hos
7070

7171
When offloading is enabled, the client can prevent specific blocks from being offloaded by toggling block priority. Blocks with lower priority than a certain threshold are not offloaded; they are evicted directly from GPU memory to reduce traffic between GPU and host. This priority is set with ```secondary_offload_min_priority```. Default value is 35, meaning any block with lower priority than 35 will not be offloaded.
7272

73+
Here is an [example](../../../examples/llm-api/llm_kv_cache_offloading.py) to show how to enable host offloading.
74+
7375
### Partial Reuse
7476

7577
Partial reuse of a block can happen when some but not all tokens are matched. It is enabled by default, but can be disabled by setting ```enable_partial_reuse``` to False.
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
'''
2+
This script demonstrates the effectiveness of KV cache host offloading in TensorRT-LLM.
3+
4+
**Scenario:**
5+
The script simulates a scenario where the GPU's KV cache is severely limited,
6+
while multiple requests with recurring prompts (like system prompts) are processed.
7+
8+
1. **Constrained GPU Cache:** The GPU KV cache is configured to be very small,
9+
only large enough to hold the state for a single request.
10+
2. **Alternating Prompts:** Four requests are sent sequentially (batch size of 1)
11+
with two distinct prompts in an A, B, A, B pattern.
12+
3. **Cache Eviction:** Due to the small GPU cache, processing prompt B will
13+
force the eviction of the cache generated for prompt A.
14+
15+
**Demonstration:**
16+
17+
* **Without Offloading (Default):**
18+
- When the first prompt 'A' is processed, its KV cache is stored on the GPU.
19+
- When prompt 'B' arrives, the cache manager needs space and discards the cache for 'A'.
20+
- When prompt 'A' is sent again, its cache must be recomputed from scratch.
21+
- **Expected Outcome:** The log will show `reused blocks: 0` and `cache hit rate: 0`.
22+
23+
* **With Offloading (`--enable_offloading`):**
24+
- When prompt 'B' arrives, the cache for 'A' is not discarded but is instead
25+
*offloaded* from the fast GPU VRAM to the slower (but larger) host CPU RAM.
26+
- When prompt 'A' is sent again, its KV cache is loaded back from host RAM
27+
to the GPU, which is significantly faster than recomputing it.
28+
- **Expected Outcome:** The log will show positive values for `reused blocks`
29+
and a non-zero `cache hit rate`, confirming that the cache was successfully
30+
reused from the host.
31+
32+
**How to Run & Verify:**
33+
34+
1. **Without Offloading:**
35+
```bash
36+
TLLM_LOG_LEVEL=DEBUG python llm_kv_cache_offloading.py 2>&1 | tee offloading_disabled.log
37+
```
38+
(Check the log for zero reuse)
39+
40+
2. **With Offloading:**
41+
```bash
42+
TLLM_LOG_LEVEL=DEBUG python llm_kv_cache_offloading.py --enable_offloading 2>&1 | tee offloading_enabled.log
43+
```
44+
(Check the log for non-zero reuse)
45+
'''
46+
47+
import argparse
48+
49+
from tensorrt_llm import LLM, SamplingParams
50+
from tensorrt_llm.llmapi import KvCacheConfig
51+
52+
53+
def main(args):
54+
# Define two distinct prompts to simulate different requests or system prompts.
55+
prompt_a = (
56+
"Returns the per-iterations statistics computed since last call to this method. "
57+
"Contains at most iter_stats_max_iterations iterations.")
58+
prompt_b = ("Use for skipping decoding step for non generation model, "
59+
"and return the batch_output (such as mm_embeddings)")
60+
61+
# Use a batch size of 1 to process requests sequentially, making the cache
62+
# eviction and reuse cycle easy to observe.
63+
max_batch_size = 1
64+
max_seq_len = 256
65+
66+
# --- KV Cache Configuration ---
67+
# Set a small GPU KV cache size (in number of tokens). This is crucial for the demo,
68+
# as it's only large enough to hold the KV cache for a single request.
69+
kv_cache_max_tokens = 256
70+
# Define the size of a single cache block.
71+
kv_cache_page_size = 16
72+
# Enable a 1 GB host cache if offloading is requested, otherwise disable it (size 0).
73+
# This is the key toggle for the experiment.
74+
kv_cache_host_size = 1024**3 if args.enable_offloading else 0
75+
76+
sampling_params = SamplingParams(max_tokens=max_seq_len)
77+
78+
llm = LLM(
79+
model="Qwen/Qwen3-8B",
80+
max_batch_size=max_batch_size,
81+
max_seq_len=max_seq_len,
82+
kv_cache_config=KvCacheConfig(
83+
enable_block_reuse=True, # Enable reuse of cached blocks
84+
max_tokens=kv_cache_max_tokens, # Max tokens in GPU cache
85+
tokens_per_block=kv_cache_page_size,
86+
host_cache_size=kv_cache_host_size # Host cache size for offloading
87+
))
88+
89+
# Process four requests sequentially using two distinct prompts (A, B, A, B).
90+
# This pattern is designed to showcase the cache eviction and reuse behavior.
91+
print("--- First Round ---")
92+
# 1. Process prompt A. Its cache is stored on the GPU.
93+
output_a = llm.generate(prompt_a, sampling_params)
94+
print(
95+
f"Prompt: {output_a.prompt!r}, Generated text: {output_a.outputs[0].text!r}"
96+
)
97+
# 2. Process prompt B. Its cache replaces/offloads A's cache.
98+
output_b = llm.generate(prompt_b, sampling_params)
99+
print(
100+
f"Prompt: {output_b.prompt!r}, Generated text: {output_b.outputs[0].text!r}"
101+
)
102+
103+
print("\n--- Second Round ---")
104+
# 3. Process prompt A again.
105+
# - Without offloading: Must recompute from scratch.
106+
# - With offloading: Recovers cache from host RAM.
107+
output_a = llm.generate(prompt_a, sampling_params)
108+
print(
109+
f"Prompt: {output_a.prompt!r}, Generated text: {output_a.outputs[0].text!r}"
110+
)
111+
# 4. Process prompt B again.
112+
# - Without offloading: Must recompute from scratch.
113+
# - With offloading: Recovers cache from host RAM.
114+
output_b = llm.generate(prompt_b, sampling_params)
115+
print(
116+
f"Prompt: {output_b.prompt!r}, Generated text: {output_b.outputs[0].text!r}"
117+
)
118+
119+
llm.shutdown()
120+
121+
122+
if __name__ == "__main__":
123+
parser = argparse.ArgumentParser(
124+
description=
125+
"A script to demonstrate the effectiveness of KV cache host offloading."
126+
)
127+
parser.add_argument('--enable_offloading',
128+
action='store_true',
129+
help='Enable host RAM for KV cache offloading.')
130+
args = parser.parse_args()
131+
main(args)

examples/scaffolding/contrib/TreeInference/run_mcts_example.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33

44
import argparse
55

6-
from tensorrt_llm.scaffolding import (MCTSController,
7-
NativeGenerationController, PRMController)
6+
from tensorrt_llm.scaffolding import NativeGenerationController, PRMController
7+
from tensorrt_llm.scaffolding.contrib.TreeInference import MCTSController
88
from tensorrt_llm.scaffolding.scaffolding_llm import ScaffoldingLlm
99
from tensorrt_llm.scaffolding.worker import TRTLLMWorker
1010

examples/scaffolding/contrib/TreeInference/run_tot_example.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33

44
import argparse
55

6-
from tensorrt_llm.scaffolding import (NativeGenerationController, PRMController,
7-
TOTController)
6+
from tensorrt_llm.scaffolding import NativeGenerationController, PRMController
7+
from tensorrt_llm.scaffolding.contrib.TreeInference import TOTController
88
from tensorrt_llm.scaffolding.scaffolding_llm import ScaffoldingLlm
99
from tensorrt_llm.scaffolding.worker import TRTLLMWorker
1010

tensorrt_llm/scaffolding/__init__.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2-
32
from .benchmark import ScaffoldingBenchRequest, async_scaffolding_benchmark
4-
from .contrib.TreeInference.tree_controllers import (MCTSController,
5-
TOTController)
63
from .controller import (BestOfNController, Controller, MajorityVoteController,
74
NativeGenerationController, NativeRewardController,
85
ParallelProcess, PRMController)
@@ -23,8 +20,6 @@
2320
"PRMController",
2421
"MajorityVoteController",
2522
"BestOfNController",
26-
"MCTSController",
27-
"TOTController",
2823
"Task",
2924
"GenerationTask",
3025
"RewardTask",
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .tree_controllers import MCTSController, TOTController
2+
3+
__all__ = ["MCTSController", "TOTController"]

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2040,6 +2040,42 @@ def test_nvfp4_multi_gpus_corner_case(self):
20402040
task = GSM8K(self.MODEL_NAME)
20412041
task.evaluate(llm)
20422042

2043+
def test_nvfp4_multi_gpus_corner_case(self):
2044+
"""
2045+
This test is used to test the corner case of the NVFP4 model.
2046+
When using the same value for max_seq_len and max_num_tokens, there will be no
2047+
enough kv block for the dummy requests in CUDA graph warmup when creating
2048+
the py_executor before estimating kv cache. Then CUDA graph capture will be
2049+
triggered when estimating kv cache. This may cause some errors.
2050+
More info in https://nvbugs/5485325.
2051+
"""
2052+
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.80,
2053+
dtype="fp8",
2054+
enable_block_reuse=False)
2055+
pytorch_config = dict(disable_overlap_scheduler=False,
2056+
cuda_graph_config=CudaGraphConfig(
2057+
enable_padding=True, max_batch_size=1024),
2058+
moe_config=MoeConfig(backend="TRTLLM"))
2059+
2060+
mtp_config = MTPDecodingConfig(num_nextn_predict_layers=1)
2061+
with LLM(f"{llm_models_root()}/DeepSeek-R1/DeepSeek-R1-FP4",
2062+
tensor_parallel_size=8,
2063+
pipeline_parallel_size=1,
2064+
moe_expert_parallel_size=8,
2065+
kv_cache_config=kv_cache_config,
2066+
**pytorch_config,
2067+
enable_attention_dp=False,
2068+
speculative_config=mtp_config,
2069+
max_seq_len=5120,
2070+
max_num_tokens=5120) as llm:
2071+
2072+
assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4
2073+
2074+
task = MMLU(self.MODEL_NAME)
2075+
task.evaluate(llm)
2076+
task = GSM8K(self.MODEL_NAME)
2077+
task.evaluate(llm)
2078+
20432079
@pytest.mark.skip_less_mpi_world_size(8)
20442080
@skip_pre_hopper
20452081
@pytest.mark.parametrize(

tests/integration/test_lists/waives.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,7 @@ examples/test_phi.py::test_llm_phi_lora_1gpu[Phi-3-mini-4k-instruct-ru-lora-Phi-
310310
examples/test_phi.py::test_llm_phi_quantization_1gpu[Phi-3-mini-128k-instruct-fp8-float16] SKIP (https://nvbugs/5465143)
311311
examples/test_phi.py::test_llm_phi_quantization_1gpu[Phi-3.5-mini-instruct-fp8-float16] SKIP (https://nvbugs/5465143)
312312
examples/test_phi.py::test_llm_phi_quantization_1gpu[Phi-3.5-MoE-instruct-fp8-bfloat16] SKIP (https://nvbugs/5465143)
313+
examples/test_gemma.py::test_llm_gemma_1gpu_summary_vswa[gemma-3-1b-it-other-bfloat16-8] SKIP (https://nvbugs/5522332)
313314
accuracy/test_cli_flow.py::TestPhi4MiniInstruct::test_auto_dtype SKIP (https://nvbugs/5465143)
314315
accuracy/test_cli_flow.py::TestPhi4MiniInstruct::test_tp2 SKIP (https://nvbugs/5465143)
315316
accuracy/test_cli_flow.py::TestLongAlpaca7B::test_auto_dtype SKIP (https://nvbugs/5481075)

tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py

Lines changed: 26 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
from typing import Any, Dict, Optional
44

5+
import pytest
56
import torch
67
import torch.nn.functional as F
78
from torch import nn
@@ -286,11 +287,12 @@ def generate_dynamic_shapes(max_batch_size, max_seq_len):
286287

287288

288289
def _hf_model_dir_or_hub_id(
289-
hf_model_dir: str,
290+
hf_model_subdir: str,
290291
hf_hub_id: str,
291292
) -> str:
292-
if os.path.isdir(hf_model_dir):
293-
return hf_model_dir
293+
llm_models_path = llm_models_root()
294+
if llm_models_path and os.path.isdir((model_fullpath := llm_models_path / hf_model_subdir)):
295+
return str(model_fullpath)
294296
else:
295297
return hf_hub_id
296298

@@ -350,10 +352,7 @@ def apply_rotary_pos_emb_ds(q, k, cos, sin, position_ids, unsqueeze_dim=1):
350352

351353
_SMALL_MODEL_CONFIGS = {
352354
"meta-llama/Meta-Llama-3.1-8B-Instruct": {
353-
"model": _hf_model_dir_or_hub_id(
354-
f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct",
355-
"meta-llama/Meta-Llama-3.1-8B-Instruct",
356-
),
355+
"llm_models_subdir": "llama-3.1-model/Llama-3.1-8B-Instruct",
357356
"model_kwargs": {
358357
"num_hidden_layers": 1,
359358
"hidden_size": 64,
@@ -363,10 +362,7 @@ def apply_rotary_pos_emb_ds(q, k, cos, sin, position_ids, unsqueeze_dim=1):
363362
},
364363
},
365364
"mistralai/Mixtral-8x7B-Instruct-v0.1": {
366-
"model": _hf_model_dir_or_hub_id(
367-
f"{llm_models_root()}/Mixtral-8x7B-Instruct-v0.1",
368-
"mistralai/Mixtral-8x7B-Instruct-v0.1",
369-
),
365+
"llm_models_subdir": "Mixtral-8x7B-Instruct-v0.1",
370366
"model_kwargs": {
371367
"num_hidden_layers": 2,
372368
"intermediate_size": 256,
@@ -377,10 +373,7 @@ def apply_rotary_pos_emb_ds(q, k, cos, sin, position_ids, unsqueeze_dim=1):
377373
},
378374
},
379375
"Qwen/Qwen3-30B-A3B": {
380-
"model": _hf_model_dir_or_hub_id(
381-
f"{llm_models_root()}/Qwen3/Qwen3-30B-A3B",
382-
"Qwen/Qwen3-30B-A3B",
383-
),
376+
"llm_models_subdir": "Qwen3/Qwen3-30B-A3B",
384377
"model_kwargs": {
385378
"num_hidden_layers": 2,
386379
"intermediate_size": 256,
@@ -391,10 +384,7 @@ def apply_rotary_pos_emb_ds(q, k, cos, sin, position_ids, unsqueeze_dim=1):
391384
},
392385
},
393386
"microsoft/Phi-3-mini-4k-instruct": {
394-
"model": _hf_model_dir_or_hub_id(
395-
f"{llm_models_root()}/Phi-3/Phi-3-mini-4k-instruct",
396-
"microsoft/Phi-3-mini-4k-instruct",
397-
),
387+
"llm_models_subdir": "Phi-3/Phi-3-mini-4k-instruct",
398388
"model_kwargs": {
399389
"num_hidden_layers": 2,
400390
"hidden_size": 128,
@@ -404,10 +394,7 @@ def apply_rotary_pos_emb_ds(q, k, cos, sin, position_ids, unsqueeze_dim=1):
404394
},
405395
},
406396
"meta-llama/Llama-4-Scout-17B-16E-Instruct": {
407-
"model": _hf_model_dir_or_hub_id(
408-
f"{llm_models_root()}/Llama-4-Scout-17B-16E-Instruct",
409-
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
410-
),
397+
"llm_models_subdir": "Llama-4-Scout-17B-16E-Instruct",
411398
"model_factory": "AutoModelForImageTextToText",
412399
"model_kwargs": {
413400
"text_config": {
@@ -426,10 +413,7 @@ def apply_rotary_pos_emb_ds(q, k, cos, sin, position_ids, unsqueeze_dim=1):
426413
},
427414
},
428415
"deepseek-ai/DeepSeek-V3": {
429-
"model": _hf_model_dir_or_hub_id(
430-
f"{llm_models_root()}/DeepSeek-V3",
431-
"deepseek-ai/DeepSeek-V3",
432-
),
416+
"llm_models_subdir": "DeepSeek-V3",
433417
"model_kwargs": {
434418
"first_k_dense_replace": 1,
435419
"num_hidden_layers": 2,
@@ -448,16 +432,13 @@ def apply_rotary_pos_emb_ds(q, k, cos, sin, position_ids, unsqueeze_dim=1):
448432
},
449433
},
450434
"Qwen/Qwen2.5-3B-Instruct": {
451-
"model": _hf_model_dir_or_hub_id(
452-
f"{llm_models_root()}/Qwen/Qwen2.5-3B-Instruct",
453-
"Qwen/Qwen2.5-3B-Instruct",
454-
),
435+
"llm_models_subdir": "Qwen2.5-3B-Instruct",
455436
"model_kwargs": {
456437
"num_hidden_layers": 2,
457438
},
458439
},
459440
"mistralai/Mistral-Small-3.1-24B-Instruct-2503": {
460-
"model": f"{llm_models_root()}/Mistral-Small-3.1-24B-Instruct-2503",
441+
"llm_models_subdir": "Mistral-Small-3.1-24B-Instruct-2503",
461442
"model_factory": "Mistral3VLM",
462443
"compile_backend": "torch-simple",
463444
"model_kwargs": {
@@ -487,6 +468,9 @@ def get_small_model_config(model_hub_id: str, **llm_args_kwargs) -> Dict[str, An
487468

488469
llm_args = copy.deepcopy(_SMALL_MODEL_CONFIGS[model_hub_id])
489470

471+
# check if should use llm_models_root or hf_hub_id
472+
llm_args["model"] = _hf_model_dir_or_hub_id(llm_args.pop("llm_models_subdir"), model_hub_id)
473+
490474
# add some defaults to llm_args
491475
llm_args["skip_loading_weights"] = True # No weight loading to speed up things
492476
llm_args["free_mem_ratio"] = 0.00 # we don't need the cache and it may cause OOM issues
@@ -507,3 +491,13 @@ def get_small_model_config(model_hub_id: str, **llm_args_kwargs) -> Dict[str, An
507491
}
508492

509493
return experiment_config
494+
495+
496+
def get_small_model_config_pytest_param(
497+
model_hub_id: str, pytest_param_kwargs=None, **llm_args_kwargs
498+
):
499+
return pytest.param(
500+
get_small_model_config(model_hub_id, **llm_args_kwargs),
501+
id=model_hub_id,
502+
**(pytest_param_kwargs or {}),
503+
)

tests/unittest/_torch/auto_deploy/unit/multigpu/test_ad_build_small_multi.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,15 @@
33
from typing import Dict
44

55
import pytest
6-
from _model_test_utils import get_small_model_config
6+
from _model_test_utils import get_small_model_config_pytest_param
77
from build_and_run_ad import ExperimentConfig, main
88

99

1010
@pytest.mark.parametrize("world_size", [1, 2])
1111
@pytest.mark.parametrize(
1212
"experiment_config",
1313
[
14-
get_small_model_config(
14+
get_small_model_config_pytest_param(
1515
"meta-llama/Meta-Llama-3.1-8B-Instruct",
1616
attn_backend="flashinfer",
1717
compile_backend="torch-opt",

0 commit comments

Comments
 (0)