Skip to content

Commit ef58f15

Browse files
committed
Clean up the code
Signed-off-by: ziyixiong-nv <[email protected]>
1 parent f785991 commit ef58f15

File tree

2 files changed

+39
-84
lines changed

2 files changed

+39
-84
lines changed

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 18 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
from .kv_cache_connector import KvCacheConnectorManager
4545
from .kv_cache_transceiver import KvCacheTransceiver
4646
from .llm_request import (ExecutorRequest, LlmRequest, LlmRequestState,
47-
LlmResponse, get_draft_token_length)
47+
LlmResponse)
4848
from .model_engine import ModelEngine
4949
from .sampler import Sampler, SampleState, SampleStateTensors
5050
from .scheduler import RequestScheduler, ScheduledRequests
@@ -1067,14 +1067,9 @@ def _executor_loop(self):
10671067
scheduled_requests=scheduled_batch):
10681068
self.drafter.prepare_draft_tokens(
10691069
scheduled_batch, self.resource_manager)
1070-
# Pad draft tokens to the max draft length. This is for CUDA
1071-
# graph compatibility.
1072-
for req in scheduled_batch.generation_requests:
1073-
max_draft_tokens = self.max_draft_len
1074-
num_draft_tokens = get_draft_token_length(req)
1075-
req.py_draft_tokens.extend(
1076-
0 for _ in range(max_draft_tokens -
1077-
num_draft_tokens))
1070+
# Pad draft tokens to the max draft length. This is for CUDA graph compatibility.
1071+
self.drafter.pad_draft_tokens_for_cuda_graph(
1072+
scheduled_batch)
10781073
# add_batch must be called again to restore to target requests with updated draft tokens.
10791074
if self.guided_decoder is not None:
10801075
self.guided_decoder.add_batch(scheduled_batch)
@@ -1207,6 +1202,8 @@ def _executor_loop_overlap(self):
12071202
target_inputs, draft_outputs, draft_batch = self._handle_speculative_decoding(
12081203
scheduled_batch, previous_tensors)
12091204

1205+
# Use the draft_model's outputs if we've launched the draft model.
1206+
# Otherwise, use the previous batch's outputs.
12101207
if target_inputs is not None:
12111208
previous_tensors_device = target_inputs
12121209
else:
@@ -1971,17 +1968,10 @@ def _remove_inflight_ids(self, scheduled_requests):
19711968
self.inflight_req_ids.erase(req.request_id)
19721969

19731970
def _handle_speculative_decoding(self, scheduled_batch, previous_tensors):
1974-
"""
1975-
Handle speculative decoding logic.
1976-
1977-
Args:
1978-
scheduled_batch: The scheduled batch to process
1979-
previous_tensors: Previous iteration tensors
1980-
1981-
Returns:
1982-
Tuple of (target_inputs, draft_outputs, draft_batch)
1983-
"""
19841971
with request_context(is_draft=True, scheduled_requests=scheduled_batch):
1972+
# Do an early checking to see if we need to forward the draft model.
1973+
# If needed, the overlap should happen between the target requests and the draft requests.
1974+
# Otherwise, we can still do overlap between the previous target requests and the current target requests.
19851975
has_draft_batch = (
19861976
self.previous_batch is not None
19871977
and self.drafter.should_forward_draft_model(scheduled_batch))
@@ -2006,60 +1996,29 @@ def _handle_speculative_decoding(self, scheduled_batch, previous_tensors):
20061996
def _process_draft_results(self, scheduled_batch, draft_outputs,
20071997
draft_batch):
20081998
"""
2009-
Process the results from draft model execution.
2010-
2011-
Args:
2012-
scheduled_batch: The scheduled batch
2013-
draft_outputs: The outputs from the draft model
2014-
draft_batch: The draft batch that was processed
1999+
Append the draft tokens to the target requests, and clean up the draft resources.
20152000
"""
20162001
req_id_to_old_request = {
20172002
req.py_request_id: req
20182003
for req in scheduled_batch.all_requests()
20192004
}
20202005

20212006
if self.drafter.use_static_draft_loop:
2022-
self.process_static_draft_outputs(draft_outputs, draft_batch,
2023-
req_id_to_old_request)
2007+
self.drafter.process_static_draft_outputs(draft_outputs,
2008+
draft_batch,
2009+
req_id_to_old_request)
20242010
elif draft_outputs is not None:
2025-
self._process_dynamic_draft_outputs(scheduled_batch, draft_outputs,
2026-
req_id_to_old_request)
2011+
self.drafter.process_dynamic_draft_outputs(draft_outputs,
2012+
req_id_to_old_request)
20272013

2028-
def process_static_draft_outputs(self, draft_outputs, draft_batch,
2029-
req_id_to_old_request):
2030-
"""
2031-
Process outputs from static draft loop.
2032-
2033-
Args:
2034-
draft_outputs: The outputs from the draft model
2035-
draft_batch: The draft batch that was processed
2036-
req_id_to_old_request: Mapping from request ID to original request
2037-
"""
2038-
self.drafter.process_static_draft_outputs(draft_outputs, draft_batch,
2039-
req_id_to_old_request)
2040-
2041-
def _process_dynamic_draft_outputs(self, scheduled_batch, draft_outputs,
2042-
req_id_to_old_request):
2043-
"""
2044-
Process outputs from dynamic draft loop.
2045-
2046-
Args:
2047-
scheduled_batch: The scheduled batch
2048-
draft_outputs: The outputs from the draft model
2049-
req_id_to_old_request: Mapping from request ID to original request
2050-
"""
2051-
self.drafter.update_requests(draft_outputs)
2052-
self.drafter.process_decoded_tokens(draft_outputs.scheduled_requests,
2053-
req_id_to_old_request)
2054-
2055-
# Rollback draft tokens if guided decoder is available
2014+
# Pad draft tokens to the max draft length. This is for CUDA graph compatibility.
2015+
self.drafter.pad_draft_tokens_for_cuda_graph(scheduled_batch)
2016+
# add_batch must be called again to restore to target requests with updated draft tokens.
20562017
if self.guided_decoder is not None:
20572018
self.guided_decoder.add_batch(scheduled_batch)
20582019
if hasattr(self.drafter, "guided_decoder"):
20592020
self.guided_decoder.rollback_draft_tokens()
20602021

2061-
self.drafter.pad_draft_tokens_for_cuda_graph(scheduled_batch)
2062-
20632022

