@@ -93,11 +93,11 @@ def warmup(self, resource_manager: ResourceManager) -> None:
9393
9494def _filter_cuda_graph_batch_sizes (cuda_graph_batch_sizes : list [int ],
9595 max_batch_size : int , max_num_tokens : int ,
96- max_draft_len : int ,
96+ max_total_draft_tokens : int ,
9797 enable_padding : bool ) -> list [int ]:
9898 # This is the largest possible batch size for a pure decoding batch.
9999 max_cuda_graph_bs = min (max_batch_size ,
100- int (max_num_tokens / (1 + max_draft_len )))
100+ int (max_num_tokens / (1 + max_total_draft_tokens )))
101101
102102 result = []
103103 # This function assumes cuda_graph_batch_sizes is sorted
@@ -162,11 +162,13 @@ def __init__(
162162 ExpertStatistic .create (self .dist .rank )
163163 self .pytorch_backend_config = pytorch_backend_config
164164 self .original_max_draft_len = spec_config .max_draft_len if spec_config is not None else 0
165+ self .original_max_total_draft_tokens = spec_config .max_total_draft_tokens if spec_config is not None else 0
165166
166167 # The draft model won't have any draft tokens attached to
167168 # generation requests when we invoke it autoregressively
168169 if spec_config is not None and is_draft_model :
169170 spec_config .max_draft_len = 0
171+ spec_config .max_total_draft_tokens = 0
170172 self .spec_config = spec_config
171173 self .is_spec_decode = spec_config is not None
172174 self .sparse_attention_config = sparse_attention_config
@@ -277,7 +279,7 @@ def __init__(
277279 self .spec_metadata = None
278280 update_spec_config_from_model_config (self .spec_config ,
279281 self .model .config )
280- max_num_draft_tokens = self .original_max_draft_len * batch_size
282+ max_num_draft_tokens = self .original_max_total_draft_tokens * batch_size
281283 self .draft_tokens_cuda = torch .empty ((max_num_draft_tokens , ),
282284 dtype = torch .int ,
283285 device = 'cuda' )
@@ -297,9 +299,11 @@ def __init__(
297299 self .without_logits = self .spec_config .spec_dec_mode .without_logits (
298300 ) or self .model_is_wrapped
299301 self .max_draft_len = spec_config .max_draft_len
302+ self .max_total_draft_tokens = spec_config .max_total_draft_tokens
300303 else :
301304 self .without_logits = False
302305 self .max_draft_len = 0
306+ self .max_total_draft_tokens = 0
303307
304308 self .guided_decoder : Optional [CapturableGuidedDecoder ] = None
305309
@@ -320,7 +324,7 @@ def __init__(
320324
321325 self ._cuda_graph_batch_sizes = _filter_cuda_graph_batch_sizes (
322326 pytorch_backend_config .cuda_graph_batch_sizes , self .batch_size ,
323- self .max_num_tokens , self .original_max_draft_len ,
327+ self .max_num_tokens , self .original_max_total_draft_tokens ,
324328 self ._cuda_graph_padding_enabled
325329 ) if pytorch_backend_config .cuda_graph_batch_sizes else []
326330
@@ -364,7 +368,7 @@ def register_forward_pass_callable(self, callable: Callable):
364368
365369 @property
366370 def runtime_draft_len (self ):
367- return self .max_draft_len if self .enable_spec_decode else 0
371+ return self .max_total_draft_tokens if self .enable_spec_decode else 0
368372
369373 def set_lora_model_config (self ,
370374 lora_target_modules : list [str ],
@@ -585,20 +589,20 @@ def _capture_generation_cuda_graphs(self,
585589 if self .model_is_wrapped and self .is_spec_decode and spec_resource_manager is not None and isinstance (
586590 spec_resource_manager , Eagle3ResourceManager ):
587591 # The CDL path uses draft_len > 0 for the number of iterations in the drafting loop.
588- draft_lengths .append (self .original_max_draft_len )
592+ draft_lengths .append (self .original_max_total_draft_tokens )
589593 else :
590- draft_lengths .append (self .max_draft_len )
594+ draft_lengths .append (self .max_total_draft_tokens )
591595 else :
592596 # For non-draft model, we also capture the CUDA graph instance for draft length 0,
593597 # so that when we disable spec decode at runtime, we can still run the captured graph.
594598 # Note that for one engine mode, we are not able to turn off spec decode at runtime.
595- if (self .max_draft_len > 0
599+ if (self .max_total_draft_tokens > 0
596600 and not self .spec_config .spec_dec_mode .use_one_engine ()
597601 # Assume that speculation is always on if the user didn't give us a max_concurrency
598602 # value. This will save on memory.
599603 and self .spec_config .max_concurrency is not None ):
600604 draft_lengths .append (0 )
601- draft_lengths = [self .max_draft_len ]
605+ draft_lengths = [self .max_total_draft_tokens ]
602606
603607 for bs in cuda_graph_batch_sizes :
604608 if bs > self .batch_size :
@@ -757,7 +761,7 @@ def _create_warmup_request(
757761 num_ctx_requests + num_gen_tokens )),
758762 token_nums = [1 ] * num_gen_tokens ,
759763 is_gen = True ,
760- max_num_draft_tokens = self .max_draft_len ,
764+ max_num_draft_tokens = self .max_total_draft_tokens ,
761765 use_mrope = self .use_mrope )
762766 if spec_resource_manager is not None :
763767 spec_resource_manager .add_dummy_requests (request_ids = list (
@@ -830,7 +834,7 @@ def _create_cuda_graph_warmup_request(
830834 def _get_cuda_graph_draft_lengths (
831835 self , resource_manager : ResourceManager ) -> List [int ]:
832836 """Determines the draft lengths for which to capture CUDA graphs."""
833- draft_lengths = [self .max_draft_len ]
837+ draft_lengths = [self .max_total_draft_tokens ]
834838 spec_resource_manager = resource_manager .get_resource_manager (
835839 ResourceManagerType .SPEC_RESOURCE_MANAGER )
836840
@@ -1027,7 +1031,7 @@ def _preprocess_inputs(self, inputs: Dict[str, Any]):
10271031 """
10281032 if self .enable_spec_decode and not self ._disable_overlap_scheduler :
10291033 # When enabling overlap scheduler, the kv cache for draft tokens will
1030- # be prepared in advance by using the max_draft_len . But we need to use
1034+ # be prepared in advance by using the max_total_draft_tokens . But we need to use
10311035 # new_tokens_lens_device to get the real past kv lengths and the
10321036 # correct position ids. And to avoid blocking the async data transfer,
10331037 # we need to preprocess the inputs in forward to update the position_ids and
@@ -2252,7 +2256,7 @@ def forward(
22522256 # attn_metadata now depends on spec_metadata since it determines the shape/content of spec_dec parameter Tensors
22532257 is_spec_dec_mode = spec_metadata .spec_dec_mode .attention_need_spec_dec_mode (
22542258 spec_resource_manager , self .is_draft_model , self .attn_backend ,
2255- self .model_is_wrapped )
2259+ self .model_is_wrapped , spec_metadata . is_spec_dec_tree )
22562260 attn_metadata .update_spec_dec_param (
22572261 is_spec_dec_mode , spec_metadata .is_spec_dec_tree ,
22582262 spec_metadata .is_spec_dec_dynamic_tree ,
0 commit comments