-
Notifications
You must be signed in to change notification settings - Fork 505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TorchToTosa] Refactoring to separate construction of legal/illegal ops and conversion patterns. #3759
base: main
Are you sure you want to change the base?
Conversation
Hi @sjarus, tagging you as you recently reviewed a PR for changes in the |
a76c138
to
49199da
Compare
I'll take a look within a day. |
target.addLegalOp<PrimTupleConstructOp>(); | ||
} | ||
|
||
void torch::populateTorchToTosaConversionIllegalOps(ConversionTarget &target) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
During rebasing main
branch I realized that it's easy to miss updating the list but this list is probably not required since target.addIllegalDialect<Torch::TorchDialect>()
is also present in this pass. There's also a check during VerifyBackend*
that ensures no torch
ops are present after a full pipeline runs. Thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding the Torch dialect would probably get in the way of partial legalization if that's the goal.
The handling Isn't there a way to make the aten op illegal conditionally - a successful rewrite ought to make it illegal but not otherwise ? I recall there was infrastructure proposed to do this, but this can be tricky when some instances of the pattern replacement succeed and others do not.
The alternative is to have the macro append to a list rather than make the op illegal and parameterize the pass behavior to either attempt a full conversion with a pass/fail or a partial conversion. You'd then either apply target.IllegalOp<> depending on whether the intended behavior is to have the pass fail on conversion or not. Does that work ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the review, @sjarus.
Yes it's possible to parameterize the pass to bypass the illegal ops check but I'd like to understand your concern better as I am still learning MLIR. For the existing pass, since there is already target.addIllegalDialect<Torch::TorchDialect>()
are the individual iilegalOps<>
adding any benefit or are they redundant?
For the tosa+linalg
pass that I'm prototyping the end-goal is to fail if any torch
op is present at the end of the full pipeline, so adding target.addIllegalDialect<Torch::TorchDialect>()
there is working as well (I've used the populateTorchToTosaConversionPatterns
in that new pipeline to perform partial conversion of the torch->tosa
ops in the pattern list).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So one immediate concern would be that the manual addition to the Aten ops list here would not be a tenable solution. It used to be abstracted behind the macro and worked cleanly, but within the constraints applied originally. The new approach should not add a construct that you've correctly recognized as easily breakable.
Secondly I'm trying to understand the design goal here. Let me try to describe my own understanding of the goal: the current pass makes all legalized Torch ops invalid so that a conversion failure manifests itself as a pass failure. This is ok if the goal is to layer two passes - TorchToTosa handling subset A of Torch ops, and subsequently TorchToLinalg for all remaining ones of interest (let's call that B).
However this breaks if there exists variants of ops in A that such that A is a combination of ops supported by TorchToTosa (A') and variants of those same ops that happen to be unsupported (A") . For example let's trivially presume conv2d with unit strides are supported but non unit strides are not.
The goal is to layer an implementation of A" within TorchToLinalg and implement a pipeline that has TorchToTosa for A' followed by TorchToLinalg for A" and B. This won't currently work because A" would be handled in TorchToTosa and emit a pass failure. The goal here is to mark A" valid in TorchToTosa so they can be consumed by TorchToLinalg later within this intended pipeline. Is that correct ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the elaborate response. I agree with your understanding but here are some clarifications:
So one immediate concern would be that the manual addition to the Aten ops list here would not be a tenable solution.
Yes, I agree. My thought is that since this pass pipeline already specifies torch
as an illegal dialect, we don't need to maintain this individual list of illegal ops. If any op from A'' is present for the current TorchToTosa
pipeline while the op itself will not be marked as illegal the pass will fail because torch
dialect is illegal after the pass completion.
The goal is to layer an implementation of A" within TorchToLinalg and implement a pipeline that has TorchToTosa for A' followed by TorchToLinalg for A" and B.
Yes, the end goal is let the TorchToTosa
pass handle as many op it can handle and let TorchToLinalg
handle the rest. So ops in A' will be handled by the conversions in TorchToTosa
and ops in A'' and B will be handled by conversions in TorchToLinalg
as you mentioned. Here is how I've setup the new pipeline https://github.com/sahas3/torch-mlir/blob/b0468a9ec367da0fb2c2e813f74437b9fa9ff7e8/lib/Conversion/TorchToTosaLinalg/TorchToTosaLinalg.cpp#L63. Instead of relying on the existing TorchToTosa
and TorchToLinalg
passes as individual passes I'm calling the rewrite patterns back to back in the same pass. In this case, my understanding is that populateTorchToTosaConversionPatterns
will handle A' but leave any ops in A'' as it is. Since we have not marked any ops in A to be illegal, there won't be any pass failure even if ops in A'' is present after running the TorchToTosa
pass patterns. Assuming A'' is supported by TorchToLinalg pass, it will then be processed correctly and lowered to linalg
+other dialects correctly as part of populateTorchToLinalgOnTensorsPatternsAndLegality
. The whole pipeline will fail if A'' is not handled by populateTorchToLinalgOnTensorsPatternsAndLegality
since we have torch
as illegal dialect in the new pass as well. For sanity, I verified that this new pass does handle the AvgPool2D
op with count_pad
set to true
that TorchToTosa
cannot support correctly (#3822) -- I see same IR generated for TorchToLinalg
and the new pass.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My thought is that since this pass pipeline already specifies torch as an illegal dialect, we don't need to maintain this individual list of illegal ops.
That would not quite be the same behavior. Leveraging the originally defined classes i.e.
A' : the op variants supported by TorchToTosa
A" + B: variants of A not supported by TorchToTosa plus additional ops B not supported in any manner (e.g. aten.sort)
Right now, the pass will return TOSA+Torch when presented with a model containing A' (which would convert to TOSA) + B (left alone) . It will not fail but will simply partially convert. If Torch dialect is made illegal, B would be illegal and would fail. That's materially different behavior.
So you'd want to retain the explicit list A . The main problem is that we currently do not disambiguate A' from A" when doing addIllegalOp() . That affects your ability to put things into one pipeline. If you add controllability of addIllegalOp such that it's only done if the conversion succeeded, then you can craft your pipeline synthetically by just sequencing TorchToTosa before TorchToLinalg since it'll just work.
Hi @sjarus , a gentle reminder to review this PR when you get a chance. Thanks! |
e62285c
to
a197fdf
Compare
a197fdf
to
14e3cea
Compare
Hi @sjarus any thoughts on this PR? |
14e3cea
to
6760b46
Compare
Terribly sorry - I completely missed this PR! Please ping me on discord if you don't get a timely response from me. Reviewing this now. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
.
This PR refactors TorchToTosa to separate the construction of legal/illegal ops and conversion patterns in their own functions:
Currently the (il)legality of the ops that are (il)legal after the conversion pass runs is embedded within the conversion pattern. Our end goal is to write a new pass pipeline that converts
torch
ops to a mix oftosa
,linalg
,tensor
, etc dialect ops. The reason we want to also emittosa
ops (instead of using the existingTorchToLinalg
to emitlinalg
+tensor
+...) is because some operations likeconv2d
encodes the padding behavior in the op intosa
unlike thelinalg
version -- this helps in lowering thetosa.conv2d
to a custom implementation that does padding on the fly.To implement this new pipeline we need to be able to separate out the illegal
tosa
ops from the conversion pattern itself. Otherwise we will hit an issue for ops likeAtenMaxDimOp
which can be lowered to bothtosa
andlinalg + others
dialects. Not allAtenMaxDimOp
can be lowered successfully totosa
as the implementation usestosa.reshape
which cannot handle multiple dynamic dimensions but theTorchToLinalg
lowering can handle it. In the current behavior the pipeline will stop as soon as the existingTorchToTosa
conversion runs asAtenMaxDimOp
will be marked as an illegal op.Essentially we want to be able to control what the legality of the ops should be independent of the conversion pattern. This is also inline with the conversion patterns in the llvm-mlir repo such as https://github.com/llvm/llvm-project/blob/000e790be35b77a01872851646d54432a203542c/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp#L718
"THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY."