Skip to content

Commit 584d39b

Browse files
authored
Tag get_attr in AOTI partitioner (#15868)
so that we can tag away `submodule`s from `torch.cond`. Let's say we have some eager code like this: ```py # Tensor predicate: True if any element is non-zero # Result is a 0-dim bool tensor suitable for torch.cond cache_is_initialized = (cached_keys != 0).any() # Use torch.cond to select branch in a traceable way. # All operands must be (nested) tensors or simple Python values. key_states, value_states = torch.cond( cache_is_initialized, use_cached_kv, recompute_kv, operands=(cached_keys, cached_values, key_value_states), ) ``` Basically we check if KV cache is all zero, if so, we compute KV projections, otherwise we read KV states from KV cache. After torch.export'ing torch.cond, the graph becomes: ``` %any_1 : [num_users=1] = call_function[target=torch.ops.aten.any.default](args = (%ne,), kwargs = {}) %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0] %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0] %cond : [num_users=2] = call_function[target=torch.ops.higher_order.cond](args = (%any_1, %true_graph_0, %false_graph_0, (%b_cache_cross_attention_cache_layers_0_keys, %b_cache_cross_attention_cache_layers_0_values, %p_decoder_layers_0_encoder_attn_k_proj_weight, %p_decoder_layers_0_encoder_attn_v_proj_bias, %p_decoder_layers_0_encoder_attn_v_proj_weight, %encoder_hidden_states)), kwargs = {}) ``` After tagging and delegate it becomes: ``` graph(): %decoder_input_ids : [num_users=1] = placeholder[target=decoder_input_ids] %encoder_hidden_states : [num_users=1] = placeholder[target=encoder_hidden_states] %cache_position : [num_users=1] = placeholder[target=cache_position] %submodule_0 : [num_users=1] = get_attr[target=submodule_0] %submodule_1 : [num_users=1] = get_attr[target=submodule_1] %submodule_2 : [num_users=1] = get_attr[target=submodule_2] %submodule_3 : [num_users=1] = get_attr[target=submodule_3] %submodule_4 : [num_users=1] = get_attr[target=submodule_4] %submodule_5 : [num_users=1] = get_attr[target=submodule_5] %submodule_6 : [num_users=1] = get_attr[target=submodule_6] %submodule_7 : [num_users=1] = get_attr[target=submodule_7] %fused_tag0 : [num_users=17] = call_module[target=fused_tag0](args = (%decoder_input_ids, %cache_position, %submodule_0, %submodule_1, %encoder_hidden_states, %submodule_2, %submodule_3, %submodule_4, %submodule_5, %submodule_6, %submodule_7), kwargs = {}) %getitem_8 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 0), kwargs = {}) %getitem_9 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 1), kwargs = {}) %getitem_10 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 2), kwargs = {}) %getitem_11 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 3), kwargs = {}) %getitem_12 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 4), kwargs = {}) %getitem_13 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 5), kwargs = {}) %getitem_14 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 6), kwargs = {}) %getitem_15 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 7), kwargs = {}) %getitem_16 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 8), kwargs = {}) %getitem_17 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 9), kwargs = {}) %getitem_18 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 10), kwargs = {}) %getitem_19 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 11), kwargs = {}) %getitem_20 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 12), kwargs = {}) %getitem_21 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 13), kwargs = {}) %getitem_22 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 14), kwargs = {}) %getitem_23 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 15), kwargs = {}) %getitem_24 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 16), kwargs = {}) return (getitem_16, getitem_17, getitem_8, getitem_9, getitem_18, getitem_19, getitem_10, getitem_11, getitem_20, getitem_21, getitem_12, getitem_13, getitem_22, getitem_23, getitem_14, getitem_15, getitem_24) ``` But actually those submodules can be delegated away to AOTI. This PR makes sure we tag them properly. ### Summary [PLEASE REMOVE] See [CONTRIBUTING.md's Pull Requests](https://github.com/pytorch/executorch/blob/main/CONTRIBUTING.md#pull-requests) for ExecuTorch PR guidelines. [PLEASE REMOVE] If this PR closes an issue, please add a `Fixes #<issue-id>` line. [PLEASE REMOVE] If this PR introduces a fix or feature that should be the upcoming release notes, please add a "Release notes: <area>" label. For a list of available release notes labels, check out [CONTRIBUTING.md's Pull Requests](https://github.com/pytorch/executorch/blob/main/CONTRIBUTING.md#pull-requests). ### Test plan [PLEASE REMOVE] How did you test this PR? Please write down any manual commands you used and note down tests that you have written if applicable.
1 parent ca4e363 commit 584d39b

File tree

1 file changed

+17
-3
lines changed

1 file changed

+17
-3
lines changed

backends/aoti/aoti_partitioner.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,24 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
5252
partition_tags: Dict[str, DelegationSpec] = {}
5353
tag = "tag0"
5454

55+
# Tag torch.cond and other control flow operations
56+
def is_control_flow(node: torch.fx.Node) -> bool:
57+
return node.op == "call_function" and node.target in [
58+
torch.ops.higher_order.cond,
59+
torch.ops.higher_order.map_impl,
60+
torch.ops.higher_order.while_loop,
61+
]
62+
5563
for node in exported_program.graph.nodes:
56-
if node.op != "call_function":
57-
continue
58-
node.meta["delegation_tag"] = tag
64+
if node.op == "call_function":
65+
node.meta["delegation_tag"] = tag
66+
# Tag get_attr nodes that are used by control flow operations
67+
elif node.op == "get_attr":
68+
# Check if any user is a control flow operation
69+
for user in node.users:
70+
if is_control_flow(user):
71+
node.meta["delegation_tag"] = tag
72+
break
5973

6074
partition_tags[tag] = self.delegation_spec
6175

0 commit comments

Comments
 (0)