-
Notifications
You must be signed in to change notification settings - Fork 151
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
nn module as base class for density estimator #965
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #965 +/- ##
==========================================
+ Coverage 76.08% 76.17% +0.08%
==========================================
Files 83 83
Lines 6361 6406 +45
==========================================
+ Hits 4840 4880 +40
- Misses 1521 1526 +5
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
A quick question: The |
I agree that it should be batched! But I believe that |
I see and I agree, I will add (b, n, d2) support. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR! I left a few small comments, happy to discuss.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks a lot for adding this!
A added a couple of comments and questions.
Sampling should now work with arbitrary context ndim and added a bunch of tests. One question @gmoss13 @janfb, the This is quite restrictive and always requires manually reshaping and repeating. This is fine for the
I think to mirror PyTorch behavior one would adapt "broadcasting" semantics to resolve the second point i.e.:
|
Alright, I added to the
To
Have to add a bit more documentation i.e. what we expect each density estimator can do. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work!
Added a couple of comment about variable names.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! almost ready I think.
Just needs clarification for the broadcasting and some more docstrings 🙄 sorry 😄
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the clarifications!
Feel free to re-assign me for a final review once everything is addressed.
Thanks for your comments. I tried to fix all left-over inconsistencies now. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, thanks!
Density Estimators inherit nn.Module.
Also added sample_and_log_prob, with a naive implementation that can be easily overriden (following Zuko API).
Limitations:
net
must be a nn.Module (which is fine and consistent with the type hints).