Skip to content

Commit b74340f

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
Update ReplaceAddMMWithLinearPass to use new pass interface
Summary: As titled, now correctly setting the modified bit. Differential Revision: D86909981
1 parent e9b7170 commit b74340f

File tree

2 files changed

+76
-76
lines changed

2 files changed

+76
-76
lines changed

backends/cadence/aot/replace_ops.py

Lines changed: 70 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ def call_operator(self, op, args, kwargs, meta):
320320

321321

322322
@register_cadence_pass(CadencePassAttribute(opt_level=1))
323-
class ReplaceAddMMWithLinearPass(ExportPass):
323+
class ReplaceAddMMWithLinearPass(RemoveOrReplacePassInterface):
324324
"""
325325
This pass replaces addmm with linear op.
326326
"""
@@ -329,81 +329,79 @@ def __init__(self):
329329
super().__init__()
330330
self.counter = 0
331331

332-
def replace_addmm_with_linear(self, graph_module: torch.fx.GraphModule):
333-
graph = graph_module.graph
334-
for node in graph.nodes:
335-
# We are only interested in admm nodes
336-
if node.target != exir_ops.edge.aten.addmm.default:
337-
continue
338-
339-
# The addmm op has three concrete args: input, mat1, mat2
340-
assert len(node.args) >= 3
341-
(bias, mat1, mat2) = node.args[0:3]
342-
# The other two args are optional scale args
343-
beta = node.kwargs.get("beta", 1.0)
344-
alpha = node.kwargs.get("alpha", 1.0)
345-
346-
# AddMM performs beta*bias + alpha*mm(mat1, mat2). We can convert
347-
# it to linear op by multiplying beta to bias, and alpha to mat2.t().
348-
# However, the following two conditions must hold:
349-
# a. If bias is not a param, then beta must be 1.0
350-
# b. If mat2 is not a param, then mat2 must be a transpose op. Also,
351-
# the input to the transpose must be a param, or alpha must be 1.0.
352-
fit_bias = is_node_with_op(bias, "get_attr") or beta == 1.0
353-
fit_mat2 = is_node_with_op(mat2, "get_attr")
354-
transposed_mat2 = False
355-
if (
356-
not fit_mat2
357-
and is_node_with_op(mat2, "call_function")
358-
and mat2.target == exir_ops.edge.aten.transpose_copy.int
359-
):
360-
mat2, transposed_mat2 = mat2.args[0], True
361-
fit_mat2 = is_node_with_op(mat2, "get_attr") or alpha == 1.0
332+
@property
333+
def targets(self) -> list[EdgeOpOverload]:
334+
return [exir_ops.edge.aten.addmm.default]
362335

363-
if not fit_bias or not fit_mat2:
364-
continue
336+
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
337+
# The addmm op has three concrete args: input, mat1, mat2
338+
assert len(node.args) >= 3
339+
(bias, mat1, mat2) = node.args[0:3]
340+
# The other two args are optional scale args
341+
beta = float(node.kwargs.get("beta", 1.0))
342+
alpha = float(node.kwargs.get("alpha", 1.0))
343+
344+
bias = cast(torch.fx.Node, bias)
345+
mat1 = cast(torch.fx.Node, mat1)
346+
mat2 = cast(torch.fx.Node, mat2)
347+
348+
# AddMM performs beta*bias + alpha*mm(mat1, mat2). We can convert
349+
# it to linear op by multiplying beta to bias, and alpha to mat2.t().
350+
# However, the following two conditions must hold:
351+
# a. If bias is not a param, then beta must be 1.0
352+
# b. If mat2 is not a param, then mat2 must be a transpose op. Also,
353+
# the input to the transpose must be a param, or alpha must be 1.0.
354+
fit_bias = is_node_with_op(bias, "get_attr") or beta == 1.0
355+
fit_mat2 = is_node_with_op(mat2, "get_attr")
356+
transposed_mat2 = False
357+
if (
358+
not fit_mat2
359+
and is_node_with_op(mat2, "call_function")
360+
and mat2.target == exir_ops.edge.aten.transpose_copy.int
361+
):
362+
mat2, transposed_mat2 = cast(torch.fx.Node, mat2.args[0]), True
363+
fit_mat2 = is_node_with_op(mat2, "get_attr") or alpha == 1.0
365364

