@@ -325,6 +325,7 @@ void MLACacheFormatter::unformat(TransferSession& session)
325325 {
326326 for (auto const & block : outputBuffers)
327327 {
328+ llmRequest.updateKvCacheSize (block->getSizeInBytes ());
328329 session.recv (pickUpConnections[i], block->data (), block->getSizeInBytes ());
329330 }
330331 }
@@ -378,13 +379,15 @@ void MLACacheFormatter::unformat(TransferSession& session)
378379 if (processIdx >= remainNoCoverTargetNum)
379380 {
380381 auto & buffer = recvSplitCaches.at (processIdx);
382+ llmRequest.updateKvCacheSize (buffer->getSizeInBytes ());
381383 session.recv (pickUpConnections.at (processIdx), buffer->data (), buffer->getSizeInBytes ());
382384 }
383385 else if (bufferCoverTargetNum > 0 )
384386 {
385387 auto recvBufferIdx = processIdx % bufferCoverTargetNum
386388 + remainNoCoverTargetNum; // caches.at(recvBufferIdx) is allocated by cudaMalloc
387389 auto & buffer = recvSplitCaches.at (recvBufferIdx);
390+ llmRequest.updateKvCacheSize (buffer->getSizeInBytes ());
388391 session.recv (pickUpConnections.at (processIdx), buffer->data (), buffer->getSizeInBytes ());
389392 bufferManager.copy (*recvSplitCaches.at (recvBufferIdx), *recvSplitCaches.at (processIdx));
390393 bufferManager.getStream ().synchronize ();
@@ -401,6 +404,7 @@ void MLACacheFormatter::unformat(TransferSession& session)
401404 auto recvSlice = runtime::ITensor::slice (preAllocRecvBuffer, 0 , recvSize);
402405 auto copySlice = runtime::ITensor::slice (
403406 recvSplitCaches.at (processIdx), targetBufferSize - remainRecvSize, recvSize);
407+ llmRequest.updateKvCacheSize (recvSlice->getSizeInBytes ());
404408 session.recv (pickUpConnections.at (processIdx), recvSlice->data (), recvSlice->getSizeInBytes ());
405409 bufferManager.copy (*recvSlice, *copySlice);
406410 bufferManager.getStream ().synchronize ();
0 commit comments