From 77bf76cfb25d2f2c5e67f040957d2733aec1b5be Mon Sep 17 00:00:00 2001 From: Chuang Zhu <111838961+chuangz0@users.noreply.github.com> Date: Wed, 6 Aug 2025 02:35:51 +0000 Subject: [PATCH 1/7] context_tp_&gen_dp_opt Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com> --- .../batch_manager/cacheFormatter.cpp | 7 ++-- .../batch_manager/mlaCacheFormatter.cpp | 40 +++++++++++-------- .../multi_gpu/cacheTransceiverTest.cpp | 4 +- 3 files changed, 30 insertions(+), 21 deletions(-) diff --git a/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp b/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp index 306cd64187e..168ea89693f 100644 --- a/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp +++ b/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp @@ -90,9 +90,9 @@ bool CacheFormatter::needSendCache( = selfConfig.getParallelConfig().mTensorParallelism / selfConfig.getParallelConfig().mDPsize; selfTpRankInDpGroup = selfTpRank % selfTPNumInDPGroup; } + int destDPRank = destConfig.getParallelConfig().mEnableAttentionDP ? destConfig.getParallelConfig().mDPrank : 0; - // only TP rank % dupHeadFactor == 0 need to send cache. - return selfTpRankInDpGroup % targetInfo.mDupHeadFactor == 0; + return (destDPRank % targetInfo.mDupHeadFactor) == (selfTpRankInDpGroup % targetInfo.mDupHeadFactor); } void checkAlternateWindow(BaseKVCacheManager* cacheManager, BaseCacheFormatter::CacheState const& selfConfig, @@ -140,11 +140,12 @@ std::vector CacheFormatter::pickRecvConnections( return ret; } TLLM_CHECK(numConnections == targetInfo.mIRanks.size()); + int selfDPRank = selfConfig.getParallelConfig().mEnableAttentionDP ? selfConfig.getParallelConfig().mDPrank : 0; std::vector ret; for (int i = 0; i < targetInfo.mDomainTPSize; i++) { - if (i % targetInfo.mPeerDupHeadFactor == 0) + if ((i % targetInfo.mPeerDupHeadFactor) == (selfDPRank % targetInfo.mPeerDupHeadFactor)) { for (int j = 0; j < targetInfo.mDomainPPSize; j++) { diff --git a/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp b/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp index 8a32d2b70c8..6fc8f25e136 100644 --- a/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp +++ b/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp @@ -72,10 +72,12 @@ std::vector MLACacheFormatter::pickRecvConnections( TLLM_CHECK(targetInfo.mDomainCPSize == 1); TLLM_CHECK(numConnections == targetInfo.mIRanks.size()); std::vector ret; - // targetInfo , mRanks [tpranks, dpranks] + // targetInfo , mRanks [tpranks, ppranks] + int dpRank = selfConfig.getParallelConfig().mEnableAttentionDP ? selfConfig.getParallelConfig().mDPrank : 0; + for (int i = 0; i < targetInfo.mDomainPPSize; i++) { - ret.push_back(i); + ret.push_back(i + (dpRank % (targetInfo.mDomainTPSize)) * targetInfo.mDomainPPSize); } return ret; } @@ -85,19 +87,24 @@ bool MLACacheFormatter::needSendCache( { int selfTpRank = selfIdx % selfConfig.getParallelConfig().mTensorParallelism; + int destTPNumInDPGroup = destConfig.getParallelConfig().mEnableAttentionDP + ? destConfig.getParallelConfig().mTensorParallelism / destConfig.getParallelConfig().mDPsize + : destConfig.getParallelConfig().mTensorParallelism; + int destDPRank = destConfig.getParallelConfig().mEnableAttentionDP ? destConfig.getParallelConfig().mDPrank : 0; + if (selfConfig.getParallelConfig().mEnableAttentionDP) { int selfTPNumInDPGroup = selfConfig.getParallelConfig().mTensorParallelism / selfConfig.getParallelConfig().mDPsize; - int destTPNumInDPGroup = destConfig.getParallelConfig().mEnableAttentionDP - ? destConfig.getParallelConfig().mTensorParallelism / destConfig.getParallelConfig().mDPsize - : destConfig.getParallelConfig().mTensorParallelism; + int selfTPrankINDPGroup = selfTpRank % selfTPNumInDPGroup; if (selfTPNumInDPGroup <= destTPNumInDPGroup) { return true; } - return selfTPrankINDPGroup % (selfTPNumInDPGroup / destTPNumInDPGroup) == 0; + + int dupHeadFactor = selfTPNumInDPGroup / destTPNumInDPGroup; + return selfTPrankINDPGroup % dupHeadFactor == destDPRank; } int destTPNum = destConfig.getParallelConfig().mEnableAttentionDP @@ -108,14 +115,15 @@ bool MLACacheFormatter::needSendCache( { return true; } - return selfTpRank % (selfTPNum / destTPNum) == 0; + int dupHeadFactor = selfTPNum / destTPNum; + return selfTpRank % dupHeadFactor == destDPRank; } void MLACacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& session) { NVTX3_SCOPED_RANGE(MLACacheFormatter_format); auto const& llmRequest = session.getLlmRequest(); - TLLM_LOG_DEBUG( + TLLM_LOG_INFO( mpi::MpiComm::world().getRank(), "Start sending KV cache for request ID: %ld.", llmRequest.mRequestId); auto const& selfConfig = session.getSelfState().getCacheState().value(); auto const& destConfig = session.getOtherState().getCacheState().value(); @@ -153,7 +161,7 @@ void MLACacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& ses && destConfig.getParallelConfig().mPipelineParallelism == selfConfig.getParallelConfig().mPipelineParallelism) { - TLLM_LOG_DEBUG("Try using zero-copy for the KV cache."); + TLLM_LOG_INFO("Try using zero-copy for the KV cache."); NVTX3_SCOPED_RANGE(sendBufferFun); TLLM_CUDA_CHECK(cudaSetDevice(deviceId)); @@ -165,7 +173,7 @@ void MLACacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& ses } } - TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(), "End the sending of KV cache for the request ID: %ld.", + TLLM_LOG_INFO(mpi::MpiComm::world().getRank(), "End the sending of KV cache for the request ID: %ld.", llmRequest.mRequestId); return; @@ -279,7 +287,7 @@ void MLACacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& ses { if (!common::getEnvEnableReceiveKVCacheParallel()) { - TLLM_LOG_DEBUG("Disable parallel receiving of the KV cache."); + TLLM_LOG_INFO("Disable parallel receiving of the KV cache."); for (size_t i = 0; i < connections.size(); i++) { sendBufferFun(deviceId, i); @@ -318,7 +326,7 @@ void MLACacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& ses } mCacheTransBufferManager->freeBufferIndexForSend(cacheBufferId); - TLLM_LOG_DEBUG( + TLLM_LOG_INFO( mpi::MpiComm::world().getRank(), "End the sending of KV cache for the request ID: %ld.", llmRequest.mRequestId); } @@ -328,7 +336,7 @@ void MLACacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& s auto const& llmRequest = session.getLlmRequest(); TLLM_CHECK_WITH_INFO(llmRequest.mSamplingConfig.beamWidth == 1, "Currently only supports beam width 1."); auto const ctxReqId = llmRequest.getContextPhaseParams().value().getReqId(); - TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(), + TLLM_LOG_INFO(mpi::MpiComm::world().getRank(), "Start receiving KV cache for request ID: %ld, context request ID: %ld.", llmRequest.mRequestId, ctxReqId); auto const& selfConfig = session.getSelfState().getCacheState().value(); auto const& destConfig = session.getOtherState().getCacheState().value(); @@ -362,7 +370,7 @@ void MLACacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& s && destConfig.getParallelConfig().mPipelineParallelism == selfConfig.getParallelConfig().mPipelineParallelism) { // recv - TLLM_LOG_DEBUG("Try zcopy for KV cache"); + TLLM_LOG_INFO("Try zcopy for KV cache"); NVTX3_SCOPED_RANGE(recvBufferFun); TLLM_CUDA_CHECK(cudaSetDevice(deviceId)); TLLM_CHECK(pickUpConnections.size() == 1); @@ -374,7 +382,7 @@ void MLACacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& s session.recv(pickUpConnections[i], block->data(), block->getSizeInBytes()); } } - TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(), + TLLM_LOG_INFO(mpi::MpiComm::world().getRank(), "End receiving KV cache for request ID: %ld, context request ID: %ld.", llmRequest.mRequestId, llmRequest.getContextPhaseParams().value().getReqId()); return; @@ -547,7 +555,7 @@ void MLACacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& s mCacheTransBufferManager->freeBufferIndexForRecv(cacheBufferId); } - TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(), + TLLM_LOG_INFO(mpi::MpiComm::world().getRank(), "End receiving KV cache for request ID: %ld, context request ID: %ld.", llmRequest.mRequestId, llmRequest.getContextPhaseParams().value().getReqId()); } diff --git a/cpp/tests/unit_tests/multi_gpu/cacheTransceiverTest.cpp b/cpp/tests/unit_tests/multi_gpu/cacheTransceiverTest.cpp index 21852a4e498..6005d8126d1 100644 --- a/cpp/tests/unit_tests/multi_gpu/cacheTransceiverTest.cpp +++ b/cpp/tests/unit_tests/multi_gpu/cacheTransceiverTest.cpp @@ -1849,13 +1849,13 @@ TEST(targetTest, CacheStateContextDP) /*expectNeedSend*/ true); verifyContext( /*contextRank*/ 0, /*generationRank*/ 1, /*expectRanks*/ {1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, - /*expectNeedSend*/ true); + /*expectNeedSend*/ false); verifyContext( /*contextRank*/ 1, /*generationRank*/ 0, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectNeedSend*/ false); verifyContext( /*contextRank*/ 1, /*generationRank*/ 1, /*expectRanks*/ {1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, - /*expectNeedSend*/ false); + /*expectNeedSend*/ true); verifyContext( /*contextRank*/ 2, /*generationRank*/ 0, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectNeedSend*/ false); From 89c4f1a00a1dadb10ea33f84ff8fbd8946a9ae68 Mon Sep 17 00:00:00 2001 From: Chuang Zhu <111838961+chuangz0@users.noreply.github.com> Date: Wed, 6 Aug 2025 06:47:26 +0000 Subject: [PATCH 2/7] info->debug Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com> --- .../batch_manager/mlaCacheFormatter.cpp | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp b/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp index 6fc8f25e136..6df47b0ceb4 100644 --- a/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp +++ b/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp @@ -123,7 +123,7 @@ void MLACacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& ses { NVTX3_SCOPED_RANGE(MLACacheFormatter_format); auto const& llmRequest = session.getLlmRequest(); - TLLM_LOG_INFO( + TLLM_LOG_DEBUG( mpi::MpiComm::world().getRank(), "Start sending KV cache for request ID: %ld.", llmRequest.mRequestId); auto const& selfConfig = session.getSelfState().getCacheState().value(); auto const& destConfig = session.getOtherState().getCacheState().value(); @@ -161,7 +161,7 @@ void MLACacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& ses && destConfig.getParallelConfig().mPipelineParallelism == selfConfig.getParallelConfig().mPipelineParallelism) { - TLLM_LOG_INFO("Try using zero-copy for the KV cache."); + TLLM_LOG_DEBUG("Try using zero-copy for the KV cache."); NVTX3_SCOPED_RANGE(sendBufferFun); TLLM_CUDA_CHECK(cudaSetDevice(deviceId)); @@ -173,7 +173,7 @@ void MLACacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& ses } } - TLLM_LOG_INFO(mpi::MpiComm::world().getRank(), "End the sending of KV cache for the request ID: %ld.", + TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(), "End the sending of KV cache for the request ID: %ld.", llmRequest.mRequestId); return; @@ -287,7 +287,7 @@ void MLACacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& ses { if (!common::getEnvEnableReceiveKVCacheParallel()) { - TLLM_LOG_INFO("Disable parallel receiving of the KV cache."); + TLLM_LOG_DEBUG("Disable parallel receiving of the KV cache."); for (size_t i = 0; i < connections.size(); i++) { sendBufferFun(deviceId, i); @@ -326,7 +326,7 @@ void MLACacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& ses } mCacheTransBufferManager->freeBufferIndexForSend(cacheBufferId); - TLLM_LOG_INFO( + TLLM_LOG_DEBUG( mpi::MpiComm::world().getRank(), "End the sending of KV cache for the request ID: %ld.", llmRequest.mRequestId); } @@ -336,7 +336,7 @@ void MLACacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& s auto const& llmRequest = session.getLlmRequest(); TLLM_CHECK_WITH_INFO(llmRequest.mSamplingConfig.beamWidth == 1, "Currently only supports beam width 1."); auto const ctxReqId = llmRequest.getContextPhaseParams().value().getReqId(); - TLLM_LOG_INFO(mpi::MpiComm::world().getRank(), + TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(), "Start receiving KV cache for request ID: %ld, context request ID: %ld.", llmRequest.mRequestId, ctxReqId); auto const& selfConfig = session.getSelfState().getCacheState().value(); auto const& destConfig = session.getOtherState().getCacheState().value(); @@ -370,7 +370,7 @@ void MLACacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& s && destConfig.getParallelConfig().mPipelineParallelism == selfConfig.getParallelConfig().mPipelineParallelism) { // recv - TLLM_LOG_INFO("Try zcopy for KV cache"); + TLLM_LOG_DEBUG("Try zcopy for KV cache"); NVTX3_SCOPED_RANGE(recvBufferFun); TLLM_CUDA_CHECK(cudaSetDevice(deviceId)); TLLM_CHECK(pickUpConnections.size() == 1); @@ -382,7 +382,7 @@ void MLACacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& s session.recv(pickUpConnections[i], block->data(), block->getSizeInBytes()); } } - TLLM_LOG_INFO(mpi::MpiComm::world().getRank(), + TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(), "End receiving KV cache for request ID: %ld, context request ID: %ld.", llmRequest.mRequestId, llmRequest.getContextPhaseParams().value().getReqId()); return; @@ -555,7 +555,7 @@ void MLACacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& s mCacheTransBufferManager->freeBufferIndexForRecv(cacheBufferId); } - TLLM_LOG_INFO(mpi::MpiComm::world().getRank(), + TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(), "End receiving KV cache for request ID: %ld, context request ID: %ld.", llmRequest.mRequestId, llmRequest.getContextPhaseParams().value().getReqId()); } From a02ea94232995186c36051a50989a2a7c4340bcd Mon Sep 17 00:00:00 2001 From: Chuang Zhu <111838961+chuangz0@users.noreply.github.com> Date: Thu, 7 Aug 2025 11:33:26 +0000 Subject: [PATCH 3/7] fix mla Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com> --- .../batch_manager/mlaCacheFormatter.cpp | 4 ++-- .../multi_gpu/cacheTransceiverTest.cpp | 17 +++++++++++++++++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp b/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp index 6df47b0ceb4..aa45c241aae 100644 --- a/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp +++ b/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp @@ -104,7 +104,7 @@ bool MLACacheFormatter::needSendCache( } int dupHeadFactor = selfTPNumInDPGroup / destTPNumInDPGroup; - return selfTPrankINDPGroup % dupHeadFactor == destDPRank; + return selfTPrankINDPGroup % dupHeadFactor == destDPRank % dupHeadFactor; } int destTPNum = destConfig.getParallelConfig().mEnableAttentionDP @@ -116,7 +116,7 @@ bool MLACacheFormatter::needSendCache( return true; } int dupHeadFactor = selfTPNum / destTPNum; - return selfTpRank % dupHeadFactor == destDPRank; + return selfTpRank % dupHeadFactor == destDPRank % dupHeadFactor; } void MLACacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& session) diff --git a/cpp/tests/unit_tests/multi_gpu/cacheTransceiverTest.cpp b/cpp/tests/unit_tests/multi_gpu/cacheTransceiverTest.cpp index 6005d8126d1..9d1a834514d 100644 --- a/cpp/tests/unit_tests/multi_gpu/cacheTransceiverTest.cpp +++ b/cpp/tests/unit_tests/multi_gpu/cacheTransceiverTest.cpp @@ -1432,6 +1432,18 @@ INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForMLA3, AsymmetricalCacheTestWi testing::Values(16), testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(1), testing::Values(true), testing::Values(false), testing::Values(true), testing::Values(false))); +INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForMLA4, AsymmetricalCacheTestWithDP, + testing::Combine(testing::Values(2), testing::Values(1), testing::Values(4), testing::Values(1), testing::Values(4), + testing::Values(1), testing::Values(4), testing::Values(16), + testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(1), + testing::Values(true), testing::Values(false), testing::Values(true), testing::Values(false))); + +INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForMLA5, AsymmetricalCacheTestWithDP, + testing::Combine(testing::Values(4), testing::Values(1), testing::Values(2), testing::Values(1), testing::Values(4), + testing::Values(1), testing::Values(4), testing::Values(16), + testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(1), + testing::Values(true), testing::Values(false), testing::Values(true), testing::Values(false))); + INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForNoMLA, AsymmetricalCacheTestWithDP, testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1), testing::Values(1, 2), testing::Values(1, 2), testing::Values(1), testing::Values(4), testing::Values(4), testing::Values(4), @@ -1472,6 +1484,11 @@ INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForNoMLADuplicate2, Asymmetrical testing::Values(1), testing::Values(1), testing::Values(4), testing::Values(2), testing::Values(4), testing::Values(16), testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2), testing::Values(false), testing::Values(false), testing::Values(false), testing::Values(false))); +INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForNoMLADuplicate3, AsymmetricalCacheTestWithDP, + testing::Combine(testing::Values(2), testing::Values(1), testing::Values(4), testing::Values(1), testing::Values(4), + testing::Values(2), testing::Values(4), testing::Values(16), + testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2), + testing::Values(false), testing::Values(false), testing::Values(true), testing::Values(false))); INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForNoMLADuplicate4, AsymmetricalCacheTestWithDP, testing::Combine(testing::Values(4), testing::Values(1), testing::Values(1), testing::Values(1, 2), From ece1d22dd11b56b426f73ca6670183e1e4fb8943 Mon Sep 17 00:00:00 2001 From: Chuang Zhu <111838961+chuangz0@users.noreply.github.com> Date: Mon, 8 Sep 2025 10:28:28 +0000 Subject: [PATCH 4/7] fix test Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com> --- .../unit_tests/multi_gpu/cacheTransceiverTest.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/cpp/tests/unit_tests/multi_gpu/cacheTransceiverTest.cpp b/cpp/tests/unit_tests/multi_gpu/cacheTransceiverTest.cpp index 9d1a834514d..607b7ece627 100644 --- a/cpp/tests/unit_tests/multi_gpu/cacheTransceiverTest.cpp +++ b/cpp/tests/unit_tests/multi_gpu/cacheTransceiverTest.cpp @@ -1433,14 +1433,14 @@ INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForMLA3, AsymmetricalCacheTestWi testing::Values(true), testing::Values(false), testing::Values(true), testing::Values(false))); INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForMLA4, AsymmetricalCacheTestWithDP, - testing::Combine(testing::Values(2), testing::Values(1), testing::Values(4), testing::Values(1), testing::Values(4), - testing::Values(1), testing::Values(4), testing::Values(16), + testing::Combine(testing::Values(2), testing::Values(1), testing::Values(1), testing::Values(4), testing::Values(1), + testing::Values(1), testing::Values(4), testing::Values(1), testing::Values(4), testing::Values(16), testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(1), testing::Values(true), testing::Values(false), testing::Values(true), testing::Values(false))); INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForMLA5, AsymmetricalCacheTestWithDP, - testing::Combine(testing::Values(4), testing::Values(1), testing::Values(2), testing::Values(1), testing::Values(4), - testing::Values(1), testing::Values(4), testing::Values(16), + testing::Combine(testing::Values(4), testing::Values(1), testing::Values(1), testing::Values(2), testing::Values(1), + testing::Values(1), testing::Values(4), testing::Values(1), testing::Values(4), testing::Values(16), testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(1), testing::Values(true), testing::Values(false), testing::Values(true), testing::Values(false))); @@ -1485,8 +1485,8 @@ INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForNoMLADuplicate2, Asymmetrical testing::Values(16), testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2), testing::Values(false), testing::Values(false), testing::Values(false), testing::Values(false))); INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForNoMLADuplicate3, AsymmetricalCacheTestWithDP, - testing::Combine(testing::Values(2), testing::Values(1), testing::Values(4), testing::Values(1), testing::Values(4), - testing::Values(2), testing::Values(4), testing::Values(16), + testing::Combine(testing::Values(2), testing::Values(1), testing::Values(1), testing::Values(4), testing::Values(1), + testing::Values(1), testing::Values(4), testing::Values(2), testing::Values(4), testing::Values(16), testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2), testing::Values(false), testing::Values(false), testing::Values(true), testing::Values(false))); From ecd62263b8da77c5a1a6066f313db574a26f0f90 Mon Sep 17 00:00:00 2001 From: Chuang Zhu <111838961+chuangz0@users.noreply.github.com> Date: Tue, 16 Sep 2025 09:32:54 +0000 Subject: [PATCH 5/7] async send kv cache Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com> --- .../batch_manager/cacheTransBuffer.cpp | 4 +- .../batch_manager/dataTransceiver.cpp | 90 ++++++++++++++----- cpp/tensorrt_llm/common/envUtils.cpp | 2 +- .../batch_manager/cacheTransBufferTest.cpp | 8 +- .../legacy/advanced/disaggregated-service.md | 2 +- 5 files changed, 75 insertions(+), 31 deletions(-) diff --git a/cpp/tensorrt_llm/batch_manager/cacheTransBuffer.cpp b/cpp/tensorrt_llm/batch_manager/cacheTransBuffer.cpp index 33986426f54..424f028262d 100644 --- a/cpp/tensorrt_llm/batch_manager/cacheTransBuffer.cpp +++ b/cpp/tensorrt_llm/batch_manager/cacheTransBuffer.cpp @@ -219,7 +219,7 @@ CacheTransBufferManager::CacheTransBufferManager( = maxNumTokens.has_value() ? bufferSizeFromMaxNumToken : common::getEnvMemSizeForKVCacheTransferBuffer(); mOnlyUseDynamicBuffer = mTransferBufferSize == 0; mRecvBufferCount = common::getEnvRequestKVCacheConcurrent() ? common::getEnvKVCacheRecvBufferCount() : 1; - mSendBufferCount = common::getEnvParallelCacheSend() ? common::getEnvKVCacheSendMaxConcurrenceNum() : 1; + mSendBufferCount = common::getEnvKVCacheSendMaxConcurrenceNum(); mUseFabricMemory = !(common::getEnvKVCacheTransferUseSyncBuffer() || common::getEnvKVCacheTransferUseAsyncBuffer()) && FabricMemory::supportFbaricMemory(); if (mUseFabricMemory) @@ -269,7 +269,7 @@ size_t CacheTransBufferManager::preAllocBufferSize( TransferBufferSize = FabricMemory::getAlignedSize(TransferBufferSize); } size_t RecvBufferCount = common::getEnvRequestKVCacheConcurrent() ? common::getEnvKVCacheRecvBufferCount() : 1; - size_t SendBufferCount = common::getEnvParallelCacheSend() ? common::getEnvKVCacheSendMaxConcurrenceNum() : 1; + size_t SendBufferCount = common::getEnvKVCacheSendMaxConcurrenceNum(); size_t PreAllocBufferSize = TransferBufferSize * (RecvBufferCount + SendBufferCount); return PreAllocBufferSize; } diff --git a/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp b/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp index 86e95000ef0..b94a027efbf 100644 --- a/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp +++ b/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp @@ -256,6 +256,12 @@ class CacheSender::Impl TLLM_CUDA_CHECK(cudaGetDevice(&mDeviceId)); mCurrentRequest = std::nullopt; mResponseFuture = std::async(std::launch::async, &Impl::response, this); + int asyncSendThreadNum = common::getEnvKVCacheSendMaxConcurrenceNum(); + for (int i = 0; i < asyncSendThreadNum; i++) + { + mAsyncSendFutures.emplace_back( + std::async(std::launch::async, &Impl::handleAsyncSend, this, std::ref(mAsyncSendResource))); + } } [[nodiscard]] std::future sendAsync(LlmRequest& llmRequest) @@ -294,9 +300,9 @@ class CacheSender::Impl void release(LlmRequest::RequestIdType requestId) { + std::unique_lock lk(mMtxForMap); auto it = mRequestToSession.find(requestId); TLLM_CHECK(it != mRequestToSession.end()); - std::unique_lock lk(mMtxForMap); if (!common::getEnvKVCacheTransferOutputPath().empty()) { if (!mMeasuresFile.is_open()) @@ -368,11 +374,15 @@ class CacheSender::Impl void sendSync(LlmRequest const& llmRequest) { - auto it = mRequestToSession.find(llmRequest.mRequestId); - TLLM_CHECK(it != mRequestToSession.end()); - auto& session = it->second; - session.setLlmRequest(llmRequest); - mFormatter->format(session); + TransferSession* session = nullptr; + { + std::unique_lock lk(mMtxForMap); + auto it = mRequestToSession.find(llmRequest.mRequestId); + TLLM_CHECK(it != mRequestToSession.end()); + session = std::addressof(it->second); + } + session->setLlmRequest(llmRequest); + mFormatter->format(*session); } ~Impl() @@ -387,6 +397,42 @@ class CacheSender::Impl std::promise mPromise; }; + struct AsyncSendResource + { + std::deque mSendQueue; + std::mutex mMtxForQueue; + std::condition_variable mCVforQueue; + std::atomic mTerminate{false}; + }; + + void handleAsyncSend(AsyncSendResource& resource) + { + tensorrt_llm::common::setThreadName("dataTransAsyncSend"); + TLLM_LOG_INFO(mpi::MpiComm::world().getRank(), "Start handling async send"); + while (!resource.mTerminate) + { + Response resp; + { + std::unique_lock lk(resource.mMtxForQueue); + resource.mCVforQueue.wait( + lk, [&resource] { return !resource.mSendQueue.empty() || resource.mTerminate; }); + if (resource.mTerminate) + { + if (!resource.mSendQueue.empty()) + { + TLLM_LOG_WARNING("There are still %zu requests in the mSendQueue, but encountered terminate.", + resource.mSendQueue.size()); + } + break; + } + resp = std::move(resource.mSendQueue.front()); + resource.mSendQueue.pop_front(); + } + // TLLM_LOG_INFO(mpi::MpiComm::world().getRank(), "Start sending request %zu", resp.mRequest->mRequestId); + sendAndRemoveResponse(resp.mRequest->mRequestId, std::move(resp)); + } + } + void sendAndRemoveResponse(RequestIdType id, Response resp) noexcept { try @@ -409,6 +455,13 @@ class CacheSender::Impl } } + void asyncSendAndRemoveResponse(RequestIdType id, Response resp) noexcept + { + std::unique_lock lk(mAsyncSendResource.mMtxForQueue); + mAsyncSendResource.mSendQueue.emplace_back(std::move(resp)); + mAsyncSendResource.mCVforQueue.notify_one(); + } + void sendResponse(std::vector const& blockHashes, std::map::iterator it) { auto reqId = mCurrentRequest.value(); @@ -422,15 +475,7 @@ class CacheSender::Impl auto llmRequest = it->second.mRequest; llmRequest->setRequestedBlockHashes(std::move(blockHashes)); - if (common::getEnvParallelCacheSend()) - { - // TODO: Use a thread pool and check for thread safety. - std::thread(&CacheSender::Impl::sendAndRemoveResponse, this, it->first, std::move(it->second)).detach(); - } - else - { - CacheSender::Impl::sendAndRemoveResponse(it->first, std::move(it->second)); - } + asyncSendAndRemoveResponse(it->first, std::move(it->second)); removeResponse(it); } mCurrentRequest = std::nullopt; @@ -454,7 +499,7 @@ class CacheSender::Impl break; } std::vector blockHashes; - if (!isSending() && !mReadyResponses.empty()) + if (!mReadyResponses.empty()) { auto const& requestInfo = recvRequestInfo(); auto reqId = requestInfo.getRequestId(); @@ -507,6 +552,12 @@ class CacheSender::Impl // We don't have to wait for the future. If another thread is sending data, it won't pay attention // to the terminate flag. mSenderCv.notify_all(); + mAsyncSendResource.mTerminate = true; + mAsyncSendResource.mCVforQueue.notify_all(); + for (auto& future : mAsyncSendFutures) + { + future.get(); + } } void removeResponse(std::map::iterator it) @@ -522,11 +573,6 @@ class CacheSender::Impl } } - [[nodiscard]] bool isSending() const - { - return mCurrentRequest.has_value(); - } - [[nodiscard]] RequestIdType getCurrentRequestId() const { return mCurrentRequest.value(); @@ -546,6 +592,8 @@ class CacheSender::Impl std::condition_variable mSenderCv; std::future mResponseFuture; std::unordered_map mRemainSendCount; + AsyncSendResource mAsyncSendResource; + std::vector> mAsyncSendFutures; int mDeviceId{-1}; executor::kv_cache::ConnectionManager* mManager; diff --git a/cpp/tensorrt_llm/common/envUtils.cpp b/cpp/tensorrt_llm/common/envUtils.cpp index 59c9d2fffe4..11cd072b2ad 100644 --- a/cpp/tensorrt_llm/common/envUtils.cpp +++ b/cpp/tensorrt_llm/common/envUtils.cpp @@ -414,7 +414,7 @@ bool getEnvKVCacheTransferUseSyncBuffer() size_t getEnvKVCacheSendMaxConcurrenceNum() { - static size_t const maxConcurrenceNum = getUInt64Env("TRTLLM_KVCACHE_SEND_MAX_CONCURRENCY_NUM").value_or(2); + static size_t const maxConcurrenceNum = getUInt64Env("TRTLLM_KVCACHE_SEND_MAX_CONCURRENCY_NUM").value_or(1); return maxConcurrenceNum; } diff --git a/cpp/tests/unit_tests/batch_manager/cacheTransBufferTest.cpp b/cpp/tests/unit_tests/batch_manager/cacheTransBufferTest.cpp index 1bc13959940..d00c74be64b 100644 --- a/cpp/tests/unit_tests/batch_manager/cacheTransBufferTest.cpp +++ b/cpp/tests/unit_tests/batch_manager/cacheTransBufferTest.cpp @@ -108,9 +108,7 @@ TEST_F(CacheTransBufferTest, TestPreAllocBufferSize) size_t recvbufferCount = tensorrt_llm::common::getEnvRequestKVCacheConcurrent() ? tensorrt_llm::common::getEnvKVCacheRecvBufferCount() : 1; - size_t sendBufferCount = tensorrt_llm::common::getEnvParallelCacheSend() - ? tensorrt_llm::common::getEnvKVCacheSendMaxConcurrenceNum() - : 1; + size_t sendBufferCount = tensorrt_llm::common::getEnvKVCacheSendMaxConcurrenceNum(); size_t cacheSizeBytesPerToken = kvCacheSizePerToken(4, 2, 64, CacheType::kSELFKONLY); std::map cacheSizeBytesPerTokenPerWindow{ {maxBlocksPerSeq * tokensPerBlock, cacheSizeBytesPerToken}}; @@ -152,9 +150,7 @@ TEST_F(CacheTransBufferTest, TestPreAllocBufferSize2) size_t recvbufferCount = tensorrt_llm::common::getEnvRequestKVCacheConcurrent() ? tensorrt_llm::common::getEnvKVCacheRecvBufferCount() : 1; - size_t sendBufferCount = tensorrt_llm::common::getEnvParallelCacheSend() - ? tensorrt_llm::common::getEnvKVCacheSendMaxConcurrenceNum() - : 1; + size_t sendBufferCount = tensorrt_llm::common::getEnvKVCacheSendMaxConcurrenceNum(); size_t cacheSizeBytesPerToken = kvCacheSizePerToken(4, 2, 64, CacheType::kSELF); tensorrt_llm::executor::CacheTransceiverConfig cacheTransceiverConfig{ tensorrt_llm::executor::CacheTransceiverConfig::BackendType::UCX, maxNumTokens}; diff --git a/docs/source/legacy/advanced/disaggregated-service.md b/docs/source/legacy/advanced/disaggregated-service.md index 18112d93264..2c66ebaaa81 100644 --- a/docs/source/legacy/advanced/disaggregated-service.md +++ b/docs/source/legacy/advanced/disaggregated-service.md @@ -30,7 +30,7 @@ TRT-LLM uses some environment variables to control the behavior of disaggregated * `TRTLLM_KVCACHE_TRANSFER_USE_ASYNC_BUFFER`: If set to `1`, TRT-LLM will use `cudaMallocAsync` to allocate buffers for KV cache transmission. The default value is `0`. This environment variable only takes effect when `TRTLLM_KVCACHE_TRANSFER_BUFFER_SIZE` is greater than 0. -* `TRTLLM_KVCACHE_SEND_MAX_CONCURRENCY_NUM`: The maximum number of concurrent KV cache sends. The default value is `4`. This environment variable only takes effect when `TRTLLM_KVCACHE_TRANSFER_BUFFER_SIZE` is greater than 0. +* `TRTLLM_KVCACHE_SEND_MAX_CONCURRENCY_NUM`: The maximum number of concurrent KV cache sends. The default value is `1`. This environment variable only takes effect when `TRTLLM_KVCACHE_TRANSFER_BUFFER_SIZE` is greater than 0. There are some other useful environment variables that may help when encountering failures or performance issues. From 2aa30d45c92d2d15047948fcab7335f7110eb12e Mon Sep 17 00:00:00 2001 From: Chuang Zhu <111838961+chuangz0@users.noreply.github.com> Date: Tue, 16 Sep 2025 10:30:39 +0000 Subject: [PATCH 6/7] remove log Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com> --- cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp b/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp index b94a027efbf..527291b220b 100644 --- a/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp +++ b/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp @@ -408,7 +408,6 @@ class CacheSender::Impl void handleAsyncSend(AsyncSendResource& resource) { tensorrt_llm::common::setThreadName("dataTransAsyncSend"); - TLLM_LOG_INFO(mpi::MpiComm::world().getRank(), "Start handling async send"); while (!resource.mTerminate) { Response resp; @@ -428,7 +427,6 @@ class CacheSender::Impl resp = std::move(resource.mSendQueue.front()); resource.mSendQueue.pop_front(); } - // TLLM_LOG_INFO(mpi::MpiComm::world().getRank(), "Start sending request %zu", resp.mRequest->mRequestId); sendAndRemoveResponse(resp.mRequest->mRequestId, std::move(resp)); } } From 5c7e1a4dddc9c05b0735bf6f16deb53ff581fabe Mon Sep 17 00:00:00 2001 From: Chuang Zhu <111838961+chuangz0@users.noreply.github.com> Date: Fri, 19 Sep 2025 06:39:33 +0000 Subject: [PATCH 7/7] fix test Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com> --- cpp/tensorrt_llm/common/envUtils.cpp | 6 ------ cpp/tests/unit_tests/batch_manager/cacheTransBufferTest.cpp | 2 +- docs/source/features/disagg-serving.md | 3 +-- docs/source/legacy/advanced/disaggregated-service.md | 1 - 4 files changed, 2 insertions(+), 10 deletions(-) diff --git a/cpp/tensorrt_llm/common/envUtils.cpp b/cpp/tensorrt_llm/common/envUtils.cpp index 11cd072b2ad..80be36c30c7 100644 --- a/cpp/tensorrt_llm/common/envUtils.cpp +++ b/cpp/tensorrt_llm/common/envUtils.cpp @@ -324,12 +324,6 @@ bool getEnvDisableSelectiveCacheTransfer() return disableSelectiveCacheTransfer; } -bool getEnvParallelCacheSend() -{ - static bool const parallelCacheSend = getBoolEnv("TRTLLM_PARALLEL_CACHE_SEND"); - return parallelCacheSend; -} - bool getEnvRequestKVCacheConcurrent() { static bool const requestKVCacheConcurrent = getBoolEnv("TRTLLM_REQUEST_KV_CACHE_CONCURRENT"); diff --git a/cpp/tests/unit_tests/batch_manager/cacheTransBufferTest.cpp b/cpp/tests/unit_tests/batch_manager/cacheTransBufferTest.cpp index d00c74be64b..42ba14d84dc 100644 --- a/cpp/tests/unit_tests/batch_manager/cacheTransBufferTest.cpp +++ b/cpp/tests/unit_tests/batch_manager/cacheTransBufferTest.cpp @@ -256,7 +256,7 @@ TEST_F(CacheTransBufferTest, TestBufferIndexAssignment1) SizeType32 tokensPerBlock = 8; std::optional maxNumTokens = maxBlocksPerSeq * tokensPerBlock; setenv("TRTLLM_REQUEST_KV_CACHE_CONCURRENT", "1", 1); - setenv("TRTLLM_PARALLEL_CACHE_SEND", "1", 1); + setenv("TRTLLM_KVCACHE_SEND_MAX_CONCURRENCY_NUM", "2", 1); SetUpCacheTransBuffer(4, 2, 64, tokensPerBlock, CacheType::kSELF, maxNumTokens, maxBlocksPerSeq); auto bufferId = mTransBufferManager->assignBufferIndexForSend(); EXPECT_TRUE(bufferId.has_value()); diff --git a/docs/source/features/disagg-serving.md b/docs/source/features/disagg-serving.md index 33a075d488f..041ab7f3784 100644 --- a/docs/source/features/disagg-serving.md +++ b/docs/source/features/disagg-serving.md @@ -192,7 +192,6 @@ For more information on how to use Dynamo with TensorRT-LLM, please refer to [th TRT-LLM uses some environment variables to control the behavior of disaggregated service. -* `TRTLLM_PARALLEL_CACHE_SEND`: If set to `1`, contextExecutor will attempt to send KV cache for multiple requests in parallel. The default value is `0`. * `TRTLLM_DISABLE_KV_CACHE_TRANSFER_OVERLAP`: If set to `1`, generationExecutor will not overlap KV cache transfer with model inference. The default value is `0`. @@ -206,7 +205,7 @@ TRT-LLM uses some environment variables to control the behavior of disaggregated * `TRTLLM_KVCACHE_TRANSFER_USE_ASYNC_BUFFER`: If set to `1`, TRT-LLM will use `cudaMallocAsync` to allocate buffers for KV cache transmission. The default value is `0`. This environment variable only takes effect when `TRTLLM_KVCACHE_TRANSFER_BUFFER_SIZE` is greater than 0. -* `TRTLLM_KVCACHE_SEND_MAX_CONCURRENCY_NUM`: The maximum number of concurrent KV cache sends. The default value is `4`. This environment variable only takes effect when `TRTLLM_KVCACHE_TRANSFER_BUFFER_SIZE` is greater than 0. +* `TRTLLM_KVCACHE_SEND_MAX_CONCURRENCY_NUM`: The maximum number of concurrent KV cache sends. The default value is `1`. This environment variable only takes effect when `TRTLLM_KVCACHE_TRANSFER_BUFFER_SIZE` is greater than 0. There are some other useful environment variables that may help when encountering failures or performance issues. diff --git a/docs/source/legacy/advanced/disaggregated-service.md b/docs/source/legacy/advanced/disaggregated-service.md index 2c66ebaaa81..fe3e30c0039 100644 --- a/docs/source/legacy/advanced/disaggregated-service.md +++ b/docs/source/legacy/advanced/disaggregated-service.md @@ -16,7 +16,6 @@ An [architectural and performance overview](../../../docs/source/blogs/tech_blog TRT-LLM uses some environment variables to control the behavior of disaggregated service. -* `TRTLLM_PARALLEL_CACHE_SEND`: If set to `1`, contextExecutor will attempt to send KV cache for multiple requests in parallel. The default value is `0`. * `TRTLLM_DISABLE_KV_CACHE_TRANSFER_OVERLAP`: If set to `1`, generationExecutor will not overlap KV cache transfer with model inference. The default value is `0`.