366-
# Multiply bias by beta
367-
if beta != 1.0:
368-
assert is_node_with_op(bias, "get_attr")
369-
bias_tensor = get_tensor_from_attr(graph_module, bias)
370-
assert isinstance(bias_tensor, torch.Tensor)
371-
bias_tensor = beta * bias_tensor
372-
with graph.inserting_before(node):
373-
bias_name = f"_bias_addmm_to_linear_{self.counter}"
374-
graph_module.register_buffer(bias_name, bias_tensor)
375-
bias = graph.get_attr(bias_name)
376-
377-
# Use associativity of scalar multiplication, and multiply alpha to mat2
378-
if is_node_with_op(mat2, "get_attr"):
379-
mat2_tensor = get_tensor_from_attr(graph_module, mat2)
380-
assert isinstance(mat2_tensor, torch.Tensor)
381-
mat2_tensor = alpha * mat2_tensor
382-
# transpose mat2
383-
mat2_tensor = mat2_tensor if transposed_mat2 else mat2_tensor.t()
384-
with graph.inserting_before(node):
385-
mat2_name = f"_mat2_addmm_to_linear_{self.counter}"
386-
graph_module.register_buffer(mat2_name, mat2_tensor)
387-
mat2 = graph.get_attr(mat2_name)
388-
389-
# Construct the linear node
390-
linear_args = (mat1, mat2, bias)
391-
with graph.inserting_before(node):
392-
linear_node = graph.call_function(
393-
exir_ops.edge.aten.linear.default, args=linear_args
394-
)
395-
linear_node.meta = node.meta
396-
# Replace all the uses of the addmm op with linear op
397-
node.replace_all_uses_with(linear_node)
398-
self.counter += 1
365+
if not fit_bias or not fit_mat2:
366+
return False
399367

400-
graph_module.recompile()
401-
graph_module.graph.eliminate_dead_code()
368+
graph = node.graph
369+
graph_module = graph.owning_module
402370

403-
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
404-
self.replace_addmm_with_linear(graph_module)
405-
result = super().call(graph_module)
406-
return result
371+
# Multiply bias by beta
372+
if beta != 1.0:
373+
assert is_node_with_op(bias, "get_attr")
374+
bias_tensor = get_tensor_from_attr(graph_module, bias)
375+
assert isinstance(bias_tensor, torch.Tensor)
376+
bias_tensor = beta * bias_tensor
377+
with graph.inserting_before(node):
378+
bias_name = f"_bias_addmm_to_linear_{self.counter}"
379+
graph_module.register_buffer(bias_name, bias_tensor)
380+
bias = graph.get_attr(bias_name)
381+
382+
# Use associativity of scalar multiplication, and multiply alpha to mat2
383+
if is_node_with_op(mat2, "get_attr"):
384+
mat2_tensor = get_tensor_from_attr(graph_module, mat2)
385+
assert isinstance(mat2_tensor, torch.Tensor)
386+
mat2_tensor = alpha * mat2_tensor
387+
# transpose mat2
388+
mat2_tensor = mat2_tensor if transposed_mat2 else mat2_tensor.t()
389+
with graph.inserting_before(node):
390+
mat2_name = f"_mat2_addmm_to_linear_{self.counter}"
391+
graph_module.register_buffer(mat2_name, mat2_tensor)
392+
mat2 = graph.get_attr(mat2_name)
393+
394+
# Construct the linear node
395+
linear_args = (mat1, mat2, bias)
396+
with graph.inserting_before(node):
397+
linear_node = graph.call_function(
398+
exir_ops.edge.aten.linear.default, args=linear_args
399+
)
400+
linear_node.meta = node.meta
401+
# Replace all the uses of the addmm op with linear op
402+
node.replace_all_uses_with(linear_node)
403+
self.counter += 1
404+
return True
407405

408406

409407
@register_cadence_pass(CadencePassAttribute(opt_level=1))

backends/cadence/aot/tests/test_replace_ops_passes.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -875,7 +875,9 @@ def test_replace_linear_with_fully_connected(self) -> None:
875875
PassResult, ReplacePermuteWithTransposePass()(original_gm)
876876
).graph_module
877877
gm = cast(PassResult, ReplaceMMWithAddMMPass()(gm)).graph_module
878-
gm = cast(PassResult, ReplaceAddMMWithLinearPass()(gm)).graph_module
878+
pass_result = cast(PassResult, ReplaceAddMMWithLinearPass()(gm))
879+
self.assertTrue(pass_result.modified)
880+
gm = pass_result.graph_module
879881
graph_after_passes = cast(
880882
PassResult, ReplaceLinearWithFullyConnectedOpPass()(gm)
881883
).graph_module
@@ -924,9 +926,9 @@ def test_replace_addmm_with_linear(
924926
gm = cast(
925927
PassResult, ReplacePermuteWithTransposePass()(original_gm)
926928
).graph_module
927-
graph_after_passes = cast(
928-
PassResult, ReplaceAddMMWithLinearPass()(gm)
929-
).graph_module
929+
pass_result = cast(PassResult, ReplaceAddMMWithLinearPass()(gm))
930+
self.assertTrue(pass_result.modified)
931+
graph_after_passes = pass_result.graph_module
930932
self.assertIsNotNone(graph_after_passes)
931933
self.assertEqual(
932934
count_node(graph_after_passes, exir_ops.edge.aten.linear.default),

0 commit comments

Comments
 (0)