From dfbebc0f5e080d3d05d0bcbfffe918e973104815 Mon Sep 17 00:00:00 2001 From: qixiang-99 <203170375+qixiang-99@users.noreply.github.com> Date: Wed, 9 Jul 2025 06:48:56 +0000 Subject: [PATCH 1/4] fix: WAR for under-allocation in torch VSWA kvcachemanager Signed-off-by: qixiang-99 <203170375+qixiang-99@users.noreply.github.com> feat: Add max_free_gpu_memory_size support for KV cache configuration - Introduced max_free_gpu_memory_size to manage GPU memory allocation for KV cache. - Updated KvCacheConfig and related methods to handle the new parameter. - Modified estimation logic in KvCacheCreator to utilize max_free_gpu_memory_size for VSWA cases. - Adjusted resource management to ensure compatibility with the new memory allocation strategy. Signed-off-by: qixiang-99 <203170375+qixiang-99@users.noreply.github.com> address comments: - vswa path should work now - modify `KvCacheCreator` so it will always provide KvCacheManager memory size instead of max_tokens - Handle user provided `max_tokens` inside `KvCacheCreator` -- it will raise warning for VSWA case; it will translate `maxt_tokens` to memory size for non-VSWA case Signed-off-by: qixiang-99 <203170375+qixiang-99@users.noreply.github.com> Respect max_gpu_total_bytes in KVCacheManager for non-VSWA case Signed-off-by: qixiang-99 <203170375+qixiang-99@users.noreply.github.com> add memory estimation for attention metadata to solve OOM issue when cuda graph is enabled and max window size is large Signed-off-by: qixiang-99 <203170375+qixiang-99@users.noreply.github.com> Enhance KVCacheManager to maintain adjusted max attention window sizes. Introduced an adjusted dictionary to track window size mappings and updated the logic to reflect these changes in the max attention window vector. Updated unit tests to validate the new behavior and ensure expected outputs for various memory configurations. Signed-off-by: qixiang-99 <203170375+qixiang-99@users.noreply.github.com> fix: Draft model calculation in KVCacheManager and KvCacheCreator - Introduced a method to calculate KV size per token and updated memory estimation to ensure proper handling of max_gpu_total_bytes. - Changed KVCacheManager to still use `max_tokens` in non-VSWA case as its`calculate_max_num_blocks` doesn't consider the draft model tokens calculation. Signed-off-by: qixiang-99 <203170375+qixiang-99@users.noreply.github.com> minor fix for description Signed-off-by: qixiang-99 <203170375+qixiang-99@users.noreply.github.com> Enhance KVCacheManager and related components to track allocated GPU memory for KV-cache. Added `allocatedBytes` to `KvCacheStats` and updated memory allocation logic in `KVCacheManager` and `KvCacheCreator` to ensure accurate memory usage reporting. Adjusted Python bindings to expose the new `allocated_bytes` attribute. Signed-off-by: qixiang-99 <203170375+qixiang-99@users.noreply.github.com> update logging Signed-off-by: qixiang-99 <203170375+qixiang-99@users.noreply.github.com> update Nanobind accordingly Signed-off-by: qixiang-99 <203170375+qixiang-99@users.noreply.github.com> Update secondary pool memory allocation in KVCacheManager to use host_cache_size from configuration if provided. Signed-off-by: qixiang-99 <203170375+qixiang-99@users.noreply.github.com> Refactor the fix for CUDA Graph OOM issue. Signed-off-by: qixiang-99 <203170375+qixiang-99@users.noreply.github.com> Add torch_empty_cache fixture to clear CUDA cache before tests Signed-off-by: qixiang-99 <203170375+qixiang-99@users.noreply.github.com> minor fix for rebase Signed-off-by: qixiang-99 <203170375+qixiang-99@users.noreply.github.com> Add validation for max_gpu_total_bytes in KvCacheConfig - Introduced a field validator to ensure that max_gpu_total_bytes is non-negative, enhancing input validation for the configuration. Signed-off-by: qixiang-99 <203170375+qixiang-99@users.noreply.github.com> --- .../batch_manager/kvCacheManager.h | 5 + cpp/include/tensorrt_llm/executor/executor.h | 11 +- .../batch_manager/kvCacheManager.cpp | 34 +++--- cpp/tensorrt_llm/executor/kvCacheConfig.cpp | 24 +++- .../nanobind/batch_manager/kvCacheManager.cpp | 3 +- .../nanobind/executor/executorConfig.cpp | 14 ++- .../pybind/batch_manager/kvCacheManager.cpp | 3 +- .../pybind/executor/executorConfig.cpp | 13 ++- tensorrt_llm/_torch/pyexecutor/_util.py | 106 ++++++++++++++---- .../_torch/pyexecutor/py_executor_creator.py | 2 +- .../_torch/pyexecutor/resource_manager.py | 29 ++++- tensorrt_llm/llmapi/llm_args.py | 15 ++- tests/integration/defs/conftest.py | 9 ++ .../test_disaggregated_single_gpu.py | 1 + .../_torch/executor/test_resource_manager.py | 18 ++- 15 files changed, 224 insertions(+), 63 deletions(-) diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h index df526a5dfbe..d5da697535f 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h @@ -192,6 +192,8 @@ struct KvCacheStats float cacheHitRate; // Number of free blocks for every configured attention-window size. std::map numFreeBlocksPerWindowSize; + // GPU bytes allocated for KV-cache + std::size_t allocatedBytes{}; }; // Basic building block of a paged KV cache - a single @@ -1474,6 +1476,7 @@ class KVCacheManager : public BaseKVCacheManager : static_cast(kvCacheStats.reusedBlocks) / static_cast(kvCacheStats.reusedBlocks + kvCacheStats.missedBlocks); kvCacheStats.numFreeBlocksPerWindowSize = getNumFreeBlocksPerWindowSize(); + kvCacheStats.allocatedBytes = mAllocatedBytes; return kvCacheStats; } @@ -1677,6 +1680,8 @@ class KVCacheManager : public BaseKVCacheManager runtime::ITensor::SharedPtr mBlockPoolPointers; runtime::ITensor::SharedPtr mLayerToPoolMapping; runtime::ITensor::SharedPtr mBlockScalePoolPointers; + // GPU bytes allocated for KV-cache + std::size_t mAllocatedBytes{0}; }; } // namespace tensorrt_llm::batch_manager::kv_cache_manager diff --git a/cpp/include/tensorrt_llm/executor/executor.h b/cpp/include/tensorrt_llm/executor/executor.h index 96345abf380..28c69074a3c 100644 --- a/cpp/include/tensorrt_llm/executor/executor.h +++ b/cpp/include/tensorrt_llm/executor/executor.h @@ -1006,7 +1006,8 @@ class KvCacheConfig std::optional secondaryOffloadMinPriority = std::nullopt, size_t eventBufferMaxSize = 0, bool enablePartialReuse = true, bool copyOnPartialReuse = true, bool useUvm = false, SizeType32 attentionDpEventsGatherPeriodMs = 5, - std::optional const& runtimeDefaults = std::nullopt); + std::optional const& runtimeDefaults = std::nullopt, + uint64_t const& maxGpuTotalBytes = 0); [[nodiscard]] bool getEnableBlockReuse() const; [[nodiscard]] bool getEnablePartialReuse() const; @@ -1022,11 +1023,12 @@ class KvCacheConfig [[nodiscard]] size_t getEventBufferMaxSize() const; [[nodiscard]] bool getUseUvm() const; [[nodiscard]] SizeType32 getAttentionDpEventsGatherPeriodMs() const; + [[nodiscard]] uint64_t getMaxGpuTotalBytes() const; void setEnableBlockReuse(bool enableBlockReuse); void setEnablePartialReuse(bool enablePartialReuse); void setCopyOnPartialReuse(bool copyOnPartialReuse); - void setMaxTokens(SizeType32 maxTokens); + void setMaxTokens(std::optional maxTokens); void setMaxAttentionWindowVec(std::vector maxAttentionWindowVec); void setSinkTokenLength(SizeType32 sinkTokenLength); void setFreeGpuMemoryFraction(FloatType freeGpuMemoryFraction); @@ -1037,6 +1039,7 @@ class KvCacheConfig void setEventBufferMaxSize(size_t eventBufferMaxSize); void setUseUvm(bool useUvm); void setAttentionDpEventsGatherPeriodMs(SizeType32 attentionDpEventsGatherPeriodMs); + void setMaxGpuTotalBytes(uint64_t maxGpuTotalBytes); void fillEmptyFieldsFromRuntimeDefaults(tensorrt_llm::runtime::RuntimeDefaults const& runtimeDefaults); @@ -1095,6 +1098,10 @@ class KvCacheConfig /// @brief The period in milliseconds to gather attention DP events across ranks SizeType32 mAttentionDpEventsGatherPeriodMs; + /// @brief The maximum size in bytes of GPU memory that can be allocated for the KV cache. + /// If both mMaxGpuTotalBytes and mFreeGpuMemoryFraction are specified, memory corresponding to the minimum will + /// be allocated. + uint64_t mMaxGpuTotalBytes; }; /// @brief Configuration class for the runtime perf knobs diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp index 0b793a041aa..364f2409bc4 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp @@ -1681,27 +1681,29 @@ void KVCacheManager::allocatePools(bool useUvm) mBlockManager.allocatePools(useUvm); auto const numPools = mBlockManager.getNumPools(); - if (tc::Logger::getLogger()->getLevel() <= tc::Logger::INFO) + uint64_t cacheSizeBytes = 0; + for (SizeType32 poolIdx = 0; poolIdx < numPools; poolIdx++) { - uint64_t cacheSizeBytes = 0; - for (SizeType32 poolIdx = 0; poolIdx < numPools; poolIdx++) - { - auto const cacheShape = mBlockManager.getPrimaryPool(poolIdx)->getShape(); - auto const cacheVolume = ITensor::volume(cacheShape); + auto const cacheShape = mBlockManager.getPrimaryPool(poolIdx)->getShape(); + auto const cacheVolume = ITensor::volume(cacheShape); #ifdef ENABLE_FP4 - auto const isFp4 = mDataType == nvinfer1::DataType::kFP4; + auto const isFp4 = mDataType == nvinfer1::DataType::kFP4; #else - auto const isFp4 = false; + auto const isFp4 = false; #endif - if (!isFp4) - { - cacheSizeBytes += cacheVolume * BufferDataType(mDataType).getSize(); - } - else - { - cacheSizeBytes += (cacheVolume * 4) / 8; - } + if (!isFp4) + { + cacheSizeBytes += cacheVolume * BufferDataType(mDataType).getSize(); + } + else + { + cacheSizeBytes += (cacheVolume * 4) / 8; } + } + // Save the total number of bytes allocated for the KV-cache for KvCacheStats + mAllocatedBytes = cacheSizeBytes; + if (tc::Logger::getLogger()->getLevel() <= tc::Logger::INFO) + { TLLM_LOG_INFO("Number of tokens per block: %d.", mBlockManager.getTokensPerBlock()); auto const maxNumTokens = mBlockManager.getNumPrimaryBlocks() * mBlockManager.getTokensPerBlock(); diff --git a/cpp/tensorrt_llm/executor/kvCacheConfig.cpp b/cpp/tensorrt_llm/executor/kvCacheConfig.cpp index 21cf314c875..1e83ba4b3a6 100644 --- a/cpp/tensorrt_llm/executor/kvCacheConfig.cpp +++ b/cpp/tensorrt_llm/executor/kvCacheConfig.cpp @@ -28,7 +28,7 @@ KvCacheConfig::KvCacheConfig(bool enableBlockReuse, std::optional co std::optional const& crossKvCacheFraction, std::optional secondaryOffloadMinPriority, size_t eventBufferMaxSize, bool enablePartialReuse, bool copyOnPartialReuse, bool useUvm, SizeType32 attentionDpEventsGatherPeriodMs, - std::optional const& runtimeDefaults) + std::optional const& runtimeDefaults, uint64_t const& maxGpuTotalBytes) : mEnableBlockReuse(enableBlockReuse) , mHostCacheSize(hostCacheSize) , mOnboardBlocks(onboardBlocks) @@ -38,6 +38,7 @@ KvCacheConfig::KvCacheConfig(bool enableBlockReuse, std::optional co , mCopyOnPartialReuse{copyOnPartialReuse} , mUseUvm{useUvm} , mAttentionDpEventsGatherPeriodMs(attentionDpEventsGatherPeriodMs) + , mMaxGpuTotalBytes{maxGpuTotalBytes} { if (maxTokens) { @@ -63,6 +64,10 @@ KvCacheConfig::KvCacheConfig(bool enableBlockReuse, std::optional co { fillEmptyFieldsFromRuntimeDefaults(runtimeDefaults.value()); } + if (maxGpuTotalBytes) + { + setMaxGpuTotalBytes(maxGpuTotalBytes); + } TLLM_CHECK_WITH_INFO( mAttentionDpEventsGatherPeriodMs > 0, "Attention DP events gather period must be greater than 0"); } @@ -137,6 +142,11 @@ SizeType32 KvCacheConfig::getAttentionDpEventsGatherPeriodMs() const return mAttentionDpEventsGatherPeriodMs; } +uint64_t KvCacheConfig::getMaxGpuTotalBytes() const +{ + return mMaxGpuTotalBytes; +} + void KvCacheConfig::setEnableBlockReuse(bool enableBlockReuse) { mEnableBlockReuse = enableBlockReuse; @@ -152,9 +162,12 @@ void KvCacheConfig::setCopyOnPartialReuse(bool copyOnPartialReuse) mCopyOnPartialReuse = copyOnPartialReuse; } -void KvCacheConfig::setMaxTokens(SizeType32 maxTokens) +void KvCacheConfig::setMaxTokens(std::optional maxTokens) { - TLLM_CHECK(maxTokens > 0); + if (maxTokens) + { + TLLM_CHECK(maxTokens.value() > 0); + } mMaxTokens = maxTokens; } @@ -219,6 +232,11 @@ void KvCacheConfig::setAttentionDpEventsGatherPeriodMs(SizeType32 attentionDpEve mAttentionDpEventsGatherPeriodMs = attentionDpEventsGatherPeriodMs; } +void KvCacheConfig::setMaxGpuTotalBytes(uint64_t maxGpuTotalBytes) +{ + mMaxGpuTotalBytes = maxGpuTotalBytes; +} + void KvCacheConfig::fillEmptyFieldsFromRuntimeDefaults(tensorrt_llm::runtime::RuntimeDefaults const& runtimeDefaults) { if (!mMaxAttentionWindowVec && runtimeDefaults.maxAttentionWindowVec) diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp index 412698215aa..b278570dad6 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp @@ -303,7 +303,8 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m) .def_rw("reused_blocks", &tbk::KvCacheStats::reusedBlocks) .def_rw("missed_blocks", &tbk::KvCacheStats::missedBlocks) .def_rw("cache_hit_rate", &tbk::KvCacheStats::cacheHitRate) - .def_rw("num_free_blocks_per_window_size", &tbk::KvCacheStats::numFreeBlocksPerWindowSize); + .def_rw("num_free_blocks_per_window_size", &tbk::KvCacheStats::numFreeBlocksPerWindowSize) + .def_ro("allocated_bytes", &tbk::KvCacheStats::allocatedBytes); nb::class_(m, "TempAttentionWindowInputs") .def(nb::init<>()) diff --git a/cpp/tensorrt_llm/nanobind/executor/executorConfig.cpp b/cpp/tensorrt_llm/nanobind/executor/executorConfig.cpp index 505ecfca595..0334eb14f6a 100644 --- a/cpp/tensorrt_llm/nanobind/executor/executorConfig.cpp +++ b/cpp/tensorrt_llm/nanobind/executor/executorConfig.cpp @@ -21,6 +21,7 @@ #include "tensorrt_llm/nanobind/common/customCasters.h" #include "tensorrt_llm/runtime/cudaStream.h" #include "tensorrt_llm/runtime/utils/mpiUtils.h" +#include #include #include #include @@ -111,11 +112,11 @@ void initConfigBindings(nb::module_& m) self.getSinkTokenLength(), self.getFreeGpuMemoryFraction(), self.getHostCacheSize(), self.getOnboardBlocks(), self.getCrossKvCacheFraction(), self.getSecondaryOffloadMinPriority(), self.getEventBufferMaxSize(), self.getEnablePartialReuse(), self.getCopyOnPartialReuse(), self.getUseUvm(), - self.getAttentionDpEventsGatherPeriodMs()); + self.getAttentionDpEventsGatherPeriodMs(), self.getMaxGpuTotalBytes()); }; auto kvCacheConfigSetstate = [](tle::KvCacheConfig& self, nb::tuple const& state) { - if (state.size() != 14) + if (state.size() != 15) { throw std::runtime_error("Invalid state!"); } @@ -125,20 +126,21 @@ void initConfigBindings(nb::module_& m) nb::cast(state[6]), nb::cast>(state[7]), nb::cast>(state[8]), nb::cast(state[9]), nb::cast(state[10]), nb::cast(state[11]), nb::cast(state[12]), - nb::cast(state[13])); + nb::cast(state[13]), std::nullopt, nb::cast(state[14])); }; nb::class_(m, "KvCacheConfig") .def(nb::init const&, std::optional> const&, std::optional const&, std::optional const&, std::optional const&, bool, std::optional const&, std::optional, size_t const&, bool, bool, bool, - SizeType32, std::optional const&>(), + SizeType32, std::optional const&, uint64_t const&>(), nb::arg("enable_block_reuse") = true, nb::arg("max_tokens") = nb::none(), nb::arg("max_attention_window") = nb::none(), nb::arg("sink_token_length") = nb::none(), nb::arg("free_gpu_memory_fraction") = nb::none(), nb::arg("host_cache_size") = nb::none(), nb::arg("onboard_blocks") = true, nb::arg("cross_kv_cache_fraction") = nb::none(), nb::arg("secondary_offload_min_priority") = nb::none(), nb::arg("event_buffer_max_size") = 0, nb::kw_only(), nb::arg("enable_partial_reuse") = true, nb::arg("copy_on_partial_reuse") = true, nb::arg("use_uvm") = false, - nb::arg("attention_dp_events_gather_period_ms") = 5, nb::arg("runtime_defaults") = nb::none()) + nb::arg("attention_dp_events_gather_period_ms") = 5, nb::arg("runtime_defaults") = nb::none(), + nb::arg("max_gpu_total_bytes") = 0) .def_prop_rw( "enable_block_reuse", &tle::KvCacheConfig::getEnableBlockReuse, &tle::KvCacheConfig::setEnableBlockReuse) .def_prop_rw("max_tokens", &tle::KvCacheConfig::getMaxTokens, &tle::KvCacheConfig::setMaxTokens) @@ -163,6 +165,8 @@ void initConfigBindings(nb::module_& m) .def_prop_rw("use_uvm", &tle::KvCacheConfig::getUseUvm, &tle::KvCacheConfig::setUseUvm) .def_prop_rw("attention_dp_events_gather_period_ms", &tle::KvCacheConfig::getAttentionDpEventsGatherPeriodMs, &tle::KvCacheConfig::setAttentionDpEventsGatherPeriodMs) + .def_prop_rw( + "max_gpu_total_bytes", &tle::KvCacheConfig::getMaxGpuTotalBytes, &tle::KvCacheConfig::setMaxGpuTotalBytes) .def("fill_empty_fields_from_runtime_defaults", &tle::KvCacheConfig::fillEmptyFieldsFromRuntimeDefaults) .def("__getstate__", kvCacheConfigGetstate) .def("__setstate__", kvCacheConfigSetstate); diff --git a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp index 54835e81d7f..fcbfaf9c64a 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp @@ -299,7 +299,8 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m) .def_readwrite("reused_blocks", &tbk::KvCacheStats::reusedBlocks) .def_readwrite("missed_blocks", &tbk::KvCacheStats::missedBlocks) .def_readwrite("cache_hit_rate", &tbk::KvCacheStats::cacheHitRate) - .def_readwrite("num_free_blocks_per_window_size", &tbk::KvCacheStats::numFreeBlocksPerWindowSize); + .def_readwrite("num_free_blocks_per_window_size", &tbk::KvCacheStats::numFreeBlocksPerWindowSize) + .def_readonly("allocated_bytes", &tbk::KvCacheStats::allocatedBytes); py::class_(m, "TempAttentionWindowInputs") .def(py::init<>()) diff --git a/cpp/tensorrt_llm/pybind/executor/executorConfig.cpp b/cpp/tensorrt_llm/pybind/executor/executorConfig.cpp index 0e279a3e47b..74e2fe56c16 100644 --- a/cpp/tensorrt_llm/pybind/executor/executorConfig.cpp +++ b/cpp/tensorrt_llm/pybind/executor/executorConfig.cpp @@ -104,11 +104,11 @@ void initConfigBindings(pybind11::module_& m) self.getSinkTokenLength(), self.getFreeGpuMemoryFraction(), self.getHostCacheSize(), self.getOnboardBlocks(), self.getCrossKvCacheFraction(), self.getSecondaryOffloadMinPriority(), self.getEventBufferMaxSize(), self.getEnablePartialReuse(), self.getCopyOnPartialReuse(), self.getUseUvm(), - self.getAttentionDpEventsGatherPeriodMs()); + self.getAttentionDpEventsGatherPeriodMs(), self.getMaxGpuTotalBytes()); }; auto kvCacheConfigSetstate = [](py::tuple const& state) { - if (state.size() != 14) + if (state.size() != 15) { throw std::runtime_error("Invalid state!"); } @@ -117,20 +117,21 @@ void initConfigBindings(pybind11::module_& m) state[4].cast>(), state[5].cast>(), state[6].cast(), state[7].cast>(), state[8].cast>(), state[9].cast(), state[10].cast(), state[11].cast(), state[12].cast(), - state[13].cast()); + state[13].cast(), std::nullopt, state[14].cast()); }; py::class_(m, "KvCacheConfig") .def(py::init const&, std::optional> const&, std::optional const&, std::optional const&, std::optional const&, bool, std::optional const&, std::optional, size_t const&, bool, bool, bool, - SizeType32, std::optional const&>(), + SizeType32, std::optional const&, uint64_t const&>(), py::arg("enable_block_reuse") = true, py::arg("max_tokens") = py::none(), py::arg("max_attention_window") = py::none(), py::arg("sink_token_length") = py::none(), py::arg("free_gpu_memory_fraction") = py::none(), py::arg("host_cache_size") = py::none(), py::arg("onboard_blocks") = true, py::arg("cross_kv_cache_fraction") = py::none(), py::arg("secondary_offload_min_priority") = py::none(), py::arg("event_buffer_max_size") = 0, py::kw_only(), py::arg("enable_partial_reuse") = true, py::arg("copy_on_partial_reuse") = true, py::arg("use_uvm") = false, - py::arg("attention_dp_events_gather_period_ms") = 5, py::arg("runtime_defaults") = py::none()) + py::arg("attention_dp_events_gather_period_ms") = 5, py::arg("runtime_defaults") = py::none(), + py::arg("max_gpu_total_bytes") = 0) .def_property( "enable_block_reuse", &tle::KvCacheConfig::getEnableBlockReuse, &tle::KvCacheConfig::setEnableBlockReuse) .def_property("max_tokens", &tle::KvCacheConfig::getMaxTokens, &tle::KvCacheConfig::setMaxTokens) @@ -140,6 +141,8 @@ void initConfigBindings(pybind11::module_& m) "sink_token_length", &tle::KvCacheConfig::getSinkTokenLength, &tle::KvCacheConfig::setSinkTokenLength) .def_property("free_gpu_memory_fraction", &tle::KvCacheConfig::getFreeGpuMemoryFraction, &tle::KvCacheConfig::setFreeGpuMemoryFraction) + .def_property( + "max_gpu_total_bytes", &tle::KvCacheConfig::getMaxGpuTotalBytes, &tle::KvCacheConfig::setMaxGpuTotalBytes) .def_property("host_cache_size", &tle::KvCacheConfig::getHostCacheSize, &tle::KvCacheConfig::setHostCacheSize) .def_property("onboard_blocks", &tle::KvCacheConfig::getOnboardBlocks, &tle::KvCacheConfig::setOnboardBlocks) .def_property("cross_kv_cache_fraction", &tle::KvCacheConfig::getCrossKvCacheFraction, diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index 20e3aaaa093..e3654ea2997 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -98,8 +98,7 @@ def _get_free_gpu_memory_fraction(self) -> float: fraction = 0.9 return fraction - def _cal_max_tokens(self, peak_memory, total_gpu_memory, fraction, - alloc_kv_tokens: int) -> int: + def _get_kv_size_per_token(self): model_config = self._model_engine.model.model_config mapping = self._mapping kv_size_per_token = self._get_cache_size_per_token( @@ -108,18 +107,25 @@ def _cal_max_tokens(self, peak_memory, total_gpu_memory, fraction, draft_model_config = self._draft_model_engine.model.model_config kv_size_per_token += self._get_cache_size_per_token( draft_model_config, mapping) + return kv_size_per_token + + def _cal_max_memory(self, peak_memory, total_gpu_memory, fraction, + allocated_bytes: int) -> int: + """ + Calculate the max KV cache capacity. + + NOTE: `allocated_bytes` is the total KV-cache memory that must be pre-allocated during the estimation phase (for both the main and draft models) so the estimation run can complete successfully. When computing `available_kv_mem`, add this amount back in. + """ + kv_size_per_token = self._get_kv_size_per_token() available_kv_mem = (total_gpu_memory - peak_memory + - alloc_kv_tokens * kv_size_per_token) * fraction + allocated_bytes) * fraction logger.info( f"Peak memory during memory usage profiling (torch + non-torch): {peak_memory / (GB):.2f} GiB, " f"available KV cache memory when calculating max tokens: {available_kv_mem / (GB):.2f} GiB, " f"fraction is set {fraction}, kv size is {kv_size_per_token}. device total memory {total_gpu_memory / (GB):.2f} GiB, " - f", tmp kv_mem { (alloc_kv_tokens * kv_size_per_token) / (GB):.2f} GiB" - ) - max_tokens = int((available_kv_mem) // kv_size_per_token) - max_tokens = max(max_tokens, 0) - return max_tokens + f", tmp kv_mem { (allocated_bytes) / (GB):.2f} GiB") + return int(available_kv_mem) def _create_dummy_context_requests( self, input_seq_len: int) -> List[trtllm.Request]: @@ -187,12 +193,17 @@ def try_prepare_estimation(self) -> bool: estimating_kv_cache = False if 'cp_type' not in self._mapping.cp_config: estimating_kv_cache = True - self._executor_config.kv_cache_config.max_tokens = self._get_token_num_for_estimation( - ) + max_attention_window_from_config = self._executor_config.kv_cache_config.max_attention_window + max_seq_len = max( + max_attention_window_from_config + ) if max_attention_window_from_config is not None else self._executor_config.max_seq_len + self._executor_config.kv_cache_config.max_tokens = max( + self._get_token_num_for_estimation(), max_seq_len) return estimating_kv_cache - def estimate_max_tokens(self, py_executor: PyExecutor) -> None: + def configure_kv_cache_capacity(self, py_executor: PyExecutor) -> None: """Perform KV cache capacity estimation. + NOTE: for VSWA case, we calculate and set kv cache memory instead of using max_tokens in kv_cache_config. This updates `kv_cache_config`. """ @@ -258,19 +269,74 @@ def estimate_max_tokens(self, py_executor: PyExecutor) -> None: logger.info( f"Memory used outside torch (e.g., NCCL and CUDA graphs) in memory usage profiling: {extra_cost / (GB):.2f} GiB" ) + + # get kv cache stats for both model and draft model kv_stats = py_executor.resource_manager.resource_managers.get( ResourceManagerType.KV_CACHE_MANAGER).get_kv_cache_stats() - - kv_cache_max_tokens = self._cal_max_tokens( - peak_memory, total_gpu_memory, fraction, - kv_stats.max_num_blocks * kv_stats.tokens_per_block) - + kv_stats_draft = py_executor.resource_manager.resource_managers.get( + ResourceManagerType.DRAFT_KV_CACHE_MANAGER).get_kv_cache_stats( + ) if self._draft_model_engine is not None else None + + # get total allocated bytes + allocated_bytes = kv_stats.allocated_bytes + ( + kv_stats_draft.allocated_bytes if kv_stats_draft is not None else 0) + + # calculate max memory from peak memory and free gpu memory fraction + kv_cache_max_memory = self._cal_max_memory(peak_memory, + total_gpu_memory, fraction, + allocated_bytes) + + max_attention_window = executor_config.kv_cache_config.max_attention_window + is_vswa = max_attention_window and len(max_attention_window) > 1 + + # NOTE: + # KvCacheCreator currently controls KV-cache capacity using two parameters in KVCacheConfig: + # • max_tokens + # • max_gpu_total_bytes + # Ideally, the internal logic would rely solely on max_gpu_total_bytes, + # leaving max_tokens as a user-defined constraint. + + # ---------------------------handle max_tokens--------------------------------- + # if user provided max_tokens, calculate max memory from max_tokens if self._max_kv_tokens_in is not None: - kv_cache_max_tokens = min(kv_cache_max_tokens, - self._max_kv_tokens_in) + # raise error if it is VSWA case + if is_vswa: + logger.warning( + "max_tokens should not be set for VSWA case as it is ambiguous concept for VSWA." + ) + # calculate max memory from max_tokens + kv_cache_max_memory_from_max_tokens = self._max_kv_tokens_in * self._get_kv_size_per_token( + ) + kv_cache_max_memory = min(kv_cache_max_memory, + kv_cache_max_memory_from_max_tokens) + logger.info( + f"max_tokens={self._max_kv_tokens_in} is provided, max_memory is set to {kv_cache_max_memory / (GB):.2f} GiB" + ) + if is_vswa: + # For VSWA KvCacheManager now it can only use max_gpu_total_bytes + executor_config.kv_cache_config.max_tokens = None + else: + # For non-VSWA KvCacheManager, its logic still relies on max_tokens, need to improve in the future. + executor_config.kv_cache_config.max_tokens = int( + kv_cache_max_memory // self._get_kv_size_per_token()) + # ---------------------------handle max_tokens--------------------------------- + + # ---------------------------handle max_gpu_total_bytes--------------------------------- + # if user provided max_gpu_total_bytes, set max memory from max_gpu_total_bytes + if executor_config.kv_cache_config.max_gpu_total_bytes > 0: + kv_cache_max_memory = min( + kv_cache_max_memory, + executor_config.kv_cache_config.max_gpu_total_bytes) + logger.info( + f"max_gpu_total_bytes={executor_config.kv_cache_config.max_gpu_total_bytes / (GB):.2f} GiB is provided, max_memory is set to {kv_cache_max_memory / (GB):.2f} GiB" + ) - logger.info(f"Estimated max tokens in KV cache : {kv_cache_max_tokens}") - executor_config.kv_cache_config.max_tokens = kv_cache_max_tokens + logger.info( + f"Estimated max memory in KV cache : {kv_cache_max_memory / (GB):.2f} GiB" + ) + # set max_gpu_total_bytes + executor_config.kv_cache_config.max_gpu_total_bytes = kv_cache_max_memory + # ---------------------------handle max_gpu_total_bytes--------------------------------- def _create_kv_cache_manager( self, model_engine: PyTorchModelEngine) -> KVCacheManager: diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index 12686728cda..6a14988f5df 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -428,7 +428,7 @@ def create_py_executor( assert kv_cache_creator is not None with mem_monitor.observe_creation_stage( _ExecutorCreationStage.MODEL_EXTRA): - kv_cache_creator.estimate_max_tokens(py_executor) + kv_cache_creator.configure_kv_cache_capacity(py_executor) kv_cache_creator.teardown_managers(resources) del py_executor # free before constructing new diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index 4066b45cf89..688bb0bc0b7 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -264,6 +264,7 @@ def append_to_kv_heads_per_layer(num_kv_heads_per_layer: List[int], assert isinstance( kv_cache_config, KvCacheConfigCpp ), "calculate_max_num_blocks_from_cpp only accepts KvCacheConfigCpp" + blocks_per_window = self.calculate_max_num_blocks_from_cpp( kv_cache_config=kv_cache_config, model_config=model_config, @@ -686,13 +687,14 @@ def _get_window_size_to_layers(self) -> dict[int, list[int]]: @staticmethod def adjust_window_sizes_for_vswa( window_size_to_layers: Dict[int, List[int]], + max_attention_window_vec: List[int], kv_cache_config: KvCacheConfigCpp, model_config: ModelConfig, pool_memory_bytes: int, kv_factor: int, dtype: DataType, is_cross_attention: bool = False, - ) -> Dict[int, List[int]]: + ) -> Tuple[Dict[int, List[int]], List[int]]: assert is_cross_attention is False, 'Cross attention is not supported' @@ -717,7 +719,8 @@ def calculate_cache_size_per_token(layers: Set[int]) -> int: if required_mem_bytes_per_seq < pool_memory_bytes: # No need to adjust the window sizes. - return copy.deepcopy(window_size_to_layers) + return (copy.deepcopy(window_size_to_layers), + max_attention_window_vec) logger.debug( f'Adjusting the window sizes {list(window_size_to_layers)} to fit ' @@ -730,6 +733,8 @@ def calculate_cache_size_per_token(layers: Set[int]) -> int: accum_max_tokens = 0 prev_window_size = 0 + adjusted_dict = {} + adjusted_max_attention_window_vec = max_attention_window_vec.copy() for window_size in sorted(window_size_to_layers): layers = window_size_to_layers[window_size] @@ -773,11 +778,18 @@ def calculate_cache_size_per_token(layers: Set[int]) -> int: adjusted_window_size_to_layers[accum_max_tokens] = layers.copy() else: adjusted_window_size_to_layers[accum_max_tokens].extend(layers) + adjusted_dict[window_size] = accum_max_tokens + # also update adjusted_max_attention_window_vec + adjusted_max_attention_window_vec = [ + adjusted_dict.get(v, v) + for v in adjusted_max_attention_window_vec + ] remaining_layers -= set(layers) prev_window_size = window_size - return adjusted_window_size_to_layers + return (adjusted_window_size_to_layers, + adjusted_max_attention_window_vec) def calculate_max_num_blocks_from_cpp( self, @@ -814,8 +826,11 @@ def calculate_max_num_blocks_from_cpp( logger.debug(f"window_size_to_layers: {window_size_to_layers}") free_mem, total_mem = torch.cuda.mem_get_info() - primary_pool_memory_bytes = free_mem - secondary_pool_memory_bytes = 0 + # Respect max_gpu_total_bytes if provided + free_gpu_memory_fraction = kv_cache_config.free_gpu_memory_fraction if kv_cache_config.free_gpu_memory_fraction else 0.9 + primary_pool_memory_bytes = kv_cache_config.max_gpu_total_bytes if kv_cache_config.max_gpu_total_bytes > 0 else int( + free_mem * free_gpu_memory_fraction) + secondary_pool_memory_bytes = kv_cache_config.host_cache_size if kv_cache_config.host_cache_size else 0 logger.debug( f"primary_pool_memory_bytes is set to {primary_pool_memory_bytes/1024**3}GB, \n" f"secondary_pool_memory_bytes is set to {secondary_pool_memory_bytes/1024**3}GB" @@ -823,8 +838,9 @@ def calculate_max_num_blocks_from_cpp( # Adjust the window sizes to fit the memory if even a single sequence # cannot fit in the memory. - window_size_to_layers = self.adjust_window_sizes_for_vswa( + window_size_to_layers, max_attention_window_vec = self.adjust_window_sizes_for_vswa( window_size_to_layers=window_size_to_layers, + max_attention_window_vec=self.max_attention_window_vec, model_config=model_config, kv_cache_config=kv_cache_config, pool_memory_bytes=primary_pool_memory_bytes, @@ -832,6 +848,7 @@ def calculate_max_num_blocks_from_cpp( dtype=self.dtype, is_cross_attention=is_cross_attention, ) + self.max_attention_window_vec = max_attention_window_vec blocks_per_window = KVCacheManagerCpp.calculate_max_num_blocks( config=kv_cache_config, diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index da5071e3b0a..1bdeb6412c7 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -991,6 +991,11 @@ class KvCacheConfig(StrictBaseModel, PybindMirror): ) use_uvm: bool = Field(default=False, description="Whether to use UVM for the KV cache.") + max_gpu_total_bytes: int = Field( + default=0, + description= + "The maximum size in bytes of GPU memory that can be allocated for the KV cache. If both `max_gpu_total_bytes` and `free_gpu_memory_fraction` are specified, memory corresponding to the minimum will be allocated." + ) # This is a pure python field, not a pybind field. It is only for the Pytorch backend. dtype: str = Field(default="auto", @@ -1021,7 +1026,15 @@ def _to_pybind(self): use_uvm=self.use_uvm, attention_dp_events_gather_period_ms=self. attention_dp_events_gather_period_ms, - ) + max_gpu_total_bytes=self.max_gpu_total_bytes) + + @field_validator('max_gpu_total_bytes') + @classmethod + def validate_max_gpu_total_bytes(cls, v: int): + if v < 0: + raise ValueError( + "kv_cache_config.max_gpu_total_bytes must be non-negative") + return v @PybindMirror.mirror_pybind_fields(_ExtendedRuntimePerfKnobConfig) diff --git a/tests/integration/defs/conftest.py b/tests/integration/defs/conftest.py index 49b485cefcb..f3e86b850f3 100644 --- a/tests/integration/defs/conftest.py +++ b/tests/integration/defs/conftest.py @@ -2386,3 +2386,12 @@ def timeout_manager(timeout_from_command_line, timeout_from_marker): timeout_value = timeout_from_command_line return TimeoutManager(timeout_value) + + +@pytest.fixture +def torch_empty_cache() -> None: + """ + Manually empty the torch CUDA cache before each test, to reduce risk of OOM errors. + """ + if torch.cuda.is_available(): + torch.cuda.empty_cache() diff --git a/tests/integration/defs/disaggregated/test_disaggregated_single_gpu.py b/tests/integration/defs/disaggregated/test_disaggregated_single_gpu.py index 93611de040b..d605a33983f 100644 --- a/tests/integration/defs/disaggregated/test_disaggregated_single_gpu.py +++ b/tests/integration/defs/disaggregated/test_disaggregated_single_gpu.py @@ -350,6 +350,7 @@ def test_disaggregated_llama_context_capacity(model, enable_cuda_graph, @pytest.mark.parametrize("spec_dec_model_path", ["EAGLE3-LLaMA3.1-Instruct-8B"]) @pytest.mark.parametrize("generation_overlap", [False]) @pytest.mark.parametrize("eagle3_one_model", [True, False]) +@pytest.mark.usefixtures("torch_empty_cache") def test_disaggregated_spec_dec_batch_slot_limit(model, spec_dec_model_path, generation_overlap, eagle3_one_model): diff --git a/tests/unittest/_torch/executor/test_resource_manager.py b/tests/unittest/_torch/executor/test_resource_manager.py index dc67e9bbfcd..4d8b9ae7896 100644 --- a/tests/unittest/_torch/executor/test_resource_manager.py +++ b/tests/unittest/_torch/executor/test_resource_manager.py @@ -438,6 +438,7 @@ def test_adjust_window_sizes_for_vswa(self): 200: [4, 5, 6], 7000: [7, 8], } + max_attention_window_vec = [100] * 4 + [200] * 3 + [7000] * 2 model_config = self.MockModelConfig() model_config.num_attention_heads = 2 @@ -465,6 +466,7 @@ def test_adjust_window_sizes_for_vswa(self): 100: [0, 1, 2, 3], 130: [4, 5, 6, 7, 8], }, + [100] * 4 + [130] * 5, None, "limited_memory_clamped_windows"), ( @@ -476,6 +478,7 @@ def test_adjust_window_sizes_for_vswa(self): 200: [4, 5, 6], 1017: [7, 8], }, + [100] * 4 + [200] * 3 + [1017] * 2, None, "less_limited_memory_clamped_windows"), ( @@ -487,6 +490,7 @@ def test_adjust_window_sizes_for_vswa(self): 200: [4, 5, 6], 7000: [7, 8], }, + [100] * 4 + [200] * 3 + [7000] * 2, None, "sufficient_memory_no_clamping"), ( @@ -495,6 +499,7 @@ def test_adjust_window_sizes_for_vswa(self): { 51: [0, 1, 2, 3, 4, 5, 6, 7, 8], }, + [51] * 9, None, "very_limited_memory_all_clamped"), ( @@ -506,15 +511,17 @@ def test_adjust_window_sizes_for_vswa(self): 100: [0, 1, 2, 3], 134: [4, 5, 6, 7, 8], }, + [100] * 4 + [134] * 5, 134, "less_limited_memory_but_clamped_by_max_tokens"), ] - for memory_bytes, expected_window_sizes, max_tokens, description in test_cases: + for memory_bytes, expected_window_sizes, expected_max_attention_window_vec, max_tokens, description in test_cases: with self.subTest(case=description, memory_bytes=memory_bytes): kv_cache_config = tllm.KvCacheConfig(max_tokens=max_tokens) - adjusted = KVCacheManager.adjust_window_sizes_for_vswa( + adjusted, adjusted_max_attention_window_vec = KVCacheManager.adjust_window_sizes_for_vswa( window_size_to_layers=window_size_to_layers, + max_attention_window_vec=max_attention_window_vec, model_config=model_config, kv_cache_config=kv_cache_config, pool_memory_bytes=memory_bytes, @@ -529,6 +536,13 @@ def test_adjust_window_sizes_for_vswa(self): f"Memory bytes: {memory_bytes}\n" f"Actual: {adjusted}\n" f"Expected: {expected_window_sizes}") + self.assertEqual( + adjusted_max_attention_window_vec, + expected_max_attention_window_vec, + f"Test case '{description}' failed.\n" + f"Memory bytes: {memory_bytes}\n" + f"Actual: {adjusted_max_attention_window_vec}\n" + f"Expected: {expected_max_attention_window_vec}") if __name__ == "__main__": From 6837e601953353d955eb0161b9baa87dc02b19c0 Mon Sep 17 00:00:00 2001 From: qixiang-99 <203170375+qixiang-99@users.noreply.github.com> Date: Tue, 12 Aug 2025 20:21:20 +0000 Subject: [PATCH 2/4] feat: Enhance max_attention_window validation in KvCacheConfig - Added a field validator for max_attention_window to ensure it is a non-empty list of positive integers. - Implemented checks to prevent redundant patterns in the list, ensuring only minimal repeating patterns are accepted. - Updated KvCacheCreator to use a set for determining VSWA status based on max_attention_window. Signed-off-by: qixiang-99 <203170375+qixiang-99@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/_util.py | 2 +- .../_torch/pyexecutor/resource_manager.py | 2 +- tensorrt_llm/llmapi/llm_args.py | 31 +++++++++++++++++++ 3 files changed, 33 insertions(+), 2 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index e3654ea2997..7cd93269d9a 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -287,7 +287,7 @@ def configure_kv_cache_capacity(self, py_executor: PyExecutor) -> None: allocated_bytes) max_attention_window = executor_config.kv_cache_config.max_attention_window - is_vswa = max_attention_window and len(max_attention_window) > 1 + is_vswa = max_attention_window and len(set(max_attention_window)) > 1 # NOTE: # KvCacheCreator currently controls KV-cache capacity using two parameters in KVCacheConfig: diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index 688bb0bc0b7..23257bc952a 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -250,7 +250,7 @@ def append_to_kv_heads_per_layer(num_kv_heads_per_layer: List[int], else 0) # Determine if this is VSWA (Variable Sliding Window Attention) - self.is_vswa = len(self.max_attention_window_vec) > 1 + self.is_vswa = len(set(self.max_attention_window_vec)) > 1 # Calculate blocks per window using appropriate method if self.is_vswa: diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 1bdeb6412c7..aff6d74744d 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -1036,6 +1036,37 @@ def validate_max_gpu_total_bytes(cls, v: int): "kv_cache_config.max_gpu_total_bytes must be non-negative") return v + @field_validator('max_attention_window') + @classmethod + def validate_max_attention_window(cls, v: Optional[List[int]]): + # Allow unset + if v is None: + return v + + # Must be a non-empty list of positive integers + if not isinstance(v, list) or len(v) == 0: + raise ValueError( + "kv_cache_config.max_attention_window must be a non-empty list of positive integers" + ) + for i in v: + if not isinstance(i, int): + raise ValueError( + "kv_cache_config.max_attention_window must contain only integers" + ) + if i <= 0: + raise ValueError( + "kv_cache_config.max_attention_window values must be positive" + ) + + # Must not be a redundant repetition of a shorter pattern + n = len(v) + for k in range(1, n): + if n % k == 0 and v == v[:k] * (n // k): + raise ValueError( + f"kv_cache_config.max_attention_window should contain only the minimal repeating pattern; use {v[:k]} instead of {v}" + ) + return v + @PybindMirror.mirror_pybind_fields(_ExtendedRuntimePerfKnobConfig) class ExtendedRuntimePerfKnobConfig(StrictBaseModel, PybindMirror): From 33cc07de84b8e6b7fc2f4a0917052c57a87b3c93 Mon Sep 17 00:00:00 2001 From: qixiang-99 <203170375+qixiang-99@users.noreply.github.com> Date: Wed, 13 Aug 2025 05:33:18 +0000 Subject: [PATCH 3/4] test: Add test to protect the memory allocation behavior. - Changed primary and secondary pool memory allocation to use instance variables instead of local variables for better clarity and maintainability. - Updated logging to reflect the new instance variable usage. - Added unit tests to validate memory allocation behavior in KVCacheManager. Signed-off-by: qixiang-99 <203170375+qixiang-99@users.noreply.github.com> --- .../_torch/pyexecutor/resource_manager.py | 15 +- tensorrt_llm/llmapi/llm_args.py | 8 - .../_torch/executor/test_resource_manager.py | 146 ++++++++++++++++++ 3 files changed, 153 insertions(+), 16 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index 23257bc952a..f9ad784a4cb 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -264,7 +264,6 @@ def append_to_kv_heads_per_layer(num_kv_heads_per_layer: List[int], assert isinstance( kv_cache_config, KvCacheConfigCpp ), "calculate_max_num_blocks_from_cpp only accepts KvCacheConfigCpp" - blocks_per_window = self.calculate_max_num_blocks_from_cpp( kv_cache_config=kv_cache_config, model_config=model_config, @@ -828,12 +827,12 @@ def calculate_max_num_blocks_from_cpp( free_mem, total_mem = torch.cuda.mem_get_info() # Respect max_gpu_total_bytes if provided free_gpu_memory_fraction = kv_cache_config.free_gpu_memory_fraction if kv_cache_config.free_gpu_memory_fraction else 0.9 - primary_pool_memory_bytes = kv_cache_config.max_gpu_total_bytes if kv_cache_config.max_gpu_total_bytes > 0 else int( + self._primary_pool_memory_bytes = kv_cache_config.max_gpu_total_bytes if kv_cache_config.max_gpu_total_bytes > 0 else int( free_mem * free_gpu_memory_fraction) - secondary_pool_memory_bytes = kv_cache_config.host_cache_size if kv_cache_config.host_cache_size else 0 + self._secondary_pool_memory_bytes = kv_cache_config.host_cache_size if kv_cache_config.host_cache_size else 0 logger.debug( - f"primary_pool_memory_bytes is set to {primary_pool_memory_bytes/1024**3}GB, \n" - f"secondary_pool_memory_bytes is set to {secondary_pool_memory_bytes/1024**3}GB" + f"primary_pool_memory_bytes is set to {self._primary_pool_memory_bytes/1024**3}GB, \n" + f"secondary_pool_memory_bytes is set to {self._secondary_pool_memory_bytes/1024**3}GB" ) # Adjust the window sizes to fit the memory if even a single sequence @@ -843,7 +842,7 @@ def calculate_max_num_blocks_from_cpp( max_attention_window_vec=self.max_attention_window_vec, model_config=model_config, kv_cache_config=kv_cache_config, - pool_memory_bytes=primary_pool_memory_bytes, + pool_memory_bytes=self._primary_pool_memory_bytes, kv_factor=self.kv_factor, dtype=self.dtype, is_cross_attention=is_cross_attention, @@ -858,8 +857,8 @@ def calculate_max_num_blocks_from_cpp( model_config=model_config, world_config=world_config_cpp, window_size_to_layers=window_size_to_layers, - allotted_primary_mem_bytes=primary_pool_memory_bytes, - allotted_secondary_mem_bytes=secondary_pool_memory_bytes, + allotted_primary_mem_bytes=self._primary_pool_memory_bytes, + allotted_secondary_mem_bytes=self._secondary_pool_memory_bytes, extra_cost_memory=extra_cost_memory, kv_factor=self.kv_factor, ) diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index aff6d74744d..e033c740443 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -1057,14 +1057,6 @@ def validate_max_attention_window(cls, v: Optional[List[int]]): raise ValueError( "kv_cache_config.max_attention_window values must be positive" ) - - # Must not be a redundant repetition of a shorter pattern - n = len(v) - for k in range(1, n): - if n % k == 0 and v == v[:k] * (n // k): - raise ValueError( - f"kv_cache_config.max_attention_window should contain only the minimal repeating pattern; use {v[:k]} instead of {v}" - ) return v diff --git a/tests/unittest/_torch/executor/test_resource_manager.py b/tests/unittest/_torch/executor/test_resource_manager.py index 4d8b9ae7896..d7af41d174d 100644 --- a/tests/unittest/_torch/executor/test_resource_manager.py +++ b/tests/unittest/_torch/executor/test_resource_manager.py @@ -3,6 +3,8 @@ import subprocess import sys import unittest +from typing import NamedTuple, Tuple +from unittest.mock import patch import numpy as np import torch @@ -13,11 +15,13 @@ from tensorrt_llm._torch.pyexecutor.resource_manager import (KVCacheManager, PeftCacheConfig, PeftCacheManager) +from tensorrt_llm.bindings import LayerType from tensorrt_llm.bindings import ModelConfig as ModelConfigCpp from tensorrt_llm.bindings import executor as tllm from tensorrt_llm.bindings.internal.batch_manager import \ PeftTaskNotCachedException from tensorrt_llm.lora_helper import LoraConfig +from tensorrt_llm.mapping import Mapping DataType = tensorrt_llm.bindings.DataType LoraModule = tensorrt_llm.bindings.LoraModule @@ -544,6 +548,148 @@ def test_adjust_window_sizes_for_vswa(self): f"Actual: {adjusted_max_attention_window_vec}\n" f"Expected: {expected_max_attention_window_vec}") + @staticmethod + def _create_model_config_for_kv_cache_manager() -> ModelConfigCpp: + """ + Create a simple model config for KVCacheManager test. + """ + + model_config_params = { + "vocab_size": 0, + "num_layers": 4, + "num_attention_layers": 4, + "num_rnn_layers": 0, + "num_heads": 64, + "hidden_size": 64, + "data_type": DataType.HALF + } + num_kv_heads = 8 + + model_config = ModelConfigCpp(**model_config_params) + model_config.layer_types = [LayerType.ATTENTION + ] * model_config.num_attention_layers() + model_config.set_num_kv_heads(num_kv_heads) + + return model_config + + @staticmethod + def _create_kv_cache_config_for_kv_cache_manager( + params: dict) -> tllm.KvCacheConfig: + """ + Create a KV cache config for KVCacheManager test. + """ + return tllm.KvCacheConfig(**params) + + def test_calculate_max_num_blocks_from_cpp(self): + # Construct a minimal mapping (single-rank, no TP/PP) + mapping = Mapping(world_size=1, tp_size=1, pp_size=1) + + # Construct model config + model_config = TestResourceManager._create_model_config_for_kv_cache_manager( + ) + + # Construct KV cache config + free_gpu_memory_fraction = 0.1 + max_attention_window = [64, 128] + max_gpu_total_bytes = 32 * 1024 * 1024 # 32MB + enable_block_reuse = False + host_cache_size = 32 * 1024 * 1024 # 32MB + + # mock values for torch.cuda.mem_get_info to return a fixed value + fixed_free_mem = 128 * 1024 * 1024 # 128MB + fixed_total_mem = 256 * 1024 * 1024 # 256MB + + class MemTestCase(NamedTuple): + case_name: str + kv_cache_config_params: dict + expected_memory_bytes: Tuple[ + int, + int] # (primary_pool_memory_bytes, secondary_pool_memory_bytes) + + test_cases = [ + # Case 1: + # max_gpu_total_bytes is set, even if free_gpu_memory_fraction is set, we will use max_gpu_total_bytes + # host_cache_size is set, we will use host_cache_size + MemTestCase( + case_name="max_gpu_total_bytes is set, host_cache_size is set", + kv_cache_config_params={ + "max_attention_window": max_attention_window, + "free_gpu_memory_fraction": free_gpu_memory_fraction, + "max_gpu_total_bytes": max_gpu_total_bytes, + "enable_block_reuse": enable_block_reuse, + "host_cache_size": host_cache_size, + }, + expected_memory_bytes=(max_gpu_total_bytes, host_cache_size), + ), + + # Case 2: + # max_gpu_total_bytes is not set, we will use free_gpu_memory_fraction + # host_cache_size is not set, we will use 0 + MemTestCase( + case_name= + "max_gpu_total_bytes is not set, host_cache_size is not set", + kv_cache_config_params={ + "max_attention_window": max_attention_window, + "free_gpu_memory_fraction": free_gpu_memory_fraction, + "enable_block_reuse": enable_block_reuse, + }, + # NOTE: use np.float32 to avoid float precision issue between python(double in most cases) and cpp binding(float) + expected_memory_bytes=(int( + fixed_free_mem * np.float32(free_gpu_memory_fraction)), 0), + ), + ] + + tokens_per_block = 32 + model_config.tokens_per_block = tokens_per_block + max_seq_len = max(max_attention_window) + max_batch_size = 1 + max_beam_width = 1 + + for case_name, kv_cache_config_params, expected_memory_bytes in test_cases: + with self.subTest(case=case_name): + kv_cache_config = TestResourceManager._create_kv_cache_config_for_kv_cache_manager( + kv_cache_config_params) + with patch('torch.cuda.mem_get_info', + return_value=(fixed_free_mem, fixed_total_mem)): + # Create a real KVCacheManager, it will run calculate_max_num_blocks_from_cpp in __init__ + manager = KVCacheManager( + kv_cache_config=kv_cache_config, + kv_cache_type=tensorrt_llm.bindings.internal. + batch_manager.CacheType.SELF, + num_layers=model_config.num_attention_layers(), + num_kv_heads=model_config.num_kv_heads( + 0 + ), # NOTE: assume same number of kv heads for all layers + head_dim=model_config.head_size, + tokens_per_block=tokens_per_block, + max_seq_len=max_seq_len, + max_batch_size=max_batch_size, + mapping=mapping, + dtype=model_config.data_type, + model_config=model_config, + max_beam_width=max_beam_width, + ) + try: + expected_primary, expected_secondary = expected_memory_bytes + self.assertEqual( + manager._primary_pool_memory_bytes, + expected_primary, + f"Test case '{case_name}' failed.\n" + f"Expected primary pool memory bytes: {expected_primary}\n" + f"Actual primary pool memory bytes: {manager._primary_pool_memory_bytes}" + ) + self.assertEqual( + manager._secondary_pool_memory_bytes, + expected_secondary, + f"Test case '{case_name}' failed.\n" + f"Expected secondary pool memory bytes: {expected_secondary}\n" + f"Actual secondary pool memory bytes: {manager._secondary_pool_memory_bytes}" + ) + except Exception as e: + self.fail(f"Test case '{case_name}' failed: {e}") + finally: + manager.shutdown() + if __name__ == "__main__": unittest.main() From 6e93c1b91f4985c5aff564fef9381039af93a18d Mon Sep 17 00:00:00 2001 From: qixiang-99 <203170375+qixiang-99@users.noreply.github.com> Date: Fri, 15 Aug 2025 06:22:02 +0000 Subject: [PATCH 4/4] revert: revert the change for prepare estimation: - we shouldn't use max_seq_len as the kv config max_tokens as it doesn't need that much - and it makes preparation to have OOM more easily especially long seq. Signed-off-by: qixiang-99 <203170375+qixiang-99@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/_util.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index 7cd93269d9a..8191318a750 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -193,12 +193,8 @@ def try_prepare_estimation(self) -> bool: estimating_kv_cache = False if 'cp_type' not in self._mapping.cp_config: estimating_kv_cache = True - max_attention_window_from_config = self._executor_config.kv_cache_config.max_attention_window - max_seq_len = max( - max_attention_window_from_config - ) if max_attention_window_from_config is not None else self._executor_config.max_seq_len - self._executor_config.kv_cache_config.max_tokens = max( - self._get_token_num_for_estimation(), max_seq_len) + self._executor_config.kv_cache_config.max_tokens = self._get_token_num_for_estimation( + ) return estimating_kv_cache def configure_kv_cache_capacity(self, py_executor: PyExecutor) -> None: