@@ -110,6 +110,33 @@ def get_pp_layers(
110110 return pp_layers , total_num_layers
111111
112112
113+ def request_context (is_draft : bool , scheduled_requests : ScheduledRequests ):
114+
115+ class RequestContext :
116+
117+ def __init__ (self , is_draft : bool ,
118+ scheduled_requests : ScheduledRequests ):
119+ self .is_draft = is_draft
120+ self .scheduled_requests = scheduled_requests
121+
122+ def __enter__ (self ):
123+ if not self .is_draft :
124+ return
125+
126+ for req in self .scheduled_requests .all_requests ():
127+ req .use_draft_model = True
128+
129+ def __exit__ (self , exc_type , exc_val , exc_tb ):
130+ if not self .is_draft :
131+ return
132+
133+ # Clean up the state
134+ for req in self .scheduled_requests .all_requests ():
135+ req .use_draft_model = False
136+
137+ return RequestContext (is_draft , scheduled_requests )
138+
139+
113140class KVCacheManager (BaseResourceManager ):
114141
115142 def __init__ (
@@ -132,6 +159,7 @@ def __init__(
132159 max_num_tokens : int = 8192 ,
133160 model_config : Optional [ModelConfig ] = None ,
134161 max_beam_width : int = 1 ,
162+ is_draft : bool = False ,
135163 ) -> None :
136164 self .mapping = mapping
137165 self .dtype = dtype
@@ -142,6 +170,7 @@ def __init__(
142170 spec_config = spec_config ,
143171 layer_mask = layer_mask ,
144172 )
173+ self .py_is_draft = is_draft
145174 self .num_local_layers = len (self .pp_layers )
146175 self .layer_offsets = {
147176 idx : offset
@@ -366,34 +395,36 @@ def get_needed_resource_to_completion(self, request: LlmRequest) -> int:
366395 return need_blocks
367396
368397 def prepare_resources (self , scheduled_batch : ScheduledRequests ):
369- context_batch = scheduled_batch . context_requests
370- generation_batch = scheduled_batch .generation_requests
371- # allocate KV Cache
372- for req in context_batch :
373- req_beam_width = req . sampling_config . beam_width
374- if 'cp_type' in self . mapping . cp_config and 'star_attention' == self . mapping . cp_config [
375- 'cp_type' ]:
376- if req . ctx_iters == 0 :
377- seq_len = sum (
378- len ( ctx_block ) for ctx_block in req . ctx_blocks )
379- self . impl . add_sequence (
380- req . py_request_id ,
381- seq_len + ( len ( req . query_id ) if self . mapping . cp_rank
382- == self .mapping .cp_size - 1 else 0 ),
383- req_beam_width , req )
384- else :
385- if req . is_first_context_chunk :
386- self . impl . add_sequence ( req . py_request_id , req .prompt_len ,
387- req_beam_width , req )
388- for _ in range ( self . num_extra_kv_tokens ):
389- self . impl . add_token ( req . py_request_id )
390- for _ in range (get_draft_token_length ( req ) ):
391- self .impl .add_token (req .py_request_id )
392-
393- for req in generation_batch :
394- self . impl . add_token ( req . py_request_id )
395- for _ in range ( get_draft_token_length ( req )) :
398+ with request_context ( self . py_is_draft , scheduled_batch ):
399+ context_batch = scheduled_batch .context_requests
400+ generation_batch = scheduled_batch . generation_requests
401+ # allocate KV Cache
402+ for req in context_batch :
403+ req_beam_width = req . sampling_config . beam_width
404+ if 'cp_type' in self . mapping . cp_config and 'star_attention' == self . mapping . cp_config [
405+ 'cp_type' ] :
406+ if req . ctx_iters == 0 :
407+ seq_len = sum (
408+ len ( ctx_block ) for ctx_block in req . ctx_blocks )
409+ self . impl . add_sequence (
410+ req . py_request_id ,
411+ seq_len + ( len ( req . query_id ) if self .mapping .cp_rank
412+ == self . mapping . cp_size - 1 else 0 ),
413+ req_beam_width , req )
414+ else :
415+ if req .is_first_context_chunk :
416+ self . impl . add_sequence ( req . py_request_id ,
417+ req . prompt_len , req_beam_width ,
418+ req )
419+ for _ in range (self . num_extra_kv_tokens ):
420+ self .impl .add_token (req .py_request_id )
421+ for _ in range ( get_draft_token_length ( req )):
422+ self . impl . add_token ( req . py_request_id )
423+
424+ for req in generation_batch :
396425 self .impl .add_token (req .py_request_id )
426+ for _ in range (get_draft_token_length (req )):
427+ self .impl .add_token (req .py_request_id )
397428
398429 def add_dummy_requests (
399430 self ,
@@ -1156,11 +1187,7 @@ def get_resource_manager(self, name: str) -> BaseResourceManager:
11561187
11571188 @nvtx_range ("prepare_resources" )
11581189 def prepare_resources (self , scheduled_batch : ScheduledRequests ):
1159- for resource_mgr_type , resource_manager in self .resource_managers .items (
1160- ):
1161- # Delay the preparation of draft kv cache manager to ModelDrafter.prepare_draft_tokens.
1162- if resource_mgr_type == ResourceManagerType .DRAFT_KV_CACHE_MANAGER :
1163- continue
1190+ for _ , resource_manager in self .resource_managers .items ():
11641191 if hasattr (resource_manager , "prepare_resources" ):
11651192 resource_manager .prepare_resources (scheduled_batch )
11661193
0 commit comments