Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions autoemulate/core/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,8 +539,6 @@ def fit_from_reinitialized(
neural networks.

"""
from autoemulate.emulators import get_emulator_class

transformed_emulator_params = (
transformed_emulator_params or self.transformed_emulator_params
)
Expand Down
9 changes: 8 additions & 1 deletion autoemulate/emulators/gaussian_process/exact.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,13 @@
rbf_times_linear,
rq_kernel,
)
from .mean import constant_mean, linear_mean, poly_mean, zero_mean
from .mean import (
constant_mean,
linear_mean,
partially_learnable_mean,
poly_mean,
zero_mean,
)


class GaussianProcess(GaussianProcessEmulator, gpytorch.models.ExactGP):
Expand Down Expand Up @@ -284,6 +290,7 @@ def get_tune_params():
zero_mean,
linear_mean,
poly_mean,
partially_learnable_mean,
],
"covar_module_fn": [
rbf,
Expand Down
43 changes: 43 additions & 0 deletions autoemulate/emulators/gaussian_process/mean.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from collections.abc import Callable

import torch
from gpytorch.means import ConstantMean, LinearMean, ZeroMean

from .partially_learnable import PartiallyLearnableMean
from .poly_mean import PolyMean


Expand Down Expand Up @@ -93,3 +96,43 @@ def poly_mean(n_features: int, n_outputs: torch.Size | None) -> PolyMean:
if n_outputs is not None
else PolyMean(degree=2, input_size=n_features)
)


def partially_learnable_mean(
n_features: int,
n_outputs: torch.Size | None,
mean_func: Callable = torch.sin,
known_dim: int = 0,
) -> PartiallyLearnableMean:
"""
PartiallyLearnableMean module with known function for one dimension.

Parameters
----------
n_features: int
Number of input features.
n_outputs: torch.Size | None
Batch shape of the mean. If None, the mean is not initialized with a batch
shape.
mean_func: callable
Function to apply to the known dimension. Defaults to torch.sin.
known_dim: int
Dimension index for the known function. Defaults to 0.

Returns
-------
PartiallyLearnableMean
The initialized PartiallyLearnableMean module.
"""
return (
PartiallyLearnableMean(
mean_func=mean_func,
known_dim=known_dim,
input_size=n_features,
batch_shape=n_outputs,
)
if n_outputs is not None
else PartiallyLearnableMean(
mean_func=mean_func, known_dim=known_dim, input_size=n_features
)
)
72 changes: 72 additions & 0 deletions autoemulate/emulators/gaussian_process/partially_learnable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from collections.abc import Callable

import gpytorch
import torch
from gpytorch.means import LinearMean

from autoemulate.core.types import TensorLike


class PartiallyLearnableMean(gpytorch.means.Mean):
"""
A mixed mean module that combines a known function with learnable components.

Parameters
----------
mean_func : Callable
A function that takes a tensor and returns a tensor of the same shape.
This function will be applied to the specified dimension.
known_dim : int
The dimension index to which the custom mean function will be applied.
Default is 0 (first dimension).
input_size : int
The total number of input features.
batch_shape : torch.Size | None
Optional batch dimension for multi-task GPs.
"""

def __init__(
self,
mean_func: Callable,
known_dim: int = 0,
input_size: int = 1,
batch_shape: torch.Size | None = None,
):
super().__init__()

# Store the custom function and dimension
self.mean_func = mean_func
self.known_dim = known_dim
self.input_size = input_size

if batch_shape is None:
batch_shape = torch.Size()

# Create indices for learnable dimensions (all except known_dim)
self.learnable_dims = [i for i in range(input_size) if i != known_dim]

# Only create linear mean if there are learnable dimensions
self.linear_mean = LinearMean(
input_size=len(self.learnable_dims), batch_shape=batch_shape
)

def forward(self, x: TensorLike) -> TensorLike:
"""Forward pass through the partially learnable mean module."""
# Apply custom mean function to the known dimension
known_part = self.mean_func(x[..., self.known_dim])

learnable_data = x[..., self.learnable_dims]
learnable_part = self.linear_mean(
learnable_data
) # this part could be replaced with other function / NN
return known_part + learnable_part

def __repr__(self) -> str:
"""Return string representation of the PartiallyLearnableMean module."""
func_name = getattr(self.mean_func, "__name__", "mean_func")
return (
f"PartiallyLearnableMean("
f"mean_func={func_name}, "
f"known_dim={self.known_dim}, "
f"input_size={self.input_size})"
)
237 changes: 237 additions & 0 deletions autoemulate/experimental/exploratory/universal_kriging.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "0",
"metadata": {},
"source": []
},
{
"cell_type": "markdown",
"id": "1",
"metadata": {},
"source": [
"# Demo for Universal Kriging implementation\n",
"\n",
"This terminology is borrowed from this paper \n",
"https://arxiv.org/pdf/2408.02331\n",
"\n",
"1. Simple Kriging: Known mean function, no noise\n",
"2. Ordinary Kriging: Unknown constant mean, no noise\n",
"3. Universal Kriging: Unknown mean as linear combination of known functions\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2",
"metadata": {},
"outputs": [],
"source": [
"from autoemulate.simulations.projectile import Projectile\n",
"\n",
"projectile = Projectile(log_level=\"error\")\n",
"n_samples = 50\n",
"x = projectile.sample_inputs(n_samples).float()\n",
"y = projectile.forward_batch(x).float()\n",
"x.shape, y.shape"
]
},
{
"cell_type": "markdown",
"id": "3",
"metadata": {},
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "4",
"metadata": {},
"outputs": [],
"source": [
"from autoemulate import AutoEmulate \n",
"from autoemulate.emulators import GaussianProcess\n"
]
},
{
"cell_type": "markdown",
"id": "5",
"metadata": {},
"source": [
"# define custom mean \n",
"for example here we try to incorporates some knowledge of projectile motion physics into the GP. With drag, there is no simple closed form solution and simulaiton is solved numerically. for no drag : $ R = v_0^2 sin(2\\theta)/g$ \n",
"\n",
"cutom_mean is returning the `mean_module.PartiallyLearnableMean` class which has `projectile_mean` mean_func\n",
"\n",
"and we replace `mean_module.partially_learnable_mean`"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6",
"metadata": {},
"outputs": [],
"source": [
"import autoemulate.emulators.gaussian_process.mean as mean_module\n",
"\n",
"def projectile_mean(x):\n",
" return x**2/9.8\n",
"\n",
"def custom_mean(n_features, n_outputs):\n",
" return mean_module.PartiallyLearnableMean(\n",
" mean_func=projectile_mean,\n",
" known_dim=0,\n",
" input_size=n_features,\n",
" batch_shape=n_outputs\n",
" )\n",
"\n",
"mean_module.partially_learnable_mean = custom_mean\n"
]
},
{
"cell_type": "markdown",
"id": "7",
"metadata": {},
"source": [
"This means that here , \n",
"\n",
"```python \n",
" return {\n",
" \"mean_module_fn\": [\n",
" constant_mean,\n",
" zero_mean,\n",
" linear_mean,\n",
" poly_mean,\n",
" partially_learnable_mean,\n",
" ],\n",
"```\n",
"\n",
"it itterates over these and our updated `partially_learnable_mean` and choose the best, as here the result is not the best for `partially_learnable_mean` I also check this for a case without tuning "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8",
"metadata": {},
"outputs": [],
"source": [
"ae = AutoEmulate(x, y, models=[GaussianProcess], log_level=\"error\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9",
"metadata": {},
"outputs": [],
"source": [
"ae.summarise()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "10",
"metadata": {},
"outputs": [],
"source": [
"ae.plot(0)\n"
]
},
{
"cell_type": "markdown",
"id": "11",
"metadata": {},
"source": [
"# No tuning \n",
"here ` mean_module.partially_learnable_mean` is forced without tuning "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "12",
"metadata": {},
"outputs": [],
"source": [
"\n",
"ae_2 = AutoEmulate(\n",
" x, y, \n",
" models=[GaussianProcess],\n",
" model_tuning=False,\n",
" model_params={\"mean_module_fn\": mean_module.partially_learnable_mean},\n",
" log_level=\"error\"\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "13",
"metadata": {},
"outputs": [],
"source": [
"ae_2.summarise()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "14",
"metadata": {},
"outputs": [],
"source": [
"ae_2.plot(0)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "15",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "16",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "17",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading
Loading