Commit 584d39b
authored
Tag get_attr in AOTI partitioner (#15868)
so that we can tag away `submodule`s from `torch.cond`.
Let's say we have some eager code like this:
```py
# Tensor predicate: True if any element is non-zero
# Result is a 0-dim bool tensor suitable for torch.cond
cache_is_initialized = (cached_keys != 0).any()
# Use torch.cond to select branch in a traceable way.
# All operands must be (nested) tensors or simple Python values.
key_states, value_states = torch.cond(
cache_is_initialized,
use_cached_kv,
recompute_kv,
operands=(cached_keys, cached_values, key_value_states),
)
```
Basically we check if KV cache is all zero, if so, we compute KV
projections, otherwise we read KV states from KV cache.
After torch.export'ing torch.cond, the graph becomes:
```
%any_1 : [num_users=1] = call_function[target=torch.ops.aten.any.default](args = (%ne,), kwargs = {})
%true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
%false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
%cond : [num_users=2] = call_function[target=torch.ops.higher_order.cond](args = (%any_1, %true_graph_0, %false_graph_0, (%b_cache_cross_attention_cache_layers_0_keys, %b_cache_cross_attention_cache_layers_0_values, %p_decoder_layers_0_encoder_attn_k_proj_weight, %p_decoder_layers_0_encoder_attn_v_proj_bias, %p_decoder_layers_0_encoder_attn_v_proj_weight, %encoder_hidden_states)), kwargs = {})
```
After tagging and delegate it becomes:
```
graph():
%decoder_input_ids : [num_users=1] = placeholder[target=decoder_input_ids]
%encoder_hidden_states : [num_users=1] = placeholder[target=encoder_hidden_states]
%cache_position : [num_users=1] = placeholder[target=cache_position]
%submodule_0 : [num_users=1] = get_attr[target=submodule_0]
%submodule_1 : [num_users=1] = get_attr[target=submodule_1]
%submodule_2 : [num_users=1] = get_attr[target=submodule_2]
%submodule_3 : [num_users=1] = get_attr[target=submodule_3]
%submodule_4 : [num_users=1] = get_attr[target=submodule_4]
%submodule_5 : [num_users=1] = get_attr[target=submodule_5]
%submodule_6 : [num_users=1] = get_attr[target=submodule_6]
%submodule_7 : [num_users=1] = get_attr[target=submodule_7]
%fused_tag0 : [num_users=17] = call_module[target=fused_tag0](args = (%decoder_input_ids, %cache_position, %submodule_0, %submodule_1, %encoder_hidden_states, %submodule_2, %submodule_3, %submodule_4, %submodule_5, %submodule_6, %submodule_7), kwargs = {})
%getitem_8 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 0), kwargs = {})
%getitem_9 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 1), kwargs = {})
%getitem_10 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 2), kwargs = {})
%getitem_11 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 3), kwargs = {})
%getitem_12 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 4), kwargs = {})
%getitem_13 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 5), kwargs = {})
%getitem_14 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 6), kwargs = {})
%getitem_15 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 7), kwargs = {})
%getitem_16 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 8), kwargs = {})
%getitem_17 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 9), kwargs = {})
%getitem_18 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 10), kwargs = {})
%getitem_19 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 11), kwargs = {})
%getitem_20 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 12), kwargs = {})
%getitem_21 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 13), kwargs = {})
%getitem_22 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 14), kwargs = {})
%getitem_23 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 15), kwargs = {})
%getitem_24 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 16), kwargs = {})
return (getitem_16, getitem_17, getitem_8, getitem_9, getitem_18, getitem_19, getitem_10, getitem_11, getitem_20, getitem_21, getitem_12, getitem_13, getitem_22, getitem_23, getitem_14, getitem_15, getitem_24)
```
But actually those submodules can be delegated away to AOTI. This PR
makes sure we tag them properly.
### Summary
[PLEASE REMOVE] See [CONTRIBUTING.md's Pull
Requests](https://github.com/pytorch/executorch/blob/main/CONTRIBUTING.md#pull-requests)
for ExecuTorch PR guidelines.
[PLEASE REMOVE] If this PR closes an issue, please add a `Fixes
#<issue-id>` line.
[PLEASE REMOVE] If this PR introduces a fix or feature that should be
the upcoming release notes, please add a "Release notes: <area>" label.
For a list of available release notes labels, check out
[CONTRIBUTING.md's Pull
Requests](https://github.com/pytorch/executorch/blob/main/CONTRIBUTING.md#pull-requests).
### Test plan
[PLEASE REMOVE] How did you test this PR? Please write down any manual
commands you used and note down tests that you have written if
applicable.1 parent ca4e363 commit 584d39b
1 file changed
+17
-3
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
52 | 52 | | |
53 | 53 | | |
54 | 54 | | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
55 | 63 | | |
56 | | - | |
57 | | - | |
58 | | - | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
59 | 73 | | |
60 | 74 | | |
61 | 75 | | |
| |||
0 commit comments