-
Notifications
You must be signed in to change notification settings - Fork 806
[LinalgExt] Initial support for aten::flex_attention and rewrite to linalg_ext.attention #22441
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
9280d3f
857c180
1cd5d58
eea83bd
53796c6
336f6f5
b6a6bb6
fca1fe8
f860ccc
06e9499
eaa0b93
1f5107d
3383d55
8e929ce
f527211
2b54533
abdc3da
1036526
64aae75
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -789,6 +789,12 @@ def IREELinalgExt_AttentionOp : IREELinalgExt_Op<"attention", | |
| If an additional mask argument M is included, the result of the first matmul is modified according to: | ||
|
|
||
| Q @ K.T += M | ||
|
|
||
| Region: | ||
| The region body can receive either 1 or 5 block arguments: | ||
| - 1 argument (legacy): score (element type of output) | ||
| - 5 arguments: score, b (batch index), h (head index), m (query seq index), n (key/value seq index) | ||
IanWood1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| The region should yield a single value (the modified score). | ||
| }]; | ||
|
|
||
| let arguments = (ins AnyShaped:$query, | ||
|
|
@@ -914,6 +920,15 @@ def IREELinalgExt_OnlineAttentionOp : IREELinalgExt_Op<"online_attention", | |
| it over the entire softmax reduction dimension by: | ||
| x, _, sum : results | ||
| x = (1 / sum) * x | ||
|
|
||
| Region: | ||
| The region body receives the following block arguments: | ||
| - score: the computed score value from Q @ K.T (element type of output) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: Can you make this such that the index operands come first and then the score.... Not Nit: The way the attention operation is setup it can increase dimensionality of the operation. So there isnt necessarily one batch dimension, there could be multiple batch dimensions. Same for head/m/n etc. That probably needs to be accounted for in the change. Just having a single batch dimension is not going to work. |
||
| - b: batch index (index type) | ||
| - h: head index (index type) | ||
| - m: query sequence index (index type) | ||
| - n: key/value sequence index (index type) | ||
| The region should yield a single value (the modified score). | ||
| }]; | ||
|
|
||
| let arguments = (ins AnyShaped:$query, | ||
|
|
@@ -998,6 +1013,8 @@ def IREELinalgExt_OnlineAttentionOp : IREELinalgExt_Op<"online_attention", | |
| // Attributes to set on QK and PV matmul after decomposition. | ||
| static StringRef getQKAttrStr() { return "qk_attrs"; } | ||
| static StringRef getPVAttrStr() { return "pv_attrs"; } | ||
| // Flag to control whether to use exp2 (with log2(e) scaling) or exp. | ||
| static StringRef getUseExp2AttrStr() { return "use_exp2"; } | ||
| }]; | ||
|
|
||
| let hasCanonicalizer = 1; | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.