Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for 1d convolution in ttir and ttnn mlir dialects #1438

Merged
merged 4 commits into from
Dec 9, 2024

Conversation

ajakovljevicTT
Copy link
Contributor

As described in issue #1342, ttnn does not currently have support for 1D convolution (bug opened tenstorrent/tt-metal#15452). Hence, in talks with @LPanosTT, we agreed that it is ok to add reshapes before and after convolutions, so that we transform 1D convolution into shapes suitable for conv2d op in ttnn.

This change is added in the ConvolutionToConv2dPattern in the 'TTIRToTTIRDecomposition', with different paths based on wheather the convolution is 2D or 1D.

The change was tested on convolution given in the issue above, which is:

module {
  func.func @main(%arg0: tensor<1x256x512xf32>, %arg1: tensor<1024x256x1xf32>, %arg2: tensor<1024xf32>) -> tensor<1x1024x512xf32> {
    %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, f, 0]x[o, i, 0]->[b, f, 0], window = {stride = [1], pad = [[0, 0]], rhs_dilate = [1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x256x512xf32>, tensor<1024x256x1xf32>) -> tensor<1x1024x512xf32>
    return %0 : tensor<1x1024x512xf32>
  }
}

And it works and runs on silicon.

Copy link
Contributor

@svuckovicTT svuckovicTT left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hej Andrej, I'd like to see some more comments and some simplifications before we land this, hence requesting changes.


if (failed(isConv2d(op))) {
LogicalResult matchAndRewrite1d(ttir::ConvolutionOp op, OpAdaptor adaptor,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add commentary outlining what's going on in this fn? It's hard to follow thru code only.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added the comments at the header of a newly added pattern for 1d convolution. Please let me know if you think I should go into more details in the code itself.

Comment on lines 256 to 257
LogicalResult isValidConv(ttir::ConvolutionOp op,
uint32_t numSpatialDims) const {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method is kind of broken now. The intent is to check whether it's a valid conv, but it only compares spatial dims cnt vs the provided numSpatialDims, so if someone pushes a 3d conv thru, it will deem it valid. It's a strange pattern to have to provide the intended number of spatial dims to the fn. In my eyes, it makes more sense for the fn to check if the conv is 1d/2d vs something else.

Btw with this change you've overshadowed the more global constexpr static uint32_t numSpatialDims = 2; which was intended to signal that only 2d convs were valid.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we should think in terms of support, not validity. isSupportedConv would better convey the intent of the function after above-mentioned changes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With regards to change change of structure outlined in the comment written below, I have changed both of the validity/support checks (for numSpatialDims and otherwise) to live in the base ConvolutionDecompositionPattern class, which both ConvolutionToConv2dPattern and Legalize1DConvolutionPattern inherit from.
Please go through the changes regarding comments one more time and let me know if you think I missed something. Thanks!

@LPanosTT
Copy link
Contributor

LPanosTT commented Dec 5, 2024

Sorry for being late to take a look.

I think it might be better to make a “ttir.conv1d” operation, and a separate ConvolutionToConv1D pattern rewriter. This way the op is clearly differentiated in the IR. We can add the reshapes that insert the second spatial dim during lowering to TTNN. And if they ever do expose conv1d as a c++ api (which shouldn’t be hard) we can just bring that up between TTNN dialect and runtime.

I’m saying this because my original intent with ConvolutionToConv2D was to convert any Convolution with two spatial dimensions to a Conv2D op by inserting whatever transposes necessary on the input/output. This is semantically not a conv2D although it can be mimicked by one.

@svuckovicTT what do you think?

@ajakovljevicTT
Copy link
Contributor Author

Sorry for being late to take a look.

I think it might be better to make a “ttir.conv1d” operation, and a separate ConvolutionToConv1D pattern rewriter. This way the op is clearly differentiated in the IR. We can add the reshapes that insert the second spatial dim during lowering to TTNN. And if they ever do expose conv1d as a c++ api (which shouldn’t be hard) we can just bring that up between TTNN dialect and runtime.

I’m saying this because my original intent with ConvolutionToConv2D was to convert any Convolution with two spatial dimensions to a Conv2D op by inserting whatever transposes necessary on the input/output. This is semantically not a conv2D although it can be mimicked by one.

@svuckovicTT what do you think?

Thanks for looking into this @LPanosTT ! I have thought about the thing that you proposed when I first started implementing this, and had discussions with the TTIR dialect team (@sdjordjevicTT can chime in if has the time) about the best way to do this. We have settled on this solution, as, at least to me, there is no considerable benefit to making a ttir.conv1d currently, without the adequate ttnn/metal support. Furthermore, the conversion from ttir.convolution to ttnn.conv2d is non trivial (as can be seen from the ConvolutionToConv2dPattern in lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp), and most of that logic would need to be replicated to in the conv1d decomposition pass, only to be later deleted if ttnn folks expose the con1d C++ ttnn op. If you feel strongly about moving this conversion to TTIR->TTNN pass, we can discuss this further.

@ajakovljevicTT
Copy link
Contributor Author

ajakovljevicTT commented Dec 6, 2024

Sorry for being late to take a look.
I think it might be better to make a “ttir.conv1d” operation, and a separate ConvolutionToConv1D pattern rewriter. This way the op is clearly differentiated in the IR. We can add the reshapes that insert the second spatial dim during lowering to TTNN. And if they ever do expose conv1d as a c++ api (which shouldn’t be hard) we can just bring that up between TTNN dialect and runtime.
I’m saying this because my original intent with ConvolutionToConv2D was to convert any Convolution with two spatial dimensions to a Conv2D op by inserting whatever transposes necessary on the input/output. This is semantically not a conv2D although it can be mimicked by one.
@svuckovicTT what do you think?

Thanks for looking into this @LPanosTT ! I have thought about the thing that you proposed when I first started implementing this, and had discussions with the TTIR dialect team (@sdjordjevicTT can chime in if has the time) about the best way to do this. We have settled on this solution, as, at least to me, there is no considerable benefit to making a ttir.conv1d currently, without the adequate ttnn/metal support. Furthermore, the conversion from ttir.convolution to ttnn.conv2d is non trivial (as can be seen from the ConvolutionToConv2dPattern in lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp), and most of that logic would need to be replicated to in the conv1d decomposition pass, only to be later deleted if ttnn folks expose the con1d C++ ttnn op. If you feel strongly about moving this conversion to TTIR->TTNN pass, we can discuss this further.

Update:
I discussed with @LPanosTT about the design of this in more detail offline, and we agreed to move the pass that rewrites 1d ttir.convolution into 2d ttir.convolution op into a separate Legalize1DConvolutionPattern pattern. That way, te end result op of this pattern will be picked up by the already imlemented ConvolutionToConv2dPattern. Later on, when ttnn introduces support for the conv1d the pattern could be altered to just call the corresponding ttnn C++ op.

@ajakovljevicTT ajakovljevicTT force-pushed the ajakovljevic/fix_1d_conv branch 3 times, most recently from d920742 to 9a32d91 Compare December 6, 2024 14:39
@ajakovljevicTT ajakovljevicTT force-pushed the ajakovljevic/fix_1d_conv branch from e9ac024 to 220725a Compare December 9, 2024 12:59
@ajakovljevicTT ajakovljevicTT merged commit d8cc464 into main Dec 9, 2024
19 checks passed
azecevicTT pushed a commit that referenced this pull request Dec 17, 2024
* Added support for 1d convolution

* Added tests

* Addressed comments

* Addressed more comments
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants