Skip to content

Commit 8cbbd16

Browse files
committed
[TRTLLM-6637][feat] Resolve KV cache divergence issue
Signed-off-by: ziyixiong-nv <[email protected]>
1 parent 8227616 commit 8cbbd16

File tree

7 files changed

+55
-16
lines changed

7 files changed

+55
-16
lines changed

cpp/include/tensorrt_llm/batch_manager/llmRequest.h

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -828,8 +828,10 @@ class GenericLlmRequest
828828
// for enc-dec models, pause means saving generated tokens to prompt but need to re-do encoder phase
829829
mState = mEncoderTokens.has_value() || mEncoderInputFeatures ? LlmRequestState::kENCODER_INIT
830830
: LlmRequestState::kCONTEXT_INIT;
831-
mContextCurrentPosition = 0;
832-
mPrepopulatedPromptLen = 0;
831+
mContextCurrentPositionTarget = 0;
832+
mContextCurrentPositionDraft = 0;
833+
mPrepopulatedPromptLenTarget = 0;
834+
mPrepopulatedPromptLenDraft = 0;
833835
mContextChunkSize = mPromptLen;
834836
mSeqSlot.reset();
835837
}
@@ -1049,7 +1051,7 @@ class GenericLlmRequest
10491051

10501052
[[nodiscard]] SizeType32 getPrepopulatedPromptLen() const
10511053
{
1052-
return mPrepopulatedPromptLen;
1054+
return mUseDraftModel ? mPrepopulatedPromptLenDraft : mPrepopulatedPromptLenTarget;
10531055
}
10541056

10551057
void setPrepopulatedPromptLen(SizeType32 prepopulatedPromptLen, SizeType32 kvTokensPerBlock)
@@ -1066,7 +1068,10 @@ class GenericLlmRequest
10661068
"Invalid state: prepopulatedPromptLen (%d) >= promptLen (%d) for request %lu", prepopulatedPromptLen,
10671069
promptLen, mRequestId);
10681070
TLLM_CHECK(prepopulatedPromptLen < promptLen);
1069-
mPrepopulatedPromptLen = prepopulatedPromptLen;
1071+
1072+
auto& prePromptLen = mUseDraftModel ? mPrepopulatedPromptLenDraft : mPrepopulatedPromptLenTarget;
1073+
auto& contextCurrentPosition = mUseDraftModel ? mContextCurrentPositionDraft : mContextCurrentPositionTarget;
1074+
prePromptLen = prepopulatedPromptLen;
10701075

10711076
if (prepopulatedPromptLen > 0)
10721077
{
@@ -1081,7 +1086,7 @@ class GenericLlmRequest
10811086
chunkSize = flooredEndPosition - prepopulatedPromptLen;
10821087
TLLM_CHECK(chunkSize <= getContextChunkSize());
10831088
}
1084-
setContextCurrentPosition(prepopulatedPromptLen);
1089+
contextCurrentPosition = prepopulatedPromptLen;
10851090
setContextChunkSize(chunkSize);
10861091

10871092
if (!isLastContextChunk())
@@ -1522,14 +1527,15 @@ class GenericLlmRequest
15221527

15231528
void setContextCurrentPosition(SizeType32 contextCurrentPosition)
15241529
{
1525-
mContextCurrentPosition = contextCurrentPosition;
1530+
mContextCurrentPositionDraft = contextCurrentPosition;
1531+
mContextCurrentPositionTarget = contextCurrentPosition;
15261532
}
15271533

15281534
/// When chunked, the position of the current chunk is returned. Otherwise, only the beginning
15291535
/// or end of the context is returned.
15301536
[[nodiscard]] SizeType32 getContextCurrentPosition() const noexcept
15311537
{
1532-
return mContextCurrentPosition;
1538+
return mUseDraftModel ? mContextCurrentPositionDraft : mContextCurrentPositionTarget;
15331539
}
15341540

15351541
/// Return the length of the context that has not yet been processed.
@@ -1570,14 +1576,16 @@ class GenericLlmRequest
15701576
{
15711577
// The number of cached token is encountered in mContextCurrentPosition,
15721578
// so the start position of the context is mPrepopulatedPromptLen.
1573-
return mContextCurrentPosition == mPrepopulatedPromptLen;
1579+
return getContextCurrentPosition() == getPrepopulatedPromptLen();
15741580
}
15751581

15761582
/// Move the cursor forward one chunk. When not chunked, move forward to the end of the context.
15771583
void moveToNextContextChunk()
15781584
{
15791585
TLLM_CHECK_WITH_INFO(isContextInitState(), "Chunking is only possible during the context phase.");
1580-
mContextCurrentPosition += getContextChunkSize();
1586+
1587+
mContextCurrentPositionDraft += getContextChunkSize();
1588+
mContextCurrentPositionTarget += getContextChunkSize();
15811589
setContextChunkSize(0);
15821590
}
15831591

@@ -1843,6 +1851,16 @@ class GenericLlmRequest
18431851
return mIsDummyRequest;
18441852
}
18451853

1854+
void setUseDraftModel(bool useDraftModel)
1855+
{
1856+
mUseDraftModel = useDraftModel;
1857+
}
1858+
1859+
[[nodiscard]] bool useDraftModel() const
1860+
{
1861+
return mUseDraftModel;
1862+
}
1863+
18461864
RequestIdType mRequestId;
18471865
SizeType32 mPromptLen;
18481866
SizeType32 mMaxNewTokens;
@@ -1885,7 +1903,8 @@ class GenericLlmRequest
18851903
// Number of tokens already in KV cache before context phase.
18861904
// A value > 0 indicates cached KV cache blocks were reused.
18871905
// Up to inputLen - 1 tokens can be reused.
1888-
SizeType32 mPrepopulatedPromptLen{0};
1906+
SizeType32 mPrepopulatedPromptLenTarget{0};
1907+
SizeType32 mPrepopulatedPromptLenDraft{0};
18891908

18901909
SizeType32 mMaxSentTokenLen;
18911910

@@ -1916,7 +1935,8 @@ class GenericLlmRequest
19161935
// The size of the context chunk must be multiple of the KV-Cache block size except the last one.
19171936
// Value `0` means Chunked-Context is disabled.
19181937
SizeType32 mContextChunkSize{0};
1919-
SizeType32 mContextCurrentPosition{0};
1938+
SizeType32 mContextCurrentPositionTarget{0};
1939+
SizeType32 mContextCurrentPositionDraft{0};
19201940

19211941
std::vector<VecLogProbs> mLogProbs; // [beamSize, seqLen]
19221942
VecLogProbs mCumLogProbs; // [beamSize]
@@ -2017,6 +2037,8 @@ class GenericLlmRequest
20172037

20182038
bool mIsDummyRequest{false};
20192039

2040+
bool mUseDraftModel{false};
2041+
20202042
private:
20212043
void initialize(VecTokens const& inputTokens, bool outputLogProbs)
20222044
{

cpp/tensorrt_llm/batch_manager/guidedDecoder.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,7 @@ void GuidedDecoder::build(ScheduledRequests const& scheduledRequests)
8888
continue;
8989
}
9090
auto const seqSlot = llmReq->mSeqSlot.value();
91-
if (llmReq->isContextInitState()
92-
&& llmReq->getContextCurrentPosition() == llmReq->getPrepopulatedPromptLen())
91+
if (llmReq->isContextInitState() && llmReq->isFirstContextChunk())
9392
{
9493
// The request is in the first context forward step (considering kv cache reuse).
9594
auto const& guideType = guidedDecodingParams->getGuideType();

cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,8 @@ void initBindings(nb::module_& m)
248248
}
249249
})
250250
.def_prop_rw("is_dummy_request", &GenLlmReq::isDummyRequest, &GenLlmReq::setIsDummyRequest)
251-
.def_prop_ro("return_perf_metrics", &GenLlmReq::getReturnPerfMetrics);
251+
.def_prop_ro("return_perf_metrics", &GenLlmReq::getReturnPerfMetrics)
252+
.def_prop_rw("use_draft_model", &GenLlmReq::useDraftModel, &GenLlmReq::setUseDraftModel);
252253

253254
nb::class_<tb::LlmRequest, GenLlmReq>(m, "LlmRequest", nb::dynamic_attr())
254255
.def(

cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,8 @@ void initBindings(pybind11::module_& m)
253253
}
254254
})
255255
.def_property("is_dummy_request", &GenLlmReq::isDummyRequest, &GenLlmReq::setIsDummyRequest)
256-
.def_property_readonly("return_perf_metrics", &GenLlmReq::getReturnPerfMetrics);
256+
.def_property_readonly("return_perf_metrics", &GenLlmReq::getReturnPerfMetrics)
257+
.def_property("use_draft_model", &GenLlmReq::useDraftModel, &GenLlmReq::setUseDraftModel);
257258

258259
py::classh<tb::LlmRequest, GenLlmReq>(m, "LlmRequest", pybind11::dynamic_attr())
259260
.def(py::init<>(

tensorrt_llm/_torch/pyexecutor/llm_request.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,7 @@ def __init__(
331331
self.py_is_draft = is_draft
332332
self.py_seq_slot = seq_slot
333333
self.py_target_seq_slot = target_seq_slot
334+
self.use_draft_model = is_draft
334335

335336
# TODO: remove this when use DynamicDecodeOp in pytorch flow.
336337
# currently, keep py_stop_words_list as python list, rather than tensor.

tensorrt_llm/_torch/pyexecutor/resource_manager.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1156,7 +1156,11 @@ def get_resource_manager(self, name: str) -> BaseResourceManager:
11561156

11571157
@nvtx_range("prepare_resources")
11581158
def prepare_resources(self, scheduled_batch: ScheduledRequests):
1159-
for _, resource_manager in self.resource_managers.items():
1159+
for resource_mgr_type, resource_manager in self.resource_managers.items(
1160+
):
1161+
# Delay the preparation of draft kv cache manager to ModelDrafter.prepare_draft_tokens.
1162+
if resource_mgr_type == ResourceManagerType.DRAFT_KV_CACHE_MANAGER:
1163+
continue
11601164
if hasattr(resource_manager, "prepare_resources"):
11611165
resource_manager.prepare_resources(scheduled_batch)
11621166

tensorrt_llm/_torch/speculative/model_drafter.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,9 +346,20 @@ def prepare_draft_tokens(
346346

347347
if resource_manager is None:
348348
raise ValueError("Resource manager is required")
349+
kv_cache_manager = resource_manager.get_resource_manager(
350+
self.draft_model_engine.kv_cache_manager_key)
351+
if kv_cache_manager is not None:
352+
# Set the use_draft_model flag for all requests to prepare resources for the draft model
353+
for req in scheduled_requests.all_requests():
354+
req.use_draft_model = True
355+
356+
kv_cache_manager.prepare_resources(scheduled_requests)
349357

350358
try:
351359
draft_batch = self._prepare_draft_batch(scheduled_requests)
360+
# Reset the use_draft_model flag for all requests
361+
for req in scheduled_requests.all_requests():
362+
req.use_draft_model = False
352363

353364
if draft_batch.batch_size == 0:
354365
return

0 commit comments

Comments
 (0)