1
1
from collections .abc import Iterable , Mapping
2
+ from enum import Enum , auto
2
3
from pathlib import Path
3
4
from typing import Any , Generic , TypedDict , TypeVar , cast
4
5
5
6
import torch
6
7
from lightning .pytorch import LightningModule
7
- from lightning .pytorch .utilities .types import OptimizerLRSchedulerConfig
8
+ from lightning .pytorch .utilities .types import (
9
+ OptimizerLRScheduler ,
10
+ )
8
11
from torch .nn import Module , ModuleDict
9
- from torch .optim .lr_scheduler import OneCycleLR
12
+ from torch .optim .adamw import AdamW
13
+ from torch .optim .lr_scheduler import LRScheduler , OneCycleLR
10
14
11
15
from shimmer .modules .contrastive_loss import ContrastiveLoss , ContrastiveLossType
12
16
from shimmer .modules .domain import DomainModule
@@ -206,6 +210,14 @@ def batch_broadcasts(
206
210
return predictions , cycles
207
211
208
212
213
+ class OneCycleSchedulerSentinel (Enum ):
214
+ """
215
+ Used for backward-compatibility issues to use One-Cycle Scheduler by default
216
+ """
217
+
218
+ DEFAULT = auto ()
219
+
220
+
209
221
class GlobalWorkspaceBase (
210
222
Generic [_T_gw_mod , _T_selection_mod , _T_loss_mod ], LightningModule
211
223
):
@@ -223,6 +235,9 @@ def __init__(
223
235
optim_lr : float = 1e-3 ,
224
236
optim_weight_decay : float = 0.0 ,
225
237
scheduler_args : SchedulerArgs | None = None ,
238
+ scheduler : LRScheduler
239
+ | None
240
+ | OneCycleSchedulerSentinel = OneCycleSchedulerSentinel .DEFAULT ,
226
241
) -> None :
227
242
"""
228
243
Initializes a GW
@@ -235,6 +250,8 @@ def __init__(
235
250
optim_weight_decay (`float`): weight decay
236
251
scheduler_args (`SchedulerArgs`): `SchedulerArgs` instance to define
237
252
scheduler parameters.
253
+ scheduler: scheduler to use. If None is explicitely given, no scheduler
254
+ will be used. By default, uses OneCycleScheduler
238
255
"""
239
256
super ().__init__ ()
240
257
self .save_hyperparameters (
@@ -248,6 +265,7 @@ def __init__(
248
265
"cont_loss_bayesian" ,
249
266
"gw_encoders" ,
250
267
"gw_decoders" ,
268
+ "scheduler" ,
251
269
]
252
270
)
253
271
@@ -262,6 +280,7 @@ def __init__(
262
280
263
281
self .optim_lr = optim_lr
264
282
self .optim_weight_decay = optim_weight_decay
283
+ self .scheduler = scheduler
265
284
self .scheduler_args = SchedulerArgs (max_lr = optim_lr , total_steps = 1 )
266
285
if scheduler_args is not None :
267
286
self .scheduler_args .update (scheduler_args )
@@ -537,21 +556,28 @@ def predict_step( # type: ignore
537
556
domain_latents = self .encode_domains (batch )
538
557
return self .forward (domain_latents )
539
558
540
- def configure_optimizers (self ) -> OptimizerLRSchedulerConfig :
559
+ def configure_optimizers (self ) -> OptimizerLRScheduler :
541
560
"""
542
561
Configure models optimizers.
543
562
544
563
Here we use `AdamW` for the optimizer and `OneCycleLR` for the learning-rate
545
564
scheduler.
546
565
"""
547
566
548
- optimizer = torch . optim . AdamW (
567
+ optimizer = AdamW (
549
568
self .parameters (),
550
569
lr = self .optim_lr ,
551
570
weight_decay = self .optim_weight_decay ,
552
571
)
553
572
554
- lr_scheduler = OneCycleLR (optimizer , ** self .scheduler_args )
573
+ if self .scheduler is None :
574
+ return {"optimizer" : optimizer }
575
+
576
+ lr_scheduler : LRScheduler
577
+ if isinstance (self .scheduler , OneCycleSchedulerSentinel ):
578
+ lr_scheduler = OneCycleLR (optimizer , ** self .scheduler_args )
579
+ else :
580
+ lr_scheduler = self .scheduler
555
581
556
582
return {
557
583
"optimizer" : optimizer ,
@@ -607,6 +633,9 @@ def __init__(
607
633
scheduler_args : SchedulerArgs | None = None ,
608
634
learn_logit_scale : bool = False ,
609
635
contrastive_loss : ContrastiveLossType | None = None ,
636
+ scheduler : LRScheduler
637
+ | None
638
+ | OneCycleSchedulerSentinel = OneCycleSchedulerSentinel .DEFAULT ,
610
639
) -> None :
611
640
"""
612
641
Initializes a Global Workspace
@@ -631,6 +660,8 @@ def __init__(
631
660
contrastive_loss (`ContrastiveLossType | None`): a contrastive loss
632
661
function used for alignment. `learn_logit_scale` will not affect custom
633
662
contrastive losses.
663
+ scheduler: The scheduler to use for traning. If None is explicitely given,
664
+ no scheduler will be used. Defaults to use OneCycleScheduler
634
665
"""
635
666
domain_mods = freeze_domain_modules (domain_mods )
636
667
@@ -651,6 +682,7 @@ def __init__(
651
682
optim_lr ,
652
683
optim_weight_decay ,
653
684
scheduler_args ,
685
+ scheduler ,
654
686
)
655
687
656
688
@@ -674,6 +706,9 @@ def __init__(
674
706
scheduler_args : SchedulerArgs | None = None ,
675
707
learn_logit_scale : bool = False ,
676
708
contrastive_loss : ContrastiveLossType | None = None ,
709
+ scheduler : LRScheduler
710
+ | None
711
+ | OneCycleSchedulerSentinel = OneCycleSchedulerSentinel .DEFAULT ,
677
712
) -> None :
678
713
"""
679
714
Initializes a Global Workspace
@@ -700,6 +735,8 @@ def __init__(
700
735
contrastive_loss (`ContrastiveLossType | None`): a contrastive loss
701
736
function used for alignment. `learn_logit_scale` will not affect custom
702
737
contrastive losses.
738
+ scheduler: The scheduler to use for traning. If None is explicitely given,
739
+ no scheduler will be used. Defaults to use OneCycleScheduler
703
740
"""
704
741
domain_mods = freeze_domain_modules (domain_mods )
705
742
gw_mod = GWModule (domain_mods , workspace_dim , gw_encoders , gw_decoders )
@@ -721,6 +758,7 @@ def __init__(
721
758
optim_lr ,
722
759
optim_weight_decay ,
723
760
scheduler_args ,
761
+ scheduler ,
724
762
)
725
763
726
764
@@ -751,6 +789,9 @@ def __init__(
751
789
use_normalized_constrastive : bool = True ,
752
790
contrastive_loss : ContrastiveLossType | None = None ,
753
791
precision_softmax_temp : float = 0.01 ,
792
+ scheduler : LRScheduler
793
+ | None
794
+ | OneCycleSchedulerSentinel = OneCycleSchedulerSentinel .DEFAULT ,
754
795
) -> None :
755
796
"""
756
797
Initializes a Global Workspace
@@ -781,6 +822,8 @@ def __init__(
781
822
contrastive losses.
782
823
precision_softmax_temp (`float`): temperature to use in softmax of
783
824
precision
825
+ scheduler: The scheduler to use for traning. If None is explicitely given,
826
+ no scheduler will be used. Defaults to use OneCycleScheduler
784
827
"""
785
828
domain_mods = freeze_domain_modules (domain_mods )
786
829
@@ -816,6 +859,7 @@ def __init__(
816
859
optim_lr ,
817
860
optim_weight_decay ,
818
861
scheduler_args ,
862
+ scheduler ,
819
863
)
820
864
821
865
@@ -827,6 +871,9 @@ def pretrained_global_workspace(
827
871
workspace_dim : int ,
828
872
loss_coefs : LossCoefs ,
829
873
contrastive_fn : ContrastiveLossType ,
874
+ scheduler : LRScheduler
875
+ | None
876
+ | OneCycleSchedulerSentinel = OneCycleSchedulerSentinel .DEFAULT ,
830
877
** kwargs ,
831
878
) -> GlobalWorkspace2Domains :
832
879
"""
@@ -848,6 +895,8 @@ def pretrained_global_workspace(
848
895
contrastive_loss (`ContrastiveLossType`): a contrastive loss
849
896
function used for alignment. `learn_logit_scale` will not affect custom
850
897
contrastive losses.
898
+ scheduler: The scheduler to use for traning. If None is explicitely given,
899
+ no scheduler will be used. Defaults to use OneCycleScheduler
851
900
**kwargs: additional arguments to pass to
852
901
`GlobalWorkspace.load_from_checkpoint`.
853
902
@@ -870,6 +919,7 @@ def pretrained_global_workspace(
870
919
selection_mid = selection_mod ,
871
920
loss_coefs = loss_coefs ,
872
921
loss_mod = loss_mod ,
922
+ scheduler = scheduler ,
873
923
** kwargs ,
874
924
)
875
925
if not isinstance (gw , GlobalWorkspace2Domains ):
0 commit comments