Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 33 additions & 11 deletions cpp/include/tensorrt_llm/batch_manager/llmRequest.h
Original file line number Diff line number Diff line change
Expand Up @@ -828,8 +828,10 @@ class GenericLlmRequest
// for enc-dec models, pause means saving generated tokens to prompt but need to re-do encoder phase
mState = mEncoderTokens.has_value() || mEncoderInputFeatures ? LlmRequestState::kENCODER_INIT
: LlmRequestState::kCONTEXT_INIT;
mContextCurrentPosition = 0;
mPrepopulatedPromptLen = 0;
mContextCurrentPositionTarget = 0;
mContextCurrentPositionDraft = 0;
mPrepopulatedPromptLenTarget = 0;
mPrepopulatedPromptLenDraft = 0;
mContextChunkSize = mPromptLen;
mSeqSlot.reset();
}
Expand Down Expand Up @@ -1049,7 +1051,7 @@ class GenericLlmRequest

[[nodiscard]] SizeType32 getPrepopulatedPromptLen() const
{
return mPrepopulatedPromptLen;
return mUseDraftModel ? mPrepopulatedPromptLenDraft : mPrepopulatedPromptLenTarget;
}

void setPrepopulatedPromptLen(SizeType32 prepopulatedPromptLen, SizeType32 kvTokensPerBlock)
Expand All @@ -1066,7 +1068,10 @@ class GenericLlmRequest
"Invalid state: prepopulatedPromptLen (%d) >= promptLen (%d) for request %lu", prepopulatedPromptLen,
promptLen, mRequestId);
TLLM_CHECK(prepopulatedPromptLen < promptLen);
mPrepopulatedPromptLen = prepopulatedPromptLen;

auto& prePromptLen = mUseDraftModel ? mPrepopulatedPromptLenDraft : mPrepopulatedPromptLenTarget;
auto& contextCurrentPosition = mUseDraftModel ? mContextCurrentPositionDraft : mContextCurrentPositionTarget;
prePromptLen = prepopulatedPromptLen;

if (prepopulatedPromptLen > 0)
{
Expand All @@ -1081,7 +1086,7 @@ class GenericLlmRequest
chunkSize = flooredEndPosition - prepopulatedPromptLen;
TLLM_CHECK(chunkSize <= getContextChunkSize());
}
setContextCurrentPosition(prepopulatedPromptLen);
contextCurrentPosition = prepopulatedPromptLen;
setContextChunkSize(chunkSize);

if (!isLastContextChunk())
Expand Down Expand Up @@ -1522,14 +1527,15 @@ class GenericLlmRequest

void setContextCurrentPosition(SizeType32 contextCurrentPosition)
{
mContextCurrentPosition = contextCurrentPosition;
mContextCurrentPositionDraft = contextCurrentPosition;
mContextCurrentPositionTarget = contextCurrentPosition;
}

/// When chunked, the position of the current chunk is returned. Otherwise, only the beginning
/// or end of the context is returned.
[[nodiscard]] SizeType32 getContextCurrentPosition() const noexcept
{
return mContextCurrentPosition;
return mUseDraftModel ? mContextCurrentPositionDraft : mContextCurrentPositionTarget;
}

/// Return the length of the context that has not yet been processed.
Expand Down Expand Up @@ -1570,14 +1576,16 @@ class GenericLlmRequest
{
// The number of cached token is encountered in mContextCurrentPosition,
// so the start position of the context is mPrepopulatedPromptLen.
return mContextCurrentPosition == mPrepopulatedPromptLen;
return getContextCurrentPosition() == getPrepopulatedPromptLen();
}

/// Move the cursor forward one chunk. When not chunked, move forward to the end of the context.
void moveToNextContextChunk()
{
TLLM_CHECK_WITH_INFO(isContextInitState(), "Chunking is only possible during the context phase.");
mContextCurrentPosition += getContextChunkSize();

mContextCurrentPositionDraft += getContextChunkSize();
mContextCurrentPositionTarget += getContextChunkSize();
setContextChunkSize(0);
}

Expand Down Expand Up @@ -1843,6 +1851,16 @@ class GenericLlmRequest
return mIsDummyRequest;
}

void setUseDraftModel(bool useDraftModel)
{
mUseDraftModel = useDraftModel;
}

[[nodiscard]] bool useDraftModel() const
{
return mUseDraftModel;
}

RequestIdType mRequestId;
SizeType32 mPromptLen;
SizeType32 mMaxNewTokens;
Expand Down Expand Up @@ -1885,7 +1903,8 @@ class GenericLlmRequest
// Number of tokens already in KV cache before context phase.
// A value > 0 indicates cached KV cache blocks were reused.
// Up to inputLen - 1 tokens can be reused.
SizeType32 mPrepopulatedPromptLen{0};
SizeType32 mPrepopulatedPromptLenTarget{0};
SizeType32 mPrepopulatedPromptLenDraft{0};

SizeType32 mMaxSentTokenLen;

Expand Down Expand Up @@ -1916,7 +1935,8 @@ class GenericLlmRequest
// The size of the context chunk must be multiple of the KV-Cache block size except the last one.
// Value `0` means Chunked-Context is disabled.
SizeType32 mContextChunkSize{0};
SizeType32 mContextCurrentPosition{0};
SizeType32 mContextCurrentPositionTarget{0};
SizeType32 mContextCurrentPositionDraft{0};

std::vector<VecLogProbs> mLogProbs; // [beamSize, seqLen]
VecLogProbs mCumLogProbs; // [beamSize]
Expand Down Expand Up @@ -2017,6 +2037,8 @@ class GenericLlmRequest

bool mIsDummyRequest{false};

bool mUseDraftModel{false};

private:
void initialize(VecTokens const& inputTokens, bool outputLogProbs)
{
Expand Down
3 changes: 1 addition & 2 deletions cpp/tensorrt_llm/batch_manager/guidedDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,7 @@ void GuidedDecoder::build(ScheduledRequests const& scheduledRequests)
continue;
}
auto const seqSlot = llmReq->mSeqSlot.value();
if (llmReq->isContextInitState()
&& llmReq->getContextCurrentPosition() == llmReq->getPrepopulatedPromptLen())
if (llmReq->isContextInitState() && llmReq->isFirstContextChunk())
{
// The request is in the first context forward step (considering kv cache reuse).
auto const& guideType = guidedDecodingParams->getGuideType();
Expand Down
3 changes: 2 additions & 1 deletion cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,8 @@ void initBindings(nb::module_& m)
}
})
.def_prop_rw("is_dummy_request", &GenLlmReq::isDummyRequest, &GenLlmReq::setIsDummyRequest)
.def_prop_ro("return_perf_metrics", &GenLlmReq::getReturnPerfMetrics);
.def_prop_ro("return_perf_metrics", &GenLlmReq::getReturnPerfMetrics)
.def_prop_rw("use_draft_model", &GenLlmReq::useDraftModel, &GenLlmReq::setUseDraftModel);

nb::class_<tb::LlmRequest, GenLlmReq>(m, "LlmRequest", nb::dynamic_attr())
.def(
Expand Down
3 changes: 2 additions & 1 deletion cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,8 @@ void initBindings(pybind11::module_& m)
}
})
.def_property("is_dummy_request", &GenLlmReq::isDummyRequest, &GenLlmReq::setIsDummyRequest)
.def_property_readonly("return_perf_metrics", &GenLlmReq::getReturnPerfMetrics);
.def_property_readonly("return_perf_metrics", &GenLlmReq::getReturnPerfMetrics)
.def_property("use_draft_model", &GenLlmReq::useDraftModel, &GenLlmReq::setUseDraftModel);

py::classh<tb::LlmRequest, GenLlmReq>(m, "LlmRequest", pybind11::dynamic_attr())
.def(py::init<>(
Expand Down
2 changes: 2 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@ def _create_kv_cache_manager(
dtype=kv_cache_dtype,
spec_config=spec_config,
max_beam_width=executor_config.max_beam_width,
is_draft=model_engine.is_draft_model,
)
elif is_nemotron_hybrid(config):
if executor_config.max_beam_width > 1:
Expand Down Expand Up @@ -376,6 +377,7 @@ def _create_kv_cache_manager(
max_num_tokens=executor_config.max_num_tokens,
model_config=binding_model_config,
max_beam_width=executor_config.max_beam_width,
is_draft=model_engine.is_draft_model,
)
# KVCacheManager (Non-draft) modifies the max_seq_len field, update it to executor_config
if model_engine.kv_cache_manager_key == ResourceManagerType.KV_CACHE_MANAGER:
Expand Down
1 change: 1 addition & 0 deletions tensorrt_llm/_torch/pyexecutor/llm_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,7 @@ def __init__(
self.py_seq_slot = seq_slot
# If the request is a draft request, target_seq_slot is the sequence slot ID of its target request.
self.py_target_seq_slot = target_seq_slot
self.use_draft_model = is_draft

# TODO: remove this when use DynamicDecodeOp in pytorch flow.
# currently, keep py_stop_words_list as python list, rather than tensor.
Expand Down
16 changes: 10 additions & 6 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
except ImportError:
from cuda import cudart

from tensorrt_llm._torch.pyexecutor.resource_manager import ResourceManagerType
from tensorrt_llm._torch.pyexecutor.resource_manager import (
ResourceManagerType, request_context)
from tensorrt_llm._torch.pyexecutor.seq_slot_manager import SeqSlotManager
from tensorrt_llm._utils import (customized_gc_thresholds, global_mpi_rank,
is_trace_enabled, nvtx_range, trace_func)
Expand Down Expand Up @@ -937,11 +938,14 @@ def _executor_loop(self):
self.guided_decoder.init_disagg_gen_requests(
scheduled_batch)
if self.drafter is not None and self.use_spec_decode:
if self.guided_decoder is not None:
self.guided_decoder.rollback_rejected_tokens(
scheduled_batch)
self.drafter.prepare_draft_tokens(
scheduled_batch, self.resource_manager)
with request_context(
is_draft=True,
scheduled_requests=scheduled_batch):
if self.guided_decoder is not None:
self.guided_decoder.rollback_rejected_tokens(
scheduled_batch)
self.drafter.prepare_draft_tokens(
scheduled_batch, self.resource_manager)

batch_outputs = self._forward_step(scheduled_batch)
self._execute_guided_decoder(scheduled_batch,
Expand Down
85 changes: 58 additions & 27 deletions tensorrt_llm/_torch/pyexecutor/resource_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,33 @@ def get_pp_layers(
return pp_layers, total_num_layers


def request_context(is_draft: bool, scheduled_requests: ScheduledRequests):

class RequestContext:

def __init__(self, is_draft: bool,
scheduled_requests: ScheduledRequests):
self.is_draft = is_draft
self.scheduled_requests = scheduled_requests

def __enter__(self):
if not self.is_draft:
return

for req in self.scheduled_requests.all_requests():
req.use_draft_model = True

def __exit__(self, exc_type, exc_val, exc_tb):
if not self.is_draft:
return

# Clean up the state
for req in self.scheduled_requests.all_requests():
req.use_draft_model = False

return RequestContext(is_draft, scheduled_requests)


class KVCacheManager(BaseResourceManager):

def __init__(
Expand All @@ -132,6 +159,7 @@ def __init__(
max_num_tokens: int = 8192,
model_config: Optional[ModelConfig] = None,
max_beam_width: int = 1,
is_draft: bool = False,
) -> None:
self.mapping = mapping
self.dtype = dtype
Expand All @@ -142,6 +170,7 @@ def __init__(
spec_config=spec_config,
layer_mask=layer_mask,
)
self.is_draft = is_draft
self.num_local_layers = len(self.pp_layers)
self.layer_offsets = {
idx: offset
Expand Down Expand Up @@ -366,34 +395,36 @@ def get_needed_resource_to_completion(self, request: LlmRequest) -> int:
return need_blocks

def prepare_resources(self, scheduled_batch: ScheduledRequests):
context_batch = scheduled_batch.context_requests
generation_batch = scheduled_batch.generation_requests
# allocate KV Cache
for req in context_batch:
req_beam_width = req.sampling_config.beam_width
if 'cp_type' in self.mapping.cp_config and 'star_attention' == self.mapping.cp_config[
'cp_type']:
if req.ctx_iters == 0:
seq_len = sum(
len(ctx_block) for ctx_block in req.ctx_blocks)
self.impl.add_sequence(
req.py_request_id,
seq_len + (len(req.query_id) if self.mapping.cp_rank
== self.mapping.cp_size - 1 else 0),
req_beam_width, req)
else:
if req.is_first_context_chunk:
self.impl.add_sequence(req.py_request_id, req.prompt_len,
req_beam_width, req)
for _ in range(self.num_extra_kv_tokens):
self.impl.add_token(req.py_request_id)
for _ in range(get_draft_token_length(req)):
self.impl.add_token(req.py_request_id)

for req in generation_batch:
self.impl.add_token(req.py_request_id)
for _ in range(get_draft_token_length(req)):
with request_context(self.is_draft, scheduled_batch):
context_batch = scheduled_batch.context_requests
generation_batch = scheduled_batch.generation_requests
# allocate KV Cache
for req in context_batch:
req_beam_width = req.sampling_config.beam_width
if 'cp_type' in self.mapping.cp_config and 'star_attention' == self.mapping.cp_config[
'cp_type']:
if req.ctx_iters == 0:
seq_len = sum(
len(ctx_block) for ctx_block in req.ctx_blocks)
self.impl.add_sequence(
req.py_request_id,
seq_len + (len(req.query_id) if self.mapping.cp_rank
== self.mapping.cp_size - 1 else 0),
req_beam_width, req)
else:
if req.is_first_context_chunk:
self.impl.add_sequence(req.py_request_id,
req.prompt_len, req_beam_width,
req)
for _ in range(self.num_extra_kv_tokens):
self.impl.add_token(req.py_request_id)
for _ in range(get_draft_token_length(req)):
self.impl.add_token(req.py_request_id)

for req in generation_batch:
self.impl.add_token(req.py_request_id)
for _ in range(get_draft_token_length(req)):
self.impl.add_token(req.py_request_id)

def add_dummy_requests(
self,
Expand Down