diff --git a/autoemulate/experimental/emulators/nn/mlp_gaussian.py b/autoemulate/experimental/emulators/nn/mlp_gaussian.py new file mode 100644 index 000000000..6c9e55fa0 --- /dev/null +++ b/autoemulate/experimental/emulators/nn/mlp_gaussian.py @@ -0,0 +1,114 @@ +import torch +from autoemulate.core.device import TorchDeviceMixin +from autoemulate.core.types import DeviceLike, GaussianLike, TensorLike +from autoemulate.data.utils import set_random_seed +from autoemulate.emulators.base import GaussianEmulator +from autoemulate.emulators.nn.mlp import MLP +from autoemulate.transforms.standardize import StandardizeTransform +from autoemulate.transforms.utils import make_positive_definite +from torch import nn + + +class GaussianMLP(MLP, GaussianEmulator): + """Multi-Layer Perceptron (MLP) emulator with Gaussian outputs.""" + + def __init__( + self, + x: TensorLike, + y: TensorLike, + standardize_x: bool = True, + standardize_y: bool = True, + activation_cls: type[nn.Module] = nn.ReLU, + loss_fn_cls: type[nn.Module] = nn.MSELoss, + epochs: int = 100, + batch_size: int = 16, + layer_dims: list[int] | None = None, + weight_init: str = "default", + scale: float = 1.0, + bias_init: str = "default", + dropout_prob: float | None = None, + lr: float = 1e-2, + random_seed: int | None = None, + device: DeviceLike | None = None, + **scheduler_kwargs, + ): + TorchDeviceMixin.__init__(self, device=device) + nn.Module.__init__(self) + + if random_seed is not None: + set_random_seed(seed=random_seed) + + # Ensure x and y are tensors with correct dimensions + x, y = self._convert_to_tensors(x, y) + + # Construct the MLP layers + # Total params required for last layer: mean + tril covariance + num_params = y.shape[1] + (y.shape[1] * (y.shape[1] + 1)) // 2 + layer_dims = ( + [x.shape[1], *layer_dims] + if layer_dims + else [x.shape[1], 4 * num_params, 2 * num_params] + ) + layers = [] + for idx, dim in enumerate(layer_dims[1:]): + layers.append(nn.Linear(layer_dims[idx], dim, device=self.device)) + layers.append(activation_cls()) + if dropout_prob is not None: + layers.append(nn.Dropout(p=dropout_prob)) + + # Add final layer without activation + layers.append(nn.Linear(layer_dims[-1], num_params, device=self.device)) + self.nn = nn.Sequential(*layers) + + # Finalize initialization + self._initialize_weights(weight_init, scale, bias_init) + self.x_transform = StandardizeTransform() if standardize_x else None + self.y_transform = StandardizeTransform() if standardize_y else None + self.epochs = epochs + self.loss_fn = loss_fn_cls() + self.lr = lr + self.num_tasks = y.shape[1] + self.batch_size = batch_size + self.optimizer = self.optimizer_cls(self.nn.parameters(), lr=lr) # type: ignore # noqa: PGH003 + self.scheduler_setup(scheduler_kwargs) + self.to(device) + + def _predict(self, x, with_grad=False): + """Predict using the MLP model.""" + with torch.set_grad_enabled(with_grad): + self.nn.eval() + return self(x) + + def forward(self, x): + """Forward pass for the Gaussian MLP.""" + y = self.nn(x) + mean = y[..., : self.num_tasks] + + # Use Cholesky decomposition to guarantee PSD covariance matrix + num_chol_params = (self.num_tasks * (self.num_tasks + 1)) // 2 + chol_params = y[..., self.num_tasks : self.num_tasks + num_chol_params] + + # Assign params to matrix + scale_tril = torch.zeros( + *y.shape[:-1], self.num_tasks, self.num_tasks, device=y.device + ) + tril_indices = torch.tril_indices( + self.num_tasks, self.num_tasks, device=y.device + ) + scale_tril[..., tril_indices[0], tril_indices[1]] = chol_params + + # Ensure positive variance + diag_idxs = torch.arange(self.num_tasks) + diag = ( + torch.nn.functional.softplus(scale_tril[..., diag_idxs, diag_idxs]) + 1e-6 + ) + scale_tril[..., diag_idxs, diag_idxs] = diag + + covariance_matrix = scale_tril @ scale_tril.transpose(-1, -2) + + # TODO: for large covariance martrices, numerical instability remains + return GaussianLike(mean, make_positive_definite(covariance_matrix)) + + def loss_func(self, y_pred, y_true): + """Negative log likelihood loss function.""" + return -y_pred.log_prob(y_true).mean()