From ef9596ac3ba3927c2177883820074e685f7e09e5 Mon Sep 17 00:00:00 2001 From: ziyixiong-nv <219238287+ziyixiong-nv@users.noreply.github.com> Date: Tue, 5 Aug 2025 09:39:53 +0000 Subject: [PATCH 1/2] [TRTLLM-6637][feat] Resolve KV cache divergence issue Signed-off-by: ziyixiong-nv <219238287+ziyixiong-nv@users.noreply.github.com> --- .../tensorrt_llm/batch_manager/llmRequest.h | 44 ++++++++++++++----- .../batch_manager/guidedDecoder.cpp | 3 +- .../nanobind/batch_manager/bindings.cpp | 3 +- .../pybind/batch_manager/bindings.cpp | 3 +- tensorrt_llm/_torch/pyexecutor/llm_request.py | 1 + .../_torch/pyexecutor/resource_manager.py | 6 ++- .../_torch/speculative/model_drafter.py | 11 +++++ 7 files changed, 55 insertions(+), 16 deletions(-) diff --git a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h index 3320c6b0929..e4d13c9e17b 100644 --- a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h +++ b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h @@ -828,8 +828,10 @@ class GenericLlmRequest // for enc-dec models, pause means saving generated tokens to prompt but need to re-do encoder phase mState = mEncoderTokens.has_value() || mEncoderInputFeatures ? LlmRequestState::kENCODER_INIT : LlmRequestState::kCONTEXT_INIT; - mContextCurrentPosition = 0; - mPrepopulatedPromptLen = 0; + mContextCurrentPositionTarget = 0; + mContextCurrentPositionDraft = 0; + mPrepopulatedPromptLenTarget = 0; + mPrepopulatedPromptLenDraft = 0; mContextChunkSize = mPromptLen; mSeqSlot.reset(); } @@ -1049,7 +1051,7 @@ class GenericLlmRequest [[nodiscard]] SizeType32 getPrepopulatedPromptLen() const { - return mPrepopulatedPromptLen; + return mUseDraftModel ? mPrepopulatedPromptLenDraft : mPrepopulatedPromptLenTarget; } void setPrepopulatedPromptLen(SizeType32 prepopulatedPromptLen, SizeType32 kvTokensPerBlock) @@ -1066,7 +1068,10 @@ class GenericLlmRequest "Invalid state: prepopulatedPromptLen (%d) >= promptLen (%d) for request %lu", prepopulatedPromptLen, promptLen, mRequestId); TLLM_CHECK(prepopulatedPromptLen < promptLen); - mPrepopulatedPromptLen = prepopulatedPromptLen; + + auto& prePromptLen = mUseDraftModel ? mPrepopulatedPromptLenDraft : mPrepopulatedPromptLenTarget; + auto& contextCurrentPosition = mUseDraftModel ? mContextCurrentPositionDraft : mContextCurrentPositionTarget; + prePromptLen = prepopulatedPromptLen; if (prepopulatedPromptLen > 0) { @@ -1081,7 +1086,7 @@ class GenericLlmRequest chunkSize = flooredEndPosition - prepopulatedPromptLen; TLLM_CHECK(chunkSize <= getContextChunkSize()); } - setContextCurrentPosition(prepopulatedPromptLen); + contextCurrentPosition = prepopulatedPromptLen; setContextChunkSize(chunkSize); if (!isLastContextChunk()) @@ -1522,14 +1527,15 @@ class GenericLlmRequest void setContextCurrentPosition(SizeType32 contextCurrentPosition) { - mContextCurrentPosition = contextCurrentPosition; + mContextCurrentPositionDraft = contextCurrentPosition; + mContextCurrentPositionTarget = contextCurrentPosition; } /// When chunked, the position of the current chunk is returned. Otherwise, only the beginning /// or end of the context is returned. [[nodiscard]] SizeType32 getContextCurrentPosition() const noexcept { - return mContextCurrentPosition; + return mUseDraftModel ? mContextCurrentPositionDraft : mContextCurrentPositionTarget; } /// Return the length of the context that has not yet been processed. @@ -1570,14 +1576,16 @@ class GenericLlmRequest { // The number of cached token is encountered in mContextCurrentPosition, // so the start position of the context is mPrepopulatedPromptLen. - return mContextCurrentPosition == mPrepopulatedPromptLen; + return getContextCurrentPosition() == getPrepopulatedPromptLen(); } /// Move the cursor forward one chunk. When not chunked, move forward to the end of the context. void moveToNextContextChunk() { TLLM_CHECK_WITH_INFO(isContextInitState(), "Chunking is only possible during the context phase."); - mContextCurrentPosition += getContextChunkSize(); + + mContextCurrentPositionDraft += getContextChunkSize(); + mContextCurrentPositionTarget += getContextChunkSize(); setContextChunkSize(0); } @@ -1843,6 +1851,16 @@ class GenericLlmRequest return mIsDummyRequest; } + void setUseDraftModel(bool useDraftModel) + { + mUseDraftModel = useDraftModel; + } + + [[nodiscard]] bool useDraftModel() const + { + return mUseDraftModel; + } + RequestIdType mRequestId; SizeType32 mPromptLen; SizeType32 mMaxNewTokens; @@ -1885,7 +1903,8 @@ class GenericLlmRequest // Number of tokens already in KV cache before context phase. // A value > 0 indicates cached KV cache blocks were reused. // Up to inputLen - 1 tokens can be reused. - SizeType32 mPrepopulatedPromptLen{0}; + SizeType32 mPrepopulatedPromptLenTarget{0}; + SizeType32 mPrepopulatedPromptLenDraft{0}; SizeType32 mMaxSentTokenLen; @@ -1916,7 +1935,8 @@ class GenericLlmRequest // The size of the context chunk must be multiple of the KV-Cache block size except the last one. // Value `0` means Chunked-Context is disabled. SizeType32 mContextChunkSize{0}; - SizeType32 mContextCurrentPosition{0}; + SizeType32 mContextCurrentPositionTarget{0}; + SizeType32 mContextCurrentPositionDraft{0}; std::vector mLogProbs; // [beamSize, seqLen] VecLogProbs mCumLogProbs; // [beamSize] @@ -2017,6 +2037,8 @@ class GenericLlmRequest bool mIsDummyRequest{false}; + bool mUseDraftModel{false}; + private: void initialize(VecTokens const& inputTokens, bool outputLogProbs) { diff --git a/cpp/tensorrt_llm/batch_manager/guidedDecoder.cpp b/cpp/tensorrt_llm/batch_manager/guidedDecoder.cpp index 040dcd147e9..ea5f0981074 100644 --- a/cpp/tensorrt_llm/batch_manager/guidedDecoder.cpp +++ b/cpp/tensorrt_llm/batch_manager/guidedDecoder.cpp @@ -88,8 +88,7 @@ void GuidedDecoder::build(ScheduledRequests const& scheduledRequests) continue; } auto const seqSlot = llmReq->mSeqSlot.value(); - if (llmReq->isContextInitState() - && llmReq->getContextCurrentPosition() == llmReq->getPrepopulatedPromptLen()) + if (llmReq->isContextInitState() && llmReq->isFirstContextChunk()) { // The request is in the first context forward step (considering kv cache reuse). auto const& guideType = guidedDecodingParams->getGuideType(); diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp index 2ac069616e0..c170ca81015 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp @@ -248,7 +248,8 @@ void initBindings(nb::module_& m) } }) .def_prop_rw("is_dummy_request", &GenLlmReq::isDummyRequest, &GenLlmReq::setIsDummyRequest) - .def_prop_ro("return_perf_metrics", &GenLlmReq::getReturnPerfMetrics); + .def_prop_ro("return_perf_metrics", &GenLlmReq::getReturnPerfMetrics) + .def_prop_rw("use_draft_model", &GenLlmReq::useDraftModel, &GenLlmReq::setUseDraftModel); nb::class_(m, "LlmRequest", nb::dynamic_attr()) .def( diff --git a/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp b/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp index 04faa90c2ff..5cf036e76cf 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp @@ -253,7 +253,8 @@ void initBindings(pybind11::module_& m) } }) .def_property("is_dummy_request", &GenLlmReq::isDummyRequest, &GenLlmReq::setIsDummyRequest) - .def_property_readonly("return_perf_metrics", &GenLlmReq::getReturnPerfMetrics); + .def_property_readonly("return_perf_metrics", &GenLlmReq::getReturnPerfMetrics) + .def_property("use_draft_model", &GenLlmReq::useDraftModel, &GenLlmReq::setUseDraftModel); py::classh(m, "LlmRequest", pybind11::dynamic_attr()) .def(py::init<>( diff --git a/tensorrt_llm/_torch/pyexecutor/llm_request.py b/tensorrt_llm/_torch/pyexecutor/llm_request.py index a068327b6db..6c9a42494cc 100644 --- a/tensorrt_llm/_torch/pyexecutor/llm_request.py +++ b/tensorrt_llm/_torch/pyexecutor/llm_request.py @@ -333,6 +333,7 @@ def __init__( self.py_seq_slot = seq_slot # If the request is a draft request, target_seq_slot is the sequence slot ID of its target request. self.py_target_seq_slot = target_seq_slot + self.use_draft_model = is_draft # TODO: remove this when use DynamicDecodeOp in pytorch flow. # currently, keep py_stop_words_list as python list, rather than tensor. diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index 56c4871542e..b5d703fd48e 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -1156,7 +1156,11 @@ def get_resource_manager(self, name: str) -> BaseResourceManager: @nvtx_range("prepare_resources") def prepare_resources(self, scheduled_batch: ScheduledRequests): - for _, resource_manager in self.resource_managers.items(): + for resource_mgr_type, resource_manager in self.resource_managers.items( + ): + # Delay the preparation of draft kv cache manager to ModelDrafter.prepare_draft_tokens. + if resource_mgr_type == ResourceManagerType.DRAFT_KV_CACHE_MANAGER: + continue if hasattr(resource_manager, "prepare_resources"): resource_manager.prepare_resources(scheduled_batch) diff --git a/tensorrt_llm/_torch/speculative/model_drafter.py b/tensorrt_llm/_torch/speculative/model_drafter.py index 7f11142c3fa..570ca22f0a5 100644 --- a/tensorrt_llm/_torch/speculative/model_drafter.py +++ b/tensorrt_llm/_torch/speculative/model_drafter.py @@ -346,9 +346,20 @@ def prepare_draft_tokens( if resource_manager is None: raise ValueError("Resource manager is required") + kv_cache_manager = resource_manager.get_resource_manager( + self.draft_model_engine.kv_cache_manager_key) + if kv_cache_manager is not None: + # Set the use_draft_model flag for all requests to prepare resources for the draft model + for req in scheduled_requests.all_requests(): + req.use_draft_model = True + + kv_cache_manager.prepare_resources(scheduled_requests) try: draft_batch = self._prepare_draft_batch(scheduled_requests) + # Reset the use_draft_model flag for all requests + for req in scheduled_requests.all_requests(): + req.use_draft_model = False if draft_batch.batch_size == 0: return From f7b69733399470103c2e2716c9ba399fa616d99c Mon Sep 17 00:00:00 2001 From: ziyixiong-nv <219238287+ziyixiong-nv@users.noreply.github.com> Date: Thu, 7 Aug 2025 23:01:32 -0700 Subject: [PATCH 2/2] Add is_draft flag into KVCacheManager Signed-off-by: ziyixiong-nv <219238287+ziyixiong-nv@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/_util.py | 2 + tensorrt_llm/_torch/pyexecutor/py_executor.py | 16 ++-- .../_torch/pyexecutor/resource_manager.py | 91 ++++++++++++------- .../_torch/speculative/model_drafter.py | 11 --- 4 files changed, 71 insertions(+), 49 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index 52bd7089d74..43778c6ecc8 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -314,6 +314,7 @@ def _create_kv_cache_manager( dtype=kv_cache_dtype, spec_config=spec_config, max_beam_width=executor_config.max_beam_width, + is_draft=model_engine.is_draft_model, ) elif is_nemotron_hybrid(config): if executor_config.max_beam_width > 1: @@ -376,6 +377,7 @@ def _create_kv_cache_manager( max_num_tokens=executor_config.max_num_tokens, model_config=binding_model_config, max_beam_width=executor_config.max_beam_width, + is_draft=model_engine.is_draft_model, ) # KVCacheManager (Non-draft) modifies the max_seq_len field, update it to executor_config if model_engine.kv_cache_manager_key == ResourceManagerType.KV_CACHE_MANAGER: diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 0dad7ba7817..26b93c0a226 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -17,7 +17,8 @@ except ImportError: from cuda import cudart -from tensorrt_llm._torch.pyexecutor.resource_manager import ResourceManagerType +from tensorrt_llm._torch.pyexecutor.resource_manager import ( + ResourceManagerType, request_context) from tensorrt_llm._torch.pyexecutor.seq_slot_manager import SeqSlotManager from tensorrt_llm._utils import (customized_gc_thresholds, global_mpi_rank, is_trace_enabled, nvtx_range, trace_func) @@ -937,11 +938,14 @@ def _executor_loop(self): self.guided_decoder.init_disagg_gen_requests( scheduled_batch) if self.drafter is not None and self.use_spec_decode: - if self.guided_decoder is not None: - self.guided_decoder.rollback_rejected_tokens( - scheduled_batch) - self.drafter.prepare_draft_tokens( - scheduled_batch, self.resource_manager) + with request_context( + is_draft=True, + scheduled_requests=scheduled_batch): + if self.guided_decoder is not None: + self.guided_decoder.rollback_rejected_tokens( + scheduled_batch) + self.drafter.prepare_draft_tokens( + scheduled_batch, self.resource_manager) batch_outputs = self._forward_step(scheduled_batch) self._execute_guided_decoder(scheduled_batch, diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index b5d703fd48e..cfa34290d02 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -110,6 +110,33 @@ def get_pp_layers( return pp_layers, total_num_layers +def request_context(is_draft: bool, scheduled_requests: ScheduledRequests): + + class RequestContext: + + def __init__(self, is_draft: bool, + scheduled_requests: ScheduledRequests): + self.is_draft = is_draft + self.scheduled_requests = scheduled_requests + + def __enter__(self): + if not self.is_draft: + return + + for req in self.scheduled_requests.all_requests(): + req.use_draft_model = True + + def __exit__(self, exc_type, exc_val, exc_tb): + if not self.is_draft: + return + + # Clean up the state + for req in self.scheduled_requests.all_requests(): + req.use_draft_model = False + + return RequestContext(is_draft, scheduled_requests) + + class KVCacheManager(BaseResourceManager): def __init__( @@ -132,6 +159,7 @@ def __init__( max_num_tokens: int = 8192, model_config: Optional[ModelConfig] = None, max_beam_width: int = 1, + is_draft: bool = False, ) -> None: self.mapping = mapping self.dtype = dtype @@ -142,6 +170,7 @@ def __init__( spec_config=spec_config, layer_mask=layer_mask, ) + self.is_draft = is_draft self.num_local_layers = len(self.pp_layers) self.layer_offsets = { idx: offset @@ -366,34 +395,36 @@ def get_needed_resource_to_completion(self, request: LlmRequest) -> int: return need_blocks def prepare_resources(self, scheduled_batch: ScheduledRequests): - context_batch = scheduled_batch.context_requests - generation_batch = scheduled_batch.generation_requests - # allocate KV Cache - for req in context_batch: - req_beam_width = req.sampling_config.beam_width - if 'cp_type' in self.mapping.cp_config and 'star_attention' == self.mapping.cp_config[ - 'cp_type']: - if req.ctx_iters == 0: - seq_len = sum( - len(ctx_block) for ctx_block in req.ctx_blocks) - self.impl.add_sequence( - req.py_request_id, - seq_len + (len(req.query_id) if self.mapping.cp_rank - == self.mapping.cp_size - 1 else 0), - req_beam_width, req) - else: - if req.is_first_context_chunk: - self.impl.add_sequence(req.py_request_id, req.prompt_len, - req_beam_width, req) - for _ in range(self.num_extra_kv_tokens): - self.impl.add_token(req.py_request_id) - for _ in range(get_draft_token_length(req)): - self.impl.add_token(req.py_request_id) - - for req in generation_batch: - self.impl.add_token(req.py_request_id) - for _ in range(get_draft_token_length(req)): + with request_context(self.is_draft, scheduled_batch): + context_batch = scheduled_batch.context_requests + generation_batch = scheduled_batch.generation_requests + # allocate KV Cache + for req in context_batch: + req_beam_width = req.sampling_config.beam_width + if 'cp_type' in self.mapping.cp_config and 'star_attention' == self.mapping.cp_config[ + 'cp_type']: + if req.ctx_iters == 0: + seq_len = sum( + len(ctx_block) for ctx_block in req.ctx_blocks) + self.impl.add_sequence( + req.py_request_id, + seq_len + (len(req.query_id) if self.mapping.cp_rank + == self.mapping.cp_size - 1 else 0), + req_beam_width, req) + else: + if req.is_first_context_chunk: + self.impl.add_sequence(req.py_request_id, + req.prompt_len, req_beam_width, + req) + for _ in range(self.num_extra_kv_tokens): + self.impl.add_token(req.py_request_id) + for _ in range(get_draft_token_length(req)): + self.impl.add_token(req.py_request_id) + + for req in generation_batch: self.impl.add_token(req.py_request_id) + for _ in range(get_draft_token_length(req)): + self.impl.add_token(req.py_request_id) def add_dummy_requests( self, @@ -1156,11 +1187,7 @@ def get_resource_manager(self, name: str) -> BaseResourceManager: @nvtx_range("prepare_resources") def prepare_resources(self, scheduled_batch: ScheduledRequests): - for resource_mgr_type, resource_manager in self.resource_managers.items( - ): - # Delay the preparation of draft kv cache manager to ModelDrafter.prepare_draft_tokens. - if resource_mgr_type == ResourceManagerType.DRAFT_KV_CACHE_MANAGER: - continue + for _, resource_manager in self.resource_managers.items(): if hasattr(resource_manager, "prepare_resources"): resource_manager.prepare_resources(scheduled_batch) diff --git a/tensorrt_llm/_torch/speculative/model_drafter.py b/tensorrt_llm/_torch/speculative/model_drafter.py index 570ca22f0a5..7f11142c3fa 100644 --- a/tensorrt_llm/_torch/speculative/model_drafter.py +++ b/tensorrt_llm/_torch/speculative/model_drafter.py @@ -346,20 +346,9 @@ def prepare_draft_tokens( if resource_manager is None: raise ValueError("Resource manager is required") - kv_cache_manager = resource_manager.get_resource_manager( - self.draft_model_engine.kv_cache_manager_key) - if kv_cache_manager is not None: - # Set the use_draft_model flag for all requests to prepare resources for the draft model - for req in scheduled_requests.all_requests(): - req.use_draft_model = True - - kv_cache_manager.prepare_resources(scheduled_requests) try: draft_batch = self._prepare_draft_batch(scheduled_requests) - # Reset the use_draft_model flag for all requests - for req in scheduled_requests.all_requests(): - req.use_draft_model = False if draft_batch.batch_size == 0: return