Skip to content

Commit 0388ff9

Browse files
authored
[https://nvbugs/5393961][fix] record kv-cache size in MLACacheFormatter (#6181)
Signed-off-by: Bo Deng <[email protected]>
1 parent d9a3530 commit 0388ff9

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,7 @@ void MLACacheFormatter::unformat(TransferSession& session)
325325
{
326326
for (auto const& block : outputBuffers)
327327
{
328+
llmRequest.updateKvCacheSize(block->getSizeInBytes());
328329
session.recv(pickUpConnections[i], block->data(), block->getSizeInBytes());
329330
}
330331
}
@@ -378,13 +379,15 @@ void MLACacheFormatter::unformat(TransferSession& session)
378379
if (processIdx >= remainNoCoverTargetNum)
379380
{
380381
auto& buffer = recvSplitCaches.at(processIdx);
382+
llmRequest.updateKvCacheSize(buffer->getSizeInBytes());
381383
session.recv(pickUpConnections.at(processIdx), buffer->data(), buffer->getSizeInBytes());
382384
}
383385
else if (bufferCoverTargetNum > 0)
384386
{
385387
auto recvBufferIdx = processIdx % bufferCoverTargetNum
386388
+ remainNoCoverTargetNum; // caches.at(recvBufferIdx) is allocated by cudaMalloc
387389
auto& buffer = recvSplitCaches.at(recvBufferIdx);
390+
llmRequest.updateKvCacheSize(buffer->getSizeInBytes());
388391
session.recv(pickUpConnections.at(processIdx), buffer->data(), buffer->getSizeInBytes());
389392
bufferManager.copy(*recvSplitCaches.at(recvBufferIdx), *recvSplitCaches.at(processIdx));
390393
bufferManager.getStream().synchronize();
@@ -401,6 +404,7 @@ void MLACacheFormatter::unformat(TransferSession& session)
401404
auto recvSlice = runtime::ITensor::slice(preAllocRecvBuffer, 0, recvSize);
402405
auto copySlice = runtime::ITensor::slice(
403406
recvSplitCaches.at(processIdx), targetBufferSize - remainRecvSize, recvSize);
407+
llmRequest.updateKvCacheSize(recvSlice->getSizeInBytes());
404408
session.recv(pickUpConnections.at(processIdx), recvSlice->data(), recvSlice->getSizeInBytes());
405409
bufferManager.copy(*recvSlice, *copySlice);
406410
bufferManager.getStream().synchronize();

0 commit comments

Comments
 (0)