Conversation
samuela
left a comment
There was a problem hiding this comment.
Thanks for putting this together @matthieutrs! My main concern is that there seems to be a bug when stride != 1. I tried poking on it briefly, but couldn't figure it out... any idea what might be necessary here?
Would an implementation with assert stride == 1 work for your usecase?
Co-authored-by: Samuel Ainsworth <skainsworth@gmail.com>
Co-authored-by: Samuel Ainsworth <skainsworth@gmail.com>
Co-authored-by: Samuel Ainsworth <skainsworth@gmail.com>
Co-authored-by: Samuel Ainsworth <skainsworth@gmail.com>
|
Thanks a lot for the careful review!
Let me know if there are other things that need to be updated! |
samuela
left a comment
There was a problem hiding this comment.
Figured out a fix such that it works for basically all configuration options to torch.nn.ConvTranspose2d! The one exception being groups != 1, but I'm not worried about that for now.
Additionally,
- Make the test case deterministic by skirting reliance on PyTorch randomness for weights
- Decrease
atol
Co-authored-by: Samuel Ainsworth <skainsworth@gmail.com>
Co-authored-by: Samuel Ainsworth <skainsworth@gmail.com>
|
Thanks a lot! It indeed works on my architectures. |
|
Thanks so much @matthieutrs! You get a special prize for being the first person to merge a PR on torch2jax! 🌟 |
Thanks for this very nice repo! I needed to convert models containing
torch.nn.ConvTranspose2dbut this is not completely straightforward as torch and lax do not perform the same transposed conv.This PR is essentially a merging of jax-ml/jax#5772 to solve this issue.
Atm it relies on numpy; if this is a problem I think this could be avoided but I didn't have time to remove the dependency yet.