-
Notifications
You must be signed in to change notification settings - Fork 473
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow delegate to consume buffer mutations #4830
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/4830
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 1 Unrelated FailureAs of commit 843f6d0 with merge base e636ef6 ( NEW FAILURE - The following job has failed:
FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This pull request was exported from Phabricator. Differential Revision: D60838243 |
@angelayi suggest that this test case, https://github.com/YifanShenSZ/executorch/blob/coreml-state/backends/apple/coreml/test/test_coreml_partitioner.py#L84-L94, cannot be supported. Angela, can you say why? cc: @YifanShenSZ |
It would require many complicated/possibly error-prone graph manipulations to support all kinds of buffer mutations, especially in the case where a buffer mutation node is also being used as an output. The current design of supporting buffer mutations in delegates is that we want to mark specific nodes as being a buffer mutation in the delegated EP, and then remove those nodes from the toplevel EP. However since buffer mutation nodes are outputs of the toplevel EP, the behavior of the partitioner/fuser is that tagged buffer mutation nodes are considered users outputs of the delegate, and then passed to the toplevel EP. This is because the partitioner does not have an understanding of "buffer mutation nodes" and has a requirement to preserve the correctness of the input/output nodes of a graph. So, to support buffer mutations, we need to do some extra graph manipulations:
In graphical form, before delegation we have a program like the following:
The current behavior of delegation:
Updated design which moves the buffer mutation into the delegate, and allows the delegate to swallow the buffer:
Now this can be done for very simple cases, but for cases where the buffer mutation node is also a user output, like the example in the test case, this will introduce more complexities because we then have to check if a node is used in multiple ways and will have to duplicate the node to represent it as both a buffer mutation and a user output.
Delegation currently behaves like:
The updated graph would have to look something like:
It seems too complicated to support all cases of buffer mutation, but I'm open to suggestions! |
@angelayi loves the detailed clear explanation. In the second half of the code snippet, to support coreml tests case, there maybe some typos but overall i understood. So from the response I see one complication is returning the mutated buffer as user output. Are there other ones that you see as potential issues? Does aliasing make it more complicated? Just asking since it would be good at least enumerate known use cases that complicate handling of buffer mutations to delegate. @cymbalrush is the coreml test case an actual usecase for you guys where you need to return buffer mutation as output? For llama, I dont think this is the case. |
I think in general if there are cases where the buffer mutation node is being used in multiple ways, like to update multiple buffers, or is both a buffer mutation and a user output, it will introduce more complexities. Right, if a buffer is being aliased, and currently one is being consumed by the delegate whereas the other is not, that is also something we do not handle correctly, even today. |
I see, thanks for the detailed investigation & explanation! The test case is just a random toy case, I can revise it. Having the actual llama usage supported is good enough. Once this PR is merged, we can start trying in-place kv-cache llama? |
@angelayi is there a way to provide limited support of letting delegate own mutable buffer, and let partitioning fail explicitly with clear error message when not possible? |
@kimishpatel yeah, that's fine. |
Sure. This could be a good proxy
|
This pull request was exported from Phabricator. Differential Revision: D60838243 |
d66ce3f
to
457f1a3
Compare
Summary: Pull Request resolved: pytorch#4830 Fixing pytorch#4209 Edge Program: ``` ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, b_b: "f32[3, 3]", x: "f32[3, 3]"): # File: /data/users/angelayi/fbsource/buck-out/v2/gen/fbcode/389acaeb40d57230/executorch/exir/backend/test/__test_partitioner__/test_partitioner#link-tree/executorch/exir/backend/test/test_partitioner.py:631 in forward, code: self.b.add_(x) aten_add_tensor: "f32[3, 3]" = executorch_exir_dialects_edge__ops_aten_add_Tensor(b_b, x); b_b = None # File: /data/users/angelayi/fbsource/buck-out/v2/gen/fbcode/389acaeb40d57230/executorch/exir/backend/test/__test_partitioner__/test_partitioner#link-tree/executorch/exir/backend/test/test_partitioner.py:632 in forward, code: return x + self.b aten_add_tensor_1: "f32[3, 3]" = executorch_exir_dialects_edge__ops_aten_add_Tensor(x, aten_add_tensor); x = None return (aten_add_tensor, aten_add_tensor_1) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_b'), target='b', persistent=True), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='aten_add_tensor'), target='b'), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='aten_add_tensor_1'), target=None)]) ``` Partitioned / lowered Exported Program (buffer mutation gets removed): ``` ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[3, 3]"): # No stacktrace found for following nodes lowered_module_0 = self.lowered_module_0 executorch_call_delegate = torch.ops.higher_order.executorch_call_delegate(lowered_module_0, x); lowered_module_0 = x = None # File: /data/users/angelayi/fbsource/buck-out/v2/gen/fbcode/389acaeb40d57230/executorch/exir/backend/test/__test_partitioner__/test_partitioner#link-tree/executorch/exir/backend/test/test_partitioner.py:632 in forward, code: return x + self.b getitem_1: "f32[3, 3]" = executorch_call_delegate[0]; executorch_call_delegate = None return (getitem_1,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem_1'), target=None)]) ``` Delegate (consumes the buffer mutation): ``` ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, b_b: "f32[3, 3]", x: "f32[3, 3]"): # File: /data/users/angelayi/fbsource/buck-out/v2/gen/fbcode/389acaeb40d57230/executorch/exir/backend/test/__test_partitioner__/test_partitioner#link-tree/executorch/exir/backend/test/test_partitioner.py:631 in forward, code: self.b.add_(x) aten_add_tensor: "f32[3, 3]" = executorch_exir_dialects_edge__ops_aten_add_Tensor(b_b, x); b_b = None # File: /data/users/angelayi/fbsource/buck-out/v2/gen/fbcode/389acaeb40d57230/executorch/exir/backend/test/__test_partitioner__/test_partitioner#link-tree/executorch/exir/backend/test/test_partitioner.py:632 in forward, code: return x + self.b aten_add_tensor_1: "f32[3, 3]" = executorch_exir_dialects_edge__ops_aten_add_Tensor(x, aten_add_tensor); x = None return (aten_add_tensor, aten_add_tensor_1) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_b'), target='b', persistent=True), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='aten_add_tensor'), target='b'), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='aten_add_tensor_1'), target=None)]) ``` Differential Revision: D60838243
Summary: Pull Request resolved: pytorch#4830 Fixing pytorch#4209 Edge Program: ``` ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, b_b: "f32[3, 3]", x: "f32[3, 3]"): # File: /data/users/angelayi/fbsource/buck-out/v2/gen/fbcode/389acaeb40d57230/executorch/exir/backend/test/__test_partitioner__/test_partitioner#link-tree/executorch/exir/backend/test/test_partitioner.py:631 in forward, code: self.b.add_(x) aten_add_tensor: "f32[3, 3]" = executorch_exir_dialects_edge__ops_aten_add_Tensor(b_b, x); b_b = None # File: /data/users/angelayi/fbsource/buck-out/v2/gen/fbcode/389acaeb40d57230/executorch/exir/backend/test/__test_partitioner__/test_partitioner#link-tree/executorch/exir/backend/test/test_partitioner.py:632 in forward, code: return x + self.b aten_add_tensor_1: "f32[3, 3]" = executorch_exir_dialects_edge__ops_aten_add_Tensor(x, aten_add_tensor); x = None return (aten_add_tensor, aten_add_tensor_1) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_b'), target='b', persistent=True), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='aten_add_tensor'), target='b'), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='aten_add_tensor_1'), target=None)]) ``` Partitioned / lowered Exported Program (buffer mutation gets removed): ``` ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[3, 3]"): # No stacktrace found for following nodes lowered_module_0 = self.lowered_module_0 executorch_call_delegate = torch.ops.higher_order.executorch_call_delegate(lowered_module_0, x); lowered_module_0 = x = None # File: /data/users/angelayi/fbsource/buck-out/v2/gen/fbcode/389acaeb40d57230/executorch/exir/backend/test/__test_partitioner__/test_partitioner#link-tree/executorch/exir/backend/test/test_partitioner.py:632 in forward, code: return x + self.b getitem_1: "f32[3, 3]" = executorch_call_delegate[0]; executorch_call_delegate = None return (getitem_1,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem_1'), target=None)]) ``` Delegate (consumes the buffer mutation): ``` ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, b_b: "f32[3, 3]", x: "f32[3, 3]"): # File: /data/users/angelayi/fbsource/buck-out/v2/gen/fbcode/389acaeb40d57230/executorch/exir/backend/test/__test_partitioner__/test_partitioner#link-tree/executorch/exir/backend/test/test_partitioner.py:631 in forward, code: self.b.add_(x) aten_add_tensor: "f32[3, 3]" = executorch_exir_dialects_edge__ops_aten_add_Tensor(b_b, x); b_b = None # File: /data/users/angelayi/fbsource/buck-out/v2/gen/fbcode/389acaeb40d57230/executorch/exir/backend/test/__test_partitioner__/test_partitioner#link-tree/executorch/exir/backend/test/test_partitioner.py:632 in forward, code: return x + self.b aten_add_tensor_1: "f32[3, 3]" = executorch_exir_dialects_edge__ops_aten_add_Tensor(x, aten_add_tensor); x = None return (aten_add_tensor, aten_add_tensor_1) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_b'), target='b', persistent=True), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='aten_add_tensor'), target='b'), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='aten_add_tensor_1'), target=None)]) ``` Differential Revision: D60838243
This pull request was exported from Phabricator. Differential Revision: D60838243 |
457f1a3
to
843f6d0
Compare
lowered_module_node = lowered_module_nodes[0] | ||
|
||
# get call delegate node | ||
call_delegate_node = list(lowered_module_node.users.keys())[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
clarification: I thought lowered_module_node
itself is call_delegate
node? It not what it is?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lowered_module_node is the getattr node that points to the actual saved lowered module. The call_delegate_node is the node that calls the delegate.
|
||
# get call delegate node | ||
call_delegate_node = list(lowered_module_node.users.keys())[0] | ||
self.assertEqual(len(call_delegate_node.args), 2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: add comment on what the two args are. I presume the first one is delegate blob and second is the input?
(feel free to ignore)
self.assertEqual(len(delegated_ep.graph_signature.buffers_to_mutate), 1) | ||
self.assertEqual(len(delegated_ep.graph_signature.buffers), 1) | ||
|
||
def test_buffer_mutation_unsupported(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice. Thanks for this test case
@@ -551,11 +556,72 @@ def _get_new_signature( # noqa: C901 | |||
) | |||
|
|||
if node.op == "output": | |||
output_nodes = pytree.tree_leaves((node.args, node.kwargs)) | |||
buffer_mutation_idxs: Dict[int, List[OutputSpec]] = defaultdict(list) | |||
for user in call_module_node.users.keys(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
THis is checking if the partitioned subgraph's call_module node is returning mutated buffer, right? If so we plan to remove those from call signature of the submodule but also from top level?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We want to remove them from the toplevel module, and in the submodule, set the signature as a "buffer mutation".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good to me
Summary:
Fixing #4209
Edge Program:
Partitioned / lowered Exported Program (buffer mutation gets removed):
Delegate (consumes the buffer mutation):
Differential Revision: D60838243