Skip to content

Commit e4e42e0

Browse files
committed
[FMDL-1328][feat] Add support for nano-v3 and super-v3 with pytorch backend
Signed-off-by: Wanli Jiang <[email protected]>
1 parent 96cfdd8 commit e4e42e0

File tree

9 files changed

+329
-55
lines changed

9 files changed

+329
-55
lines changed

tensorrt_llm/_torch/models/checkpoints/hf/nemotron_h_weight_mapper.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
from tensorrt_llm._torch.models.checkpoints.hf.weight_mapper import \
44
HfWeightMapper
5-
from tensorrt_llm._torch.models.modeling_nemotron_h import split
65
from tensorrt_llm._torch.models.modeling_utils import register_mapper
6+
from tensorrt_llm._torch.utils import split
77

88

99
@register_mapper("HF", "NemotronHForCausalLM")
@@ -34,7 +34,8 @@ def preprocess_weights(self, weights: dict) -> dict:
3434
if "A_log" in key:
3535
key = key.replace("A_log", "A")
3636

37-
if "_scale" in key:
37+
if ("mixer.in_proj" in key
38+
or "mixer.out_proj" in key) and "_scale" in key:
3839
new_weights[key] = weights[name]
3940
elif "A" in key:
4041
w = split(weights[name], tp_size, tp_rank)
@@ -94,6 +95,39 @@ def preprocess_weights(self, weights: dict) -> dict:
9495
elif "mixer.norm.weight" in key:
9596
w = split(weights[name], tp_size, tp_rank)
9697
new_weights[key] = w
98+
# Remap MoE expert weights.
99+
elif "mixer.experts." in key:
100+
if self.config.moe_backend == 'VANILLA':
101+
new_weights[key] = weights[name]
102+
else:
103+
if "up_proj" in key:
104+
w1_key = key.replace("up_proj", "w1")
105+
w3_key = key.replace("up_proj", "w3")
106+
# Don't need to handle with input_scale and weight_scale_2 since they are scalar for fp8 and nvfp4 models.
107+
if "input_scale" in key or "weight_scale_2" in key:
108+
new_weights[w3_key] = weights[name]
109+
new_weights[w1_key] = weights[name]
110+
elif "weight_scale" in key:
111+
# NVFP4 case.
112+
if weights[name].shape:
113+
new_weights[w3_key] = weights[
114+
name][:weights[name].shape[0] // 2]
115+
new_weights[w1_key] = weights[name][
116+
weights[name].shape[0] // 2:]
117+
# FP8 case.
118+
else:
119+
new_weights[w3_key] = weights[name]
120+
new_weights[w1_key] = weights[name]
121+
else:
122+
new_weights[w3_key] = weights[name][:weights[name].
123+
shape[0] // 2]
124+
new_weights[w1_key] = weights[name][weights[name].
125+
shape[0] // 2:]
126+
elif "down_proj" in key:
127+
key = key.replace("down_proj", "w2")
128+
new_weights[key] = weights[name]
129+
else:
130+
raise ValueError(f"Unknown MoE weight: {key}")
97131
else:
98132
new_weights[key] = weights[name]
99133

tensorrt_llm/_torch/models/checkpoints/hf/qwen3_next_weight_mapper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
from tensorrt_llm._torch.model_config import ModelConfig
77
from tensorrt_llm._torch.models.checkpoints.hf.qwen2_moe_weight_mapper import \
88
Qwen2MoeHfWeightMapper
9-
from tensorrt_llm._torch.models.modeling_nemotron_h import split
109
from tensorrt_llm._torch.models.modeling_utils import register_mapper
10+
from tensorrt_llm._torch.utils import split
1111
from tensorrt_llm.models.modeling_utils import DecoderModelForCausalLM
1212

1313

tensorrt_llm/_torch/models/modeling_nemotron_h.py

Lines changed: 194 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,44 +14,33 @@
1414
# limitations under the License.
1515

1616
import re
17-
from typing import Optional
17+
from typing import Dict, Optional
1818

1919
import torch
2020
from torch import nn
21-
from torch.nn import functional as F
2221
from transformers import AutoConfig, PretrainedConfig
2322

2423
from tensorrt_llm._torch.models.checkpoints.base_weight_mapper import \
2524
BaseWeightMapper
2625
from tensorrt_llm._torch.modules.mamba.mamba2_metadata import Mamba2Metadata
26+
from tensorrt_llm._torch.utils import ActivationType, relu2
2727

2828
from ..attention_backend import AttentionMetadata
2929
from ..model_config import ModelConfig
3030
from ..modules.attention import Attention
3131
from ..modules.decoder_layer import DecoderLayer
3232
from ..modules.embedding import Embedding
33+
from ..modules.fused_moe import MoEWeightLoadingMode, create_moe
34+
from ..modules.linear import Linear
3335
from ..modules.mamba.mamba2_mixer import Mamba2Mixer
3436
from ..modules.mlp import MLP
37+
from ..modules.multi_stream_utils import maybe_execute_in_parallel
3538
from ..modules.rms_norm import RMSNorm
39+
from ..utils import AuxStreamType, EventType
3640
from .modeling_utils import (DecoderModel, DecoderModelForCausalLM,
3741
register_auto_model)
3842

3943

40-
def split(x: torch.Tensor,
41-
tp_size: int,
42-
idx: int,
43-
dim: int = 0) -> torch.Tensor:
44-
assert x.shape[dim] % tp_size == 0
45-
split_size = x.shape[dim] // tp_size
46-
if tp_size == 1:
47-
return x
48-
return torch.split(x, split_size, dim=dim)[idx]
49-
50-
51-
def relu2(x: torch.Tensor) -> torch.Tensor:
52-
return torch.square(F.relu(x))
53-
54-
5544
class NemotronHConfig(PretrainedConfig):
5645
model_type = "nemotron_h"
5746

@@ -120,6 +109,163 @@ def forward(
120109
attn_metadata=attn_metadata)
121110

122111

112+
class NemotronHMOE(nn.Module):
113+
114+
def __init__(
115+
self,
116+
model_config: ModelConfig[PretrainedConfig],
117+
layer_idx: int,
118+
aux_stream_dict: Dict[AuxStreamType, torch.cuda.Stream],
119+
):
120+
super().__init__()
121+
122+
# Import here to avoid circular dependency.
123+
from .modeling_deepseekv3 import DeepseekV3Gate
124+
125+
self.activation_type = ActivationType.Relu2
126+
self.reduce_results = True
127+
128+
config = model_config.pretrained_config
129+
self.hidden_dim = config.hidden_size
130+
self.ffn_dim = config.intermediate_size
131+
self.layer_idx = layer_idx
132+
self.moe_intermediate_size = config.moe_intermediate_size[0] \
133+
if isinstance(config.moe_intermediate_size, list) else config.moe_intermediate_size
134+
self.use_latent_moe: bool = getattr(config, "moe_latent_size",
135+
None) is not None
136+
self.moe_hidden_size: int = config.moe_latent_size if self.use_latent_moe else config.hidden_size
137+
self.mlp_bias = config.mlp_bias if hasattr(config,
138+
'mlp_bias') else False
139+
self.moe_n_group = config.n_group
140+
self.num_experts = config.n_routed_experts
141+
self.hidden_size = config.hidden_size
142+
self.num_shared_experts = config.n_shared_experts
143+
self.top_k = config.num_experts_per_tok
144+
self.enable_attention_dp = model_config.mapping.enable_attention_dp
145+
self.routed_scaling_factor = config.routed_scaling_factor
146+
147+
# Setup shared expert MLP.
148+
if config.n_shared_experts is None or config.n_shared_experts == 0:
149+
self.shared_experts = None
150+
else:
151+
shared_expert_intermediate_size = (
152+
config.moe_shared_expert_intermediate_size *
153+
config.n_shared_experts)
154+
self.shared_experts = MLP(
155+
hidden_size=config.hidden_size,
156+
intermediate_size=shared_expert_intermediate_size,
157+
bias=self.mlp_bias,
158+
activation=relu2,
159+
dtype=config.torch_dtype,
160+
config=model_config,
161+
layer_idx=self.layer_idx,
162+
)
163+
# Setup MoE gate.
164+
self.gate = DeepseekV3Gate(
165+
self.hidden_size,
166+
self.num_experts,
167+
top_k=self.top_k,
168+
n_group=self.moe_n_group,
169+
topk_group=config.topk_group,
170+
routed_scaling_factor=self.routed_scaling_factor,
171+
dtype=config.torch_dtype,
172+
fuse_routing_kernel=True,
173+
apply_routing=False,
174+
moe_backend=model_config.moe_backend)
175+
176+
# Setup MoE experts.
177+
self.experts = create_moe(
178+
routing_method=self.gate.routing_method,
179+
num_experts=self.num_experts,
180+
hidden_size=self.moe_hidden_size,
181+
intermediate_size=self.moe_intermediate_size,
182+
aux_stream_dict=aux_stream_dict,
183+
dtype=config.torch_dtype,
184+
reduce_results=self.reduce_results,
185+
model_config=model_config,
186+
layer_idx=self.layer_idx,
187+
weight_loading_mode=MoEWeightLoadingMode.VANILLA,
188+
bias=self.mlp_bias,
189+
activation_type=self.activation_type,
190+
# Default values
191+
override_quant_config=None,
192+
apply_router_weight_on_input=False,
193+
swiglu_alpha=None,
194+
swiglu_beta=None,
195+
swiglu_limit=None,
196+
)
197+
198+
# Setup latent projection layers.
199+
if self.use_latent_moe:
200+
self.fc1_latent_proj = Linear(
201+
in_features=self.hidden_size,
202+
out_features=self.moe_hidden_size,
203+
bias=self.mlp_bias,
204+
dtype=config.torch_dtype,
205+
)
206+
self.fc2_latent_proj = Linear(
207+
in_features=self.moe_hidden_size,
208+
out_features=self.hidden_size,
209+
bias=self.mlp_bias,
210+
dtype=config.torch_dtype,
211+
)
212+
else:
213+
self.fc1_latent_proj = None
214+
self.fc2_latent_proj = None
215+
216+
self.aux_stream_shared = aux_stream_dict[AuxStreamType.MoeShared]
217+
self.event_dict = {
218+
key: torch.cuda.Event()
219+
for key in [EventType.Main, EventType.MoeShared]
220+
}
221+
222+
def forward(
223+
self,
224+
hidden_states: torch.Tensor,
225+
attn_metadata: AttentionMetadata,
226+
**kwargs,
227+
) -> torch.Tensor:
228+
assert hidden_states.shape[-1] == self.hidden_dim
229+
orig_shape = hidden_states.shape
230+
hidden_states = hidden_states.view(-1, self.hidden_dim)
231+
232+
def _compute_shared_output():
233+
if self.shared_experts is not None:
234+
shared_expert_output = self.shared_experts(hidden_states)
235+
else:
236+
shared_expert_output = 0
237+
return shared_expert_output
238+
239+
def _compute_routed_output():
240+
router_logits = self.gate(hidden_states)
241+
242+
routed_hidden_states = hidden_states
243+
if self.use_latent_moe:
244+
routed_hidden_states = self.fc1_latent_proj(
245+
routed_hidden_states)
246+
247+
all_rank_num_tokens = attn_metadata.all_rank_num_tokens
248+
final_hidden_states = self.experts(
249+
routed_hidden_states,
250+
router_logits,
251+
all_rank_num_tokens=all_rank_num_tokens,
252+
use_dp_padding=False)
253+
254+
if self.use_latent_moe:
255+
final_hidden_states = self.fc2_latent_proj(final_hidden_states)
256+
257+
return final_hidden_states
258+
259+
routed_output, shared_output = maybe_execute_in_parallel(
260+
_compute_routed_output, _compute_shared_output,
261+
self.event_dict[EventType.Main],
262+
self.event_dict[EventType.MoeShared], self.aux_stream_shared)
263+
264+
final_hidden_states = shared_output + routed_output
265+
266+
return final_hidden_states.view(orig_shape)
267+
268+
123269
class NemotronHLayer(DecoderLayer):
124270

125271
def __init__(
@@ -130,6 +276,7 @@ def __init__(
130276
# - -> MLPLayer
131277
# * -> TransformerLayer
132278
layer_type: str,
279+
aux_stream_dict: Dict[AuxStreamType, torch.cuda.Stream],
133280
):
134281
super().__init__()
135282

@@ -160,6 +307,10 @@ def __init__(
160307
self.mixer = MLPLayer(model_config, layer_idx)
161308
elif layer_type == "*":
162309
self.mixer = TransformerLayer(model_config, layer_idx)
310+
elif layer_type == "E":
311+
self.mixer = NemotronHMOE(model_config,
312+
layer_idx=layer_idx,
313+
aux_stream_dict=aux_stream_dict)
163314
else:
164315
raise ValueError(f"{layer_type} is not supported")
165316

@@ -186,6 +337,18 @@ def __init__(self, model_config: ModelConfig[NemotronHConfig]):
186337
super().__init__(model_config)
187338
config = self.model_config.pretrained_config
188339

340+
aux_stream_list = [torch.cuda.Stream() for _ in range(3)]
341+
self.aux_stream_dict = {
342+
# TODO: add attention stream.
343+
# AuxStreamType.Attention: aux_stream_list[0],
344+
AuxStreamType.MoeShared:
345+
aux_stream_list[0],
346+
AuxStreamType.MoeChunkingOverlap:
347+
aux_stream_list[1],
348+
AuxStreamType.MoeBalancer:
349+
aux_stream_list[2],
350+
}
351+
189352
# calculate embeddings
190353
self.embed_tokens = Embedding(
191354
config.vocab_size,
@@ -196,7 +359,11 @@ def __init__(self, model_config: ModelConfig[NemotronHConfig]):
196359
# create layers
197360
layers = []
198361
for layer_idx, layer_type in enumerate(config.hybrid_override_pattern):
199-
layers.append(NemotronHLayer(model_config, layer_idx, layer_type))
362+
layers.append(
363+
NemotronHLayer(model_config,
364+
layer_idx,
365+
layer_type,
366+
aux_stream_dict=self.aux_stream_dict))
200367
self.layers = nn.ModuleList(layers)
201368

202369
# final norm
@@ -251,6 +418,15 @@ def __init__(
251418
self,
252419
model_config: ModelConfig[NemotronHConfig],
253420
):
421+
# rms_norm_eps might be named differently in the config.
422+
if hasattr(model_config.pretrained_config, "rms_norm_eps"):
423+
rms_epsilon = model_config.pretrained_config.rms_norm_eps
424+
elif hasattr(model_config.pretrained_config, "layer_norm_epsilon"):
425+
rms_epsilon = model_config.pretrained_config.layer_norm_epsilon
426+
else:
427+
raise ValueError("layer_norm_epsilon or rms_norm_eps is not set")
428+
model_config.pretrained_config.rms_norm_eps = rms_epsilon
429+
254430
if not model_config.mapping.tp_size in [1, 2, 4, 8]:
255431
raise ValueError("TP has to be either 1, 2, 4 or 8")
256432

tensorrt_llm/_torch/modules/fused_moe/create_moe.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from tensorrt_llm.models.modeling_utils import QuantConfig
77

88
from ...model_config import ModelConfig
9-
from ...utils import AuxStreamType
9+
from ...utils import ActivationType, AuxStreamType
1010
from .fused_moe_cute_dsl import CuteDslFusedMoE
1111
from .fused_moe_cutlass import CutlassFusedMoE
1212
from .fused_moe_deepgemm import DeepGemmFusedMoE
@@ -74,6 +74,7 @@ def create_moe(
7474
swiglu_alpha: Optional[torch.Tensor] = None,
7575
swiglu_beta: Optional[torch.Tensor] = None,
7676
swiglu_limit: Optional[torch.Tensor] = None,
77+
activation_type: ActivationType = ActivationType.Swiglu,
7778
) -> MoE:
7879
moe_cls = get_moe_cls(model_config, override_quant_config)
7980

@@ -133,6 +134,7 @@ def create_moe(
133134
swiglu_alpha=swiglu_alpha,
134135
swiglu_beta=swiglu_beta,
135136
swiglu_limit=swiglu_limit,
137+
activation_type=activation_type,
136138
)
137139
elif moe_cls == WideEPMoE:
138140
return moe_cls(
@@ -161,6 +163,8 @@ def create_moe(
161163
model_config=model_config,
162164
weight_loading_mode=weight_loading_mode,
163165
apply_router_weight_on_input=apply_router_weight_on_input,
166+
layer_idx=layer_idx,
167+
activation_type=activation_type,
164168
)
165169
elif moe_cls == CuteDslFusedMoE:
166170
return moe_cls(

0 commit comments

Comments
 (0)