@@ -136,23 +136,31 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]):
136136                            eps = config .rms_norm_eps ,
137137                            dtype = config .torch_dtype )
138138
139+     @torch .compile (mode = "max-autotune-no-cudagraphs" ) 
140+     def  get_last_token_states (self , hidden_states , attn_metadata ):
141+         last_tokens  =  torch .cumsum (
142+             attn_metadata .seq_lens_cuda ,
143+             dim = 0 ,
144+             dtype = torch .long ,
145+         ) -  1 
146+         return  hidden_states [last_tokens ]
147+ 
139148    def  forward (self ,
140149                hidden_states : torch .Tensor ,
141150                lm_head : Linear ,
142151                attn_metadata : AttentionMetadata ,
143152                return_context_logits : bool  =  False ) ->  torch .Tensor :
144153        if  not  return_context_logits :
145154            if  attn_metadata  is  not None :
146-                 last_tokens  =  torch .cumsum (
147-                     attn_metadata .seq_lens_cuda ,
148-                     dim = 0 ,
149-                     dtype = torch .long ,
150-                 ) -  1 
151-                 hidden_states  =  hidden_states [last_tokens ]
155+                 hidden_states  =  self .get_last_token_states (
156+                     hidden_states , attn_metadata )
152157            else :
153158                hidden_states  =  hidden_states [- 1 ].unsqueeze (0 )
154159
160+         lm_head .gather_output  =  False 
155161        logits  =  lm_head (hidden_states )
162+         # print("AMEYN: inside DeepseekV3MTPHead lm_head logits.shape:", logits.shape) 
163+         lm_head .gather_output  =  True 
156164        return  logits 
157165
158166
@@ -911,22 +919,46 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig],
911919        self .num_shared_experts  =  config .n_shared_experts 
912920        self .top_k  =  config .num_experts_per_tok 
913921
922+         self .aux_stream  =  aux_stream_dict [AuxStreamType .MoeShared ]
923+         self .event_dict  =  {
924+             key : torch .cuda .Event ()
925+             for  key  in  [EventType .Main , EventType .MoeShared ]
926+         }
927+ 
914928        self .enorm  =  RMSNorm (hidden_size = config .hidden_size ,
915929                             eps = config .rms_norm_eps ,
916930                             dtype = config .torch_dtype )
917931
918932        self .hnorm  =  RMSNorm (hidden_size = config .hidden_size ,
919933                             eps = config .rms_norm_eps ,
920934                             dtype = config .torch_dtype )
935+         self .fuse_norm_ar  =  False   #FIXME: AMEYN 
936+         if  self .fuse_norm_ar :
937+             self .eh_proj  =  Linear (
938+                 config .hidden_size  *  2 ,
939+                 config .hidden_size ,
940+                 bias = False ,
941+                 dtype = config .torch_dtype ,
942+                 tensor_parallel_mode = TensorParallelMode .ROW ,
943+                 mapping = model_config .mapping ,
944+                 reduce_output = False ,
945+                 skip_create_weights_in_init = model_config .
946+                 skip_create_weights_in_init ,
947+             )
948+         else :
949+             self .eh_proj  =  Linear (
950+                 config .hidden_size  *  2 ,
951+                 config .hidden_size ,
952+                 bias = False ,
953+                 dtype = config .torch_dtype ,
954+                 tensor_parallel_mode = TensorParallelMode .ROW ,
955+                 mapping = model_config .mapping ,
956+                 reduce_output = True ,
957+                 skip_create_weights_in_init = model_config .
958+                 skip_create_weights_in_init ,
959+             )
921960
922-         self .eh_proj  =  Linear (
923-             config .hidden_size  *  2 ,
924-             config .hidden_size ,
925-             bias = False ,
926-             dtype = config .torch_dtype ,
927-             skip_create_weights_in_init = model_config .
928-             skip_create_weights_in_init ,
929-         )
961+         # Print shared head initialization message only for rank 0 
930962
931963        self .shared_head  =  DeepseekV3MTPHead (model_config )
932964
@@ -942,14 +974,41 @@ def forward(
942974        ** kwargs ,
943975    ) ->  Tuple [torch .Tensor , torch .Tensor ]:
944976
945-         inputs_embeds  =  self .enorm (embed_tokens (input_ids ))
946-         hidden_states  =  self .hnorm (hidden_states )
977+         def  norm_embeds ():
978+             return  self .enorm (embed_tokens (input_ids ))  #emdedding 
979+ 
980+         def  norm_hidden ():
981+             return  self .hnorm (hidden_states )
982+ 
983+         inputs_embeds , hidden_states  =  maybe_execute_in_parallel (
984+             norm_embeds ,
985+             norm_hidden ,
986+             self .event_dict [EventType .Main ],
987+             self .event_dict [EventType .MoeShared ],
988+             self .aux_stream ,
989+         )
947990        hidden_states  =  torch .concat ([inputs_embeds , hidden_states ], dim = - 1 )
991+         # Split hidden_states columnwise based on TP 
992+         tp_size  =  self .model_config .mapping .tp_size 
993+         tp_rank  =  self .model_config .mapping .tp_rank 
994+         if  tp_size  >  1 :
995+             hidden_states  =  torch .chunk (hidden_states , tp_size , dim = - 1 )[tp_rank ]
948996        hidden_states  =  self .eh_proj (hidden_states )
949997
950998        # Input layer norm 
951-         residual  =  hidden_states 
952-         hidden_states  =  self .input_layernorm (hidden_states )
999+         if  self .fuse_norm_ar :
1000+             hidden_states , residual  =  self .allreduce (
1001+                 hidden_states ,
1002+                 all_reduce_params = AllReduceParams (
1003+                     fusion_op = AllReduceFusionOp .RESIDUAL_RMS_NORM ,
1004+                     residual = torch .zeros_like (hidden_states ),
1005+                     norm_weight = self .input_layernorm .weight ,
1006+                     eps = self .input_layernorm .variance_epsilon ,
1007+                 ),
1008+             )
1009+         else :
1010+             residual  =  hidden_states 
1011+             hidden_states  =  self .input_layernorm (hidden_states )
9531012
9541013        # Self Attention 
9551014        hidden_states  =  self .self_attn (
@@ -1082,7 +1141,8 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]):
10821141                                          self .model .aux_stream_dict )
10831142                self .model .layers .append (mtp_layer )
10841143                self .epilogue .append (mtp_layer )
1085-                 self .mtp_worker  =  MTPEagleWorker (model_config .spec_config )
1144+                 self .mtp_worker  =  MTPEagleWorker (model_config .spec_config ,
1145+                                                  model_config )
10861146            else :
10871147                mtp_layers  =  nn .ModuleList ([
10881148                    DeepseekV3MTP (model_config ,
@@ -1092,7 +1152,8 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]):
10921152                ])
10931153                self .model .layers .extend (mtp_layers )
10941154                self .epilogue .extend (mtp_layers )
1095-                 self .mtp_worker  =  MTPWorker (model_config .spec_config )
1155+                 self .mtp_worker  =  MTPWorker (model_config .spec_config ,
1156+                                             model_config )
10961157                # modify the QuantConfig to support duplicated mtp layers 
10971158                if  model_config .quant_config .exclude_modules  is  not None :
10981159                    extend_exclude_modules  =  []
0 commit comments