Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TorchToTosa] Refactoring to separate construction of legal/illegal ops and conversion patterns. #3759

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

sahas3
Copy link
Contributor

@sahas3 sahas3 commented Oct 3, 2024

This PR refactors TorchToTosa to separate the construction of legal/illegal ops and conversion patterns in their own functions:

  1. populateTorchToTosaConversionLegalOps -- populate any ops that are legal after the conversion pass
  2. populateTorchToTosaConversionIllegalOps -- populate any ops that are illegal after the conversion pass
  3. populateTorchToTosaConversionPatterns -- populate the ops conversion patterns

Currently the (il)legality of the ops that are (il)legal after the conversion pass runs is embedded within the conversion pattern. Our end goal is to write a new pass pipeline that converts torch ops to a mix of tosa, linalg, tensor, etc dialect ops. The reason we want to also emit tosa ops (instead of using the existing TorchToLinalg to emit linalg+tensor+...) is because some operations like conv2d encodes the padding behavior in the op in tosa unlike the linalg version -- this helps in lowering the tosa.conv2d to a custom implementation that does padding on the fly.

To implement this new pipeline we need to be able to separate out the illegal tosa ops from the conversion pattern itself. Otherwise we will hit an issue for ops like AtenMaxDimOp which can be lowered to both tosa and linalg + others dialects. Not all AtenMaxDimOp can be lowered successfully to tosa as the implementation uses tosa.reshape which cannot handle multiple dynamic dimensions but the TorchToLinalg lowering can handle it. In the current behavior the pipeline will stop as soon as the existing TorchToTosa conversion runs as AtenMaxDimOp will be marked as an illegal op.

Essentially we want to be able to control what the legality of the ops should be independent of the conversion pattern. This is also inline with the conversion patterns in the llvm-mlir repo such as https://github.com/llvm/llvm-project/blob/000e790be35b77a01872851646d54432a203542c/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp#L718

"THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY."

@sahas3
Copy link
Contributor Author

sahas3 commented Oct 7, 2024

Hi @sjarus, tagging you as you recently reviewed a PR for changes in the TorchToTosa files. I think since I do not have write access yet, I couldn't directly request your feedback. Also, for context this PR is a requirement for adding a new torch to tosa+linalg pipeline that we (MathWorks) mentioned in the last TOSA meeting. It'll be great if you can take a look at these changes. Thanks!

@sjarus
Copy link
Collaborator

sjarus commented Oct 8, 2024

I'll take a look within a day.

@sjarus sjarus self-requested a review October 8, 2024 23:47
target.addLegalOp<PrimTupleConstructOp>();
}

void torch::populateTorchToTosaConversionIllegalOps(ConversionTarget &target) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

During rebasing main branch I realized that it's easy to miss updating the list but this list is probably not required since target.addIllegalDialect<Torch::TorchDialect>() is also present in this pass. There's also a check during VerifyBackend* that ensures no torch ops are present after a full pipeline runs. Thoughts?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding the Torch dialect would probably get in the way of partial legalization if that's the goal.

The handling Isn't there a way to make the aten op illegal conditionally - a successful rewrite ought to make it illegal but not otherwise ? I recall there was infrastructure proposed to do this, but this can be tricky when some instances of the pattern replacement succeed and others do not.

The alternative is to have the macro append to a list rather than make the op illegal and parameterize the pass behavior to either attempt a full conversion with a pass/fail or a partial conversion. You'd then either apply target.IllegalOp<> depending on whether the intended behavior is to have the pass fail on conversion or not. Does that work ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review, @sjarus.

Yes it's possible to parameterize the pass to bypass the illegal ops check but I'd like to understand your concern better as I am still learning MLIR. For the existing pass, since there is already target.addIllegalDialect<Torch::TorchDialect>() are the individual iilegalOps<> adding any benefit or are they redundant?

For the tosa+linalg pass that I'm prototyping the end-goal is to fail if any torch op is present at the end of the full pipeline, so adding target.addIllegalDialect<Torch::TorchDialect>() there is working as well (I've used the populateTorchToTosaConversionPatterns in that new pipeline to perform partial conversion of the torch->tosa ops in the pattern list).

Copy link
Collaborator

@sjarus sjarus Nov 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So one immediate concern would be that the manual addition to the Aten ops list here would not be a tenable solution. It used to be abstracted behind the macro and worked cleanly, but within the constraints applied originally. The new approach should not add a construct that you've correctly recognized as easily breakable.

Secondly I'm trying to understand the design goal here. Let me try to describe my own understanding of the goal: the current pass makes all legalized Torch ops invalid so that a conversion failure manifests itself as a pass failure. This is ok if the goal is to layer two passes - TorchToTosa handling subset A of Torch ops, and subsequently TorchToLinalg for all remaining ones of interest (let's call that B).

However this breaks if there exists variants of ops in A that such that A is a combination of ops supported by TorchToTosa (A') and variants of those same ops that happen to be unsupported (A") . For example let's trivially presume conv2d with unit strides are supported but non unit strides are not.

The goal is to layer an implementation of A" within TorchToLinalg and implement a pipeline that has TorchToTosa for A' followed by TorchToLinalg for A" and B. This won't currently work because A" would be handled in TorchToTosa and emit a pass failure. The goal here is to mark A" valid in TorchToTosa so they can be consumed by TorchToLinalg later within this intended pipeline. Is that correct ?

Copy link
Contributor Author

@sahas3 sahas3 Nov 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the elaborate response. I agree with your understanding but here are some clarifications:

So one immediate concern would be that the manual addition to the Aten ops list here would not be a tenable solution.

Yes, I agree. My thought is that since this pass pipeline already specifies torch as an illegal dialect, we don't need to maintain this individual list of illegal ops. If any op from A'' is present for the current TorchToTosa pipeline while the op itself will not be marked as illegal the pass will fail because torch dialect is illegal after the pass completion.

The goal is to layer an implementation of A" within TorchToLinalg and implement a pipeline that has TorchToTosa for A' followed by TorchToLinalg for A" and B.

Yes, the end goal is let the TorchToTosa pass handle as many op it can handle and let TorchToLinalg handle the rest. So ops in A' will be handled by the conversions in TorchToTosa and ops in A'' and B will be handled by conversions in TorchToLinalg as you mentioned. Here is how I've setup the new pipeline https://github.com/sahas3/torch-mlir/blob/b0468a9ec367da0fb2c2e813f74437b9fa9ff7e8/lib/Conversion/TorchToTosaLinalg/TorchToTosaLinalg.cpp#L63. Instead of relying on the existing TorchToTosa and TorchToLinalg passes as individual passes I'm calling the rewrite patterns back to back in the same pass. In this case, my understanding is that populateTorchToTosaConversionPatterns will handle A' but leave any ops in A'' as it is. Since we have not marked any ops in A to be illegal, there won't be any pass failure even if ops in A'' is present after running the TorchToTosa pass patterns. Assuming A'' is supported by TorchToLinalg pass, it will then be processed correctly and lowered to linalg+other dialects correctly as part of populateTorchToLinalgOnTensorsPatternsAndLegality. The whole pipeline will fail if A'' is not handled by populateTorchToLinalgOnTensorsPatternsAndLegality since we have torch as illegal dialect in the new pass as well. For sanity, I verified that this new pass does handle the AvgPool2D op with count_pad set to true that TorchToTosa cannot support correctly (#3822) -- I see same IR generated for TorchToLinalg and the new pass.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My thought is that since this pass pipeline already specifies torch as an illegal dialect, we don't need to maintain this individual list of illegal ops.

That would not quite be the same behavior. Leveraging the originally defined classes i.e.
A' : the op variants supported by TorchToTosa
A" + B: variants of A not supported by TorchToTosa plus additional ops B not supported in any manner (e.g. aten.sort)

Right now, the pass will return TOSA+Torch when presented with a model containing A' (which would convert to TOSA) + B (left alone) . It will not fail but will simply partially convert. If Torch dialect is made illegal, B would be illegal and would fail. That's materially different behavior.

So you'd want to retain the explicit list A . The main problem is that we currently do not disambiguate A' from A" when doing addIllegalOp() . That affects your ability to put things into one pipeline. If you add controllability of addIllegalOp such that it's only done if the conversion succeeded, then you can craft your pipeline synthetically by just sequencing TorchToTosa before TorchToLinalg since it'll just work.

@sahas3
Copy link
Contributor Author

sahas3 commented Oct 16, 2024

Hi @sjarus , a gentle reminder to review this PR when you get a chance. Thanks!

@sahas3
Copy link
Contributor Author

sahas3 commented Oct 29, 2024

Hi @sjarus any thoughts on this PR?

@sjarus
Copy link
Collaborator

sjarus commented Oct 31, 2024

Terribly sorry - I completely missed this PR! Please ping me on discord if you don't get a timely response from me. Reviewing this now.

sjarus
sjarus previously approved these changes Oct 31, 2024
Copy link
Collaborator

@sjarus sjarus left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.

@sjarus sjarus dismissed their stale review October 31, 2024 05:23

updating comment

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants