Skip to content

Commit 40d129a

Browse files
authored
[None][fix] Fix cache buffer size for window (#8320)
Signed-off-by: Chuang Zhu <[email protected]>
1 parent e265eb5 commit 40d129a

File tree

11 files changed

+38
-38
lines changed

11 files changed

+38
-38
lines changed

cpp/tensorrt_llm/batch_manager/cacheTransBuffer.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,11 @@ CacheTransBufferManager::CacheTransBufferManager(
210210
{
211211
auto poolIdx = mCacheManager->getBlockManager().getLayerPoolIdx(layerId);
212212
auto windowSize = static_cast<size_t>(mCacheManager->getBlockManager().getPoolWindowSize(poolIdx));
213-
auto validTokenNum = (windowSize < maxNumTokens.value() ? windowSize : maxNumTokens.value());
213+
auto alignedWindowSize = (windowSize + tokensPerBlock - 1) / tokensPerBlock * tokensPerBlock;
214+
auto validTokenNum = (alignedWindowSize < maxNumTokens.value() ? alignedWindowSize : maxNumTokens.value());
215+
// if windowSize % (tokensPerBlock) !=0
216+
validTokenNum += tokensPerBlock; // add one more block
217+
214218
bufferSizeFromMaxNumToken += validTokenNum * kvCacheByteSizePerTokenPerLayer;
215219
}
216220
}
@@ -237,7 +241,7 @@ CacheTransBufferManager::CacheTransBufferManager(
237241
allocateBuffer();
238242
}
239243

240-
size_t CacheTransBufferManager::preAllocBufferSize(
244+
size_t CacheTransBufferManager::preAllocBufferSize(size_t tokensPerBlock,
241245
std::map<SizeType32, SizeType32> const& cacheSizeBytesPerTokenPerWindow,
242246
std::optional<executor::CacheTransceiverConfig> const& cacheTransceiverConfig)
243247
{
@@ -256,9 +260,9 @@ size_t CacheTransBufferManager::preAllocBufferSize(
256260
TransferBufferSize = 0;
257261
for (auto const& [windowSize, cacheSizeBytesPerToken] : cacheSizeBytesPerTokenPerWindow)
258262
{
259-
auto validTokenNum
260-
= (static_cast<size_t>(windowSize) < maxNumTokens.value() ? static_cast<size_t>(windowSize)
261-
: maxNumTokens.value());
263+
auto alignedWindowSize = (windowSize + tokensPerBlock - 1) / tokensPerBlock * tokensPerBlock;
264+
auto validTokenNum = (alignedWindowSize < maxNumTokens.value() ? alignedWindowSize : maxNumTokens.value());
265+
validTokenNum += tokensPerBlock; // add one more block
262266
TransferBufferSize += validTokenNum * cacheSizeBytesPerToken;
263267
}
264268
}

cpp/tensorrt_llm/batch_manager/cacheTransBuffer.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,8 @@ class CacheTransBufferManager
6060
CacheTransBufferManager(
6161
KVCacheManager::BaseKVCacheManager* cacheManager, std::optional<size_t> maxNumTokens = std::nullopt);
6262

63-
static size_t preAllocBufferSize(std::map<SizeType32, SizeType32> const& cacheSizeBytesPerTokenPerWindow,
63+
static size_t preAllocBufferSize(size_t tokensPerBlock,
64+
std::map<SizeType32, SizeType32> const& cacheSizeBytesPerTokenPerWindow,
6465
std::optional<executor::CacheTransceiverConfig> const& cacheTransceiverConfig = std::nullopt);
6566

6667
std::optional<int> assignBufferIndexForSend();

cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -524,7 +524,18 @@ class CacheSender::Impl
524524

525525
if (isReady)
526526
{
527-
asyncSendAndRemoveResponse(it->first, std::move(it->second));
527+
if (dynamic_cast<executor::kv_cache::AgentConnectionManager*>(mManager) != nullptr)
528+
{
529+
// our nixl impl seems only support recv and send in the same thread
530+
// if we use zmq as control path, we may avoid this issue
531+
sendAndRemoveResponse(it->first, std::move(it->second));
532+
}
533+
else
534+
{
535+
// if we send data in another thread, multiple rank may send data for different requests at the same
536+
// time with gen DP case.
537+
asyncSendAndRemoveResponse(it->first, std::move(it->second));
538+
}
528539
removeResponse(it);
529540
}
530541
else

cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ TrtGptModelInflightBatching::TrtGptModelInflightBatching(std::shared_ptr<nvinfer
306306
auto const cacheSizeBytesPerTokenPerWindow = calculateCacheSizePerTokenForDisagg(
307307
mModelConfig, mWorldConfig, getMaxAttentionWindowVec(), mModelConfig.useCrossAttention(), 2);
308308
auto cacheTransPreAllocaSize = kv_cache_manager::CacheTransBufferManager::preAllocBufferSize(
309-
cacheSizeBytesPerTokenPerWindow, cacheTransceiverConfig);
309+
mModelConfig.getTokensPerBlock(), cacheSizeBytesPerTokenPerWindow, cacheTransceiverConfig);
310310

311311
auto const [freePrimaryMemBytes, freeSecondaryMemBytes]
312312
= BaseKVCacheManager::calculateFreeMemBytes(mRuntime->getBufferManager(), kvCacheConfig);

cpp/tensorrt_llm/nanobind/batch_manager/cacheTransceiver.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,5 +182,6 @@ void tb::CacheTransceiverBindings::initBindings(nb::module_& m)
182182
.def(nb::init<tb::kv_cache_manager::BaseKVCacheManager*, std::optional<size_t>>(), nb::arg("cache_manager"),
183183
nb::arg("max_num_tokens") = std::nullopt)
184184
.def_static("pre_alloc_buffer_size", &tb::kv_cache_manager::CacheTransBufferManager::preAllocBufferSize,
185-
nb::arg("cache_size_bytes_per_token_per_window"), nb::arg("cache_transceiver_config") = nb::none());
185+
nb::arg("tokens_per_block"), nb::arg("cache_size_bytes_per_token_per_window"),
186+
nb::arg("cache_transceiver_config") = nb::none());
186187
}

cpp/tensorrt_llm/pybind/batch_manager/cacheTransceiver.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,5 +178,6 @@ void tb::CacheTransceiverBindings::initBindings(py::module_& m)
178178
.def(py::init<tb::kv_cache_manager::BaseKVCacheManager*, std::optional<size_t>>(), py::arg("cache_manager"),
179179
py::arg("max_num_tokens") = std::nullopt)
180180
.def_static("pre_alloc_buffer_size", &tb::kv_cache_manager::CacheTransBufferManager::preAllocBufferSize,
181-
py::arg("cache_size_bytes_per_token_per_window"), py::arg("cache_transceiver_config") = py::none());
181+
py::arg("tokens_per_block"), py::arg("cache_size_bytes_per_token_per_window"),
182+
py::arg("cache_transceiver_config") = py::none());
182183
}

