Skip to content

Commit 538e27c

Browse files
committed
support mtp eagle with 2 models style
Signed-off-by: qgai <[email protected]>
1 parent 523a17d commit 538e27c

File tree

12 files changed

+1362
-1162
lines changed

12 files changed

+1362
-1162
lines changed

examples/llm-api/quickstart_advanced.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -170,16 +170,15 @@ def setup_llm(args, **kwargs):
170170

171171
if spec_decode_algo == 'MTP':
172172
if not args.use_one_model:
173-
print(
174-
"MTP only supports one model style spec decode; ignoring default use_one_model=False"
175-
)
176-
173+
print("Running MTP eagle with two model style.")
177174
spec_config = MTPDecodingConfig(
178175
num_nextn_predict_layers=args.spec_decode_max_draft_len,
179176
use_relaxed_acceptance_for_thinking=args.
180177
use_relaxed_acceptance_for_thinking,
181178
relaxed_topk=args.relaxed_topk,
182-
relaxed_delta=args.relaxed_delta)
179+
relaxed_delta=args.relaxed_delta,
180+
mtp_eagle_one_model=args.use_one_model,
181+
speculative_model_dir=args.model_dir)
183182
elif spec_decode_algo == "EAGLE3":
184183
spec_config = EagleDecodingConfig(
185184
max_draft_len=args.spec_decode_max_draft_len,

tensorrt_llm/_torch/models/modeling_auto.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ def from_config(
3131
"") # Strip the appended EAGLE3
3232
if hasattr(config.pretrained_config, "draft_vocab_size"):
3333
model_arch = "EAGLE3" + model_arch
34+
if model_arch == "DeepseekV3ForCausalLM" and config.spec_config is not None and config.spec_config.max_draft_len == 0:
35+
model_arch = "MTPDraftModelForCausalLM"
3436

3537
cls = MODEL_CLASS_MAPPING.get(model_arch)
3638
if cls is None:

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 1144 additions & 1121 deletions
Large diffs are not rendered by default.

tensorrt_llm/_torch/models/modeling_speculative.py

Lines changed: 116 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, Generic, Optional, Tuple
1+
from typing import Dict, Generic, List, Optional, Tuple
22

33
import torch
44
from torch import nn
@@ -18,6 +18,7 @@
1818
from ..modules.rms_norm import RMSNorm
1919
from ..pyexecutor.guided_decoder import CapturableGuidedDecoder
2020
from ..speculative import SpecMetadata, get_spec_worker
21+
from ..utils import AuxStreamType
2122
from .checkpoints.base_weight_mapper import BaseWeightMapper
2223
from .modeling_utils import (DecoderModel, DecoderModelForCausalLM, TModel,
2324
register_auto_model)
@@ -342,8 +343,8 @@ def __init__(
342343
from .modeling_deepseekv3 import DeepseekV3MTP
343344

344345
spec_dec_mode = model_config.spec_config.spec_dec_mode
345-
assert spec_dec_mode.is_mtp()
346-
mtp_num_layers = 1 if spec_dec_mode.is_mtp_eagle(
346+
assert spec_dec_mode.is_mtp_one_model()
347+
mtp_num_layers = 1 if spec_dec_mode.is_mtp_eagle_one_model(
347348
) else model_config.spec_config.num_nextn_predict_layers
348349

349350
moe_load_balancer_set_repeated_for_next_layer(
@@ -358,16 +359,127 @@ def __init__(
358359
self.embed_tokens = model.embed_tokens
359360

360361

362+
class MTPDraftModel(nn.Module):
363+
364+
def __init__(self, model_config: ModelConfig[PretrainedConfig],
365+
layer_idx: int, aux_stream_dict: Dict[AuxStreamType,
366+
torch.cuda.Stream]):
367+
super().__init__()
368+
# Import here to avoid circular import
369+
from .modeling_deepseekv3 import DeepseekV3MTP
370+
371+
mtp_layer = DeepseekV3MTP(model_config,
372+
layer_idx,
373+
aux_stream_dict,
374+
is_separate_draft_engine=True)
375+
setattr(self, f"layers.{layer_idx}", mtp_layer)
376+
self.layers = mtp_layer
377+
self.layer_idx = layer_idx
378+
self.config = model_config.pretrained_config
379+
self.embed_tokens = Embedding(
380+
self.config.vocab_size,
381+
self.config.hidden_size,
382+
dtype=self.config.torch_dtype,
383+
)
384+
385+
def __repr__(self):
386+
"""Custom string representation to display layer index"""
387+
return f"(layers): ({self.layer_idx}): {repr(self.layers)}"
388+
389+
def forward(
390+
self,
391+
input_ids: torch.IntTensor,
392+
position_ids: torch.IntTensor,
393+
hidden_states: torch.Tensor,
394+
attn_metadata: AttentionMetadata,
395+
all_rank_num_tokens: Optional[List[int]] = None,
396+
all_rank_max_num_tokens: Optional[int] = None,
397+
**kwargs,
398+
) -> Tuple[torch.Tensor, torch.Tensor]:
399+
hidden_states = self.layers(
400+
input_ids,
401+
position_ids,
402+
hidden_states,
403+
embed_tokens=self.embed_tokens,
404+
attn_metadata=attn_metadata,
405+
all_rank_num_tokens=all_rank_num_tokens,
406+
all_rank_max_num_tokens=all_rank_max_num_tokens,
407+
)
408+
409+
return hidden_states
410+
411+
412+
@register_auto_model("MTPDraftModelForCausalLM")
413+
class MTPDraftModelForCausalLM(DecoderModelForCausalLM[MTPDraftModel,
414+
PretrainedConfig]):
415+
416+
def __init__(self, model_config: ModelConfig[PretrainedConfig]):
417+
self.model_config = model_config
418+
aux_stream_list = [torch.cuda.Stream() for _ in range(2)]
419+
self.aux_stream_dict = {
420+
AuxStreamType.Attention: aux_stream_list[0],
421+
AuxStreamType.MoeShared: aux_stream_list[0],
422+
AuxStreamType.MoeChunkingOverlap: aux_stream_list[1],
423+
}
424+
super().__init__(
425+
MTPDraftModel(self.model_config,
426+
self.model_config.pretrained_config.num_hidden_layers,
427+
self.aux_stream_dict),
428+
config=self.model_config,
429+
hidden_size=self.model_config.pretrained_config.hidden_size,
430+
vocab_size=self.model_config.pretrained_config.vocab_size)
431+
432+
def load_weights(self, weights: Dict):
433+
# Import here to avoid circular import
434+
from .modeling_deepseekv3 import DeepseekV3WeightLoader
435+
weight_loader = DeepseekV3WeightLoader(self, is_draft_model=True)
436+
weight_loader.load_weights(weights)
437+
438+
def load_weights_from_target_model(self,
439+
target_model: torch.nn.Module) -> None:
440+
if self.model.embed_tokens is None:
441+
self.model.embed_tokens = target_model.model.embed_tokens
442+
self.lm_head = target_model.lm_head
443+
444+
def forward(self,
445+
attn_metadata: AttentionMetadata,
446+
input_ids: torch.IntTensor = None,
447+
position_ids: torch.IntTensor = None,
448+
inputs_embeds: Optional[torch.FloatTensor] = None,
449+
return_context_logits: bool = False,
450+
spec_metadata: Optional[SpecMetadata] = None,
451+
hidden_states: torch.Tensor = None,
452+
**kwargs) -> torch.Tensor:
453+
454+
hidden_states = spec_metadata.get_hidden_states()
455+
output = self.model(
456+
input_ids=input_ids,
457+
position_ids=position_ids,
458+
hidden_states=hidden_states,
459+
attn_metadata=attn_metadata,
460+
all_rank_num_tokens=attn_metadata.all_rank_num_tokens,
461+
all_rank_max_num_tokens=attn_metadata.all_rank_max_num_tokens,
462+
**kwargs)
463+
return self.logits_processor.forward(
464+
output,
465+
self.lm_head,
466+
attn_metadata,
467+
return_context_logits,
468+
)
469+
470+
361471
def get_draft_model(model_config, draft_config, lm_head, model):
362472
assert getattr(model_config, 'spec_config', None) != None
363473
spec_dec_mode = model_config.spec_config.spec_dec_mode
364474
if spec_dec_mode.is_eagle3_one_model():
365475
return Eagle3ForCausalLM(
366476
draft_config, model_config.pretrained_config.num_hidden_layers)
367-
elif spec_dec_mode.is_mtp():
477+
elif spec_dec_mode.is_mtp_one_model():
368478
return MTPForCausalLM(model_config,
369479
model_config.pretrained_config.num_hidden_layers,
370480
lm_head, model)
481+
elif spec_dec_mode.is_mtp_eagle():
482+
return MTPDraftModelForCausalLM(model_config)
371483
else:
372484
raise NotImplementedError(
373485
f"get_draft_model does not support speculative decoding mode {spec_dec_mode}."

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,9 @@ def drafting_loop_wrapper(model):
330330
is_draft_model=True,
331331
drafting_loop_wrapper=drafting_loop_wrapper,
332332
)
333+
# For DeepseekV3 MTP, we need to set the num_hidden_layers to 1 for the draft model
334+
if spec_config.spec_dec_mode.is_mtp_eagle():
335+
draft_model_engine.model.model_config.pretrained_config.num_hidden_layers = 1
333336
draft_model_engine.kv_cache_manager_key = ResourceManagerType.DRAFT_KV_CACHE_MANAGER
334337
draft_model_engine.load_weights_from_target_model(
335338
model_engine.model)

tensorrt_llm/_torch/speculative/eagle3.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,18 +92,18 @@ class Eagle3SpecMetadata(SpecMetadata):
9292
is_draft_model: bool = False
9393
is_first_draft: bool = False
9494
eagle3_resource_manager: Optional[Eagle3ResourceManager] = None
95+
is_mtp_eagle: bool = False
9596

9697
def __post_init__(self):
9798
if self.is_draft_model:
9899
self.layers_to_capture = (self.num_layers - 1, )
99100
elif self.layers_to_capture is None:
100-
if self.num_layers == 1:
101+
if self.num_layers == 1 or self.is_mtp_eagle:
101102
self.layers_to_capture = (self.num_layers - 1, )
102103
else:
103104
if self.num_layers <= 5:
104105
raise ValueError(
105106
"Not enough hidden layers for default EAGLE3 capture")
106-
107107
self.layers_to_capture = (1, self.num_layers // 2 - 1,
108108
self.num_layers - 4)
109109
else:

tensorrt_llm/_torch/speculative/interface.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
class SpeculativeDecodingMode(IntEnum):
1313
MTP = auto()
1414
MTP_EAGLE = auto()
15+
MTP_EAGLE_ONE_MODEL = auto()
1516
EAGLE3 = auto()
1617
EAGLE3_ONE_MODEL = auto()
1718
NGRAM = auto()
@@ -20,8 +21,11 @@ class SpeculativeDecodingMode(IntEnum):
2021
NONE = auto()
2122
AUTO = auto()
2223

23-
def is_mtp(self):
24-
return self == SpeculativeDecodingMode.MTP or self == SpeculativeDecodingMode.MTP_EAGLE
24+
def is_mtp_one_model(self):
25+
return self == SpeculativeDecodingMode.MTP or self == SpeculativeDecodingMode.MTP_EAGLE_ONE_MODEL
26+
27+
def is_mtp_eagle_one_model(self):
28+
return self == SpeculativeDecodingMode.MTP_EAGLE_ONE_MODEL
2529

2630
def is_mtp_vanilla(self):
2731
return self == SpeculativeDecodingMode.MTP
@@ -33,7 +37,7 @@ def is_eagle3(self):
3337
return self == SpeculativeDecodingMode.EAGLE3
3438

3539
def use_one_engine(self):
36-
return self.is_mtp() or self.is_eagle3_one_model()
40+
return self.is_eagle3_one_model() or self.is_mtp_one_model()
3741

3842
def is_eagle3_one_model(self):
3943
return self == SpeculativeDecodingMode.EAGLE3_ONE_MODEL
@@ -51,31 +55,32 @@ def is_draft_target(self):
5155
return self == SpeculativeDecodingMode.DRAFT_TARGET
5256

5357
def without_logits(self):
54-
return self.is_mtp() or self.is_eagle3_one_model()
58+
return self.is_mtp_one_model() or self.is_eagle3_one_model()
5559

5660
def needs_kv_cache_rewind(self):
57-
return self.is_mtp() or self.is_eagle3_one_model() or self.is_ngram()
61+
return self.is_mtp_one_model() or self.is_eagle3_one_model(
62+
) or self.is_ngram()
5863

5964
def support_overlap_scheduler(self):
60-
return self.is_mtp() or self.is_eagle3_one_model(
65+
return self.is_mtp_one_model() or self.is_eagle3_one_model(
6166
) or self.has_draft_model()
6267

6368
def support_guided_decoder(self):
6469
return self.is_none() or self.has_spec_drafter()
6570

6671
def support_capturable_guided_decoder(self):
67-
return self.is_mtp() or self.is_eagle3_one_model()
72+
return self.is_mtp_one_model() or self.is_eagle3_one_model()
6873

6974
def has_draft_model(self):
70-
return self.is_eagle3() or self.is_draft_target()
75+
return self.is_eagle3() or self.is_draft_target() or self.is_mtp_eagle()
7176

7277
def needs_kv_cache_recompute(self):
7378
"""
7479
Whether the draft model needs to recompute the kv cache.
7580
If true, the 1st draft model forward will recompute the kv cache for
7681
the accepted draft tokens.
7782
"""
78-
return self.is_eagle3()
83+
return self.is_eagle3() or self.is_mtp_eagle()
7984

8085
def need_load_draft_weights(self):
8186
"""
@@ -85,11 +90,12 @@ def need_load_draft_weights(self):
8590
return self.is_eagle3_one_model()
8691

8792
def has_spec_decoder(self):
88-
return self.is_mtp() or self.is_eagle3() or self.is_eagle3_one_model()
93+
return self.is_mtp_one_model() or self.is_mtp_eagle() or self.is_eagle3(
94+
) or self.is_eagle3_one_model()
8995

9096
def has_spec_drafter(self):
9197
return self.is_eagle3() or self.is_draft_target() or self.is_ngram(
92-
) or self.is_user_provided()
98+
) or self.is_user_provided() or self.is_mtp_eagle()
9399

94100
def extend_ctx(self, attention_backend: Type[AttentionBackend]):
95101
"""

tensorrt_llm/_torch/speculative/model_drafter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def get_draft_model_prompt(spec_dec_mode: SpeculativeDecodingMode,
3131
Can be used to modify prompts for speculative algorithms that need to update tokens
3232
before drafting.
3333
"""
34-
if spec_dec_mode.is_eagle3():
34+
if spec_dec_mode.is_eagle3() or spec_dec_mode.is_mtp_eagle():
3535
# EAGLE3 always throws away the first token when processing draft inputs
3636
return input_tokens[1:]
3737
return input_tokens

tensorrt_llm/_torch/speculative/mtp.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def all_rank_num_seqs(self):
158158
@all_rank_num_seqs.setter
159159
def all_rank_num_seqs(self, value: List[int]):
160160
self._all_rank_num_seqs = value
161-
if self.spec_dec_mode.is_mtp_eagle():
161+
if self.spec_dec_mode.is_mtp_eagle_one_model():
162162
self.subseq_all_rank_num_tokens = value
163163

164164
def prepare(self):
@@ -175,7 +175,7 @@ def prepare(self):
175175
# while MTP Eagle worker uses (max_draft_len + 1) input tokens in the 1st draft
176176
# forward and only one input token in the following draft forward.
177177
# This num_tokens is used to set the all_rank_num_tokens for attention dp.
178-
if not self.spec_dec_mode.is_mtp_eagle():
178+
if not self.spec_dec_mode.is_mtp_eagle_one_model():
179179
self.num_tokens -= self.num_generations
180180

181181
if self.mtp_hidden_states_manager is not None: # MTP vanilla or use relaxed acceptance
@@ -186,7 +186,7 @@ def prepare(self):
186186
mtp_slot_ids.append(slot_id)
187187

188188
# MTP Vanilla: Update mtp hidden states and past tokens
189-
if self.spec_dec_mode.is_mtp():
189+
if self.spec_dec_mode.is_mtp_one_model():
190190
mtp_hidden_states_ptrs = []
191191
mtp_past_tokens_ptrs = []
192192
for slot_id in mtp_slot_ids:

0 commit comments

Comments
 (0)