Skip to content

Commit e1d67c6

Browse files
ShunkangShunkang
authored andcommitted
Add NIXL support
Signed-off-by: Shunkang <[email protected]>
1 parent 2b3dc94 commit e1d67c6

File tree

4 files changed

+188
-33
lines changed

4 files changed

+188
-33
lines changed

cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp

Lines changed: 46 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -231,21 +231,54 @@ class DataResponder::Impl
231231
{
232232
mRemainSendCount.erase(reqId);
233233

234-
// TODO(zhengd): pass the hashes directly instead of update llmRequest
235-
auto llmRequest = it->second.mRequest;
236-
llmRequest->setRequestedBlockHashes(std::move(blockHashes));
234+
// Check if the request is cancelled
235+
bool isReady = true;
236+
{
237+
std::unique_lock lk(mResponderMutex);
238+
if (mCancelledRequests.find(reqId) != mCancelledRequests.end())
239+
{
240+
isReady = false;
241+
}
242+
}
243+
mSender->sendReadySignal(reqId, isReady);
237244

238-
if (common::getEnvParallelCacheSend())
245+
if (isReady)
239246
{
240-
// TODO: Use a thread pool and check for thread safety.
241-
std::thread(&DataResponder::Impl::sendAndRemoveResponse, this, it->first, std::move(it->second))
242-
.detach();
247+
// TODO(zhengd): pass the hashes directly instead of update llmRequest
248+
auto llmRequest = it->second.mRequest;
249+
llmRequest->setRequestedBlockHashes(std::move(blockHashes));
250+
251+
if (common::getEnvParallelCacheSend())
252+
{
253+
// TODO: Use a thread pool and check for thread safety.
254+
std::thread(&DataResponder::Impl::sendAndRemoveResponse, this, it->first, std::move(it->second))
255+
.detach();
256+
}
257+
else
258+
{
259+
DataResponder::Impl::sendAndRemoveResponse(it->first, std::move(it->second));
260+
}
261+
removeResponse(it);
243262
}
244263
else
245264
{
246-
DataResponder::Impl::sendAndRemoveResponse(it->first, std::move(it->second));
265+
// TODO: if the generation does not require the kv cache, the request will
266+
// not be removed from mCancelledRequests.
267+
auto it = mReadyResponses.find(mCurrentRequest.value());
268+
{
269+
std::unique_lock lkResp(mResponderMutex);
270+
mReadyResponses.erase(it);
271+
mCancelledRequests.erase(mCurrentRequest.value());
272+
mRemainSendCount.erase(mCurrentRequest.value());
273+
}
274+
mCurrentRequest = std::nullopt;
275+
276+
if (mReadyResponses.empty())
277+
{
278+
std::unique_lock lk(mCondMutex);
279+
mAnyReady = false;
280+
}
247281
}
248-
removeResponse(it);
249282
}
250283
mCurrentRequest = std::nullopt;
251284
}
@@ -274,25 +307,18 @@ class DataResponder::Impl
274307
auto reqId = requestInfo.getRequestId();
275308
blockHashes = requestInfo.getBlockHashes();
276309

277-
bool isReady = true;
278310
{
279311
std::unique_lock lk(mResponderMutex);
280312
mCurrentRequest = reqId;
281-
if (mCancelledRequests.find(reqId) != mCancelledRequests.end())
282-
{
283-
isReady = false;
284-
}
285313
}
286-
mSender->sendReadySignal(reqId, isReady);
287314

288315
if (mRemainSendCount.find(reqId) == mRemainSendCount.end())
289316
{
290317
mRemainSendCount[reqId] = mSender->getCounterpartsCount(reqId);
291318
}
292319
}
293320
auto it = getCurrentResponse();
294-
bool isReady = !isCancelled(mCurrentRequest.value());
295-
if (it != mReadyResponses.end() && isReady)
321+
if (it != mReadyResponses.end())
296322
{
297323
sendResponse(blockHashes, it);
298324
}
@@ -486,7 +512,9 @@ class DataRequester::Impl
486512
bool isReady = mReceiver->receiveReadySignal(session);
487513
if (!isReady)
488514
{
489-
// TODO: set the error state for the request
515+
// Reuse the error state for the cancelled request.
516+
llmRequest.setState(LlmRequestState::kDISAGG_TRANS_ERROR);
517+
llmRequest.setKvCacheTransferEnd(std::chrono::steady_clock::now());
490518
return;
491519
}
492520
mReceiver->receiveSync(session);

cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.cpp

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,17 @@ void DataSenderImpl::sendReadySignal(LlmRequest::RequestIdType requestId, bool i
128128
auto connections = session.getConnections();
129129
for (size_t i = 0; i < connections.size(); i++)
130130
{
131-
connections.at(i)->send(executor::kv_cache::DataContext{kREADY_SIGNAL_TAG}, &isReady, sizeof(isReady));
131+
auto* agentConnectionManager = dynamic_cast<executor::kv_cache::AgentConnectionManager*>(mManager);
132+
if (agentConnectionManager != nullptr)
133+
{
134+
auto* agentConnection = dynamic_cast<executor::kv_cache::AgentConnection const*>(connections.at(i));
135+
TLLM_CHECK(agentConnection != nullptr);
136+
agentConnection->sendReadySignal(executor::kv_cache::DataContext{kREADY_SIGNAL_TAG}, isReady);
137+
}
138+
else
139+
{
140+
connections.at(i)->send(executor::kv_cache::DataContext{kREADY_SIGNAL_TAG}, &isReady, sizeof(isReady));
141+
}
132142
}
133143
}
134144

@@ -282,7 +292,17 @@ bool DataReceiverImpl::receiveReadySignal(TransferSession& session)
282292
// TODO: check if the logic is correct
283293
for (size_t i = 0; i < connections.size(); i++)
284294
{
285-
connections.at(i)->recv(executor::kv_cache::DataContext{kREADY_SIGNAL_TAG}, &isReady, sizeof(isReady));
295+
auto* agentConnectionManager = dynamic_cast<executor::kv_cache::AgentConnectionManager*>(mManager);
296+
if (agentConnectionManager != nullptr)
297+
{
298+
auto* agentConnection = dynamic_cast<executor::kv_cache::AgentConnection const*>(connections.at(i));
299+
TLLM_CHECK(agentConnection != nullptr);
300+
isReady = agentConnection->recvReadySignal(executor::kv_cache::DataContext{kREADY_SIGNAL_TAG});
301+
}
302+
else
303+
{
304+
connections.at(i)->recv(executor::kv_cache::DataContext{kREADY_SIGNAL_TAG}, &isReady, sizeof(isReady));
305+
}
286306
isReadyFinal &= isReady;
287307
}
288308

cpp/tensorrt_llm/executor/cache_transmission/agent_utils/connection.cpp

Lines changed: 67 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,22 @@ void AgentConnection::recv(DataContext const& ctx, void* data, size_t size) cons
104104
mAgentConnectionManager->waitForSyncInfo(mRemoteAgentName, syncInfo);
105105
}
106106

107+
void AgentConnection::sendReadySignal(DataContext const& ctx, bool isReady) const
108+
{
109+
ReadySignalInfo readySignalInfo{mRemoteAgentName, ctx, isReady};
110+
NotificationInfo notificationInfo{readySignalInfo};
111+
std::stringstream ss;
112+
NotificationInfo::serialize(notificationInfo, ss);
113+
mAgentConnectionManager->getAgent()->notifySyncMessage(mRemoteAgentName, ss.str());
114+
}
115+
116+
bool AgentConnection::recvReadySignal(DataContext const& ctx) const
117+
{
118+
ReadySignalInfo readySignalInfo{mAgentName, ctx, false};
119+
mAgentConnectionManager->waitForReadySignal(mRemoteAgentName, readySignalInfo);
120+
return true;
121+
}
122+
107123
void AgentConnection::sendRequestAndBufferInfo(
108124
batch_manager::RequestInfo& requestInfo, std::optional<size_t> cacheBufferId, int validConnectionIdx)
109125
{
@@ -401,11 +417,11 @@ int AgentConnectionManager::getDeviceId() const
401417
return mDeviceId;
402418
}
403419

