Skip to content

Commit 91c4af3

Browse files
authored
[https://nvbugs/5434320][bug] Fix disagg pp bug (#7099)
Signed-off-by: Iman Tabrizian <[email protected]>
1 parent 4b1898e commit 91c4af3

File tree

2 files changed

+40
-33
lines changed

2 files changed

+40
-33
lines changed

cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp

Lines changed: 40 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,34 @@ class DataResponder::Impl
160160
}
161161
}
162162

163+
void sendResponse(std::vector<size_t> const& blockHashes, std::map<RequestIdType, Response>::iterator it)
164+
{
165+
auto reqId = mCurrentRequest.value();
166+
auto count = --mRemainSendCount[reqId];
167+
TLLM_CHECK(count >= 0);
168+
if (count == 0)
169+
{
170+
mRemainSendCount.erase(reqId);
171+
172+
// TODO(zhengd): pass the hashes directly instead of update llmRequest
173+
auto llmRequest = it->second.mRequest;
174+
llmRequest->setRequestedBlockHashes(std::move(blockHashes));
175+
176+
if (common::getEnvParallelCacheSend())
177+
{
178+
// TODO: Use a thread pool and check for thread safety.
179+
std::thread(&DataResponder::Impl::sendAndRemoveResponse, this, it->first, std::move(it->second))
180+
.detach();
181+
}
182+
else
183+
{
184+
DataResponder::Impl::sendAndRemoveResponse(it->first, std::move(it->second));
185+
}
186+
removeResponse(it);
187+
}
188+
mCurrentRequest = std::nullopt;
189+
}
190+
163191
void response() noexcept
164192
{
165193
try
@@ -193,40 +221,22 @@ class DataResponder::Impl
193221
auto it = getCurrentResponse();
194222
if (it != mReadyResponses.end())
195223
{
196-
auto reqId = mCurrentRequest.value();
197-
auto count = --mRemainSendCount[reqId];
198-
TLLM_CHECK(count >= 0);
199-
if (count == 0)
224+
sendResponse(blockHashes, it);
225+
}
226+
else
227+
{
228+
auto it = getCurrentResponse();
229+
while (it == mReadyResponses.end())
200230
{
201-
mRemainSendCount.erase(reqId);
202-
203-
// TODO(zhengd): pass the hashes directly instead of update llmRequest
204-
auto llmRequest = it->second.mRequest;
205-
llmRequest->setRequestedBlockHashes(std::move(blockHashes));
206-
207-
if (common::getEnvParallelCacheSend())
208-
{
209-
// TODO: Use a thread pool and check for thread safety.
210-
std::thread(
211-
&DataResponder::Impl::sendAndRemoveResponse, this, it->first, std::move(it->second))
212-
.detach();
213-
}
214-
else
231+
std::unique_lock lk(mCondMutex);
232+
mResponderCv.wait(lk, [this]() { return (mAnyReady || mTerminate); });
233+
if (mTerminate)
215234
{
216-
DataResponder::Impl::sendAndRemoveResponse(it->first, std::move(it->second));
235+
break;
217236
}
218-
removeResponse(it);
237+
it = getCurrentResponse();
219238
}
220-
mCurrentRequest = std::nullopt;
221-
}
222-
else
223-
{
224-
TLLM_CHECK_WITH_INFO(!mCurrentRequest.has_value(),
225-
"This executor does not have a prepared KV cache for request ID: %zu, and the "
226-
"mReadyResponses size is: %zu. mpi rank :%d ",
227-
mCurrentRequest.value(), mReadyResponses.size(), mpi::MpiComm::world().getRank());
228-
std::unique_lock lk(mCondMutex);
229-
mResponderCv.wait(lk, [this]() { return (mAnyReady || mTerminate); });
239+
sendResponse(blockHashes, it);
230240
}
231241
}
232242
}

tests/integration/test_lists/waives.txt

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -258,16 +258,13 @@ examples/test_recurrentgemma.py::test_llm_recurrentgemma_1gpu[use_cpp_session-re
258258
examples/test_bert.py::test_llm_bert_general[compare_hf-enable_remove_input_padding-use_attention_plugin-enable_context_fmha-tp:1-pp:1-float16-BertForSequenceClassification-bert/bert-base-uncased-yelp-polarity] SKIP (https://nvbugs/5421989)
259259
examples/test_bert.py::test_llm_bert_general[compare_hf-enable_remove_input_padding-use_attention_plugin-enable_context_fmha-tp:1-pp:1-float16-RobertaForSequenceClassification-bert/twitter-roberta-base-emotion] SKIP (https://nvbugs/5421989)
260260
examples/test_granite.py::test_granite_bf16_lora[granite-3.0-1b-a400m-instruct] SKIP (https://nvbugs/5431132)
261-
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=1-ctx_pp=4] SKIP (https://nvbugs/5434320)
262-
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=2-ctx_pp=4] SKIP (https://nvbugs/5434320)
263261
accuracy/test_llm_api.py::TestLlama3_2_1B::test_int4_awq_int8_kv_cache SKIP (https://nvbugs/5433541)
264262
accuracy/test_llm_api.py::TestLlama3_2_1B::test_fp8_pp2 SKIP (https://nvbugs/5433541)
265263
accuracy/test_cli_flow.py::TestLlama3_3_70BInstruct::test_fp8_prequantized_tp4 SKIP (https://nvbugs/5409414)
266264
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_tp4 SKIP (https://nvbugs/5409414)
267265
accuracy/test_llm_api_pytorch.py::TestPhi4MM::test_auto_dtype SKIP (https://nvbugs/5433543)
268266
accuracy/test_llm_api_pytorch.py::TestPhi4MM::test_auto_dtype_long_rope SKIP (https://nvbugs/5433543)
269267
accuracy/test_llm_api_pytorch.py::TestPhi4MiniInstruct::test_auto_dtype SKIP (https://nvbugs/5433545)
270-
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[GSM8K-gen_tp=2-ctx_pp=4] SKIP (https://nvbugs/5434320)
271268
examples/test_gemma.py::test_hf_gemma_fp8_base_bf16_multi_lora[gemma-2-9b-it] SKIP (https://nvbugs/5434451)
272269
examples/test_gemma.py::test_hf_gemma_fp8_base_bf16_multi_lora[gemma-2-27b-it] SKIP (https://nvbugs/5434451)
273270
examples/test_gemma.py::test_hf_gemma_fp8_base_bf16_multi_lora[gemma-3-1b-it] SKIP (https://nvbugs/5434451)

0 commit comments

Comments
 (0)