-
Notifications
You must be signed in to change notification settings - Fork 346
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow custom optimizers, improve JaxTrainingPlan
#1747
Conversation
adamgayoso
commented
Oct 17, 2022
•
edited
Loading
edited
- Easily add optax/pytorch custom optimizers easily
- Breaking change where we force keyword args to the training plans
Codecov ReportBase: 90.96% // Head: 90.87% // Decreases project coverage by
Additional details and impacted files@@ Coverage Diff @@
## main #1747 +/- ##
==========================================
- Coverage 90.96% 90.87% -0.09%
==========================================
Files 116 116
Lines 9757 9764 +7
==========================================
- Hits 8875 8873 -2
- Misses 882 891 +9
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. ☔ View full report at Codecov. |
# Allclose because on GPU, the values are not exactly the same | ||
# as latents are moved to cpu in latent mode | ||
np.testing.assert_allclose( | ||
params_latent[k], params_orig[k], rtol=3e-1, atol=5e-1 | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @watiss needed to change this. LMK if the comment makes sense!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM