Skip to content

Commit 3b13608

Browse files
committed
add max_total_draft_tokens
Signed-off-by: Yue Weng <[email protected]>
1 parent 4a8ac8d commit 3b13608

File tree

18 files changed

+156
-80
lines changed

18 files changed

+156
-80
lines changed

tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -318,13 +318,11 @@ def create_autodeploy_executor(ad_config: LlmArgs):
318318
max_draft_len = (
319319
0 if ad_config.speculative_config is None else ad_config.speculative_config.max_draft_len
320320
)
321-
max_total_draft_tokens = 0
322-
if ad_config.speculative_config is None:
323-
max_total_draft_tokens = 0
324-
elif hasattr(ad_config.speculative_config, "max_total_draft_tokens"):
325-
max_total_draft_tokens = ad_config.speculative_config.max_total_draft_tokens
326-
else:
327-
max_total_draft_tokens = max_draft_len
321+
max_total_draft_tokens = (
322+
0
323+
if ad_config.speculative_config is None
324+
else ad_config.speculative_config.max_total_draft_tokens
325+
)
328326

329327
# initialize model engine
330328
engine = ADEngine.build_from_config(ad_config=ad_config)
@@ -399,6 +397,7 @@ def create_autodeploy_executor(ad_config: LlmArgs):
399397
max_input_len=ad_config.max_input_len,
400398
max_batch_size=ad_config.max_batch_size,
401399
max_draft_len=max_draft_len,
400+
max_total_draft_tokens=max_total_draft_tokens,
402401
max_beam_width=ad_config.max_beam_width,
403402
)
404403
return py_executor

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -510,7 +510,7 @@ def __init__(
510510
aux_stream: Optional[torch.cuda.Stream] = None,
511511
):
512512
config = model_config.pretrained_config
513-
predicted_tokens_per_seq = model_config.spec_config.max_draft_len + 1 if model_config.spec_config is not None else 1
513+
predicted_tokens_per_seq = model_config.spec_config.max_total_draft_tokens + 1 if model_config.spec_config is not None else 1
514514
super().__init__(hidden_size=config.hidden_size,
515515
num_attention_heads=config.num_attention_heads,
516516
num_key_value_heads=config.num_key_value_heads,

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -250,10 +250,10 @@ def _get_token_num_for_estimation(self) -> int:
250250
if not pytorch_backend_config.disable_overlap_scheduler:
251251
num_extra_tokens_per_seq = num_extra_tokens_per_seq + 1
252252
if spec_cfg is not None:
253-
num_extra_tokens_per_seq += spec_cfg.max_draft_len
253+
num_extra_tokens_per_seq += spec_cfg.max_total_draft_tokens
254254

255255
if spec_cfg is not None:
256-
num_extra_tokens_per_seq += spec_cfg.max_draft_len
256+
num_extra_tokens_per_seq += spec_cfg.max_total_draft_tokens
257257
num_extra_tokens_per_seq += get_num_extra_kv_tokens(spec_cfg)
258258

259259
if self._dummy_reqs is None:
@@ -808,6 +808,8 @@ def create_py_executor_instance(
808808
max_beam_width=max_beam_width,
809809
max_draft_len=spec_config.max_draft_len
810810
if spec_config is not None else 0,
811+
max_total_draft_tokens=spec_config.max_total_draft_tokens
812+
if spec_config is not None else 0,
811813
kv_cache_transceiver=kv_cache_transceiver,
812814
guided_decoder=guided_decoder,
813815
start_worker=start_worker,
@@ -824,13 +826,8 @@ def create_torch_sampler_args(mapping: Mapping, *, max_seq_len: int,
824826
max_num_sequences = max_batch_size * mapping.pp_size
825827
max_draft_len = (0 if speculative_config is None else
826828
speculative_config.max_draft_len)
827-
max_total_draft_tokens = 0
828-
if speculative_config is None:
829-
max_total_draft_tokens = 0
830-
elif hasattr(speculative_config, 'max_total_draft_tokens'):
831-
max_total_draft_tokens = speculative_config.max_total_draft_tokens
832-
else:
833-
max_total_draft_tokens = max_draft_len
829+
max_total_draft_tokens = (0 if speculative_config is None else
830+
speculative_config.max_total_draft_tokens)
834831

835832
return TorchSampler.Args(
836833
max_seq_len=max_seq_len,

tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@ def enable_spec_decode(self):
9393
@property
9494
def max_possible_draft_len(self):
9595
engine = self._get_engine()
96-
return (engine.original_max_draft_len if self.enable_spec_decode else 0)
96+
return (engine.original_max_total_draft_tokens
97+
if self.enable_spec_decode else 0)
9798

9899
def get_graph_key(
99100
self,
@@ -102,10 +103,12 @@ def get_graph_key(
102103
engine = self._get_engine()
103104
if engine.is_draft_model and spec_resource_manager is not None and isinstance(
104105
spec_resource_manager, Eagle3ResourceManager):
106+
# If 'is_first_draft' is True, even with tree decoding, the length of draft_len will only be 'max_draft_len', not 'max_total_draft_token'.
107+
# Because we will pad the input to 'max_draft_len' length for the first draft layer.
105108
draft_len = engine.original_max_draft_len if spec_resource_manager.is_first_draft else 0
106109
key = (batch_size, draft_len, spec_resource_manager.is_first_draft)
107110
else:
108-
draft_len = self.spec_config.max_draft_len if self.enable_spec_decode else 0
111+
draft_len = self.spec_config.max_total_draft_tokens if self.enable_spec_decode else 0
109112
key = (batch_size, draft_len, False)
110113
return key
111114

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -93,11 +93,11 @@ def warmup(self, resource_manager: ResourceManager) -> None:
9393

9494
def _filter_cuda_graph_batch_sizes(cuda_graph_batch_sizes: list[int],
9595
max_batch_size: int, max_num_tokens: int,
96-
max_draft_len: int,
96+
max_total_draft_tokens: int,
9797
enable_padding: bool) -> list[int]:
9898
# This is the largest possible batch size for a pure decoding batch.
9999
max_cuda_graph_bs = min(max_batch_size,
100-
int(max_num_tokens / (1 + max_draft_len)))
100+
int(max_num_tokens / (1 + max_total_draft_tokens)))
101101

102102
result = []
103103
# This function assumes cuda_graph_batch_sizes is sorted
@@ -162,11 +162,13 @@ def __init__(
162162
ExpertStatistic.create(self.dist.rank)
163163
self.pytorch_backend_config = pytorch_backend_config
164164
self.original_max_draft_len = spec_config.max_draft_len if spec_config is not None else 0
165+
self.original_max_total_draft_tokens = spec_config.max_total_draft_tokens if spec_config is not None else 0
165166

166167
# The draft model won't have any draft tokens attached to
167168
# generation requests when we invoke it autoregressively
168169
if spec_config is not None and is_draft_model:
169170
spec_config.max_draft_len = 0
171+
spec_config.max_total_draft_tokens = 0
170172
self.spec_config = spec_config
171173
self.is_spec_decode = spec_config is not None
172174
self.sparse_attention_config = sparse_attention_config
@@ -277,7 +279,7 @@ def __init__(
277279
self.spec_metadata = None
278280
update_spec_config_from_model_config(self.spec_config,
279281
self.model.config)
280-
max_num_draft_tokens = self.original_max_draft_len * batch_size
282+
max_num_draft_tokens = self.original_max_total_draft_tokens * batch_size
281283
self.draft_tokens_cuda = torch.empty((max_num_draft_tokens, ),
282284
dtype=torch.int,
283285
device='cuda')
@@ -297,9 +299,11 @@ def __init__(
297299
self.without_logits = self.spec_config.spec_dec_mode.without_logits(
298300
) or self.model_is_wrapped
299301
self.max_draft_len = spec_config.max_draft_len
302+
self.max_total_draft_tokens = spec_config.max_total_draft_tokens
300303
else:
301304
self.without_logits = False
302305
self.max_draft_len = 0
306+
self.max_total_draft_tokens = 0
303307

304308
self.guided_decoder: Optional[CapturableGuidedDecoder] = None
305309

@@ -320,7 +324,7 @@ def __init__(
320324

321325
self._cuda_graph_batch_sizes = _filter_cuda_graph_batch_sizes(
322326
pytorch_backend_config.cuda_graph_batch_sizes, self.batch_size,
323-
self.max_num_tokens, self.original_max_draft_len,
327+
self.max_num_tokens, self.original_max_total_draft_tokens,
324328
self._cuda_graph_padding_enabled
325329
) if pytorch_backend_config.cuda_graph_batch_sizes else []
326330

@@ -364,7 +368,7 @@ def register_forward_pass_callable(self, callable: Callable):
364368

365369
@property
366370
def runtime_draft_len(self):
367-
return self.max_draft_len if self.enable_spec_decode else 0
371+
return self.max_total_draft_tokens if self.enable_spec_decode else 0
368372

369373
def set_lora_model_config(self,
370374
lora_target_modules: list[str],
@@ -585,20 +589,20 @@ def _capture_generation_cuda_graphs(self,
585589
if self.model_is_wrapped and self.is_spec_decode and spec_resource_manager is not None and isinstance(
586590
spec_resource_manager, Eagle3ResourceManager):
587591
# The CDL path uses draft_len > 0 for the number of iterations in the drafting loop.
588-
draft_lengths.append(self.original_max_draft_len)
592+
draft_lengths.append(self.original_max_total_draft_tokens)
589593
else:
590-
draft_lengths.append(self.max_draft_len)
594+
draft_lengths.append(self.max_total_draft_tokens)
591595
else:
592596
# For non-draft model, we also capture the CUDA graph instance for draft length 0,
593597
# so that when we disable spec decode at runtime, we can still run the captured graph.
594598
# Note that for one engine mode, we are not able to turn off spec decode at runtime.
595-
if (self.max_draft_len > 0
599+
if (self.max_total_draft_tokens > 0
596600
and not self.spec_config.spec_dec_mode.use_one_engine()
597601
# Assume that speculation is always on if the user didn't give us a max_concurrency
598602
# value. This will save on memory.
599603
and self.spec_config.max_concurrency is not None):
600604
draft_lengths.append(0)
601-
draft_lengths = [self.max_draft_len]
605+
draft_lengths = [self.max_total_draft_tokens]
602606

603607
for bs in cuda_graph_batch_sizes:
604608
if bs > self.batch_size:
@@ -757,7 +761,7 @@ def _create_warmup_request(
757761
num_ctx_requests + num_gen_tokens)),
758762
token_nums=[1] * num_gen_tokens,
759763
is_gen=True,
760-
max_num_draft_tokens=self.max_draft_len,
764+
max_num_draft_tokens=self.max_total_draft_tokens,
761765
use_mrope=self.use_mrope)
762766
if spec_resource_manager is not None:
763767
spec_resource_manager.add_dummy_requests(request_ids=list(
@@ -830,7 +834,7 @@ def _create_cuda_graph_warmup_request(
830834
def _get_cuda_graph_draft_lengths(
831835
self, resource_manager: ResourceManager) -> List[int]:
832836
"""Determines the draft lengths for which to capture CUDA graphs."""
833-
draft_lengths = [self.max_draft_len]
837+
draft_lengths = [self.max_total_draft_tokens]
834838
spec_resource_manager = resource_manager.get_resource_manager(
835839
ResourceManagerType.SPEC_RESOURCE_MANAGER)
836840

@@ -1027,7 +1031,7 @@ def _preprocess_inputs(self, inputs: Dict[str, Any]):
10271031
"""
10281032
if self.enable_spec_decode and not self._disable_overlap_scheduler:
10291033
# When enabling overlap scheduler, the kv cache for draft tokens will
1030-
# be prepared in advance by using the max_draft_len. But we need to use
1034+
# be prepared in advance by using the max_total_draft_tokens. But we need to use
10311035
# new_tokens_lens_device to get the real past kv lengths and the
10321036
# correct position ids. And to avoid blocking the async data transfer,
10331037
# we need to preprocess the inputs in forward to update the position_ids and
@@ -2252,7 +2256,7 @@ def forward(
22522256
# attn_metadata now depends on spec_metadata since it determines the shape/content of spec_dec parameter Tensors
22532257
is_spec_dec_mode = spec_metadata.spec_dec_mode.attention_need_spec_dec_mode(
22542258
spec_resource_manager, self.is_draft_model, self.attn_backend,
2255-
self.model_is_wrapped)
2259+
self.model_is_wrapped, spec_metadata.is_spec_dec_tree)
22562260
attn_metadata.update_spec_dec_param(
22572261
is_spec_dec_mode, spec_metadata.is_spec_dec_tree,
22582262
spec_metadata.is_spec_dec_dynamic_tree,

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ def __init__(self,
160160
max_batch_size: int = 8,
161161
max_beam_width: int = 1,
162162
max_draft_len: int = 0,
163+
max_total_draft_tokens: int = 0,
163164
kv_cache_transceiver: Optional[KvCacheTransceiver] = None,
164165
guided_decoder: Optional[GuidedDecoder] = None,
165166
garbage_collection_gen0_threshold: Optional[int] = None,
@@ -195,6 +196,7 @@ def __init__(self,
195196
self.active = True
196197
self.max_beam_width = max_beam_width
197198
self.max_draft_len = max_draft_len
199+
self.max_total_draft_tokens = max_total_draft_tokens
198200
self.max_num_tokens = model_engine.pytorch_backend_config.max_num_tokens
199201
self.print_log = model_engine.pytorch_backend_config.print_iter_log
200202
self.enable_iter_perf_stats = model_engine.pytorch_backend_config.enable_iter_perf_stats
@@ -1040,7 +1042,7 @@ def _prepare_and_schedule_batch(self):
10401042
self.use_spec_decode = self.drafter.should_use_spec_decode(
10411043
self.active_requests, self.max_batch_size,
10421044
self.model_engine.max_num_tokens,
1043-
self.model_engine.spec_config.max_draft_len)
1045+
self.model_engine.spec_config.max_total_draft_tokens)
10441046
logger.debug(f"Use spec decode: {self.use_spec_decode}")
10451047
self.model_engine.enable_spec_decode = self.use_spec_decode
10461048

@@ -1050,10 +1052,10 @@ def _prepare_and_schedule_batch(self):
10501052
LlmRequestState.GENERATION_IN_PROGRESS,
10511053
LlmRequestState.DISAGG_GENERATION_INIT):
10521054
continue
1053-
max_draft_len = self.model_engine.spec_config.max_draft_len
1055+
max_total_draft_tokens = self.model_engine.spec_config.max_total_draft_tokens
10541056
request.draft_tokens = [
10551057
0
1056-
] * max_draft_len if max_draft_len > 0 else []
1058+
] * max_total_draft_tokens if max_total_draft_tokens > 0 else []
10571059

10581060
# When overlap scheduler is enabled, and we already prepared the draft tokens in the previous batch,
10591061
# we don't need to initialize py_draft_tokens at this stage because we haven't append the accepted tokens to the request yet.
@@ -1224,11 +1226,11 @@ def _prepare_draft_requests(self):
12241226
continue
12251227

12261228
req.py_last_draft_tokens = req.py_draft_tokens
1227-
max_draft_len = self.model_engine.spec_config.max_draft_len
1229+
max_total_draft_tokens = self.model_engine.spec_config.max_total_draft_tokens
12281230

1229-
if max_draft_len > 0 and self.use_spec_decode:
1230-
req.py_draft_tokens = [0] * max_draft_len
1231-
req.py_draft_pages_allocated = max_draft_len
1231+
if max_total_draft_tokens > 0 and self.use_spec_decode:
1232+
req.py_draft_tokens = [0] * max_total_draft_tokens
1233+
req.py_draft_pages_allocated = max_total_draft_tokens
12321234
else:
12331235
req.py_draft_tokens = []
12341236
req.py_draft_pages_allocated = 0
@@ -1616,7 +1618,7 @@ def _pad_attention_dp_dummy_request(self):
16161618
request_ids=[0],
16171619
is_gen=True,
16181620
prepare_resource=True,
1619-
max_num_draft_tokens=self.max_draft_len,
1621+
max_num_draft_tokens=self.max_total_draft_tokens,
16201622
)[0]
16211623
llm_request.is_attention_dp_dummy = True
16221624
spec_resource_manager = self.resource_manager.get_resource_manager(

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,9 @@ def drafting_loop_wrapper(model):
357357
from tensorrt_llm._torch.speculative.drafting_loops import \
358358
ChainDrafter
359359

360-
return ChainDrafter(spec_config.max_draft_len, model)
360+
return ChainDrafter(spec_config.max_draft_len,
361+
spec_config.max_total_draft_tokens,
362+
model)
361363
else:
362364
drafting_loop_wrapper = None
363365

@@ -397,11 +399,11 @@ def drafting_loop_wrapper(model):
397399
if not pytorch_backend_config.disable_overlap_scheduler:
398400
model_engine_max_seq_len = model_engine.max_seq_len + 1
399401
if spec_config is not None:
400-
model_engine_max_seq_len += spec_config.max_draft_len
402+
model_engine_max_seq_len += spec_config.max_total_draft_tokens
401403

402404
if spec_config is not None:
403405
model_engine_max_seq_len += get_num_extra_kv_tokens(spec_config)
404-
model_engine_max_seq_len += spec_config.max_draft_len
406+
model_engine_max_seq_len += spec_config.max_total_draft_tokens
405407

406408
max_seq_len = model_engine_max_seq_len
407409
max_num_tokens = model_engine.max_num_tokens
@@ -471,7 +473,8 @@ def drafting_loop_wrapper(model):
471473
"vocab_size_padded": model_engine.model.vocab_size_padded
472474
}
473475
if spec_config is not None:
474-
kwargs["max_num_draft_tokens"] = spec_config.max_draft_len
476+
kwargs[
477+
"max_num_draft_tokens"] = spec_config.max_total_draft_tokens
475478

476479
if spec_config is None or spec_config.spec_dec_mode.support_guided_decoder(
477480
):

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -867,7 +867,7 @@ class Args:
867867

868868
def __init__(self, args: Args):
869869
self.max_seq_len = args.max_seq_len
870-
self.max_tokens = args.max_draft_len + 1
870+
self.max_tokens = args.max_total_draft_tokens + 1
871871
assert args.max_beam_width == self.MAX_BEAM_WIDTH, "TorchSampler only supports beam_width = 1"
872872
self.max_num_sequences = args.max_num_sequences
873873

@@ -1002,8 +1002,8 @@ def _process_draft_tokens_tree(self, request: LlmRequest,
10021002
we can find the longest match by comparing all the paths.
10031003
Args:
10041004
request: LlmRequest. The request with draft tokens.
1005-
new_tokens: torch.Tensor. [max_draft_len + 1, max_num_sequences, MAX_BEAM_WIDTH], host buffer. The tokens generated by the target model
1006-
The relationship between [max_draft_len + 1] and the draft token tree:
1005+
new_tokens: torch.Tensor. [max_total_draft_tokens + 1, max_num_sequences, MAX_BEAM_WIDTH], host buffer. The tokens generated by the target model
1006+
The relationship between [max_total_draft_tokens + 1] and the draft token tree:
10071007
If the current node is accepted, what is the NEXT token_id that the target model will generate?
10081008
For example, new_tokens[0, req_idx, 1] indicates the NEXT token_id sampled from the root node in the draft token tree if it is accepted.
10091009
We know that the root node in the draft token tree is always accepted. Therefore, new_tokens[0, req_idx, 1] indicates the token_id following the root node,

0 commit comments

Comments
 (0)