@@ -101,10 +101,52 @@ class DataResponder::Impl
101101 {
102102 TLLM_CHECK (mSender );
103103 TLLM_CUDA_CHECK (cudaGetDevice (&mDeviceId ));
104- mCurrentRequest = std::nullopt ;
105104 mResponseFuture = std::async (std::launch::async, &Impl::response, this );
106105 }
107106
107+ void sendResponse (RequestIdType reqId) noexcept
108+ {
109+ std::unique_lock lk (mSendMutex );
110+ // Send context cache is not called for this request yet.
111+ std::unique_lock<std::mutex> lkResp (mResponderMutex );
112+ auto readyResponseIt = mReadyResponses .find (reqId);
113+ if (readyResponseIt == mReadyResponses .end ())
114+ {
115+ return ;
116+ }
117+ lkResp.unlock ();
118+ auto it = mRequestInfoMap .find (reqId);
119+ if (it == mRequestInfoMap .end ())
120+ {
121+ return ;
122+ }
123+ auto blockHashes = it->second .getBlockHashes ();
124+ auto count = --mRemainSendCount [reqId];
125+ TLLM_CHECK (count >= 0 );
126+ if (count == 0 )
127+ {
128+ mRequestInfoMap .erase (it);
129+ mRemainSendCount .erase (reqId);
130+
131+ // TODO(zhengd): pass the hashes directly instead of update llmRequest
132+ auto llmRequest = readyResponseIt->second .mRequest ;
133+ llmRequest->setRequestedBlockHashes (std::move (blockHashes));
134+
135+ if (common::getEnvParallelCacheSend ())
136+ {
137+ // TODO: Use a thread pool and check for thread safety.
138+ std::thread (
139+ &DataResponder::Impl::sendAndRemoveResponse, this , reqId, std::move (readyResponseIt->second ))
140+ .detach ();
141+ }
142+ else
143+ {
144+ DataResponder::Impl::sendAndRemoveResponse (reqId, std::move (readyResponseIt->second ));
145+ }
146+ removeResponse (readyResponseIt);
147+ }
148+ }
149+
108150 [[nodiscard]] std::future<void > respondAndSendAsync (LlmRequest& llmRequest)
109151 {
110152 std::promise<void > promise;
@@ -115,6 +157,7 @@ class DataResponder::Impl
115157 mReadyResponses .emplace (
116158 llmRequest.mRequestId , Response{std::addressof (llmRequest), std::move (promise)});
117159 }
160+ sendResponse (llmRequest.mRequestId );
118161 std::unique_lock lkCond (mCondMutex );
119162 mAnyReady = true ;
120163 }
@@ -178,56 +221,18 @@ class DataResponder::Impl
178221 break ;
179222 }
180223 std::vector<size_t > blockHashes;
181- if (!isSending () && !mReadyResponses .empty ())
224+ auto const & requestInfo = mSender ->recvRequestInfo ();
225+ auto reqId = requestInfo.getRequestId ();
226+ blockHashes = requestInfo.getBlockHashes ();
182227 {
183- auto const & requestInfo = mSender ->recvRequestInfo ();
184- auto reqId = requestInfo.getRequestId ();
185- blockHashes = requestInfo.getBlockHashes ();
186-
187- mCurrentRequest = reqId;
228+ std::unique_lock lk (mSendMutex );
229+ mRequestInfoMap [reqId] = std::move (requestInfo);
188230 if (mRemainSendCount .find (reqId) == mRemainSendCount .end ())
189231 {
190232 mRemainSendCount [reqId] = mSender ->getCounterpartsCount (reqId);
191233 }
192234 }
193- auto it = getCurrentResponse ();
194- if (it != mReadyResponses .end ())
195- {
196- auto reqId = mCurrentRequest .value ();
197- auto count = --mRemainSendCount [reqId];
198- TLLM_CHECK (count >= 0 );
199- if (count == 0 )
200- {
201- mRemainSendCount .erase (reqId);
202-
203- // TODO(zhengd): pass the hashes directly instead of update llmRequest
204- auto llmRequest = it->second .mRequest ;
205- llmRequest->setRequestedBlockHashes (std::move (blockHashes));
206-
207- if (common::getEnvParallelCacheSend ())
208- {
209- // TODO: Use a thread pool and check for thread safety.
210- std::thread (
211- &DataResponder::Impl::sendAndRemoveResponse, this , it->first , std::move (it->second ))
212- .detach ();
213- }
214- else
215- {
216- DataResponder::Impl::sendAndRemoveResponse (it->first , std::move (it->second ));
217- }
218- removeResponse (it);
219- }
220- mCurrentRequest = std::nullopt ;
221- }
222- else
223- {
224- TLLM_CHECK_WITH_INFO (!mCurrentRequest .has_value (),
225- " This executor does not have a prepared KV cache for request ID: %zu, and the "
226- " mReadyResponses size is: %zu. mpi rank :%d " ,
227- mCurrentRequest .value (), mReadyResponses .size (), mpi::MpiComm::world ().getRank ());
228- std::unique_lock lk (mCondMutex );
229- mResponderCv .wait (lk, [this ]() { return (mAnyReady || mTerminate ); });
230- }
235+ sendResponse (reqId);
231236 }
232237 }
233238 catch (std::exception const & err)
@@ -264,31 +269,15 @@ class DataResponder::Impl
264269 }
265270 }
266271
267- [[nodiscard]] bool isSending () const
268- {
269- return mCurrentRequest .has_value ();
270- }
271-
272- [[nodiscard]] RequestIdType getCurrentRequestId () const
273- {
274- return mCurrentRequest .value ();
275- }
276-
277- [[nodiscard]] std::map<RequestIdType, Response>::iterator getCurrentResponse ()
278- {
279- std::unique_lock lk (mResponderMutex );
280- return mReadyResponses .find (getCurrentRequestId ());
281- }
282-
283272private:
284- std::optional<RequestIdType> mCurrentRequest ;
285273 std::map<RequestIdType, Response> mReadyResponses ;
286- std::mutex mResponderMutex , mCondMutex ;
274+ std::mutex mCondMutex , mSendMutex , mResponderMutex ;
287275 std::atomic<bool > mAnyReady {false }, mTerminate {false };
288276 std::condition_variable mResponderCv ;
289277 std::future<void > mResponseFuture ;
290278 std::unique_ptr<DataSender> mSender ;
291279 std::unordered_map<LlmRequest::RequestIdType, int > mRemainSendCount ;
280+ std::unordered_map<LlmRequest::RequestIdType, RequestInfo> mRequestInfoMap ;
292281 int mDeviceId {-1 };
293282};
294283
0 commit comments