88from gpytorch .kernels import MultitaskKernel , ScaleKernel
99from gpytorch .likelihoods import MultitaskGaussianLikelihood
1010from gpytorch .means import MultitaskMean
11- from torch import nn , optim
11+ from torch import optim
1212from torch .optim .lr_scheduler import LRScheduler
1313
1414from autoemulate .experimental .callbacks .early_stopping import (
@@ -63,7 +63,6 @@ def __init__( # noqa: PLR0913 allow too many arguments since all currently requ
6363 mean_module_fn : MeanModuleFn = constant_mean ,
6464 covar_module_fn : CovarModuleFn = rbf ,
6565 epochs : int = 50 ,
66- activation : type [nn .Module ] = nn .ReLU ,
6766 lr : float = 1e-1 ,
6867 early_stopping : EarlyStopping | None = None ,
6968 device : DeviceLike | None = None ,
@@ -87,8 +86,6 @@ def __init__( # noqa: PLR0913 allow too many arguments since all currently requ
8786 Function to create the covariance module.
8887 epochs : int, default=50
8988 Number of training epochs.
90- activation : type[nn.Module], default=nn.ReLU
91- Activation function to use in the model.
9289 lr : float, default=2e-1
9390 Learning rate for the optimizer.
9491 device : DeviceLike | None, default=None
@@ -130,7 +127,6 @@ def __init__( # noqa: PLR0913 allow too many arguments since all currently requ
130127 self .covar_module = covar_module
131128 self .epochs = epochs
132129 self .lr = lr
133- self .activation = activation
134130 self .optimizer = self .optimizer_cls (self .parameters (), lr = self .lr ) # type: ignore[call-arg] since all optimizers include lr
135131 # Extract scheduler-specific kwargs if present
136132 scheduler_kwargs = kwargs .pop ("scheduler_kwargs" , {})
@@ -206,8 +202,27 @@ def _predict(self, x: TensorLike, with_grad: bool) -> GaussianProcessLike:
206202 x = x .to (self .device )
207203 return self (x )
208204
205+ @classmethod
206+ def scheduler_config (cls ) -> dict :
207+ """
208+ Returns a random configuration for the learning rate scheduler.
209+ This should be added to the `get_tune_config()` method of subclasses
210+ to allow tuning of the scheduler parameters.
211+ """
212+ all_params = [
213+ {"scheduler_cls" : None , "scheduler_kwargs" : None },
214+ {
215+ "scheduler_cls" : [LRScheduler ],
216+ "scheduler_kwargs" : [
217+ {"policy" : "ReduceLROnPlateau" , "patience" : 5 , "factor" : 0.5 }
218+ ],
219+ },
220+ ]
221+ return np .random .choice (all_params )
222+
209223 @staticmethod
210224 def get_tune_config ():
225+ scheduler_params = GaussianProcessExact .scheduler_config ()
211226 return {
212227 "mean_module_fn" : [
213228 constant_mean ,
@@ -226,12 +241,10 @@ def get_tune_config():
226241 rbf_times_linear ,
227242 ],
228243 "epochs" : [50 , 100 , 200 ],
229- "activation" : [
230- nn .ReLU ,
231- nn .GELU ,
232- ],
233- "lr" : list (np .logspace (- 3 , - 1 )),
244+ "lr" : list (np .logspace (- 3 , 0 , 100 )),
234245 "likelihood_cls" : [MultitaskGaussianLikelihood ],
246+ "scheduler_cls" : scheduler_params ["scheduler_cls" ],
247+ "scheduler_kwargs" : scheduler_params ["scheduler_kwargs" ],
235248 }
236249
237250
@@ -255,7 +268,6 @@ def __init__( # noqa: PLR0913 allow too many arguments since all currently requ
255268 mean_module_fn : MeanModuleFn = constant_mean ,
256269 covar_module_fn : CovarModuleFn = rbf ,
257270 epochs : int = 50 ,
258- activation : type [nn .Module ] = nn .ReLU ,
259271 lr : float = 2e-1 ,
260272 early_stopping : EarlyStopping | None = None ,
261273 seed : int | None = None ,
@@ -332,7 +344,6 @@ def __init__( # noqa: PLR0913 allow too many arguments since all currently requ
332344 self .covar_module = covar_module
333345 self .epochs = epochs
334346 self .lr = lr
335- self .activation = activation
336347 self .optimizer = self .optimizer_cls (self .parameters (), lr = self .lr ) # type: ignore[call-arg] since all optimizers include lr
337348 # Extract scheduler-specific kwargs if present
338349 scheduler_kwargs = kwargs .pop ("scheduler_kwargs" , {})
0 commit comments