Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow delegate to consume buffer mutations #4830

Merged
merged 1 commit into from
Aug 28, 2024

Conversation

angelayi
Copy link
Contributor

Summary:
Fixing #4209

Edge Program:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, b_b: "f32[3, 3]", x: "f32[3, 3]"):
             # File: /data/users/angelayi/fbsource/buck-out/v2/gen/fbcode/389acaeb40d57230/executorch/exir/backend/test/__test_partitioner__/test_partitioner#link-tree/executorch/exir/backend/test/test_partitioner.py:631 in forward, code: self.b.add_(x)
            aten_add_tensor: "f32[3, 3]" = executorch_exir_dialects_edge__ops_aten_add_Tensor(b_b, x);  b_b = None

             # File: /data/users/angelayi/fbsource/buck-out/v2/gen/fbcode/389acaeb40d57230/executorch/exir/backend/test/__test_partitioner__/test_partitioner#link-tree/executorch/exir/backend/test/test_partitioner.py:632 in forward, code: return x + self.b
            aten_add_tensor_1: "f32[3, 3]" = executorch_exir_dialects_edge__ops_aten_add_Tensor(x, aten_add_tensor);  x = None
            return (aten_add_tensor, aten_add_tensor_1)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_b'), target='b', persistent=True), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='aten_add_tensor'), target='b'), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='aten_add_tensor_1'), target=None)])

Partitioned / lowered Exported Program (buffer mutation gets removed):

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[3, 3]"):
            # No stacktrace found for following nodes
            lowered_module_0 = self.lowered_module_0
            executorch_call_delegate = torch.ops.higher_order.executorch_call_delegate(lowered_module_0, x);  lowered_module_0 = x = None

             # File: /data/users/angelayi/fbsource/buck-out/v2/gen/fbcode/389acaeb40d57230/executorch/exir/backend/test/__test_partitioner__/test_partitioner#link-tree/executorch/exir/backend/test/test_partitioner.py:632 in forward, code: return x + self.b
            getitem_1: "f32[3, 3]" = executorch_call_delegate[0];  executorch_call_delegate = None
            return (getitem_1,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem_1'), target=None)])

Delegate (consumes the buffer mutation):

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, b_b: "f32[3, 3]", x: "f32[3, 3]"):
             # File: /data/users/angelayi/fbsource/buck-out/v2/gen/fbcode/389acaeb40d57230/executorch/exir/backend/test/__test_partitioner__/test_partitioner#link-tree/executorch/exir/backend/test/test_partitioner.py:631 in forward, code: self.b.add_(x)
            aten_add_tensor: "f32[3, 3]" = executorch_exir_dialects_edge__ops_aten_add_Tensor(b_b, x);  b_b = None

             # File: /data/users/angelayi/fbsource/buck-out/v2/gen/fbcode/389acaeb40d57230/executorch/exir/backend/test/__test_partitioner__/test_partitioner#link-tree/executorch/exir/backend/test/test_partitioner.py:632 in forward, code: return x + self.b
            aten_add_tensor_1: "f32[3, 3]" = executorch_exir_dialects_edge__ops_aten_add_Tensor(x, aten_add_tensor);  x = None
            return (aten_add_tensor, aten_add_tensor_1)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_b'), target='b', persistent=True), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='aten_add_tensor'), target='b'), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='aten_add_tensor_1'), target=None)])

Differential Revision: D60838243

Copy link

pytorch-bot bot commented Aug 22, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/4830

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 1 Unrelated Failure

As of commit 843f6d0 with merge base e636ef6 (image):

NEW FAILURE - The following job has failed:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported labels Aug 22, 2024
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D60838243

@kimishpatel
Copy link
Contributor

@kimishpatel
Copy link
Contributor

@cymbalrush

@angelayi
Copy link
Contributor Author

angelayi commented Aug 22, 2024

It would require many complicated/possibly error-prone graph manipulations to support all kinds of buffer mutations, especially in the case where a buffer mutation node is also being used as an output.

