@@ -90,11 +90,11 @@ def warmup(self, resource_manager: ResourceManager) -> None:
9090
9191def _filter_cuda_graph_batch_sizes (cuda_graph_batch_sizes : list [int ],
9292 max_batch_size : int , max_num_tokens : int ,
93- max_draft_len : int ,
93+ max_total_draft_tokens : int ,
9494 enable_padding : bool ) -> list [int ]:
9595 # This is the largest possible batch size for a pure decoding batch.
9696 max_cuda_graph_bs = min (max_batch_size ,
97- int (max_num_tokens / (1 + max_draft_len )))
97+ int (max_num_tokens / (1 + max_total_draft_tokens )))
9898
9999 result = []
100100 # This function assumes cuda_graph_batch_sizes is sorted
@@ -157,11 +157,13 @@ def __init__(
157157 ExpertStatistic .create (self .dist .rank )
158158 self .pytorch_backend_config = pytorch_backend_config
159159 self .original_max_draft_len = spec_config .max_draft_len if spec_config is not None else 0
160+ self .original_max_total_draft_tokens = spec_config .max_total_draft_tokens if spec_config is not None else 0
160161
161162 # The draft model won't have any draft tokens attached to
162163 # generation requests when we invoke it autoregressively
163164 if spec_config is not None and is_draft_model :
164165 spec_config .max_draft_len = 0
166+ spec_config .max_total_draft_tokens = 0
165167 self .spec_config = spec_config
166168 self .is_spec_decode = spec_config is not None
167169 self .enable_spec_decode = self .is_spec_decode
@@ -267,7 +269,7 @@ def __init__(
267269 self .spec_metadata = None
268270 update_spec_config_from_model_config (self .spec_config ,
269271 self .model .config )
270- max_num_draft_tokens = self .original_max_draft_len * batch_size
272+ max_num_draft_tokens = self .original_max_total_draft_tokens * batch_size
271273 self .draft_tokens_cuda = torch .empty ((max_num_draft_tokens , ),
272274 dtype = torch .int ,
273275 device = 'cuda' )
@@ -287,9 +289,11 @@ def __init__(
287289 self .without_logits = self .spec_config .spec_dec_mode .without_logits (
288290 ) or self .model_is_wrapped
289291 self .max_draft_len = spec_config .max_draft_len
292+ self .max_total_draft_tokens = spec_config .max_total_draft_tokens
290293 else :
291294 self .without_logits = False
292295 self .max_draft_len = 0
296+ self .max_total_draft_tokens = 0
293297
294298 self .guided_decoder : Optional [CapturableGuidedDecoder ] = None
295299
@@ -310,7 +314,7 @@ def __init__(
310314
311315 self ._cuda_graph_batch_sizes = _filter_cuda_graph_batch_sizes (
312316 pytorch_backend_config .cuda_graph_batch_sizes , self .batch_size ,
313- self .max_num_tokens , self .original_max_draft_len ,
317+ self .max_num_tokens , self .original_max_total_draft_tokens ,
314318 self ._cuda_graph_padding_enabled
315319 ) if pytorch_backend_config .cuda_graph_batch_sizes else []
316320
@@ -351,7 +355,7 @@ def __init__(
351355
352356 @property
353357 def runtime_draft_len (self ):
354- return self .max_draft_len if self .enable_spec_decode else 0
358+ return self .max_total_draft_tokens if self .enable_spec_decode else 0
355359
356360 def set_lora_model_config (self ,
357361 lora_target_modules : list [str ],
@@ -458,6 +462,8 @@ def warmup(self, resource_manager: ResourceManager) -> None:
458462
459463 def get_num_extra_decoding_steps ():
460464 if isinstance (self .model , ChainDrafter ):
465+ # We should use max_draft_len instead of max_total_draft_tokens here,
466+ # because max_draft_len indicates the real number of draft layers.
461467 return self .model .max_draft_len
462468 else :
463469 assert not self .model_is_wrapped , (
@@ -595,7 +601,7 @@ def get_warmup_request(num_tokens: int, num_gen_tokens: int):
595601 num_ctx_requests + num_gen_tokens )),
596602 token_nums = [1 ] * num_gen_tokens ,
597603 is_gen = True ,
598- max_num_draft_tokens = self .max_draft_len ,
604+ max_num_draft_tokens = self .max_total_draft_tokens ,
599605 use_mrope = self .use_mrope )
600606 if spec_resource_manager is not None :
601607 spec_resource_manager .add_dummy_requests (request_ids = list (
@@ -610,7 +616,7 @@ def get_warmup_request(num_tokens: int, num_gen_tokens: int):
610616
611617 curr_max_num_tokens = min (
612618 kv_cache_manager .get_num_available_tokens (
613- self .original_max_draft_len ), self .max_num_tokens ,
619+ self .original_max_total_draft_tokens ), self .max_num_tokens ,
614620 self .batch_size * (self .max_seq_len - 1 ))
615621
616622 def get_autotune_warmup_request ():
@@ -700,20 +706,20 @@ def release_batch(result: ScheduledRequests | None):
700706 if self .model_is_wrapped and self .is_spec_decode and spec_resource_manager is not None and isinstance (
701707 spec_resource_manager , Eagle3ResourceManager ):
702708 # The CDL path uses draft_len > 0 for the number of iterations in the drafting loop.
703- draft_lengths .append (self .original_max_draft_len )
709+ draft_lengths .append (self .original_max_total_draft_tokens )
704710 else :
705711 draft_lengths .append (self .max_draft_len )
706712 else :
707713 # For non-draft model, we also capture the CUDA graph instance for draft length 0,
708714 # so that when we disable spec decode at runtime, we can still run the captured graph.
709715 # Note that for one engine mode, we are not able to turn off spec decode at runtime.
710- if (self .max_draft_len > 0
716+ if (self .max_total_draft_tokens > 0
711717 and not self .spec_config .spec_dec_mode .use_one_engine ()
712718 # Assume that speculation is always on if the user didn't give us a max_concurrency
713719 # value. This will save on memory.
714720 and self .spec_config .max_concurrency is not None ):
715721 draft_lengths .append (0 )
716- draft_lengths = [self .max_draft_len ]
722+ draft_lengths = [self .max_total_draft_tokens ]
717723
718724 for bs in cuda_graph_batch_sizes :
719725 if bs > self .batch_size :
@@ -941,7 +947,7 @@ def _preprocess_inputs(self, inputs: Dict[str, Any]):
941947 """
942948 if self .enable_spec_decode and not self ._disable_overlap_scheduler :
943949 # When enabling overlap scheduler, the kv cache for draft tokens will
944- # be prepared in advance by using the max_draft_len . But we need to use
950+ # be prepared in advance by using the max_total_draft_tokens . But we need to use
945951 # new_tokens_lens_device to get the real past kv lengths and the
946952 # correct position ids. And to avoid blocking the async data transfer,
947953 # we need to preprocess the inputs in forward to update the position_ids and
0 commit comments