Skip to content

Conversation

@larryliu0820
Copy link
Contributor

so that we can tag away submodules from torch.cond.

Let's say we have some eager code like this:

            # Tensor predicate: True if any element is non-zero
            # Result is a 0-dim bool tensor suitable for torch.cond
            cache_is_initialized = (cached_keys != 0).any()

            # Use torch.cond to select branch in a traceable way.
            # All operands must be (nested) tensors or simple Python values.
            key_states, value_states = torch.cond(
                cache_is_initialized,
                use_cached_kv,
                recompute_kv,
                operands=(cached_keys, cached_values, key_value_states),
            )

Basically we check if KV cache is all zero, if so, we compute KV projections, otherwise we read KV states from KV cache.

After torch.export'ing torch.cond, the graph becomes:

    %any_1 : [num_users=1] = call_function[target=torch.ops.aten.any.default](args = (%ne,), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=2] = call_function[target=torch.ops.higher_order.cond](args = (%any_1, %true_graph_0, %false_graph_0, (%b_cache_cross_attention_cache_layers_0_keys, %b_cache_cross_attention_cache_layers_0_values, %p_decoder_layers_0_encoder_attn_k_proj_weight, %p_decoder_layers_0_encoder_attn_v_proj_bias, %p_decoder_layers_0_encoder_attn_v_proj_weight, %encoder_hidden_states)), kwargs = {})

After tagging and delegate it becomes:

graph():
    %decoder_input_ids : [num_users=1] = placeholder[target=decoder_input_ids]
    %encoder_hidden_states : [num_users=1] = placeholder[target=encoder_hidden_states]
    %cache_position : [num_users=1] = placeholder[target=cache_position]
    %submodule_0 : [num_users=1] = get_attr[target=submodule_0]
    %submodule_1 : [num_users=1] = get_attr[target=submodule_1]
    %submodule_2 : [num_users=1] = get_attr[target=submodule_2]
    %submodule_3 : [num_users=1] = get_attr[target=submodule_3]
    %submodule_4 : [num_users=1] = get_attr[target=submodule_4]
    %submodule_5 : [num_users=1] = get_attr[target=submodule_5]
    %submodule_6 : [num_users=1] = get_attr[target=submodule_6]
    %submodule_7 : [num_users=1] = get_attr[target=submodule_7]
    %fused_tag0 : [num_users=17] = call_module[target=fused_tag0](args = (%decoder_input_ids, %cache_position, %submodule_0, %submodule_1, %encoder_hidden_states, %submodule_2, %submodule_3, %submodule_4, %submodule_5, %submodule_6, %submodule_7), kwargs = {})
    %getitem_8 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 0), kwargs = {})
    %getitem_9 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 1), kwargs = {})
    %getitem_10 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 2), kwargs = {})
    %getitem_11 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 3), kwargs = {})
    %getitem_12 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 4), kwargs = {})
    %getitem_13 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 5), kwargs = {})
    %getitem_14 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 6), kwargs = {})
    %getitem_15 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 7), kwargs = {})
    %getitem_16 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 8), kwargs = {})
    %getitem_17 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 9), kwargs = {})
    %getitem_18 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 10), kwargs = {})
    %getitem_19 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 11), kwargs = {})
    %getitem_20 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 12), kwargs = {})
    %getitem_21 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 13), kwargs = {})
    %getitem_22 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 14), kwargs = {})
    %getitem_23 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 15), kwargs = {})
    %getitem_24 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 16), kwargs = {})
    return (getitem_16, getitem_17, getitem_8, getitem_9, getitem_18, getitem_19, getitem_10, getitem_11, getitem_20, getitem_21, getitem_12, getitem_13, getitem_22, getitem_23, getitem_14, getitem_15, getitem_24)

But actually those submodules can be delegated away to AOTI. This PR makes sure we tag them properly.

Summary

[PLEASE REMOVE] See CONTRIBUTING.md's Pull Requests for ExecuTorch PR guidelines.

[PLEASE REMOVE] If this PR closes an issue, please add a Fixes #<issue-id> line.

[PLEASE REMOVE] If this PR introduces a fix or feature that should be the upcoming release notes, please add a "Release notes: " label. For a list of available release notes labels, check out CONTRIBUTING.md's Pull Requests.

Test plan

[PLEASE REMOVE] How did you test this PR? Please write down any manual commands you used and note down tests that you have written if applicable.

so that we can tag away `submodule`s from `torch.cond`.

Let's say we have some eager code like this:

```py
            # Tensor predicate: True if any element is non-zero
            # Result is a 0-dim bool tensor suitable for torch.cond
            cache_is_initialized = (cached_keys != 0).any()

            # Use torch.cond to select branch in a traceable way.
            # All operands must be (nested) tensors or simple Python values.
            key_states, value_states = torch.cond(
                cache_is_initialized,
                use_cached_kv,
                recompute_kv,
                operands=(cached_keys, cached_values, key_value_states),
            )
```
Basically we check if KV cache is all zero, if so, we compute KV
projections, otherwise we read KV states from KV cache.

After torch.export'ing torch.cond, the graph becomes:

