Skip to content

Commit 11c3ef4

Browse files
Support for partial config from factory
Signed-off-by: greg-kwasniewski1 <[email protected]>
1 parent 3349e0f commit 11c3ef4

File tree

4 files changed

+65
-8
lines changed

4 files changed

+65
-8
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def extract_from_precompiled(precompiled_location: str, package_data: List[str],
215215
precompiled_location = download_precompiled(tempdir, version)
216216
extract_from_precompiled(precompiled_location, package_data, tempdir)
217217

218-
# sanity_check()
218+
sanity_check()
219219

220220
# https://setuptools.pypa.io/en/latest/references/keywords.html
221221
setup(

tensorrt_llm/_torch/auto_deploy/config/default.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ transforms:
5656
stage: sharding
5757
simple_shard_only: false
5858
use_sharding_from_factory: false
59-
support_partial_config: true
59+
support_partial_config: false
6060
sharding_dims: ['tp', 'ep', 'bmm']
6161
# TODO: (hg) need to ensure run_shape_prop after sharding.
6262
sharding_transform_executor:

tensorrt_llm/_torch/auto_deploy/llm_args.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings):
5858
)
5959

6060
model_factory: Literal["AutoModelForCausalLM", "AutoModelForImageTextToText"] = Field(
61-
default="AutoModelForCausalLM",
61+
default="AutoModelForImageTextToText",
6262
description="The model factory to use for loading the model.",
6363
)
6464

tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py

Lines changed: 62 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,8 @@ def _apply(
186186
)
187187
shared_config.sharding_config.simple_shard_only = self.config.simple_shard_only
188188
shared_config.sharding_config.support_partial_config = self.config.support_partial_config
189+
shared_config.sharding_config.sharding_dims = self.config.sharding_dims
190+
189191
shared_config.sharding_config.use_sharding_from_factory = (
190192
self.config.use_sharding_from_factory
191193
)
@@ -201,8 +203,6 @@ def _apply(
201203
factory_info = detect_sharding_from_factory_config(gm, sharding_config)
202204
return gm, factory_info
203205

204-
shared_config.sharding_config.sharding_dims = self.config.sharding_dims
205-
206206
ad_logger.info(
207207
f"Running autodeploy sharding heuristics: {shared_config.sharding_config.sharding_dims}"
208208
)
@@ -339,8 +339,39 @@ def detect_sharding_from_factory_config(
339339
# TODO: Sequence parallelism is not supported yet.
340340
ad_logger.warning("Sequence parallelism is not supported yet. Skipping.")
341341
elif "local" in config:
342-
# TODO: local refers to hybrid EP+TP parallelism. Not supported yet.
343-
ad_logger.warning("Local EP+TP sharding is not supported yet. Skipping.")
342+
# Check if this applies to shared experts in EP parallelism.
343+
# If yes, apply the TP col-row shard.
344+
if "shared" in module_name:
345+
col_row_action = config.replace("local_", "")
346+
if col_row_action == "colwise":
347+
sharding_config.tp_transforms.append(
348+
TPShardingInfo(
349+
target_node=lin_node.name,
350+
split_dim=SplitDimension.COLUMN,
351+
rank=rank,
352+
world_size=world_size,
353+
dist_op=None,
354+
min_local_shape=min_local_shape,
355+
)
356+
)
357+
elif col_row_action == "rowwise":
358+
sharding_config.tp_transforms.append(
359+
TPShardingInfo(
360+
target_node=lin_node.name,
361+
split_dim=SplitDimension.ROW,
362+
rank=rank,
363+
world_size=world_size,
364+
dist_op="all_reduce",
365+
min_local_shape=min_local_shape,
366+
)
367+
)
368+
num_row_col_shards += 1
369+
else:
370+
ad_logger.warning("Invalid sharding config. Skipping.")
371+
else:
372+
# TODO: local refers to hybrid EP+TP parallelism. Not supported yet.
373+
ad_logger.warning("Local EP+TP sharding is not supported yet. Skipping.")
374+
344375
elif "gather" in config:
345376
# Simple shard (row + all_gather)
346377
sharding_config.tp_transforms.append(
@@ -363,9 +394,35 @@ def detect_sharding_from_factory_config(
363394
f"Applied {num_shards} TP shards (simple: {num_simple_shards}, "
364395
f"row-col pattern: {num_row_col_shards})"
365396
)
397+
398+
num_matches = len(sharding_config.tp_transforms)
399+
400+
if sharding_config.support_partial_config:
401+
ad_logger.info(
402+
f"Partial factory config applied only for TP. "
403+
f"Applying heuristics for {sharding_config.sharding_dims}."
404+
)
405+
406+
# run EP sharding across ranks
407+
if "ep" in sharding_config.sharding_dims:
408+
ep_info = detect_ep_shard(gm, sharding_config)
409+
else:
410+
ep_info = TransformInfo(
411+
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
412+
)
413+
414+
# run BMM sharding across ranks
415+
if "bmm" in sharding_config.sharding_dims:
416+
dp_bmm_info = detect_dp_bmm_shard(gm, sharding_config)
417+
else:
418+
dp_bmm_info = TransformInfo(
419+
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
420+
)
421+
num_matches += ep_info.num_matches + dp_bmm_info.num_matches
422+
366423
return TransformInfo(
367424
skipped=False,
368-
num_matches=len(sharding_config.tp_transforms),
425+
num_matches=num_matches,
369426
is_clean=False,
370427
has_valid_shapes=False,
371428
)

0 commit comments

Comments
 (0)