1717
1818#include " connection.h"
1919#include " tensorrt_llm/common/envUtils.h"
20+ #include " tensorrt_llm/executor/cache_transmission/cacheSplitConcat.h"
2021#include < string>
2122#include < unistd.h>
2223
@@ -34,6 +35,22 @@ std::string genUniqueAgentName()
3435 return std::string (hostname) + " _" + std::to_string (pid) + " _" + std::to_string (counter++);
3536}
3637
38+ auto computeSendOffsetRatio (
39+ CacheState const & peerCacheState, size_t peerIdx, CacheState const & selfCacheState, int valideConnectionIdx)
40+ {
41+ auto peerTargetInfo = targetIRanks (selfCacheState, peerCacheState, peerIdx);
42+ // int ppRank = valideConnectionIdx % peerTargetInfo.mDomainPPSize;
43+ size_t offsetLayer = 0 ;
44+ for (int i = 0 ; i < valideConnectionIdx; i++)
45+ {
46+ offsetLayer += peerTargetInfo.getPeerPPDomainLayerNum (i);
47+ }
48+
49+ size_t selfSendLayer = peerTargetInfo.getPeerPPDomainLayerNum (valideConnectionIdx);
50+
51+ return std::make_pair (offsetLayer, selfSendLayer);
52+ }
53+
3754AgentConnection::AgentConnection (
3855 std::string mAgentName , std::string mRemoteAgentName , AgentConnectionManager* mAgentConnectionManager )
3956 : mAgentName (mAgentName )
@@ -82,7 +99,8 @@ void AgentConnection::send(DataContext const& ctx, void const* data, size_t size
8299 reinterpret_cast <uintptr_t >(data), size, static_cast <uint32_t >(mAgentConnectionManager ->getDeviceId ())};
83100 MemoryDescs srcDescs{MemoryType::kVRAM , {srcDesc}};
84101 auto dstBaseDesc = mSenderState .mReceiverBufferDesc ;
85- MemoryDesc dstDesc{dstBaseDesc.getAddr () + (mSenderState .validSegmentIdx * size), size, dstBaseDesc.getDeviceId ()};
102+ auto offset = size / mSenderState .mOffsetRatio .second * mSenderState .mOffsetRatio .first ;
103+ MemoryDesc dstDesc{dstBaseDesc.getAddr () + offset, size, dstBaseDesc.getDeviceId ()};
86104 TLLM_LOG_DEBUG (
87105 " send dstDesc: %p, size: %ld ,validSegmentIdx: %ld" , dstDesc.getAddr (), size, mSenderState .validSegmentIdx );
88106 MemoryDescs dstDescs{MemoryType::kVRAM , {dstDesc}};
@@ -137,10 +155,12 @@ void AgentConnection::sendRequestAndBufferInfo(
137155 mAgentConnectionManager ->getAgent ()->notifySyncMessage (mRemoteAgentName , ss.str ());
138156}
139157
140- void AgentConnection::setSenderState (MemoryDesc mReceiverBufferDesc , int validSegmentIdx)
158+ void AgentConnection::setSenderState (
159+ MemoryDesc mReceiverBufferDesc , int validSegmentIdx, std::pair<size_t , size_t > offsetRatio)
141160{
142161 mSenderState .mReceiverBufferDesc = mReceiverBufferDesc ;
143162 mSenderState .validSegmentIdx = validSegmentIdx;
163+ mSenderState .mOffsetRatio = offsetRatio;
144164}
145165
146166void AgentConnection::setHasLoadRemoteAgent (bool hasLoadRemoteAgent)
@@ -155,8 +175,9 @@ bool AgentConnection::hasLoadRemoteAgent() const
155175}
156176
157177AgentConnectionManager::AgentConnectionManager (
158- batch_manager::kv_cache_manager::CacheTransBufferManager* cacheTransBufferManager)
178+ batch_manager::kv_cache_manager::CacheTransBufferManager* cacheTransBufferManager, CacheState cacheState )
159179 : mRegMemDescs(MemoryType::kVRAM , {})
180+ , mCacheState(std::move(cacheState))
160181{
161182 TLLM_CUDA_CHECK (cudaGetDevice (&mDeviceId ));
162183 TLLM_CHECK (mDeviceId != -1 );
@@ -260,7 +281,10 @@ AgentConnection const* AgentConnectionManager::recvConnectionAndRequestInfo(batc
260281 auto remoteAgentName = requestAndBufferInfo.mAgentName ;
261282 TLLM_LOG_DEBUG (" recv Address:%s" , address.c_str ());
262283 auto connection = connect (remoteAgentName, address, metadataOpt, true );
263- connection->setSenderState (bufferDesc, validConnectionIdx);
284+ // to compute the offset.
285+ auto offsetRatio = computeSendOffsetRatio (requestInfo.getTransState ().getCacheState ().value (),
286+ requestInfo.getTransState ().getCommState ()->getSelfIdx (), mCacheState , validConnectionIdx);
287+ connection->setSenderState (bufferDesc, validConnectionIdx, offsetRatio);
264288 it2 = notifs.erase (it2);
265289 if (notifs.empty ())
266290 {
0 commit comments