-
Notifications
You must be signed in to change notification settings - Fork 151
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add CategoricalMADE
#1269
base: main
Are you sure you want to change the base?
Add CategoricalMADE
#1269
Changes from all commits
57ecf42
40c657a
d6dc444
298971d
e214fe1
d070d2a
089d1d3
045bf5e
bcc75db
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,13 +4,116 @@ | |
from typing import Optional | ||
|
||
import torch | ||
from nflows.nn.nde.made import MADE | ||
from nflows.utils import torchutils | ||
from torch import Tensor, nn | ||
from torch.distributions import Categorical | ||
from torch.nn import Sigmoid, Softmax | ||
from torch.nn import functional as F | ||
|
||
from sbi.neural_nets.estimators.base import ConditionalDensityEstimator | ||
|
||
|
||
class CategoricalMADE(MADE): | ||
def __init__( | ||
self, | ||
categories, # Tensor[int] | ||
hidden_features, | ||
context_features=None, | ||
num_blocks=2, | ||
use_residual_blocks=True, | ||
random_mask=False, | ||
activation=F.relu, | ||
dropout_probability=0.0, | ||
use_batch_norm=False, | ||
epsilon=1e-2, | ||
custom_initialization=True, | ||
Comment on lines
+20
to
+30
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please add types. |
||
embedding_net: Optional[nn.Module] = nn.Identity(), | ||
): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please add docstring with brief explanation and arg list. |
||
if use_residual_blocks and random_mask: | ||
raise ValueError("Residual blocks can't be used with random masks.") | ||
|
||
self.num_variables = len(categories) | ||
self.num_categories = int(max(categories)) | ||
self.categories = categories | ||
Comment on lines
+36
to
+38
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am confused by the notion of these three variables. what's the difference between variables and categories here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what type does |
||
self.mask = torch.zeros(self.num_variables, self.num_categories) | ||
for i, c in enumerate(categories): | ||
self.mask[i, :c] = 1 | ||
|
||
super().__init__( | ||
self.num_variables, | ||
hidden_features, | ||
context_features=context_features, | ||
num_blocks=num_blocks, | ||
output_multiplier=self.num_categories, | ||
use_residual_blocks=use_residual_blocks, | ||
random_mask=random_mask, | ||
activation=activation, | ||
dropout_probability=dropout_probability, | ||
use_batch_norm=use_batch_norm, | ||
) | ||
|
||
self.embedding_net = embedding_net | ||
self.hidden_features = hidden_features | ||
self.epsilon = epsilon | ||
|
||
if custom_initialization: | ||
self._initialize() | ||
|
||
def forward(self, inputs, context=None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. missing types. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. and short docstring. |
||
embedded_context = self.embedding_net.forward(context) | ||
return super().forward(inputs, context=embedded_context) | ||
|
||
def compute_probs(self, outputs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add types and short docstring. |
||
ps = F.softmax(outputs, dim=-1) * self.mask | ||
ps = ps / ps.sum(dim=-1, keepdim=True) | ||
return ps | ||
|
||
# outputs (batch_size, num_variables, num_categories) | ||
def log_prob(self, inputs, context=None): | ||
Comment on lines
+72
to
+73
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove inline comment and instead add types, return types and docstring with details on the return dimensions if needed. |
||
outputs = self.forward(inputs, context=context) | ||
outputs = outputs.reshape(*inputs.shape, self.num_categories) | ||
ps = self.compute_probs(outputs) | ||
|
||
# categorical log prob | ||
log_prob = torch.log(ps.gather(-1, inputs.unsqueeze(-1).long())) | ||
log_prob = log_prob.squeeze(-1).sum(dim=-1) | ||
|
||
return log_prob | ||
|
||
def sample(self, sample_shape, context=None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add types and docstring. |
||
# Ensure sample_shape is a tuple | ||
if isinstance(sample_shape, int): | ||
sample_shape = (sample_shape,) | ||
sample_shape = torch.Size(sample_shape) | ||
|
||
# Calculate total number of samples | ||
num_samples = torch.prod(torch.tensor(sample_shape)).item() | ||
|
||
# Prepare context | ||
if context is not None: | ||
if context.ndim == 1: | ||
context = context.unsqueeze(0) | ||
context = torchutils.repeat_rows(context, num_samples) | ||
else: | ||
context = torch.zeros(num_samples, self.context_dim) | ||
|
||
with torch.no_grad(): | ||
samples = torch.zeros(num_samples, self.num_variables) | ||
for variable in range(self.num_variables): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can't we get it for all variables at once? |
||
outputs = self.forward(samples, context) | ||
outputs = outputs.reshape( | ||
num_samples, self.num_variables, self.num_categories | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if this is an iteration for one variable |
||
) | ||
ps = self.compute_probs(outputs) | ||
samples[:, variable] = Categorical(probs=ps[:, variable]).sample() | ||
|
||
return samples.reshape(*sample_shape, self.num_variables) | ||
|
||
def _initialize(self): | ||
pass | ||
Comment on lines
+113
to
+114
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what is the custom init? is it a abstract method from |
||
|
||
|
||
class CategoricalNet(nn.Module): | ||
"""Conditional density (mass) estimation for a categorical random variable. | ||
|
||
|
@@ -43,6 +146,7 @@ | |
self.activation = Sigmoid() | ||
self.softmax = Softmax(dim=1) | ||
self.num_categories = num_categories | ||
self.num_variables = 1 | ||
|
||
# Maybe add embedding net in front. | ||
if embedding_net is not None: | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -80,8 +80,10 @@ def sample( | |||||
sample_shape=sample_shape, | ||||||
condition=condition, | ||||||
) | ||||||
# Trailing `1` because `Categorical` has event_shape `()`. | ||||||
discrete_samples = discrete_samples.reshape(num_samples * batch_dim, 1) | ||||||
num_variables = self.discrete_net.net.num_variables | ||||||
discrete_samples = discrete_samples.reshape( | ||||||
num_samples * batch_dim, num_variables | ||||||
) | ||||||
Comment on lines
+83
to
+86
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 |
||||||
|
||||||
# repeat the batch of embedded condition to match number of choices. | ||||||
condition_event_dim = embedded_condition.dim() - 1 | ||||||
|
@@ -145,7 +147,8 @@ def log_prob(self, input: Tensor, condition: Tensor) -> Tensor: | |||||
f"{input_batch_dim} do not match." | ||||||
) | ||||||
|
||||||
cont_input, disc_input = _separate_input(input) | ||||||
num_disc = self.discrete_net.net.num_variables | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
cont_input, disc_input = _separate_input(input, num_discrete_columns=num_disc) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
# Embed continuous condition | ||||||
embedded_condition = self.condition_embedding(condition) | ||||||
# expand and repeat to match batch of inputs. | ||||||
|
@@ -204,3 +207,8 @@ def _separate_input( | |||||
Assumes the discrete data to live in the last columns of input. | ||||||
""" | ||||||
return input[..., :-num_discrete_columns], input[..., -num_discrete_columns:] | ||||||
|
||||||
|
||||||
def _is_discrete(input: Tensor) -> Tensor: | ||||||
"""Infer discrete columns in input data.""" | ||||||
return torch.tensor([torch.allclose(col, col.round()) for col in input.T]) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,16 +1,19 @@ | ||
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed | ||
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/> | ||
|
||
import warnings | ||
from typing import Optional | ||
|
||
from torch import Tensor, nn, unique | ||
from torch import Tensor, nn, tensor, unique | ||
|
||
from sbi.neural_nets.estimators import CategoricalMassEstimator, CategoricalNet | ||
from sbi.utils.nn_utils import get_numel | ||
from sbi.utils.sbiutils import ( | ||
standardizing_net, | ||
z_score_parser, | ||
from sbi.neural_nets.estimators import ( | ||
CategoricalMADE, | ||
CategoricalMassEstimator, | ||
CategoricalNet, | ||
) | ||
from sbi.neural_nets.estimators.mixed_density_estimator import _is_discrete | ||
from sbi.utils.nn_utils import get_numel | ||
from sbi.utils.sbiutils import standardizing_net, z_score_parser | ||
from sbi.utils.user_input_checks import check_data_device | ||
|
||
|
||
|
@@ -61,3 +64,60 @@ | |
return CategoricalMassEstimator( | ||
categorical_net, input_shape=batch_x[0].shape, condition_shape=batch_y[0].shape | ||
) | ||
|
||
|
||
def build_autoregressive_categoricalmassestimator( | ||
batch_x: Tensor, | ||
batch_y: Tensor, | ||
z_score_x: Optional[str] = "none", | ||
z_score_y: Optional[str] = "independent", | ||
num_hidden: int = 20, | ||
num_layers: int = 2, | ||
categories: Optional[Tensor] = None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add to docstring. I suggest renaming this to |
||
embedding_net: nn.Module = nn.Identity(), | ||
): | ||
"""Returns a density estimator for a categorical random variable. | ||
|
||
Args: | ||
batch_x: A batch of input data. | ||
batch_y: A batch of condition data. | ||
z_score_x: Whether to z-score the input data. | ||
z_score_y: Whether to z-score the condition data. | ||
num_hidden: Number of hidden units per layer. | ||
num_layers: Number of hidden layers. | ||
embedding_net: Embedding net for y. | ||
""" | ||
|
||
if z_score_x != "none": | ||
raise ValueError("Categorical input should not be z-scored.") | ||
if categories is None: | ||
warnings.warn( | ||
"Inferring categories from batch_x. Ensure all categories are present.", | ||
stacklevel=2, | ||
) | ||
|
||
check_data_device(batch_x, batch_y) | ||
|
||
z_score_y_bool, structured_y = z_score_parser(z_score_y) | ||
y_numel = get_numel(batch_y, embedding_net=embedding_net) | ||
|
||
if z_score_y_bool: | ||
embedding_net = nn.Sequential( | ||
standardizing_net(batch_y, structured_y), embedding_net | ||
) | ||
|
||
batch_x_discrete = batch_x[:, _is_discrete(batch_x)] | ||
inferred_categories = tensor([unique(col).numel() for col in batch_x_discrete.T]) | ||
Comment on lines
+109
to
+110
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. these two lines should be executed only |
||
categories = categories if categories is not None else inferred_categories | ||
|
||
categorical_net = CategoricalMADE( | ||
categories=categories, | ||
hidden_features=num_hidden, | ||
context_features=y_numel, | ||
num_blocks=num_layers, | ||
embedding_net=embedding_net, | ||
) | ||
|
||
return CategoricalMassEstimator( | ||
categorical_net, input_shape=batch_x[0].shape, condition_shape=batch_y[0].shape | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,8 +8,14 @@ | |
from torch import Tensor, nn | ||
|
||
from sbi.neural_nets.estimators import MixedDensityEstimator | ||
from sbi.neural_nets.estimators.mixed_density_estimator import _separate_input | ||
from sbi.neural_nets.net_builders.categorial import build_categoricalmassestimator | ||
from sbi.neural_nets.estimators.mixed_density_estimator import ( | ||
_is_discrete, | ||
_separate_input, | ||
) | ||
from sbi.neural_nets.net_builders.categorial import ( | ||
build_autoregressive_categoricalmassestimator, | ||
build_categoricalmassestimator, | ||
) | ||
from sbi.neural_nets.net_builders.flow import ( | ||
build_made, | ||
build_maf, | ||
|
@@ -26,10 +32,7 @@ | |
build_zuko_unaf, | ||
) | ||
from sbi.neural_nets.net_builders.mdn import build_mdn | ||
from sbi.utils.sbiutils import ( | ||
standardizing_net, | ||
z_score_parser, | ||
) | ||
from sbi.utils.sbiutils import standardizing_net, z_score_parser | ||
from sbi.utils.user_input_checks import check_data_device | ||
|
||
model_builders = { | ||
|
@@ -56,6 +59,7 @@ | |
z_score_x: Optional[str] = "independent", | ||
z_score_y: Optional[str] = "independent", | ||
flow_model: str = "nsf", | ||
categorical_model: str = "mlp", | ||
embedding_net: nn.Module = nn.Identity(), | ||
combined_embedding_net: Optional[nn.Module] = None, | ||
num_transforms: int = 2, | ||
|
@@ -102,6 +106,8 @@ | |
as z_score_x. | ||
flow_model: type of flow model to use for the continuous part of the | ||
data. | ||
categorical_model: type of categorical net to use for the discrete part of | ||
the data. Can be "made" or "mlp". | ||
embedding_net: Optional embedding network for y, required if y is > 1D. | ||
combined_embedding_net: Optional embedding for combining the discrete | ||
part of the input and the embedded condition into a joined | ||
|
@@ -125,13 +131,14 @@ | |
|
||
warnings.warn( | ||
"The mixed neural likelihood estimator assumes that x contains " | ||
"continuous data in the first n-1 columns (e.g., reaction times) and " | ||
"categorical data in the last column (e.g., corresponding choices). If " | ||
"continuous data in the first n-k columns (e.g., reaction times) and " | ||
"categorical data in the last k columns (e.g., corresponding choices). If " | ||
"this is not the case for the passed `x` do not use this function.", | ||
stacklevel=2, | ||
) | ||
# Separate continuous and discrete data. | ||
cont_x, disc_x = _separate_input(batch_x) | ||
num_disc = int(torch.sum(_is_discrete(batch_x))) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please add |
||
cont_x, disc_x = _separate_input(batch_x, num_discrete_columns=num_disc) | ||
|
||
# Set up y-embedding net with z-scoring. | ||
z_score_y_bool, structured_y = z_score_parser(z_score_y) | ||
|
@@ -144,15 +151,31 @@ | |
combined_condition = torch.cat([disc_x, embedded_batch_y], dim=-1) | ||
|
||
# Set up a categorical RV neural net for modelling the discrete data. | ||
discrete_net = build_categoricalmassestimator( | ||
disc_x, | ||
batch_y, | ||
z_score_x="none", # discrete data should not be z-scored. | ||
z_score_y="none", # y-embedding net already z-scores. | ||
num_hidden=hidden_features, | ||
num_layers=hidden_layers, | ||
embedding_net=embedding_net, | ||
) | ||
if categorical_model == "made": | ||
discrete_net = build_autoregressive_categoricalmassestimator( | ||
disc_x, | ||
batch_y, | ||
z_score_x="none", # discrete data should not be z-scored. | ||
z_score_y="none", # y-embedding net already z-scores. | ||
num_hidden=hidden_features, | ||
num_layers=hidden_layers, | ||
embedding_net=embedding_net, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. pass |
||
) | ||
elif categorical_model == "mlp": | ||
assert num_disc == 1, "MLP only supports 1D input." | ||
discrete_net = build_categoricalmassestimator( | ||
disc_x, | ||
batch_y, | ||
z_score_x="none", # discrete data should not be z-scored. | ||
z_score_y="none", # y-embedding net already z-scores. | ||
num_hidden=hidden_features, | ||
num_layers=hidden_layers, | ||
embedding_net=embedding_net, | ||
) | ||
Comment on lines
+164
to
+174
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. more generally, isn't the MLP a special case of the MADE? can't we absorb them into one class? |
||
else: | ||
raise ValueError( | ||
f"Unknown categorical net {categorical_model}. Must be 'made' or 'mlp'." | ||
) | ||
|
||
if combined_embedding_net is None: | ||
# set up linear embedding net for combining discrete and continuous | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should eventually be integrated with the MNLE tutorial in
12_iid_data_and_permutation_invariant_embeddings.ipynb