-
Notifications
You must be signed in to change notification settings - Fork 151
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* base density estimator class and simple flow wrapper * NFlowsNSF Density Estimator and API tests * formatting and batchwise sampling for NFlows * final formatting and change name to NFlowsFlow
- Loading branch information
Showing
4 changed files
with
177 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from sbi.neural_nets.density_estimators.flow import NFlowsFlow |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import torch | ||
from torch import Tensor, nn | ||
|
||
|
||
class DensityEstimator: | ||
r"""Base class for density estimators. | ||
The density estimator class is a wrapper around neural networks that | ||
allows to evaluate the `log_prob`, `sample`, and provide the `loss` of $theta,x$ | ||
pairs. | ||
""" | ||
|
||
def __init__(self, net: nn.Module) -> None: | ||
r"""Base class for density estimators. | ||
Args: | ||
net: Neural network. | ||
""" | ||
|
||
self.net = net | ||
|
||
def log_prob(self, input: Tensor, condition: Tensor, **kwargs) -> Tensor: | ||
r"""Return the batched log probabilities of the inputs given the conditions. | ||
Args: | ||
input: Inputs to evaluate the log probability of. Must have batch dimension. | ||
x: Conditions. Must have batch dimension. | ||
Returns: | ||
Sample-wise log probabilities. | ||
""" | ||
|
||
raise NotImplementedError | ||
|
||
def loss(self, input: Tensor, condition: Tensor, **kwargs) -> Tensor: | ||
r"""Return the loss for training the density estimator. | ||
Args: | ||
input: Inputs to evaluate the loss on. | ||
condition: Conditions. | ||
Returns: | ||
Loss. | ||
""" | ||
|
||
raise NotImplementedError | ||
|
||
def sample(self, sample_shape: torch.Size, condition: Tensor, **kwargs) -> Tensor: | ||
r"""Return samples from the density estimator. | ||
Args: | ||
sample_shape: Shape of the samples to return. | ||
Returns: | ||
Samples. | ||
""" | ||
|
||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import torch | ||
from pyknos.nflows import flows | ||
from torch import Tensor | ||
|
||
from sbi.neural_nets.density_estimators.base import DensityEstimator | ||
from sbi.types import Shape | ||
|
||
|
||
class NFlowsFlow(DensityEstimator): | ||
r"""`nflows`- based normalizing flow density estimator. | ||
Flow type objects already have a .log_prob() and .sample() method, so here we just | ||
wrap them and add the .loss() method. | ||
""" | ||
|
||
def __init__(self, net: flows.Flow): | ||
|
||
super().__init__(net) | ||
|
||
def log_prob(self, input: Tensor, condition: Tensor) -> Tensor: | ||
r"""Return the batched log probabilities of the inputs given the conditions. | ||
Args: | ||
input: Inputs to evaluate the log probability of. Must have batch dimension. | ||
condition: Conditions. Must have batch dimension. | ||
Returns: | ||
Sample-wise log probabilities. | ||
""" | ||
return self.net.log_prob(input, context=condition) | ||
|
||
def loss(self, input: Tensor, condition: Tensor) -> Tensor: | ||
r"""Return the loss for training the density estimator. | ||
Args: | ||
input: Inputs to evaluate the loss on. Must be batched. | ||
condition: Conditions. Must be batched. | ||
Returns: | ||
Negative log-probability. | ||
""" | ||
|
||
return -self.log_prob(input, condition) | ||
|
||
def sample(self, sample_shape: Shape, condition: Tensor) -> Tensor: | ||
r"""Return samples from the density estimator. | ||
Args: | ||
sample_shape: Batch dimensions of the samples to return | ||
condition: Condition. | ||
Returns: | ||
Samples. | ||
""" | ||
|
||
num_samples = torch.Size(sample_shape).numel() | ||
|
||
# nflows.sample() expects conditions to be batched. | ||
if len(condition.shape) == 1: | ||
condition = condition.unsqueeze(0) | ||
return self.net.sample(num_samples, context=condition).reshape( | ||
(*sample_shape, -1) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed | ||
# under the Affero General Public License v3, see <https://www.gnu.org/licenses/>. | ||
|
||
from __future__ import annotations | ||
|
||
import pytest | ||
from torch import eye, zeros | ||
from torch.distributions import MultivariateNormal | ||
|
||
from sbi.neural_nets.density_estimators.flow import NFlowsFlow | ||
from sbi.neural_nets.flow import build_nsf | ||
|
||
|
||
@pytest.mark.parametrize("density_estimator", (NFlowsFlow,)) | ||
@pytest.mark.parametrize("input_dim", (1, 2)) | ||
@pytest.mark.parametrize("context_dim", (1, 2)) | ||
def test_api_density_estimator(density_estimator, input_dim, context_dim): | ||
r"""Checks whether we can evaluate and sample from density estimators correctly. | ||
Args: | ||
density_estimator: DensityEstimator subclass. | ||
input_dim: Dimensionality of the input. | ||
context_dim: Dimensionality of the context. | ||
""" | ||
|
||
nsamples = 10 | ||
nsamples_test = 5 | ||
|
||
input_mvn = MultivariateNormal( | ||
loc=zeros(input_dim), covariance_matrix=eye(input_dim) | ||
) | ||
batch_input = input_mvn.sample((nsamples,)) | ||
context_mvn = MultivariateNormal( | ||
loc=zeros(context_dim), covariance_matrix=eye(context_dim) | ||
) | ||
batch_context = context_mvn.sample((nsamples,)) | ||
|
||
net = build_nsf(batch_input, batch_context, hidden_features=10, num_transforms=2) | ||
estimator = density_estimator(net) | ||
|
||
log_probs = estimator.log_prob(batch_input, batch_context) | ||
assert log_probs.shape == (nsamples,), "log_prob shape is not correct" | ||
|
||
loss = estimator.loss(batch_input, batch_context) | ||
assert loss.shape == (nsamples,), "loss shape is not correct" | ||
|
||
samples = estimator.sample((nsamples_test,), batch_context[0]) | ||
assert samples.shape == (nsamples_test, input_dim), "samples shape is not correct" | ||
|
||
samples = estimator.sample((2, nsamples_test), batch_context[0]) | ||
assert samples.shape == ( | ||
2, | ||
nsamples_test, | ||
input_dim, | ||
), "samples shape is not correct" |