Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nn module as base class for density estimator #965

Merged
merged 14 commits into from
Mar 4, 2024

Conversation

manuelgloeckler
Copy link
Contributor

Density Estimators inherit nn.Module.
Also added sample_and_log_prob, with a naive implementation that can be easily overriden (following Zuko API).

Limitations:

  • net must be a nn.Module (which is fine and consistent with the type hints).

Copy link

codecov bot commented Feb 28, 2024

Codecov Report

Attention: Patch coverage is 90.38462% with 5 lines in your changes are missing coverage. Please review.

Project coverage is 76.17%. Comparing base (c3c2b6e) to head (082dd8d).

Files Patch % Lines
sbi/neural_nets/density_estimators/base.py 66.66% 5 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #965      +/-   ##
==========================================
+ Coverage   76.08%   76.17%   +0.08%     
==========================================
  Files          83       83              
  Lines        6361     6406      +45     
==========================================
+ Hits         4840     4880      +40     
- Misses       1521     1526       +5     
Flag Coverage Δ
unittests 76.17% <90.38%> (+0.08%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@manuelgloeckler
Copy link
Contributor Author

A quick question:

The sample for NFlowsFlow with e.g. sample shape (n,) will given a batched condition (b, d) return a tensor of shape (10, b*d). Is this intended? To be consistent with torch we should return (10, b,d), no?

@gmoss13
Copy link
Contributor

gmoss13 commented Feb 29, 2024

A quick question:

The sample for NFlowsFlow with e.g. sample shape (n,) will given a batched condition (b, d) return a tensor of shape (10, b*d). Is this intended? To be consistent with torch we should return (10, b,d), no?

I agree that it should be batched! But I believe that nflows in this case would return the condition batch dimension first in such a case, i.e. (b, n ,d2) where d2 is the dimension of the inputs (and d is the dimension of the condition). I would be in favour of sticking to this convention.

@manuelgloeckler
Copy link
Contributor Author

I see and I agree, I will add (b, n, d2) support.

Copy link
Contributor

@gmoss13 gmoss13 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! I left a few small comments, happy to discuss.

Copy link
Contributor

@janfb janfb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks a lot for adding this!
A added a couple of comments and questions.

sbi/neural_nets/density_estimators/flow.py Outdated Show resolved Hide resolved
sbi/neural_nets/density_estimators/flow.py Outdated Show resolved Hide resolved
tests/density_estimator_test.py Show resolved Hide resolved
@manuelgloeckler
Copy link
Contributor Author

Sampling should now work with arbitrary context ndim and added a bunch of tests.

One question @gmoss13 @janfb, the log_prob as it was implemented currently only works if input.shape[0] == condition.shape[0] (by nflows).

This is quite restrictive and always requires manually reshaping and repeating. This is fine for the loss, but we might want to make this more general for the log_prob method.

  • If we have an input of shape (n, d1) and condition (d2,) -> (n,) (automate repeat)
  • If we have an input of shape (n, d1) and condition (m,d2) -> (m,n) ? (this is in conflict with previous behavior for m=n)

I think to mirror PyTorch behavior one would adapt "broadcasting" semantics to resolve the second point i.e.:

  • (n,1,d1), (m, d2) -> (n,m)
  • (n, d1), (m,d2) -> error (n != m) or (n,) for (m == n)

@manuelgloeckler
Copy link
Contributor Author

manuelgloeckler commented Feb 29, 2024

Alright, I added to the DensityEstimator

  • x_shape to density estimator + a basic check if inputs are consistent with this attribute

To NFlowFlows:

  • log_prob through broadcasting (a bit ugly due to nflow limitaitons).
  • Using x_shape to squeeze/unsqueeze/reshape the right dimensions.
  • Extended test cases.

Have to add a bit more documentation i.e. what we expect each density estimator can do.

Copy link
Contributor

@janfb janfb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work!
Added a couple of comment about variable names.

sbi/neural_nets/density_estimators/base.py Outdated Show resolved Hide resolved
sbi/neural_nets/density_estimators/base.py Outdated Show resolved Hide resolved
sbi/neural_nets/density_estimators/base.py Outdated Show resolved Hide resolved
sbi/neural_nets/density_estimators/base.py Outdated Show resolved Hide resolved
sbi/neural_nets/density_estimators/base.py Outdated Show resolved Hide resolved
sbi/neural_nets/density_estimators/flow.py Outdated Show resolved Hide resolved
tests/density_estimator_test.py Show resolved Hide resolved
tests/density_estimator_test.py Show resolved Hide resolved
tests/density_estimator_test.py Outdated Show resolved Hide resolved
tests/density_estimator_test.py Show resolved Hide resolved
Copy link
Contributor

@janfb janfb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! almost ready I think.
Just needs clarification for the broadcasting and some more docstrings 🙄 sorry 😄

sbi/neural_nets/density_estimators/base.py Outdated Show resolved Hide resolved
sbi/neural_nets/density_estimators/base.py Outdated Show resolved Hide resolved
sbi/neural_nets/density_estimators/base.py Outdated Show resolved Hide resolved
sbi/neural_nets/density_estimators/base.py Outdated Show resolved Hide resolved
sbi/neural_nets/density_estimators/base.py Show resolved Hide resolved
Copy link
Contributor

@janfb janfb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the clarifications!
Feel free to re-assign me for a final review once everything is addressed.

sbi/neural_nets/density_estimators/base.py Outdated Show resolved Hide resolved
sbi/neural_nets/density_estimators/base.py Show resolved Hide resolved
@manuelgloeckler
Copy link
Contributor Author

Thanks for your comments. I tried to fix all left-over inconsistencies now.

Copy link
Contributor

@janfb janfb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thanks!

@janfb janfb merged commit 17f3033 into main Mar 4, 2024
3 checks passed
@janfb janfb deleted the density_estimator_base_class_as_nnmodule branch March 4, 2024 13:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants