Adding a new transposed convolution function to lax#5772
Adding a new transposed convolution function to lax#5772yang-song wants to merge 1 commit intojax-ml:mainfrom
Conversation
froystig
left a comment
There was a problem hiding this comment.
Thank you!
The changes to lax.py include many file-wide formatting adjustments. Could you undo those? We probably don't want to take those at the moment, and it obscures the main change in the diff. You'll also want to squash commits so that there isn't one commit that makes formatting changes followed by another that undoes them.
89dc631 to
6584a54
Compare
Just removed formatting changes (done automatically by my IDE). |
|
Are there any updates on this? It would be awesome to have this function. |
|
I'm not an expert on this, but I'm wondering what the pros and cons are of introducing a new API endpoint vs adding features like |
|
Because the meaning of |
|
Sorry, I wrote the original at a time when no frameworks really existed in JAX. (Nowadays, I'd probably not even add this function to "lax", since it's strictly a specialization of general convolutions, and delegate these matters to NN frameworks.) Aside from a pending review of correctness, this mainly comes down to a question of organization:
|
|
The issue with 1. might be that at least for Flax and Objax, all the convolution modules are just wrappers of the functions in |
|
Hey is this issue being actively worked on ? |
|
Hi all! Is there a plan to merge this PR? It seems that it is the root cause issue of converting some PyTorch models to JAX/FLAX , would be nice if we can merge it ;) |
|
The |
|
Also came here from the Flax docs. Is there any other way to use transposed convolutions that are compatible with PyTorch's |
This PR implements
lax.gradient_based_conv_transpose. Compared toconv_transpose, it provides support foroutput_shapeandoutput_padding. It matches the APIs for transposed convolutions derived from the gradient of a forward convolution, which is common in other deep learning frameworks such as TensorFlow, PyTorch, and Keras. This additional function on transposed convolution can make it much easier to reproduce code written in other (and currently more popular) frameworks.