Skip to content

Commit a0024f4

Browse files
authored
[None][doc] Facilitates the integration of the transfer agent (#7867)
Signed-off-by: Shixiaowei02 <[email protected]>
1 parent 653aa6b commit a0024f4

File tree

9 files changed

+154
-37
lines changed

9 files changed

+154
-37
lines changed

cpp/include/tensorrt_llm/executor/transferAgent.h

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ enum class MemoryType : uint8_t
4040
kFILE
4141
};
4242

43+
// `MemoryDesc` is used to describe a memory region, which can then be designated
44+
// as the source or destination of read/write operations.
4345
class MemoryDesc
4446
{
4547
public:
@@ -192,6 +194,8 @@ using RegisterDescs = MemoryDescs;
192194
using SyncMessage = std::string;
193195
using ConnectionInfoType = std::string;
194196

197+
// `AgentDesc` represents the unique identifier for reading and writing to the agent.
198+
// By accessing this identifier, the backend can establish the correct connection.
195199
class AgentDesc final
196200
{
197201
public:
@@ -209,15 +213,24 @@ class AgentDesc final
209213
std::string mBackendAgentDesc;
210214
};
211215

216+
// `TransferOp` is an enumeration that represents the types of transfer operations.
217+
// Currently, it supports two operations: `read` and `write`.
212218
enum class TransferOp : uint8_t
213219
{
214220
kREAD,
215221
kWRITE,
216222
};
217223

224+
// `TransferRequest` is used to represent the transfer requests supported by the underlying agent.
218225
class TransferRequest
219226
{
220227
public:
228+
/// @brief The constructor of `TransferRequest`.
229+
/// @param op Source data arrangement.
230+
/// @param srcDescs Description of the source memory region.
231+
/// @param dstDescs Description of the destination memory region.
232+
/// @param remoteName Name of the remote counterpart.
233+
/// @param syncMessage Synchronization information for the end of the transfer.
221234
TransferRequest(TransferOp op, TransferDescs srcDescs, TransferDescs dstDescs, std::string const& remoteName,
222235
std::optional<SyncMessage> syncMessage = std::nullopt)
223236
: mOp{op}
@@ -261,6 +274,7 @@ class TransferRequest
261274
std::optional<SyncMessage> mSyncMessage;
262275
};
263276

277+
// Data structure for checking the status of active transfer operations.
264278
class TransferStatus
265279
{
266280
public:
@@ -281,22 +295,52 @@ class BaseTransferAgent
281295
public:
282296
virtual ~BaseTransferAgent() = default;
283297

298+
/// @brief Register a memory region.
299+
/// @param descs Describe the memory regions to be registered.
284300
virtual void registerMemory(RegisterDescs const& descs) = 0;
285301

302+
/// @brief Unregister a memory region.
303+
/// @param descs Describe the memory regions to be unregistered.
286304
virtual void deregisterMemory(RegisterDescs const& descs) = 0;
287305

306+
/// @brief Initialize and establish a connection with a remote agent.
307+
/// @param name Specify the name of the remote agent.
308+
/// @param agentDesc Provide the necessary communication details for connecting to the remote agent.
288309
virtual void loadRemoteAgent(std::string const& name, AgentDesc const& agentDesc) = 0;
289-
virtual AgentDesc getLocalAgentDesc() = 0;
290310

311+
/// @brief Initialize and establish a connection with a remote agent.
312+
/// @param name Specify the name of the remote agent.
313+
/// @param connectionInfo Provide the necessary communication details for connecting to the remote agent.
314+
virtual void loadRemoteAgent(std::string const& name, ConnectionInfoType const& connectionInfo) = 0;
315+
316+
/// @brief Invalidate a connection with a remote agent.
317+
/// @param name Specify the name of the remote agent.
291318
virtual void invalidateRemoteAgent(std::string const& name) = 0;
292319

320+
/// @brief Fetch the descriptor of the local agent.
321+
/// @return The descriptor of the local agent.
322+
virtual AgentDesc getLocalAgentDesc() = 0;
323+
324+
/// @brief Fetch the descriptor of the local agent.
325+
/// @return The descriptor of the local agent.
326+
virtual ConnectionInfoType getLocalConnectionInfo() = 0;
327+
328+
/// @brief Initiate the transfer by submitting the request.
329+
/// @param request Specify the transmission request.
330+
/// @return The status of the requests.
293331
[[nodiscard]] virtual std::unique_ptr<TransferStatus> submitTransferRequests(TransferRequest const& request) = 0;
332+
333+
/// @brief Generate a notification, not bound to a transfer, e.g., for control.
334+
/// @param name Specify the name of the remote agent to which the information should be sent.
335+
/// @param syncMessage The data or message intended for synchronization.
294336
virtual void notifySyncMessage(std::string const& name, SyncMessage const& syncMessage) = 0;
295337

338+
/// @brief Retrieve notification messages sent by other agents.
339+
/// @return A mapping from remote agent names to their respective notification messages.
296340
virtual std::unordered_map<std::string, std::vector<SyncMessage>> getNotifiedSyncMessages() = 0;
297341

298-
virtual ConnectionInfoType getConnectionInfo() = 0;
299-
virtual void connectRemoteAgent(std::string const& name, ConnectionInfoType const& connectionInfo) = 0;
342+
/// @brief Check if metadata is available for a remote agent.
343+
/// @return Whether the metadata is available for a remote agent.
300344
virtual bool checkRemoteDescs(std::string const& name, MemoryDescs const& memoryDescs) = 0;
301345
};
302346

cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -145,16 +145,18 @@ using runtime::SizeType32;
145145
using AgentConnectionManager = tensorrt_llm::executor::kv_cache::AgentConnectionManager;
146146
using DataContext = tensorrt_llm::executor::kv_cache::DataContext;
147147

148-
static int32_t tagFromRequestId(LlmRequest::RequestIdType requestId)
148+
namespace
149+
{
150+
151+
int32_t tagFromRequestId(LlmRequest::RequestIdType requestId)
149152
{
150153
constexpr int32_t kDATA_TAG{43};
151154
return ((requestId & 0xFFF) << 8) | (kDATA_TAG & 0xFF);
152155
}
153156

154-
namespace fs = std::filesystem;
155-
156-
static fs::path getTransferOutputPath(char const* tag)
157+
std::filesystem::path getTransferOutputPath(char const* tag)
157158
{
159+
namespace fs = std::filesystem;
158160
auto outputPath = common::getEnvKVCacheTransferOutputPath();
159161
if (!outputPath.empty())
160162
{
@@ -166,13 +168,15 @@ static fs::path getTransferOutputPath(char const* tag)
166168
return {};
167169
}
168170

171+
} // namespace
172+
169173
struct ReceiveCacheResource
170174
{
171175
runtime::BufferManager mBufferManager;
172176
runtime::CudaEvent mCudaEvent;
173177

174-
ReceiveCacheResource(runtime::BufferManager&& bufferManager, runtime::CudaEvent&& cudaEvent)
175-
: mBufferManager(bufferManager)
178+
ReceiveCacheResource(runtime::BufferManager&& bufferManager, runtime::CudaEvent cudaEvent)
179+
: mBufferManager(std::move(bufferManager))
176180
, mCudaEvent(std::move(cudaEvent))
177181
{
178182
}
@@ -343,8 +347,7 @@ class CacheSender::Impl
343347
TLLM_CHECK_WITH_INFO(mFormatter->inquireSupport(
344348
mSelfState.getCacheState().value(), info.getTransState().getCacheState().value()),
345349
"Disagg server does not currently support these cacheState, please check the cacheState of the context and "
346-
"gen "
347-
"executors");
350+
"gen executors");
348351
auto peerRelativeRanks = executor::kv_cache::targetIRanks(info.getTransState().getCacheState().value(),
349352
mSelfState.getCacheState().value(), mSelfState.getCommState().value().getSelfIdx())
350353
.mIRanks;
@@ -1024,7 +1027,6 @@ class CacheReceiver::Impl
10241027

10251028
void request(AsyncResource& resource)
10261029
{
1027-
10281030
tensorrt_llm::common::setThreadName("dataTransRequest");
10291031
TLLM_CUDA_CHECK(cudaSetDevice(mDeviceId));
10301032

cpp/tensorrt_llm/executor/cache_transmission/agent_utils/connection.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ void AgentConnection::sendRequestAndBufferInfo(
144144
TLLM_CHECK(deviceId == mAgentConnectionManager->getDeviceId());
145145
MemoryDesc bufferDesc(
146146
reinterpret_cast<uintptr_t>(preAllocateBuffer->data()), preAllocateBuffer->getSize(), deviceId);
147-
std::string address = mAgentConnectionManager->getAgent()->getConnectionInfo();
147+
std::string address = mAgentConnectionManager->getAgent()->getLocalConnectionInfo();
148148
std::optional<std::string> metadataOpt = std::nullopt;
149149
if (mNeedSendMetadata)
150150
{
@@ -225,7 +225,7 @@ AgentConnectionManager::AgentConnectionManager(
225225
mRegMemDescs = MemoryDescs{MemoryType::kVRAM, MemDescs};
226226
m_Agent->registerMemory(mRegMemDescs);
227227

228-
AgentState localAgentState{mAgentName, m_Agent->getConnectionInfo()};
228+
AgentState localAgentState{mAgentName, m_Agent->getLocalConnectionInfo()};
229229
std::vector<AgentState> agentStates(mpi::MpiComm::session().getSize());
230230
if (mpi::MpiComm::session().getSize() > 1)
231231
{
@@ -411,10 +411,10 @@ AgentConnection* AgentConnectionManager::connect(std::string const& remoteAgentN
411411
}
412412
else
413413
{
414-
TLLM_CHECK_WITH_INFO(!isSender, "Sender shouldn't call connectRemoteAgent");
415-
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(), "mAgentName: %s connect to %s with connectRemoteAgent",
414+
TLLM_CHECK_WITH_INFO(!isSender, "Sender shouldn't call loadRemoteAgent");
415+
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(), "mAgentName: %s connect to %s with loadRemoteAgent",
416416
mAgentName.c_str(), remoteAgentName.c_str());
417-
m_Agent->connectRemoteAgent(remoteAgentName, connectionInfo);
417+
m_Agent->loadRemoteAgent(remoteAgentName, connectionInfo);
418418
}
419419
}
420420
else

cpp/tensorrt_llm/executor/cache_transmission/nixl_utils/transferAgent.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -469,19 +469,19 @@ void NixlTransferAgent::notifySyncMessage(std::string const& name, SyncMessage c
469469
return notifs;
470470
}
471471

472-
ConnectionInfoType NixlTransferAgent::getConnectionInfo()
472+
ConnectionInfoType NixlTransferAgent::getLocalConnectionInfo()
473473
{
474474
return mAddress;
475475
}
476476

477-
void NixlTransferAgent::connectRemoteAgent(std::string const& name, ConnectionInfoType const& connectionInfo)
477+
void NixlTransferAgent::loadRemoteAgent(std::string const& name, ConnectionInfoType const& connectionInfo)
478478
{
479479
std::string ip = connectionInfo.substr(0, connectionInfo.find(":"));
480480
std::string port = connectionInfo.substr(connectionInfo.find(":") + 1);
481481
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
482-
"NixlTransferAgent::connectRemoteAgent connectRemoteAgent to %s remoteagent name: %s", connectionInfo.c_str(),
482+
"NixlTransferAgent::loadRemoteAgent loadRemoteAgent to %s remoteagent name: %s", connectionInfo.c_str(),
483483
name.c_str());
484-
TLLM_CHECK_WITH_INFO(!ip.empty() && !port.empty(), "connectRemoteAgent get empty ip or port, connectionInfo: %s",
484+
TLLM_CHECK_WITH_INFO(!ip.empty() && !port.empty(), "loadRemoteAgent get empty ip or port, connectionInfo: %s",
485485
connectionInfo.c_str());
486486
nixl_opt_args_t md_extra_params;
487487
md_extra_params.ipAddr = ip;
@@ -506,7 +506,7 @@ void NixlTransferAgent::connectRemoteAgent(std::string const& name, ConnectionIn
506506
}
507507
}
508508
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
509-
"NixlTransferAgent::connectRemoteAgent connectRemoteAgent to %s remoteagent name: %s success status: %s",
509+
"NixlTransferAgent::loadRemoteAgent loadRemoteAgent to %s remoteagent name: %s success status: %s",
510510
connectionInfo.c_str(), name.c_str(), nixlEnumStrings::statusStr(status).c_str());
511511
}
512512

cpp/tensorrt_llm/executor/cache_transmission/nixl_utils/transferAgent.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,9 @@ class NixlTransferAgent final : public BaseTransferAgent
8484

8585
[[nodiscard]] std::unordered_map<std::string, std::vector<SyncMessage>> getNotifiedSyncMessages() override;
8686

87-
ConnectionInfoType getConnectionInfo() override;
87+
ConnectionInfoType getLocalConnectionInfo() override;
8888

89-
void connectRemoteAgent(std::string const& name, ConnectionInfoType const& connectionInfo) override;
89+
void loadRemoteAgent(std::string const& name, ConnectionInfoType const& connectionInfo) override;
9090

9191
bool checkRemoteDescs(std::string const& name, MemoryDescs const& memoryDescs) override;
9292

cpp/tests/unit_tests/executor/transferAgentTest.cpp

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,8 @@ TEST_F(TransferAgentTest, Basic)
8080
RegisteredHostMemory regMem1(MemoryDescs{MemoryType::kDRAM, {MemoryDesc{memory1}}}, nixlAgent1.get());
8181

8282
// nixlAgent0->loadRemoteAgent(agent1);
83-
auto connectionInfo = nixlAgent1->getConnectionInfo();
84-
nixlAgent0->connectRemoteAgent(agent1, connectionInfo);
83+
auto connectionInfo = nixlAgent1->getLocalConnectionInfo();
84+
nixlAgent0->loadRemoteAgent(agent1, connectionInfo);
8585
bool checked = false;
8686
do
8787
{
@@ -116,8 +116,8 @@ TEST_F(TransferAgentTest, Basic2)
116116
RegisteredHostMemory regMem1(MemoryDescs{MemoryType::kDRAM, {MemoryDesc{memory1}}}, nixlAgent1.get());
117117

118118
// nixlAgent0->loadRemoteAgent(agent1);
119-
auto connectionInfo = nixlAgent1->getConnectionInfo();
120-
nixlAgent0->connectRemoteAgent(agent1, connectionInfo);
119+
auto connectionInfo = nixlAgent1->getLocalConnectionInfo();
120+
nixlAgent0->loadRemoteAgent(agent1, connectionInfo);
121121
bool checked = false;
122122
do
123123
{
@@ -159,8 +159,8 @@ TEST_F(TransferAgentTest, DeviceMemory)
159159
MemoryDescs{MemoryType::kVRAM, {MemoryDesc{dev_ptr1, size, deviceId}}}, nixlAgent1.get());
160160

161161
// nixlAgent0->loadRemoteAgent(agent1);
162-
auto connectionInfo = nixlAgent1->getConnectionInfo();
163-
nixlAgent0->connectRemoteAgent(agent1, connectionInfo);
162+
auto connectionInfo = nixlAgent1->getLocalConnectionInfo();
163+
nixlAgent0->loadRemoteAgent(agent1, connectionInfo);
164164
bool checked = false;
165165
do
166166
{
@@ -201,8 +201,8 @@ TEST_F(TransferAgentTest, Connect)
201201
nixlAgent2->registerMemory(memDescs0);
202202

203203
// nixlAgent0->loadRemoteAgent(agent1);
204-
auto connectionInfo = nixlAgent1->getConnectionInfo();
205-
nixlAgent0->connectRemoteAgent(agent1, connectionInfo);
204+
auto connectionInfo = nixlAgent1->getLocalConnectionInfo();
205+
nixlAgent0->loadRemoteAgent(agent1, connectionInfo);
206206
bool checked = false;
207207
do
208208
{
@@ -213,7 +213,7 @@ TEST_F(TransferAgentTest, Connect)
213213
status->wait();
214214

215215
TLLM_CHECK(memory0 == memory1);
216-
nixlAgent2->connectRemoteAgent(agent1, connectionInfo);
216+
nixlAgent2->loadRemoteAgent(agent1, connectionInfo);
217217
checked = false;
218218
do
219219
{
@@ -251,8 +251,8 @@ TEST_F(TransferAgentTest, SyncMessage)
251251
RegisteredHostMemory regMem3(MemoryDescs{MemoryType::kDRAM, {MemoryDesc{memory1}}}, nixlAgent1.get());
252252

253253
// nixlAgent0->loadRemoteAgent(agent1);
254-
auto connectionInfo = nixlAgent1->getConnectionInfo();
255-
nixlAgent0->connectRemoteAgent(agent1, connectionInfo);
254+
auto connectionInfo = nixlAgent1->getLocalConnectionInfo();
255+
nixlAgent0->loadRemoteAgent(agent1, connectionInfo);
256256
bool checked = false;
257257
do
258258
{
@@ -287,8 +287,8 @@ TEST_F(TransferAgentTest, SyncMessage)
287287
TLLM_CHECK(notif2[agent0][0] == syncMessage2);
288288

289289
// nixlAgent1->loadRemoteAgent(agent0);
290-
auto connectionInfo2 = nixlAgent0->getConnectionInfo();
291-
nixlAgent1->connectRemoteAgent(agent0, connectionInfo2);
290+
auto connectionInfo2 = nixlAgent0->getLocalConnectionInfo();
291+
nixlAgent1->loadRemoteAgent(agent0, connectionInfo2);
292292
std::string syncMessage3 = "three_agent_sync_message";
293293
nixlAgent1->notifySyncMessage(agent0, syncMessage3);
294294
auto notif3 = nixlAgent0->getNotifiedSyncMessages();
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Introduction to KV Cache Transmission
2+
3+
This article provides a general overview of the components used for device-to-device transmission of KV cache, which is relied upon by dist-serving. It is intended as a reference for users who wish to understand the internal implementation or develop extended functionalities.
4+
5+
## Table of Contents
6+
7+
- [Workflow](#workflow)
8+
- [Key Components](#key-components)
9+
- [Transceiver](#transceiver)
10+
- [Sender and Receiver](#sender-and-receiver)
11+
- [Formatter](#formatter)
12+
- [Connection](#connection)
13+
- [Transfer Agent](#transfer-agent)
14+
- [Customization](#customization)
15+
- [Encapsulation and Overloading of Low-Level Communication Libraries](#encapsulation-and-overloading-of-low-level-communication-libraries)
16+
- [Modifications to Upper-Level Runtime Logic](#modifications-to-upper-level-runtime-logic)
17+
- [Evolution Outlook](#evolution-outlook)
18+
19+
## Workflow
20+
21+
<img src="https://github.com/NVIDIA/TensorRT-LLM/blob/rel/docs/source/media/kv_transfer.png?raw=true" alt="KV Cache Transfer Overview" width="500" height="auto">
22+
23+
1. Context phase completes computation, KV cache stays in device memory awaiting transmission.
24+
2. Context returns its communicator handle to the user, who selects the generation executor for continued communication.
25+
3. If no prior connection exists, it's established now. Generation phase shares its cache layout with context.
26+
4. Generation phase requests KV cache for specific tokens.
27+
5. Context sends KV cache to generation phase.
28+
6. Generation phase resumes computation, context releases KV cache.
29+
30+
## Key Components
31+
32+
### Transceiver
33+
34+
Responsible for coordinating the sending and receiving of cache among different ranks within the same executor.
35+
36+
### Sender and Receiver
37+
38+
Responsible for transmitting control plane messages. That is, during per-request transmission, the receiver bound to the generation informs the sender of the specific information it requires. The sender then sends the corresponding KV cache based on these messages.
39+
40+
### Formatter
41+
42+
Performs KV cache data transmission and correctly handles the mapping between caches across different TP/PP configurations.
43+
44+
### Connection
45+
46+
Bidirectional byte-stream protocol facility. Apart from essential operations such as connection establishment, it mainly provides send and receive functionalities. UCX accesses the system through this facility. The `AgentConnection` data structure adapts the upper-layer bidirectional send/receive semantics into a unidirectional read/write operation model.
47+
48+
### Transfer Agent
49+
50+
Unidirectional byte-stream read/write protocol facility. Apart from essential operations such as connection establishment, it primarily provides read and write functionalities. NIXL accesses the system through this facility.
51+
52+
## Customization
53+
54+
At the current stage, the customization work mainly involves inheriting the low-level data plane interfaces to enable the invocation of third-party communication libraries, as well as defining the data structures required for establishing connections in the data plane.
55+
56+
### Encapsulation and Overloading of Low-Level Communication Libraries
57+
58+
Each layer of interface described in the previous section supports overloading. Here, based on whether the underlying library uses a unidirectional or bidirectional protocol, we describe the customization methods respectively.
59+
60+
If the underlying library you are integrating uses a unidirectional communication model, with read/write as its primary interfaces, you should inherit the `executor::kv_cache::BaseTransferAgent` data structure. This structure mainly provides interfaces for memory registration, remote agent loading, and transfer request submission.
61+
62+
If the underlying library you are integrating uses a bidirectional communication model, you should inherit the `executor::kv_cache::Connection` data structure. This structure mainly provides send and receive interfaces.
63+
64+
### Modifications to Upper-Level Runtime Logic
65+
66+
This corresponds to the communication info section shown in the figure above. Since different underlying communication connections may require completely different setup methods—for example, some use IP and port, others require a world rank, and some communication libraries establish connections using binary-transparent metadata—we provide sufficient flexibility to allow users to customize this part as needed.
67+
68+
## Evolution Outlook
69+
70+
Currently, the architecture of KV transfer is being optimized. First, we plan to move the control plane logic up to Python to enable better integration with the Python runtime. In addition, we are reevaluating the current design choice of initiating communication only after the context computation is completed, which was originally made for flexibility. Lastly, since some control logic is still being transmitted through the data plane, we aim to clarify the relationship between the control and data planes, and to simplify and streamline the code logic of the data plane. Due to the modular architecture, these iterative enhancements are only loosely coupled with the `TransferAgent`. We aim to minimize the impact of future upgrades on third-party integrations.

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ Welcome to TensorRT LLM's Documentation!
8585
developer-guide/ci-overview.md
8686
developer-guide/dev-containers.md
8787
developer-guide/api-change.md
88+
developer-guide/kv-transfer.md
8889

8990

9091
.. toctree::

docs/source/media/kv_transfer.png

198 KB
Loading

0 commit comments

Comments
 (0)