The current design of supporting buffer mutations in delegates is that we want to mark specific nodes as being a buffer mutation in the delegated EP, and then remove those nodes from the toplevel EP.

However since buffer mutation nodes are outputs of the toplevel EP, the behavior of the partitioner/fuser is that tagged buffer mutation nodes are considered users outputs of the delegate, and then passed to the toplevel EP. This is because the partitioner does not have an understanding of "buffer mutation nodes" and has a requirement to preserve the correctness of the input/output nodes of a graph.

So, to support buffer mutations, we need to do some extra graph manipulations:

  1. Determine which output node in the partitioned EP is the buffer mutation -- this can be done by matching the output node with the user nodes in the toplevel program, and checking if the user node corresponds to a buffer mutation in the toplevel EP.
  2. Update the partitioned EP's graph signature to show that the node is a buffer mutation and no longer a user output. The buffer mutation is something within the delegate and hidden from the toplevel program.
  3. Update the callsite to the delegate in the toplevel EP to remove the buffer mutation as an output, since the buffer mutation is no longer an output of the partitioned EP.
  4. Remove the buffer mutation node from the toplevel EP, since that is consumed by the delegate.

In graphical form, before delegation we have a program like the following:

Toplevel ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, b_b, x):
            add_0 = torch.ops.aten.add.Tensor(b_b, x);
            add_1 = torch.ops.aten.add.Tensor(x, aten_add_tensor);
            return (add_0, add_1)

    Graph signature: 
        input_spec = [b_b: BUFFER, x: USER_INPUT]
        output_spec = [add_0: BUFFER_MUTATION, add_1: USER_OUTPUT]

The current behavior of delegation:

Delegated ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, b_b, x):
            add_0 = torch.ops.aten.add.Tensor(b_b, x);
            add_1 = torch.ops.aten.add.Tensor(x, aten_add_tensor);
            return (add_0, add_1)

    Graph signature: 
        input_spec = [b_b: USER_INPUT, x: USER_INPUT]
        output_spec = [add_0: USER_OUTPUT, add_1: USER_OUTPUT]

Toplevel ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, b_b, x):
            call_delegate = torch.ops.higher_order.call_delegate(lowered_module_0, (b_b, x))  # this has 2 outputs
            getitem_0 = call_delegate[0]
            getitem_1 = call_delegate[1]
            return (getitem_0, getitem_1)

    Graph signature: 
        input_spec = [b_b: BUFFER, x: USER_INPUT]
        output_spec = [getitem_0: BUFFER_MUTATION, getitem_1: USER_OUTPUT]

Updated design which moves the buffer mutation into the delegate, and allows the delegate to swallow the buffer:

Delegated ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, b_b, x):
            add_0 = torch.ops.aten.add.Tensor(b_b, x);
            add_1 = torch.ops.aten.add.Tensor(x, aten_add_tensor);
            return (add_0, add_1)

    Graph signature: 
        input_spec = [b_b: BUFFER, x: USER_INPUT]
        output_spec = [add_0: BUFFER_MUTATION, add_1: USER_OUTPUT]

Toplevel ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x):
            call_delegate = torch.ops.higher_order.call_delegate(lowered_module_0, x)  # this only has 1 output now
            getitem_0 = call_delegate[0]
            return (getitem_0,)

    Graph signature: 
        input_spec = [x: USER_INPUT]
        output_spec = [getitem_0: USER_OUTPUT]

Now this can be done for very simple cases, but for cases where the buffer mutation node is also a user output, like the example in the test case, this will introduce more complexities because we then have to check if a node is used in multiple ways and will have to duplicate the node to represent it as both a buffer mutation and a user output.

Toplevel ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, b_b, x):
            add_0 = torch.ops.aten.add.Tensor(b_b, x);
            return (add_0, add_0)

    Graph signature: 
        input_spec = [b_b: BUFFER, x: USER_INPUT]
        output_spec = [add_0: BUFFER_MUTATION, add_0: USER_OUTPUT]

Delegation currently behaves like:

Delegated ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, b_b, x):
            add_0 = torch.ops.aten.add.Tensor(b_b, x);
            return (add_0)

    Graph signature: 
        input_spec = [b_b: USER_INPUT, x: USER_INPUT]
        output_spec = [add_0: USER_OUTPUT]

Toplevel ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, b_b, x):
            call_delegate = torch.ops.higher_order.call_delegate(lowered_module_0, (b_b, x))
            getitem_0 = call_delegate[0]
            return (getitem_0, getitem_0)

    Graph signature: 
        input_spec = [b_b: BUFFER, x: USER_INPUT]
        output_spec = [getitem_0: BUFFER_MUTATION, getitem_0: USER_OUTPUT]

The updated graph would have to look something like:

Delegated ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, b_b, x):
            add_0 = torch.ops.aten.add.Tensor(b_b, x);
            return (add_0, add_0)  # Duplicate add_0 as an output

    Graph signature: 
        input_spec = [b_b: BUFFER, x: USER_INPUT]
        output_spec = [add_0: BUFFER_MUTATION, add_0: USER_OUTPUT]

Toplevel ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x):
            call_delegate = torch.ops.higher_order.call_delegate(lowered_module_0, x)  # this still has 1 output
            getitem_0 = call_delegate[0]
            return (getitem_0)  # Remove one of the getitem outputs

    Graph signature: 
        input_spec = [x: USER_INPUT]
        output_spec = [getitem_0: USER_OUTPUT]

It seems too complicated to support all cases of buffer mutation, but I'm open to suggestions!

@kimishpatel
Copy link
Contributor

@angelayi loves the detailed clear explanation. In the second half of the code snippet, to support coreml tests case, there maybe some typos but overall i understood.

So from the response I see one complication is returning the mutated buffer as user output.

Are there other ones that you see as potential issues? Does aliasing make it more complicated? Just asking since it would be good at least enumerate known use cases that complicate handling of buffer mutations to delegate.

@cymbalrush is the coreml test case an actual usecase for you guys where you need to return buffer mutation as output? For llama, I dont think this is the case.

@angelayi
Copy link
Contributor Author

Are there other ones that you see as potential issues? Does aliasing make it more complicated? Just asking since it would be good at least enumerate known use cases that complicate handling of buffer mutations to delegate.

I think in general if there are cases where the buffer mutation node is being used in multiple ways, like to update multiple buffers, or is both a buffer mutation and a user output, it will introduce more complexities.

Right, if a buffer is being aliased, and currently one is being consumed by the delegate whereas the other is not, that is also something we do not handle correctly, even today.

@YifanShenSZ
Copy link
Collaborator

YifanShenSZ commented Aug 22, 2024

I see, thanks for the detailed investigation & explanation!

The test case is just a random toy case, I can revise it. Having the actual llama usage supported is good enough.

Once this PR is merged, we can start trying in-place kv-cache llama?

@kimishpatel
Copy link
Contributor

@angelayi is there a way to provide limited support of letting delegate own mutable buffer, and let partitioning fail explicitly with clear error message when not possible?

@angelayi
Copy link
Contributor Author

@kimishpatel yeah, that's fine.
@YifanShenSZ if you could provide me a test case with the llama model you're exporting, that would be great!

@YifanShenSZ
Copy link
Collaborator

YifanShenSZ commented Aug 23, 2024

if you could provide me a test case with the llama model you're exporting, that would be great!

Sure. This could be a good proxy

SHAPE = (2, 3)

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.register_buffer("cache", torch.zeros(SHAPE, dtype=torch.float32))

    def forward(self, q, k_val, input_pos):
        q_T = q.transpose(0, 1)
        k = torch.ops.aten.index_put_(self.cache, [input_pos, None], k_val)
        attn = k.mm(q_T)
        return attn

q = torch.rand(1, 3)
k = torch.rand(1, 3)

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D60838243

angelayi added a commit to angelayi/executorch-1 that referenced this pull request Aug 28, 2024
Summary:
Pull Request resolved: pytorch#4830

Fixing pytorch#4209

Edge Program:
```
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, b_b: "f32[3, 3]", x: "f32[3, 3]"):
             # File: /data/users/angelayi/fbsource/buck-out/v2/gen/fbcode/389acaeb40d57230/executorch/exir/backend/test/__test_partitioner__/test_partitioner#link-tree/executorch/exir/backend/test/test_partitioner.py:631 in forward, code: self.b.add_(x)
            aten_add_tensor: "f32[3, 3]" = executorch_exir_dialects_edge__ops_aten_add_Tensor(b_b, x);  b_b = None

             # File: /data/users/angelayi/fbsource/buck-out/v2/gen/fbcode/389acaeb40d57230/executorch/exir/backend/test/__test_partitioner__/test_partitioner#link-tree/executorch/exir/backend/test/test_partitioner.py:632 in forward, code: return x + self.b
            aten_add_tensor_1: "f32[3, 3]" = executorch_exir_dialects_edge__ops_aten_add_Tensor(x, aten_add_tensor);  x = None
            return (aten_add_tensor, aten_add_tensor_1)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_b'), target='b', persistent=True), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='aten_add_tensor'), target='b'), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='aten_add_tensor_1'), target=None)])
```

Partitioned / lowered Exported Program (buffer mutation gets removed):
```
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[3, 3]"):
            # No stacktrace found for following nodes
            lowered_module_0 = self.lowered_module_0
            executorch_call_delegate = torch.ops.higher_order.executorch_call_delegate(lowered_module_0, x);  lowered_module_0 = x = None

             # File: /data/users/angelayi/fbsource/buck-out/v2/gen/fbcode/389acaeb40d57230/executorch/exir/backend/test/__test_partitioner__/test_partitioner#link-tree/executorch/exir/backend/test/test_partitioner.py:632 in forward, code: return x + self.b
            getitem_1: "f32[3, 3]" = executorch_call_delegate[0];  executorch_call_delegate = None
            return (getitem_1,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem_1'), target=None)])
```

Delegate (consumes the buffer mutation):
```
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, b_b: "f32[3, 3]", x: "f32[3, 3]"):
             # File: /data/users/angelayi/fbsource/buck-out/v2/gen/fbcode/389acaeb40d57230/executorch/exir/backend/test/__test_partitioner__/test_partitioner#link-tree/executorch/exir/backend/test/test_partitioner.py:631 in forward, code: self.b.add_(x)
            aten_add_tensor: "f32[3, 3]" = executorch_exir_dialects_edge__ops_aten_add_Tensor(b_b, x);  b_b = None

             # File: /data/users/angelayi/fbsource/buck-out/v2/gen/fbcode/389acaeb40d57230/executorch/exir/backend/test/__test_partitioner__/test_partitioner#link-tree/executorch/exir/backend/test/test_partitioner.py:632 in forward, code: return x + self.b
            aten_add_tensor_1: "f32[3, 3]" = executorch_exir_dialects_edge__ops_aten_add_Tensor(x, aten_add_tensor);  x = None
            return (aten_add_tensor, aten_add_tensor_1)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_b'), target='b', persistent=True), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='aten_add_tensor'), target='b'), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='aten_add_tensor_1'), target=None)])
```

Differential Revision: D60838243
Summary:
Pull Request resolved: pytorch#4830

Fixing pytorch#4209

Edge Program:
```
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, b_b: "f32[3, 3]", x: "f32[3, 3]"):
             # File: /data/users/angelayi/fbsource/buck-out/v2/gen/fbcode/389acaeb40d57230/executorch/exir/backend/test/__test_partitioner__/test_partitioner#link-tree/executorch/exir/backend/test/test_partitioner.py:631 in forward, code: self.b.add_(x)
            aten_add_tensor: "f32[3, 3]" = executorch_exir_dialects_edge__ops_aten_add_Tensor(b_b, x);  b_b = None

             # File: /data/users/angelayi/fbsource/buck-out/v2/gen/fbcode/389acaeb40d57230/executorch/exir/backend/test/__test_partitioner__/test_partitioner#link-tree/executorch/exir/backend/test/test_partitioner.py:632 in forward, code: return x + self.b
            aten_add_tensor_1: "f32[3, 3]" = executorch_exir_dialects_edge__ops_aten_add_Tensor(x, aten_add_tensor);  x = None
            return (aten_add_tensor, aten_add_tensor_1)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_b'), target='b', persistent=True), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='aten_add_tensor'), target='b'), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='aten_add_tensor_1'), target=None)])
```

