Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,9 @@ bool CacheFormatter::needSendCache(
= selfConfig.getParallelConfig().mTensorParallelism / selfConfig.getParallelConfig().mDPsize;
selfTpRankInDpGroup = selfTpRank % selfTPNumInDPGroup;
}
int destDPRank = destConfig.getParallelConfig().mEnableAttentionDP ? destConfig.getParallelConfig().mDPrank : 0;

// only TP rank % dupHeadFactor == 0 need to send cache.
return selfTpRankInDpGroup % targetInfo.mDupHeadFactor == 0;
return (destDPRank % targetInfo.mDupHeadFactor) == (selfTpRankInDpGroup % targetInfo.mDupHeadFactor);
}

void checkAlternateWindow(BaseKVCacheManager* cacheManager, BaseCacheFormatter::CacheState const& selfConfig,
Expand Down Expand Up @@ -140,11 +140,12 @@ std::vector<size_t> CacheFormatter::pickRecvConnections(
return ret;
}
TLLM_CHECK(numConnections == targetInfo.mIRanks.size());
int selfDPRank = selfConfig.getParallelConfig().mEnableAttentionDP ? selfConfig.getParallelConfig().mDPrank : 0;

std::vector<size_t> ret;
for (int i = 0; i < targetInfo.mDomainTPSize; i++)
{
if (i % targetInfo.mPeerDupHeadFactor == 0)
if ((i % targetInfo.mPeerDupHeadFactor) == (selfDPRank % targetInfo.mPeerDupHeadFactor))
{
for (int j = 0; j < targetInfo.mDomainPPSize; j++)
{
Expand Down
4 changes: 2 additions & 2 deletions cpp/tensorrt_llm/batch_manager/cacheTransBuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ CacheTransBufferManager::CacheTransBufferManager(
= maxNumTokens.has_value() ? bufferSizeFromMaxNumToken : common::getEnvMemSizeForKVCacheTransferBuffer();
mOnlyUseDynamicBuffer = mTransferBufferSize == 0;
mRecvBufferCount = common::getEnvRequestKVCacheConcurrent() ? common::getEnvKVCacheRecvBufferCount() : 1;
mSendBufferCount = common::getEnvParallelCacheSend() ? common::getEnvKVCacheSendMaxConcurrenceNum() : 1;
mSendBufferCount = common::getEnvKVCacheSendMaxConcurrenceNum();
mUseFabricMemory = !(common::getEnvKVCacheTransferUseSyncBuffer() || common::getEnvKVCacheTransferUseAsyncBuffer())
&& FabricMemory::supportFbaricMemory();
if (mUseFabricMemory)
Expand Down Expand Up @@ -269,7 +269,7 @@ size_t CacheTransBufferManager::preAllocBufferSize(
TransferBufferSize = FabricMemory::getAlignedSize(TransferBufferSize);
}
size_t RecvBufferCount = common::getEnvRequestKVCacheConcurrent() ? common::getEnvKVCacheRecvBufferCount() : 1;
size_t SendBufferCount = common::getEnvParallelCacheSend() ? common::getEnvKVCacheSendMaxConcurrenceNum() : 1;
size_t SendBufferCount = common::getEnvKVCacheSendMaxConcurrenceNum();
size_t PreAllocBufferSize = TransferBufferSize * (RecvBufferCount + SendBufferCount);
return PreAllocBufferSize;
}
Expand Down
88 changes: 67 additions & 21 deletions cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,12 @@ class CacheSender::Impl
TLLM_CUDA_CHECK(cudaGetDevice(&mDeviceId));
mCurrentRequest = std::nullopt;
mResponseFuture = std::async(std::launch::async, &Impl::response, this);
int asyncSendThreadNum = common::getEnvKVCacheSendMaxConcurrenceNum();
for (int i = 0; i < asyncSendThreadNum; i++)
{
mAsyncSendFutures.emplace_back(
std::async(std::launch::async, &Impl::handleAsyncSend, this, std::ref(mAsyncSendResource)));
}
}

[[nodiscard]] std::future<void> sendAsync(LlmRequest& llmRequest)
Expand Down Expand Up @@ -294,9 +300,9 @@ class CacheSender::Impl

void release(LlmRequest::RequestIdType requestId)
{
std::unique_lock<std::mutex> lk(mMtxForMap);
auto it = mRequestToSession.find(requestId);
TLLM_CHECK(it != mRequestToSession.end());
std::unique_lock<std::mutex> lk(mMtxForMap);
if (!common::getEnvKVCacheTransferOutputPath().empty())
{
if (!mMeasuresFile.is_open())
Expand Down Expand Up @@ -368,11 +374,15 @@ class CacheSender::Impl

void sendSync(LlmRequest const& llmRequest)
{
auto it = mRequestToSession.find(llmRequest.mRequestId);
TLLM_CHECK(it != mRequestToSession.end());
auto& session = it->second;
session.setLlmRequest(llmRequest);
mFormatter->format(session);
TransferSession* session = nullptr;
{
std::unique_lock<std::mutex> lk(mMtxForMap);
auto it = mRequestToSession.find(llmRequest.mRequestId);
TLLM_CHECK(it != mRequestToSession.end());
session = std::addressof(it->second);
}
session->setLlmRequest(llmRequest);
mFormatter->format(*session);
}

~Impl()
Expand All @@ -387,6 +397,40 @@ class CacheSender::Impl
std::promise<void> mPromise;
};

struct AsyncSendResource
{
std::deque<Response> mSendQueue;
std::mutex mMtxForQueue;
std::condition_variable mCVforQueue;
std::atomic<bool> mTerminate{false};
};

void handleAsyncSend(AsyncSendResource& resource)
{
tensorrt_llm::common::setThreadName("dataTransAsyncSend");
while (!resource.mTerminate)
{
Response resp;
{
std::unique_lock lk(resource.mMtxForQueue);
resource.mCVforQueue.wait(
lk, [&resource] { return !resource.mSendQueue.empty() || resource.mTerminate; });
if (resource.mTerminate)
{
if (!resource.mSendQueue.empty())
{
TLLM_LOG_WARNING("There are still %zu requests in the mSendQueue, but encountered terminate.",
resource.mSendQueue.size());
}
break;
}
resp = std::move(resource.mSendQueue.front());
resource.mSendQueue.pop_front();
}
sendAndRemoveResponse(resp.mRequest->mRequestId, std::move(resp));
}
}

void sendAndRemoveResponse(RequestIdType id, Response resp) noexcept
{
try
Expand All @@ -409,6 +453,13 @@ class CacheSender::Impl
}
}

void asyncSendAndRemoveResponse(RequestIdType id, Response resp) noexcept
{
std::unique_lock lk(mAsyncSendResource.mMtxForQueue);
mAsyncSendResource.mSendQueue.emplace_back(std::move(resp));
mAsyncSendResource.mCVforQueue.notify_one();
}

void sendResponse(std::vector<size_t> const& blockHashes, std::map<RequestIdType, Response>::iterator it)
{
auto reqId = mCurrentRequest.value();
Expand All @@ -422,15 +473,7 @@ class CacheSender::Impl
auto llmRequest = it->second.mRequest;
llmRequest->setRequestedBlockHashes(std::move(blockHashes));

if (common::getEnvParallelCacheSend())
{
// TODO: Use a thread pool and check for thread safety.
std::thread(&CacheSender::Impl::sendAndRemoveResponse, this, it->first, std::move(it->second)).detach();
}
else
{
CacheSender::Impl::sendAndRemoveResponse(it->first, std::move(it->second));
}
asyncSendAndRemoveResponse(it->first, std::move(it->second));
removeResponse(it);
}
mCurrentRequest = std::nullopt;
Expand All @@ -454,7 +497,7 @@ class CacheSender::Impl
break;
}
std::vector<size_t> blockHashes;
if (!isSending() && !mReadyResponses.empty())
if (!mReadyResponses.empty())
{
auto const& requestInfo = recvRequestInfo();
auto reqId = requestInfo.getRequestId();
Expand Down Expand Up @@ -507,6 +550,12 @@ class CacheSender::Impl
// We don't have to wait for the future. If another thread is sending data, it won't pay attention
// to the terminate flag.
mSenderCv.notify_all();
mAsyncSendResource.mTerminate = true;
mAsyncSendResource.mCVforQueue.notify_all();
for (auto& future : mAsyncSendFutures)
{
future.get();
}
}

void removeResponse(std::map<RequestIdType, Response>::iterator it)
Expand All @@ -522,11 +571,6 @@ class CacheSender::Impl
}
}

[[nodiscard]] bool isSending() const
{
return mCurrentRequest.has_value();
}

[[nodiscard]] RequestIdType getCurrentRequestId() const
{
return mCurrentRequest.value();
Expand All @@ -546,6 +590,8 @@ class CacheSender::Impl
std::condition_variable mSenderCv;
std::future<void> mResponseFuture;
std::unordered_map<LlmRequest::RequestIdType, int> mRemainSendCount;
AsyncSendResource mAsyncSendResource;
std::vector<std::future<void>> mAsyncSendFutures;
int mDeviceId{-1};

executor::kv_cache::ConnectionManager* mManager;
Expand Down
22 changes: 15 additions & 7 deletions cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,12 @@ std::vector<size_t> MLACacheFormatter::pickRecvConnections(
TLLM_CHECK(targetInfo.mDomainCPSize == 1);
TLLM_CHECK(numConnections == targetInfo.mIRanks.size());
std::vector<size_t> ret;
// targetInfo , mRanks [tpranks, dpranks]
// targetInfo , mRanks [tpranks, ppranks]
int dpRank = selfConfig.getParallelConfig().mEnableAttentionDP ? selfConfig.getParallelConfig().mDPrank : 0;

for (int i = 0; i < targetInfo.mDomainPPSize; i++)
{
ret.push_back(i);
ret.push_back(i + (dpRank % (targetInfo.mDomainTPSize)) * targetInfo.mDomainPPSize);
}
return ret;
}
Expand All @@ -85,19 +87,24 @@ bool MLACacheFormatter::needSendCache(
{
int selfTpRank = selfIdx % selfConfig.getParallelConfig().mTensorParallelism;

int destTPNumInDPGroup = destConfig.getParallelConfig().mEnableAttentionDP
? destConfig.getParallelConfig().mTensorParallelism / destConfig.getParallelConfig().mDPsize
: destConfig.getParallelConfig().mTensorParallelism;
int destDPRank = destConfig.getParallelConfig().mEnableAttentionDP ? destConfig.getParallelConfig().mDPrank : 0;

if (selfConfig.getParallelConfig().mEnableAttentionDP)
{
int selfTPNumInDPGroup
= selfConfig.getParallelConfig().mTensorParallelism / selfConfig.getParallelConfig().mDPsize;
int destTPNumInDPGroup = destConfig.getParallelConfig().mEnableAttentionDP
? destConfig.getParallelConfig().mTensorParallelism / destConfig.getParallelConfig().mDPsize
: destConfig.getParallelConfig().mTensorParallelism;

int selfTPrankINDPGroup = selfTpRank % selfTPNumInDPGroup;
if (selfTPNumInDPGroup <= destTPNumInDPGroup)
{
return true;
}
return selfTPrankINDPGroup % (selfTPNumInDPGroup / destTPNumInDPGroup) == 0;

int dupHeadFactor = selfTPNumInDPGroup / destTPNumInDPGroup;
return selfTPrankINDPGroup % dupHeadFactor == destDPRank % dupHeadFactor;
}

int destTPNum = destConfig.getParallelConfig().mEnableAttentionDP
Expand All @@ -108,7 +115,8 @@ bool MLACacheFormatter::needSendCache(
{
return true;
}
return selfTpRank % (selfTPNum / destTPNum) == 0;
int dupHeadFactor = selfTPNum / destTPNum;
return selfTpRank % dupHeadFactor == destDPRank % dupHeadFactor;
}

void MLACacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& session)
Expand Down
8 changes: 1 addition & 7 deletions cpp/tensorrt_llm/common/envUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -324,12 +324,6 @@ bool getEnvDisableSelectiveCacheTransfer()
return disableSelectiveCacheTransfer;
}

bool getEnvParallelCacheSend()
{
static bool const parallelCacheSend = getBoolEnv("TRTLLM_PARALLEL_CACHE_SEND");
return parallelCacheSend;
}

bool getEnvRequestKVCacheConcurrent()
{
static bool const requestKVCacheConcurrent = getBoolEnv("TRTLLM_REQUEST_KV_CACHE_CONCURRENT");
Expand Down Expand Up @@ -414,7 +408,7 @@ bool getEnvKVCacheTransferUseSyncBuffer()
size_t getEnvKVCacheSendMaxConcurrenceNum()
{

static size_t const maxConcurrenceNum = getUInt64Env("TRTLLM_KVCACHE_SEND_MAX_CONCURRENCY_NUM").value_or(2);
static size_t const maxConcurrenceNum = getUInt64Env("TRTLLM_KVCACHE_SEND_MAX_CONCURRENCY_NUM").value_or(1);
return maxConcurrenceNum;
}

Expand Down
10 changes: 3 additions & 7 deletions cpp/tests/unit_tests/batch_manager/cacheTransBufferTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,7 @@ TEST_F(CacheTransBufferTest, TestPreAllocBufferSize)
size_t recvbufferCount = tensorrt_llm::common::getEnvRequestKVCacheConcurrent()
? tensorrt_llm::common::getEnvKVCacheRecvBufferCount()
: 1;
size_t sendBufferCount = tensorrt_llm::common::getEnvParallelCacheSend()
? tensorrt_llm::common::getEnvKVCacheSendMaxConcurrenceNum()
: 1;
size_t sendBufferCount = tensorrt_llm::common::getEnvKVCacheSendMaxConcurrenceNum();
size_t cacheSizeBytesPerToken = kvCacheSizePerToken(4, 2, 64, CacheType::kSELFKONLY);
std::map<SizeType32, SizeType32> cacheSizeBytesPerTokenPerWindow{
{maxBlocksPerSeq * tokensPerBlock, cacheSizeBytesPerToken}};
Expand Down Expand Up @@ -152,9 +150,7 @@ TEST_F(CacheTransBufferTest, TestPreAllocBufferSize2)
size_t recvbufferCount = tensorrt_llm::common::getEnvRequestKVCacheConcurrent()
? tensorrt_llm::common::getEnvKVCacheRecvBufferCount()
: 1;
size_t sendBufferCount = tensorrt_llm::common::getEnvParallelCacheSend()
? tensorrt_llm::common::getEnvKVCacheSendMaxConcurrenceNum()
: 1;
size_t sendBufferCount = tensorrt_llm::common::getEnvKVCacheSendMaxConcurrenceNum();
size_t cacheSizeBytesPerToken = kvCacheSizePerToken(4, 2, 64, CacheType::kSELF);
tensorrt_llm::executor::CacheTransceiverConfig cacheTransceiverConfig{
tensorrt_llm::executor::CacheTransceiverConfig::BackendType::UCX, maxNumTokens};
Expand Down Expand Up @@ -260,7 +256,7 @@ TEST_F(CacheTransBufferTest, TestBufferIndexAssignment1)
SizeType32 tokensPerBlock = 8;
std::optional<size_t> maxNumTokens = maxBlocksPerSeq * tokensPerBlock;
setenv("TRTLLM_REQUEST_KV_CACHE_CONCURRENT", "1", 1);
setenv("TRTLLM_PARALLEL_CACHE_SEND", "1", 1);
setenv("TRTLLM_KVCACHE_SEND_MAX_CONCURRENCY_NUM", "2", 1);
SetUpCacheTransBuffer(4, 2, 64, tokensPerBlock, CacheType::kSELF, maxNumTokens, maxBlocksPerSeq);
auto bufferId = mTransBufferManager->assignBufferIndexForSend();
EXPECT_TRUE(bufferId.has_value());
Expand Down
21 changes: 19 additions & 2 deletions cpp/tests/unit_tests/multi_gpu/cacheTransceiverTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1432,6 +1432,18 @@ INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForMLA3, AsymmetricalCacheTestWi
testing::Values(16), testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(1),
testing::Values(true), testing::Values(false), testing::Values(true), testing::Values(false)));

INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForMLA4, AsymmetricalCacheTestWithDP,
testing::Combine(testing::Values(2), testing::Values(1), testing::Values(1), testing::Values(4), testing::Values(1),
testing::Values(1), testing::Values(4), testing::Values(1), testing::Values(4), testing::Values(16),
testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(1),
testing::Values(true), testing::Values(false), testing::Values(true), testing::Values(false)));

INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForMLA5, AsymmetricalCacheTestWithDP,
testing::Combine(testing::Values(4), testing::Values(1), testing::Values(1), testing::Values(2), testing::Values(1),
testing::Values(1), testing::Values(4), testing::Values(1), testing::Values(4), testing::Values(16),
testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(1),
testing::Values(true), testing::Values(false), testing::Values(true), testing::Values(false)));

INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForNoMLA, AsymmetricalCacheTestWithDP,
testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1), testing::Values(1, 2),
testing::Values(1, 2), testing::Values(1), testing::Values(4), testing::Values(4), testing::Values(4),
Expand Down Expand Up @@ -1472,6 +1484,11 @@ INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForNoMLADuplicate2, Asymmetrical
testing::Values(1), testing::Values(1), testing::Values(4), testing::Values(2), testing::Values(4),
testing::Values(16), testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2),
testing::Values(false), testing::Values(false), testing::Values(false), testing::Values(false)));
INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForNoMLADuplicate3, AsymmetricalCacheTestWithDP,
testing::Combine(testing::Values(2), testing::Values(1), testing::Values(1), testing::Values(4), testing::Values(1),
testing::Values(1), testing::Values(4), testing::Values(2), testing::Values(4), testing::Values(16),
testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2),
testing::Values(false), testing::Values(false), testing::Values(true), testing::Values(false)));

INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForNoMLADuplicate4, AsymmetricalCacheTestWithDP,
testing::Combine(testing::Values(4), testing::Values(1), testing::Values(1), testing::Values(1, 2),
Expand Down Expand Up @@ -1849,13 +1866,13 @@ TEST(targetTest, CacheStateContextDP)
/*expectNeedSend*/ true);
verifyContext(
/*contextRank*/ 0, /*generationRank*/ 1, /*expectRanks*/ {1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1,
/*expectNeedSend*/ true);
/*expectNeedSend*/ false);
verifyContext(
/*contextRank*/ 1, /*generationRank*/ 0, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1,
/*expectNeedSend*/ false);
verifyContext(
/*contextRank*/ 1, /*generationRank*/ 1, /*expectRanks*/ {1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1,
/*expectNeedSend*/ false);
/*expectNeedSend*/ true);
verifyContext(
/*contextRank*/ 2, /*generationRank*/ 0, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1,
/*expectNeedSend*/ false);
Expand Down
3 changes: 1 addition & 2 deletions docs/source/features/disagg-serving.md
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,6 @@ For more information on how to use Dynamo with TensorRT-LLM, please refer to [th

TRT-LLM uses some environment variables to control the behavior of disaggregated service.

* `TRTLLM_PARALLEL_CACHE_SEND`: If set to `1`, contextExecutor will attempt to send KV cache for multiple requests in parallel. The default value is `0`.

* `TRTLLM_DISABLE_KV_CACHE_TRANSFER_OVERLAP`: If set to `1`, generationExecutor will not overlap KV cache transfer with model inference. The default value is `0`.

Expand All @@ -206,7 +205,7 @@ TRT-LLM uses some environment variables to control the behavior of disaggregated

* `TRTLLM_KVCACHE_TRANSFER_USE_ASYNC_BUFFER`: If set to `1`, TRT-LLM will use `cudaMallocAsync` to allocate buffers for KV cache transmission. The default value is `0`. This environment variable only takes effect when `TRTLLM_KVCACHE_TRANSFER_BUFFER_SIZE` is greater than 0.

* `TRTLLM_KVCACHE_SEND_MAX_CONCURRENCY_NUM`: The maximum number of concurrent KV cache sends. The default value is `4`. This environment variable only takes effect when `TRTLLM_KVCACHE_TRANSFER_BUFFER_SIZE` is greater than 0.
* `TRTLLM_KVCACHE_SEND_MAX_CONCURRENCY_NUM`: The maximum number of concurrent KV cache sends. The default value is `1`. This environment variable only takes effect when `TRTLLM_KVCACHE_TRANSFER_BUFFER_SIZE` is greater than 0.

There are some other useful environment variables that may help when encountering failures or performance issues.

Expand Down
Loading
Loading