Skip to content

adding torch.nn.ConvTranspose2d#3

Merged
samuela merged 10 commits intosamuela:mainfrom
matthieutrs:main
Oct 31, 2023
Merged

adding torch.nn.ConvTranspose2d#3
samuela merged 10 commits intosamuela:mainfrom
matthieutrs:main

Conversation

@matthieutrs
Copy link
Contributor

Thanks for this very nice repo! I needed to convert models containing torch.nn.ConvTranspose2d but this is not completely straightforward as torch and lax do not perform the same transposed conv.

This PR is essentially a merging of jax-ml/jax#5772 to solve this issue.

Atm it relies on numpy; if this is a problem I think this could be avoided but I didn't have time to remove the dependency yet.

Copy link
Owner

@samuela samuela left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for putting this together @matthieutrs! My main concern is that there seems to be a bug when stride != 1. I tried poking on it briefly, but couldn't figure it out... any idea what might be necessary here?

Would an implementation with assert stride == 1 work for your usecase?

matthieutrs and others added 6 commits October 28, 2023 13:36
Co-authored-by: Samuel Ainsworth <skainsworth@gmail.com>
Co-authored-by: Samuel Ainsworth <skainsworth@gmail.com>
Co-authored-by: Samuel Ainsworth <skainsworth@gmail.com>
Co-authored-by: Samuel Ainsworth <skainsworth@gmail.com>
@matthieutrs
Copy link
Contributor Author

Thanks a lot for the careful review!

  1. The dependency to numpy has been removed;
  2. The failling tests for some strides/kernel sizes was due to jax assuming a certain output_padding in torch.nn.ConvTranspose2d, I've added an assertion in the def of conv_transpose2d (see above);

Let me know if there are other things that need to be updated!

Copy link
Owner

@samuela samuela left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Figured out a fix such that it works for basically all configuration options to torch.nn.ConvTranspose2d! The one exception being groups != 1, but I'm not worried about that for now.

Additionally,

  • Make the test case deterministic by skirting reliance on PyTorch randomness for weights
  • Decrease atol

matthieutrs and others added 2 commits October 30, 2023 12:25
Co-authored-by: Samuel Ainsworth <skainsworth@gmail.com>
Co-authored-by: Samuel Ainsworth <skainsworth@gmail.com>
@matthieutrs
Copy link
Contributor Author

Thanks a lot! It indeed works on my architectures.

@samuela samuela merged commit 4ed99f5 into samuela:main Oct 31, 2023
@samuela
Copy link
Owner

samuela commented Oct 31, 2023

Thanks so much @matthieutrs! You get a special prize for being the first person to merge a PR on torch2jax! 🌟

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants

Comments