Partitioned / lowered Exported Program (buffer mutation gets removed):
```
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[3, 3]"):
            # No stacktrace found for following nodes
            lowered_module_0 = self.lowered_module_0
            executorch_call_delegate = torch.ops.higher_order.executorch_call_delegate(lowered_module_0, x);  lowered_module_0 = x = None

             # File: /data/users/angelayi/fbsource/buck-out/v2/gen/fbcode/389acaeb40d57230/executorch/exir/backend/test/__test_partitioner__/test_partitioner#link-tree/executorch/exir/backend/test/test_partitioner.py:632 in forward, code: return x + self.b
            getitem_1: "f32[3, 3]" = executorch_call_delegate[0];  executorch_call_delegate = None
            return (getitem_1,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem_1'), target=None)])
```

Delegate (consumes the buffer mutation):
```
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, b_b: "f32[3, 3]", x: "f32[3, 3]"):
             # File: /data/users/angelayi/fbsource/buck-out/v2/gen/fbcode/389acaeb40d57230/executorch/exir/backend/test/__test_partitioner__/test_partitioner#link-tree/executorch/exir/backend/test/test_partitioner.py:631 in forward, code: self.b.add_(x)
            aten_add_tensor: "f32[3, 3]" = executorch_exir_dialects_edge__ops_aten_add_Tensor(b_b, x);  b_b = None

             # File: /data/users/angelayi/fbsource/buck-out/v2/gen/fbcode/389acaeb40d57230/executorch/exir/backend/test/__test_partitioner__/test_partitioner#link-tree/executorch/exir/backend/test/test_partitioner.py:632 in forward, code: return x + self.b
            aten_add_tensor_1: "f32[3, 3]" = executorch_exir_dialects_edge__ops_aten_add_Tensor(x, aten_add_tensor);  x = None
            return (aten_add_tensor, aten_add_tensor_1)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_b'), target='b', persistent=True), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='aten_add_tensor'), target='b'), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='aten_add_tensor_1'), target=None)])
```

Differential Revision: D60838243
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D60838243

lowered_module_node = lowered_module_nodes[0]

# get call delegate node
call_delegate_node = list(lowered_module_node.users.keys())[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clarification: I thought lowered_module_node itself is call_delegate node? It not what it is?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lowered_module_node is the getattr node that points to the actual saved lowered module. The call_delegate_node is the node that calls the delegate.


# get call delegate node
call_delegate_node = list(lowered_module_node.users.keys())[0]
self.assertEqual(len(call_delegate_node.args), 2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: add comment on what the two args are. I presume the first one is delegate blob and second is the input?
(feel free to ignore)

self.assertEqual(len(delegated_ep.graph_signature.buffers_to_mutate), 1)
self.assertEqual(len(delegated_ep.graph_signature.buffers), 1)

def test_buffer_mutation_unsupported(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice. Thanks for this test case

@@ -551,11 +556,72 @@ def _get_new_signature( # noqa: C901
)

if node.op == "output":
output_nodes = pytree.tree_leaves((node.args, node.kwargs))
buffer_mutation_idxs: Dict[int, List[OutputSpec]] = defaultdict(list)
for user in call_module_node.users.keys():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

THis is checking if the partitioned subgraph's call_module node is returning mutated buffer, right? If so we plan to remove those from call signature of the submodule but also from top level?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to remove them from the toplevel module, and in the submodule, set the signature as a "buffer mutation".

Copy link
Contributor

@kimishpatel kimishpatel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good to me

@facebook-github-bot facebook-github-bot merged commit a5157de into pytorch:main Aug 28, 2024
39 of 43 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants