diff --git a/exir/passes/constant_prop_pass.py b/exir/passes/constant_prop_pass.py index 14ff651c936..764efffa18f 100644 --- a/exir/passes/constant_prop_pass.py +++ b/exir/passes/constant_prop_pass.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Callable, List, Optional + import torch from torch._export.utils import get_buffer, get_param, is_buffer, is_param from torch._guards import detect_fake_mode @@ -11,11 +13,27 @@ from torch.export.exported_program import InputKind, InputSpec, TensorArgument -def is_const(arg, exported_program, const_data_list) -> bool: +_PRIMITIVE_TYPES = ( + float, + int, + bool, + str, + torch.Tensor, + torch.device, + torch.dtype, + torch.layout, +) + + +def is_const( + arg: object, exported_program: ExportedProgram, const_data_list: List[str] +) -> bool: if isinstance(arg, (tuple, list)): return all(is_const(x, exported_program, const_data_list) for x in arg) elif isinstance(arg, dict): return all(is_const(x, exported_program, const_data_list) for x in arg.values()) + elif isinstance(arg, _PRIMITIVE_TYPES): + return True elif not isinstance(arg, torch.fx.Node) or arg.op != "placeholder": return False elif ( @@ -27,9 +45,11 @@ def is_const(arg, exported_program, const_data_list) -> bool: return False -def get_data(exported_program, arg): +def get_data(exported_program: ExportedProgram, arg): if isinstance(arg, (tuple, list)): return [get_data(exported_program, x) for x in arg] + elif isinstance(arg, _PRIMITIVE_TYPES): + return arg elif is_param(exported_program, arg): return get_param(exported_program, arg) elif is_buffer(exported_program, arg): @@ -37,7 +57,10 @@ def get_data(exported_program, arg): return None -def constant_prop_pass(exported_program: ExportedProgram) -> ExportedProgram: +def constant_prop_pass( + exported_program: ExportedProgram, + skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None, +) -> ExportedProgram: """ This pass is for constant propagation for Exported Program with lifted parameters, as the parameters will not be shown up as `get_attr` but as `placeholder` to the graph. @@ -56,12 +79,14 @@ def constant_prop_pass(exported_program: ExportedProgram) -> ExportedProgram: if len(has_cond) > 0: raise RuntimeError("constant_prop_pass for control flow is not supported yet.") + first_user_input_idx = -1 first_user_input = None - for node in exported_program.graph.nodes: + for i, node in enumerate(exported_program.graph.nodes): if ( node.op == "placeholder" and node.name in exported_program.graph_signature.user_inputs ): + first_user_input_idx = i first_user_input = node break @@ -79,6 +104,9 @@ def constant_prop_pass(exported_program: ExportedProgram) -> ExportedProgram: assert fake_mode is not None for node in exported_program.graph.nodes: + if skip_folding_node_fn is not None and skip_folding_node_fn(node): + # Do not process this node if we were told to skip it. + continue if node.op == "call_function": constant_data_name_list = [ input_spec.target for input_spec in prop_constant_data @@ -115,9 +143,11 @@ def constant_prop_pass(exported_program: ExportedProgram) -> ExportedProgram: exported_program.state_dict[prop_constant_tensor_fqn] = ( prop_constant_tensor ) - exported_program.graph_signature.input_specs.append( - prop_constant_node_input_spec + # Insert new buffers before the first user input. + exported_program.graph_signature.input_specs.insert( + first_user_input_idx, prop_constant_node_input_spec ) + first_user_input_idx += 1 # Remove the propogated buffer from the state dict for node in exported_program.graph.nodes: @@ -128,6 +158,16 @@ def constant_prop_pass(exported_program: ExportedProgram) -> ExportedProgram: ): exported_program.state_dict.pop(node.name, None) exported_program.graph.erase_node(node) + # Delete the input spec for this deleted buffer. + to_erase_idx = [] + for i, spec in enumerate(exported_program.graph_signature.input_specs): + if spec.arg.name == node.name: + to_erase_idx.append(i) + assert ( + len(to_erase_idx) == 1 + ), f"Should only delete one spec per node, but deleting multiple: {to_erase_idx} {exported_program.graph_signature.input_specs}" + for i in reversed(to_erase_idx): + exported_program.graph_signature.input_specs.pop(i) exported_program.graph_module.recompile() return exported_program