Skip to content

Commit 23d89b5

Browse files
authored
Move apply_torch_ops_passes() to _lower_ep_to_edge().
Differential Revision: D78852189 Pull Request resolved: #12786
1 parent c310efe commit 23d89b5

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

backends/cadence/aot/compiler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,9 @@ def _lower_ep_to_edge(
228228
"""
229229
Lower an ExportedProgram to an EdgeProgramManager (in edge IR).
230230
"""
231+
# Apply passes which transform the ExportedProgram before it gets lowered to edge.
232+
expo_program = apply_torch_ops_passes(expo_program)
233+
231234
# Call to_edge to convert the graph to edge IR.
232235
# Note: dim_order is skipped (https://github.com/pytorch/executorch/issues/3704)
233236
edge_prog_manager = to_edge(
@@ -263,9 +266,6 @@ def export_to_edge(
263266
# Export the model into an ExportedProgram.
264267
expo_program = trace(model, inputs)
265268

266-
# Apply passes which transform the ExportedProgram before it gets lowered to edge.
267-
expo_program = apply_torch_ops_passes(expo_program)
268-
269269
# Lower the model to edge IR.
270270
edge_prog_manager = _lower_ep_to_edge(
271271
expo_program, dump_graphs, constant_methods, core_aten_exceptions

backends/cadence/aot/replace_ops.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2328,12 +2328,15 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
23282328

23292329
# Extract an argument to a separate full op.
23302330
with graph_module.graph.inserting_before(mul_node):
2331-
full_tensor = graph_module.graph.call_function(
2331+
full_node = graph_module.graph.call_function(
23322332
torch.ops.aten.full.default, args=([1], full_arg)
23332333
)
2334+
full_node.meta = mul_node.meta
2335+
full_node.meta["val"] = [1]
23342336
new_mul_node = graph_module.graph.call_function(
2335-
torch.ops.aten.mul.Tensor, args=(x_arg, full_tensor)
2337+
torch.ops.aten.mul.Tensor, args=(x_arg, full_node)
23362338
)
2339+
new_mul_node.meta = mul_node.meta
23372340
# Replace the old mul with a newly created mul.
23382341
mul_node.replace_all_uses_with(new_mul_node)
23392342
graph_module.graph.erase_node(mul_node)

0 commit comments

Comments
 (0)