Skip to content

Commit 91325f8

Browse files
committed
[FMDL-1328][feat] Add support for nano-v3 and super-v3 with pytorch backend
Signed-off-by: Wanli Jiang <[email protected]>
1 parent 96cfdd8 commit 91325f8

File tree

11 files changed

+492
-142
lines changed

11 files changed

+492
-142
lines changed

cpp/tensorrt_llm/thop/moeOp.cpp

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,8 @@ class FusedMoeRunner : public torch::CustomClassHolder
435435
WorkspaceInfo const& workspace_info = getWorkspaceInfo(num_rows, hidden_size, inter_size, num_experts_total,
436436
static_cast<int>(experts_per_token), base_activation_type, parallelism_config, min_latency_mode, stream);
437437

438-
auto const quant_params = getQuantParams(num_experts_on_rank, hidden_size, inter_size, quant_scales);
438+
auto const quant_params
439+
= getQuantParams(num_experts_on_rank, hidden_size, inter_size, quant_scales, base_activation_type);
439440
kernels::MoeMinLatencyParams min_latency_params{};
440441

441442
// TODO: support lora in the future
@@ -613,7 +614,8 @@ class FusedMoeRunner : public torch::CustomClassHolder
613614
WorkspaceInfo const& workspace_info = getWorkspaceInfo(num_rows, hidden_size, inter_size, num_experts_total,
614615
static_cast<int>(experts_per_token), base_activation_type, parallelism_config, min_latency_mode, stream);
615616

616-
auto const quant_params = getQuantParams(num_experts_on_rank, hidden_size, inter_size, quant_scales);
617+
auto const quant_params
618+
= getQuantParams(num_experts_on_rank, hidden_size, inter_size, quant_scales, base_activation_type);
617619

618620
// TODO: support lora in the future
619621
::tensorrt_llm::kernels::LoraParams lora_params{};
@@ -859,7 +861,8 @@ class FusedMoeRunner : public torch::CustomClassHolder
859861
}
860862

