Skip to content

Commit 7905d6c

Browse files
[#9098][feat] Simple sharding latent experts (#9099)
Signed-off-by: greg-kwasniewski1 <[email protected]>
1 parent fbf6c16 commit 7905d6c

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

tensorrt_llm/_torch/auto_deploy/config/default.yaml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,15 +73,13 @@ transforms:
7373
stage: pattern_matcher
7474
quantize_mxfp4_moe:
7575
stage: pattern_matcher
76-
# TODO: Infer sharding parameters (tp_size, row/column sharding) from the model config.
7776
detect_sharding:
7877
stage: sharding
7978
simple_shard_only: false
8079
sharding_source: ['factory','heuristic']
8180
support_partial_config: true
8281
sharding_dims: ['tp', 'ep', 'bmm']
8382
requires_shape_prop: true
84-
# TODO: (hg) need to ensure run_shape_prop after sharding.
8583
sharding_transform_executor:
8684
stage: sharding
8785
run_shape_prop: true

tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,15 +187,23 @@ def get_model_from_config_patched(config, **kwargs):
187187

188188
_config_from_pretrained_original = AutoConfig.from_pretrained
189189
_nemotron_h_base_model_tp_plan = {
190+
# mamba SSM layer
190191
"in_proj": "mamba",
191192
"out_proj": "rowwise",
193+
# attention layer
192194
"q_proj": "colwise",
193195
"k_proj": "colwise",
194196
"v_proj": "colwise",
195197
"o_proj": "rowwise",
198+
# NOTE: consider not sharding shared experts and/or
199+
# latent projections at all, keeping them replicated.
200+
# To do so, comment out the corresponding entries.
201+
# moe layer: SHARED experts
196202
"up_proj": "colwise",
197203
"down_proj": "rowwise",
198-
# "*": "gather",
204+
# MoLE: latent projections: simple shard
205+
"fc1_latent_proj": "gather",
206+
"fc2_latent_proj": "gather",
199207
}
200208

201209

0 commit comments

Comments
 (0)