Skip to content

Commit 55a0d9e

Browse files
committed
add max_total_draft_tokens
Signed-off-by: Yue Weng <[email protected]>
1 parent 72d65d0 commit 55a0d9e

File tree

15 files changed

+142
-66
lines changed

15 files changed

+142
-66
lines changed

tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -318,13 +318,11 @@ def create_autodeploy_executor(ad_config: LlmArgs):
318318
max_draft_len = (
319319
0 if ad_config.speculative_config is None else ad_config.speculative_config.max_draft_len
320320
)
321-
max_total_draft_tokens = 0
322-
if ad_config.speculative_config is None:
323-
max_total_draft_tokens = 0
324-
elif hasattr(ad_config.speculative_config, "max_total_draft_tokens"):
325-
max_total_draft_tokens = ad_config.speculative_config.max_total_draft_tokens
326-
else:
327-
max_total_draft_tokens = max_draft_len
321+
max_total_draft_tokens = (
322+
0
323+
if ad_config.speculative_config is None
324+
else ad_config.speculative_config.max_total_draft_tokens
325+
)
328326

329327
# initialize model engine
330328
engine = ADEngine.build_from_config(ad_config=ad_config)
@@ -399,6 +397,7 @@ def create_autodeploy_executor(ad_config: LlmArgs):
399397
max_input_len=ad_config.max_input_len,
400398
max_batch_size=ad_config.max_batch_size,
401399
max_draft_len=max_draft_len,
400+
max_total_draft_tokens=max_total_draft_tokens,
402401
max_beam_width=ad_config.max_beam_width,
403402
)
404403
return py_executor

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -510,7 +510,7 @@ def __init__(
510510
aux_stream: Optional[torch.cuda.Stream] = None,
511511
):
512512
config = model_config.pretrained_config
513-
predicted_tokens_per_seq = model_config.spec_config.max_draft_len + 1 if model_config.spec_config is not None else 1
513+
predicted_tokens_per_seq = model_config.spec_config.max_total_draft_tokens + 1 if model_config.spec_config is not None else 1
514514
super().__init__(hidden_size=config.hidden_size,
515515
num_attention_heads=config.num_attention_heads,
516516
num_key_value_heads=config.num_key_value_heads,

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -202,10 +202,10 @@ def _get_token_num_for_estimation(self) -> int:
202202
if not pytorch_backend_config.disable_overlap_scheduler:
203203
num_extra_tokens_per_seq = num_extra_tokens_per_seq + 1
204204
if spec_cfg is not None:
205-
num_extra_tokens_per_seq += spec_cfg.max_draft_len
205+
num_extra_tokens_per_seq += spec_cfg.max_total_draft_tokens
206206

207207
if spec_cfg is not None:
208-
num_extra_tokens_per_seq += spec_cfg.max_draft_len
208+
num_extra_tokens_per_seq += spec_cfg.max_total_draft_tokens
209209
num_extra_tokens_per_seq += get_num_extra_kv_tokens(spec_cfg)
210210

211211
if self._dummy_reqs is None:
@@ -751,6 +751,8 @@ def create_py_executor_instance(
751751
max_beam_width=max_beam_width,
752752
max_draft_len=spec_config.max_draft_len
753753
if spec_config is not None else 0,
754+
max_total_draft_tokens=spec_config.max_total_draft_tokens
755+
if spec_config is not None else 0,
754756
kv_cache_transceiver=kv_cache_transceiver,
755757
guided_decoder=guided_decoder,
756758
start_worker=start_worker,
@@ -767,13 +769,8 @@ def create_torch_sampler_args(mapping: Mapping, *, max_seq_len: int,
767769
max_num_sequences = max_batch_size * mapping.pp_size
768770
max_draft_len = (0 if speculative_config is None else
769771
speculative_config.max_draft_len)
770-
max_total_draft_tokens = 0
771-
if speculative_config is None:
772-
max_total_draft_tokens = 0
773-
elif hasattr(speculative_config, 'max_total_draft_tokens'):
774-
max_total_draft_tokens = speculative_config.max_total_draft_tokens
775-
else:
776-
max_total_draft_tokens = max_draft_len
772+
max_total_draft_tokens = (0 if speculative_config is None else
773+
speculative_config.max_total_draft_tokens)
777774

778775
return TorchSampler.Args(
779776
max_seq_len=max_seq_len,

tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@ def enable_spec_decode(self):
9393
@property
9494
def max_possible_draft_len(self):
9595
engine = self._get_engine()
96-
return (engine.original_max_draft_len if self.enable_spec_decode else 0)
96+
return (engine.original_max_total_draft_tokens
97+
if self.enable_spec_decode else 0)
9798

9899
def get_graph_key(
99100
self,
@@ -102,10 +103,12 @@ def get_graph_key(
102103
engine = self._get_engine()
103104
if engine.is_draft_model and spec_resource_manager is not None and isinstance(
104105
spec_resource_manager, Eagle3ResourceManager):
106+
# If 'is_first_draft' is True, even with tree decoding, the length of draft_len will only be 'max_draft_len', not 'max_total_draft_token'.
107+
# Because we will pad the input to 'max_draft_len' length for the first draft layer.
105108
draft_len = engine.original_max_draft_len if spec_resource_manager.is_first_draft else 0
106109
key = (batch_size, draft_len, spec_resource_manager.is_first_draft)
107110
else:
108-
draft_len = self.spec_config.max_draft_len if self.enable_spec_decode else 0
111+
draft_len = self.spec_config.max_total_draft_tokens if self.enable_spec_decode else 0
109112
key = (batch_size, draft_len, False)
110113
return key
111114

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -90,11 +90,11 @@ def warmup(self, resource_manager: ResourceManager) -> None:
9090

9191
def _filter_cuda_graph_batch_sizes(cuda_graph_batch_sizes: list[int],
9292
max_batch_size: int, max_num_tokens: int,
93-
max_draft_len: int,
93+
max_total_draft_tokens: int,
9494
enable_padding: bool) -> list[int]:
9595
# This is the largest possible batch size for a pure decoding batch.
9696
max_cuda_graph_bs = min(max_batch_size,
97-
int(max_num_tokens / (1 + max_draft_len)))
97+
int(max_num_tokens / (1 + max_total_draft_tokens)))
9898

9999
result = []
100100
# This function assumes cuda_graph_batch_sizes is sorted
@@ -157,11 +157,13 @@ def __init__(
157157
ExpertStatistic.create(self.dist.rank)
158158
self.pytorch_backend_config = pytorch_backend_config
159159
self.original_max_draft_len = spec_config.max_draft_len if spec_config is not None else 0
160+
self.original_max_total_draft_tokens = spec_config.max_total_draft_tokens if spec_config is not None else 0
160161

161162
# The draft model won't have any draft tokens attached to
162163
# generation requests when we invoke it autoregressively
163164
if spec_config is not None and is_draft_model:
164165
spec_config.max_draft_len = 0
166+
spec_config.max_total_draft_tokens = 0
165167
self.spec_config = spec_config
166168
self.is_spec_decode = spec_config is not None
167169
self.enable_spec_decode = self.is_spec_decode
@@ -267,7 +269,7 @@ def __init__(
267269
self.spec_metadata = None
268270
update_spec_config_from_model_config(self.spec_config,
269271
self.model.config)
270-
max_num_draft_tokens = self.original_max_draft_len * batch_size
272+
max_num_draft_tokens = self.original_max_total_draft_tokens * batch_size
271273
self.draft_tokens_cuda = torch.empty((max_num_draft_tokens, ),
272274
dtype=torch.int,
273275
device='cuda')
@@ -287,9 +289,11 @@ def __init__(
287289
self.without_logits = self.spec_config.spec_dec_mode.without_logits(
288290
) or self.model_is_wrapped
289291
self.max_draft_len = spec_config.max_draft_len
292+
self.max_total_draft_tokens = spec_config.max_total_draft_tokens
290293
else:
291294
self.without_logits = False
292295
self.max_draft_len = 0
296+
self.max_total_draft_tokens = 0
293297

294298
self.guided_decoder: Optional[CapturableGuidedDecoder] = None
295299

@@ -310,7 +314,7 @@ def __init__(
310314

311315
self._cuda_graph_batch_sizes = _filter_cuda_graph_batch_sizes(
312316
pytorch_backend_config.cuda_graph_batch_sizes, self.batch_size,
313-
self.max_num_tokens, self.original_max_draft_len,
317+
self.max_num_tokens, self.original_max_total_draft_tokens,
314318
self._cuda_graph_padding_enabled
315319
) if pytorch_backend_config.cuda_graph_batch_sizes else []
316320

@@ -351,7 +355,7 @@ def __init__(
351355

352356
@property
353357
def runtime_draft_len(self):
354-
return self.max_draft_len if self.enable_spec_decode else 0
358+
return self.max_total_draft_tokens if self.enable_spec_decode else 0
355359

356360
def set_lora_model_config(self,
357361
lora_target_modules: list[str],
@@ -458,6 +462,8 @@ def warmup(self, resource_manager: ResourceManager) -> None:
458462

459463
def get_num_extra_decoding_steps():
460464
if isinstance(self.model, ChainDrafter):
465+
# We should use max_draft_len instead of max_total_draft_tokens here,
466+
# because max_draft_len indicates the real number of draft layers.
461467
return self.model.max_draft_len
462468
else:
463469
assert not self.model_is_wrapped, (
@@ -595,7 +601,7 @@ def get_warmup_request(num_tokens: int, num_gen_tokens: int):
595601
num_ctx_requests + num_gen_tokens)),
596602
token_nums=[1] * num_gen_tokens,
597603
is_gen=True,
598-
max_num_draft_tokens=self.max_draft_len,
604+
max_num_draft_tokens=self.max_total_draft_tokens,
599605
use_mrope=self.use_mrope)
600606
if spec_resource_manager is not None:
601607
spec_resource_manager.add_dummy_requests(request_ids=list(
@@ -610,7 +616,7 @@ def get_warmup_request(num_tokens: int, num_gen_tokens: int):
610616

611617
curr_max_num_tokens = min(
612618
kv_cache_manager.get_num_available_tokens(
613-
self.original_max_draft_len), self.max_num_tokens,
619+
self.original_max_total_draft_tokens), self.max_num_tokens,
614620
self.batch_size * (self.max_seq_len - 1))
615621

616622
def get_autotune_warmup_request():
@@ -700,20 +706,20 @@ def release_batch(result: ScheduledRequests | None):
700706
if self.model_is_wrapped and self.is_spec_decode and spec_resource_manager is not None and isinstance(
701707
spec_resource_manager, Eagle3ResourceManager):
702708
# The CDL path uses draft_len > 0 for the number of iterations in the drafting loop.
703-
draft_lengths.append(self.original_max_draft_len)
709+
draft_lengths.append(self.original_max_total_draft_tokens)
704710
else:
705711
draft_lengths.append(self.max_draft_len)
706712
else:
707713
# For non-draft model, we also capture the CUDA graph instance for draft length 0,
708714
# so that when we disable spec decode at runtime, we can still run the captured graph.
709715
# Note that for one engine mode, we are not able to turn off spec decode at runtime.
710-
if (self.max_draft_len > 0
716+
if (self.max_total_draft_tokens > 0
711717
and not self.spec_config.spec_dec_mode.use_one_engine()
712718
# Assume that speculation is always on if the user didn't give us a max_concurrency
713719
# value. This will save on memory.
714720
and self.spec_config.max_concurrency is not None):
715721
draft_lengths.append(0)
716-
draft_lengths = [self.max_draft_len]
722+
draft_lengths = [self.max_total_draft_tokens]
717723

718724
for bs in cuda_graph_batch_sizes:
719725
if bs > self.batch_size:
@@ -941,7 +947,7 @@ def _preprocess_inputs(self, inputs: Dict[str, Any]):
941947
"""
942948
if self.enable_spec_decode and not self._disable_overlap_scheduler:
943949
# When enabling overlap scheduler, the kv cache for draft tokens will
944-
# be prepared in advance by using the max_draft_len. But we need to use
950+
# be prepared in advance by using the max_total_draft_tokens. But we need to use
945951
# new_tokens_lens_device to get the real past kv lengths and the
946952
# correct position ids. And to avoid blocking the async data transfer,
947953
# we need to preprocess the inputs in forward to update the position_ids and

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ def __init__(self,
160160
max_batch_size: int = 8,
161161
max_beam_width: int = 1,
162162
max_draft_len: int = 0,
163+
max_total_draft_tokens: int = 0,
163164
kv_cache_transceiver: Optional[KvCacheTransceiver] = None,
164165
guided_decoder: Optional[GuidedDecoder] = None,
165166
garbage_collection_gen0_threshold: Optional[int] = None,
@@ -195,6 +196,7 @@ def __init__(self,
195196
self.active = True
196197
self.max_beam_width = max_beam_width
197198
self.max_draft_len = max_draft_len
199+
self.max_total_draft_tokens = max_total_draft_tokens
198200
self.max_num_tokens = model_engine.pytorch_backend_config.max_num_tokens
199201
self.print_log = model_engine.pytorch_backend_config.print_iter_log
200202
self.enable_iter_perf_stats = model_engine.pytorch_backend_config.enable_iter_perf_stats
@@ -1040,7 +1042,7 @@ def _prepare_and_schedule_batch(self):
10401042
self.use_spec_decode = self.drafter.should_use_spec_decode(
10411043
self.active_requests, self.max_batch_size,
10421044
self.model_engine.max_num_tokens,
1043-
self.model_engine.spec_config.max_draft_len)
1045+
self.model_engine.spec_config.max_total_draft_tokens)
10441046
logger.debug(f"Use spec decode: {self.use_spec_decode}")
10451047
self.model_engine.enable_spec_decode = self.use_spec_decode
10461048

@@ -1050,10 +1052,10 @@ def _prepare_and_schedule_batch(self):
10501052
LlmRequestState.GENERATION_IN_PROGRESS,
10511053
LlmRequestState.DISAGG_GENERATION_INIT):
10521054
continue
1053-
max_draft_len = self.model_engine.spec_config.max_draft_len
1055+
max_total_draft_tokens = self.model_engine.spec_config.max_total_draft_tokens
10541056
request.draft_tokens = [
10551057
0
1056-
] * max_draft_len if max_draft_len > 0 else []
1058+
] * max_total_draft_tokens if max_total_draft_tokens > 0 else []
10571059

10581060
# When overlap scheduler is enabled, and we already prepared the draft tokens in the previous batch,
10591061
# we don't need to initialize py_draft_tokens at this stage because we haven't append the accepted tokens to the request yet.
@@ -1224,11 +1226,11 @@ def _prepare_draft_requests(self):
12241226
continue
12251227

12261228
req.py_last_draft_tokens = req.py_draft_tokens
1227-
max_draft_len = self.model_engine.spec_config.max_draft_len
1229+
max_total_draft_tokens = self.model_engine.spec_config.max_total_draft_tokens
12281230

1229-
if max_draft_len > 0 and self.use_spec_decode:
1230-
req.py_draft_tokens = [0] * max_draft_len
1231-
req.py_draft_pages_allocated = max_draft_len
1231+
if max_total_draft_tokens > 0 and self.use_spec_decode:
1232+
req.py_draft_tokens = [0] * max_total_draft_tokens
1233+
req.py_draft_pages_allocated = max_total_draft_tokens
12321234
else:
12331235
req.py_draft_tokens = []
12341236
req.py_draft_pages_allocated = 0
@@ -1616,7 +1618,7 @@ def _pad_attention_dp_dummy_request(self):
16161618
request_ids=[0],
16171619
is_gen=True,
16181620
prepare_resource=True,
1619-
max_num_draft_tokens=self.max_draft_len,
1621+
max_num_draft_tokens=self.max_total_draft_tokens,
16201622
)[0]
16211623
llm_request.is_attention_dp_dummy = True
16221624
spec_resource_manager = self.resource_manager.get_resource_manager(

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,8 @@ def create_py_executor(
347347
guided_decoding_config is None
348348
and draft_spec_config._allow_chain_drafter
349349
and draft_spec_config._allow_greedy_draft_tokens
350-
and pytorch_backend_config.attn_backend == "TRTLLM")
350+
and pytorch_backend_config.attn_backend == "TRTLLM"
351+
and spec_config.is_linear_tree)
351352

352353
logger.debug(f"USE CHAIN DRAFTER: {use_chain_drafter}")
353354
if use_chain_drafter:
@@ -356,7 +357,9 @@ def drafting_loop_wrapper(model):
356357
from tensorrt_llm._torch.speculative.drafting_loops import \
357358
ChainDrafter
358359

359-
return ChainDrafter(spec_config.max_draft_len, model)
360+
return ChainDrafter(spec_config.max_draft_len,
361+
spec_config.max_total_draft_tokens,
362+
model)
360363
else:
361364
drafting_loop_wrapper = None
362365

@@ -396,11 +399,11 @@ def drafting_loop_wrapper(model):
396399
if not pytorch_backend_config.disable_overlap_scheduler:
397400
model_engine_max_seq_len = model_engine.max_seq_len + 1
398401
if spec_config is not None:
399-
model_engine_max_seq_len += spec_config.max_draft_len
402+
model_engine_max_seq_len += spec_config.max_total_draft_tokens
400403

401404
if spec_config is not None:
402405
model_engine_max_seq_len += get_num_extra_kv_tokens(spec_config)
403-
model_engine_max_seq_len += spec_config.max_draft_len
406+
model_engine_max_seq_len += spec_config.max_total_draft_tokens
404407

405408
max_seq_len = model_engine_max_seq_len
406409
max_num_tokens = model_engine.max_num_tokens
@@ -470,7 +473,8 @@ def drafting_loop_wrapper(model):
470473
"vocab_size_padded": model_engine.model.vocab_size_padded
471474
}
472475
if spec_config is not None:
473-
kwargs["max_num_draft_tokens"] = spec_config.max_draft_len
476+
kwargs[
477+
"max_num_draft_tokens"] = spec_config.max_total_draft_tokens
474478

475479
if spec_config is None or spec_config.spec_dec_mode.support_guided_decoder(
476480
):

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -867,7 +867,7 @@ class Args:
867867

868868
def __init__(self, args: Args):
869869
self.max_seq_len = args.max_seq_len
870-
self.max_tokens = args.max_draft_len + 1
870+
self.max_tokens = args.max_total_draft_tokens + 1
871871
assert args.max_beam_width == self.MAX_BEAM_WIDTH, "TorchSampler only supports beam_width = 1"
872872
self.max_num_sequences = args.max_num_sequences
873873

@@ -1002,8 +1002,8 @@ def _process_draft_tokens_tree(self, request: LlmRequest,
10021002
we can find the longest match by comparing all the paths.
10031003
Args:
10041004
request: LlmRequest. The request with draft tokens.
1005-
new_tokens: torch.Tensor. [max_draft_len + 1, max_num_sequences, MAX_BEAM_WIDTH], host buffer. The tokens generated by the target model
1006-
The relationship between [max_draft_len + 1] and the draft token tree:
1005+
new_tokens: torch.Tensor. [max_total_draft_tokens + 1, max_num_sequences, MAX_BEAM_WIDTH], host buffer. The tokens generated by the target model
1006+
The relationship between [max_total_draft_tokens + 1] and the draft token tree:
10071007
If the current node is accepted, what is the NEXT token_id that the target model will generate?
10081008
For example, new_tokens[0, req_idx, 1] indicates the NEXT token_id sampled from the root node in the draft token tree if it is accepted.
10091009
We know that the root node in the draft token tree is always accepted. Therefore, new_tokens[0, req_idx, 1] indicates the token_id following the root node,

0 commit comments

Comments
 (0)