Skip to content

Commit 5369d9e

Browse files
authored
fix: global workspace scheduler arg as a callback (#169)
The previous implementation (#132) could not work with new scheduler as it could not take the optimizer as param. Here, the scheduler should be a callback instead: ```python def get_scheduler(optimizer: Optimizer) -> LRScheduler: return StepLR(otimizer, ...) gw = GlobalWorkspace( ... scheduler=get_scheduler ) ```
1 parent 9b50160 commit 5369d9e

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

shimmer/modules/global_workspace.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from torch.nn import Module, ModuleDict
1010
from torch.optim.adamw import AdamW
1111
from torch.optim.lr_scheduler import LRScheduler, OneCycleLR
12+
from torch.optim.optimizer import Optimizer
1213

1314
from shimmer.modules.contrastive_loss import ContrastiveLoss, ContrastiveLossType
1415
from shimmer.modules.domain import DomainModule, LossOutput
@@ -230,7 +231,7 @@ def __init__(
230231
optim_lr: float = 1e-3,
231232
optim_weight_decay: float = 0.0,
232233
scheduler_args: SchedulerArgs | None = None,
233-
scheduler: LRScheduler
234+
scheduler: Callable[[Optimizer], LRScheduler]
234235
| None
235236
| OneCycleSchedulerSentinel = OneCycleSchedulerSentinel.DEFAULT,
236237
) -> None:
@@ -245,7 +246,8 @@ def __init__(
245246
optim_weight_decay (`float`): weight decay
246247
scheduler_args (`SchedulerArgs`): `SchedulerArgs` instance to define
247248
scheduler parameters.
248-
scheduler: scheduler to use. If None is explicitely given, no scheduler
249+
scheduler (`Callable[[Optimizer], LRScheduler]`): Callback that returns the
250+
scheduler to use. If None is explicitely given, no scheduler
249251
will be used. By default, uses OneCycleScheduler
250252
"""
251253
super().__init__()
@@ -604,7 +606,7 @@ def configure_optimizers(self) -> OptimizerLRScheduler:
604606
if isinstance(self.scheduler, OneCycleSchedulerSentinel):
605607
lr_scheduler = OneCycleLR(optimizer, **self.scheduler_args)
606608
else:
607-
lr_scheduler = self.scheduler
609+
lr_scheduler = self.scheduler(optimizer)
608610

609611
return {
610612
"optimizer": optimizer,
@@ -661,7 +663,7 @@ def __init__(
661663
scheduler_args: SchedulerArgs | None = None,
662664
learn_logit_scale: bool = False,
663665
contrastive_loss: ContrastiveLossType | None = None,
664-
scheduler: LRScheduler
666+
scheduler: Callable[[Optimizer], LRScheduler]
665667
| None
666668
| OneCycleSchedulerSentinel = OneCycleSchedulerSentinel.DEFAULT,
667669
fusion_activation_fn: Callable[[torch.Tensor], torch.Tensor] = torch.tanh,
@@ -739,7 +741,7 @@ def __init__(
739741
scheduler_args: SchedulerArgs | None = None,
740742
learn_logit_scale: bool = False,
741743
contrastive_loss: ContrastiveLossType | None = None,
742-
scheduler: LRScheduler
744+
scheduler: Callable[[Optimizer], LRScheduler]
743745
| None
744746
| OneCycleSchedulerSentinel = OneCycleSchedulerSentinel.DEFAULT,
745747
fusion_activation_fn: Callable[[torch.Tensor], torch.Tensor] = torch.tanh,

0 commit comments

Comments
 (0)