@@ -131,28 +131,38 @@ class DeepseekV3MTPHead(nn.Module):
131131    def  __init__ (self , model_config : ModelConfig [PretrainedConfig ]):
132132        super ().__init__ ()
133133        config  =  model_config .pretrained_config 
134+         self .model_config  =  model_config 
134135
135136        self .norm  =  RMSNorm (hidden_size = config .hidden_size ,
136137                            eps = config .rms_norm_eps ,
137138                            dtype = config .torch_dtype )
138139
140+     @torch .compile (options = {"max-autotune" : True }) 
141+     def  get_last_token_states (self , hidden_states , attn_metadata ):
142+         last_tokens  =  torch .cumsum (
143+             attn_metadata .seq_lens_cuda ,
144+             dim = 0 ,
145+             dtype = torch .long ,
146+         ) -  1 
147+         return  hidden_states [last_tokens ]
148+ 
139149    def  forward (self ,
140150                hidden_states : torch .Tensor ,
141151                lm_head : Linear ,
142152                attn_metadata : AttentionMetadata ,
143153                return_context_logits : bool  =  False ) ->  torch .Tensor :
144154        if  not  return_context_logits :
145155            if  attn_metadata  is  not None :
146-                 last_tokens  =  torch .cumsum (
147-                     attn_metadata .seq_lens_cuda ,
148-                     dim = 0 ,
149-                     dtype = torch .long ,
150-                 ) -  1 
151-                 hidden_states  =  hidden_states [last_tokens ]
156+                 hidden_states  =  self .get_last_token_states (
157+                     hidden_states , attn_metadata )
152158            else :
153159                hidden_states  =  hidden_states [- 1 ].unsqueeze (0 )
154160
161+         if  not  (self .model_config .mapping .enable_attention_dp ):
162+             lm_head .gather_output  =  False 
155163        logits  =  lm_head (hidden_states )
164+         if  not  (self .model_config .mapping .enable_attention_dp ):
165+             lm_head .gather_output  =  True 
156166        return  logits 
157167
158168
@@ -903,22 +913,40 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig],
903913        self .num_shared_experts  =  config .n_shared_experts 
904914        self .top_k  =  config .num_experts_per_tok 
905915
916+         self .aux_stream  =  aux_stream_dict [AuxStreamType .MoeShared ]
917+         self .event_dict  =  {
918+             key : torch .cuda .Event ()
919+             for  key  in  [EventType .Main , EventType .MoeShared ]
920+         }
921+ 
906922        self .enorm  =  RMSNorm (hidden_size = config .hidden_size ,
907923                             eps = config .rms_norm_eps ,
908924                             dtype = config .torch_dtype )
909925
910926        self .hnorm  =  RMSNorm (hidden_size = config .hidden_size ,
911927                             eps = config .rms_norm_eps ,
912928                             dtype = config .torch_dtype )
913- 
914-         self .eh_proj  =  Linear (
915-             config .hidden_size  *  2 ,
916-             config .hidden_size ,
917-             bias = False ,
918-             dtype = config .torch_dtype ,
919-             skip_create_weights_in_init = model_config .
920-             skip_create_weights_in_init ,
921-         )
929+         if  model_config .mapping .enable_attention_dp :
930+             self .eh_proj  =  Linear (
931+                 config .hidden_size  *  2 ,
932+                 config .hidden_size ,
933+                 bias = False ,
934+                 dtype = config .torch_dtype ,
935+                 skip_create_weights_in_init = model_config .
936+                 skip_create_weights_in_init ,
937+             )
938+         else :
939+             self .eh_proj  =  Linear (
940+                 config .hidden_size  *  2 ,
941+                 config .hidden_size ,
942+                 bias = False ,
943+                 dtype = config .torch_dtype ,
944+                 tensor_parallel_mode = TensorParallelMode .ROW ,
945+                 mapping = model_config .mapping ,
946+                 reduce_output = True ,
947+                 skip_create_weights_in_init = model_config .
948+                 skip_create_weights_in_init ,
949+             )
922950
923951        self .shared_head  =  DeepseekV3MTPHead (model_config )
924952
@@ -934,9 +962,26 @@ def forward(
934962        ** kwargs ,
935963    ) ->  Tuple [torch .Tensor , torch .Tensor ]:
936964
937-         inputs_embeds  =  self .enorm (embed_tokens (input_ids ))
938-         hidden_states  =  self .hnorm (hidden_states )
965+         def  norm_embeds ():
966+             return  self .enorm (embed_tokens (input_ids ))  #emdedding 
967+ 
968+         def  norm_hidden ():
969+             return  self .hnorm (hidden_states )
970+ 
971+         inputs_embeds , hidden_states  =  maybe_execute_in_parallel (
972+             norm_embeds ,
973+             norm_hidden ,
974+             self .event_dict [EventType .Main ],
975+             self .event_dict [EventType .MoeShared ],
976+             self .aux_stream ,
977+         )
939978        hidden_states  =  torch .concat ([inputs_embeds , hidden_states ], dim = - 1 )
979+         # Split hidden_states columnwise based on TP 
980+         tp_size  =  self .model_config .mapping .tp_size 
981+         tp_rank  =  self .model_config .mapping .tp_rank 
982+ 
983+         if  tp_size  >  1  and  not  (self .model_config .mapping .enable_attention_dp ):
984+             hidden_states  =  torch .chunk (hidden_states , tp_size , dim = - 1 )[tp_rank ]
940985        hidden_states  =  self .eh_proj (hidden_states )
941986
942987        # Input layer norm 
@@ -1074,7 +1119,8 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]):
10741119                                          self .model .aux_stream_dict )
10751120                self .model .layers .append (mtp_layer )
10761121                self .epilogue .append (mtp_layer )
1077-                 self .mtp_worker  =  MTPEagleWorker (model_config .spec_config )
1122+                 self .mtp_worker  =  MTPEagleWorker (model_config .spec_config ,
1123+                                                  model_config )
10781124            else :
10791125                mtp_layers  =  nn .ModuleList ([
10801126                    DeepseekV3MTP (model_config ,
@@ -1084,7 +1130,8 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]):
10841130                ])
10851131                self .model .layers .extend (mtp_layers )
10861132                self .epilogue .extend (mtp_layers )
1087-                 self .mtp_worker  =  MTPWorker (model_config .spec_config )
1133+                 self .mtp_worker  =  MTPWorker (model_config .spec_config ,
1134+                                             model_config )
10881135                # modify the QuantConfig to support duplicated mtp layers 
10891136                if  model_config .quant_config .exclude_modules  is  not None :
10901137                    extend_exclude_modules  =  []
0 commit comments