Adds HOWTO "Convert PyTorch Models to Flax" to docs#1848
Adds HOWTO "Convert PyTorch Models to Flax" to docs#1848copybara-service[bot] merged 4 commits intogoogle:mainfrom
Conversation
|
This is super useful! Thanks so much for writing this up... |
andsteing
left a comment
There was a problem hiding this comment.
Great HOWTO, added some comments...
| from flax import linen as nn | ||
| from flax.core import freeze | ||
|
|
||
| import torch |
There was a problem hiding this comment.
this is currently failing because torch is not listed as a [testing] dependency.
could you add it to to setup.py in the same PR?
(we'd have to look at overall check setup and run time to see if this causes a significant change before deciding if we want to add it for real)
There was a problem hiding this comment.
Note: adding torch adds about 30 seconds to the Install Dependencies step in the build. The build time variance is rather large (12 min ... 20 min)
| Transposed Convolutions | ||
| -------------------------------- | ||
|
|
||
| ``torch.nn.ConvTranspose2d`` and |flax.linen.ConvTranspose|_ are not compatible. |
There was a problem hiding this comment.
This paragraph will make the reader wonder how a torch.nn.ConvTranspose2d should be converted. Do you know how?
There was a problem hiding this comment.
Thanks for adding the reference. I created #1872 to track this.
|
Thanks for the quick review! Regarding the paragraph on |
Codecov Report
@@ Coverage Diff @@
## main #1848 +/- ##
==========================================
+ Coverage 70.36% 70.50% +0.14%
==========================================
Files 58 58
Lines 4872 4886 +14
==========================================
+ Hits 3428 3445 +17
+ Misses 1444 1441 -3
Continue to review full report at Codecov.
|
| from flax import linen as nn | ||
| from flax.core import freeze | ||
|
|
||
| import torch |
There was a problem hiding this comment.
Note: adding torch adds about 30 seconds to the Install Dependencies step in the build. The build time variance is rather large (12 min ... 20 min)
andsteing
left a comment
There was a problem hiding this comment.
Thanks for the updates! One pending comment about the PyTorch compatible avg_pool() implementation, otherwise PR looks good to me.
|
Thanks for the feedback! |
This HOWTO shows how to convert PyTorch models to Flax.
The idea for this HOWTO originated in #1770, specifically this comment by @marcvanzee.