Skip to content

Commit 33978ba

Browse files
committed
Fix disagg pp bug
Signed-off-by: Iman Tabrizian <[email protected]>
1 parent c4535e6 commit 33978ba

File tree

5 files changed

+56
-68
lines changed

5 files changed

+56
-68
lines changed

cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp

Lines changed: 52 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -101,10 +101,52 @@ class DataResponder::Impl
101101
{
102102
TLLM_CHECK(mSender);
103103
TLLM_CUDA_CHECK(cudaGetDevice(&mDeviceId));
104-
mCurrentRequest = std::nullopt;
105104
mResponseFuture = std::async(std::launch::async, &Impl::response, this);
106105
}
107106

107+
void sendResponse(RequestIdType reqId) noexcept
108+
{
109+
std::unique_lock lk(mSendMutex);
110+
// Send context cache is not called for this request yet.
111+
std::unique_lock<std::mutex> lkResp(mResponderMutex);
112+
auto readyResponseIt = mReadyResponses.find(reqId);
113+
if (readyResponseIt == mReadyResponses.end())
114+
{
115+
return;
116+
}
117+
lkResp.unlock();
118+
auto it = mRequestInfoMap.find(reqId);
119+
if (it == mRequestInfoMap.end())
120+
{
121+
return;
122+
}
123+
auto blockHashes = it->second.getBlockHashes();
124+
auto count = --mRemainSendCount[reqId];
125+
TLLM_CHECK(count >= 0);
126+
if (count == 0)
127+
{
128+
mRequestInfoMap.erase(it);
129+
mRemainSendCount.erase(reqId);
130+
131+
// TODO(zhengd): pass the hashes directly instead of update llmRequest
132+
auto llmRequest = readyResponseIt->second.mRequest;
133+
llmRequest->setRequestedBlockHashes(std::move(blockHashes));
134+
135+
if (common::getEnvParallelCacheSend())
136+
{
137+
// TODO: Use a thread pool and check for thread safety.
138+
std::thread(
139+
&DataResponder::Impl::sendAndRemoveResponse, this, reqId, std::move(readyResponseIt->second))
140+
.detach();
141+
}
142+
else
143+
{
144+
DataResponder::Impl::sendAndRemoveResponse(reqId, std::move(readyResponseIt->second));
145+
}
146+
removeResponse(readyResponseIt);
147+
}
148+
}
149+
108150
[[nodiscard]] std::future<void> respondAndSendAsync(LlmRequest& llmRequest)
109151
{
110152
std::promise<void> promise;
@@ -115,6 +157,7 @@ class DataResponder::Impl
115157
mReadyResponses.emplace(
116158
llmRequest.mRequestId, Response{std::addressof(llmRequest), std::move(promise)});
117159
}
160+
sendResponse(llmRequest.mRequestId);
118161
std::unique_lock lkCond(mCondMutex);
119162
mAnyReady = true;
120163
}
@@ -178,56 +221,18 @@ class DataResponder::Impl
178221
break;
179222
}
180223
std::vector<size_t> blockHashes;
181-
if (!isSending() && !mReadyResponses.empty())
224+
auto const& requestInfo = mSender->recvRequestInfo();
225+
auto reqId = requestInfo.getRequestId();
226+
blockHashes = requestInfo.getBlockHashes();
182227
{
183-
auto const& requestInfo = mSender->recvRequestInfo();
184-
auto reqId = requestInfo.getRequestId();
185-
blockHashes = requestInfo.getBlockHashes();
186-
187-
mCurrentRequest = reqId;
228+
std::unique_lock lk(mSendMutex);
229+
mRequestInfoMap[reqId] = std::move(requestInfo);
188230
if (mRemainSendCount.find(reqId) == mRemainSendCount.end())
189231
{
190232
mRemainSendCount[reqId] = mSender->getCounterpartsCount(reqId);
191233
}
192234
}
193-
auto it = getCurrentResponse();
194-
if (it != mReadyResponses.end())
195-
{
196-
auto reqId = mCurrentRequest.value();
197-
auto count = --mRemainSendCount[reqId];
198-
TLLM_CHECK(count >= 0);
199-
if (count == 0)
200-
{
201-
mRemainSendCount.erase(reqId);
202-
203-
// TODO(zhengd): pass the hashes directly instead of update llmRequest
204-
auto llmRequest = it->second.mRequest;
205-
llmRequest->setRequestedBlockHashes(std::move(blockHashes));
206-
207-
if (common::getEnvParallelCacheSend())
208-
{
209-
// TODO: Use a thread pool and check for thread safety.
210-
std::thread(
211-
&DataResponder::Impl::sendAndRemoveResponse, this, it->first, std::move(it->second))
212-
.detach();
213-
}
214-
else
215-
{
216-
DataResponder::Impl::sendAndRemoveResponse(it->first, std::move(it->second));
217-
}
218-
removeResponse(it);
219-
}
220-
mCurrentRequest = std::nullopt;
221-
}
222-
else
223-
{
224-
TLLM_CHECK_WITH_INFO(!mCurrentRequest.has_value(),
225-
"This executor does not have a prepared KV cache for request ID: %zu, and the "
226-
"mReadyResponses size is: %zu. mpi rank :%d ",
227-
mCurrentRequest.value(), mReadyResponses.size(), mpi::MpiComm::world().getRank());
228-
std::unique_lock lk(mCondMutex);
229-
mResponderCv.wait(lk, [this]() { return (mAnyReady || mTerminate); });
230-
}
235+
sendResponse(reqId);
231236
}
232237
}
233238
catch (std::exception const& err)
@@ -264,31 +269,15 @@ class DataResponder::Impl
264269
}
265270
}
266271

