From 9d6d13ced8f56c77b27eb04ce5bffbd569ff8124 Mon Sep 17 00:00:00 2001 From: Guy Moss <91739128+gmoss13@users.noreply.github.com> Date: Wed, 28 Feb 2024 09:18:24 +0100 Subject: [PATCH] Density Estimator Class (#952) * base density estimator class and simple flow wrapper * NFlowsNSF Density Estimator and API tests * formatting and batchwise sampling for NFlows * final formatting and change name to NFlowsFlow --- .../density_estimators/__init__.py | 1 + sbi/neural_nets/density_estimators/base.py | 58 +++++++++++++++++ sbi/neural_nets/density_estimators/flow.py | 63 +++++++++++++++++++ tests/density_estimator_test.py | 55 ++++++++++++++++ 4 files changed, 177 insertions(+) create mode 100644 sbi/neural_nets/density_estimators/__init__.py create mode 100644 sbi/neural_nets/density_estimators/base.py create mode 100644 sbi/neural_nets/density_estimators/flow.py create mode 100644 tests/density_estimator_test.py diff --git a/sbi/neural_nets/density_estimators/__init__.py b/sbi/neural_nets/density_estimators/__init__.py new file mode 100644 index 000000000..ac9d35085 --- /dev/null +++ b/sbi/neural_nets/density_estimators/__init__.py @@ -0,0 +1 @@ +from sbi.neural_nets.density_estimators.flow import NFlowsFlow diff --git a/sbi/neural_nets/density_estimators/base.py b/sbi/neural_nets/density_estimators/base.py new file mode 100644 index 000000000..4999faf65 --- /dev/null +++ b/sbi/neural_nets/density_estimators/base.py @@ -0,0 +1,58 @@ +import torch +from torch import Tensor, nn + + +class DensityEstimator: + r"""Base class for density estimators. + + The density estimator class is a wrapper around neural networks that + allows to evaluate the `log_prob`, `sample`, and provide the `loss` of $theta,x$ + pairs. + """ + + def __init__(self, net: nn.Module) -> None: + r"""Base class for density estimators. + + Args: + net: Neural network. + """ + + self.net = net + + def log_prob(self, input: Tensor, condition: Tensor, **kwargs) -> Tensor: + r"""Return the batched log probabilities of the inputs given the conditions. + + Args: + input: Inputs to evaluate the log probability of. Must have batch dimension. + x: Conditions. Must have batch dimension. + + Returns: + Sample-wise log probabilities. + """ + + raise NotImplementedError + + def loss(self, input: Tensor, condition: Tensor, **kwargs) -> Tensor: + r"""Return the loss for training the density estimator. + + Args: + input: Inputs to evaluate the loss on. + condition: Conditions. + + Returns: + Loss. + """ + + raise NotImplementedError + + def sample(self, sample_shape: torch.Size, condition: Tensor, **kwargs) -> Tensor: + r"""Return samples from the density estimator. + + Args: + sample_shape: Shape of the samples to return. + + Returns: + Samples. + """ + + raise NotImplementedError diff --git a/sbi/neural_nets/density_estimators/flow.py b/sbi/neural_nets/density_estimators/flow.py new file mode 100644 index 000000000..10a1979b4 --- /dev/null +++ b/sbi/neural_nets/density_estimators/flow.py @@ -0,0 +1,63 @@ +import torch +from pyknos.nflows import flows +from torch import Tensor + +from sbi.neural_nets.density_estimators.base import DensityEstimator +from sbi.types import Shape + + +class NFlowsFlow(DensityEstimator): + r"""`nflows`- based normalizing flow density estimator. + + Flow type objects already have a .log_prob() and .sample() method, so here we just + wrap them and add the .loss() method. + """ + + def __init__(self, net: flows.Flow): + + super().__init__(net) + + def log_prob(self, input: Tensor, condition: Tensor) -> Tensor: + r"""Return the batched log probabilities of the inputs given the conditions. + + Args: + input: Inputs to evaluate the log probability of. Must have batch dimension. + condition: Conditions. Must have batch dimension. + + Returns: + Sample-wise log probabilities. + """ + return self.net.log_prob(input, context=condition) + + def loss(self, input: Tensor, condition: Tensor) -> Tensor: + r"""Return the loss for training the density estimator. + + Args: + input: Inputs to evaluate the loss on. Must be batched. + condition: Conditions. Must be batched. + + Returns: + Negative log-probability. + """ + + return -self.log_prob(input, condition) + + def sample(self, sample_shape: Shape, condition: Tensor) -> Tensor: + r"""Return samples from the density estimator. + + Args: + sample_shape: Batch dimensions of the samples to return + condition: Condition. + + Returns: + Samples. + """ + + num_samples = torch.Size(sample_shape).numel() + + # nflows.sample() expects conditions to be batched. + if len(condition.shape) == 1: + condition = condition.unsqueeze(0) + return self.net.sample(num_samples, context=condition).reshape( + (*sample_shape, -1) + ) diff --git a/tests/density_estimator_test.py b/tests/density_estimator_test.py new file mode 100644 index 000000000..339070660 --- /dev/null +++ b/tests/density_estimator_test.py @@ -0,0 +1,55 @@ +# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed +# under the Affero General Public License v3, see . + +from __future__ import annotations + +import pytest +from torch import eye, zeros +from torch.distributions import MultivariateNormal + +from sbi.neural_nets.density_estimators.flow import NFlowsFlow +from sbi.neural_nets.flow import build_nsf + + +@pytest.mark.parametrize("density_estimator", (NFlowsFlow,)) +@pytest.mark.parametrize("input_dim", (1, 2)) +@pytest.mark.parametrize("context_dim", (1, 2)) +def test_api_density_estimator(density_estimator, input_dim, context_dim): + r"""Checks whether we can evaluate and sample from density estimators correctly. + + Args: + density_estimator: DensityEstimator subclass. + input_dim: Dimensionality of the input. + context_dim: Dimensionality of the context. + """ + + nsamples = 10 + nsamples_test = 5 + + input_mvn = MultivariateNormal( + loc=zeros(input_dim), covariance_matrix=eye(input_dim) + ) + batch_input = input_mvn.sample((nsamples,)) + context_mvn = MultivariateNormal( + loc=zeros(context_dim), covariance_matrix=eye(context_dim) + ) + batch_context = context_mvn.sample((nsamples,)) + + net = build_nsf(batch_input, batch_context, hidden_features=10, num_transforms=2) + estimator = density_estimator(net) + + log_probs = estimator.log_prob(batch_input, batch_context) + assert log_probs.shape == (nsamples,), "log_prob shape is not correct" + + loss = estimator.loss(batch_input, batch_context) + assert loss.shape == (nsamples,), "loss shape is not correct" + + samples = estimator.sample((nsamples_test,), batch_context[0]) + assert samples.shape == (nsamples_test, input_dim), "samples shape is not correct" + + samples = estimator.sample((2, nsamples_test), batch_context[0]) + assert samples.shape == ( + 2, + nsamples_test, + input_dim, + ), "samples shape is not correct"