1414# limitations under the License.
1515
1616import re
17- from typing import Optional
17+ from typing import Dict , Optional
1818
1919import torch
2020from torch import nn
21- from torch .nn import functional as F
2221from transformers import AutoConfig , PretrainedConfig
2322
2423from tensorrt_llm ._torch .models .checkpoints .base_weight_mapper import \
2524 BaseWeightMapper
2625from tensorrt_llm ._torch .modules .mamba .mamba2_metadata import Mamba2Metadata
26+ from tensorrt_llm ._torch .utils import ActivationType , relu2
2727
2828from ..attention_backend import AttentionMetadata
2929from ..model_config import ModelConfig
3030from ..modules .attention import Attention
3131from ..modules .decoder_layer import DecoderLayer
3232from ..modules .embedding import Embedding
33+ from ..modules .fused_moe import MoEWeightLoadingMode , create_moe
34+ from ..modules .linear import Linear
3335from ..modules .mamba .mamba2_mixer import Mamba2Mixer
3436from ..modules .mlp import MLP
37+ from ..modules .multi_stream_utils import maybe_execute_in_parallel
3538from ..modules .rms_norm import RMSNorm
39+ from ..utils import AuxStreamType , EventType
3640from .modeling_utils import (DecoderModel , DecoderModelForCausalLM ,
3741 register_auto_model )
3842
3943
40- def split (x : torch .Tensor ,
41- tp_size : int ,
42- idx : int ,
43- dim : int = 0 ) -> torch .Tensor :
44- assert x .shape [dim ] % tp_size == 0
45- split_size = x .shape [dim ] // tp_size
46- if tp_size == 1 :
47- return x
48- return torch .split (x , split_size , dim = dim )[idx ]
49-
50-
51- def relu2 (x : torch .Tensor ) -> torch .Tensor :
52- return torch .square (F .relu (x ))
53-
54-
5544class NemotronHConfig (PretrainedConfig ):
5645 model_type = "nemotron_h"
5746
@@ -120,6 +109,163 @@ def forward(
120109 attn_metadata = attn_metadata )
121110
122111
112+ class NemotronHMOE (nn .Module ):
113+
114+ def __init__ (
115+ self ,
116+ model_config : ModelConfig [PretrainedConfig ],
117+ layer_idx : int ,
118+ aux_stream_dict : Dict [AuxStreamType , torch .cuda .Stream ],
119+ ):
120+ super ().__init__ ()
121+
122+ # Import here to avoid circular dependency.
123+ from .modeling_deepseekv3 import DeepseekV3Gate
124+
125+ self .activation_type = ActivationType .Relu2
126+ self .reduce_results = True
127+
128+ config = model_config .pretrained_config
129+ self .hidden_dim = config .hidden_size
130+ self .ffn_dim = config .intermediate_size
131+ self .layer_idx = layer_idx
132+ self .moe_intermediate_size = config .moe_intermediate_size [0 ] \
133+ if isinstance (config .moe_intermediate_size , list ) else config .moe_intermediate_size
134+ self .use_latent_moe : bool = getattr (config , "moe_latent_size" ,
135+ None ) is not None
136+ self .moe_hidden_size : int = config .moe_latent_size if self .use_latent_moe else config .hidden_size
137+ self .mlp_bias = config .mlp_bias if hasattr (config ,
138+ 'mlp_bias' ) else False
139+ self .moe_n_group = config .n_group
140+ self .num_experts = config .n_routed_experts
141+ self .hidden_size = config .hidden_size
142+ self .num_shared_experts = config .n_shared_experts
143+ self .top_k = config .num_experts_per_tok
144+ self .enable_attention_dp = model_config .mapping .enable_attention_dp
145+ self .routed_scaling_factor = config .routed_scaling_factor
146+
147+ # Setup shared expert MLP.
148+ if config .n_shared_experts is None or config .n_shared_experts == 0 :
149+ self .shared_experts = None
150+ else :
151+ shared_expert_intermediate_size = (
152+ config .moe_shared_expert_intermediate_size *
153+ config .n_shared_experts )
154+ self .shared_experts = MLP (
155+ hidden_size = config .hidden_size ,
156+ intermediate_size = shared_expert_intermediate_size ,
157+ bias = self .mlp_bias ,
158+ activation = relu2 ,
159+ dtype = config .torch_dtype ,
160+ config = model_config ,
161+ layer_idx = self .layer_idx ,
162+ )
163+ # Setup MoE gate.
164+ self .gate = DeepseekV3Gate (
165+ self .hidden_size ,
166+ self .num_experts ,
167+ top_k = self .top_k ,
168+ n_group = self .moe_n_group ,
169+ topk_group = config .topk_group ,
170+ routed_scaling_factor = self .routed_scaling_factor ,
171+ dtype = config .torch_dtype ,
172+ fuse_routing_kernel = True ,
173+ apply_routing = False ,
174+ moe_backend = model_config .moe_backend )
175+
176+ # Setup MoE experts.
177+ self .experts = create_moe (
178+ routing_method = self .gate .routing_method ,
179+ num_experts = self .num_experts ,
180+ hidden_size = self .moe_hidden_size ,
181+ intermediate_size = self .moe_intermediate_size ,
182+ aux_stream_dict = aux_stream_dict ,
183+ dtype = config .torch_dtype ,
184+ reduce_results = self .reduce_results ,
185+ model_config = model_config ,
186+ layer_idx = self .layer_idx ,
187+ weight_loading_mode = MoEWeightLoadingMode .VANILLA ,
188+ bias = self .mlp_bias ,
189+ activation_type = self .activation_type ,
190+ # Default values
191+ override_quant_config = None ,
192+ apply_router_weight_on_input = False ,
193+ swiglu_alpha = None ,
194+ swiglu_beta = None ,
195+ swiglu_limit = None ,
196+ )
197+
198+ # Setup latent projection layers.
199+ if self .use_latent_moe :
200+ self .fc1_latent_proj = Linear (
201+ in_features = self .hidden_size ,
202+ out_features = self .moe_hidden_size ,
203+ bias = self .mlp_bias ,
204+ dtype = config .torch_dtype ,
205+ )
206+ self .fc2_latent_proj = Linear (
207+ in_features = self .moe_hidden_size ,
208+ out_features = self .hidden_size ,
209+ bias = self .mlp_bias ,
210+ dtype = config .torch_dtype ,
211+ )
212+ else :
213+ self .fc1_latent_proj = None
214+ self .fc2_latent_proj = None
215+
216+ self .aux_stream_shared = aux_stream_dict [AuxStreamType .MoeShared ]
217+ self .event_dict = {
218+ key : torch .cuda .Event ()
219+ for key in [EventType .Main , EventType .MoeShared ]
220+ }
221+
222+ def forward (
223+ self ,
224+ hidden_states : torch .Tensor ,
225+ attn_metadata : AttentionMetadata ,
226+ ** kwargs ,
227+ ) -> torch .Tensor :
228+ assert hidden_states .shape [- 1 ] == self .hidden_dim
229+ orig_shape = hidden_states .shape
230+ hidden_states = hidden_states .view (- 1 , self .hidden_dim )
231+
232+ def _compute_shared_output ():
233+ if self .shared_experts is not None :
234+ shared_expert_output = self .shared_experts (hidden_states )
235+ else :
236+ shared_expert_output = 0
237+ return shared_expert_output
238+
239+ def _compute_routed_output ():
240+ router_logits = self .gate (hidden_states )
241+
242+ routed_hidden_states = hidden_states
243+ if self .use_latent_moe :
244+ routed_hidden_states = self .fc1_latent_proj (
245+ routed_hidden_states )
246+
247+ all_rank_num_tokens = attn_metadata .all_rank_num_tokens
248+ final_hidden_states = self .experts (
249+ routed_hidden_states ,
250+ router_logits ,
251+ all_rank_num_tokens = all_rank_num_tokens ,
252+ use_dp_padding = False )
253+
254+ if self .use_latent_moe :
255+ final_hidden_states = self .fc2_latent_proj (final_hidden_states )
256+
257+ return final_hidden_states
258+
259+ routed_output , shared_output = maybe_execute_in_parallel (
260+ _compute_routed_output , _compute_shared_output ,
261+ self .event_dict [EventType .Main ],
262+ self .event_dict [EventType .MoeShared ], self .aux_stream_shared )
263+
264+ final_hidden_states = shared_output + routed_output
265+
266+ return final_hidden_states .view (orig_shape )
267+
268+
123269class NemotronHLayer (DecoderLayer ):
124270
125271 def __init__ (
@@ -130,6 +276,7 @@ def __init__(
130276 # - -> MLPLayer
131277 # * -> TransformerLayer
132278 layer_type : str ,
279+ aux_stream_dict : Dict [AuxStreamType , torch .cuda .Stream ],
133280 ):
134281 super ().__init__ ()
135282
@@ -160,6 +307,10 @@ def __init__(
160307 self .mixer = MLPLayer (model_config , layer_idx )
161308 elif layer_type == "*" :
162309 self .mixer = TransformerLayer (model_config , layer_idx )
310+ elif layer_type == "E" :
311+ self .mixer = NemotronHMOE (model_config ,
312+ layer_idx = layer_idx ,
313+ aux_stream_dict = aux_stream_dict )
163314 else :
164315 raise ValueError (f"{ layer_type } is not supported" )
165316
@@ -186,6 +337,18 @@ def __init__(self, model_config: ModelConfig[NemotronHConfig]):
186337 super ().__init__ (model_config )
187338 config = self .model_config .pretrained_config
188339
340+ aux_stream_list = [torch .cuda .Stream () for _ in range (3 )]
341+ self .aux_stream_dict = {
342+ # TODO: add attention stream.
343+ # AuxStreamType.Attention: aux_stream_list[0],
344+ AuxStreamType .MoeShared :
345+ aux_stream_list [0 ],
346+ AuxStreamType .MoeChunkingOverlap :
347+ aux_stream_list [1 ],
348+ AuxStreamType .MoeBalancer :
349+ aux_stream_list [2 ],
350+ }
351+
189352 # calculate embeddings
190353 self .embed_tokens = Embedding (
191354 config .vocab_size ,
@@ -196,7 +359,11 @@ def __init__(self, model_config: ModelConfig[NemotronHConfig]):
196359 # create layers
197360 layers = []
198361 for layer_idx , layer_type in enumerate (config .hybrid_override_pattern ):
199- layers .append (NemotronHLayer (model_config , layer_idx , layer_type ))
362+ layers .append (
363+ NemotronHLayer (model_config ,
364+ layer_idx ,
365+ layer_type ,
366+ aux_stream_dict = self .aux_stream_dict ))
200367 self .layers = nn .ModuleList (layers )
201368
202369 # final norm
@@ -251,6 +418,15 @@ def __init__(
251418 self ,
252419 model_config : ModelConfig [NemotronHConfig ],
253420 ):
421+ # rms_norm_eps might be named differently in the config.
422+ if hasattr (model_config .pretrained_config , "rms_norm_eps" ):
423+ rms_epsilon = model_config .pretrained_config .rms_norm_eps
424+ elif hasattr (model_config .pretrained_config , "layer_norm_epsilon" ):
425+ rms_epsilon = model_config .pretrained_config .layer_norm_epsilon
426+ else :
427+ raise ValueError ("layer_norm_epsilon or rms_norm_eps is not set" )
428+ model_config .pretrained_config .rms_norm_eps = rms_epsilon
429+
254430 if not model_config .mapping .tp_size in [1 , 2 , 4 , 8 ]:
255431 raise ValueError ("TP has to be either 1, 2, 4 or 8" )
256432
0 commit comments