From 54acf1b426a0eb7f962b0d4d5e9fab75b9666109 Mon Sep 17 00:00:00 2001 From: Adrian Lundell Date: Wed, 5 Feb 2025 17:11:07 +0100 Subject: [PATCH 1/4] [ARM backend] Update fuse_batchnorm_pass to create new placeholders - This allows to fuse bn+convs with multiple users of the same weights - Adds new util functions create/delete_const_placeholders to take care of updating the GraphSignature and state_dict/constants dict when handling constant placholders. - Adds and updates related tests Change-Id: I8e550614d9741de840786d9dca9f30af9eb95a64 --- backends/arm/_passes/arm_pass_utils.py | 128 ++++++++++++++++- backends/arm/_passes/fuse_batchnorm2d_pass.py | 129 +++++++++++------- ...test_create_delete_constant_placeholder.py | 93 +++++++++++++ .../test/passes/test_fuse_batchnorm_pass.py | 8 +- 4 files changed, 305 insertions(+), 53 deletions(-) create mode 100644 backends/arm/test/misc/test_create_delete_constant_placeholder.py diff --git a/backends/arm/_passes/arm_pass_utils.py b/backends/arm/_passes/arm_pass_utils.py index cb43acc7fdb..0680b9dc86b 100644 --- a/backends/arm/_passes/arm_pass_utils.py +++ b/backends/arm/_passes/arm_pass_utils.py @@ -1,6 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. -# Copyright 2024-2025 Arm Limited and/or its affiliates. # All rights reserved. +# Copyright 2024-2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -25,7 +25,13 @@ is_param, ) from torch._ops import OpOverload -from torch._subclasses.fake_tensor import FakeTensor +from torch._subclasses.fake_tensor import FakeTensor, FakeTensorConverter +from torch.export.graph_signature import ( + ExportGraphSignature, + InputKind, + InputSpec, + TensorArgument, +) def is_get_attr_node(node: torch.fx.Node) -> bool: @@ -64,6 +70,124 @@ def get_param_tensor( raise RuntimeError(f"unsupported param type, {node.op}.") +def create_constant_placeholder( + exp_program: ExportedProgram, + graph: torch.fx.Graph, + name: str, + kind: InputKind, + data: torch.Tensor, + persistent_buffer: Optional[bool] = None, +) -> torch.fx.Node: + """ + Creates and returns a constant placeholder node, meaning that it is of type parameter, buffer, or lifted constant tensor. + graph.inserting_before/after() should be used before the call to decide where to insert the node. + """ + + target = name + + # Add data to state_dict/ constants + match kind: + case InputKind.PARAMETER: + exp_program.state_dict[target] = torch.nn.Parameter( + data, requires_grad=False + ) + case InputKind.BUFFER: + if persistent_buffer is None: + raise RuntimeError( + "Must set persistent_buffer when creating a new buffer." + ) + elif persistent_buffer: + exp_program.state_dict[target] = data + else: + exp_program.constants[target] = data + + case InputKind.CONSTANT_TENSOR: + exp_program.constants[target] = data + case _: + raise RuntimeError("Can only create constant input nodes.") + + # Create node + fake_tensor_mode = get_first_fake_tensor( + list(graph.nodes)[0] + ).fake_mode # Use the same fake_tensor_mode as all other fake tensors in the graph + node = graph.create_node(op="placeholder", name=name, target=name) + node.meta["val"] = FakeTensorConverter().from_real_tensor(fake_tensor_mode, t=data) + + # Add tensor to graph_signature in the same order as nodes in the graph + node_names = [n.name for n in graph.nodes if n.op == "placeholder"] + node_index = node_names.index(name) + + input_specs = exp_program.graph_signature.input_specs + user_input_indices = [ + i for i, spec in enumerate(input_specs) if spec.kind == InputKind.USER_INPUT + ] + if not all( + (user_input_index > node_index for user_input_index in user_input_indices) + ): + raise RuntimeError( + f"Failed to insert {name}; Const placeholder nodes must be inserted before user input nodes in the graph." + ) + + arg_spec = TensorArgument(name) + input_spec = InputSpec(kind, arg_spec, target, persistent_buffer) + input_specs.insert(node_index, input_spec) + + new_graph_signature = ExportGraphSignature( + input_specs, exp_program.graph_signature.output_specs + ) + exp_program._graph_signature = new_graph_signature + + return node + + +def delete_constant_placeholder(exp_program: ExportedProgram, node: torch.fx.Node): + """ + Deletes a constant placeholder node, meaning that it is of type parameter, buffer, or lifted constant tensor, + if the node does not have any users. + """ + if not len(node.users) == 0: + raise RuntimeError( + f"Cannot delete input node {node.name} since it has users in the graph." + ) + + # Remove tensor from state_dict/ constants + if node.name in exp_program.graph_signature.inputs_to_parameters: + target = exp_program.graph_signature.inputs_to_parameters[node.name] + del exp_program.state_dict[target] + + elif node.name in exp_program.graph_signature.inputs_to_buffers: + target = exp_program.graph_signature.inputs_to_buffers[node.name] + + if target in exp_program.graph_signature.non_persistent_buffers: + del exp_program.constants[target] + else: + del exp_program.state_dict[target] + + elif node.name in exp_program.graph_signature.inputs_to_lifted_tensor_constants: + target = exp_program.graph_signature.inputs_to_lifted_tensor_constants[ + node.name + ] + del exp_program.constants[target] + else: + raise RuntimeError( + f"Cannot delete input node {node.name} since it is not a parameter, a buffer, nor a lifted tensor constant." + ) + + # Remove input from graph signature + input_specs = [ + spec + for spec in exp_program.graph_signature.input_specs + if spec.arg.name != node.name + ] + new_graph_signature = ExportGraphSignature( + input_specs, exp_program.graph_signature.output_specs + ) + exp_program._graph_signature = new_graph_signature + + # Remove node from graph + node.graph.erase_node(node) + + def create_node( graph: torch.fx.Graph, op_target: OpOverload, diff --git a/backends/arm/_passes/fuse_batchnorm2d_pass.py b/backends/arm/_passes/fuse_batchnorm2d_pass.py index 8675b340af4..7b238c02ded 100644 --- a/backends/arm/_passes/fuse_batchnorm2d_pass.py +++ b/backends/arm/_passes/fuse_batchnorm2d_pass.py @@ -6,10 +6,15 @@ # pyre-unsafe import torch +from executorch.backends.arm._passes.arm_pass_utils import ( + create_constant_placeholder, + delete_constant_placeholder, +) from executorch.exir import ExportedProgram from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult from torch._export.utils import get_buffer, get_param +from torch.export.graph_signature import InputKind from torch.fx import Node from torch.nn.utils.fusion import fuse_conv_bn_weights @@ -23,7 +28,7 @@ def __init__(self, exported_program: ExportedProgram): self.exported_program = exported_program super().__init__() - def is_fuseable_conv_bn(self, node: Node): + def is_fuseable_conv_bn(self, node: Node) -> bool: """Returns True if node is a batchnorm that can be fused into a parent convolution.""" if node.op != "call_function": @@ -44,15 +49,19 @@ def is_fuseable_conv_bn(self, node: Node): # Since we change the output of the conv, fuse only if it has single user. if len(conv.users) > 1: return False - # For similar reasons, only fuse if conv parameters have single user. - if len(conv.all_input_nodes[1].users) > 1: - return False - if len(conv.all_input_nodes) > 2 and len(conv.all_input_nodes[2].users) > 1: - return False return True + def get_bias_name(self, conv_weight_node: Node, conv_bias_node: Node) -> str: + if conv_bias_node: + return conv_bias_node.name + "_fused_bn" + elif "weight" in conv_weight_node.name: + return conv_weight_node.name.replace("weight", "bias") + "_fused_bn" + else: + return conv_weight_node.name + "_bias_fused_bn" + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901 modified = False + constant_placeholders_to_delete = set() for node in graph_module.graph.nodes: if not self.is_fuseable_conv_bn(node): continue @@ -64,67 +73,93 @@ def get_param_or_none(arg) -> torch.nn.Parameter | None: ) # Get weight, bias, mean, var and epsilon from the batchnorm - bn = node - conv, bn_weight_node, bn_bias_node, bn_mean_node, bn_var_node = bn.args[0:5] - bn_weight = get_param_or_none(bn_weight_node) - bn_bias = get_param_or_none(bn_bias_node) - - running_mean = get_buffer(self.exported_program, bn_mean_node) - running_var = get_buffer(self.exported_program, bn_var_node) - if running_mean is None or running_var is None: + bn_node = node + conv, bn_weight_node, bn_bias_node, bn_mean_node, bn_var_node = ( + bn_node.args[0:5] + ) + bn_weight_tensor = get_param_or_none(bn_weight_node) + bn_bias_tensor = get_param_or_none(bn_bias_node) + bn_mean_tensor = get_buffer(self.exported_program, bn_mean_node) + bn_var_tensor = get_buffer(self.exported_program, bn_var_node) + if bn_mean_tensor is None or bn_var_tensor is None: raise ValueError( "Parameters running_mean and running_var of batchnorm can't be None." ) - epsilon = bn.args[-1] + epsilon = bn_node.args[-1] # Get weight and bias from conv conv_weight_node, conv_bias_node = conv.args[1:3] - conv_weight = get_param(self.exported_program, conv_weight_node) - conv_bias = get_param_or_none(conv_bias_node) - if conv_weight is None: + conv_weight_tensor = get_param(self.exported_program, conv_weight_node) + conv_bias_tensor = get_param_or_none(conv_bias_node) + if conv_weight_tensor is None: raise ValueError("Parameter weight of convolution can't be None.") # Compute conv parameters folded with batchnorm fused_conv_weight, fused_conv_bias = fuse_conv_bn_weights( - conv_weight, - conv_bias, - running_mean, - running_var, + conv_weight_tensor, + conv_bias_tensor, + bn_mean_tensor, + bn_var_tensor, epsilon, - bn_weight, - bn_bias, + bn_weight_tensor, + bn_bias_tensor, ) - # Set the conv parameters to fused value - def try_set_param( - param_node: Node | None, param_value: torch.nn.Parameter - ) -> bool: - """set_param but check if param_node is None first. Return True if param was set successfully, otherwise False.""" - if param_node is not None: - param_name = ( - self.exported_program.graph_signature.inputs_to_parameters[ - param_node.name - ] - ) - self.exported_program.state_dict[param_name] = param_value - return True - return False + # Create fused weights and bias to conv and replace conv args + with graph_module.graph.inserting_before(conv_weight_node): + fused_conv_weight_node = create_constant_placeholder( + exp_program=self.exported_program, + graph=graph_module.graph, + kind=InputKind.PARAMETER, + name=conv_weight_node.name + "_fused_bn", + data=fused_conv_weight, + ) - try_set_param(conv_weight_node, fused_conv_weight) - if not try_set_param(conv_bias_node, fused_conv_bias) and try_set_param( - bn_bias_node, fused_conv_bias - ): - # Conv didn't have bias but batchnorm did, steal bias from batchnorm. - conv_args = (*conv.args[0:2], bn_bias_node, *conv.args[3:]) - conv.args = conv_args + if fused_conv_bias is not None: + fused_conv_bias_node = create_constant_placeholder( + exp_program=self.exported_program, + graph=graph_module.graph, + kind=InputKind.PARAMETER, + name=self.get_bias_name(conv_weight_node, conv_bias_node), + data=fused_conv_bias, + ) + else: + fused_conv_bias_node = None + + conv.args = ( + conv.args[0], + fused_conv_weight_node, + fused_conv_bias_node, + *conv.args[3:], + ) - # Erasing nodes is handled by dead-code elimination. - for user in bn.users: + # Erasing batch-norm nodes is handled by dead-code elimination. After that we may remove their constant placeholder inputs + for user in bn_node.users: user.replace_all_uses_with(conv) + + constant_placeholders_to_delete.update( + [ + bn_weight_node, + bn_bias_node, + bn_mean_node, + bn_var_node, + conv_weight_node, + conv_bias_node, + ] + ) modified = True if modified: graph_module.graph.eliminate_dead_code() + for constant_placeholder in constant_placeholders_to_delete: + if (constant_placeholder is not None) and ( + len(constant_placeholder.users) == 0 + ): + delete_constant_placeholder( + self.exported_program, constant_placeholder + ) + graph_module.recompile() graph_module = super().call(graph_module).graph_module + return PassResult(graph_module=graph_module, modified=modified) diff --git a/backends/arm/test/misc/test_create_delete_constant_placeholder.py b/backends/arm/test/misc/test_create_delete_constant_placeholder.py new file mode 100644 index 00000000000..378f7cf6994 --- /dev/null +++ b/backends/arm/test/misc/test_create_delete_constant_placeholder.py @@ -0,0 +1,93 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.backends.arm._passes.arm_pass_utils import ( + create_constant_placeholder, + create_node, + delete_constant_placeholder, +) +from executorch.exir import to_edge +from executorch.exir.dialects._ops import ops as exir_ops +from torch.export import export +from torch.export.graph_signature import InputKind + + +class EmptyNetwork(torch.nn.Module): + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + + test_data: torch.Tensor = (torch.zeros(1),) + + +def _test_create_delete(kind: InputKind, persistent_buffer: bool = None): + """ + Tests the utility functions create_constant_placeholder and delete_constant_placeholder + """ + + # Toy network with two nodes, input and output + # The result should be 0 = 0 + module = EmptyNetwork() + exported_program = export(module, args=module.test_data) + exported_program = to_edge(exported_program).exported_program() + graph = exported_program.graph_module.graph + assert len(graph.nodes) == 2 + assert exported_program.module()(torch.zeros(1)) == 0 + + # Create one const node with value 1 and add it to the input + input_node = list(graph.nodes)[0] + with graph.inserting_before(input_node): + const_node = create_constant_placeholder( + exp_program=exported_program, + graph=graph, + kind=kind, + name="test_node", + data=torch.ones(1), + persistent_buffer=persistent_buffer, + ) + assert "val" in const_node.meta + + with graph.inserting_after(input_node): + add_node = create_node( + graph=graph, + op_target=exir_ops.edge.aten.add.Tensor, + args=(input_node, const_node), + ) + + output_node = list(graph.nodes)[-1] + output_node.replace_input_with(input_node, add_node) + + # We should now have four nodes: test_node, input, add, output + # The result should be 0 + 1 = 1 + assert exported_program.module()(torch.zeros(1)) == 1 + assert len(graph.nodes) == 4 + + # Replacing the add op and using eliminate_dead_code() deletes the add op but not the input op + output_node.replace_input_with(add_node, input_node) + graph.eliminate_dead_code() + assert len(graph.nodes) == 3 + + # Delete the input op manually + # The result should again be 0 = 0 + delete_constant_placeholder(exported_program, const_node) + assert exported_program.module()(torch.zeros(1)) == 0 + assert len(graph.nodes) == 2 + + +def test_create_delete_parameter(): + _test_create_delete(InputKind.PARAMETER) + + +def test_create_delete_persistent_buffer(): + _test_create_delete(InputKind.BUFFER, True) + + +def test_create_delete_non_persistent_buffer(): + _test_create_delete(InputKind.BUFFER, False) + + +def test_create_delete_constant_tensor(): + _test_create_delete(InputKind.CONSTANT_TENSOR) diff --git a/backends/arm/test/passes/test_fuse_batchnorm_pass.py b/backends/arm/test/passes/test_fuse_batchnorm_pass.py index b18e536b155..1021423b222 100644 --- a/backends/arm/test/passes/test_fuse_batchnorm_pass.py +++ b/backends/arm/test/passes/test_fuse_batchnorm_pass.py @@ -85,13 +85,13 @@ def forward(self, x): return x -class MergeNoBN(torch.nn.Module): +class MergeMultipleUsersBN(torch.nn.Module): ops_before_pass = { "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default": 2, "executorch_exir_dialects_edge__ops_aten_convolution_default": 3, } ops_after_pass = { - "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default": 2, + "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default": 1, "executorch_exir_dialects_edge__ops_aten_convolution_default": 3, } @@ -122,7 +122,7 @@ def forward(self, x): z = self.conv2d2(x) a = self.batch_norm2d( y - ) # Can't be fused since paramters of conv2d2 have multiple users. + ) # Can be fused despite paramters of conv2d2 having multiple users. return z, a @@ -131,7 +131,7 @@ def forward(self, x): "merge_one_of_two_bn_affine": MergeOneOfTwoBN(True), "merge_one_of_two_bn": MergeOneOfTwoBN(False), "merge_two_of_two_bn_affine": MergeTwosOfTwoBN(True), - "merge_no_bn_affine": MergeNoBN(True), + "merge_multiple_users_bn_affine": MergeMultipleUsersBN(True), } From 71e791e171f4dbfc8b85a7d222868bef9b228aa9 Mon Sep 17 00:00:00 2001 From: Adrian Lundell Date: Mon, 17 Feb 2025 09:48:04 +0100 Subject: [PATCH 2/4] Move create/delete_constant_node utils to shared folder Change-Id: I3a82f58f9796e421bd205f030f7d79d72a2f7ed9 --- backends/arm/_passes/arm_pass_utils.py | 126 +--------------- backends/arm/_passes/fuse_batchnorm2d_pass.py | 2 +- ...test_create_delete_constant_placeholder.py | 42 +++++- backends/transforms/utils.py | 136 +++++++++++++++++- 4 files changed, 173 insertions(+), 133 deletions(-) rename backends/{arm/test/misc => transforms/test}/test_create_delete_constant_placeholder.py (63%) diff --git a/backends/arm/_passes/arm_pass_utils.py b/backends/arm/_passes/arm_pass_utils.py index 0680b9dc86b..3445886ffa7 100644 --- a/backends/arm/_passes/arm_pass_utils.py +++ b/backends/arm/_passes/arm_pass_utils.py @@ -25,13 +25,7 @@ is_param, ) from torch._ops import OpOverload -from torch._subclasses.fake_tensor import FakeTensor, FakeTensorConverter -from torch.export.graph_signature import ( - ExportGraphSignature, - InputKind, - InputSpec, - TensorArgument, -) +from torch._subclasses.fake_tensor import FakeTensor def is_get_attr_node(node: torch.fx.Node) -> bool: @@ -70,124 +64,6 @@ def get_param_tensor( raise RuntimeError(f"unsupported param type, {node.op}.") -def create_constant_placeholder( - exp_program: ExportedProgram, - graph: torch.fx.Graph, - name: str, - kind: InputKind, - data: torch.Tensor, - persistent_buffer: Optional[bool] = None, -) -> torch.fx.Node: - """ - Creates and returns a constant placeholder node, meaning that it is of type parameter, buffer, or lifted constant tensor. - graph.inserting_before/after() should be used before the call to decide where to insert the node. - """ - - target = name - - # Add data to state_dict/ constants - match kind: - case InputKind.PARAMETER: - exp_program.state_dict[target] = torch.nn.Parameter( - data, requires_grad=False - ) - case InputKind.BUFFER: - if persistent_buffer is None: - raise RuntimeError( - "Must set persistent_buffer when creating a new buffer." - ) - elif persistent_buffer: - exp_program.state_dict[target] = data - else: - exp_program.constants[target] = data - - case InputKind.CONSTANT_TENSOR: - exp_program.constants[target] = data - case _: - raise RuntimeError("Can only create constant input nodes.") - - # Create node - fake_tensor_mode = get_first_fake_tensor( - list(graph.nodes)[0] - ).fake_mode # Use the same fake_tensor_mode as all other fake tensors in the graph - node = graph.create_node(op="placeholder", name=name, target=name) - node.meta["val"] = FakeTensorConverter().from_real_tensor(fake_tensor_mode, t=data) - - # Add tensor to graph_signature in the same order as nodes in the graph - node_names = [n.name for n in graph.nodes if n.op == "placeholder"] - node_index = node_names.index(name) - - input_specs = exp_program.graph_signature.input_specs - user_input_indices = [ - i for i, spec in enumerate(input_specs) if spec.kind == InputKind.USER_INPUT - ] - if not all( - (user_input_index > node_index for user_input_index in user_input_indices) - ): - raise RuntimeError( - f"Failed to insert {name}; Const placeholder nodes must be inserted before user input nodes in the graph." - ) - - arg_spec = TensorArgument(name) - input_spec = InputSpec(kind, arg_spec, target, persistent_buffer) - input_specs.insert(node_index, input_spec) - - new_graph_signature = ExportGraphSignature( - input_specs, exp_program.graph_signature.output_specs - ) - exp_program._graph_signature = new_graph_signature - - return node - - -def delete_constant_placeholder(exp_program: ExportedProgram, node: torch.fx.Node): - """ - Deletes a constant placeholder node, meaning that it is of type parameter, buffer, or lifted constant tensor, - if the node does not have any users. - """ - if not len(node.users) == 0: - raise RuntimeError( - f"Cannot delete input node {node.name} since it has users in the graph." - ) - - # Remove tensor from state_dict/ constants - if node.name in exp_program.graph_signature.inputs_to_parameters: - target = exp_program.graph_signature.inputs_to_parameters[node.name] - del exp_program.state_dict[target] - - elif node.name in exp_program.graph_signature.inputs_to_buffers: - target = exp_program.graph_signature.inputs_to_buffers[node.name] - - if target in exp_program.graph_signature.non_persistent_buffers: - del exp_program.constants[target] - else: - del exp_program.state_dict[target] - - elif node.name in exp_program.graph_signature.inputs_to_lifted_tensor_constants: - target = exp_program.graph_signature.inputs_to_lifted_tensor_constants[ - node.name - ] - del exp_program.constants[target] - else: - raise RuntimeError( - f"Cannot delete input node {node.name} since it is not a parameter, a buffer, nor a lifted tensor constant." - ) - - # Remove input from graph signature - input_specs = [ - spec - for spec in exp_program.graph_signature.input_specs - if spec.arg.name != node.name - ] - new_graph_signature = ExportGraphSignature( - input_specs, exp_program.graph_signature.output_specs - ) - exp_program._graph_signature = new_graph_signature - - # Remove node from graph - node.graph.erase_node(node) - - def create_node( graph: torch.fx.Graph, op_target: OpOverload, diff --git a/backends/arm/_passes/fuse_batchnorm2d_pass.py b/backends/arm/_passes/fuse_batchnorm2d_pass.py index 7b238c02ded..9eb74aca145 100644 --- a/backends/arm/_passes/fuse_batchnorm2d_pass.py +++ b/backends/arm/_passes/fuse_batchnorm2d_pass.py @@ -6,7 +6,7 @@ # pyre-unsafe import torch -from executorch.backends.arm._passes.arm_pass_utils import ( +from executorch.backends.transforms.utils import ( create_constant_placeholder, delete_constant_placeholder, ) diff --git a/backends/arm/test/misc/test_create_delete_constant_placeholder.py b/backends/transforms/test/test_create_delete_constant_placeholder.py similarity index 63% rename from backends/arm/test/misc/test_create_delete_constant_placeholder.py rename to backends/transforms/test/test_create_delete_constant_placeholder.py index 378f7cf6994..0e1f5224b44 100644 --- a/backends/arm/test/misc/test_create_delete_constant_placeholder.py +++ b/backends/transforms/test/test_create_delete_constant_placeholder.py @@ -4,9 +4,8 @@ # LICENSE file in the root directory of this source tree. import torch -from executorch.backends.arm._passes.arm_pass_utils import ( +from executorch.backends.transforms.utils import ( create_constant_placeholder, - create_node, delete_constant_placeholder, ) from executorch.exir import to_edge @@ -36,6 +35,11 @@ def _test_create_delete(kind: InputKind, persistent_buffer: bool = None): graph = exported_program.graph_module.graph assert len(graph.nodes) == 2 assert exported_program.module()(torch.zeros(1)) == 0 + assert len(exported_program.graph_signature.input_specs) == 1 + assert len(exported_program.state_dict) == 0 + assert len(exported_program.constants) == 0 + + const_name = "test_node" # Create one const node with value 1 and add it to the input input_node = list(graph.nodes)[0] @@ -44,17 +48,18 @@ def _test_create_delete(kind: InputKind, persistent_buffer: bool = None): exp_program=exported_program, graph=graph, kind=kind, - name="test_node", + name=const_name, data=torch.ones(1), persistent_buffer=persistent_buffer, ) assert "val" in const_node.meta with graph.inserting_after(input_node): - add_node = create_node( - graph=graph, - op_target=exir_ops.edge.aten.add.Tensor, + add_node = graph.create_node( + "call_function", + exir_ops.edge.aten.add.Tensor, args=(input_node, const_node), + kwargs={}, ) output_node = list(graph.nodes)[-1] @@ -65,6 +70,28 @@ def _test_create_delete(kind: InputKind, persistent_buffer: bool = None): assert exported_program.module()(torch.zeros(1)) == 1 assert len(graph.nodes) == 4 + if kind == InputKind.PARAMETER: + assert const_name in exported_program.graph_signature.inputs_to_parameters + assert const_name in exported_program.state_dict + assert len(exported_program.constants) == 0 + elif kind == InputKind.BUFFER and persistent_buffer: + assert const_name in exported_program.graph_signature.inputs_to_buffers + assert const_name in exported_program.state_dict + assert len(exported_program.constants) == 0 + elif kind == InputKind.BUFFER and not persistent_buffer: + assert const_name in exported_program.graph_signature.inputs_to_buffers + assert len(exported_program.state_dict) == 0 + assert const_name in exported_program.constants + elif kind == InputKind.CONSTANT_TENSOR: + assert ( + const_name + in exported_program.graph_signature.inputs_to_lifted_tensor_constants + ) + assert len(exported_program.state_dict) == 0 + assert const_name in exported_program.constants + else: + raise RuntimeError("Wrong input kind") + # Replacing the add op and using eliminate_dead_code() deletes the add op but not the input op output_node.replace_input_with(add_node, input_node) graph.eliminate_dead_code() @@ -75,6 +102,9 @@ def _test_create_delete(kind: InputKind, persistent_buffer: bool = None): delete_constant_placeholder(exported_program, const_node) assert exported_program.module()(torch.zeros(1)) == 0 assert len(graph.nodes) == 2 + assert len(exported_program.graph_signature.input_specs) == 1 + assert len(exported_program.state_dict) == 0 + assert len(exported_program.constants) == 0 def test_create_delete_parameter(): diff --git a/backends/transforms/utils.py b/backends/transforms/utils.py index 03c48039b93..4e451928ee4 100644 --- a/backends/transforms/utils.py +++ b/backends/transforms/utils.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -8,7 +9,6 @@ import torch from executorch.exir import ExportedProgram - from torch._export.utils import ( get_buffer, get_lifted_tensor_constant, @@ -17,6 +17,13 @@ is_lifted_tensor_constant, is_param, ) +from torch._subclasses.fake_tensor import FakeTensorConverter +from torch.export.graph_signature import ( + ExportGraphSignature, + InputKind, + InputSpec, + TensorArgument, +) def is_get_attr_node(node: torch.fx.Node) -> bool: @@ -53,3 +60,130 @@ def get_param_tensor( except AttributeError: return getattr(exp_prog.graph_module, node.target) raise RuntimeError(f"unsupported param type, {node.op}.") + + +def create_constant_placeholder( + exp_program: ExportedProgram, + graph: torch.fx.Graph, + name: str, + kind: InputKind, + data: torch.Tensor, + persistent_buffer: Optional[bool] = None, +) -> torch.fx.Node: + """ + Creates and returns a constant placeholder node, meaning that it is of type parameter, buffer, + or lifted constant tensor. graph.inserting_before/after() should be used before the call to + decide where to insert the node, at an insertion point before the first input node. + """ + + target = name + + # Add data to state_dict/ constants + match kind: + case InputKind.PARAMETER: + exp_program.state_dict[target] = torch.nn.Parameter( + data, requires_grad=False + ) + case InputKind.BUFFER: + if persistent_buffer is None: + raise RuntimeError( + "Must set persistent_buffer when creating a new buffer." + ) + elif persistent_buffer: + exp_program.state_dict[target] = data + else: + exp_program.constants[target] = data + case InputKind.CONSTANT_TENSOR: + exp_program.constants[target] = data + case _: + raise RuntimeError("Can only create constant input nodes.") + + # Create fake tensor using the same fake_mode as the other fake tensors in the graph + example_node = list(graph.nodes)[0] + if isinstance( + example_node.meta["val"], (tuple, torch.fx.immutable_collections.immutable_list) + ): + example_fake_tensor = example_node.meta["val"][0] + else: + example_fake_tensor = example_node.meta["val"] + fake_tensor = FakeTensorConverter().from_real_tensor( + example_fake_tensor.fake_mode, t=data + ) + + # Create node + node = graph.create_node(op="placeholder", name=name, target=name) + node.meta["val"] = fake_tensor + + # Add tensor to graph_signature in the same order as nodes in the graph + node_names = [n.name for n in graph.nodes if n.op == "placeholder"] + node_index = node_names.index(name) + + input_specs = exp_program.graph_signature.input_specs + user_input_indices = [ + i for i, spec in enumerate(input_specs) if spec.kind == InputKind.USER_INPUT + ] + if not all( + (user_input_index >= node_index for user_input_index in user_input_indices) + ): + raise RuntimeError( + f"Failed to insert {name}; Const placeholder nodes must be inserted before user input nodes in the graph." + ) + + arg_spec = TensorArgument(name) + input_spec = InputSpec(kind, arg_spec, target, persistent_buffer) + input_specs.insert(node_index, input_spec) + + new_graph_signature = ExportGraphSignature( + input_specs, exp_program.graph_signature.output_specs + ) + exp_program._graph_signature = new_graph_signature + + return node + + +def delete_constant_placeholder(exp_program: ExportedProgram, node: torch.fx.Node): + """ + Deletes a node of type parameter, buffer, or lifted constant tensor and its related + graph signature and state_dict/constant entries. The node may not have any users. + """ + if not len(node.users) == 0: + raise RuntimeError( + f"Cannot delete input node {node.name} since it has users in the graph." + ) + + # Remove tensor from state_dict/ constants + if node.name in exp_program.graph_signature.inputs_to_parameters: + target = exp_program.graph_signature.inputs_to_parameters[node.name] + del exp_program.state_dict[target] + + elif node.name in exp_program.graph_signature.inputs_to_buffers: + target = exp_program.graph_signature.inputs_to_buffers[node.name] + + if target in exp_program.graph_signature.non_persistent_buffers: + del exp_program.constants[target] + else: + del exp_program.state_dict[target] + + elif node.name in exp_program.graph_signature.inputs_to_lifted_tensor_constants: + target = exp_program.graph_signature.inputs_to_lifted_tensor_constants[ + node.name + ] + del exp_program.constants[target] + else: + raise RuntimeError( + f"Cannot delete input node {node.name} since it is not a parameter, a buffer, nor a lifted tensor constant." + ) + + # Remove input from graph signature + input_specs = [ + spec + for spec in exp_program.graph_signature.input_specs + if spec.arg.name != node.name + ] + new_graph_signature = ExportGraphSignature( + input_specs, exp_program.graph_signature.output_specs + ) + exp_program._graph_signature = new_graph_signature + + # Remove node from graph + node.graph.erase_node(node) From 67bc6cd18688ae24fc4882671719d0f884ac2ba2 Mon Sep 17 00:00:00 2001 From: Adrian Lundell Date: Fri, 21 Feb 2025 10:21:34 +0100 Subject: [PATCH 3/4] Add buck dependency --- backends/arm/_passes/TARGETS | 1 + 1 file changed, 1 insertion(+) diff --git a/backends/arm/_passes/TARGETS b/backends/arm/_passes/TARGETS index 6ca59cfee27..151e8f6d8ad 100644 --- a/backends/arm/_passes/TARGETS +++ b/backends/arm/_passes/TARGETS @@ -9,5 +9,6 @@ python_library( "//executorch/backends/arm:tosa_utils", "//executorch/backends/xnnpack/_passes:xnnpack_passes", "//executorch/exir:lib", + "//executorch/backends/transforms:utils", ], ) From 6f719674f13eb927e00f56b4978e7f8469972f89 Mon Sep 17 00:00:00 2001 From: Adrian Lundell Date: Mon, 24 Feb 2025 09:45:50 +0100 Subject: [PATCH 4/4] Fix bazel build --- backends/transforms/targets.bzl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/backends/transforms/targets.bzl b/backends/transforms/targets.bzl index c532798546d..30ace744077 100644 --- a/backends/transforms/targets.bzl +++ b/backends/transforms/targets.bzl @@ -149,6 +149,9 @@ def define_common_targets(): runtime.python_library( name = "utils", srcs = ["utils.py"], + visibility = [ + "//executorch/backends/...", + ], deps = [ "//caffe2:torch", "//executorch/exir:lib",