Skip to content

Commit b181568

Browse files
lucasliegreg-kwasniewski1suyoggupta
authored
[TRTLLM-8201][feat] Nemotron H MoE Sharding (#8744)
Signed-off-by: Lucas Liebenwein <[email protected]> Signed-off-by: greg-kwasniewski1 <[email protected]> Co-authored-by: greg-kwasniewski1 <[email protected]> Co-authored-by: Suyog Gupta <[email protected]>
1 parent 222bc91 commit b181568

File tree

13 files changed

+1240
-296
lines changed

13 files changed

+1240
-296
lines changed

tensorrt_llm/_torch/auto_deploy/config/default.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,8 @@ transforms:
7777
detect_sharding:
7878
stage: sharding
7979
simple_shard_only: false
80-
use_sharding_from_factory: false
81-
support_partial_config: false
80+
sharding_source: ['heuristic']
81+
support_partial_config: true
8282
sharding_dims: ['tp', 'ep', 'bmm']
8383
requires_shape_prop: true
8484
# TODO: (hg) need to ensure run_shape_prop after sharding.

tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch
88
import torch.nn.functional as F
99
from einops import rearrange
10-
from transformers import AutoModelForCausalLM
10+
from transformers import AutoConfig, AutoModelForCausalLM
1111

1212
from tensorrt_llm._torch.auto_deploy.models.patches.bamba import _bamba_mixer_torch_forward
1313

@@ -144,6 +144,34 @@ def get_model_from_config_patched(config, **kwargs):
144144
# TODO: figure out how this can be incorporated into the export patch system
145145
AutoModelForCausalLM.from_config = get_model_from_config_patched
146146

147+
_config_from_pretrained_original = AutoConfig.from_pretrained
148+
_nemotron_h_base_model_tp_plan = {
149+
"in_proj": "mamba",
150+
"out_proj": "rowwise",
151+
"q_proj": "colwise",
152+
"k_proj": "colwise",
153+
"v_proj": "colwise",
154+
"o_proj": "rowwise",
155+
"up_proj": "colwise",
156+
"down_proj": "rowwise",
157+
# "*": "gather",
158+
}
159+
160+
161+
def get_config_from_pretrained_patched(*args, **kwargs):
162+
ret = _config_from_pretrained_original(*args, **kwargs)
163+
config = ret[0] if isinstance(ret, tuple) else ret
164+
# heuristic to check if it's a NemotronH MoE Model
165+
model_type = getattr(config, "model_type", None)
166+
num_moe_layers = getattr(config, "layers_block_type", []).count("moe")
167+
if model_type == "nemotron_h" and num_moe_layers > 0:
168+
config.base_model_tp_plan = _nemotron_h_base_model_tp_plan
169+
return (config, *ret[1:]) if isinstance(ret, tuple) else config
170+
171+
172+
# TODO: figure out how this can be incorporated into the export patch system
173+
AutoConfig.from_pretrained = get_config_from_pretrained_patched
174+
147175
# TODO: figure out how this can be incorporated into the export patch system
148176
# Only patch if the module isn't available
149177
_mamba_ssm_module = "mamba_ssm"

tensorrt_llm/_torch/auto_deploy/transform/interface.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,10 @@ def __and__(self, other: "TransformInfo") -> "TransformInfo":
173173
has_valid_shapes=self.has_valid_shapes and other.has_valid_shapes,
174174
)
175175

176+
# implement + addition operator for TransformInfo
177+
def __add__(self, other: "TransformInfo") -> "TransformInfo":
178+
return self.__and__(other)
179+
176180

177181
TransformHistory = Dict[str, TransformInfo]
178182

tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ def _find_final_hidden_state_node(
317317
if not (hasattr(mul_node, "args") and len(mul_node.args) >= 2):
318318
return None
319319
index_node = mul_node.args[1]
320-
index_add_node = bfs(
320+
index_add_node, _ = bfs(
321321
index_node, lambda n: is_op(n, torch.ops.aten.index_add_), boundary=end_boundary
322322
)
323323
if not index_add_node:
@@ -383,7 +383,7 @@ def target(n: torch.fx.Node) -> bool:
383383
return is_op(n, {torch.ops.aten.index_add_}) and len(n.users) == 0
384384

385385
try:
386-
node_to_remove = bfs(start_boundary, target, attr_next="users", boundary=end_boundary)
386+
node_to_remove, _ = bfs(start_boundary, target, attr_next="users", boundary=end_boundary)
387387
graph.erase_node(node_to_remove)
388388
return True
389389
except RuntimeError:
@@ -458,7 +458,7 @@ def _apply(
458458
lambda node: is_op(node, torch.ops.aten.one_hot),
459459
attr_next="all_input_nodes",
460460
boundary=start_boundary,
461-
).args[0]
461+
)[0].args[0]
462462
if not selected_experts:
463463
continue
464464

tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from ...shim.interface import CachedSequenceInterface
1414
from ...utils.cuda_mem_tracker import cuda_memory_tracker
1515
from ...utils.logger import ad_logger
16-
from ...utils.node_utils import extract_param_names_from_lin_node, is_linear_op, is_op
16+
from ...utils.node_utils import extract_param_names_from_node, is_linear_op, is_op
1717
from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry
1818

1919

@@ -36,7 +36,7 @@ def _insert_fused_gemm(gm: GraphModule, idx: int, parent_node: Node, linear_node
3636
y2 = y[:, out1:out1+out2]
3737
"""
3838
# some info we need
39-
keys_unfused = [extract_param_names_from_lin_node(n)[0] for n in linear_nodes]
39+
keys_unfused = [extract_param_names_from_node(n)[0] for n in linear_nodes]
4040
params_unfused = [gm.get_parameter(k) for k in keys_unfused]
4141
sizes_unfused = [p.size(0) for p in params_unfused]
4242
key_fused = f"fused_weight_{idx}"
@@ -128,7 +128,7 @@ def build_custom_args_for_linear(self, scale_getattrs: Dict[str, Node]) -> Tuple
128128
def _insert_fused_quant_gemm(
129129
self, gm: GraphModule, idx: int, parent_node: Node, linear_nodes: List[Node]
130130
):
131-
keys_unfused = [extract_param_names_from_lin_node(n)[0] for n in linear_nodes]
131+
keys_unfused = [extract_param_names_from_node(n)[0] for n in linear_nodes]
132132
params_unfused = [gm.get_parameter(k) for k in keys_unfused]
133133
sizes_unfused = [p.size(0) for p in params_unfused]
134134
key_fused = f"fused_weight_{idx}"

tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from ...models.factory import ModelFactory
1515
from ...shim.interface import CachedSequenceInterface
1616
from ...utils.node_utils import (
17-
extract_param_names_from_lin_node,
17+
extract_param_names_from_node,
1818
get_quantization_params_from_linear_node,
1919
is_bmm_op,
2020
is_linear_op,
@@ -136,7 +136,7 @@ def _insert_quantized_linear(
136136
137137
The state_dict is also updated to contain the sharded weights.
138138
"""
139-
param_name, _ = extract_param_names_from_lin_node(node)
139+
param_name, _ = extract_param_names_from_node(node)
140140
original_weight = gm.get_parameter(param_name)
141141
new_param = nn.Parameter(self.quantize_weight(original_weight), requires_grad=False)
142142
modname, _, attrname = param_name.rpartition(".")

0 commit comments

Comments
 (0)