From 75fc7d3d96e42bdf6fe8355048ff05677d3265b0 Mon Sep 17 00:00:00 2001 From: Janet Yang Date: Mon, 12 Aug 2024 17:02:09 -0700 Subject: [PATCH] [lowering] Add max_acc_splits (#133041) Summary: X-link: https://github.com/facebookincubator/AITemplate/pull/1017 Pull Request resolved: https://github.com/pytorch/pytorch/pull/133041 Model owners can set the lower_settings with max_acc_splits=2, and lowering will fail during model iteration, to alert them of possible performance degradation from increased fragmentation. Test Plan: Added unit tests Reviewed By: frank-wei Differential Revision: D60133589 --- torch/fx/passes/splitter_base.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/torch/fx/passes/splitter_base.py b/torch/fx/passes/splitter_base.py index 53ad10309f43e..e97441e39ba54 100644 --- a/torch/fx/passes/splitter_base.py +++ b/torch/fx/passes/splitter_base.py @@ -41,7 +41,8 @@ def __init__( self, min_acc_module_size=DEFAULT_MIN_ACC_MODULE_SIZE, skip_fusion=DEFAULT_SKIP_FUSION, - allow_non_tensor=DEFAULT_ALLOW_NON_TENSOR + allow_non_tensor=DEFAULT_ALLOW_NON_TENSOR, + max_acc_splits: int = -1, ): parser = argparse.ArgumentParser() parser.add_argument( @@ -51,6 +52,13 @@ def __init__( type=int, help="Minimum size limit of an accelerator subgraph.", ) + parser.add_argument( + "--max-acc_splits", + "--max-acc_splits", + required=False, + type=int, + help="Enforce a maximum number of split subgraphs.", + ) parser.add_argument( "--skip-fusion", "--skip_fusion", @@ -78,6 +86,7 @@ def __init__( self.min_acc_module_size: int = args.min_acc_module_size if args.min_acc_module_size else min_acc_module_size self.skip_fusion: bool = args.skip_fusion if args.skip_fusion else skip_fusion self.allow_non_tensor: bool = args.allow_non_tensor if args.allow_non_tensor else allow_non_tensor + self.max_acc_splits: int = max_acc_splits @compatibility(is_backward_compatible=False) @@ -876,5 +885,15 @@ def generate_split_results(self) -> SplitResult: submodule_names = [] for name, mod in split_module.named_children(): submodule_names.append(name) + if ( + self.settings.max_acc_splits > 0 + and len(submodule_names) > self.settings.max_acc_splits + ): + raise ValueError( + "Cannot fulfill max_acc_splits limit. " + "This may cause split fragmentation and " + "result in performance issues." + ) + submodule_inputs = generate_inputs_for_submodules(split_module, self.sample_input, submodule_names) return SplitResult(split_module, submodule_inputs, self.non_acc_submodule_name)