Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,10 @@ class BlockRange
auto windowSize = cacheManager.getBlockManager().getPoolWindowSize(poolIdx);
mPoolsPerWindow[windowSize].push_back(cacheManager.getBlockManager().getPrimaryPool(poolIdx));
}
if (cacheManager.isEnableIndexerKCache())
{
mIndexerKCachePool = cacheManager.getIndexerKCachePool();
}
}

BlockRange(BaseKVCacheManager const& cacheManager, LlmRequest::RequestIdType requestId)
Expand Down
2 changes: 1 addition & 1 deletion cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -806,7 +806,7 @@ class CacheReceiver::Impl

RequestInfo requestInfo(requestId, mSelfState);

if (mFormatter->getCacheManager()->getBlockManager().getNumPools() == 1)
if (!mFormatter->getCacheManager()->getBlockManager().isVariableWindow())
{
auto* cacheManager = mFormatter->getCacheManager();
auto beam = 0;
Expand Down
9 changes: 1 addition & 8 deletions cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -876,14 +876,7 @@ void WindowBlockManager::allocatePools(bool useUvm)
}

nvinfer1::Dims cacheShape;
if (pool.containsIndexerKCache)
{
cacheShape = ITensor::makeShape({mNumPrimaryBlocks, pool.numLayers, blockSize});
}
else
{
cacheShape = ITensor::makeShape({mNumPrimaryBlocks, pool.numLayers, mKVFactor, blockSize});
}
cacheShape = ITensor::makeShape({mNumPrimaryBlocks, pool.numLayers, mKVFactor, blockSize});

TLLM_LOG_DEBUG("[%s] Allocating primary pool with %d blocks for %d layers with %d kv heads", mLogPrefix.c_str(),
mNumPrimaryBlocks, pool.numLayers, pool.numKvHeads);
Expand Down
9 changes: 0 additions & 9 deletions examples/models/core/deepseek_v3/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -881,12 +881,3 @@ python quickstart_advanced.py --model_dir <YOUR_MODEL_DIR> --enable_chunked_pref
- **GPU Memory:** Adjust `--max_batch_size` and `--max_num_tokens` if you encounter out-of-memory errors.
- **Logs:** Check `/workspace/trt_bench.log` for detailed performance information and troubleshooting messages.
- **Configuration Files:** Verify that the configuration files are correctly formatted to avoid runtime issues.

## Known Issues
- Support for KV Cache Reuse and Chunked Prefill in DeepSeek-V3.2-Exp is currently under development. When running `quickstart_advanced.py`, please include `--disable_kv_cache_reuse` to disable KV Cache Reuse. When using `trtllm-eval`/`trtllm-serve`/`trtllm-bench`, please include the following configuration in the extra llm_api options:
```
kv_cache_config:
enable_block_reuse: false
tokens_per_block: 64
enable_chunked_prefill: false
```
12 changes: 4 additions & 8 deletions tensorrt_llm/_torch/attention_backend/sparse/dsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,15 +930,15 @@ def prepare(metadata: DSAtrtllmAttentionMetadata):
start_idx=0,
)

if len(chunk_groups) > 1:
if len(chunk_groups
) > 1 or metadata.enable_context_mla_with_cached_kv:
metadata.indexer_prefill_chunks = [
Indexer.prepare_one_prefill_chunk(
metadata,
chunk_specs,
) for chunk_specs in chunk_groups
]
else:
# Single chunk - use non-chunked fallback path
metadata.indexer_prefill_chunks = None

