Skip to content

Commit 399252d

Browse files
committed
nixl support uneven layerNum
Signed-off-by: Chuang Zhu <[email protected]>
1 parent 3d25e92 commit 399252d

File tree

5 files changed

+45
-12
lines changed

5 files changed

+45
-12
lines changed

cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa
186186
else if (backendType.value() == executor::CacheTransceiverConfig::BackendType::NIXL)
187187
{
188188
mManager = std::make_unique<tensorrt_llm::executor::kv_cache::AgentConnectionManager>(
189-
mCacheTransBufferManager.get());
189+
mCacheTransBufferManager.get(), *mCacheState);
190190
TLLM_LOG_INFO("NIXL Connection Manager created");
191191
}
192192
else if (backendType.value() == executor::CacheTransceiverConfig::BackendType::MPI)

cpp/tensorrt_llm/executor/cache_transmission/agent_utils/connection.cpp

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
#include "connection.h"
1919
#include "tensorrt_llm/common/envUtils.h"
20+
#include "tensorrt_llm/executor/cache_transmission/cacheSplitConcat.h"
2021
#include <string>
2122
#include <unistd.h>
2223

@@ -34,6 +35,22 @@ std::string genUniqueAgentName()
3435
return std::string(hostname) + "_" + std::to_string(pid) + "_" + std::to_string(counter++);
3536
}
3637

38+
auto computeSendOffsetRatio(
39+
CacheState const& peerCacheState, size_t peerIdx, CacheState const& selfCacheState, int valideConnectionIdx)
40+
{
41+
auto peerTargetInfo = targetIRanks(selfCacheState, peerCacheState, peerIdx);
42+
// int ppRank = valideConnectionIdx % peerTargetInfo.mDomainPPSize;
43+
size_t offsetLayer = 0;
44+
for (int i = 0; i < valideConnectionIdx; i++)
45+
{
46+
offsetLayer += peerTargetInfo.getPeerPPDomainLayerNum(i);
47+
}
48+
49+
size_t selfSendLayer = peerTargetInfo.getPeerPPDomainLayerNum(valideConnectionIdx);
50+
51+
return std::make_pair(offsetLayer, selfSendLayer);
52+
}
53+
3754
AgentConnection::AgentConnection(
3855
std::string mAgentName, std::string mRemoteAgentName, AgentConnectionManager* mAgentConnectionManager)
3956
: mAgentName(mAgentName)
@@ -82,7 +99,8 @@ void AgentConnection::send(DataContext const& ctx, void const* data, size_t size
8299
reinterpret_cast<uintptr_t>(data), size, static_cast<uint32_t>(mAgentConnectionManager->getDeviceId())};
83100
MemoryDescs srcDescs{MemoryType::kVRAM, {srcDesc}};
84101
auto dstBaseDesc = mSenderState.mReceiverBufferDesc;
85-
MemoryDesc dstDesc{dstBaseDesc.getAddr() + (mSenderState.validSegmentIdx * size), size, dstBaseDesc.getDeviceId()};
102+
auto offset = size / mSenderState.mOffsetRatio.second * mSenderState.mOffsetRatio.first;
103+
MemoryDesc dstDesc{dstBaseDesc.getAddr() + offset, size, dstBaseDesc.getDeviceId()};
86104
TLLM_LOG_DEBUG(
87105
"send dstDesc: %p, size: %ld ,validSegmentIdx: %ld", dstDesc.getAddr(), size, mSenderState.validSegmentIdx);
88106
MemoryDescs dstDescs{MemoryType::kVRAM, {dstDesc}};
@@ -137,10 +155,12 @@ void AgentConnection::sendRequestAndBufferInfo(
137155
mAgentConnectionManager->getAgent()->notifySyncMessage(mRemoteAgentName, ss.str());
138156
}
139157

