Skip to content

Conversation

@cccclai
Copy link
Contributor

@cccclai cccclai commented May 30, 2024

Summary:
The decomposition from

class IndexPut(torch.nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, x, input_pos, value):
        x[:, :, input_pos] = value
        return x

is

opcode         name             target                      args                                             kwargs
-------------  ---------------  --------------------------  -----------------------------------------------  --------
placeholder    x                x                           ()                                               {}
placeholder    input_pos        input_pos                   ()                                               {}
placeholder    value            value                       ()                                               {}
call_function  slice_1          aten.slice.Tensor           (x, 0, 0, 9223372036854775807)                   {}
call_function  slice_2          aten.slice.Tensor           (slice_1, 1, 0, 9223372036854775807)             {}
call_function  index_put        aten.index_put.default      (slice_2, [None, None, input_pos], value)        {}
call_function  slice_3          aten.slice.Tensor           (x, 0, 0, 9223372036854775807)                   {}
call_function  slice_scatter    aten.slice_scatter.default  (slice_3, index_put, 1, 0, 9223372036854775807)  {}
call_function  slice_scatter_1  aten.slice_scatter.default  (x, slice_scatter, 0, 0, 9223372036854775807)    {}
output         output           output                      ((slice_scatter_1, slice_scatter_1),)            {}

however x[:, :, input_pos] = value really is just updating the content inside x with value, essentially just index_put

By replacing x[:, :, input_pos] = value with torch.ops.aten.index_put_(x, [None, None, input_pos], value), we reduce the number of operators from 6 to 1.

class IndexPut(torch.nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, x, indices, values):
        torch.ops.aten.index_put_(x, [None, None, input_pos], value)
        return x

decomposition is

opcode         name       target                  args                                 kwargs
-------------  ---------  ----------------------  -----------------------------------  --------
placeholder    x          x                       ()                                   {}
placeholder    input_pos  input_pos               ()                                   {}
placeholder    value      value                   ()                                   {}
call_function  index_put  aten.index_put.default  (x, [None, None, input_pos], value)  {}
output         output     output                  ((index_put, index_put),)            {}

A more proper way to address this in long term is via pattern matching to replace the patterns with the simplified pattern

Differential Revision: D57949659

@pytorch-bot
Copy link

pytorch-bot bot commented May 30, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/3786

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures

As of commit 6574788 with merge base 0333390 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 30, 2024
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D57949659

…ytorch#3786)

Summary:
Pull Request resolved: pytorch#3786

The decomposition from

```
class IndexPut(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, input_pos, value):
        x[:, :, input_pos] = value
        return x
```

is
```
opcode         name             target                      args                                             kwargs
-------------  ---------------  --------------------------  -----------------------------------------------  --------
placeholder    x                x                           ()                                               {}
placeholder    input_pos        input_pos                   ()                                               {}
placeholder    value            value                       ()                                               {}
call_function  slice_1          aten.slice.Tensor           (x, 0, 0, 9223372036854775807)                   {}
call_function  slice_2          aten.slice.Tensor           (slice_1, 1, 0, 9223372036854775807)             {}
call_function  index_put        aten.index_put.default      (slice_2, [None, None, input_pos], value)        {}
call_function  slice_3          aten.slice.Tensor           (x, 0, 0, 9223372036854775807)                   {}
call_function  slice_scatter    aten.slice_scatter.default  (slice_3, index_put, 1, 0, 9223372036854775807)  {}
call_function  slice_scatter_1  aten.slice_scatter.default  (x, slice_scatter, 0, 0, 9223372036854775807)    {}
output         output           output                      ((slice_scatter_1, slice_scatter_1),)            {}
```

however `x[:, :, input_pos] = value` really is just updating the content inside `x` with value, essentially just `index_put`

By replacing `x[:, :, input_pos] = value` with `torch.ops.aten.index_put_(x, [None, None, input_pos], value)`, we reduce the number of operators from 6 to 1.

```
class IndexPut(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, indices, values):
        torch.ops.aten.index_put_(x, [None, None, input_pos], value)
        return x
```
decomposition is
```
opcode         name       target                  args                                 kwargs
-------------  ---------  ----------------------  -----------------------------------  --------
placeholder    x          x                       ()                                   {}
placeholder    input_pos  input_pos               ()                                   {}
placeholder    value      value                   ()                                   {}
call_function  index_put  aten.index_put.default  (x, [None, None, input_pos], value)  {}
output         output     output                  ((index_put, index_put),)            {}
```

A more proper way to address this in long term is via pattern matching to replace the patterns with the simplified pattern

Perf:
For stories, before the diff
```
I 00:00:03.437290 executorch:runner.cpp:419] 	Prompt Tokens: 9    Generated Tokens: 118
I 00:00:03.437295 executorch:runner.cpp:425] 	Model Load Time:		0.763000 (seconds)
I 00:00:03.437301 executorch:runner.cpp:435] 	Total inference time:		2.661000 (seconds)		 Rate: 	44.344231 (tokens/second)
I 00:00:03.437305 executorch:runner.cpp:443] 		Prompt evaluation:	0.185000 (seconds)		 Rate: 	48.648649 (tokens/second)
I 00:00:03.437309 executorch:runner.cpp:454] 		Generated 118 tokens:	2.476000 (seconds)		 Rate: 	47.657512 (tokens/second)
I 00:00:03.437313 executorch:runner.cpp:462] 	Time to first generated token:	0.206000 (seconds)
I 00:00:03.437315 executorch:runner.cpp:469] 	Sampling time over 127 tokens:	0.042000 (seconds)
```
After the diff
```
I 00:00:03.195257 executorch:runner.cpp:419] 	Prompt Tokens: 9    Generated Tokens: 118
I 00:00:03.195295 executorch:runner.cpp:425] 	Model Load Time:		0.683000 (seconds)
I 00:00:03.195314 executorch:runner.cpp:435] 	Total inference time:		2.502000 (seconds)		 Rate: 	47.162270 (tokens/second)
I 00:00:03.195319 executorch:runner.cpp:443] 		Prompt evaluation:	0.175000 (seconds)		 Rate: 	51.428571 (tokens/second)
I 00:00:03.195323 executorch:runner.cpp:454] 		Generated 118 tokens:	2.327000 (seconds)		 Rate: 	50.709067 (tokens/second)
I 00:00:03.195327 executorch:runner.cpp:462] 	Time to first generated token:	0.195000 (seconds)
I 00:00:03.195330 executorch:runner.cpp:469] 	Sampling time over 127 tokens:	0.049000 (seconds)
```

Differential Revision: D57949659
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D57949659

@cccclai cccclai force-pushed the export-D57949659 branch from d9aed5a to 6574788 Compare July 18, 2024 22:08
)
v_out = torch.ops.aten.index_put_(
self.v_cache, [None, None, input_pos], v_val
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the 2nd v_out redundant?

Copy link
Contributor

@larryliu0820 larryliu0820 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you address the comment then land?

@iseeyuan
Copy link
Contributor

@cccclai Do you have any performance numbers for before and after comparison?

@cccclai
Copy link
Contributor Author

cccclai commented Feb 11, 2025

@cccclai Do you have any performance numbers for before and after comparison?

about 5% IIRC.

@cccclai
Copy link
Contributor Author

cccclai commented Jun 6, 2025

Stale PR

@cccclai cccclai closed this Jun 6, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants