1- from typing import Dict , Generic , Optional , Tuple
1+ from typing import Dict , Generic , List , Optional , Tuple
22
33import torch
44from torch import nn
1818from ..modules .rms_norm import RMSNorm
1919from ..pyexecutor .guided_decoder import CapturableGuidedDecoder
2020from ..speculative import SpecMetadata , get_spec_worker
21+ from ..utils import AuxStreamType
2122from .checkpoints .base_weight_mapper import BaseWeightMapper
2223from .modeling_utils import (DecoderModel , DecoderModelForCausalLM , TModel ,
2324 register_auto_model )
@@ -342,8 +343,8 @@ def __init__(
342343 from .modeling_deepseekv3 import DeepseekV3MTP
343344
344345 spec_dec_mode = model_config .spec_config .spec_dec_mode
345- assert spec_dec_mode .is_mtp ()
346- mtp_num_layers = 1 if spec_dec_mode .is_mtp_eagle (
346+ assert spec_dec_mode .is_mtp_one_model ()
347+ mtp_num_layers = 1 if spec_dec_mode .is_mtp_eagle_one_model (
347348 ) else model_config .spec_config .num_nextn_predict_layers
348349
349350 moe_load_balancer_set_repeated_for_next_layer (
@@ -358,16 +359,127 @@ def __init__(
358359 self .embed_tokens = model .embed_tokens
359360
360361
362+ class MTPDraftModel (nn .Module ):
363+
364+ def __init__ (self , model_config : ModelConfig [PretrainedConfig ],
365+ layer_idx : int , aux_stream_dict : Dict [AuxStreamType ,
366+ torch .cuda .Stream ]):
367+ super ().__init__ ()
368+ # Import here to avoid circular import
369+ from .modeling_deepseekv3 import DeepseekV3MTP
370+
371+ mtp_layer = DeepseekV3MTP (model_config ,
372+ layer_idx ,
373+ aux_stream_dict ,
374+ is_separate_draft_engine = True )
375+ setattr (self , f"layers.{ layer_idx } " , mtp_layer )
376+ self .layers = mtp_layer
377+ self .layer_idx = layer_idx
378+ self .config = model_config .pretrained_config
379+ self .embed_tokens = Embedding (
380+ self .config .vocab_size ,
381+ self .config .hidden_size ,
382+ dtype = self .config .torch_dtype ,
383+ )
384+
385+ def __repr__ (self ):
386+ """Custom string representation to display layer index"""
387+ return f"(layers): ({ self .layer_idx } ): { repr (self .layers )} "
388+
389+ def forward (
390+ self ,
391+ input_ids : torch .IntTensor ,
392+ position_ids : torch .IntTensor ,
393+ hidden_states : torch .Tensor ,
394+ attn_metadata : AttentionMetadata ,
395+ all_rank_num_tokens : Optional [List [int ]] = None ,
396+ all_rank_max_num_tokens : Optional [int ] = None ,
397+ ** kwargs ,
398+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
399+ hidden_states = self .layers (
400+ input_ids ,
401+ position_ids ,
402+ hidden_states ,
403+ embed_tokens = self .embed_tokens ,
404+ attn_metadata = attn_metadata ,
405+ all_rank_num_tokens = all_rank_num_tokens ,
406+ all_rank_max_num_tokens = all_rank_max_num_tokens ,
407+ )
408+
409+ return hidden_states
410+
411+
412+ @register_auto_model ("MTPDraftModelForCausalLM" )
413+ class MTPDraftModelForCausalLM (DecoderModelForCausalLM [MTPDraftModel ,
414+ PretrainedConfig ]):
415+
416+ def __init__ (self , model_config : ModelConfig [PretrainedConfig ]):
417+ self .model_config = model_config
418+ aux_stream_list = [torch .cuda .Stream () for _ in range (2 )]
419+ self .aux_stream_dict = {
420+ AuxStreamType .Attention : aux_stream_list [0 ],
421+ AuxStreamType .MoeShared : aux_stream_list [0 ],
422+ AuxStreamType .MoeChunkingOverlap : aux_stream_list [1 ],
423+ }
424+ super ().__init__ (
425+ MTPDraftModel (self .model_config ,
426+ self .model_config .pretrained_config .num_hidden_layers ,
427+ self .aux_stream_dict ),
428+ config = self .model_config ,
429+ hidden_size = self .model_config .pretrained_config .hidden_size ,
430+ vocab_size = self .model_config .pretrained_config .vocab_size )
431+
432+ def load_weights (self , weights : Dict ):
433+ # Import here to avoid circular import
434+ from .modeling_deepseekv3 import DeepseekV3WeightLoader
435+ weight_loader = DeepseekV3WeightLoader (self , is_draft_model = True )
436+ weight_loader .load_weights (weights )
437+
438+ def load_weights_from_target_model (self ,
439+ target_model : torch .nn .Module ) -> None :
440+ if self .model .embed_tokens is None :
441+ self .model .embed_tokens = target_model .model .embed_tokens
442+ self .lm_head = target_model .lm_head
443+
444+ def forward (self ,
445+ attn_metadata : AttentionMetadata ,
446+ input_ids : torch .IntTensor = None ,
447+ position_ids : torch .IntTensor = None ,
448+ inputs_embeds : Optional [torch .FloatTensor ] = None ,
449+ return_context_logits : bool = False ,
450+ spec_metadata : Optional [SpecMetadata ] = None ,
451+ hidden_states : torch .Tensor = None ,
452+ ** kwargs ) -> torch .Tensor :
453+
454+ hidden_states = spec_metadata .get_hidden_states ()
455+ output = self .model (
456+ input_ids = input_ids ,
457+ position_ids = position_ids ,
458+ hidden_states = hidden_states ,
459+ attn_metadata = attn_metadata ,
460+ all_rank_num_tokens = attn_metadata .all_rank_num_tokens ,
461+ all_rank_max_num_tokens = attn_metadata .all_rank_max_num_tokens ,
462+ ** kwargs )
463+ return self .logits_processor .forward (
464+ output ,
465+ self .lm_head ,
466+ attn_metadata ,
467+ return_context_logits ,
468+ )
469+
470+
361471def get_draft_model (model_config , draft_config , lm_head , model ):
362472 assert getattr (model_config , 'spec_config' , None ) != None
363473 spec_dec_mode = model_config .spec_config .spec_dec_mode
364474 if spec_dec_mode .is_eagle3_one_model ():
365475 return Eagle3ForCausalLM (
366476 draft_config , model_config .pretrained_config .num_hidden_layers )
367- elif spec_dec_mode .is_mtp ():
477+ elif spec_dec_mode .is_mtp_one_model ():
368478 return MTPForCausalLM (model_config ,
369479 model_config .pretrained_config .num_hidden_layers ,
370480 lm_head , model )
481+ elif spec_dec_mode .is_mtp_eagle ():
482+ return MTPDraftModelForCausalLM (model_config )
371483 else :
372484 raise NotImplementedError (
373485 f"get_draft_model does not support speculative decoding mode { spec_dec_mode } ."
0 commit comments