-
Notifications
You must be signed in to change notification settings - Fork 16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for 1d convolution in ttir and ttnn mlir dialects #1438
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hej Andrej, I'd like to see some more comments and some simplifications before we land this, hence requesting changes.
lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp
Outdated
Show resolved
Hide resolved
lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp
Outdated
Show resolved
Hide resolved
|
||
if (failed(isConv2d(op))) { | ||
LogicalResult matchAndRewrite1d(ttir::ConvolutionOp op, OpAdaptor adaptor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add commentary outlining what's going on in this fn? It's hard to follow thru code only.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have added the comments at the header of a newly added pattern for 1d convolution. Please let me know if you think I should go into more details in the code itself.
lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp
Outdated
Show resolved
Hide resolved
lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp
Outdated
Show resolved
Hide resolved
LogicalResult isValidConv(ttir::ConvolutionOp op, | ||
uint32_t numSpatialDims) const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This method is kind of broken now. The intent is to check whether it's a valid conv, but it only compares spatial dims cnt vs the provided numSpatialDims
, so if someone pushes a 3d conv thru, it will deem it valid. It's a strange pattern to have to provide the intended number of spatial dims to the fn. In my eyes, it makes more sense for the fn to check if the conv is 1d/2d vs something else.
Btw with this change you've overshadowed the more global constexpr static uint32_t numSpatialDims = 2;
which was intended to signal that only 2d convs were valid.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess we should think in terms of support, not validity. isSupportedConv
would better convey the intent of the function after above-mentioned changes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With regards to change change of structure outlined in the comment written below, I have changed both of the validity/support checks (for numSpatialDims
and otherwise) to live in the base ConvolutionDecompositionPattern
class, which both ConvolutionToConv2dPattern
and Legalize1DConvolutionPattern
inherit from.
Please go through the changes regarding comments one more time and let me know if you think I missed something. Thanks!
lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp
Outdated
Show resolved
Hide resolved
lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp
Outdated
Show resolved
Hide resolved
lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp
Outdated
Show resolved
Hide resolved
Sorry for being late to take a look. I think it might be better to make a “ttir.conv1d” operation, and a separate ConvolutionToConv1D pattern rewriter. This way the op is clearly differentiated in the IR. We can add the reshapes that insert the second spatial dim during lowering to TTNN. And if they ever do expose conv1d as a c++ api (which shouldn’t be hard) we can just bring that up between TTNN dialect and runtime. I’m saying this because my original intent with ConvolutionToConv2D was to convert any Convolution with two spatial dimensions to a Conv2D op by inserting whatever transposes necessary on the input/output. This is semantically not a conv2D although it can be mimicked by one. @svuckovicTT what do you think? |
Thanks for looking into this @LPanosTT ! I have thought about the thing that you proposed when I first started implementing this, and had discussions with the TTIR dialect team (@sdjordjevicTT can chime in if has the time) about the best way to do this. We have settled on this solution, as, at least to me, there is no considerable benefit to making a |
Update: |
d920742
to
9a32d91
Compare
lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp
Outdated
Show resolved
Hide resolved
e9ac024
to
220725a
Compare
* Added support for 1d convolution * Added tests * Addressed comments * Addressed more comments
As described in issue #1342, ttnn does not currently have support for 1D convolution (bug opened tenstorrent/tt-metal#15452). Hence, in talks with @LPanosTT, we agreed that it is ok to add reshapes before and after convolutions, so that we transform 1D convolution into shapes suitable for conv2d op in ttnn.
This change is added in the
ConvolutionToConv2dPattern
in the 'TTIRToTTIRDecomposition', with different paths based on wheather the convolution is 2D or 1D.The change was tested on convolution given in the issue above, which is:
And it works and runs on silicon.