diff --git a/backends/arm/_passes/TARGETS b/backends/arm/_passes/TARGETS index f8bf9c0d208..c56eaca8d4b 100644 --- a/backends/arm/_passes/TARGETS +++ b/backends/arm/_passes/TARGETS @@ -9,5 +9,6 @@ python_library( "//executorch/backends/transforms:replace_scalar_with_tensor", "//executorch/backends/xnnpack/_passes:xnnpack_passes", "//executorch/exir:lib", + "//executorch/backends/transforms:utils", ], ) diff --git a/backends/arm/_passes/arm_pass_utils.py b/backends/arm/_passes/arm_pass_utils.py index cb43acc7fdb..3445886ffa7 100644 --- a/backends/arm/_passes/arm_pass_utils.py +++ b/backends/arm/_passes/arm_pass_utils.py @@ -1,6 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. -# Copyright 2024-2025 Arm Limited and/or its affiliates. # All rights reserved. +# Copyright 2024-2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. diff --git a/backends/arm/_passes/fuse_batchnorm2d_pass.py b/backends/arm/_passes/fuse_batchnorm2d_pass.py index 6cb7548a70c..9eb74aca145 100644 --- a/backends/arm/_passes/fuse_batchnorm2d_pass.py +++ b/backends/arm/_passes/fuse_batchnorm2d_pass.py @@ -6,10 +6,15 @@ # pyre-unsafe import torch +from executorch.backends.transforms.utils import ( + create_constant_placeholder, + delete_constant_placeholder, +) from executorch.exir import ExportedProgram from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult from torch._export.utils import get_buffer, get_param +from torch.export.graph_signature import InputKind from torch.fx import Node from torch.nn.utils.fusion import fuse_conv_bn_weights @@ -23,7 +28,7 @@ def __init__(self, exported_program: ExportedProgram): self.exported_program = exported_program super().__init__() - def is_fuseable_conv_bn(self, node: Node): + def is_fuseable_conv_bn(self, node: Node) -> bool: """Returns True if node is a batchnorm that can be fused into a parent convolution.""" if node.op != "call_function": @@ -44,15 +49,19 @@ def is_fuseable_conv_bn(self, node: Node): # Since we change the output of the conv, fuse only if it has single user. if len(conv.users) > 1: return False - # For similar reasons, only fuse if conv parameters have single user. - if len(conv.all_input_nodes[1].users) > 1: - return False - if len(conv.all_input_nodes) > 2 and len(conv.all_input_nodes[2].users) > 1: - return False return True + def get_bias_name(self, conv_weight_node: Node, conv_bias_node: Node) -> str: + if conv_bias_node: + return conv_bias_node.name + "_fused_bn" + elif "weight" in conv_weight_node.name: + return conv_weight_node.name.replace("weight", "bias") + "_fused_bn" + else: + return conv_weight_node.name + "_bias_fused_bn" + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901 modified = False + constant_placeholders_to_delete = set() for node in graph_module.graph.nodes: if not self.is_fuseable_conv_bn(node): continue @@ -64,68 +73,93 @@ def get_param_or_none(arg) -> torch.nn.Parameter | None: ) # Get weight, bias, mean, var and epsilon from the batchnorm - bn = node - conv, bn_weight_node, bn_bias_node, bn_mean_node, bn_var_node = bn.args[0:5] - bn_weight = get_param_or_none(bn_weight_node) - bn_bias = get_param_or_none(bn_bias_node) - - running_mean = get_buffer(self.exported_program, bn_mean_node) - running_var = get_buffer(self.exported_program, bn_var_node) - if running_mean is None or running_var is None: + bn_node = node + conv, bn_weight_node, bn_bias_node, bn_mean_node, bn_var_node = ( + bn_node.args[0:5] + ) + bn_weight_tensor = get_param_or_none(bn_weight_node) + bn_bias_tensor = get_param_or_none(bn_bias_node) + bn_mean_tensor = get_buffer(self.exported_program, bn_mean_node) + bn_var_tensor = get_buffer(self.exported_program, bn_var_node) + if bn_mean_tensor is None or bn_var_tensor is None: raise ValueError( "Parameters running_mean and running_var of batchnorm can't be None." ) - epsilon = bn.args[-1] + epsilon = bn_node.args[-1] # Get weight and bias from conv conv_weight_node, conv_bias_node = conv.args[1:3] - conv_weight = get_param(self.exported_program, conv_weight_node) - conv_bias = get_param_or_none(conv_bias_node) - if conv_weight is None: + conv_weight_tensor = get_param(self.exported_program, conv_weight_node) + conv_bias_tensor = get_param_or_none(conv_bias_node) + if conv_weight_tensor is None: raise ValueError("Parameter weight of convolution can't be None.") # Compute conv parameters folded with batchnorm fused_conv_weight, fused_conv_bias = fuse_conv_bn_weights( - conv_weight, - conv_bias, - running_mean, - running_var, + conv_weight_tensor, + conv_bias_tensor, + bn_mean_tensor, + bn_var_tensor, epsilon, - bn_weight, - bn_bias, + bn_weight_tensor, + bn_bias_tensor, ) - # Set the conv parameters to fused value - def try_set_param( - param_node: Node | None, param_value: torch.nn.Parameter - ) -> bool: - """set_param but check if param_node is None first. Return True if param was set successfully, otherwise False.""" - if param_node is not None: - param_name = ( - self.exported_program.graph_signature.inputs_to_parameters[ - param_node.name - ] + # Create fused weights and bias to conv and replace conv args + with graph_module.graph.inserting_before(conv_weight_node): + fused_conv_weight_node = create_constant_placeholder( + exp_program=self.exported_program, + graph=graph_module.graph, + kind=InputKind.PARAMETER, + name=conv_weight_node.name + "_fused_bn", + data=fused_conv_weight, + ) + + if fused_conv_bias is not None: + fused_conv_bias_node = create_constant_placeholder( + exp_program=self.exported_program, + graph=graph_module.graph, + kind=InputKind.PARAMETER, + name=self.get_bias_name(conv_weight_node, conv_bias_node), + data=fused_conv_bias, ) - self.exported_program.state_dict[param_name] = param_value - return True - return False + else: + fused_conv_bias_node = None + + conv.args = ( + conv.args[0], + fused_conv_weight_node, + fused_conv_bias_node, + *conv.args[3:], + ) - try_set_param(conv_weight_node, fused_conv_weight) - if not try_set_param(conv_bias_node, fused_conv_bias) and try_set_param( - bn_bias_node, fused_conv_bias - ): - # pyre-ignore[60] - # Conv didn't have bias but batchnorm did, steal bias from batchnorm. - conv_args = (*conv.args[0:2], bn_bias_node, *conv.args[3:]) - conv.args = conv_args - - # Erasing nodes is handled by dead-code elimination. - for user in bn.users: + # Erasing batch-norm nodes is handled by dead-code elimination. After that we may remove their constant placeholder inputs + for user in bn_node.users: user.replace_all_uses_with(conv) + + constant_placeholders_to_delete.update( + [ + bn_weight_node, + bn_bias_node, + bn_mean_node, + bn_var_node, + conv_weight_node, + conv_bias_node, + ] + ) modified = True if modified: graph_module.graph.eliminate_dead_code() + for constant_placeholder in constant_placeholders_to_delete: + if (constant_placeholder is not None) and ( + len(constant_placeholder.users) == 0 + ): + delete_constant_placeholder( + self.exported_program, constant_placeholder + ) + graph_module.recompile() graph_module = super().call(graph_module).graph_module + return PassResult(graph_module=graph_module, modified=modified) diff --git a/backends/arm/test/passes/test_fuse_batchnorm_pass.py b/backends/arm/test/passes/test_fuse_batchnorm_pass.py index 45b3253f848..415aa9f6132 100644 --- a/backends/arm/test/passes/test_fuse_batchnorm_pass.py +++ b/backends/arm/test/passes/test_fuse_batchnorm_pass.py @@ -85,13 +85,13 @@ def forward(self, x): return x -class MergeNoBN(torch.nn.Module): +class MergeMultipleUsersBN(torch.nn.Module): ops_before_pass = { "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default": 2, "executorch_exir_dialects_edge__ops_aten_convolution_default": 3, } ops_after_pass = { - "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default": 2, + "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default": 1, "executorch_exir_dialects_edge__ops_aten_convolution_default": 3, } @@ -122,7 +122,7 @@ def forward(self, x): z = self.conv2d2(x) a = self.batch_norm2d( y - ) # Can't be fused since paramters of conv2d2 have multiple users. + ) # Can be fused despite paramters of conv2d2 having multiple users. return z, a @@ -131,7 +131,7 @@ def forward(self, x): "merge_one_of_two_bn_affine": MergeOneOfTwoBN(True), "merge_one_of_two_bn": MergeOneOfTwoBN(False), "merge_two_of_two_bn_affine": MergeTwosOfTwoBN(True), - "merge_no_bn_affine": MergeNoBN(True), + "merge_multiple_users_bn_affine": MergeMultipleUsersBN(True), } diff --git a/backends/transforms/targets.bzl b/backends/transforms/targets.bzl index ec4e1412862..66ff9111f52 100644 --- a/backends/transforms/targets.bzl +++ b/backends/transforms/targets.bzl @@ -149,6 +149,9 @@ def define_common_targets(): runtime.python_library( name = "utils", srcs = ["utils.py"], + visibility = [ + "//executorch/backends/...", + ], deps = [ "//caffe2:torch", "//executorch/exir:lib", diff --git a/backends/transforms/test/test_create_delete_constant_placeholder.py b/backends/transforms/test/test_create_delete_constant_placeholder.py new file mode 100644 index 00000000000..0e1f5224b44 --- /dev/null +++ b/backends/transforms/test/test_create_delete_constant_placeholder.py @@ -0,0 +1,123 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.backends.transforms.utils import ( + create_constant_placeholder, + delete_constant_placeholder, +) +from executorch.exir import to_edge +from executorch.exir.dialects._ops import ops as exir_ops +from torch.export import export +from torch.export.graph_signature import InputKind + + +class EmptyNetwork(torch.nn.Module): + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + + test_data: torch.Tensor = (torch.zeros(1),) + + +def _test_create_delete(kind: InputKind, persistent_buffer: bool = None): + """ + Tests the utility functions create_constant_placeholder and delete_constant_placeholder + """ + + # Toy network with two nodes, input and output + # The result should be 0 = 0 + module = EmptyNetwork() + exported_program = export(module, args=module.test_data) + exported_program = to_edge(exported_program).exported_program() + graph = exported_program.graph_module.graph + assert len(graph.nodes) == 2 + assert exported_program.module()(torch.zeros(1)) == 0 + assert len(exported_program.graph_signature.input_specs) == 1 + assert len(exported_program.state_dict) == 0 + assert len(exported_program.constants) == 0 + + const_name = "test_node" + + # Create one const node with value 1 and add it to the input + input_node = list(graph.nodes)[0] + with graph.inserting_before(input_node): + const_node = create_constant_placeholder( + exp_program=exported_program, + graph=graph, + kind=kind, + name=const_name, + data=torch.ones(1), + persistent_buffer=persistent_buffer, + ) + assert "val" in const_node.meta + + with graph.inserting_after(input_node): + add_node = graph.create_node( + "call_function", + exir_ops.edge.aten.add.Tensor, + args=(input_node, const_node), + kwargs={}, + ) + + output_node = list(graph.nodes)[-1] + output_node.replace_input_with(input_node, add_node) + + # We should now have four nodes: test_node, input, add, output + # The result should be 0 + 1 = 1 + assert exported_program.module()(torch.zeros(1)) == 1 + assert len(graph.nodes) == 4 + + if kind == InputKind.PARAMETER: + assert const_name in exported_program.graph_signature.inputs_to_parameters + assert const_name in exported_program.state_dict + assert len(exported_program.constants) == 0 + elif kind == InputKind.BUFFER and persistent_buffer: + assert const_name in exported_program.graph_signature.inputs_to_buffers + assert const_name in exported_program.state_dict + assert len(exported_program.constants) == 0 + elif kind == InputKind.BUFFER and not persistent_buffer: + assert const_name in exported_program.graph_signature.inputs_to_buffers + assert len(exported_program.state_dict) == 0 + assert const_name in exported_program.constants + elif kind == InputKind.CONSTANT_TENSOR: + assert ( + const_name + in exported_program.graph_signature.inputs_to_lifted_tensor_constants + ) + assert len(exported_program.state_dict) == 0 + assert const_name in exported_program.constants + else: + raise RuntimeError("Wrong input kind") + + # Replacing the add op and using eliminate_dead_code() deletes the add op but not the input op + output_node.replace_input_with(add_node, input_node) + graph.eliminate_dead_code() + assert len(graph.nodes) == 3 + + # Delete the input op manually + # The result should again be 0 = 0 + delete_constant_placeholder(exported_program, const_node) + assert exported_program.module()(torch.zeros(1)) == 0 + assert len(graph.nodes) == 2 + assert len(exported_program.graph_signature.input_specs) == 1 + assert len(exported_program.state_dict) == 0 + assert len(exported_program.constants) == 0 + + +def test_create_delete_parameter(): + _test_create_delete(InputKind.PARAMETER) + + +def test_create_delete_persistent_buffer(): + _test_create_delete(InputKind.BUFFER, True) + + +def test_create_delete_non_persistent_buffer(): + _test_create_delete(InputKind.BUFFER, False) + + +def test_create_delete_constant_tensor(): + _test_create_delete(InputKind.CONSTANT_TENSOR) diff --git a/backends/transforms/utils.py b/backends/transforms/utils.py index 03c48039b93..4e451928ee4 100644 --- a/backends/transforms/utils.py +++ b/backends/transforms/utils.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -8,7 +9,6 @@ import torch from executorch.exir import ExportedProgram - from torch._export.utils import ( get_buffer, get_lifted_tensor_constant, @@ -17,6 +17,13 @@ is_lifted_tensor_constant, is_param, ) +from torch._subclasses.fake_tensor import FakeTensorConverter +from torch.export.graph_signature import ( + ExportGraphSignature, + InputKind, + InputSpec, + TensorArgument, +) def is_get_attr_node(node: torch.fx.Node) -> bool: @@ -53,3 +60,130 @@ def get_param_tensor( except AttributeError: return getattr(exp_prog.graph_module, node.target) raise RuntimeError(f"unsupported param type, {node.op}.") + + +def create_constant_placeholder( + exp_program: ExportedProgram, + graph: torch.fx.Graph, + name: str, + kind: InputKind, + data: torch.Tensor, + persistent_buffer: Optional[bool] = None, +) -> torch.fx.Node: + """ + Creates and returns a constant placeholder node, meaning that it is of type parameter, buffer, + or lifted constant tensor. graph.inserting_before/after() should be used before the call to + decide where to insert the node, at an insertion point before the first input node. + """ + + target = name + + # Add data to state_dict/ constants + match kind: + case InputKind.PARAMETER: + exp_program.state_dict[target] = torch.nn.Parameter( + data, requires_grad=False + ) + case InputKind.BUFFER: + if persistent_buffer is None: + raise RuntimeError( + "Must set persistent_buffer when creating a new buffer." + ) + elif persistent_buffer: + exp_program.state_dict[target] = data + else: + exp_program.constants[target] = data + case InputKind.CONSTANT_TENSOR: + exp_program.constants[target] = data + case _: + raise RuntimeError("Can only create constant input nodes.") + + # Create fake tensor using the same fake_mode as the other fake tensors in the graph + example_node = list(graph.nodes)[0] + if isinstance( + example_node.meta["val"], (tuple, torch.fx.immutable_collections.immutable_list) + ): + example_fake_tensor = example_node.meta["val"][0] + else: + example_fake_tensor = example_node.meta["val"] + fake_tensor = FakeTensorConverter().from_real_tensor( + example_fake_tensor.fake_mode, t=data + ) + + # Create node + node = graph.create_node(op="placeholder", name=name, target=name) + node.meta["val"] = fake_tensor + + # Add tensor to graph_signature in the same order as nodes in the graph + node_names = [n.name for n in graph.nodes if n.op == "placeholder"] + node_index = node_names.index(name) + + input_specs = exp_program.graph_signature.input_specs + user_input_indices = [ + i for i, spec in enumerate(input_specs) if spec.kind == InputKind.USER_INPUT + ] + if not all( + (user_input_index >= node_index for user_input_index in user_input_indices) + ): + raise RuntimeError( + f"Failed to insert {name}; Const placeholder nodes must be inserted before user input nodes in the graph." + ) + + arg_spec = TensorArgument(name) + input_spec = InputSpec(kind, arg_spec, target, persistent_buffer) + input_specs.insert(node_index, input_spec) + + new_graph_signature = ExportGraphSignature( + input_specs, exp_program.graph_signature.output_specs + ) + exp_program._graph_signature = new_graph_signature + + return node + + +def delete_constant_placeholder(exp_program: ExportedProgram, node: torch.fx.Node): + """ + Deletes a node of type parameter, buffer, or lifted constant tensor and its related + graph signature and state_dict/constant entries. The node may not have any users. + """ + if not len(node.users) == 0: + raise RuntimeError( + f"Cannot delete input node {node.name} since it has users in the graph." + ) + + # Remove tensor from state_dict/ constants + if node.name in exp_program.graph_signature.inputs_to_parameters: + target = exp_program.graph_signature.inputs_to_parameters[node.name] + del exp_program.state_dict[target] + + elif node.name in exp_program.graph_signature.inputs_to_buffers: + target = exp_program.graph_signature.inputs_to_buffers[node.name] + + if target in exp_program.graph_signature.non_persistent_buffers: + del exp_program.constants[target] + else: + del exp_program.state_dict[target] + + elif node.name in exp_program.graph_signature.inputs_to_lifted_tensor_constants: + target = exp_program.graph_signature.inputs_to_lifted_tensor_constants[ + node.name + ] + del exp_program.constants[target] + else: + raise RuntimeError( + f"Cannot delete input node {node.name} since it is not a parameter, a buffer, nor a lifted tensor constant." + ) + + # Remove input from graph signature + input_specs = [ + spec + for spec in exp_program.graph_signature.input_specs + if spec.arg.name != node.name + ] + new_graph_signature = ExportGraphSignature( + input_specs, exp_program.graph_signature.output_specs + ) + exp_program._graph_signature = new_graph_signature + + # Remove node from graph + node.graph.erase_node(node)