Skip to content

Commit f13ac8c

Browse files
authored
Option to change the GlobalWorkspace LR scheduler (#132)
* Add possiblitity to choose a learning rate scheduler for training GW * make sentinel public * more general output for configure_optimizers
1 parent 3e78b27 commit f13ac8c

File tree

1 file changed

+55
-5
lines changed

1 file changed

+55
-5
lines changed

shimmer/modules/global_workspace.py

+55-5
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
from collections.abc import Iterable, Mapping
2+
from enum import Enum, auto
23
from pathlib import Path
34
from typing import Any, Generic, TypedDict, TypeVar, cast
45

56
import torch
67
from lightning.pytorch import LightningModule
7-
from lightning.pytorch.utilities.types import OptimizerLRSchedulerConfig
8+
from lightning.pytorch.utilities.types import (
9+
OptimizerLRScheduler,
10+
)
811
from torch.nn import Module, ModuleDict
9-
from torch.optim.lr_scheduler import OneCycleLR
12+
from torch.optim.adamw import AdamW
13+
from torch.optim.lr_scheduler import LRScheduler, OneCycleLR
1014

1115
from shimmer.modules.contrastive_loss import ContrastiveLoss, ContrastiveLossType
1216
from shimmer.modules.domain import DomainModule
@@ -206,6 +210,14 @@ def batch_broadcasts(
206210
return predictions, cycles
207211

208212

213+
class OneCycleSchedulerSentinel(Enum):
214+
"""
215+
Used for backward-compatibility issues to use One-Cycle Scheduler by default
216+
"""
217+
218+
DEFAULT = auto()
219+
220+
209221
class GlobalWorkspaceBase(
210222
Generic[_T_gw_mod, _T_selection_mod, _T_loss_mod], LightningModule
211223
):
@@ -223,6 +235,9 @@ def __init__(
223235
optim_lr: float = 1e-3,
224236
optim_weight_decay: float = 0.0,
225237
scheduler_args: SchedulerArgs | None = None,
238+
scheduler: LRScheduler
239+
| None
240+
| OneCycleSchedulerSentinel = OneCycleSchedulerSentinel.DEFAULT,
226241
) -> None:
227242
"""
228243
Initializes a GW
@@ -235,6 +250,8 @@ def __init__(
235250
optim_weight_decay (`float`): weight decay
236251
scheduler_args (`SchedulerArgs`): `SchedulerArgs` instance to define
237252
scheduler parameters.
253+
scheduler: scheduler to use. If None is explicitely given, no scheduler
254+
will be used. By default, uses OneCycleScheduler
238255
"""
239256
super().__init__()
240257
self.save_hyperparameters(
@@ -248,6 +265,7 @@ def __init__(
248265
"cont_loss_bayesian",
249266
"gw_encoders",
250267
"gw_decoders",
268+
"scheduler",
251269
]
252270
)
253271

@@ -262,6 +280,7 @@ def __init__(
262280

263281
self.optim_lr = optim_lr
264282
self.optim_weight_decay = optim_weight_decay
283+
self.scheduler = scheduler
265284
self.scheduler_args = SchedulerArgs(max_lr=optim_lr, total_steps=1)
266285
if scheduler_args is not None:
267286
self.scheduler_args.update(scheduler_args)
@@ -537,21 +556,28 @@ def predict_step( # type: ignore
537556
domain_latents = self.encode_domains(batch)
538557
return self.forward(domain_latents)
539558

540-
def configure_optimizers(self) -> OptimizerLRSchedulerConfig:
559+
def configure_optimizers(self) -> OptimizerLRScheduler:
541560
"""
542561
Configure models optimizers.
543562
544563
Here we use `AdamW` for the optimizer and `OneCycleLR` for the learning-rate
545564
scheduler.
546565
"""
547566

548-
optimizer = torch.optim.AdamW(
567+
optimizer = AdamW(
549568
self.parameters(),
550569
lr=self.optim_lr,
551570
weight_decay=self.optim_weight_decay,
552571
)
553572

554-
lr_scheduler = OneCycleLR(optimizer, **self.scheduler_args)
573+
if self.scheduler is None:
574+
return {"optimizer": optimizer}
575+
576+
lr_scheduler: LRScheduler
577+
if isinstance(self.scheduler, OneCycleSchedulerSentinel):
578+
lr_scheduler = OneCycleLR(optimizer, **self.scheduler_args)
579+
else:
580+
lr_scheduler = self.scheduler
555581

556582
return {
557583
"optimizer": optimizer,
@@ -607,6 +633,9 @@ def __init__(
607633
scheduler_args: SchedulerArgs | None = None,
608634
learn_logit_scale: bool = False,
609635
contrastive_loss: ContrastiveLossType | None = None,
636+
scheduler: LRScheduler
637+
| None
638+
| OneCycleSchedulerSentinel = OneCycleSchedulerSentinel.DEFAULT,
610639
) -> None:
611640
"""
612641
Initializes a Global Workspace
@@ -631,6 +660,8 @@ def __init__(
631660
contrastive_loss (`ContrastiveLossType | None`): a contrastive loss
632661
function used for alignment. `learn_logit_scale` will not affect custom
633662
contrastive losses.
663+
scheduler: The scheduler to use for traning. If None is explicitely given,
664+
no scheduler will be used. Defaults to use OneCycleScheduler
634665
"""
635666
domain_mods = freeze_domain_modules(domain_mods)
636667

@@ -651,6 +682,7 @@ def __init__(
651682
optim_lr,
652683
optim_weight_decay,
653684
scheduler_args,
685+
scheduler,
654686
)
655687

656688

@@ -674,6 +706,9 @@ def __init__(
674706
scheduler_args: SchedulerArgs | None = None,
675707
learn_logit_scale: bool = False,
676708
contrastive_loss: ContrastiveLossType | None = None,
709+
scheduler: LRScheduler
710+
| None
711+
| OneCycleSchedulerSentinel = OneCycleSchedulerSentinel.DEFAULT,
677712
) -> None:
678713
"""
679714
Initializes a Global Workspace
@@ -700,6 +735,8 @@ def __init__(
700735
contrastive_loss (`ContrastiveLossType | None`): a contrastive loss
701736
function used for alignment. `learn_logit_scale` will not affect custom
702737
contrastive losses.
738+
scheduler: The scheduler to use for traning. If None is explicitely given,
739+
no scheduler will be used. Defaults to use OneCycleScheduler
703740
"""
704741
domain_mods = freeze_domain_modules(domain_mods)
705742
gw_mod = GWModule(domain_mods, workspace_dim, gw_encoders, gw_decoders)
@@ -721,6 +758,7 @@ def __init__(
721758
optim_lr,
722759
optim_weight_decay,
723760
scheduler_args,
761+
scheduler,
724762
)
725763

726764

@@ -751,6 +789,9 @@ def __init__(
751789
use_normalized_constrastive: bool = True,
752790
contrastive_loss: ContrastiveLossType | None = None,
753791
precision_softmax_temp: float = 0.01,
792+
scheduler: LRScheduler
793+
| None
794+
| OneCycleSchedulerSentinel = OneCycleSchedulerSentinel.DEFAULT,
754795
) -> None:
755796
"""
756797
Initializes a Global Workspace
@@ -781,6 +822,8 @@ def __init__(
781822
contrastive losses.
782823
precision_softmax_temp (`float`): temperature to use in softmax of
783824
precision
825+
scheduler: The scheduler to use for traning. If None is explicitely given,
826+
no scheduler will be used. Defaults to use OneCycleScheduler
784827
"""
785828
domain_mods = freeze_domain_modules(domain_mods)
786829

@@ -816,6 +859,7 @@ def __init__(
816859
optim_lr,
817860
optim_weight_decay,
818861
scheduler_args,
862+
scheduler,
819863
)
820864

821865

@@ -827,6 +871,9 @@ def pretrained_global_workspace(
827871
workspace_dim: int,
828872
loss_coefs: LossCoefs,
829873
contrastive_fn: ContrastiveLossType,
874+
scheduler: LRScheduler
875+
| None
876+
| OneCycleSchedulerSentinel = OneCycleSchedulerSentinel.DEFAULT,
830877
**kwargs,
831878
) -> GlobalWorkspace2Domains:
832879
"""
@@ -848,6 +895,8 @@ def pretrained_global_workspace(
848895
contrastive_loss (`ContrastiveLossType`): a contrastive loss
849896
function used for alignment. `learn_logit_scale` will not affect custom
850897
contrastive losses.
898+
scheduler: The scheduler to use for traning. If None is explicitely given,
899+
no scheduler will be used. Defaults to use OneCycleScheduler
851900
**kwargs: additional arguments to pass to
852901
`GlobalWorkspace.load_from_checkpoint`.
853902
@@ -870,6 +919,7 @@ def pretrained_global_workspace(
870919
selection_mid=selection_mod,
871920
loss_coefs=loss_coefs,
872921
loss_mod=loss_mod,
922+
scheduler=scheduler,
873923
**kwargs,
874924
)
875925
if not isinstance(gw, GlobalWorkspace2Domains):

0 commit comments

Comments
 (0)