Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions backends/aoti/aoti_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,24 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
partition_tags: Dict[str, DelegationSpec] = {}
tag = "tag0"

# Tag torch.cond and other control flow operations
def is_control_flow(node: torch.fx.Node) -> bool:
return node.op == "call_function" and node.target in [
torch.ops.higher_order.cond,
torch.ops.higher_order.map_impl,
torch.ops.higher_order.while_loop,
]

for node in exported_program.graph.nodes:
if node.op != "call_function":
continue
node.meta["delegation_tag"] = tag
if node.op == "call_function":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you use this get_control_flow_submodules

def get_control_flow_submodules(
graph_module: torch.fx.GraphModule,
) -> List[Tuple[str, torch.fx.GraphModule, torch.fx.Node]]:
"""
Returns a list of submodules used for control flow operations
(torch.ops.higher_order.cond/map) that are in the given toplevel graph (does not look
into submodules). Specifically, the returned value is a list containing
tuples of (name of the submodule that's stored in the graph module, the
submodule itself, and the fx node that uses this submodule).
"""
return _get_control_flow_submodules(
graph_module,
{torch.ops.higher_order.cond: [1, 2], torch.ops.higher_order.map_impl: [0]},
)
I think we have some examples usages https://github.com/search?q=repo%3Apytorch%2Fexecutorch%20get_control_flow_submodules&type=code

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I specifically don't want to use this get_control_flow_submodules. The reason is it recursively lower the branches and leave torch.cond for ET to handle. For a backend like AOTI, it is fully capable to handle torch.cond and the predicate and the branches, so it should be delegated away.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see sounds good

node.meta["delegation_tag"] = tag
# Tag get_attr nodes that are used by control flow operations
elif node.op == "get_attr":
# Check if any user is a control flow operation
for user in node.users:
if is_control_flow(user):
node.meta["delegation_tag"] = tag
break

partition_tags[tag] = self.delegation_spec

Expand Down
Loading