@@ -136,23 +136,31 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]):
136136                            eps = config .rms_norm_eps ,
137137                            dtype = config .torch_dtype )
138138
139+     @torch .compile (mode = "max-autotune-no-cudagraphs" ) 
140+     def  get_last_token_states (self , hidden_states , attn_metadata ):
141+         last_tokens  =  torch .cumsum (
142+             attn_metadata .seq_lens_cuda ,
143+             dim = 0 ,
144+             dtype = torch .long ,
145+         ) -  1 
146+         return  hidden_states [last_tokens ]
147+ 
139148    def  forward (self ,
140149                hidden_states : torch .Tensor ,
141150                lm_head : Linear ,
142151                attn_metadata : AttentionMetadata ,
143152                return_context_logits : bool  =  False ) ->  torch .Tensor :
144153        if  not  return_context_logits :
145154            if  attn_metadata  is  not None :
146-                 last_tokens  =  torch .cumsum (
147-                     attn_metadata .seq_lens_cuda ,
148-                     dim = 0 ,
149-                     dtype = torch .long ,
150-                 ) -  1 
151-                 hidden_states  =  hidden_states [last_tokens ]
155+                 hidden_states  =  self .get_last_token_states (
156+                     hidden_states , attn_metadata )
152157            else :
153158                hidden_states  =  hidden_states [- 1 ].unsqueeze (0 )
154159
160+         lm_head .gather_output  =  False 
155161        logits  =  lm_head (hidden_states )
162+         # print("AMEYN: inside DeepseekV3MTPHead lm_head logits.shape:", logits.shape) 
163+         lm_head .gather_output  =  True 
156164        return  logits 
157165
158166
@@ -913,22 +921,46 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig],
913921        self .num_shared_experts  =  config .n_shared_experts 
914922        self .top_k  =  config .num_experts_per_tok 
915923
924+         self .aux_stream  =  aux_stream_dict [AuxStreamType .MoeShared ]
925+         self .event_dict  =  {
926+             key : torch .cuda .Event ()
927+             for  key  in  [EventType .Main , EventType .MoeShared ]
928+         }
929+ 
916930        self .enorm  =  RMSNorm (hidden_size = config .hidden_size ,
917931                             eps = config .rms_norm_eps ,
918932                             dtype = config .torch_dtype )
919933
920934        self .hnorm  =  RMSNorm (hidden_size = config .hidden_size ,
921935                             eps = config .rms_norm_eps ,
922936                             dtype = config .torch_dtype )
937+         self .fuse_norm_ar  =  False   #FIXME: AMEYN 
938+         if  self .fuse_norm_ar :
939+             self .eh_proj  =  Linear (
940+                 config .hidden_size  *  2 ,
941+                 config .hidden_size ,
942+                 bias = False ,
943+                 dtype = config .torch_dtype ,
944+                 tensor_parallel_mode = TensorParallelMode .ROW ,
945+                 mapping = model_config .mapping ,
946+                 reduce_output = False ,
947+                 skip_create_weights_in_init = model_config .
948+                 skip_create_weights_in_init ,
949+             )
950+         else :
951+             self .eh_proj  =  Linear (
952+                 config .hidden_size  *  2 ,
953+                 config .hidden_size ,
954+                 bias = False ,
955+                 dtype = config .torch_dtype ,
956+                 tensor_parallel_mode = TensorParallelMode .ROW ,
957+                 mapping = model_config .mapping ,
958+                 reduce_output = True ,
959+                 skip_create_weights_in_init = model_config .
960+                 skip_create_weights_in_init ,
961+             )
923962
924-         self .eh_proj  =  Linear (
925-             config .hidden_size  *  2 ,
926-             config .hidden_size ,
927-             bias = False ,
928-             dtype = config .torch_dtype ,
929-             skip_create_weights_in_init = model_config .
930-             skip_create_weights_in_init ,
931-         )
963+         # Print shared head initialization message only for rank 0 
932964
933965        self .shared_head  =  DeepseekV3MTPHead (model_config )
934966
@@ -944,14 +976,41 @@ def forward(
944976        ** kwargs ,
945977    ) ->  Tuple [torch .Tensor , torch .Tensor ]:
946978
947-         inputs_embeds  =  self .enorm (embed_tokens (input_ids ))
948-         hidden_states  =  self .hnorm (hidden_states )
979+         def  norm_embeds ():
980+             return  self .enorm (embed_tokens (input_ids ))  #emdedding 
981+ 
982+         def  norm_hidden ():
983+             return  self .hnorm (hidden_states )
984+ 
985+         inputs_embeds , hidden_states  =  maybe_execute_in_parallel (
986+             norm_embeds ,
987+             norm_hidden ,
988+             self .event_dict [EventType .Main ],
989+             self .event_dict [EventType .MoeShared ],
990+             self .aux_stream ,
991+         )
949992        hidden_states  =  torch .concat ([inputs_embeds , hidden_states ], dim = - 1 )
993+         # Split hidden_states columnwise based on TP 
994+         tp_size  =  self .model_config .mapping .tp_size 
995+         tp_rank  =  self .model_config .mapping .tp_rank 
996+         if  tp_size  >  1 :
997+             hidden_states  =  torch .chunk (hidden_states , tp_size , dim = - 1 )[tp_rank ]
950998        hidden_states  =  self .eh_proj (hidden_states )
951999
9521000        # Input layer norm 
953-         residual  =  hidden_states 
954-         hidden_states  =  self .input_layernorm (hidden_states )
1001+         if  self .fuse_norm_ar :
1002+             hidden_states , residual  =  self .allreduce (
1003+                 hidden_states ,
1004+                 all_reduce_params = AllReduceParams (
1005+                     fusion_op = AllReduceFusionOp .RESIDUAL_RMS_NORM ,
1006+                     residual = torch .zeros_like (hidden_states ),
1007+                     norm_weight = self .input_layernorm .weight ,
1008+                     eps = self .input_layernorm .variance_epsilon ,
1009+                 ),
1010+             )
1011+         else :
1012+             residual  =  hidden_states 
1013+             hidden_states  =  self .input_layernorm (hidden_states )
9551014
9561015        # Self Attention 
9571016        hidden_states  =  self .self_attn (
@@ -1084,7 +1143,8 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]):
10841143                                          self .model .aux_stream_dict )
10851144                self .model .layers .append (mtp_layer )
10861145                self .epilogue .append (mtp_layer )
1087-                 self .mtp_worker  =  MTPEagleWorker (model_config .spec_config )
1146+                 self .mtp_worker  =  MTPEagleWorker (model_config .spec_config ,
1147+                                                  model_config )
10881148            else :
10891149                mtp_layers  =  nn .ModuleList ([
10901150                    DeepseekV3MTP (model_config ,
@@ -1094,7 +1154,8 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]):
10941154                ])
10951155                self .model .layers .extend (mtp_layers )
10961156                self .epilogue .extend (mtp_layers )
1097-                 self .mtp_worker  =  MTPWorker (model_config .spec_config )
1157+                 self .mtp_worker  =  MTPWorker (model_config .spec_config ,
1158+                                             model_config )
10981159                # modify the QuantConfig to support duplicated mtp layers 
10991160                if  model_config .quant_config .exclude_modules  is  not None :
11001161                    extend_exclude_modules  =  []
0 commit comments