404-
void AgentConnectionManager::waitForSyncInfo(std::string const& remoteAgentName, NotificationSyncInfo const& syncInfo)
420+
template <typename NotificationType>
421+
void AgentConnectionManager::waitForNotification(std::string const& remoteAgentName, NotificationType& expectedInfo)
405422
{
406423
while (true)
407424
{
408-
409425
updateUnhandledNotifications();
410426
std::scoped_lock lock(mNotificationMutex);
411427
auto it = mUnhandledNotifications.begin();
@@ -423,19 +439,43 @@ void AgentConnectionManager::waitForSyncInfo(std::string const& remoteAgentName,
423439
std::stringstream ss(*it2);
424440
NotificationInfo notificationInfo = NotificationInfo::deserialize(ss);
425441
bool erase = false;
426-
if (std::holds_alternative<NotificationSyncInfo>(notificationInfo.mInfo))
442+
if constexpr (std::is_same_v<NotificationType, NotificationSyncInfo>)
443+
{
444+
if (std::holds_alternative<NotificationSyncInfo>(notificationInfo.mInfo))
445+
{
446+
auto notificationData = std::get<NotificationSyncInfo>(notificationInfo.mInfo);
447+
if (notificationData.mContext.getTag() == expectedInfo.mContext.getTag()
448+
&& notificationData.mAgentName == expectedInfo.mAgentName)
449+
{
450+
erase = true;
451+
it2 = notifs.erase(it2);
452+
if (notifs.empty())
453+
{
454+
it = mUnhandledNotifications.erase(it);
455+
}
456+
return;
457+
}
458+
}
459+
}
460+
461+
if constexpr (std::is_same_v<NotificationType, ReadySignalInfo>)
427462
{
428-
auto notificationSyncInfo = std::get<NotificationSyncInfo>(notificationInfo.mInfo);
429-
if (notificationSyncInfo.mContext.getTag() == syncInfo.mContext.getTag()
430-
&& notificationSyncInfo.mAgentName == syncInfo.mAgentName)
463+
if (std::holds_alternative<ReadySignalInfo>(notificationInfo.mInfo))
431464
{
432-
erase = true;
433-
it2 = notifs.erase(it2);
434-
if (notifs.empty())
465+
auto readySignalData = std::get<ReadySignalInfo>(notificationInfo.mInfo);
466+
if (readySignalData.mContext.getTag() == expectedInfo.mContext.getTag()
467+
&& readySignalData.mAgentName == expectedInfo.mAgentName)
435468
{
436-
it = mUnhandledNotifications.erase(it);
469+
expectedInfo.mIsReady = readySignalData.mIsReady;
470+
471+
erase = true;
472+
it2 = notifs.erase(it2);
473+
if (notifs.empty())
474+
{
475+
it = mUnhandledNotifications.erase(it);
476+
}
477+
return;
437478
}
438-
return;
439479
}
440480
}
441481
if (!erase)
@@ -455,6 +495,22 @@ void AgentConnectionManager::waitForSyncInfo(std::string const& remoteAgentName,
455495
}
456496
}
457497

498+
// Explicit template instantiations
499+
template void AgentConnectionManager::waitForNotification<NotificationSyncInfo>(
500+
std::string const& remoteAgentName, NotificationSyncInfo& expectedInfo);
501+
template void AgentConnectionManager::waitForNotification<ReadySignalInfo>(
502+
std::string const& remoteAgentName, ReadySignalInfo& expectedInfo);
503+
504+
void AgentConnectionManager::waitForSyncInfo(std::string const& remoteAgentName, NotificationSyncInfo& syncInfo)
505+
{
506+
waitForNotification(remoteAgentName, syncInfo);
507+
}
508+
509+
void AgentConnectionManager::waitForReadySignal(std::string const& remoteAgentName, ReadySignalInfo& readySignalInfo)
510+
{
511+
waitForNotification(remoteAgentName, readySignalInfo);
512+
}
513+
458514
std::string const& AgentConnectionManager::getAgentName() const
459515
{
460516
return mAgentName;

cpp/tensorrt_llm/executor/cache_transmission/agent_utils/connection.h

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,10 +101,42 @@ struct NotificationSyncInfo
101101
}
102102
};
103103

104+
struct ReadySignalInfo
105+
{
106+
std::string mAgentName;
107+
DataContext mContext;
108+
bool mIsReady;
109+
110+
static void serialize(ReadySignalInfo const& readySignalInfo, std::ostream& os)
111+
{
112+
namespace su = executor::serialize_utils;
113+
su::serialize(readySignalInfo.mAgentName, os);
114+
su::serialize(readySignalInfo.mContext.getTag(), os);
115+
su::serialize(readySignalInfo.mIsReady, os);
116+
}
117+
118+
static ReadySignalInfo deserialize(std::istream& is)
119+
{
120+
namespace su = executor::serialize_utils;
121+
auto agentName = su::deserialize<decltype(mAgentName)>(is);
122+
auto contextTag = su::deserialize<decltype(mContext.getTag())>(is);
123+
DataContext context{contextTag};
124+
auto isReady = su::deserialize<decltype(mIsReady)>(is);
125+
return ReadySignalInfo{agentName, context, isReady};
126+
}
127+
128+
static size_t serializedSize(ReadySignalInfo const& readySignalInfo)
129+
{
130+
namespace su = executor::serialize_utils;
131+
return su::serializedSize(readySignalInfo.mAgentName) + su::serializedSize(readySignalInfo.mContext.getTag())
132+
+ su::serializedSize(readySignalInfo.mIsReady);
133+
}
134+
};
135+
104136
struct NotificationInfo
105137
{
106138

107-
std::variant<RequestAndBufferInfo, NotificationSyncInfo> mInfo;
139+
std::variant<RequestAndBufferInfo, NotificationSyncInfo, ReadySignalInfo> mInfo;
108140

109141
static void serialize(NotificationInfo const& notificationInfo, std::ostream& os)
110142
{
@@ -118,6 +150,10 @@ struct NotificationInfo
118150
{
119151
NotificationSyncInfo::serialize(std::get<NotificationSyncInfo>(notificationInfo.mInfo), os);
120152
}
153+
else if (std::holds_alternative<ReadySignalInfo>(notificationInfo.mInfo))
154+
{
155+
ReadySignalInfo::serialize(std::get<ReadySignalInfo>(notificationInfo.mInfo), os);
156+
}
121157
else
122158
{
123159
TLLM_THROW("Unknown variant type");
@@ -130,6 +166,7 @@ struct NotificationInfo
130166
auto variantIdx = su::deserialize<std::size_t>(is);
131167
constexpr std::size_t requestAndBufferInfoIdx{0};
132168
constexpr std::size_t notificationSyncInfoIdx{1};
169+
constexpr std::size_t readySignalInfoIdx{2};
133170
if (variantIdx == requestAndBufferInfoIdx)
134171
{
135172
return NotificationInfo{RequestAndBufferInfo::deserialize(is)};
@@ -138,6 +175,10 @@ struct NotificationInfo
138175
{
139176
return NotificationInfo{NotificationSyncInfo::deserialize(is)};
140177
}
178+
else if (variantIdx == readySignalInfoIdx)
179+
{
180+
return NotificationInfo{ReadySignalInfo::deserialize(is)};
181+
}
141182
else
142183
{
143184
TLLM_THROW("Unknown variant type");
@@ -157,6 +198,10 @@ struct NotificationInfo
157198
{
158199
totalSize += NotificationSyncInfo::serializedSize(std::get<NotificationSyncInfo>(notificationInfo.mInfo));
159200
}
201+
else if (std::holds_alternative<ReadySignalInfo>(notificationInfo.mInfo))
202+
{
203+
totalSize += ReadySignalInfo::serializedSize(std::get<ReadySignalInfo>(notificationInfo.mInfo));
204+
}
160205
else
161206
{
162207
TLLM_THROW("Unknown variant type");
@@ -179,6 +224,8 @@ class AgentConnection : public Connection
179224
[[nodiscard]] std::optional<size_t> getCacheBufferId() const;
180225
void setHasLoadRemoteAgent(bool hasLoadRemoteAgent);
181226
[[nodiscard]] bool hasLoadRemoteAgent() const;
227+
void sendReadySignal(DataContext const& ctx, bool isReady) const;
228+
bool recvReadySignal(DataContext const& ctx) const;
182229

183230
private:
184231
std::string mAgentName;
@@ -216,7 +263,11 @@ class AgentConnectionManager : public ConnectionManager
216263
std::optional<std::string> metadata = std::nullopt, bool isSender = false);
217264
int getDeviceId() const;
218265
[[nodiscard]] std::string const& getAgentName() const;
219-
void waitForSyncInfo(std::string const& remoteAgentName, NotificationSyncInfo const& syncInfo);
266+
267+
template <typename NotificationType>
268+
void waitForNotification(std::string const& remoteAgentName, NotificationType& expectedInfo);
269+
void waitForSyncInfo(std::string const& remoteAgentName, NotificationSyncInfo& syncInfo);
270+
void waitForReadySignal(std::string const& remoteAgentName, ReadySignalInfo& readySignalInfo);
220271

221272
private:
222273
std::map<std::string, std::shared_ptr<AgentConnection>> mConnections;

0 commit comments

Comments
 (0)