Skip to content

Comments

Adds HOWTO "Convert PyTorch Models to Flax" to docs#1848

Merged
copybara-service[bot] merged 4 commits intogoogle:mainfrom
matthias-wright:howto-torch-to-flax
Feb 9, 2022
Merged

Adds HOWTO "Convert PyTorch Models to Flax" to docs#1848
copybara-service[bot] merged 4 commits intogoogle:mainfrom
matthias-wright:howto-torch-to-flax

Conversation

@matthias-wright
Copy link
Contributor

This HOWTO shows how to convert PyTorch models to Flax.
The idea for this HOWTO originated in #1770, specifically this comment by @marcvanzee.

@andsteing andsteing self-requested a review February 3, 2022 10:24
@andsteing andsteing self-assigned this Feb 3, 2022
@andsteing
Copy link
Contributor

This is super useful! Thanks so much for writing this up...

Copy link
Contributor

@andsteing andsteing left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great HOWTO, added some comments...

from flax import linen as nn
from flax.core import freeze

import torch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is currently failing because torch is not listed as a [testing] dependency.

could you add it to to setup.py in the same PR?
(we'd have to look at overall check setup and run time to see if this causes a significant change before deciding if we want to add it for real)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: adding torch adds about 30 seconds to the Install Dependencies step in the build. The build time variance is rather large (12 min ... 20 min)

Transposed Convolutions
--------------------------------

``torch.nn.ConvTranspose2d`` and |flax.linen.ConvTranspose|_ are not compatible.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This paragraph will make the reader wonder how a torch.nn.ConvTranspose2d should be converted. Do you know how?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding the reference. I created #1872 to track this.

@matthias-wright
Copy link
Contributor Author

Thanks for the quick review!

Regarding the paragraph on torch.nn.ConvTranspose2d: there is an implementation of a gradient based transposed convolution but the PR is currently pending. For now, I just linked to the PR in the HOWTO. In order to create a code example, I would have to copy the implementation from the PR into the HOWTO. But the implementation consists of 207 lines of code. I think that might make the HOWTO a bit messy. What do you think?

@codecov-commenter
Copy link

codecov-commenter commented Feb 8, 2022

Codecov Report

Merging #1848 (99fcbfa) into main (811084f) will increase coverage by 0.14%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1848      +/-   ##
==========================================
+ Coverage   70.36%   70.50%   +0.14%     
==========================================
  Files          58       58              
  Lines        4872     4886      +14     
==========================================
+ Hits         3428     3445      +17     
+ Misses       1444     1441       -3     
Impacted Files Coverage Δ
flax/linen/__init__.py 100.00% <100.00%> (ø)
flax/linen/transforms.py 95.89% <0.00%> (+1.74%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 811084f...99fcbfa. Read the comment docs.

from flax import linen as nn
from flax.core import freeze

import torch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: adding torch adds about 30 seconds to the Install Dependencies step in the build. The build time variance is rather large (12 min ... 20 min)

Copy link
Contributor

@andsteing andsteing left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates! One pending comment about the PyTorch compatible avg_pool() implementation, otherwise PR looks good to me.

@matthias-wright
Copy link
Contributor Author

Thanks for the feedback!

@copybara-service copybara-service bot merged commit cb5ebcd into google:main Feb 9, 2022
@matthias-wright matthias-wright deleted the howto-torch-to-flax branch February 9, 2022 17:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants