1818#pragma once
1919
2020#include " cacheTransBuffer.h"
21- #include " dataTransceiver.h"
2221#include " tensorrt_llm/batch_manager/kvCacheManager.h"
2322#include " tensorrt_llm/batch_manager/kvCacheUtils.h"
2423#include " tensorrt_llm/common/envUtils.h"
2524#include " tensorrt_llm/common/logger.h"
25+ #include " tensorrt_llm/executor/cacheCommunicator.h"
2626#include " tensorrt_llm/executor/cache_transmission/cacheSplitConcat.h"
2727#include " tensorrt_llm/executor/dataTransceiverState.h"
2828#include " tensorrt_llm/runtime/bufferManager.h"
@@ -38,6 +38,88 @@ BlockRange getBlockRangeForSending(BaseKVCacheManager* cacheManager, LlmRequest
3838
3939BlockRange getBlockRangeForReceiving (BaseKVCacheManager* cacheManager, LlmRequest const & llmRequest);
4040
41+ using DataContext = tensorrt_llm::executor::kv_cache::DataContext;
42+ using Connection = tensorrt_llm::executor::kv_cache::Connection;
43+ using SizeType32 = tensorrt_llm::runtime::SizeType32;
44+
45+ class TransferSession
46+ {
47+ public:
48+ TransferSession (std::vector<Connection const *> connections, DataContext dataContext,
49+ executor::DataTransceiverState const & selfState, executor::DataTransceiverState otherState,
50+ runtime::BufferManager const & bufferManager, LlmRequest const * llmRequest = nullptr )
51+ : mConnections (std::move(connections))
52+ , mDataContext (dataContext)
53+ , mSelfState (&selfState)
54+ , mOtherState (std::move(otherState))
55+ , mBufferManager (&bufferManager)
56+ , mRequest (llmRequest)
57+ {
58+ TLLM_CHECK (!mConnections .empty ());
59+ }
60+
61+ [[nodiscard]] std::vector<Connection const *> const & getConnections () const
62+ {
63+ return mConnections ;
64+ }
65+
66+ // should be called only during the initialization of the TransferSession
67+ void setConnection (size_t idx, Connection const * conn)
68+ {
69+ mConnections .at (idx) = conn;
70+ }
71+
72+ [[nodiscard]] DataContext const & getDataContext () const
73+ {
74+ return mDataContext ;
75+ }
76+
77+ [[nodiscard]] executor::DataTransceiverState const & getSelfState () const
78+ {
79+ return *mSelfState ;
80+ }
81+
82+ [[nodiscard]] executor::DataTransceiverState const & getOtherState () const
83+ {
84+ return mOtherState ;
85+ }
86+
87+ [[nodiscard]] runtime::BufferManager const & getBufferManager () const
88+ {
89+ return *mBufferManager ;
90+ }
91+
92+ void send (size_t idx, void const * data, size_t size)
93+ {
94+ mConnections .at (idx)->send (mDataContext , data, size);
95+ }
96+
97+ void recv (size_t idx, void * data, size_t size)
98+ {
99+ mConnections .at (idx)->recv (mDataContext , data, size);
100+ }
101+
102+ [[nodiscard]] LlmRequest const & getLlmRequest () const
103+ {
104+ TLLM_CHECK (mRequest != nullptr );
105+ return *mRequest ;
106+ }
107+
108+ // in CacheSender, the LlmRequest is not available until the sendSync is called
109+ void setLlmRequest (LlmRequest const & llmRequest)
110+ {
111+ mRequest = &llmRequest;
112+ }
113+
114+ private:
115+ std::vector<Connection const *> mConnections ;
116+ DataContext mDataContext ;
117+ executor::DataTransceiverState const * mSelfState ; // stored in CacheReceiver/CacheSender
118+ executor::DataTransceiverState mOtherState ;
119+ runtime::BufferManager const * mBufferManager ;
120+ LlmRequest const * mRequest ;
121+ };
122+
41123// Used to support the cache transmission with different layouts and different protocols.
42124class BaseCacheFormatter
43125{
@@ -78,6 +160,66 @@ class BaseCacheFormatter
78160 virtual ~BaseCacheFormatter () = default ;
79161};
80162
163+ class KvCacheMeasureHelper
164+ {
165+ public:
166+ KvCacheMeasureHelper (std::string output_path)
167+ : mOutputPath (std::move(output_path))
168+ {
169+ }
170+
171+ void appendKVCacheTransfer (LlmRequest::RequestIdType requestId, double duration, size_t size)
172+ {
173+ auto bandwidth = size * 8 / (duration / 1000 ) / 1e9 ;
174+ if (mOutputPath .empty ())
175+ {
176+ return ;
177+ }
178+
179+ std::lock_guard<std::mutex> lock (mMutex );
180+ mRequestKVCacheTranfserMeasure [requestId].emplace_back (duration, bandwidth);
181+ }
182+
183+ ~KvCacheMeasureHelper ()
184+ {
185+ if (!mRequestKVCacheTranfserMeasure .empty () && !mOutputPath .empty ())
186+ {
187+ auto rank = mpi::MpiComm::world ().getRank ();
188+ std::string outFilePath = mOutputPath + " rank_" + std::to_string (rank) + " .txt" ;
189+ std::ofstream outFile (outFilePath);
190+
191+ TLLM_CHECK_WITH_INFO (outFile.is_open (), " Cannot write to file " + outFilePath);
192+
193+ size_t numTransferMeasure = mRequestKVCacheTranfserMeasure .begin ()->second .size ();
194+
195+ outFile << " RequestID" ;
196+ for (size_t i = 0 ; i < numTransferMeasure; i++)
197+ {
198+ outFile << " ,TimeDuration,Bandwidth" ;
199+ }
200+ outFile << ' \n ' ;
201+
202+ for (auto const & [requestID, measures] : mRequestKVCacheTranfserMeasure )
203+ {
204+ outFile << requestID;
205+
206+ for (auto const & [time, bandwidth] : measures)
207+ {
208+ outFile << " ," << time << " ," << bandwidth;
209+ }
210+ outFile << ' \n ' ;
211+ }
212+
213+ outFile.close ();
214+ }
215+ }
216+
217+ private:
218+ std::map<LlmRequest::RequestIdType, std::vector<std::pair<double , double >>> mRequestKVCacheTranfserMeasure ;
219+ std::string mOutputPath ;
220+ std::mutex mMutex ;
221+ };
222+
81223// Simple cache block copy. Because it does not involve data splitting or merging, it performs best when the
82224// parallel topology is completely identical, making it the preferred method.
83225class CacheFormatter final : public BaseCacheFormatter
0 commit comments