Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 126 additions & 2 deletions backends/arm/_passes/arm_pass_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Copyright 2024-2025 Arm Limited and/or its affiliates.
# All rights reserved.
# Copyright 2024-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -25,7 +25,13 @@
is_param,
)
from torch._ops import OpOverload
from torch._subclasses.fake_tensor import FakeTensor
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorConverter
from torch.export.graph_signature import (
ExportGraphSignature,
InputKind,
InputSpec,
TensorArgument,
)


def is_get_attr_node(node: torch.fx.Node) -> bool:
Expand Down Expand Up @@ -64,6 +70,124 @@ def get_param_tensor(
raise RuntimeError(f"unsupported param type, {node.op}.")


def create_constant_placeholder(
exp_program: ExportedProgram,
graph: torch.fx.Graph,
name: str,
kind: InputKind,
data: torch.Tensor,
persistent_buffer: Optional[bool] = None,
) -> torch.fx.Node:
"""
Creates and returns a constant placeholder node, meaning that it is of type parameter, buffer, or lifted constant tensor.
graph.inserting_before/after() should be used before the call to decide where to insert the node.
"""

target = name

# Add data to state_dict/ constants
match kind:
case InputKind.PARAMETER:
exp_program.state_dict[target] = torch.nn.Parameter(
data, requires_grad=False
)
case InputKind.BUFFER:
if persistent_buffer is None:
raise RuntimeError(
"Must set persistent_buffer when creating a new buffer."
)
elif persistent_buffer:
exp_program.state_dict[target] = data
else:
exp_program.constants[target] = data

case InputKind.CONSTANT_TENSOR:
exp_program.constants[target] = data
case _:
raise RuntimeError("Can only create constant input nodes.")

# Create node
fake_tensor_mode = get_first_fake_tensor(
list(graph.nodes)[0]
).fake_mode # Use the same fake_tensor_mode as all other fake tensors in the graph
node = graph.create_node(op="placeholder", name=name, target=name)
node.meta["val"] = FakeTensorConverter().from_real_tensor(fake_tensor_mode, t=data)

# Add tensor to graph_signature in the same order as nodes in the graph
node_names = [n.name for n in graph.nodes if n.op == "placeholder"]
node_index = node_names.index(name)

input_specs = exp_program.graph_signature.input_specs
user_input_indices = [
i for i, spec in enumerate(input_specs) if spec.kind == InputKind.USER_INPUT
]
if not all(
(user_input_index > node_index for user_input_index in user_input_indices)
):
raise RuntimeError(
f"Failed to insert {name}; Const placeholder nodes must be inserted before user input nodes in the graph."
)

arg_spec = TensorArgument(name)
input_spec = InputSpec(kind, arg_spec, target, persistent_buffer)
input_specs.insert(node_index, input_spec)

new_graph_signature = ExportGraphSignature(
input_specs, exp_program.graph_signature.output_specs
)
exp_program._graph_signature = new_graph_signature

return node


def delete_constant_placeholder(exp_program: ExportedProgram, node: torch.fx.Node):
"""
Deletes a constant placeholder node, meaning that it is of type parameter, buffer, or lifted constant tensor,
if the node does not have any users.
"""
if not len(node.users) == 0:
raise RuntimeError(
f"Cannot delete input node {node.name} since it has users in the graph."
)

# Remove tensor from state_dict/ constants
if node.name in exp_program.graph_signature.inputs_to_parameters:
target = exp_program.graph_signature.inputs_to_parameters[node.name]
del exp_program.state_dict[target]

elif node.name in exp_program.graph_signature.inputs_to_buffers:
target = exp_program.graph_signature.inputs_to_buffers[node.name]

if target in exp_program.graph_signature.non_persistent_buffers:
del exp_program.constants[target]
else:
del exp_program.state_dict[target]

elif node.name in exp_program.graph_signature.inputs_to_lifted_tensor_constants:
target = exp_program.graph_signature.inputs_to_lifted_tensor_constants[
node.name
]
del exp_program.constants[target]
else:
raise RuntimeError(
f"Cannot delete input node {node.name} since it is not a parameter, a buffer, nor a lifted tensor constant."
)

# Remove input from graph signature
input_specs = [
spec
for spec in exp_program.graph_signature.input_specs
if spec.arg.name != node.name
]
new_graph_signature = ExportGraphSignature(
input_specs, exp_program.graph_signature.output_specs
)
exp_program._graph_signature = new_graph_signature

# Remove node from graph
node.graph.erase_node(node)


def create_node(
graph: torch.fx.Graph,
op_target: OpOverload,
Expand Down
129 changes: 82 additions & 47 deletions backends/arm/_passes/fuse_batchnorm2d_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,15 @@
# pyre-unsafe

import torch
from executorch.backends.arm._passes.arm_pass_utils import (
create_constant_placeholder,
delete_constant_placeholder,
)
from executorch.exir import ExportedProgram
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
from torch._export.utils import get_buffer, get_param
from torch.export.graph_signature import InputKind
from torch.fx import Node
from torch.nn.utils.fusion import fuse_conv_bn_weights

Expand All @@ -23,7 +28,7 @@ def __init__(self, exported_program: ExportedProgram):
self.exported_program = exported_program
super().__init__()

def is_fuseable_conv_bn(self, node: Node):
def is_fuseable_conv_bn(self, node: Node) -> bool:
"""Returns True if node is a batchnorm that can be fused into
a parent convolution."""
if node.op != "call_function":
Expand All @@ -44,15 +49,19 @@ def is_fuseable_conv_bn(self, node: Node):
# Since we change the output of the conv, fuse only if it has single user.
if len(conv.users) > 1:
return False
# For similar reasons, only fuse if conv parameters have single user.
if len(conv.all_input_nodes[1].users) > 1:
return False
if len(conv.all_input_nodes) > 2 and len(conv.all_input_nodes[2].users) > 1:
return False
return True

def get_bias_name(self, conv_weight_node: Node, conv_bias_node: Node) -> str:
if conv_bias_node:
return conv_bias_node.name + "_fused_bn"
elif "weight" in conv_weight_node.name:
return conv_weight_node.name.replace("weight", "bias") + "_fused_bn"
else:
return conv_weight_node.name + "_bias_fused_bn"

def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901
modified = False
constant_placeholders_to_delete = set()
for node in graph_module.graph.nodes:
if not self.is_fuseable_conv_bn(node):
continue
Expand All @@ -64,67 +73,93 @@ def get_param_or_none(arg) -> torch.nn.Parameter | None:
)

# Get weight, bias, mean, var and epsilon from the batchnorm
bn = node
conv, bn_weight_node, bn_bias_node, bn_mean_node, bn_var_node = bn.args[0:5]
bn_weight = get_param_or_none(bn_weight_node)
bn_bias = get_param_or_none(bn_bias_node)

running_mean = get_buffer(self.exported_program, bn_mean_node)
running_var = get_buffer(self.exported_program, bn_var_node)
if running_mean is None or running_var is None:
bn_node = node
conv, bn_weight_node, bn_bias_node, bn_mean_node, bn_var_node = (
bn_node.args[0:5]
)
bn_weight_tensor = get_param_or_none(bn_weight_node)
bn_bias_tensor = get_param_or_none(bn_bias_node)
bn_mean_tensor = get_buffer(self.exported_program, bn_mean_node)
bn_var_tensor = get_buffer(self.exported_program, bn_var_node)
if bn_mean_tensor is None or bn_var_tensor is None:
raise ValueError(
"Parameters running_mean and running_var of batchnorm can't be None."
)
epsilon = bn.args[-1]
epsilon = bn_node.args[-1]

# Get weight and bias from conv
conv_weight_node, conv_bias_node = conv.args[1:3]
conv_weight = get_param(self.exported_program, conv_weight_node)
conv_bias = get_param_or_none(conv_bias_node)
if conv_weight is None:
conv_weight_tensor = get_param(self.exported_program, conv_weight_node)
conv_bias_tensor = get_param_or_none(conv_bias_node)
if conv_weight_tensor is None:
raise ValueError("Parameter weight of convolution can't be None.")

# Compute conv parameters folded with batchnorm
fused_conv_weight, fused_conv_bias = fuse_conv_bn_weights(
conv_weight,
conv_bias,
running_mean,
running_var,
conv_weight_tensor,
conv_bias_tensor,
bn_mean_tensor,
bn_var_tensor,
epsilon,
bn_weight,
bn_bias,
bn_weight_tensor,
bn_bias_tensor,
)

# Set the conv parameters to fused value
def try_set_param(
param_node: Node | None, param_value: torch.nn.Parameter
) -> bool:
"""set_param but check if param_node is None first. Return True if param was set successfully, otherwise False."""
if param_node is not None:
param_name = (
self.exported_program.graph_signature.inputs_to_parameters[
param_node.name
]
)
self.exported_program.state_dict[param_name] = param_value
return True
return False
# Create fused weights and bias to conv and replace conv args
with graph_module.graph.inserting_before(conv_weight_node):
fused_conv_weight_node = create_constant_placeholder(
exp_program=self.exported_program,
graph=graph_module.graph,
kind=InputKind.PARAMETER,
name=conv_weight_node.name + "_fused_bn",
data=fused_conv_weight,
)

try_set_param(conv_weight_node, fused_conv_weight)
if not try_set_param(conv_bias_node, fused_conv_bias) and try_set_param(
bn_bias_node, fused_conv_bias
):
# Conv didn't have bias but batchnorm did, steal bias from batchnorm.
conv_args = (*conv.args[0:2], bn_bias_node, *conv.args[3:])
conv.args = conv_args
if fused_conv_bias is not None:
fused_conv_bias_node = create_constant_placeholder(
exp_program=self.exported_program,
graph=graph_module.graph,
kind=InputKind.PARAMETER,
name=self.get_bias_name(conv_weight_node, conv_bias_node),
data=fused_conv_bias,
)
else:
fused_conv_bias_node = None

conv.args = (
conv.args[0],
fused_conv_weight_node,
fused_conv_bias_node,
*conv.args[3:],
)

# Erasing nodes is handled by dead-code elimination.
for user in bn.users:
# Erasing batch-norm nodes is handled by dead-code elimination. After that we may remove their constant placeholder inputs
for user in bn_node.users:
user.replace_all_uses_with(conv)

constant_placeholders_to_delete.update(
[
bn_weight_node,
bn_bias_node,
bn_mean_node,
bn_var_node,
conv_weight_node,
conv_bias_node,
]
)
modified = True

if modified:
graph_module.graph.eliminate_dead_code()
for constant_placeholder in constant_placeholders_to_delete:
if (constant_placeholder is not None) and (
len(constant_placeholder.users) == 0
):
delete_constant_placeholder(
self.exported_program, constant_placeholder
)

graph_module.recompile()
graph_module = super().call(graph_module).graph_module

return PassResult(graph_module=graph_module, modified=modified)
Loading
Loading