20642023
class DisaggPPTerminationHandler:
20652024
"""Handles termination synchronization across pipeline parallel ranks under disaggregated serving.

tensorrt_llm/_torch/speculative/model_drafter.py

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -178,8 +178,8 @@ def _add_to_draft_batch(self, draft_batch: ScheduledRequests,
178178
else:
179179
draft_batch.context_requests.append(draft_request)
180180

181-
@nvtx_range("prepare_draft_batch")
182-
def prepare_draft_batch(
181+
@nvtx_range("_prepare_draft_batch")
182+
def _prepare_draft_batch(
183183
self, scheduled_requests: ScheduledRequests) -> ScheduledRequests:
184184
"""
185185
Prepares a batch for the draft model engine. Draft tokens are only produced
@@ -238,7 +238,7 @@ def prepare_draft_batch(
238238
return draft_batch
239239

240240
except Exception as e:
241-
logger.error(f"Error in prepare_draft_batch: {str(e)}")
241+
logger.error(f"Error in _prepare_draft_batch: {str(e)}")
242242
traceback.print_exc()
243243
raise e
244244

@@ -461,7 +461,7 @@ def _setup_draft_batch_and_resources(
461461
if guided_decoder is not None:
462462
guided_decoder.rollback_rejected_tokens(scheduled_batch)
463463

464-
draft_batch = self.prepare_draft_batch(scheduled_batch)
464+
draft_batch = self._prepare_draft_batch(scheduled_batch)
465465
if draft_batch.batch_size == 0:
466466
return None, None
467467

@@ -482,7 +482,7 @@ def process_static_draft_outputs(
482482
Args:
483483
outputs: The outputs from the draft model
484484
draft_batch: The draft batch that was processed
485-
req_id_to_old_request: Mapping from request ID to original request
485+
req_id_to_old_request: Mapping from draft request ID to original request
486486
"""
487487
outputs_host = outputs.cpu()
488488
for token_idx in range(self.max_draft_tokens):
@@ -500,25 +500,24 @@ def process_static_draft_outputs(
500500
for req in draft_batch.all_requests():
501501
self.draft_seq_slot_manager.free_resources(req)
502502

503+
def process_dynamic_draft_outputs(
504+
self, outputs: Any,
505+
req_id_to_old_request: Dict[int, LlmRequest]) -> None:
506+
"""
507+
Process outputs from dynamic draft loop, update target requests, and clean up resources.
508+
"""
509+
self.update_requests(outputs)
510+
self.process_decoded_tokens(outputs.scheduled_requests,
511+
req_id_to_old_request)
512+
503513
def _execute_draft_iteration(
504514
self,
505515
draft_batch: ScheduledRequests,
506516
resource_manager: ResourceManager,
507517
previous_draft_state: Optional[SampleState],
508518
guided_decoder: Optional[GuidedDecoder] = None
509519
) -> Tuple[Any, Optional[SampleState]]:
510-
"""
511-
Execute a single draft iteration.
512-
513-
Args:
514-
draft_batch: The draft batch to process
515-
resource_manager: The resource manager
516-
previous_draft_state: The previous draft state
517-
guided_decoder: The guided decoder
518-
519-
Returns:
520-
Tuple of (outputs, new_sample_state)
521-
"""
520+
"""Forward pass through the draft model."""
522521
outputs = self.forward_draft_model(
523522
draft_batch,
524523
resource_manager,
@@ -586,7 +585,7 @@ def _execute_draft_loop(
586585
self._update_target_inputs_with_draft_tokens(
587586
target_inputs,
588587
draft_tensors,
589-
i + 1,
588+
draft_position=i + 1,
590589
draft_length=1,
591590
num_draft_reqs=num_draft_reqs)
592591

@@ -659,7 +658,7 @@ def generate_draft_tokens_with_overlap(
659658
self._update_target_inputs_with_draft_tokens(
660659
target_inputs,
661660
outputs,
662-
0,
661+
draft_position=0,
663662
draft_length=self.max_draft_tokens,
664663
num_draft_reqs=num_draft_reqs)
665664
return target_inputs, outputs, draft_batch
@@ -677,7 +676,7 @@ def generate_draft_tokens_with_overlap(
677676
self._update_target_inputs_with_draft_tokens(
678677
target_inputs,
679678
draft_tensors,
680-
0,
679+
draft_position=0,
681680
draft_length=1,
682681
num_draft_reqs=num_draft_reqs)
683682

@@ -726,7 +725,6 @@ def prepare_draft_tokens(
726725
req_id_to_old_request)
727726
return
728727

729-
# Handle guided decoder and sampling for non-static loop
730728
if self.guided_decoder is not None:
731729
self.guided_decoder.add_batch(draft_batch)
732730
self.guided_decoder.execute(outputs['logits'],
@@ -741,10 +739,8 @@ def prepare_draft_tokens(
741739

742740
# Final cleanup
743741
if previous_draft_state is not None:
744-
self.update_requests(previous_draft_state)
745-
self.process_decoded_tokens(
746-
previous_draft_state.scheduled_requests,
747-
req_id_to_old_request)
742+
self.process_dynamic_draft_outputs(previous_draft_state,
743+
req_id_to_old_request)
748744

749745
except Exception as e:
750746
traceback.print_exc()

0 commit comments

Comments
 (0)