Skip to content

Commit 8dc4aac

Browse files
authored
[TRTLLM-8160][feat] Add max_total_draft_tokens (#8366)
Signed-off-by: Yue Weng <[email protected]>
1 parent a0024f4 commit 8dc4aac

File tree

18 files changed

+156
-80
lines changed

18 files changed

+156
-80
lines changed

tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -320,13 +320,11 @@ def create_autodeploy_executor(ad_config: LlmArgs):
320320
max_draft_len = (
321321
0 if ad_config.speculative_config is None else ad_config.speculative_config.max_draft_len
322322
)
323-
max_total_draft_tokens = 0
324-
if ad_config.speculative_config is None:
325-
max_total_draft_tokens = 0
326-
elif hasattr(ad_config.speculative_config, "max_total_draft_tokens"):
327-
max_total_draft_tokens = ad_config.speculative_config.max_total_draft_tokens
328-
else:
329-
max_total_draft_tokens = max_draft_len
323+
max_total_draft_tokens = (
324+
0
325+
if ad_config.speculative_config is None
326+
else ad_config.speculative_config.max_total_draft_tokens
327+
)
330328

331329
# initialize model engine
332330
engine = ADEngine.build_from_config(ad_config=ad_config)
@@ -417,6 +415,7 @@ def create_autodeploy_executor(ad_config: LlmArgs):
417415
max_input_len=ad_config.max_input_len,
418416
max_batch_size=ad_config.max_batch_size,
419417
max_draft_len=max_draft_len,
418+
max_total_draft_tokens=max_total_draft_tokens,
420419
max_beam_width=ad_config.max_beam_width,
421420
)
422421
return py_executor

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -510,7 +510,7 @@ def __init__(
510510
aux_stream: Optional[torch.cuda.Stream] = None,
511511
):
512512
config = model_config.pretrained_config
513-
predicted_tokens_per_seq = model_config.spec_config.max_draft_len + 1 if model_config.spec_config is not None else 1
513+
predicted_tokens_per_seq = model_config.spec_config.max_total_draft_tokens + 1 if model_config.spec_config is not None else 1
514514
super().__init__(hidden_size=config.hidden_size,
515515
num_attention_heads=config.num_attention_heads,
516516
num_key_value_heads=config.num_key_value_heads,

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -250,10 +250,10 @@ def _get_token_num_for_estimation(self) -> int:
250250
if not pytorch_backend_config.disable_overlap_scheduler:
251251
num_extra_tokens_per_seq = num_extra_tokens_per_seq + 1
252252
if spec_cfg is not None:
253-
num_extra_tokens_per_seq += spec_cfg.max_draft_len
253+
num_extra_tokens_per_seq += spec_cfg.max_total_draft_tokens
254254

255255
if spec_cfg is not None:
256-
num_extra_tokens_per_seq += spec_cfg.max_draft_len
256+
num_extra_tokens_per_seq += spec_cfg.max_total_draft_tokens
257257
num_extra_tokens_per_seq += get_num_extra_kv_tokens(spec_cfg)
258258

259259
if self._dummy_reqs is None:
@@ -808,6 +808,8 @@ def create_py_executor_instance(
808808
max_beam_width=max_beam_width,
809809
max_draft_len=spec_config.max_draft_len
810810
if spec_config is not None else 0,
811+
max_total_draft_tokens=spec_config.max_total_draft_tokens
812+
if spec_config is not None else 0,
811813
kv_cache_transceiver=kv_cache_transceiver,
812814
guided_decoder=guided_decoder,
813815
start_worker=start_worker,
@@ -824,13 +826,8 @@ def create_torch_sampler_args(mapping: Mapping, *, max_seq_len: int,
824826
max_num_sequences = max_batch_size * mapping.pp_size
825827
max_draft_len = (0 if speculative_config is None else
826828
speculative_config.max_draft_len)
827-
max_total_draft_tokens = 0
828-
if speculative_config is None:
829-
max_total_draft_tokens = 0
830-
elif hasattr(speculative_config, 'max_total_draft_tokens'):
831-
max_total_draft_tokens = speculative_config.max_total_draft_tokens
832-
else:
833-
max_total_draft_tokens = max_draft_len
829+
max_total_draft_tokens = (0 if speculative_config is None else
830+
speculative_config.max_total_draft_tokens)
834831

835832
return TorchSampler.Args(
836833
max_seq_len=max_seq_len,

tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@ def enable_spec_decode(self):
9393
@property
9494
def max_possible_draft_len(self):
9595
engine = self._get_engine()
96-
return (engine.original_max_draft_len if self.enable_spec_decode else 0)
96+
return (engine.original_max_total_draft_tokens
97+
if self.enable_spec_decode else 0)
9798

9899
def get_graph_key(
99100
self,
@@ -102,10 +103,12 @@ def get_graph_key(
102103
engine = self._get_engine()
103104
if engine.is_draft_model and spec_resource_manager is not None and isinstance(
104105
spec_resource_manager, Eagle3ResourceManager):
106+
# If 'is_first_draft' is True, even with tree decoding, the length of draft_len will only be 'max_draft_len', not 'max_total_draft_token'.
107+
# Because we will pad the input to 'max_draft_len' length for the first draft layer.
105108
draft_len = engine.original_max_draft_len if spec_resource_manager.is_first_draft else 0
106109
key = (batch_size, draft_len, spec_resource_manager.is_first_draft)
107110
else:
108-
draft_len = self.spec_config.max_draft_len if self.enable_spec_decode else 0
111+
draft_len = self.spec_config.max_total_draft_tokens if self.enable_spec_decode else 0
109112
key = (batch_size, draft_len, False)
110113
return key
111114

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -93,11 +93,11 @@ def warmup(self, resource_manager: ResourceManager) -> None:
9393

9494
def _filter_cuda_graph_batch_sizes(cuda_graph_batch_sizes: list[int],
9595
max_batch_size: int, max_num_tokens: int,
96-
max_draft_len: int,
96+
max_total_draft_tokens: int,
9797
enable_padding: bool) -> list[int]:
9898
# This is the largest possible batch size for a pure decoding batch.
9999
max_cuda_graph_bs = min(max_batch_size,
100-
int(max_num_tokens / (1 + max_draft_len)))
100+
int(max_num_tokens / (1 + max_total_draft_tokens)))
101101

102102
result = []
103103
# This function assumes cuda_graph_batch_sizes is sorted
@@ -162,11 +162,13 @@ def __init__(
162162
ExpertStatistic.create(self.dist.rank)
163163
self.pytorch_backend_config = pytorch_backend_config
164164
self.original_max_draft_len = spec_config.max_draft_len if spec_config is not None else 0
165+
self.original_max_total_draft_tokens = spec_config.max_total_draft_tokens if spec_config is not None else 0
165166

166167
# The draft model won't have any draft tokens attached to
167168
# generation requests when we invoke it autoregressively
168169
if spec_config is not None and is_draft_model:
169170
spec_config.max_draft_len = 0
171+
spec_config.max_total_draft_tokens = 0
170172
self.spec_config = spec_config
171173
self.is_spec_decode = spec_config is not None
172174
self.sparse_attention_config = sparse_attention_config
@@ -277,7 +279,7 @@ def __init__(
277279
self.spec_metadata = None
278280
update_spec_config_from_model_config(self.spec_config,
279281
self.model.config)
280-
max_num_draft_tokens = self.original_max_draft_len * batch_size
282+
max_num_draft_tokens = self.original_max_total_draft_tokens * batch_size
281283
self.draft_tokens_cuda = torch.empty((max_num_draft_tokens, ),
282284
dtype=torch.int,
283285
device='cuda')
@@ -297,9 +299,11 @@ def __init__(
297299
self.without_logits = self.spec_config.spec_dec_mode.without_logits(
298300
) or self.model_is_wrapped
299301
self.max_draft_len = spec_config.max_draft_len
302+
self.max_total_draft_tokens = spec_config.max_total_draft_tokens
300303
else:
301304
self.without_logits = False
302305
self.max_draft_len = 0
306+
self.max_total_draft_tokens = 0
303307

304308
self.guided_decoder: Optional[CapturableGuidedDecoder] = None
305309

@@ -320,7 +324,7 @@ def __init__(
320324

321325
self._cuda_graph_batch_sizes = _filter_cuda_graph_batch_sizes(
322326
pytorch_backend_config.cuda_graph_batch_sizes, self.batch_size,
323-
self.max_num_tokens, self.original_max_draft_len,
327+
self.max_num_tokens, self.original_max_total_draft_tokens,
324328
self._cuda_graph_padding_enabled
325329
) if pytorch_backend_config.cuda_graph_batch_sizes else []
326330

@@ -364,7 +368,7 @@ def register_forward_pass_callable(self, callable: Callable):
364368

365369
@property
366370
def runtime_draft_len(self):
367-
return self.max_draft_len if self.enable_spec_decode else 0
371+
return self.max_total_draft_tokens if self.enable_spec_decode else 0
368372

369373
def set_lora_model_config(self,
370374
lora_target_modules: list[str],
@@ -585,20 +589,20 @@ def _capture_generation_cuda_graphs(self,
585589
if self.model_is_wrapped and self.is_spec_decode and spec_resource_manager is not None and isinstance(
586590
spec_resource_manager, Eagle3ResourceManager):
587591
# The CDL path uses draft_len > 0 for the number of iterations in the drafting loop.
588-
draft_lengths.append(self.original_max_draft_len)
592+
draft_lengths.append(self.original_max_total_draft_tokens)
589593
else:
590-
draft_lengths.append(self.max_draft_len)
594+
draft_lengths.append(self.max_total_draft_tokens)
591595
else:
592596
# For non-draft model, we also capture the CUDA graph instance for draft length 0,
593597
# so that when we disable spec decode at runtime, we can still run the captured graph.
594598
# Note that for one engine mode, we are not able to turn off spec decode at runtime.
595-
if (self.max_draft_len > 0
599+
if (self.max_total_draft_tokens > 0
596600
and not self.spec_config.spec_dec_mode.use_one_engine()
597601
# Assume that speculation is always on if the user didn't give us a max_concurrency
598602
# value. This will save on memory.
599603
and self.spec_config.max_concurrency is not None):
600604
draft_lengths.append(0)
601-
draft_lengths = [self.max_draft_len]
605+
draft_lengths = [self.max_total_draft_tokens]
602606

603607
for bs in cuda_graph_batch_sizes:
604608
if bs > self.batch_size:
@@ -757,7 +761,7 @@ def _create_warmup_request(
757761
num_ctx_requests + num_gen_tokens)),
758762
token_nums=[1] * num_gen_tokens,
759763
is_gen=True,
760-
max_num_draft_tokens=self.max_draft_len,
764+
max_num_draft_tokens=self.max_total_draft_tokens,
761765
use_mrope=self.use_mrope)
762766
if spec_resource_manager is not None:
763767
spec_resource_manager.add_dummy_requests(request_ids=list(
@@ -830,7 +834,7 @@ def _create_cuda_graph_warmup_request(
830834
def _get_cuda_graph_draft_lengths(
831835
self, resource_manager: ResourceManager) -> List[int]:
832836
"""Determines the draft lengths for which to capture CUDA graphs."""
833-
draft_lengths = [self.max_draft_len]
837+
draft_lengths = [self.max_total_draft_tokens]
834838
spec_resource_manager = resource_manager.get_resource_manager(
835839
ResourceManagerType.SPEC_RESOURCE_MANAGER)
836840

@@ -1027,7 +1031,7 @@ def _preprocess_inputs(self, inputs: Dict[str, Any]):
10271031
"""
10281032
if self.enable_spec_decode and not self._disable_overlap_scheduler:
10291033
# When enabling overlap scheduler, the kv cache for draft tokens will
1030-
# be prepared in advance by using the max_draft_len. But we need to use
1034+
# be prepared in advance by using the max_total_draft_tokens. But we need to use
10311035
# new_tokens_lens_device to get the real past kv lengths and the
10321036
# correct position ids. And to avoid blocking the async data transfer,
10331037
# we need to preprocess the inputs in forward to update the position_ids and
@@ -2252,7 +2256,7 @@ def forward(
22522256
# attn_metadata now depends on spec_metadata since it determines the shape/content of spec_dec parameter Tensors
22532257
is_spec_dec_mode = spec_metadata.spec_dec_mode.attention_need_spec_dec_mode(
22542258
spec_resource_manager, self.is_draft_model, self.attn_backend,
2255-
self.model_is_wrapped)
2259+
self.model_is_wrapped, spec_metadata.is_spec_dec_tree)
22562260
attn_metadata.update_spec_dec_param(
22572261
is_spec_dec_mode, spec_metadata.is_spec_dec_tree,
22582262
spec_metadata.is_spec_dec_dynamic_tree,

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ def __init__(self,
160160
max_batch_size: int = 8,
161161
max_beam_width: int = 1,
162162
max_draft_len: int = 0,
163+
max_total_draft_tokens: int = 0,
163164
kv_cache_transceiver: Optional[KvCacheTransceiver] = None,
164165
guided_decoder: Optional[GuidedDecoder] = None,
165166
garbage_collection_gen0_threshold: Optional[int] = None,
@@ -195,6 +196,7 @@ def __init__(self,
195196
self.active = True
196197
self.max_beam_width = max_beam_width
197198
self.max_draft_len = max_draft_len
199+
self.max_total_draft_tokens = max_total_draft_tokens
198200
self.max_num_tokens = model_engine.pytorch_backend_config.max_num_tokens
199201
self.print_log = model_engine.pytorch_backend_config.print_iter_log
200202
self.enable_iter_perf_stats = model_engine.pytorch_backend_config.enable_iter_perf_stats
@@ -1040,7 +1042,7 @@ def _prepare_and_schedule_batch(self):
10401042
self.use_spec_decode = self.drafter.should_use_spec_decode(
10411043
self.active_requests, self.max_batch_size,
10421044
self.model_engine.max_num_tokens,
1043-
self.model_engine.spec_config.max_draft_len)
1045+
self.model_engine.spec_config.max_total_draft_tokens)
10441046
logger.debug(f"Use spec decode: {self.use_spec_decode}")
10451047
self.model_engine.enable_spec_decode = self.use_spec_decode
10461048

@@ -1050,10 +1052,10 @@ def _prepare_and_schedule_batch(self):
10501052
LlmRequestState.GENERATION_IN_PROGRESS,
10511053
LlmRequestState.DISAGG_GENERATION_INIT):
10521054
continue
1053-
max_draft_len = self.model_engine.spec_config.max_draft_len
1055+
max_total_draft_tokens = self.model_engine.spec_config.max_total_draft_tokens
10541056
request.draft_tokens = [
10551057
0
1056-
] * max_draft_len if max_draft_len > 0 else []
1058+
] * max_total_draft_tokens if max_total_draft_tokens > 0 else []
10571059

10581060
# When overlap scheduler is enabled, and we already prepared the draft tokens in the previous batch,
10591061
# we don't need to initialize py_draft_tokens at this stage because we haven't append the accepted tokens to the request yet.
@@ -1223,11 +1225,11 @@ def _prepare_draft_requests(self):
12231225
continue
12241226

12251227
req.py_last_draft_tokens = req.py_draft_tokens
1226-
max_draft_len = self.model_engine.spec_config.max_draft_len
1228+
max_total_draft_tokens = self.model_engine.spec_config.max_total_draft_tokens
12271229

1228-
if max_draft_len > 0 and self.use_spec_decode:
1229-
req.py_draft_tokens = [0] * max_draft_len
1230-
req.py_draft_pages_allocated = max_draft_len
1230+
if max_total_draft_tokens > 0 and self.use_spec_decode:
1231+
req.py_draft_tokens = [0] * max_total_draft_tokens
1232+
req.py_draft_pages_allocated = max_total_draft_tokens
12311233
else:
12321234
req.py_draft_tokens = []
12331235
req.py_draft_pages_allocated = 0
@@ -1615,7 +1617,7 @@ def _pad_attention_dp_dummy_request(self):
16151617
request_ids=[0],
16161618
is_gen=True,
16171619
prepare_resource=True,
1618-
max_num_draft_tokens=self.max_draft_len,
1620+
max_num_draft_tokens=self.max_total_draft_tokens,
16191621
)[0]
16201622
llm_request.is_attention_dp_dummy = True
16211623
spec_resource_manager = self.resource_manager.get_resource_manager(

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,9 @@ def drafting_loop_wrapper(model):
357357
from tensorrt_llm._torch.speculative.drafting_loops import \
358358
ChainDrafter
359359

360-
return ChainDrafter(spec_config.max_draft_len, model)
360+
return ChainDrafter(spec_config.max_draft_len,
361+
spec_config.max_total_draft_tokens,
362+
model)
361363
else:
362364
drafting_loop_wrapper = None
363365

@@ -397,11 +399,11 @@ def drafting_loop_wrapper(model):
397399
if not pytorch_backend_config.disable_overlap_scheduler:
398400
model_engine_max_seq_len = model_engine.max_seq_len + 1
399401
if spec_config is not None:
400-
model_engine_max_seq_len += spec_config.max_draft_len
402+
model_engine_max_seq_len += spec_config.max_total_draft_tokens
401403

402404
if spec_config is not None:
403405
model_engine_max_seq_len += get_num_extra_kv_tokens(spec_config)
404-
model_engine_max_seq_len += spec_config.max_draft_len
406+
model_engine_max_seq_len += spec_config.max_total_draft_tokens
405407

406408
max_seq_len = model_engine_max_seq_len
407409
max_num_tokens = model_engine.max_num_tokens
@@ -471,7 +473,8 @@ def drafting_loop_wrapper(model):
471473
"vocab_size_padded": model_engine.model.vocab_size_padded
472474
}
473475
if spec_config is not None:
474-
kwargs["max_num_draft_tokens"] = spec_config.max_draft_len
476+
kwargs[
477+
"max_num_draft_tokens"] = spec_config.max_total_draft_tokens
475478

476479
if spec_config is None or spec_config.spec_dec_mode.support_guided_decoder(
477480
):

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -589,7 +589,7 @@ class Args:
589589

590590
def __init__(self, args: Args):
591591
self.max_seq_len = args.max_seq_len
592-
self.max_tokens = args.max_draft_len + 1
592+
self.max_tokens = args.max_total_draft_tokens + 1
593593
assert args.max_beam_width == self.MAX_BEAM_WIDTH, (
594594
"TorchSampler only supports beam_width = 1"
595595
)
@@ -738,9 +738,9 @@ def _process_draft_tokens_tree(
738738
we can find the longest match by comparing all the paths.
739739
Args:
740740
request: LlmRequest. The request with draft tokens.
741-
new_tokens: torch.Tensor. [max_draft_len + 1, max_num_sequences, MAX_BEAM_WIDTH], host buffer.
741+
new_tokens: torch.Tensor. [max_total_draft_tokens + 1, max_num_sequences, MAX_BEAM_WIDTH], host buffer.
742742
The tokens generated by the target model
743-
The relationship between [max_draft_len + 1] and the draft token tree:
743+
The relationship between [max_total_draft_tokens + 1] and the draft token tree:
744744
If the current node is accepted, what is the NEXT token_id that the target model will generate?
745745
For example, new_tokens[0, req_idx, 1] indicates the NEXT token_id sampled from the root
746746
node in the draft token tree if it is accepted.

0 commit comments

Comments
 (0)