-
Notifications
You must be signed in to change notification settings - Fork 483
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to write in-place custom ops compatible with torch.compile using pallas #8385
Comments
You are right, when I talked to @bdhirsh my take away was it is better to make the custom op functional. I think functionization will not run inside a custom op. If you want to enable the buffer aliasing with xla/test/dynamo/test_dynamo_aliasing.py Line 32 in 102cd48
|
Let me take a look... |
I also run your code and saw
This is running with dynamo, at least in the HLO we passed to the XLA it does not have the copy. Let me see which stage this copy is added. |
so I run your code with
to collect HLOs at different stages, I saw
in
There is no copy and the buffer donor was setup correctly in
so it confirms that this |
❓ Questions and Help
I'm trying to implement an in-place operator using pallas, and wrap it as a torch custom op. However, I found it difficult to make it work with
torch.compile
. More specifically, I’m unclear about how to set donation, input-output aliases, and the op schema. It seems having an output aliased with the input will leads to functionalization problems in torch compiler.Thanks!
My script is like this:
And it seems it does not change the value of
x
.The text was updated successfully, but these errors were encountered: