Skip to content

Commit 6de50c2

Browse files
committed
Add py_is_draft flag into KVCacheManager
Signed-off-by: ziyixiong-nv <[email protected]>
1 parent 8cbbd16 commit 6de50c2

File tree

4 files changed

+69
-49
lines changed

4 files changed

+69
-49
lines changed

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,7 @@ def _create_kv_cache_manager(
314314
dtype=kv_cache_dtype,
315315
spec_config=spec_config,
316316
max_beam_width=executor_config.max_beam_width,
317+
is_draft=model_engine.is_draft_model,
317318
)
318319
elif is_nemotron_hybrid(config):
319320
if executor_config.max_beam_width > 1:
@@ -376,6 +377,7 @@ def _create_kv_cache_manager(
376377
max_num_tokens=executor_config.max_num_tokens,
377378
model_config=binding_model_config,
378379
max_beam_width=executor_config.max_beam_width,
380+
is_draft=model_engine.is_draft_model,
379381
)
380382
# KVCacheManager (Non-draft) modifies the max_seq_len field, update it to executor_config
381383
if model_engine.kv_cache_manager_key == ResourceManagerType.KV_CACHE_MANAGER:

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
except ImportError:
1818
from cuda import cudart
1919

20-
from tensorrt_llm._torch.pyexecutor.resource_manager import ResourceManagerType
20+
from tensorrt_llm._torch.pyexecutor.resource_manager import (
21+
ResourceManagerType, request_context)
2122
from tensorrt_llm._torch.pyexecutor.seq_slot_manager import SeqSlotManager
2223
from tensorrt_llm._utils import (customized_gc_thresholds, global_mpi_rank,
2324
is_trace_enabled, nvtx_range, trace_func)
@@ -940,11 +941,12 @@ def _executor_loop(self):
940941

941942
self.resource_manager.prepare_resources(scheduled_batch)
942943
if self.drafter is not None and self.use_spec_decode:
943-
if self.guided_decoder is not None:
944-
self.guided_decoder.rollback_rejected_tokens(
945-
scheduled_batch)
946-
self.drafter.prepare_draft_tokens(
947-
scheduled_batch, self.resource_manager)
944+
with request_context(True, scheduled_batch):
945+
if self.guided_decoder is not None:
946+
self.guided_decoder.rollback_rejected_tokens(
947+
scheduled_batch)
948+
self.drafter.prepare_draft_tokens(
949+
scheduled_batch, self.resource_manager)
948950

949951
batch_outputs = self._forward_step(scheduled_batch)
950952
self._execute_guided_decoder(scheduled_batch,

tensorrt_llm/_torch/pyexecutor/resource_manager.py

Lines changed: 59 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,33 @@ def get_pp_layers(
110110
return pp_layers, total_num_layers
111111

112112

113+
def request_context(is_draft: bool, scheduled_requests: ScheduledRequests):
114+
115+
class RequestContext:
116+
117+
def __init__(self, is_draft: bool,
118+
scheduled_requests: ScheduledRequests):
119+
self.is_draft = is_draft
120+
self.scheduled_requests = scheduled_requests
121+
122+
def __enter__(self):
123+
if not self.is_draft:
124+
return
125+
126+
for req in self.scheduled_requests.all_requests():
127+
req.use_draft_model = True
128+
129+
def __exit__(self, exc_type, exc_val, exc_tb):
130+
if not self.is_draft:
131+
return
132+
133+
# Clean up the state
134+
for req in self.scheduled_requests.all_requests():
135+
req.use_draft_model = False
136+
137+
return RequestContext(is_draft, scheduled_requests)
138+
139+
113140
class KVCacheManager(BaseResourceManager):
114141

115142
def __init__(
@@ -132,6 +159,7 @@ def __init__(
132159
max_num_tokens: int = 8192,
133160
model_config: Optional[ModelConfig] = None,
134161
max_beam_width: int = 1,
162+
is_draft: bool = False,
135163
) -> None:
136164
self.mapping = mapping
137165
self.dtype = dtype
@@ -142,6 +170,7 @@ def __init__(
142170
spec_config=spec_config,
143171
layer_mask=layer_mask,
144172
)
173+
self.py_is_draft = is_draft
145174
self.num_local_layers = len(self.pp_layers)
146175
self.layer_offsets = {
147176
idx: offset
@@ -366,34 +395,36 @@ def get_needed_resource_to_completion(self, request: LlmRequest) -> int:
366395
return need_blocks
367396

368397
def prepare_resources(self, scheduled_batch: ScheduledRequests):
369-
context_batch = scheduled_batch.context_requests
370-
generation_batch = scheduled_batch.generation_requests
371-
# allocate KV Cache
372-
for req in context_batch:
373-
req_beam_width = req.sampling_config.beam_width
374-
if 'cp_type' in self.mapping.cp_config and 'star_attention' == self.mapping.cp_config[
375-
'cp_type']:
376-
if req.ctx_iters == 0:
377-
seq_len = sum(
378-
len(ctx_block) for ctx_block in req.ctx_blocks)
379-
self.impl.add_sequence(
380-
req.py_request_id,
381-
seq_len + (len(req.query_id) if self.mapping.cp_rank
382-
== self.mapping.cp_size - 1 else 0),
383-
req_beam_width, req)
384-
else:
385-
if req.is_first_context_chunk:
386-
self.impl.add_sequence(req.py_request_id, req.prompt_len,
387-
req_beam_width, req)
388-
for _ in range(self.num_extra_kv_tokens):
389-
self.impl.add_token(req.py_request_id)
390-
for _ in range(get_draft_token_length(req)):
391-
self.impl.add_token(req.py_request_id)
392-
393-
for req in generation_batch:
394-
self.impl.add_token(req.py_request_id)
395-
for _ in range(get_draft_token_length(req)):
398+
with request_context(self.py_is_draft, scheduled_batch):
399+
context_batch = scheduled_batch.context_requests
400+
generation_batch = scheduled_batch.generation_requests
401+
# allocate KV Cache
402+
for req in context_batch:
403+
req_beam_width = req.sampling_config.beam_width
404+
if 'cp_type' in self.mapping.cp_config and 'star_attention' == self.mapping.cp_config[
405+
'cp_type']:
406+
if req.ctx_iters == 0:
407+
seq_len = sum(
408+
len(ctx_block) for ctx_block in req.ctx_blocks)
409+
self.impl.add_sequence(
410+
req.py_request_id,
411+
seq_len + (len(req.query_id) if self.mapping.cp_rank
412+
== self.mapping.cp_size - 1 else 0),
413+
req_beam_width, req)
414+
else:
415+
if req.is_first_context_chunk:
416+
self.impl.add_sequence(req.py_request_id,
417+
req.prompt_len, req_beam_width,
418+
req)
419+
for _ in range(self.num_extra_kv_tokens):
420+
self.impl.add_token(req.py_request_id)
421+
for _ in range(get_draft_token_length(req)):
422+
self.impl.add_token(req.py_request_id)
423+
424+
for req in generation_batch:
396425
self.impl.add_token(req.py_request_id)
426+
for _ in range(get_draft_token_length(req)):
427+
self.impl.add_token(req.py_request_id)
397428

398429
def add_dummy_requests(
399430
self,
@@ -1156,11 +1187,7 @@ def get_resource_manager(self, name: str) -> BaseResourceManager:
11561187

11571188
@nvtx_range("prepare_resources")
11581189
def prepare_resources(self, scheduled_batch: ScheduledRequests):
1159-
for resource_mgr_type, resource_manager in self.resource_managers.items(
1160-
):
1161-
# Delay the preparation of draft kv cache manager to ModelDrafter.prepare_draft_tokens.
1162-
if resource_mgr_type == ResourceManagerType.DRAFT_KV_CACHE_MANAGER:
1163-
continue
1190+
for _, resource_manager in self.resource_managers.items():
11641191
if hasattr(resource_manager, "prepare_resources"):
11651192
resource_manager.prepare_resources(scheduled_batch)
11661193

tensorrt_llm/_torch/speculative/model_drafter.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -346,20 +346,9 @@ def prepare_draft_tokens(
346346

347347
if resource_manager is None:
348348
raise ValueError("Resource manager is required")
349-
kv_cache_manager = resource_manager.get_resource_manager(
350-
self.draft_model_engine.kv_cache_manager_key)
351-
if kv_cache_manager is not None:
352-
# Set the use_draft_model flag for all requests to prepare resources for the draft model
353-
for req in scheduled_requests.all_requests():
354-
req.use_draft_model = True
355-
356-
kv_cache_manager.prepare_resources(scheduled_requests)
357349

358350
try:
359351
draft_batch = self._prepare_draft_batch(scheduled_requests)
360-
# Reset the use_draft_model flag for all requests
361-
for req in scheduled_requests.all_requests():
362-
req.use_draft_model = False
363352

364353
if draft_batch.batch_size == 0:
365354
return

0 commit comments

Comments
 (0)