Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 40 additions & 30 deletions cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,34 @@ class DataResponder::Impl
}
}

void sendResponse(std::vector<size_t> const& blockHashes, std::map<RequestIdType, Response>::iterator it)
{
auto reqId = mCurrentRequest.value();
auto count = --mRemainSendCount[reqId];
TLLM_CHECK(count >= 0);
if (count == 0)
{
mRemainSendCount.erase(reqId);

// TODO(zhengd): pass the hashes directly instead of update llmRequest
auto llmRequest = it->second.mRequest;
llmRequest->setRequestedBlockHashes(std::move(blockHashes));

if (common::getEnvParallelCacheSend())
{
// TODO: Use a thread pool and check for thread safety.
std::thread(&DataResponder::Impl::sendAndRemoveResponse, this, it->first, std::move(it->second))
.detach();
}
else
{
DataResponder::Impl::sendAndRemoveResponse(it->first, std::move(it->second));
}
removeResponse(it);
}
mCurrentRequest = std::nullopt;
}

void response() noexcept
{
try
Expand Down Expand Up @@ -193,40 +221,22 @@ class DataResponder::Impl
auto it = getCurrentResponse();
if (it != mReadyResponses.end())
{
auto reqId = mCurrentRequest.value();
auto count = --mRemainSendCount[reqId];
TLLM_CHECK(count >= 0);
if (count == 0)
sendResponse(blockHashes, it);
}
else
{
auto it = getCurrentResponse();
while (it == mReadyResponses.end())
{
mRemainSendCount.erase(reqId);

// TODO(zhengd): pass the hashes directly instead of update llmRequest
auto llmRequest = it->second.mRequest;
llmRequest->setRequestedBlockHashes(std::move(blockHashes));

if (common::getEnvParallelCacheSend())
{
// TODO: Use a thread pool and check for thread safety.
std::thread(
&DataResponder::Impl::sendAndRemoveResponse, this, it->first, std::move(it->second))
.detach();
}
else
std::unique_lock lk(mCondMutex);
mResponderCv.wait(lk, [this]() { return (mAnyReady || mTerminate); });
if (mTerminate)
{
DataResponder::Impl::sendAndRemoveResponse(it->first, std::move(it->second));
break;
}
removeResponse(it);
it = getCurrentResponse();
}
mCurrentRequest = std::nullopt;
}
else
{
TLLM_CHECK_WITH_INFO(!mCurrentRequest.has_value(),
"This executor does not have a prepared KV cache for request ID: %zu, and the "
"mReadyResponses size is: %zu. mpi rank :%d ",
mCurrentRequest.value(), mReadyResponses.size(), mpi::MpiComm::world().getRank());
std::unique_lock lk(mCondMutex);
mResponderCv.wait(lk, [this]() { return (mAnyReady || mTerminate); });
sendResponse(blockHashes, it);
}
}
}
Expand Down
3 changes: 0 additions & 3 deletions tests/integration/test_lists/waives.txt
Original file line number Diff line number Diff line change
Expand Up @@ -258,16 +258,13 @@ examples/test_recurrentgemma.py::test_llm_recurrentgemma_1gpu[use_cpp_session-re
examples/test_bert.py::test_llm_bert_general[compare_hf-enable_remove_input_padding-use_attention_plugin-enable_context_fmha-tp:1-pp:1-float16-BertForSequenceClassification-bert/bert-base-uncased-yelp-polarity] SKIP (https://nvbugs/5421989)
examples/test_bert.py::test_llm_bert_general[compare_hf-enable_remove_input_padding-use_attention_plugin-enable_context_fmha-tp:1-pp:1-float16-RobertaForSequenceClassification-bert/twitter-roberta-base-emotion] SKIP (https://nvbugs/5421989)
examples/test_granite.py::test_granite_bf16_lora[granite-3.0-1b-a400m-instruct] SKIP (https://nvbugs/5431132)
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=1-ctx_pp=4] SKIP (https://nvbugs/5434320)
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=2-ctx_pp=4] SKIP (https://nvbugs/5434320)
accuracy/test_llm_api.py::TestLlama3_2_1B::test_int4_awq_int8_kv_cache SKIP (https://nvbugs/5433541)
accuracy/test_llm_api.py::TestLlama3_2_1B::test_fp8_pp2 SKIP (https://nvbugs/5433541)
accuracy/test_cli_flow.py::TestLlama3_3_70BInstruct::test_fp8_prequantized_tp4 SKIP (https://nvbugs/5409414)
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_tp4 SKIP (https://nvbugs/5409414)
accuracy/test_llm_api_pytorch.py::TestPhi4MM::test_auto_dtype SKIP (https://nvbugs/5433543)
accuracy/test_llm_api_pytorch.py::TestPhi4MM::test_auto_dtype_long_rope SKIP (https://nvbugs/5433543)
accuracy/test_llm_api_pytorch.py::TestPhi4MiniInstruct::test_auto_dtype SKIP (https://nvbugs/5433545)
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[GSM8K-gen_tp=2-ctx_pp=4] SKIP (https://nvbugs/5434320)
examples/test_gemma.py::test_hf_gemma_fp8_base_bf16_multi_lora[gemma-2-9b-it] SKIP (https://nvbugs/5434451)
examples/test_gemma.py::test_hf_gemma_fp8_base_bf16_multi_lora[gemma-2-27b-it] SKIP (https://nvbugs/5434451)
examples/test_gemma.py::test_hf_gemma_fp8_base_bf16_multi_lora[gemma-3-1b-it] SKIP (https://nvbugs/5434451)
Expand Down