267-
[[nodiscard]] bool isSending() const
268-
{
269-
return mCurrentRequest.has_value();
270-
}
271-
272-
[[nodiscard]] RequestIdType getCurrentRequestId() const
273-
{
274-
return mCurrentRequest.value();
275-
}
276-
277-
[[nodiscard]] std::map<RequestIdType, Response>::iterator getCurrentResponse()
278-
{
279-
std::unique_lock lk(mResponderMutex);
280-
return mReadyResponses.find(getCurrentRequestId());
281-
}
282-
283272
private:
284-
std::optional<RequestIdType> mCurrentRequest;
285273
std::map<RequestIdType, Response> mReadyResponses;
286-
std::mutex mResponderMutex, mCondMutex;
274+
std::mutex mCondMutex, mSendMutex, mResponderMutex;
287275
std::atomic<bool> mAnyReady{false}, mTerminate{false};
288276
std::condition_variable mResponderCv;
289277
std::future<void> mResponseFuture;
290278
std::unique_ptr<DataSender> mSender;
291279
std::unordered_map<LlmRequest::RequestIdType, int> mRemainSendCount;
280+
std::unordered_map<LlmRequest::RequestIdType, RequestInfo> mRequestInfoMap;
292281
int mDeviceId{-1};
293282
};
294283

cpp/tensorrt_llm/batch_manager/llmRequest.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ void LlmRequest::createSerializedResult(
6969
/// Note that there is some dependency on the order of operations in this method. Modify with care!
7070
std::optional<executor::Result> LlmRequest::createResult(bool useFastLogits, int32_t mpiWorldRank)
7171
{
72-
TLLM_CHECK(!isDisaggContextCompleteState());
7372
if (!(isFinished() || (mIsStreaming && mState == LlmRequestState::kGENERATION_IN_PROGRESS)))
7473
{
7574
return std::nullopt;

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -815,6 +815,9 @@ def _executor_loop_pp(self):
815815
),
816816
dest=self.dist.next_pp_rank,
817817
tag=prev_microbatch_id)
818+
# TODO: remove this wait, without this wait
819+
# there is an intermittent hang on some nodes.
820+
self.send_handles[prev_microbatch_id].wait()
818821
torch.cuda.nvtx.range_pop()
819822

820823
# Stage 3: Finalize previous batch that finished tokens communication

tests/integration/defs/accuracy/test_disaggregated_serving.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def multi_popen(server_configs):
183183
)
184184
raise
185185

186-
with (MyThreadPoolExecutor(max_workers=16) as thread_pool, temp_dir):
186+
with (MyThreadPoolExecutor(max_workers=4) as thread_pool, temp_dir):
187187
with multi_popen(ctx_servers + gen_servers):
188188
with popen([
189189
trtllm_serve_path, "disaggregated", "-c",

tests/integration/test_lists/waives.txt

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -259,17 +259,14 @@ examples/test_bert.py::test_llm_bert_general[compare_hf-enable_remove_input_padd
259259
examples/test_bert.py::test_llm_bert_general[compare_hf-enable_remove_input_padding-use_attention_plugin-enable_context_fmha-tp:1-pp:1-float16-RobertaForSequenceClassification-bert/twitter-roberta-base-emotion] SKIP (https://nvbugs/5421989)
260260
test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-mixture_text_image-True] SKIP (https://nvbugs/5430124)
261261
examples/test_granite.py::test_granite_bf16_lora[granite-3.0-1b-a400m-instruct] SKIP (https://nvbugs/5431132)
262-
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=1-ctx_pp=4] SKIP (https://nvbugs/5434320)
263262
accuracy/test_llm_api_pytorch.py::TestLlama4MaverickInstruct::test_fp8_eagle3[tp8-torch_compile=True] SKIP (https://nvbugs/5427801)
264-
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=2-ctx_pp=4] SKIP (https://nvbugs/5434320)
265263
accuracy/test_llm_api.py::TestLlama3_2_1B::test_int4_awq_int8_kv_cache SKIP (https://nvbugs/5433541)
266264
accuracy/test_llm_api.py::TestLlama3_2_1B::test_fp8_pp2 SKIP (https://nvbugs/5433541)
267265
accuracy/test_cli_flow.py::TestLlama3_3_70BInstruct::test_fp8_prequantized_tp4 SKIP (https://nvbugs/5409414)
268266
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_tp4 SKIP (https://nvbugs/5409414)
269267
accuracy/test_llm_api_pytorch.py::TestPhi4MM::test_auto_dtype SKIP (https://nvbugs/5433543)
270268
accuracy/test_llm_api_pytorch.py::TestPhi4MM::test_auto_dtype_long_rope SKIP (https://nvbugs/5433543)
271269
accuracy/test_llm_api_pytorch.py::TestPhi4MiniInstruct::test_auto_dtype SKIP (https://nvbugs/5433545)
272-
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[GSM8K-gen_tp=2-ctx_pp=4] SKIP (https://nvbugs/5434320)
273270
examples/test_gemma.py::test_hf_gemma_fp8_base_bf16_multi_lora[gemma-2-9b-it] SKIP (https://nvbugs/5434451)
274271
examples/test_gemma.py::test_hf_gemma_fp8_base_bf16_multi_lora[gemma-2-27b-it] SKIP (https://nvbugs/5434451)
275272
examples/test_gemma.py::test_hf_gemma_fp8_base_bf16_multi_lora[gemma-3-1b-it] SKIP (https://nvbugs/5434451)

0 commit comments

Comments
 (0)