Skip to content

Commit

Permalink
Density Estimator Class (#952)
Browse files Browse the repository at this point in the history
* base density estimator class and simple flow wrapper

* NFlowsNSF Density Estimator and API tests

* formatting and batchwise sampling for NFlows

* final formatting and change name to NFlowsFlow
  • Loading branch information
gmoss13 authored Feb 28, 2024
1 parent 3aeb775 commit 9d6d13c
Show file tree
Hide file tree
Showing 4 changed files with 177 additions and 0 deletions.
1 change: 1 addition & 0 deletions sbi/neural_nets/density_estimators/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from sbi.neural_nets.density_estimators.flow import NFlowsFlow
58 changes: 58 additions & 0 deletions sbi/neural_nets/density_estimators/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import torch
from torch import Tensor, nn


class DensityEstimator:
r"""Base class for density estimators.
The density estimator class is a wrapper around neural networks that
allows to evaluate the `log_prob`, `sample`, and provide the `loss` of $theta,x$
pairs.
"""

def __init__(self, net: nn.Module) -> None:
r"""Base class for density estimators.
Args:
net: Neural network.
"""

self.net = net

def log_prob(self, input: Tensor, condition: Tensor, **kwargs) -> Tensor:
r"""Return the batched log probabilities of the inputs given the conditions.
Args:
input: Inputs to evaluate the log probability of. Must have batch dimension.
x: Conditions. Must have batch dimension.
Returns:
Sample-wise log probabilities.
"""

raise NotImplementedError

def loss(self, input: Tensor, condition: Tensor, **kwargs) -> Tensor:
r"""Return the loss for training the density estimator.
Args:
input: Inputs to evaluate the loss on.
condition: Conditions.
Returns:
Loss.
"""

raise NotImplementedError

def sample(self, sample_shape: torch.Size, condition: Tensor, **kwargs) -> Tensor:
r"""Return samples from the density estimator.
Args:
sample_shape: Shape of the samples to return.
Returns:
Samples.
"""

raise NotImplementedError
63 changes: 63 additions & 0 deletions sbi/neural_nets/density_estimators/flow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import torch
from pyknos.nflows import flows
from torch import Tensor

from sbi.neural_nets.density_estimators.base import DensityEstimator
from sbi.types import Shape


class NFlowsFlow(DensityEstimator):
r"""`nflows`- based normalizing flow density estimator.
Flow type objects already have a .log_prob() and .sample() method, so here we just
wrap them and add the .loss() method.
"""

def __init__(self, net: flows.Flow):

super().__init__(net)

def log_prob(self, input: Tensor, condition: Tensor) -> Tensor:
r"""Return the batched log probabilities of the inputs given the conditions.
Args:
input: Inputs to evaluate the log probability of. Must have batch dimension.
condition: Conditions. Must have batch dimension.
Returns:
Sample-wise log probabilities.
"""
return self.net.log_prob(input, context=condition)

def loss(self, input: Tensor, condition: Tensor) -> Tensor:
r"""Return the loss for training the density estimator.
Args:
input: Inputs to evaluate the loss on. Must be batched.
condition: Conditions. Must be batched.
Returns:
Negative log-probability.
"""

return -self.log_prob(input, condition)

def sample(self, sample_shape: Shape, condition: Tensor) -> Tensor:
r"""Return samples from the density estimator.
Args:
sample_shape: Batch dimensions of the samples to return
condition: Condition.
Returns:
Samples.
"""

num_samples = torch.Size(sample_shape).numel()

# nflows.sample() expects conditions to be batched.
if len(condition.shape) == 1:
condition = condition.unsqueeze(0)
return self.net.sample(num_samples, context=condition).reshape(
(*sample_shape, -1)
)
55 changes: 55 additions & 0 deletions tests/density_estimator_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Affero General Public License v3, see <https://www.gnu.org/licenses/>.

from __future__ import annotations

import pytest
from torch import eye, zeros
from torch.distributions import MultivariateNormal

from sbi.neural_nets.density_estimators.flow import NFlowsFlow
from sbi.neural_nets.flow import build_nsf


@pytest.mark.parametrize("density_estimator", (NFlowsFlow,))
@pytest.mark.parametrize("input_dim", (1, 2))
@pytest.mark.parametrize("context_dim", (1, 2))
def test_api_density_estimator(density_estimator, input_dim, context_dim):
r"""Checks whether we can evaluate and sample from density estimators correctly.
Args:
density_estimator: DensityEstimator subclass.
input_dim: Dimensionality of the input.
context_dim: Dimensionality of the context.
"""

nsamples = 10
nsamples_test = 5

input_mvn = MultivariateNormal(
loc=zeros(input_dim), covariance_matrix=eye(input_dim)
)
batch_input = input_mvn.sample((nsamples,))
context_mvn = MultivariateNormal(
loc=zeros(context_dim), covariance_matrix=eye(context_dim)
)
batch_context = context_mvn.sample((nsamples,))

net = build_nsf(batch_input, batch_context, hidden_features=10, num_transforms=2)
estimator = density_estimator(net)

log_probs = estimator.log_prob(batch_input, batch_context)
assert log_probs.shape == (nsamples,), "log_prob shape is not correct"

loss = estimator.loss(batch_input, batch_context)
assert loss.shape == (nsamples,), "loss shape is not correct"

samples = estimator.sample((nsamples_test,), batch_context[0])
assert samples.shape == (nsamples_test, input_dim), "samples shape is not correct"

samples = estimator.sample((2, nsamples_test), batch_context[0])
assert samples.shape == (
2,
nsamples_test,
input_dim,
), "samples shape is not correct"

0 comments on commit 9d6d13c

Please sign in to comment.