140-
void AgentConnection::setSenderState(MemoryDesc mReceiverBufferDesc, int validSegmentIdx)
158+
void AgentConnection::setSenderState(
159+
MemoryDesc mReceiverBufferDesc, int validSegmentIdx, std::pair<size_t, size_t> offsetRatio)
141160
{
142161
mSenderState.mReceiverBufferDesc = mReceiverBufferDesc;
143162
mSenderState.validSegmentIdx = validSegmentIdx;
163+
mSenderState.mOffsetRatio = offsetRatio;
144164
}
145165

146166
void AgentConnection::setHasLoadRemoteAgent(bool hasLoadRemoteAgent)
@@ -155,8 +175,9 @@ bool AgentConnection::hasLoadRemoteAgent() const
155175
}
156176

157177
AgentConnectionManager::AgentConnectionManager(
158-
batch_manager::kv_cache_manager::CacheTransBufferManager* cacheTransBufferManager)
178+
batch_manager::kv_cache_manager::CacheTransBufferManager* cacheTransBufferManager, CacheState cacheState)
159179
: mRegMemDescs(MemoryType::kVRAM, {})
180+
, mCacheState(std::move(cacheState))
160181
{
161182
TLLM_CUDA_CHECK(cudaGetDevice(&mDeviceId));
162183
TLLM_CHECK(mDeviceId != -1);
@@ -260,7 +281,10 @@ AgentConnection const* AgentConnectionManager::recvConnectionAndRequestInfo(batc
260281
auto remoteAgentName = requestAndBufferInfo.mAgentName;
261282
TLLM_LOG_DEBUG(" recv Address:%s", address.c_str());
262283
auto connection = connect(remoteAgentName, address, metadataOpt, true);
263-
connection->setSenderState(bufferDesc, validConnectionIdx);
284+
// to compute the offset.
285+
auto offsetRatio = computeSendOffsetRatio(requestInfo.getTransState().getCacheState().value(),
286+
requestInfo.getTransState().getCommState()->getSelfIdx(), mCacheState, validConnectionIdx);
287+
connection->setSenderState(bufferDesc, validConnectionIdx, offsetRatio);
264288
it2 = notifs.erase(it2);
265289
if (notifs.empty())
266290
{

cpp/tensorrt_llm/executor/cache_transmission/agent_utils/connection.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ class AgentConnection : public Connection
175175
void recv(DataContext const& ctx, void* data, size_t size) const override;
176176
void sendRequestAndBufferInfo(
177177
batch_manager::RequestInfo& requestInfo, std::optional<size_t> cacheBufferId, int validConnectionIdx);
178-
void setSenderState(MemoryDesc mReceiverBufferDesc, int valideSegmentIdx);
178+
void setSenderState(MemoryDesc mReceiverBufferDesc, int valideSegmentIdx, std::pair<size_t, size_t> offsetRatio);
179179
[[nodiscard]] std::optional<size_t> getCacheBufferId() const;
180180
void setHasLoadRemoteAgent(bool hasLoadRemoteAgent);
181181
[[nodiscard]] bool hasLoadRemoteAgent() const;
@@ -188,6 +188,7 @@ class AgentConnection : public Connection
188188
{
189189
MemoryDesc mReceiverBufferDesc{nullptr, 0, 0};
190190
int validSegmentIdx{0};
191+
std::pair<size_t, size_t> mOffsetRatio;
191192
SenderState() = default;
192193
};
193194

@@ -203,7 +204,8 @@ class AgentConnection : public Connection
203204
class AgentConnectionManager : public ConnectionManager
204205
{
205206
public:
206-
AgentConnectionManager(batch_manager::kv_cache_manager::CacheTransBufferManager* cacheTransBufferManager);
207+
AgentConnectionManager(
208+
batch_manager::kv_cache_manager::CacheTransBufferManager* cacheTransBufferManager, CacheState cacheState);
207209
~AgentConnectionManager();
208210
AgentConnection* recvConnect(DataContext const& ctx, void* data, size_t size) override;
209211
[[nodiscard]] std::vector<Connection const*> getConnections(CommState const& state) override;
@@ -222,6 +224,7 @@ class AgentConnectionManager : public ConnectionManager
222224
std::map<std::string, std::shared_ptr<AgentConnection>> mConnections;
223225
std::mutex mConnectionsMutex;
224226
CommState mCommState;
227+
CacheState mCacheState;
225228
batch_manager::kv_cache_manager::CacheTransBufferManager* mCacheTransBufferManager;
226229
std::mutex mNotificationMutex;
227230
std::unordered_map<std::string, std::list<std::string>> mUnhandledNotifications;

cpp/tests/unit_tests/executor/agentCommTest.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ class AgentCommTest : public ::testing::Test
108108

109109
TEST_F(AgentCommTest, AgentConnectionManagerBasic)
110110
{
111-
auto connectionManager = std::make_unique<AgentConnectionManager>(mTransBufferManager.get());
111+
auto connectionManager = std::make_unique<AgentConnectionManager>(mTransBufferManager.get(), *mCacheState);
112112
ASSERT_TRUE(connectionManager != nullptr);
113113
ASSERT_TRUE(connectionManager->getCacheTransBufferManager() != nullptr);
114114
ASSERT_EQ(connectionManager->getDeviceId(), 0);
@@ -121,8 +121,8 @@ TEST_F(AgentCommTest, AgentConnectionManagerBasic)
121121

122122
TEST_F(AgentCommTest, AgentConnectionManagerConnect)
123123
{
124-
auto connectionManager0 = std::make_unique<AgentConnectionManager>(mTransBufferManager.get());
125-
auto connectionManager1 = std::make_unique<AgentConnectionManager>(mTransBufferManager.get());
124+
auto connectionManager0 = std::make_unique<AgentConnectionManager>(mTransBufferManager.get(), *mCacheState);
125+
auto connectionManager1 = std::make_unique<AgentConnectionManager>(mTransBufferManager.get(), *mCacheState);
126126
auto agentName0 = connectionManager0->getAgentName();
127127
auto agentName1 = connectionManager1->getAgentName();
128128
ASSERT_TRUE(!agentName0.empty());

cpp/tests/unit_tests/multi_gpu/cacheTransceiverTest.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -776,8 +776,8 @@ class AsymmetricalCacheTest : public ::testing::TestWithParam<AsymmetricTestPara
776776

777777
setenv("TRTLLM_NIXL_PORT", std::to_string(port).c_str(), 1);
778778

779-
mConnectionManager
780-
= std::make_unique<texec::kv_cache::AgentConnectionManager>(mCacheTransBufferManager.get());
779+
mConnectionManager = std::make_unique<texec::kv_cache::AgentConnectionManager>(
780+
mCacheTransBufferManager.get(), *mCacheState);
781781
}
782782
else
783783
{
@@ -1488,6 +1488,12 @@ INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForNoMLADuplicate0, Asymmetrical
14881488
testing::Values(16), testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2),
14891489
testing::Values(false), testing::Values(true, false), testing::Values(false), testing::Values(false)));
14901490

1491+
INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForNoMLADuplicate0EvenLayer, AsymmetricalCacheTestWithDP,
1492+
testing::Combine(testing::Values(1), testing::Values(4), testing::Values(1), testing::Values(4), testing::Values(1),
1493+
testing::Values(1), testing::Values(5), testing::Values(2), testing::Values(4), testing::Values(16),
1494+
testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2),
1495+
testing::Values(false), testing::Values(true, false), testing::Values(false), testing::Values(false)));
1496+
14911497
INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForNoMLADuplicate1, AsymmetricalCacheTestWithDP,
14921498
testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1), testing::Values(2),
14931499
testing::Values(2), testing::Values(1), testing::Values(4), testing::Values(1), testing::Values(4),

0 commit comments

Comments
 (0)