host_cu_seqlen_ks, host_cu_seqlen_ke = compute_cu_seqlen_kv_bounds_with_cache(
Expand Down Expand Up @@ -1018,9 +1018,9 @@ def prepare(metadata: DSAtrtllmAttentionMetadata):
metadata.slot_mapping_scale[:total_tokens].copy_(
metadata.host_slot_mapping_scale[:total_tokens], non_blocking=True)

# Only when MLA chunked prefill is enabled, we need to gather the full KV for indexer's logit computation.
# When chunked prefill or KVCache reuse is enabled, we need to gather the full KV for indexer's logit computation.
# Indexer's own chunking does not need full KV gathering, instead it gathers only the current chunk with loop-based gathering.
_need_full_kv_gathering = num_contexts > 0 and has_mla_chunked_prefill
_need_full_kv_gathering = num_contexts > 0 and metadata.enable_context_mla_with_cached_kv
if _need_full_kv_gathering:
total_kv_len = metadata.host_ctx_kv_indptr[num_contexts].item()
total_kv_per_request = seq_lens[:
Expand Down Expand Up @@ -1589,10 +1589,6 @@ def __init__(
sparse_attn_config: "SparseAttentionConfig",
**kwargs,
) -> None:

if kv_cache_config.enable_block_reuse:
raise NotImplementedError(
"DSA indexer K-cache manager does not support block reuse yet")
self.quant_block_size = 128
self.index_head_dim = sparse_attn_config.index_head_dim
# Use a fixed tokens_per_block for indexer k cache due to DG kernel constraints
Expand Down
2 changes: 0 additions & 2 deletions tests/integration/defs/accuracy/test_disaggregated_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -1055,7 +1055,6 @@ def test_auto_dtype(self, overlap_scheduler):
ctx_server_config["cache_transceiver_config"] = {"backend": "DEFAULT"}
gen_server_config["cache_transceiver_config"] = {"backend": "DEFAULT"}
ctx_server_config["kv_cache_config"] = {
"enable_block_reuse": False,
"free_gpu_memory_fraction": 0.7,
"tokens_per_block": 64,
"dtype": "fp8"
Expand All @@ -1072,7 +1071,6 @@ def test_auto_dtype(self, overlap_scheduler):
ctx_server_config["enable_attention_dp"] = True
ctx_server_config["enable_autotuner"] = False
gen_server_config["kv_cache_config"] = {
"enable_block_reuse": False,
"tokens_per_block": 64,
"free_gpu_memory_fraction": 0.7,
"dtype": "fp8"
Expand Down
14 changes: 4 additions & 10 deletions tests/integration/defs/accuracy/test_llm_api_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2597,17 +2597,13 @@ def test_fp8_blockscale(self, tp_size, pp_size, ep_size, mtp_nextn, fp8kv,
if get_sm_version() == 100 or get_sm_version() == 103:
moe_backend = "DEEPGEMM" if moe_backend == "_DEFAULT" else moe_backend
moe_config = MoeConfig(backend=moe_backend, max_num_tokens=16384)
# TODO: Support block reuse for DeepSeek-V3.2
kv_cache_config = KvCacheConfig(enable_block_reuse=False,
free_gpu_memory_fraction=0.6,
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.6,
tokens_per_block=64)
else:
if moe_backend != "_DEFAULT":
pytest.skip("Not supported MoE backend!")
moe_config = MoeConfig()
# TODO: Support block reuse for DeepSeek-V3.2
kv_cache_config = KvCacheConfig(enable_block_reuse=False,
free_gpu_memory_fraction=0.7,
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.7,
tokens_per_block=64)

pytorch_config = dict(
Expand Down Expand Up @@ -2670,8 +2666,7 @@ def test_nvfp4_multi_gpus(self, tp_size, pp_size, ep_size, mtp_nextn, fp8kv,
"MOE TRTLLM backend does not support SM version 120 or 121")

moe_config = MoeConfig(backend=moe_backend, max_num_tokens=16384)
kv_cache_config = KvCacheConfig(enable_block_reuse=False,
free_gpu_memory_fraction=0.7,
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.7,
tokens_per_block=64)
cuda_graph_config = CudaGraphConfig(
enable_padding=True,
Expand Down Expand Up @@ -2730,8 +2725,7 @@ def test_nvfp4_multi_gpus_chunked_prefill(self, tp_size, pp_size, ep_size,
"MOE TRTLLM backend does not support SM version 120 or 121")

moe_config = MoeConfig(backend=moe_backend, max_num_tokens=16384)
kv_cache_config = KvCacheConfig(enable_block_reuse=False,
free_gpu_memory_fraction=0.7,
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.7,
tokens_per_block=64)
cuda_graph_config = CudaGraphConfig(
enable_padding=True,
Expand Down