Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
)
from .fuse_batchnorm2d_pass import FuseBatchnorm2DPass # noqa
from .fuse_constant_ops_pass import ComputeConstantOpsAOT, FuseConstantArgsPass # noqa
from .fuse_equal_placeholders_pass import FuseEqualPlaceholdersPass # noqa
from .fuse_quantized_activation_pass import FuseQuantizedActivationPass # noqa
from .insert_rescales_pass import InsertRescalePass # noqa
from .insert_table_ops import InsertTableOpsPass # noqa
Expand Down
3 changes: 3 additions & 0 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
FoldAndAnnotateQParamsPass,
FuseBatchnorm2DPass,
FuseConstantArgsPass,
FuseEqualPlaceholdersPass,
FuseQuantizedActivationPass,
InsertRescalePass,
InsertTableOpsPass,
Expand Down Expand Up @@ -108,6 +109,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
self.add_pass(FuseConstantArgsPass(exported_program))

self.add_pass(InsertTableOpsPass(exported_program))
self.add_pass(FuseEqualPlaceholdersPass(exported_program))
self.add_pass(AnnotateChannelsLastDimOrder())
self.add_pass(InsertRescalePass())

Expand Down Expand Up @@ -155,6 +157,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
self.add_pass(FuseViewCopyTransform())
self.add_pass(FuseConstantArgsPass(exported_program))
self.add_pass(InsertTableOpsPass(exported_program))
self.add_pass(FuseEqualPlaceholdersPass(exported_program))
self.add_pass(AnnotateChannelsLastDimOrder())
self.add_pass(InsertRescalePass())

Expand Down
83 changes: 83 additions & 0 deletions backends/arm/_passes/fuse_equal_placeholders_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.backends.arm._passes.arm_pass_utils import (
get_constant_placeholder_kind,
get_param_tensor,
is_param_node,
)
from executorch.backends.transforms.utils import (
create_constant_placeholder,
delete_constant_placeholder,
)
from executorch.exir import ExportedProgram
from executorch.exir.pass_base import ExportPass, PassResult


class FuseEqualPlaceholdersPass(ExportPass):
"""
This pass optimizes memory usage by finding constant placeholders
pointing to identical tensors and fusing them to one single placeholder
with multiple users.
"""

def __init__(self, exported_program: ExportedProgram):
self.exported_program = exported_program
super().__init__()

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
modified = False
const_placeholder_nodes = []
for node in graph_module.graph.nodes:
if is_param_node(self.exported_program, node):
const_placeholder_nodes.append(node)

while const_placeholder_nodes:

# Find equal tensors
node1 = const_placeholder_nodes.pop()
eq_nodes = [node1]
tensor1 = get_param_tensor(self.exported_program, node1)
if tensor1 is None:
continue

for node2 in const_placeholder_nodes:
tensor2 = get_param_tensor(self.exported_program, node2)
if tensor2 is None:
continue

if torch.equal(tensor1, tensor2):
eq_nodes.append(node2)

if len(eq_nodes) > 1:
common_name = node1.name + "_common"
common_kind = get_constant_placeholder_kind(
self.exported_program, node1
)
common_persisten_buffer = True

with graph_module.graph.inserting_before(node1):
common_node = create_constant_placeholder(
self.exported_program,
graph_module.graph,
common_name,
common_kind,
tensor1,
common_persisten_buffer,
)

for eq_node in eq_nodes:
eq_node.replace_all_uses_with(common_node)
delete_constant_placeholder(self.exported_program, eq_node)
if eq_node != node1:
const_placeholder_nodes.remove(eq_node)

modified = True

if modified:
graph_module.recompile()
graph_module = super().call(graph_module).graph_module
return PassResult(graph_module=graph_module, modified=modified)
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from copy import deepcopy
from typing import Tuple

import torch
from executorch.backends.arm._passes.fuse_equal_placeholders_pass import (
FuseEqualPlaceholdersPass,
)
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline

input_t = Tuple[torch.Tensor] # Input x


class FuseWeightsConstants(torch.nn.Module):
ops_before_pass = {}
ops_after_pass = {}
ops_not_after_pass = []

def __init__(
self,
):
super().__init__()
self.weights1 = torch.rand(1, 2, 1)
self.weights2 = deepcopy(self.weights1)
self.bias1 = torch.rand(1)
self.bias2 = deepcopy(self.bias1)
self.bias3 = deepcopy(self.bias1)

def forward(self, x):
return (
torch.conv1d(x, self.weights1, self.bias1)
+ torch.conv1d(x, self.weights2, self.bias2)
+ self.bias3
)


class FuseWeightsStateDict(torch.nn.Module):
ops_before_pass = {}
ops_after_pass = {}
ops_not_after_pass = []

def __init__(
self,
):
super().__init__()
self.fc1 = torch.nn.Linear(in_features=8, out_features=2, bias=True)
self.fc2 = deepcopy(self.fc1)

def forward(self, x):
return self.fc1(x) + self.fc2(x)


def test_fuse_equal_placeholders_constants_tosa_MI():
module = FuseWeightsConstants()
data = (torch.rand(1, 2, 8),)
pipeline = PassPipeline[input_t](
module,
data,
tosa_version="TOSA-0.80+MI",
ops_before_pass=module.ops_before_pass,
ops_after_pass=module.ops_after_pass,
passes_with_exported_program=[FuseEqualPlaceholdersPass],
)
pipeline.run()

# Check that weights and bias has been merged.
exp_program = pipeline.tester.get_artifact().exported_program()
constant_keys = list(exp_program.constants.keys())
assert len(constant_keys) == 2, "FuseEqualPlaceholders constants failed"
assert "_common" in constant_keys[0], "FuseEqualPlaceholders constants failed"
assert "_common" in constant_keys[1], "FuseEqualPlaceholders constants failed"


def test_fuse_equal_placeholders_state_dict_tosa_MI():
module = FuseWeightsStateDict()
data = (torch.rand(1, 2, 8),)
pipeline = PassPipeline[input_t](
module,
data,
tosa_version="TOSA-0.80+MI",
ops_before_pass=module.ops_before_pass,
ops_after_pass=module.ops_after_pass,
passes_with_exported_program=[FuseEqualPlaceholdersPass],
)
pipeline.run()

# Check that weights and bias has been merged.
exp_program = pipeline.tester.get_artifact().exported_program()
state_dict_keys = list(exp_program.state_dict.keys())
assert len(state_dict_keys) == 2, "FuseEqualPlaceholders state_dict failed"
assert "_common" in state_dict_keys[0], "FuseEqualPlaceholders state_dict failed"
assert "_common" in state_dict_keys[1], "FuseEqualPlaceholders state_dict failed"
2 changes: 1 addition & 1 deletion examples/arm/setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ tosa_reference_model_rev="70ed0b40fa831387e36abdb4f7fb9670a3464f5a"

# vela
vela_repo_url="https://gitlab.arm.com/artificial-intelligence/ethos-u/ethos-u-vela"
vela_rev="425541302c7e4b6fbeca7c0061286b131ee507c3"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the heads up @zingo / @AdrianLundell . I think we will pull this one in.

vela_rev="859cc066178a87ff28230c1ce9bd370f1e98aa5a"

########
### Optional user args
Expand Down
Loading