861863
kernels::QuantParams getQuantParams(int64_t const num_experts_on_rank, int64_t const hidden_size,
862-
int64_t const inter_size, torch::optional<c10::ArrayRef<torch::Tensor>> const& quant_scales) const
864+
int64_t const inter_size, torch::optional<c10::ArrayRef<torch::Tensor>> const& quant_scales,
865+
ActivationType base_activation_type) const
863866
{
864867
if (isFp8Quant())
865868
{
@@ -921,16 +924,17 @@ class FusedMoeRunner : public torch::CustomClassHolder
921924
TORCH_CHECK(fc2_weight_block.dim() == 3, "fc2 weight block must be 3D");
922925
TORCH_CHECK(fc2_global.dim() == 1, "fc2 global must be 1D");
923926
// Check shapes
927+
int expand_ratio = isGatedActivation(base_activation_type) ? 2 : 1;
924928
TORCH_CHECK(fc1_weight_block.sizes()[0] == num_experts_on_rank
925929
&& fc1_weight_block.sizes()[1]
926930
== TmaWarpSpecializedGroupedGemmInput::alignToSfDim(
927931
inter_size, TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentMXFPX)
928-
* 2
932+
* expand_ratio
929933
&& fc1_weight_block.sizes()[2] * FP8_PER_INT32
930934
* TmaWarpSpecializedGroupedGemmInput::MXFPXBlockScaleVectorSize
931935
== TmaWarpSpecializedGroupedGemmInput::alignToSfDim(
932936
hidden_size, TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentMXFPX),
933-
"fc1 weight block size must be (num_experts_on_rank, inter_size * 2, hidden_size // 4 // "
937+
"fc1 weight block size must be (num_experts_on_rank, inter_size * expand_ratio, hidden_size // 4 // "
934938
"block_scale_vector_size)");
935939
TORCH_CHECK(fc1_global.sizes()[0] == num_experts_on_rank, "fc1 global size must be (num_experts_on_rank,)");
936940
TORCH_CHECK(fc2_act_global.dim() == 0 || fc2_act_global.sizes()[0] == num_experts_on_rank,
@@ -974,16 +978,17 @@ class FusedMoeRunner : public torch::CustomClassHolder
974978
TORCH_CHECK(fc1_global.dim() == 1, "fc1 global must be 1D");
975979
TORCH_CHECK(fc2_weight_block.dim() == 3, "fc2 weight block must be 3D");
976980
TORCH_CHECK(fc2_global.dim() == 1, "fc2 global must be 1D");
981+
int expand_ratio = isGatedActivation(base_activation_type) ? 2 : 1;
977982
TORCH_CHECK(fc1_weight_block.sizes()[0] == num_experts_on_rank
978983
&& fc1_weight_block.sizes()[1]
979984
== TmaWarpSpecializedGroupedGemmInput::alignToSfDim(
980985
inter_size, TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentMXFPX)
981-
* 2
986+
* expand_ratio
982987
&& fc1_weight_block.sizes()[2] * FP8_PER_INT32
983988
* TmaWarpSpecializedGroupedGemmInput::MXFPXBlockScaleVectorSize
984989
== TmaWarpSpecializedGroupedGemmInput::alignToSfDim(
985990
hidden_size, TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentMXFPX),
986-
"fc1 weight block size must be (num_experts_on_rank, inter_size * 2, hidden_size // 4 // "
991+
"fc1 weight block size must be (num_experts_on_rank, inter_size * expand_ratio, hidden_size // 4 // "
987992
"block_scale_vector_size)");
988993
TORCH_CHECK(fc1_global.sizes()[0] == num_experts_on_rank, "fc1 global size must be (num_experts_on_rank,)");
989994
TORCH_CHECK(fc2_weight_block.sizes()[0] == num_experts_on_rank
@@ -1040,16 +1045,17 @@ class FusedMoeRunner : public torch::CustomClassHolder
10401045
// Check shapes
10411046
TORCH_CHECK(fc1_act_global.dim() == 0 || fc1_act_global.sizes()[0] == num_experts_on_rank,
10421047
"fc1 act global must be scalar or (num_experts_on_rank,)");
1048+
int expand_ratio = isGatedActivation(base_activation_type) ? 2 : 1;
10431049
TORCH_CHECK(fc1_weight_block.sizes()[0] == num_experts_on_rank
10441050
&& fc1_weight_block.sizes()[1]
10451051
== TmaWarpSpecializedGroupedGemmInput::alignToSfDim(
10461052
inter_size, TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentNVFP4)
1047-
* 2
1053+
* expand_ratio
10481054
&& fc1_weight_block.sizes()[2] * FP8_PER_INT32
10491055
* TmaWarpSpecializedGroupedGemmInput::NVFP4BlockScaleVectorSize
10501056
== TmaWarpSpecializedGroupedGemmInput::alignToSfDim(
10511057
hidden_size, TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentNVFP4),
1052-
"fc1 weight block size must be (num_experts_on_rank, inter_size * 2, hidden_size // 4 // "
1058+
"fc1 weight block size must be (num_experts_on_rank, inter_size * expand_ratio, hidden_size // 4 // "
10531059
"block_scale_vector_size)");
10541060
TORCH_CHECK(fc1_global.sizes()[0] == num_experts_on_rank, "fc1 global size must be (num_experts_on_rank,)");
10551061
TORCH_CHECK(fc2_act_global.dim() == 0 || fc2_act_global.sizes()[0] == num_experts_on_rank,

tensorrt_llm/_torch/models/checkpoints/hf/nemotron_h_weight_mapper.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
from tensorrt_llm._torch.models.checkpoints.hf.weight_mapper import \
44
HfWeightMapper
5-
from tensorrt_llm._torch.models.modeling_nemotron_h import split
65
from tensorrt_llm._torch.models.modeling_utils import register_mapper
6+
from tensorrt_llm._torch.utils import split
77

88

99
@register_mapper("HF", "NemotronHForCausalLM")
@@ -34,7 +34,8 @@ def preprocess_weights(self, weights: dict) -> dict:
3434
if "A_log" in key:
3535
key = key.replace("A_log", "A")
3636

37-
if "_scale" in key:
37+
if ("mixer.in_proj" in key
38+
or "mixer.out_proj" in key) and "_scale" in key:
3839
new_weights[key] = weights[name]
3940
elif "A" in key:
4041
w = split(weights[name], tp_size, tp_rank)
@@ -94,6 +95,39 @@ def preprocess_weights(self, weights: dict) -> dict:
9495
elif "mixer.norm.weight" in key:
9596
w = split(weights[name], tp_size, tp_rank)
9697
new_weights[key] = w
98+
# Remap MoE expert weights.
99+
elif "mixer.experts." in key:
100+
if self.config.moe_backend == 'VANILLA':
101+
new_weights[key] = weights[name]
102+
else:
103+
if "up_proj" in key:
104+
w1_key = key.replace("up_proj", "w1")
105+
w3_key = key.replace("up_proj", "w3")
106+
# Don't need to handle with input_scale and weight_scale_2 since they are scalar for fp8 and nvfp4 models.
107+
if "input_scale" in key or "weight_scale_2" in key:
108+
new_weights[w3_key] = weights[name]
109+
new_weights[w1_key] = weights[name]
110+
elif "weight_scale" in key:
111+
# NVFP4 case.
112+
if weights[name].shape:
113+
new_weights[w3_key] = weights[
114+
name][:weights[name].shape[0] // 2]
115+
new_weights[w1_key] = weights[name][
116+
weights[name].shape[0] // 2:]
117+
# FP8 case.
118+
else:
119+
new_weights[w3_key] = weights[name]
120+
new_weights[w1_key] = weights[name]
121+
else:
122+
new_weights[w3_key] = weights[name][:weights[name].
123+
shape[0] // 2]
124+
new_weights[w1_key] = weights[name][weights[name].
125+
shape[0] // 2:]
126+
elif "down_proj" in key:
127+
key = key.replace("down_proj", "w2")
128+
new_weights[key] = weights[name]
129+
else:
130+
raise ValueError(f"Unknown MoE weight: {key}")
97131
else:
98132
new_weights[key] = weights[name]
99133

tensorrt_llm/_torch/models/checkpoints/hf/qwen3_next_weight_mapper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
from tensorrt_llm._torch.model_config import ModelConfig
77
from tensorrt_llm._torch.models.checkpoints.hf.qwen2_moe_weight_mapper import \
88
Qwen2MoeHfWeightMapper
9-
from tensorrt_llm._torch.models.modeling_nemotron_h import split
109
from tensorrt_llm._torch.models.modeling_utils import register_mapper
10+
from tensorrt_llm._torch.utils import split
1111
from tensorrt_llm.models.modeling_utils import DecoderModelForCausalLM
1212

1313

0 commit comments

Comments
 (0)