Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CategoricalMADE #1269

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
181 changes: 181 additions & 0 deletions sbi/made_mnle.ipynb
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should eventually be integrated with the MNLE tutorial in 12_iid_data_and_permutation_invariant_embeddings.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions sbi/neural_nets/estimators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from sbi.neural_nets.estimators.categorical_net import (
CategoricalMassEstimator,
CategoricalNet,
CategoricalMADE,
)
from sbi.neural_nets.estimators.flowmatching_estimator import FlowMatchingEstimator
from sbi.neural_nets.estimators.mixed_density_estimator import (
Expand Down
104 changes: 104 additions & 0 deletions sbi/neural_nets/estimators/categorical_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,116 @@
from typing import Optional

import torch
from nflows.nn.nde.made import MADE
from nflows.utils import torchutils
from torch import Tensor, nn
from torch.distributions import Categorical
from torch.nn import Sigmoid, Softmax
from torch.nn import functional as F

from sbi.neural_nets.estimators.base import ConditionalDensityEstimator


class CategoricalMADE(MADE):
def __init__(
self,
categories, # Tensor[int]
hidden_features,
context_features=None,
num_blocks=2,
use_residual_blocks=True,
random_mask=False,
activation=F.relu,
dropout_probability=0.0,
use_batch_norm=False,
epsilon=1e-2,
custom_initialization=True,
Comment on lines +20 to +30
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add types.

embedding_net: Optional[nn.Module] = nn.Identity(),
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add docstring with brief explanation and arg list.

if use_residual_blocks and random_mask:
raise ValueError("Residual blocks can't be used with random masks.")

Check warning on line 34 in sbi/neural_nets/estimators/categorical_net.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/estimators/categorical_net.py#L33-L34

Added lines #L33 - L34 were not covered by tests

self.num_variables = len(categories)
self.num_categories = int(max(categories))
self.categories = categories
Comment on lines +36 to +38
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am confused by the notion of these three variables.

what's the difference between variables and categories here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what type does categories have?

self.mask = torch.zeros(self.num_variables, self.num_categories)
for i, c in enumerate(categories):
self.mask[i, :c] = 1

Check warning on line 41 in sbi/neural_nets/estimators/categorical_net.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/estimators/categorical_net.py#L36-L41

Added lines #L36 - L41 were not covered by tests

super().__init__(

Check warning on line 43 in sbi/neural_nets/estimators/categorical_net.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/estimators/categorical_net.py#L43

Added line #L43 was not covered by tests
self.num_variables,
hidden_features,
context_features=context_features,
num_blocks=num_blocks,
output_multiplier=self.num_categories,
use_residual_blocks=use_residual_blocks,
random_mask=random_mask,
activation=activation,
dropout_probability=dropout_probability,
use_batch_norm=use_batch_norm,
)

self.embedding_net = embedding_net
self.hidden_features = hidden_features
self.epsilon = epsilon

Check warning on line 58 in sbi/neural_nets/estimators/categorical_net.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/estimators/categorical_net.py#L56-L58

Added lines #L56 - L58 were not covered by tests

if custom_initialization:
self._initialize()

Check warning on line 61 in sbi/neural_nets/estimators/categorical_net.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/estimators/categorical_net.py#L60-L61

Added lines #L60 - L61 were not covered by tests

def forward(self, inputs, context=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing types.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and short docstring.

embedded_context = self.embedding_net.forward(context)
return super().forward(inputs, context=embedded_context)

Check warning on line 65 in sbi/neural_nets/estimators/categorical_net.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/estimators/categorical_net.py#L64-L65

Added lines #L64 - L65 were not covered by tests

def compute_probs(self, outputs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add types and short docstring.

ps = F.softmax(outputs, dim=-1) * self.mask
ps = ps / ps.sum(dim=-1, keepdim=True)
return ps

Check warning on line 70 in sbi/neural_nets/estimators/categorical_net.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/estimators/categorical_net.py#L68-L70

Added lines #L68 - L70 were not covered by tests

# outputs (batch_size, num_variables, num_categories)
def log_prob(self, inputs, context=None):
Comment on lines +72 to +73
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove inline comment and instead add types, return types and docstring with details on the return dimensions if needed.

outputs = self.forward(inputs, context=context)
outputs = outputs.reshape(*inputs.shape, self.num_categories)
ps = self.compute_probs(outputs)

Check warning on line 76 in sbi/neural_nets/estimators/categorical_net.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/estimators/categorical_net.py#L74-L76

Added lines #L74 - L76 were not covered by tests

# categorical log prob
log_prob = torch.log(ps.gather(-1, inputs.unsqueeze(-1).long()))
log_prob = log_prob.squeeze(-1).sum(dim=-1)

Check warning on line 80 in sbi/neural_nets/estimators/categorical_net.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/estimators/categorical_net.py#L79-L80

Added lines #L79 - L80 were not covered by tests

return log_prob

Check warning on line 82 in sbi/neural_nets/estimators/categorical_net.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/estimators/categorical_net.py#L82

Added line #L82 was not covered by tests

def sample(self, sample_shape, context=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add types and docstring.

# Ensure sample_shape is a tuple
if isinstance(sample_shape, int):
sample_shape = (sample_shape,)
sample_shape = torch.Size(sample_shape)

Check warning on line 88 in sbi/neural_nets/estimators/categorical_net.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/estimators/categorical_net.py#L86-L88

Added lines #L86 - L88 were not covered by tests

# Calculate total number of samples
num_samples = torch.prod(torch.tensor(sample_shape)).item()

Check warning on line 91 in sbi/neural_nets/estimators/categorical_net.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/estimators/categorical_net.py#L91

Added line #L91 was not covered by tests

# Prepare context
if context is not None:
if context.ndim == 1:
context = context.unsqueeze(0)
context = torchutils.repeat_rows(context, num_samples)

Check warning on line 97 in sbi/neural_nets/estimators/categorical_net.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/estimators/categorical_net.py#L94-L97

Added lines #L94 - L97 were not covered by tests
else:
context = torch.zeros(num_samples, self.context_dim)

Check warning on line 99 in sbi/neural_nets/estimators/categorical_net.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/estimators/categorical_net.py#L99

Added line #L99 was not covered by tests

with torch.no_grad():
samples = torch.zeros(num_samples, self.num_variables)
for variable in range(self.num_variables):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't we get it for all variables at once?

outputs = self.forward(samples, context)
outputs = outputs.reshape(

Check warning on line 105 in sbi/neural_nets/estimators/categorical_net.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/estimators/categorical_net.py#L101-L105

Added lines #L101 - L105 were not covered by tests
num_samples, self.num_variables, self.num_categories
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if this is an iteration for one variable i why do we have to reshape into num_variables here?

)
ps = self.compute_probs(outputs)
samples[:, variable] = Categorical(probs=ps[:, variable]).sample()

Check warning on line 109 in sbi/neural_nets/estimators/categorical_net.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/estimators/categorical_net.py#L108-L109

Added lines #L108 - L109 were not covered by tests

return samples.reshape(*sample_shape, self.num_variables)

Check warning on line 111 in sbi/neural_nets/estimators/categorical_net.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/estimators/categorical_net.py#L111

Added line #L111 was not covered by tests

def _initialize(self):
pass

Check warning on line 114 in sbi/neural_nets/estimators/categorical_net.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/estimators/categorical_net.py#L114

Added line #L114 was not covered by tests
Comment on lines +113 to +114
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the custom init? is it a abstract method from nflows MADE?
if so, we should probably raise "not implemented" here.



class CategoricalNet(nn.Module):
"""Conditional density (mass) estimation for a categorical random variable.

Expand Down Expand Up @@ -43,6 +146,7 @@
self.activation = Sigmoid()
self.softmax = Softmax(dim=1)
self.num_categories = num_categories
self.num_variables = 1

# Maybe add embedding net in front.
if embedding_net is not None:
Expand Down
14 changes: 11 additions & 3 deletions sbi/neural_nets/estimators/mixed_density_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,10 @@ def sample(
sample_shape=sample_shape,
condition=condition,
)
# Trailing `1` because `Categorical` has event_shape `()`.
discrete_samples = discrete_samples.reshape(num_samples * batch_dim, 1)
num_variables = self.discrete_net.net.num_variables
discrete_samples = discrete_samples.reshape(
num_samples * batch_dim, num_variables
)
Comment on lines +83 to +86
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍


# repeat the batch of embedded condition to match number of choices.
condition_event_dim = embedded_condition.dim() - 1
Expand Down Expand Up @@ -145,7 +147,8 @@ def log_prob(self, input: Tensor, condition: Tensor) -> Tensor:
f"{input_batch_dim} do not match."
)

cont_input, disc_input = _separate_input(input)
num_disc = self.discrete_net.net.num_variables
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
num_disc = self.discrete_net.net.num_variables
num_discrete_variables = self.discrete_net.net.num_variables

cont_input, disc_input = _separate_input(input, num_discrete_columns=num_disc)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
cont_input, disc_input = _separate_input(input, num_discrete_columns=num_disc)
cont_input, disc_input = _separate_input(input, num_discrete_variables)

# Embed continuous condition
embedded_condition = self.condition_embedding(condition)
# expand and repeat to match batch of inputs.
Expand Down Expand Up @@ -204,3 +207,8 @@ def _separate_input(
Assumes the discrete data to live in the last columns of input.
"""
return input[..., :-num_discrete_columns], input[..., -num_discrete_columns:]


def _is_discrete(input: Tensor) -> Tensor:
"""Infer discrete columns in input data."""
return torch.tensor([torch.allclose(col, col.round()) for col in input.T])
72 changes: 66 additions & 6 deletions sbi/neural_nets/net_builders/categorial.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

import warnings
from typing import Optional

from torch import Tensor, nn, unique
from torch import Tensor, nn, tensor, unique

from sbi.neural_nets.estimators import CategoricalMassEstimator, CategoricalNet
from sbi.utils.nn_utils import get_numel
from sbi.utils.sbiutils import (
standardizing_net,
z_score_parser,
from sbi.neural_nets.estimators import (
CategoricalMADE,
CategoricalMassEstimator,
CategoricalNet,
)
from sbi.neural_nets.estimators.mixed_density_estimator import _is_discrete
from sbi.utils.nn_utils import get_numel
from sbi.utils.sbiutils import standardizing_net, z_score_parser
from sbi.utils.user_input_checks import check_data_device


Expand Down Expand Up @@ -61,3 +64,60 @@
return CategoricalMassEstimator(
categorical_net, input_shape=batch_x[0].shape, condition_shape=batch_y[0].shape
)


def build_autoregressive_categoricalmassestimator(
batch_x: Tensor,
batch_y: Tensor,
z_score_x: Optional[str] = "none",
z_score_y: Optional[str] = "independent",
num_hidden: int = 20,
num_layers: int = 2,
categories: Optional[Tensor] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add to docstring.

I suggest renaming this to num_categories. and I like the approach with the warning that if it is None it will be inferred from the data.

embedding_net: nn.Module = nn.Identity(),
):
"""Returns a density estimator for a categorical random variable.

Args:
batch_x: A batch of input data.
batch_y: A batch of condition data.
z_score_x: Whether to z-score the input data.
z_score_y: Whether to z-score the condition data.
num_hidden: Number of hidden units per layer.
num_layers: Number of hidden layers.
embedding_net: Embedding net for y.
"""

if z_score_x != "none":
raise ValueError("Categorical input should not be z-scored.")
if categories is None:
warnings.warn(

Check warning on line 94 in sbi/neural_nets/net_builders/categorial.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/net_builders/categorial.py#L91-L94

Added lines #L91 - L94 were not covered by tests
"Inferring categories from batch_x. Ensure all categories are present.",
stacklevel=2,
)

check_data_device(batch_x, batch_y)

Check warning on line 99 in sbi/neural_nets/net_builders/categorial.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/net_builders/categorial.py#L99

Added line #L99 was not covered by tests

z_score_y_bool, structured_y = z_score_parser(z_score_y)
y_numel = get_numel(batch_y, embedding_net=embedding_net)

Check warning on line 102 in sbi/neural_nets/net_builders/categorial.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/net_builders/categorial.py#L101-L102

Added lines #L101 - L102 were not covered by tests

if z_score_y_bool:
embedding_net = nn.Sequential(

Check warning on line 105 in sbi/neural_nets/net_builders/categorial.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/net_builders/categorial.py#L104-L105

Added lines #L104 - L105 were not covered by tests
standardizing_net(batch_y, structured_y), embedding_net
)

batch_x_discrete = batch_x[:, _is_discrete(batch_x)]
inferred_categories = tensor([unique(col).numel() for col in batch_x_discrete.T])
Comment on lines +109 to +110
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these two lines should be executed only if categories is None.

categories = categories if categories is not None else inferred_categories

Check warning on line 111 in sbi/neural_nets/net_builders/categorial.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/net_builders/categorial.py#L109-L111

Added lines #L109 - L111 were not covered by tests

categorical_net = CategoricalMADE(

Check warning on line 113 in sbi/neural_nets/net_builders/categorial.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/net_builders/categorial.py#L113

Added line #L113 was not covered by tests
categories=categories,
hidden_features=num_hidden,
context_features=y_numel,
num_blocks=num_layers,
embedding_net=embedding_net,
)

return CategoricalMassEstimator(

Check warning on line 121 in sbi/neural_nets/net_builders/categorial.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/net_builders/categorial.py#L121

Added line #L121 was not covered by tests
categorical_net, input_shape=batch_x[0].shape, condition_shape=batch_y[0].shape
)
59 changes: 41 additions & 18 deletions sbi/neural_nets/net_builders/mnle.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,14 @@
from torch import Tensor, nn

from sbi.neural_nets.estimators import MixedDensityEstimator
from sbi.neural_nets.estimators.mixed_density_estimator import _separate_input
from sbi.neural_nets.net_builders.categorial import build_categoricalmassestimator
from sbi.neural_nets.estimators.mixed_density_estimator import (
_is_discrete,
_separate_input,
)
from sbi.neural_nets.net_builders.categorial import (
build_autoregressive_categoricalmassestimator,
build_categoricalmassestimator,
)
from sbi.neural_nets.net_builders.flow import (
build_made,
build_maf,
Expand All @@ -26,10 +32,7 @@
build_zuko_unaf,
)
from sbi.neural_nets.net_builders.mdn import build_mdn
from sbi.utils.sbiutils import (
standardizing_net,
z_score_parser,
)
from sbi.utils.sbiutils import standardizing_net, z_score_parser
from sbi.utils.user_input_checks import check_data_device

model_builders = {
Expand All @@ -56,6 +59,7 @@
z_score_x: Optional[str] = "independent",
z_score_y: Optional[str] = "independent",
flow_model: str = "nsf",
categorical_model: str = "mlp",
embedding_net: nn.Module = nn.Identity(),
combined_embedding_net: Optional[nn.Module] = None,
num_transforms: int = 2,
Expand Down Expand Up @@ -102,6 +106,8 @@
as z_score_x.
flow_model: type of flow model to use for the continuous part of the
data.
categorical_model: type of categorical net to use for the discrete part of
the data. Can be "made" or "mlp".
embedding_net: Optional embedding network for y, required if y is > 1D.
combined_embedding_net: Optional embedding for combining the discrete
part of the input and the embedded condition into a joined
Expand All @@ -125,13 +131,14 @@

warnings.warn(
"The mixed neural likelihood estimator assumes that x contains "
"continuous data in the first n-1 columns (e.g., reaction times) and "
"categorical data in the last column (e.g., corresponding choices). If "
"continuous data in the first n-k columns (e.g., reaction times) and "
"categorical data in the last k columns (e.g., corresponding choices). If "
"this is not the case for the passed `x` do not use this function.",
stacklevel=2,
)
# Separate continuous and discrete data.
cont_x, disc_x = _separate_input(batch_x)
num_disc = int(torch.sum(_is_discrete(batch_x)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add num_categorical_columns as input arg to this function to enable users to pass this number. Inferring it from batch_x should be the fallback with warning.

cont_x, disc_x = _separate_input(batch_x, num_discrete_columns=num_disc)

# Set up y-embedding net with z-scoring.
z_score_y_bool, structured_y = z_score_parser(z_score_y)
Expand All @@ -144,15 +151,31 @@
combined_condition = torch.cat([disc_x, embedded_batch_y], dim=-1)

# Set up a categorical RV neural net for modelling the discrete data.
discrete_net = build_categoricalmassestimator(
disc_x,
batch_y,
z_score_x="none", # discrete data should not be z-scored.
z_score_y="none", # y-embedding net already z-scores.
num_hidden=hidden_features,
num_layers=hidden_layers,
embedding_net=embedding_net,
)
if categorical_model == "made":
discrete_net = build_autoregressive_categoricalmassestimator(

Check warning on line 155 in sbi/neural_nets/net_builders/mnle.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/net_builders/mnle.py#L155

Added line #L155 was not covered by tests
disc_x,
batch_y,
z_score_x="none", # discrete data should not be z-scored.
z_score_y="none", # y-embedding net already z-scores.
num_hidden=hidden_features,
num_layers=hidden_layers,
embedding_net=embedding_net,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pass num_categories here as well.

)
elif categorical_model == "mlp":
assert num_disc == 1, "MLP only supports 1D input."
discrete_net = build_categoricalmassestimator(
disc_x,
batch_y,
z_score_x="none", # discrete data should not be z-scored.
z_score_y="none", # y-embedding net already z-scores.
num_hidden=hidden_features,
num_layers=hidden_layers,
embedding_net=embedding_net,
)
Comment on lines +164 to +174
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

more generally, isn't the MLP a special case of the MADE? can't we absorb them into one class?

else:
raise ValueError(

Check warning on line 176 in sbi/neural_nets/net_builders/mnle.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/net_builders/mnle.py#L176

Added line #L176 was not covered by tests
f"Unknown categorical net {categorical_model}. Must be 'made' or 'mlp'."
)

if combined_embedding_net is None:
# set up linear embedding net for combining discrete and continuous
Expand Down
Loading