|
15 | 15 | from shimmer.modules.gw_module import (
|
16 | 16 | GWModule,
|
17 | 17 | GWModuleBase,
|
18 |
| - GWModuleBayesian, |
19 | 18 | GWModulePrediction,
|
20 | 19 | broadcast_cycles,
|
21 | 20 | cycle,
|
|
26 | 25 | GWLosses,
|
27 | 26 | GWLosses2Domains,
|
28 | 27 | GWLossesBase,
|
29 |
| - GWLossesBayesian, |
30 | 28 | LossCoefs,
|
31 | 29 | )
|
32 | 30 | from shimmer.modules.selection import (
|
33 |
| - FixedSharedSelection, |
34 | 31 | RandomSelection,
|
35 | 32 | SelectionBase,
|
36 | 33 | SingleDomainSelection,
|
@@ -793,107 +790,6 @@ def __init__(
|
793 | 790 | )
|
794 | 791 |
|
795 | 792 |
|
796 |
| -class GlobalWorkspaceBayesian( |
797 |
| - GlobalWorkspaceBase[GWModuleBayesian, FixedSharedSelection, GWLossesBayesian] |
798 |
| -): |
799 |
| - """ |
800 |
| - A simple 2-domains max GlobalWorkspaceBase with a Bayesian base uncertainty |
801 |
| - prediction. |
802 |
| -
|
803 |
| - This is used to simplify a Global Workspace instanciation and only overrides the |
804 |
| - `__init__` method. |
805 |
| - """ |
806 |
| - |
807 |
| - def __init__( |
808 |
| - self, |
809 |
| - domain_mods: Mapping[str, DomainModule], |
810 |
| - gw_encoders: Mapping[str, Module], |
811 |
| - gw_decoders: Mapping[str, Module], |
812 |
| - workspace_dim: int, |
813 |
| - loss_coefs: BroadcastLossCoefs, |
814 |
| - sensitivity_selection: float = 1, |
815 |
| - sensitivity_precision: float = 1, |
816 |
| - optim_lr: float = 1e-3, |
817 |
| - optim_weight_decay: float = 0.0, |
818 |
| - scheduler_args: SchedulerArgs | None = None, |
819 |
| - learn_logit_scale: bool = False, |
820 |
| - use_normalized_constrastive: bool = True, |
821 |
| - contrastive_loss: ContrastiveLossType | None = None, |
822 |
| - precision_softmax_temp: float = 0.01, |
823 |
| - scheduler: LRScheduler |
824 |
| - | None |
825 |
| - | OneCycleSchedulerSentinel = OneCycleSchedulerSentinel.DEFAULT, |
826 |
| - ) -> None: |
827 |
| - """ |
828 |
| - Initializes a Global Workspace |
829 |
| -
|
830 |
| - Args: |
831 |
| - domain_mods (`Mapping[str, DomainModule]`): mapping of the domains |
832 |
| - connected to the GW. Keys are domain names, values are the |
833 |
| - `DomainModule`. |
834 |
| - gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain |
835 |
| - name to a `torch.nn.Module` class which role is to encode a |
836 |
| - unimodal latent representations into a GW representation (pre fusion). |
837 |
| - gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain |
838 |
| - name to a `torch.nn.Module` class which role is to decode a |
839 |
| - GW representation into a unimodal latent representations. |
840 |
| - workspace_dim (`int`): dimension of the GW. |
841 |
| - loss_coefs (`LossCoefs`): loss coefficients |
842 |
| - sensitivity_selection (`float`): sensivity coef $c'_1$ |
843 |
| - sensitivity_precision (`float`): sensitivity coef $c'_2$ |
844 |
| - optim_lr (`float`): learning rate |
845 |
| - optim_weight_decay (`float`): weight decay |
846 |
| - scheduler_args (`SchedulerArgs | None`): optimization scheduler's arguments |
847 |
| - learn_logit_scale (`bool`): whether to learn the contrastive learning |
848 |
| - contrastive loss when using the default contrastive loss. |
849 |
| - use_normalized_constrastive (`bool`): whether to use the normalized cont |
850 |
| - loss by the precision coefs |
851 |
| - contrastive_loss (`ContrastiveLossType | None`): a contrastive loss |
852 |
| - function used for alignment. `learn_logit_scale` will not affect custom |
853 |
| - contrastive losses. |
854 |
| - precision_softmax_temp (`float`): temperature to use in softmax of |
855 |
| - precision |
856 |
| - scheduler: The scheduler to use for traning. If None is explicitely given, |
857 |
| - no scheduler will be used. Defaults to use OneCycleScheduler |
858 |
| - """ |
859 |
| - domain_mods = freeze_domain_modules(domain_mods) |
860 |
| - |
861 |
| - gw_mod = GWModuleBayesian( |
862 |
| - domain_mods, |
863 |
| - workspace_dim, |
864 |
| - gw_encoders, |
865 |
| - gw_decoders, |
866 |
| - sensitivity_selection, |
867 |
| - sensitivity_precision, |
868 |
| - precision_softmax_temp, |
869 |
| - ) |
870 |
| - |
871 |
| - selection_mod = FixedSharedSelection() |
872 |
| - |
873 |
| - contrastive_loss = ContrastiveLoss( |
874 |
| - torch.tensor([1]).log(), "mean", learn_logit_scale |
875 |
| - ) |
876 |
| - |
877 |
| - loss_mod = GWLossesBayesian( |
878 |
| - gw_mod, |
879 |
| - selection_mod, |
880 |
| - domain_mods, |
881 |
| - loss_coefs, |
882 |
| - contrastive_loss, |
883 |
| - use_normalized_constrastive, |
884 |
| - ) |
885 |
| - |
886 |
| - super().__init__( |
887 |
| - gw_mod, |
888 |
| - selection_mod, |
889 |
| - loss_mod, |
890 |
| - optim_lr, |
891 |
| - optim_weight_decay, |
892 |
| - scheduler_args, |
893 |
| - scheduler, |
894 |
| - ) |
895 |
| - |
896 |
| - |
897 | 793 | def pretrained_global_workspace(
|
898 | 794 | checkpoint_path: str | Path,
|
899 | 795 | domain_mods: Mapping[str, DomainModule],
|
|
0 commit comments