-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add basic conversion between ttir and linalg #1558
base: main
Are you sure you want to change the base?
Conversation
f36bad4
to
5272da5
Compare
5272da5
to
faf1bb4
Compare
Could be wrong, but I don't see any |
Let me double-check, there were quite a few interleaved commits I tried to manually split out, so wouldn't surprise me if I missed one. Thanks! Update: yup, was missing several commits...hopefully fixed now |
faf1bb4
to
cb27f5a
Compare
22dcfe3
to
1c254f7
Compare
1c254f7
to
f8ede73
Compare
84f34dd
to
85feec5
Compare
f98055a
to
cdaa88f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor comments inline, otherwise looks great!
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" | ||
#include "mlir/Dialect/Tensor/IR/Tensor.h" | ||
|
||
// #include "mlir/Dialect/Linalg/IR/LinalgOps.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Stray comment
cdaa88f
to
ad7f058
Compare
656c4ac
to
e6624f6
Compare
e6624f6
to
2fd1b2c
Compare
b9bdddc
to
5c663f5
Compare
|
||
// Helper func to check which dims need to be broadcast and which need to be | ||
// collapsed. Assumes that inputShape is broadcast-able to targetShape. | ||
static void getDimsToBroadcastAndCollapse( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logic is much more complicated than I originally expected, but linalg.broadcastOp asserts something like "input rank + len(dimsToBroadcastAlong) == output rank". We want to support torch-like broadcasting, where ranks with size 1 are implicitly broadcasted => we need to do a tensor.collapseOp in such cases, which requires hairy logic.
I would appreciate careful review/re-review of this logic, especially this function 🙂
|
||
using TensorRanks = SmallVector<int64_t, 2>; | ||
|
||
static LogicalResult computeBroadcastedShape(SmallVector<Value, 3> inputs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use OpTrait::util::getBroadcastedShape
util function to calculate broadcasted shape. It will save you all work you did here.
Also you don't need to check if operands are broadcastable since we already check this using traits for all eltwise operations (check verifyBroadcastable
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, I have switched to this func. Let me know if you think I should replace this if
check with some sort of assert etc instead given checks must have occurred at TTIR level
80986dd
to
7097494
Compare
Goal: The end-to-end goal is to integrate a path to compile and execute specific ops or sets of ops on the CPU.
Context:
The entire task will be split into (tentatively) 7 PRs, as follows:
This PR represents the 2nd subtask above--it converts TTIR ops in the "cpu" module into their linalg equivalents. In this case, we only enable a few basic operations for now, in an attempt to get basic end-to-end conversions working before fleshing out full conversion. Definitely open to adding some other specific ops to conversion before merging, not sure which ops would make most sense.
Example
Input:
Output: