Skip to content
Closed
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cpp/tensorrt_llm/batch_manager/cacheTransBuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,6 @@ CacheTransBufferManager::CacheTransBufferManager(
: mCacheManager{cacheManager}
, mBufferManager{std::make_shared<runtime::CudaStream>()}
{

// TODO: FP4 dataSize
TLLM_CHECK(mCacheManager);
mDataType = mCacheManager->getPrimaryPool(0)->getDataType();
Expand Down Expand Up @@ -229,7 +228,7 @@ CacheTransBufferManager::CacheTransBufferManager(
mPreAllocBufferSize = mTransferBufferSize * (mRecvBufferCount + mSendBufferCount);
TLLM_LOG_INFO(
"CacheTransBufferManager: mMaxNumTokens:%ld, mRecvBufferCount:%ld, "
"mSendBufferCount:%ld,mTransferBufferSize:%ld, mPreAllocBufferSize:%ld,mOnlyUseDynamicBuffer:%d "
"mSendBufferCount:%ld, mTransferBufferSize:%ld, mPreAllocBufferSize:%ld, mOnlyUseDynamicBuffer:%d "
"mUseFabricMemory:%d mDataType:%d",
maxNumTokens.has_value() ? maxNumTokens.value() : 0, mRecvBufferCount, mSendBufferCount, mTransferBufferSize,
mPreAllocBufferSize, mOnlyUseDynamicBuffer, mUseFabricMemory, mDataType);
Expand Down Expand Up @@ -335,6 +334,7 @@ std::tuple<std::vector<runtime::ITensor::SharedPtr>, size_t, bool> CacheTransBuf
std::optional<int> bufferId, int targetNum, size_t targetBufferEleSize,
runtime::BufferManager const& bufferManagerToUse, ConcurrenceResource& concurrenceResource)
{
printf("[CacheTransBufferManager::getOrAllocateBuffers] targetNum:%d, targetBufferEleSize:%ld, mTransferBufferSize:%ld\n", targetNum, targetBufferEleSize, mTransferBufferSize);
TLLM_CHECK(bufferId.has_value() || mOnlyUseDynamicBuffer);
std::vector<runtime::ITensor::SharedPtr> retSplitCaches;
size_t bufferCoverTargetNum = std::min(
Expand Down
49 changes: 19 additions & 30 deletions cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,11 @@ void MLACacheFormatter::format(TransferSession& session)
auto& bufferManager = session.getBufferManager();
TLLM_CHECK_WITH_INFO(llmRequest.mSamplingConfig.beamWidth == 1, "Currently only supports beam width 1.");
TLLM_CHECK(!connections.empty());
// diff start
if (!needSendCache(selfConfig, destConfig, selfIdx))
{
return;
}

// diff end

auto const numPools = mCacheManager->getBlockManager().getNumPools();
auto blockRange = getBlockRangeForSending(mCacheManager, llmRequest);

Expand Down Expand Up @@ -150,28 +147,28 @@ void MLACacheFormatter::format(TransferSession& session)
auto cacheBlockSize = inputKvCacheBlocks.at(0)->getSize();

auto cacheBufferId = mCacheTransBufferManager->assignBufferIndexForSend();
// diff start

auto targetInfo = executor::kv_cache::targetIRanks(destConfig, selfConfig, selfIdx);

size_t const pPDomainSize = targetInfo.mDomainPPSize;
TLLM_CHECK((cacheBlockSize * blockNum) % pPDomainSize == 0);
auto const targetBufferSize = (cacheBlockSize * blockNum) / pPDomainSize;
size_t const cPDomainSize = targetInfo.mDomainCPSize;
TLLM_CHECK((cacheBlockSize * blockNum) % (pPDomainSize * cPDomainSize) == 0);
// @B: This works as if all output caches are of the same size. Is this a fair assumption?
auto const targetBufferSize = (cacheBlockSize * blockNum) / (pPDomainSize * cPDomainSize);
TLLM_LOG_INFO("[MLACacheFormatter::format] BEFORE getOrAllocateSendBuffers cacheBlockSize: %zu, blockNum: %d, pPDomainSize: %zu, cPDomainSize: %zu, targetBufferSize: %zu", cacheBlockSize, blockNum, pPDomainSize, cPDomainSize, targetBufferSize);
auto result = mCacheTransBufferManager->getOrAllocateSendBuffers(
cacheBufferId, pPDomainSize, targetBufferSize, bufferManager);
cacheBufferId, pPDomainSize * cPDomainSize, targetBufferSize, bufferManager);
auto& outputSplitCaches = std::get<0>(result);
Comment on lines 153 to 161
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Guard: ensure connections.size() matches PP×CP before modulo mapping.

Add a sanity check to prevent modulo on a smaller connection set, which would scatter to wrong peers.

     auto targetInfo = executor::kv_cache::targetIRanks(destConfig, selfConfig, selfIdx);
     size_t const pPDomainSize = targetInfo.mDomainPPSize;
     size_t const cPDomainSize = targetInfo.mDomainCPSize;
+    TLLM_CHECK_WITH_INFO(
+        connections.size() == pPDomainSize * cPDomainSize,
+        "Mismatch: number of connections (%zu) must equal PP×CP (%zu).",
+        connections.size(), pPDomainSize * cPDomainSize);
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
size_t const pPDomainSize = targetInfo.mDomainPPSize;
TLLM_CHECK((cacheBlockSize * blockNum) % pPDomainSize == 0);
auto const targetBufferSize = (cacheBlockSize * blockNum) / pPDomainSize;
size_t const cPDomainSize = targetInfo.mDomainCPSize;
TLLM_CHECK((cacheBlockSize * blockNum) % (pPDomainSize * cPDomainSize) == 0);
// @B: This works as if all output caches are of the same size. Is this a fair assumption?
auto const targetBufferSize = (cacheBlockSize * blockNum) / (pPDomainSize * cPDomainSize);
TLLM_LOG_INFO("[MLACacheFormatter::format] BEFORE getOrAllocateSendBuffers cacheBlockSize: %zu, blockNum: %d, pPDomainSize: %zu, cPDomainSize: %zu, targetBufferSize: %zu", cacheBlockSize, blockNum, pPDomainSize, cPDomainSize, targetBufferSize);
auto result = mCacheTransBufferManager->getOrAllocateSendBuffers(
cacheBufferId, pPDomainSize, targetBufferSize, bufferManager);
cacheBufferId, pPDomainSize * cPDomainSize, targetBufferSize, bufferManager);
auto& outputSplitCaches = std::get<0>(result);
size_t const pPDomainSize = targetInfo.mDomainPPSize;
size_t const cPDomainSize = targetInfo.mDomainCPSize;
TLLM_CHECK_WITH_INFO(
connections.size() == pPDomainSize * cPDomainSize,
"Mismatch: number of connections (%zu) must equal PP×CP (%zu).",
connections.size(), pPDomainSize * cPDomainSize);
TLLM_CHECK((cacheBlockSize * blockNum) % (pPDomainSize * cPDomainSize) == 0);
// @B: This works as if all output caches are of the same size. Is this a fair assumption?
auto const targetBufferSize = (cacheBlockSize * blockNum) / (pPDomainSize * cPDomainSize);
TLLM_LOG_INFO("[MLACacheFormatter::format] BEFORE getOrAllocateSendBuffers cacheBlockSize: %zu, blockNum: %d, pPDomainSize: %zu, cPDomainSize: %zu, targetBufferSize: %zu",
cacheBlockSize, blockNum, pPDomainSize, cPDomainSize, targetBufferSize);
auto result = mCacheTransBufferManager->getOrAllocateSendBuffers(
cacheBufferId, pPDomainSize * cPDomainSize, targetBufferSize, bufferManager);
auto& outputSplitCaches = std::get<0>(result);
🤖 Prompt for AI Agents
cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp around lines 153 to 161:
the code assumes the number of connections equals pPDomainSize * cPDomainSize
before doing modulo-based mapping which can scatter to wrong peers if
connections is smaller; add a sanity guard that verifies connections.size() >=
pPDomainSize * cPDomainSize (or == if strict) and fail fast with an explanatory
TLLM_CHECK or error log if the condition is not met, before computing
targetBufferSize and calling getOrAllocateSendBuffers; ensure the check prevents
division/modulo mapping against a smaller connection set and include minimal
context in the error message (e.g., actual sizes) so callers can debug.

auto& bufferCoverTargetNum = std::get<1>(result);
auto& onlyUseDynamicBuffer = std::get<2>(result);
auto* agentConnnecion = dynamic_cast<executor::kv_cache::AgentConnection const*>(connections[0]);
if (agentConnnecion != nullptr)
auto* agentConnnection = dynamic_cast<executor::kv_cache::AgentConnection const*>(connections[0]);
if (agentConnnection != nullptr)
{
TLLM_CHECK_WITH_INFO(bufferCoverTargetNum == pPDomainSize, "Agent need all buffer pre-allocated");
TLLM_CHECK_WITH_INFO(bufferCoverTargetNum == pPDomainSize * cPDomainSize, "Agent need all buffer pre-allocated");
TLLM_CHECK(onlyUseDynamicBuffer == false);
}
// diff end

// The size of outputSplitCaches should be equal to pPDomainSize

// The size of outputSplitCaches should be equal to pPDomainSize * cPDomainSize.
SizeType32 window = mCacheManager->getBlockManager().getPoolWindowSize(0);
std::map<SizeType32, std::vector<runtime::ITensor::SharedPtr>> inputKvCacheBlocksPerWindow;
inputKvCacheBlocksPerWindow.emplace(window, inputKvCacheBlocks);
Expand All @@ -191,8 +188,9 @@ void MLACacheFormatter::format(TransferSession& session)

TLLM_CUDA_CHECK(cudaSetDevice(deviceId));
auto startTime = std::chrono::steady_clock::now();
auto cacheIdx = processIdx % pPDomainSize;
auto cacheIdx = processIdx % (pPDomainSize * cPDomainSize);
size_t size;
// @B: What does this check mean?
if (cacheIdx < bufferCoverTargetNum)
{
size = outputSplitCaches.at(cacheIdx)->getSizeInBytes();
Expand Down Expand Up @@ -252,7 +250,7 @@ void MLACacheFormatter::format(TransferSession& session)
else
{
// concurrency num
auto concurrencyNum = std::min(std::max(static_cast<size_t>(1), bufferCoverTargetNum), pPDomainSize);
auto concurrencyNum = std::min(std::max(static_cast<size_t>(1), bufferCoverTargetNum), pPDomainSize * cPDomainSize);

auto remainSendNum = connections.size();

Expand Down Expand Up @@ -300,10 +298,9 @@ void MLACacheFormatter::unformat(TransferSession& session)
auto& bufferManager = session.getBufferManager();
auto arrivalTime = llmRequest.getPerfMetrics().timingMetrics.arrivalTime;
bool recordDelay = arrivalTime != std::chrono::steady_clock::time_point();
// diff start
auto pickUpConnections = pickRecvConnections(connections.size(), selfConfig, selfIdx, destConfig);
// diff end
auto blockRange = getBlockRangeForReceiving(mCacheManager, llmRequest);
printf("[MLACacheFormatter::unformat] pickUpConnections.size(): %zu, connections.size(): %zu, blockRange.size(): %zu\n", pickUpConnections.size(), connections.size(), blockRange.size());
std::vector<runtime::ITensor::SharedPtr> recvBufferTmps;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Replace stdout printf with logger.

Use TLLM_LOG_DEBUG to keep debug output consistent and filterable.

-    printf("[MLACacheFormatter::unformat] pickUpConnections.size(): %zu, connections.size(): %zu, blockRange.size(): %zu\n", pickUpConnections.size(), connections.size(), blockRange.size());
+    TLLM_LOG_DEBUG("[MLACacheFormatter::unformat] pickUpConnections:%zu, connections:%zu, blockRange:%zu",
+        pickUpConnections.size(), connections.size(), blockRange.size());

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp around line 304, there's
a debug print to stdout using printf; replace that call with TLLM_LOG_DEBUG to
make debug output consistent and filterable. Change the printf to a
TLLM_LOG_DEBUG call that includes the same formatted message and variables,
ensure the file includes the logger header if missing, and remove the stdout
printf so all debug messages use TLLM_LOG_DEBUG.

std::vector<runtime::ITensor::SharedPtr> outputBuffers;
auto const numPools = mCacheManager->getBlockManager().getNumPools();
Expand Down Expand Up @@ -346,10 +343,10 @@ void MLACacheFormatter::unformat(TransferSession& session)
}
else
{
auto* agentConnnecion = dynamic_cast<executor::kv_cache::AgentConnection const*>(connections[0]);
if (agentConnnecion != nullptr)
auto* agentConnnection = dynamic_cast<executor::kv_cache::AgentConnection const*>(connections[0]);
if (agentConnnection != nullptr)
{
cacheBufferId = agentConnnecion->getCacheBufferId();
cacheBufferId = agentConnnection->getCacheBufferId();
TLLM_CHECK(cacheBufferId.has_value());
}
else
Expand All @@ -368,7 +365,7 @@ void MLACacheFormatter::unformat(TransferSession& session)
auto& bufferCoverTargetNum = std::get<1>(result);
size_t remainNoCoverTargetNum = targetNum > bufferCoverTargetNum ? targetNum - bufferCoverTargetNum : 0;
auto& onlyUseDynamicBuffer = std::get<2>(result);
if (agentConnnecion != nullptr)
if (agentConnnection != nullptr)
{
TLLM_CHECK_WITH_INFO(bufferCoverTargetNum == targetNum, "Agent need buffer pre-allocated");
TLLM_CHECK(onlyUseDynamicBuffer == false);
Expand Down Expand Up @@ -489,7 +486,7 @@ void MLACacheFormatter::unformat(TransferSession& session)
outputCachesPerWindow.emplace(window, outputBuffers);
NVTX3_SCOPED_RANGE(formatInputConcatenate);

// recvSplitCaches size == ppdomainsize
// recvSplitCaches size == ppdomainsize * cPDomainSize.
executor::kv_cache::concatKvCacheV2Dispatch(
recvSplitCaches, outputCachesPerWindow, destConfig, selfConfig, selfIdx, bufferManager);
}
Expand Down Expand Up @@ -564,14 +561,6 @@ void MLACacheFormatter::unformat(TransferSession& session)
TLLM_LOG_WARNING("MLACacheFormatter::inquireSupport: TP size must be divisible by DP size");
return false;
}
if (selfConfig.getParallelConfig().mContextParallelism != 1
|| destConfig.getParallelConfig().mContextParallelism != 1)
{
TLLM_LOG_WARNING(
"MLACacheFormatter::inquireSupport: context parallelism is not currently supported (selfCP=%d, destCP=%d).",
selfConfig.getParallelConfig().mContextParallelism, destConfig.getParallelConfig().mContextParallelism);
return false;
}
if (destConfig.getParallelConfig().mEnableAttentionDP
&& (destConfig.getParallelConfig().mTensorParallelism % destConfig.getParallelConfig().mDPsize != 0))
{
Expand Down
Loading
Loading