```
    %any_1 : [num_users=1] = call_function[target=torch.ops.aten.any.default](args = (%ne,), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=2] = call_function[target=torch.ops.higher_order.cond](args = (%any_1, %true_graph_0, %false_graph_0, (%b_cache_cross_attention_cache_layers_0_keys, %b_cache_cross_attention_cache_layers_0_values, %p_decoder_layers_0_encoder_attn_k_proj_weight, %p_decoder_layers_0_encoder_attn_v_proj_bias, %p_decoder_layers_0_encoder_attn_v_proj_weight, %encoder_hidden_states)), kwargs = {})
```

After tagging and delegate it becomes:

```
graph():
    %decoder_input_ids : [num_users=1] = placeholder[target=decoder_input_ids]
    %encoder_hidden_states : [num_users=1] = placeholder[target=encoder_hidden_states]
    %cache_position : [num_users=1] = placeholder[target=cache_position]
    %submodule_0 : [num_users=1] = get_attr[target=submodule_0]
    %submodule_1 : [num_users=1] = get_attr[target=submodule_1]
    %submodule_2 : [num_users=1] = get_attr[target=submodule_2]
    %submodule_3 : [num_users=1] = get_attr[target=submodule_3]
    %submodule_4 : [num_users=1] = get_attr[target=submodule_4]
    %submodule_5 : [num_users=1] = get_attr[target=submodule_5]
    %submodule_6 : [num_users=1] = get_attr[target=submodule_6]
    %submodule_7 : [num_users=1] = get_attr[target=submodule_7]
    %fused_tag0 : [num_users=17] = call_module[target=fused_tag0](args = (%decoder_input_ids, %cache_position, %submodule_0, %submodule_1, %encoder_hidden_states, %submodule_2, %submodule_3, %submodule_4, %submodule_5, %submodule_6, %submodule_7), kwargs = {})
    %getitem_8 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 0), kwargs = {})
    %getitem_9 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 1), kwargs = {})
    %getitem_10 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 2), kwargs = {})
    %getitem_11 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 3), kwargs = {})
    %getitem_12 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 4), kwargs = {})
    %getitem_13 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 5), kwargs = {})
    %getitem_14 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 6), kwargs = {})
    %getitem_15 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 7), kwargs = {})
    %getitem_16 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 8), kwargs = {})
    %getitem_17 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 9), kwargs = {})
    %getitem_18 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 10), kwargs = {})
    %getitem_19 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 11), kwargs = {})
    %getitem_20 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 12), kwargs = {})
    %getitem_21 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 13), kwargs = {})
    %getitem_22 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 14), kwargs = {})
    %getitem_23 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 15), kwargs = {})
    %getitem_24 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 16), kwargs = {})
    return (getitem_16, getitem_17, getitem_8, getitem_9, getitem_18, getitem_19, getitem_10, getitem_11, getitem_20, getitem_21, getitem_12, getitem_13, getitem_22, getitem_23, getitem_14, getitem_15, getitem_24)
```

But actually those submodules can be delegated away to AOTI. This PR
makes sure we tag them properly.
@pytorch-bot
Copy link

pytorch-bot bot commented Nov 18, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/15868

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 730dc8b with merge base 763a474 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 18, 2025
@larryliu0820 larryliu0820 marked this pull request as ready for review November 18, 2025 08:37
@github-actions
Copy link

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@mergennachin
Copy link
Contributor

@angelayi @cccclai didn't we solve this a while back in low level backend partitioner logic?

@mergennachin
Copy link
Contributor

Also @larryliu0820 - see this PR that's doing something similar now -- #15849

if node.op != "call_function":
continue
node.meta["delegation_tag"] = tag
if node.op == "call_function":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you use this get_control_flow_submodules

def get_control_flow_submodules(
graph_module: torch.fx.GraphModule,
) -> List[Tuple[str, torch.fx.GraphModule, torch.fx.Node]]:
"""
Returns a list of submodules used for control flow operations
(torch.ops.higher_order.cond/map) that are in the given toplevel graph (does not look
into submodules). Specifically, the returned value is a list containing
tuples of (name of the submodule that's stored in the graph module, the
submodule itself, and the fx node that uses this submodule).
"""
return _get_control_flow_submodules(
graph_module,
{torch.ops.higher_order.cond: [1, 2], torch.ops.higher_order.map_impl: [0]},
)
I think we have some examples usages https://github.com/search?q=repo%3Apytorch%2Fexecutorch%20get_control_flow_submodules&type=code

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I specifically don't want to use this get_control_flow_submodules. The reason is it recursively lower the branches and leave torch.cond for ET to handle. For a backend like AOTI, it is fully capable to handle torch.cond and the predicate and the branches, so it should be delegated away.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see sounds good

@cccclai
Copy link
Contributor

cccclai commented Nov 18, 2025

Also @larryliu0820 - see this PR that's doing something similar now -- #15849

I think they're also using get_control_flow_submodules (thought they moved to get_cond_while_submodules)

@angelayi @cccclai didn't we solve this a while back in low level backend partitioner logic?

It was resolved. We need the partitioner to handle the logic inside and offer util solution. We need the backend to decide whether they handle the subgraphs only, or also handle the high ordered op in the partitioner because some backend might support control flow and some not.

@larryliu0820 larryliu0820 merged commit 584d39b into main Nov 18, 2025
172 of 179 checks passed
@larryliu0820 larryliu0820 deleted the cuda_tag_get_attr branch November 18, 2025 22:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants