Skip to content

Commit

Permalink
[lowering] Add max_acc_splits (pytorch#133041)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookincubator/AITemplate#1017

Pull Request resolved: pytorch#133041

Model owners can set the lower_settings with max_acc_splits=2, and lowering will fail during model iteration, to alert them of possible performance degradation from increased fragmentation.

Test Plan: Added unit tests

Reviewed By: frank-wei

Differential Revision: D60133589
  • Loading branch information
qxy11 authored and facebook-github-bot committed Aug 13, 2024
1 parent 00aa086 commit 75fc7d3
Showing 1 changed file with 20 additions and 1 deletion.
21 changes: 20 additions & 1 deletion torch/fx/passes/splitter_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ def __init__(
self,
min_acc_module_size=DEFAULT_MIN_ACC_MODULE_SIZE,
skip_fusion=DEFAULT_SKIP_FUSION,
allow_non_tensor=DEFAULT_ALLOW_NON_TENSOR
allow_non_tensor=DEFAULT_ALLOW_NON_TENSOR,
max_acc_splits: int = -1,
):
parser = argparse.ArgumentParser()
parser.add_argument(
Expand All @@ -51,6 +52,13 @@ def __init__(
type=int,
help="Minimum size limit of an accelerator subgraph.",
)
parser.add_argument(
"--max-acc_splits",
"--max-acc_splits",
required=False,
type=int,
help="Enforce a maximum number of split subgraphs.",
)
parser.add_argument(
"--skip-fusion",
"--skip_fusion",
Expand Down Expand Up @@ -78,6 +86,7 @@ def __init__(
self.min_acc_module_size: int = args.min_acc_module_size if args.min_acc_module_size else min_acc_module_size
self.skip_fusion: bool = args.skip_fusion if args.skip_fusion else skip_fusion
self.allow_non_tensor: bool = args.allow_non_tensor if args.allow_non_tensor else allow_non_tensor
self.max_acc_splits: int = max_acc_splits


@compatibility(is_backward_compatible=False)
Expand Down Expand Up @@ -876,5 +885,15 @@ def generate_split_results(self) -> SplitResult:
submodule_names = []
for name, mod in split_module.named_children():
submodule_names.append(name)
if (
self.settings.max_acc_splits > 0
and len(submodule_names) > self.settings.max_acc_splits
):
raise ValueError(
"Cannot fulfill max_acc_splits limit. "
"This may cause split fragmentation and "
"result in performance issues."
)

submodule_inputs = generate_inputs_for_submodules(split_module, self.sample_input, submodule_names)
return SplitResult(split_module, submodule_inputs, self.non_acc_submodule_name)

0 comments on commit 75fc7d3

Please sign in to comment.