-
Notifications
You must be signed in to change notification settings - Fork 732
Tag get_attr in AOTI partitioner #15868
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
so that we can tag away `submodule`s from `torch.cond`.
Let's say we have some eager code like this:
```py
# Tensor predicate: True if any element is non-zero
# Result is a 0-dim bool tensor suitable for torch.cond
cache_is_initialized = (cached_keys != 0).any()
# Use torch.cond to select branch in a traceable way.
# All operands must be (nested) tensors or simple Python values.
key_states, value_states = torch.cond(
cache_is_initialized,
use_cached_kv,
recompute_kv,
operands=(cached_keys, cached_values, key_value_states),
)
```
Basically we check if KV cache is all zero, if so, we compute KV
projections, otherwise we read KV states from KV cache.
After torch.export'ing torch.cond, the graph becomes:
```
%any_1 : [num_users=1] = call_function[target=torch.ops.aten.any.default](args = (%ne,), kwargs = {})
%true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
%false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
%cond : [num_users=2] = call_function[target=torch.ops.higher_order.cond](args = (%any_1, %true_graph_0, %false_graph_0, (%b_cache_cross_attention_cache_layers_0_keys, %b_cache_cross_attention_cache_layers_0_values, %p_decoder_layers_0_encoder_attn_k_proj_weight, %p_decoder_layers_0_encoder_attn_v_proj_bias, %p_decoder_layers_0_encoder_attn_v_proj_weight, %encoder_hidden_states)), kwargs = {})
```
After tagging and delegate it becomes:
```
graph():
%decoder_input_ids : [num_users=1] = placeholder[target=decoder_input_ids]
%encoder_hidden_states : [num_users=1] = placeholder[target=encoder_hidden_states]
%cache_position : [num_users=1] = placeholder[target=cache_position]
%submodule_0 : [num_users=1] = get_attr[target=submodule_0]
%submodule_1 : [num_users=1] = get_attr[target=submodule_1]
%submodule_2 : [num_users=1] = get_attr[target=submodule_2]
%submodule_3 : [num_users=1] = get_attr[target=submodule_3]
%submodule_4 : [num_users=1] = get_attr[target=submodule_4]
%submodule_5 : [num_users=1] = get_attr[target=submodule_5]
%submodule_6 : [num_users=1] = get_attr[target=submodule_6]
%submodule_7 : [num_users=1] = get_attr[target=submodule_7]
%fused_tag0 : [num_users=17] = call_module[target=fused_tag0](args = (%decoder_input_ids, %cache_position, %submodule_0, %submodule_1, %encoder_hidden_states, %submodule_2, %submodule_3, %submodule_4, %submodule_5, %submodule_6, %submodule_7), kwargs = {})
%getitem_8 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 0), kwargs = {})
%getitem_9 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 1), kwargs = {})
%getitem_10 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 2), kwargs = {})
%getitem_11 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 3), kwargs = {})
%getitem_12 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 4), kwargs = {})
%getitem_13 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 5), kwargs = {})
%getitem_14 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 6), kwargs = {})
%getitem_15 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 7), kwargs = {})
%getitem_16 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 8), kwargs = {})
%getitem_17 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 9), kwargs = {})
%getitem_18 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 10), kwargs = {})
%getitem_19 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 11), kwargs = {})
%getitem_20 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 12), kwargs = {})
%getitem_21 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 13), kwargs = {})
%getitem_22 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 14), kwargs = {})
%getitem_23 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 15), kwargs = {})
%getitem_24 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 16), kwargs = {})
return (getitem_16, getitem_17, getitem_8, getitem_9, getitem_18, getitem_19, getitem_10, getitem_11, getitem_20, getitem_21, getitem_12, getitem_13, getitem_22, getitem_23, getitem_14, getitem_15, getitem_24)
```
But actually those submodules can be delegated away to AOTI. This PR
makes sure we tag them properly.
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/15868
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 730dc8b with merge base 763a474 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
|
Also @larryliu0820 - see this PR that's doing something similar now -- #15849 |
| if node.op != "call_function": | ||
| continue | ||
| node.meta["delegation_tag"] = tag | ||
| if node.op == "call_function": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you use this get_control_flow_submodules
executorch/exir/graph_module.py
Lines 76 to 89 in ca4e363
| def get_control_flow_submodules( | |
| graph_module: torch.fx.GraphModule, | |
| ) -> List[Tuple[str, torch.fx.GraphModule, torch.fx.Node]]: | |
| """ | |
| Returns a list of submodules used for control flow operations | |
| (torch.ops.higher_order.cond/map) that are in the given toplevel graph (does not look | |
| into submodules). Specifically, the returned value is a list containing | |
| tuples of (name of the submodule that's stored in the graph module, the | |
| submodule itself, and the fx node that uses this submodule). | |
| """ | |
| return _get_control_flow_submodules( | |
| graph_module, | |
| {torch.ops.higher_order.cond: [1, 2], torch.ops.higher_order.map_impl: [0]}, | |
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I specifically don't want to use this get_control_flow_submodules. The reason is it recursively lower the branches and leave torch.cond for ET to handle. For a backend like AOTI, it is fully capable to handle torch.cond and the predicate and the branches, so it should be delegated away.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I see sounds good
I think they're also using
It was resolved. We need the partitioner to handle the logic inside and offer util solution. We need the backend to decide whether they handle the subgraphs only, or also handle the high ordered op in the partitioner because some backend might support control flow and some not. |
so that we can tag away
submodules fromtorch.cond.Let's say we have some eager code like this:
Basically we check if KV cache is all zero, if so, we compute KV projections, otherwise we read KV states from KV cache.
After torch.export'ing torch.cond, the graph becomes:
After tagging and delegate it becomes:
But actually those submodules can be delegated away to AOTI. This PR makes sure we tag them properly.
Summary
[PLEASE REMOVE] See CONTRIBUTING.md's Pull Requests for ExecuTorch PR guidelines.
[PLEASE REMOVE] If this PR closes an issue, please add a
Fixes #<issue-id>line.[PLEASE REMOVE] If this PR introduces a fix or feature that should be the upcoming release notes, please add a "Release notes: " label. For a list of available release notes labels, check out CONTRIBUTING.md's Pull Requests.
Test plan
[PLEASE REMOVE] How did you test this PR? Please write down any manual commands you used and note down tests that you have written if applicable.