4444from .kv_cache_connector import KvCacheConnectorManager
4545from .kv_cache_transceiver import KvCacheTransceiver
4646from .llm_request import (ExecutorRequest , LlmRequest , LlmRequestState ,
47- LlmResponse , get_draft_token_length )
47+ LlmResponse )
4848from .model_engine import ModelEngine
4949from .sampler import Sampler , SampleState , SampleStateTensors
5050from .scheduler import RequestScheduler , ScheduledRequests
@@ -1067,14 +1067,9 @@ def _executor_loop(self):
10671067 scheduled_requests = scheduled_batch ):
10681068 self .drafter .prepare_draft_tokens (
10691069 scheduled_batch , self .resource_manager )
1070- # Pad draft tokens to the max draft length. This is for CUDA
1071- # graph compatibility.
1072- for req in scheduled_batch .generation_requests :
1073- max_draft_tokens = self .max_draft_len
1074- num_draft_tokens = get_draft_token_length (req )
1075- req .py_draft_tokens .extend (
1076- 0 for _ in range (max_draft_tokens -
1077- num_draft_tokens ))
1070+ # Pad draft tokens to the max draft length. This is for CUDA graph compatibility.
1071+ self .drafter .pad_draft_tokens_for_cuda_graph (
1072+ scheduled_batch )
10781073 # add_batch must be called again to restore to target requests with updated draft tokens.
10791074 if self .guided_decoder is not None :
10801075 self .guided_decoder .add_batch (scheduled_batch )
@@ -1207,6 +1202,8 @@ def _executor_loop_overlap(self):
12071202 target_inputs , draft_outputs , draft_batch = self ._handle_speculative_decoding (
12081203 scheduled_batch , previous_tensors )
12091204
1205+ # Use the draft_model's outputs if we've launched the draft model.
1206+ # Otherwise, use the previous batch's outputs.
12101207 if target_inputs is not None :
12111208 previous_tensors_device = target_inputs
12121209 else :
@@ -1971,17 +1968,10 @@ def _remove_inflight_ids(self, scheduled_requests):
19711968 self .inflight_req_ids .erase (req .request_id )
19721969
19731970 def _handle_speculative_decoding (self , scheduled_batch , previous_tensors ):
1974- """
1975- Handle speculative decoding logic.
1976-
1977- Args:
1978- scheduled_batch: The scheduled batch to process
1979- previous_tensors: Previous iteration tensors
1980-
1981- Returns:
1982- Tuple of (target_inputs, draft_outputs, draft_batch)
1983- """
19841971 with request_context (is_draft = True , scheduled_requests = scheduled_batch ):
1972+ # Do an early checking to see if we need to forward the draft model.
1973+ # If needed, the overlap should happen between the target requests and the draft requests.
1974+ # Otherwise, we can still do overlap between the previous target requests and the current target requests.
19851975 has_draft_batch = (
19861976 self .previous_batch is not None
19871977 and self .drafter .should_forward_draft_model (scheduled_batch ))
@@ -2006,60 +1996,29 @@ def _handle_speculative_decoding(self, scheduled_batch, previous_tensors):
20061996 def _process_draft_results (self , scheduled_batch , draft_outputs ,
20071997 draft_batch ):
20081998 """
2009- Process the results from draft model execution.
2010-
2011- Args:
2012- scheduled_batch: The scheduled batch
2013- draft_outputs: The outputs from the draft model
2014- draft_batch: The draft batch that was processed
1999+ Append the draft tokens to the target requests, and clean up the draft resources.
20152000 """
20162001 req_id_to_old_request = {
20172002 req .py_request_id : req
20182003 for req in scheduled_batch .all_requests ()
20192004 }
20202005
20212006 if self .drafter .use_static_draft_loop :
2022- self .process_static_draft_outputs (draft_outputs , draft_batch ,
2023- req_id_to_old_request )
2007+ self .drafter .process_static_draft_outputs (draft_outputs ,
2008+ draft_batch ,
2009+ req_id_to_old_request )
20242010 elif draft_outputs is not None :
2025- self ._process_dynamic_draft_outputs ( scheduled_batch , draft_outputs ,
2026- req_id_to_old_request )
2011+ self .drafter . process_dynamic_draft_outputs ( draft_outputs ,
2012+ req_id_to_old_request )
20272013
2028- def process_static_draft_outputs (self , draft_outputs , draft_batch ,
2029- req_id_to_old_request ):
2030- """
2031- Process outputs from static draft loop.
2032-
2033- Args:
2034- draft_outputs: The outputs from the draft model
2035- draft_batch: The draft batch that was processed
2036- req_id_to_old_request: Mapping from request ID to original request
2037- """
2038- self .drafter .process_static_draft_outputs (draft_outputs , draft_batch ,
2039- req_id_to_old_request )
2040-
2041- def _process_dynamic_draft_outputs (self , scheduled_batch , draft_outputs ,
2042- req_id_to_old_request ):
2043- """
2044- Process outputs from dynamic draft loop.
2045-
2046- Args:
2047- scheduled_batch: The scheduled batch
2048- draft_outputs: The outputs from the draft model
2049- req_id_to_old_request: Mapping from request ID to original request
2050- """
2051- self .drafter .update_requests (draft_outputs )
2052- self .drafter .process_decoded_tokens (draft_outputs .scheduled_requests ,
2053- req_id_to_old_request )
2054-
2055- # Rollback draft tokens if guided decoder is available
2014+ # Pad draft tokens to the max draft length. This is for CUDA graph compatibility.
2015+ self .drafter .pad_draft_tokens_for_cuda_graph (scheduled_batch )
2016+ # add_batch must be called again to restore to target requests with updated draft tokens.
20562017 if self .guided_decoder is not None :
20572018 self .guided_decoder .add_batch (scheduled_batch )
20582019 if hasattr (self .drafter , "guided_decoder" ):
20592020 self .guided_decoder .rollback_draft_tokens ()
20602021
2061- self .drafter .pad_draft_tokens_for_cuda_graph (scheduled_batch )
2062-
20632022
20642023class DisaggPPTerminationHandler :
20652024 """Handles termination synchronization across pipeline parallel ranks under disaggregated serving.
0 commit comments