Skip to content

Commit 3349e0f

Browse files
wip partial sharding
Signed-off-by: greg-kwasniewski1 <[email protected]>
1 parent 80cbe2a commit 3349e0f

File tree

5 files changed

+14
-3
lines changed

5 files changed

+14
-3
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def extract_from_precompiled(precompiled_location: str, package_data: List[str],
215215
precompiled_location = download_precompiled(tempdir, version)
216216
extract_from_precompiled(precompiled_location, package_data, tempdir)
217217

218-
sanity_check()
218+
# sanity_check()
219219

220220
# https://setuptools.pypa.io/en/latest/references/keywords.html
221221
setup(

tensorrt_llm/_torch/auto_deploy/config/default.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ transforms:
5656
stage: sharding
5757
simple_shard_only: false
5858
use_sharding_from_factory: false
59+
support_partial_config: true
5960
sharding_dims: ['tp', 'ep', 'bmm']
6061
# TODO: (hg) need to ensure run_shape_prop after sharding.
6162
sharding_transform_executor:

tensorrt_llm/_torch/auto_deploy/llm_args.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,13 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings):
165165
"AutoDeployConfig.",
166166
)
167167

168+
support_partial_config: bool = Field(
169+
default=False,
170+
description="If True, factory sharding will be applied to the subset of transformations "
171+
"that are currently supported. If False, sharding from factory will be performed only if"
172+
"all provided transformations are supported.",
173+
)
174+
168175
sharding_dims: List[str] = Field(
169176
default=["tp", "ep", "dp"],
170177
description="The sharding methods to apply by the heuristic sharding stage.",

tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ class ShardingTransformConfig(TransformConfig):
126126

127127
simple_shard_only: bool = Field(default=False)
128128
use_sharding_from_factory: bool = Field(default=False)
129+
support_partial_config: bool = Field(default=False)
129130
# Which sharding families to run: any subset of {"tp", "ep", "bmm"}
130131
sharding_dims: List[str] = Field(default_factory=lambda: ["tp", "ep", "bmm"])
131132

@@ -184,6 +185,7 @@ def _apply(
184185
else ShardingConfigSource.UNKNOWN
185186
)
186187
shared_config.sharding_config.simple_shard_only = self.config.simple_shard_only
188+
shared_config.sharding_config.support_partial_config = self.config.support_partial_config
187189
shared_config.sharding_config.use_sharding_from_factory = (
188190
self.config.use_sharding_from_factory
189191
)

tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -488,6 +488,7 @@ class ShardingConfig(BaseModel):
488488
predefined_config: Optional[Dict[str, Any]] = None
489489
simple_shard_only: bool = Field(default=False)
490490
use_sharding_from_factory: bool = False
491+
support_partial_config: bool = False
491492
sharding_dims: List[str] = Field(default_factory=list)
492493
tp_transforms: List[TPShardingInfo] = Field(default_factory=list)
493494
bmm_transforms: List[BMMShardingInfo] = Field(default_factory=list)
@@ -532,7 +533,7 @@ def validate_config(self) -> bool:
532533
tp_plan = self.predefined_config["tp_plan"]
533534

534535
values = set(tp_plan.values())
535-
allowed_values = {
536+
supported_modes = {
536537
"colwise", # row split and no collective
537538
"rowwise", # column split and all-reduce
538539
"gather", # simple shard (row + all_gather)
@@ -544,7 +545,7 @@ def validate_config(self) -> bool:
544545
# "local_packed_rowwise",
545546
# "local",
546547
}
547-
if not values.issubset(allowed_values):
548+
if not self.support_partial_config and not values.issubset(supported_modes):
548549
ad_logger.warning("Sharding config contains invalid values. Skipping.")
549550
# invalidate the config
550551
self.predefined_config = {}

0 commit comments

Comments
 (0)