cpp/tests/unit_tests/batch_manager/cacheTransBufferTest.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,8 @@ TEST_F(CacheTransBufferTest, TestPreAllocBufferSize)
114114
{maxBlocksPerSeq * tokensPerBlock, cacheSizeBytesPerToken}};
115115
tensorrt_llm::executor::CacheTransceiverConfig cacheTransceiverConfig{
116116
tensorrt_llm::executor::CacheTransceiverConfig::BackendType::UCX, maxNumTokens};
117-
size_t bufferSizeBytes
118-
= CacheTransBufferManager::preAllocBufferSize(cacheSizeBytesPerTokenPerWindow, cacheTransceiverConfig);
117+
size_t bufferSizeBytes = CacheTransBufferManager::preAllocBufferSize(
118+
tokensPerBlock, cacheSizeBytesPerTokenPerWindow, cacheTransceiverConfig);
119119
auto bufferId = mTransBufferManager->assignBufferIndexForSend();
120120
EXPECT_TRUE(bufferId.has_value());
121121
EXPECT_EQ(bufferId.value(), 0);
@@ -156,8 +156,8 @@ TEST_F(CacheTransBufferTest, TestPreAllocBufferSize2)
156156
tensorrt_llm::executor::CacheTransceiverConfig::BackendType::UCX, maxNumTokens};
157157
std::map<SizeType32, SizeType32> cacheSizeBytesPerTokenPerWindow{
158158
{maxBlocksPerSeq * tokensPerBlock, cacheSizeBytesPerToken}};
159-
size_t bufferSizeBytes
160-
= CacheTransBufferManager::preAllocBufferSize(cacheSizeBytesPerTokenPerWindow, cacheTransceiverConfig);
159+
size_t bufferSizeBytes = CacheTransBufferManager::preAllocBufferSize(
160+
tokensPerBlock, cacheSizeBytesPerTokenPerWindow, cacheTransceiverConfig);
161161
auto bufferId = mTransBufferManager->assignBufferIndexForSend();
162162
EXPECT_TRUE(bufferId.has_value());
163163
EXPECT_EQ(bufferId.value(), 0);

tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,8 @@ def __init__(self, kv_cache_manager: KVCacheManager, max_num_tokens: int):
145145
max_num_tokens)
146146

147147
@staticmethod
148-
def pre_alloc_buffer_size(kv_cache_size_per_token: int,
148+
def pre_alloc_buffer_size(tokens_per_block: int,
149+
kv_cache_size_per_token: int,
149150
cache_transceiver_config: CacheTransceiverConfig):
150151
return CacheTransBufferManagerCpp.pre_alloc_buffer_size(
151-
kv_cache_size_per_token, cache_transceiver_config)
152+
tokens_per_block, kv_cache_size_per_token, cache_transceiver_config)

tests/integration/defs/cpp/test_multi_gpu.py

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -108,22 +108,6 @@ def run_cache_transceiver_tests(build_dir: _pl.Path,
108108
env=mgpu_env,
109109
timeout=timeout)
110110

111-
# Nixl transfer agent tests
112-
new_env = get_multi_gpu_env(kv_cache_type=KVCacheType.NIXL)
113-
114-
# Cache transceiver tests
115-
cache_trans_test_8_proc = [
116-
"mpirun",
117-
"-n",
118-
"8",
119-
"--allow-run-as-root",
120-
"cacheTransceiverTest",
121-
]
122-
_cpp.run_command(cache_trans_test_8_proc,
123-
cwd=tests_dir,
124-
env=new_env,
125-
timeout=600)
126-
127111

128112
def run_user_buffer_tests(build_dir: _pl.Path, nprocs=2, timeout=300):
129113
tests_dir = build_dir / "tests" / "unit_tests" / "multi_gpu"
@@ -500,8 +484,8 @@ def test_fused_gemm_allreduce(build_google_tests, nprocs, build_dir):
500484

501485
@pytest.mark.parametrize("build_google_tests", ["80", "86", "89", "90"],
502486
indirect=True)
503-
@pytest.mark.parametrize("kvcache_type", [KVCacheType.MPI, KVCacheType.UCX],
504-
ids=["mpi_kvcache", "ucx_kvcache"])
487+
@pytest.mark.parametrize("kvcache_type", [KVCacheType.NIXL, KVCacheType.UCX],
488+
ids=["nixl_kvcache", "ucx_kvcache"])
505489
@pytest.mark.parametrize("nprocs", [2, 8], ids=["2proc", "8proc"])
506490
def test_cache_transceiver(build_google_tests, nprocs, kvcache_type, build_dir):
507491

tests/integration/test_lists/test-db/l0_dgx_h100.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,9 +189,8 @@ l0_dgx_h100:
189189
# ------------- CPP tests ---------------
190190
- cpp/test_multi_gpu.py::test_mpi_utils[90]
191191
- cpp/test_multi_gpu.py::test_fused_gemm_allreduce[4proc-90]
192-
- cpp/test_multi_gpu.py::test_cache_transceiver[2proc-mpi_kvcache-90]
193192
- cpp/test_multi_gpu.py::test_cache_transceiver[2proc-ucx_kvcache-90]
194-
- cpp/test_multi_gpu.py::test_cache_transceiver[8proc-mpi_kvcache-90]
193+
- cpp/test_multi_gpu.py::test_cache_transceiver[8proc-nixl_kvcache-90]
195194
- cpp/test_multi_gpu.py::test_cache_transceiver[8proc-ucx_kvcache-90]
196195
- cpp/test_multi_gpu.py::test_user_buffer[2proc-90]
197196
- cpp/test_multi_gpu.py::test_enc_dec[t5-90]

0 commit comments

Comments
 (0)