Skip to content

Commit 0ee7a40

Browse files
authored
Remove Bayesian models (#164)
1 parent 4e011e7 commit 0ee7a40

8 files changed

+2
-539
lines changed

docs/q_and_a.md

-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ To get insipiration, you can look at the source code of
1919
## How can I change the loss function?
2020
If you are using pre-made GW architecture
2121
([`GlobalWorkspace`](https://ruflab.github.io/shimmer/latest/shimmer/modules/global_workspace.html#GlobalWorkspace),
22-
[`GlobalWorkspaceBayesian`](https://ruflab.github.io/shimmer/latest/shimmer/modules/global_workspace.html#GlobalWorkspaceBayesian),
2322
[`GlobalWorkspaceFusion`](https://ruflab.github.io/shimmer/latest/shimmer/modules/global_workspace.html#GlobalWorkspaceFusion)) and want to update the loss
2423
used for demi-cycles, cycles, translations or broadcast, you can do so directly from
2524
your definition of the

shimmer/__init__.py

-6
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from shimmer.modules.global_workspace import (
1515
GlobalWorkspace2Domains,
1616
GlobalWorkspaceBase,
17-
GlobalWorkspaceBayesian,
1817
SchedulerArgs,
1918
batch_broadcasts,
2019
batch_cycles,
@@ -28,7 +27,6 @@
2827
GWEncoderLinear,
2928
GWModule,
3029
GWModuleBase,
31-
GWModuleBayesian,
3230
GWModulePrediction,
3331
broadcast,
3432
broadcast_cycles,
@@ -39,7 +37,6 @@
3937
BroadcastLossCoefs,
4038
GWLosses2Domains,
4139
GWLossesBase,
42-
GWLossesBayesian,
4340
LossCoefs,
4441
)
4542
from shimmer.modules.selection import (
@@ -75,7 +72,6 @@
7572
"SchedulerArgs",
7673
"GlobalWorkspaceBase",
7774
"GlobalWorkspace2Domains",
78-
"GlobalWorkspaceBayesian",
7975
"pretrained_global_workspace",
8076
"LossOutput",
8177
"DomainModule",
@@ -84,7 +80,6 @@
8480
"GWEncoderLinear",
8581
"GWModuleBase",
8682
"GWModule",
87-
"GWModuleBayesian",
8883
"GWModulePrediction",
8984
"ContrastiveLossType",
9085
"contrastive_loss",
@@ -93,7 +88,6 @@
9388
"BroadcastLossCoefs",
9489
"GWLossesBase",
9590
"GWLosses2Domains",
96-
"GWLossesBayesian",
9791
"RepeatedDataset",
9892
"batch_cycles",
9993
"batch_demi_cycles",

shimmer/modules/__init__.py

-8
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
11
from shimmer.data.dataset import RepeatedDataset
22
from shimmer.modules.contrastive_loss import (
33
ContrastiveLoss,
4-
ContrastiveLossBayesianType,
54
ContrastiveLossType,
65
contrastive_loss,
76
)
87
from shimmer.modules.domain import DomainModule, LossOutput
98
from shimmer.modules.global_workspace import (
109
GlobalWorkspace2Domains,
1110
GlobalWorkspaceBase,
12-
GlobalWorkspaceBayesian,
1311
SchedulerArgs,
1412
batch_broadcasts,
1513
batch_cycles,
@@ -23,7 +21,6 @@
2321
GWEncoderLinear,
2422
GWModule,
2523
GWModuleBase,
26-
GWModuleBayesian,
2724
GWModulePrediction,
2825
broadcast,
2926
broadcast_cycles,
@@ -34,7 +31,6 @@
3431
BroadcastLossCoefs,
3532
GWLosses2Domains,
3633
GWLossesBase,
37-
GWLossesBayesian,
3834
LossCoefs,
3935
)
4036
from shimmer.modules.selection import (
@@ -55,7 +51,6 @@
5551
"SchedulerArgs",
5652
"GlobalWorkspaceBase",
5753
"GlobalWorkspace2Domains",
58-
"GlobalWorkspaceBayesian",
5954
"pretrained_global_workspace",
6055
"LossOutput",
6156
"DomainModule",
@@ -64,17 +59,14 @@
6459
"GWEncoderLinear",
6560
"GWModuleBase",
6661
"GWModule",
67-
"GWModuleBayesian",
6862
"GWModulePrediction",
6963
"ContrastiveLossType",
70-
"ContrastiveLossBayesianType",
7164
"contrastive_loss",
7265
"ContrastiveLoss",
7366
"LossCoefs",
7467
"BroadcastLossCoefs",
7568
"GWLossesBase",
7669
"GWLosses2Domains",
77-
"GWLossesBayesian",
7870
"RepeatedDataset",
7971
"reparameterize",
8072
"kl_divergence_loss",

shimmer/modules/contrastive_loss.py

-10
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,6 @@
1515
A function taking the prediction and targets and returning a LossOutput.
1616
"""
1717

18-
ContrastiveLossBayesianType = Callable[
19-
[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], LossOutput
20-
]
21-
"""
22-
Contrastive loss function type for GlobalWorkspaceBayesian.
23-
24-
A function taking the prediction mean, prediction std, target mean and target std and
25-
returns a LossOutput.
26-
"""
27-
2818

2919
def info_nce(
3020
x: torch.Tensor,

shimmer/modules/global_workspace.py

-104
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from shimmer.modules.gw_module import (
1616
GWModule,
1717
GWModuleBase,
18-
GWModuleBayesian,
1918
GWModulePrediction,
2019
broadcast_cycles,
2120
cycle,
@@ -26,11 +25,9 @@
2625
GWLosses,
2726
GWLosses2Domains,
2827
GWLossesBase,
29-
GWLossesBayesian,
3028
LossCoefs,
3129
)
3230
from shimmer.modules.selection import (
33-
FixedSharedSelection,
3431
RandomSelection,
3532
SelectionBase,
3633
SingleDomainSelection,
@@ -793,107 +790,6 @@ def __init__(
793790
)
794791

795792

796-
class GlobalWorkspaceBayesian(
797-
GlobalWorkspaceBase[GWModuleBayesian, FixedSharedSelection, GWLossesBayesian]
798-
):
799-
"""
800-
A simple 2-domains max GlobalWorkspaceBase with a Bayesian base uncertainty
801-
prediction.
802-
803-
This is used to simplify a Global Workspace instanciation and only overrides the
804-
`__init__` method.
805-
"""
806-
807-
def __init__(
808-
self,
809-
domain_mods: Mapping[str, DomainModule],
810-
gw_encoders: Mapping[str, Module],
811-
gw_decoders: Mapping[str, Module],
812-
workspace_dim: int,
813-
loss_coefs: BroadcastLossCoefs,
814-
sensitivity_selection: float = 1,
815-
sensitivity_precision: float = 1,
816-
optim_lr: float = 1e-3,
817-
optim_weight_decay: float = 0.0,
818-
scheduler_args: SchedulerArgs | None = None,
819-
learn_logit_scale: bool = False,
820-
use_normalized_constrastive: bool = True,
821-
contrastive_loss: ContrastiveLossType | None = None,
822-
precision_softmax_temp: float = 0.01,
823-
scheduler: LRScheduler
824-
| None
825-
| OneCycleSchedulerSentinel = OneCycleSchedulerSentinel.DEFAULT,
826-
) -> None:
827-
"""
828-
Initializes a Global Workspace
829-
830-
Args:
831-
domain_mods (`Mapping[str, DomainModule]`): mapping of the domains
832-
connected to the GW. Keys are domain names, values are the
833-
`DomainModule`.
834-
gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
835-
name to a `torch.nn.Module` class which role is to encode a
836-
unimodal latent representations into a GW representation (pre fusion).
837-
gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
838-
name to a `torch.nn.Module` class which role is to decode a
839-
GW representation into a unimodal latent representations.
840-
workspace_dim (`int`): dimension of the GW.
841-
loss_coefs (`LossCoefs`): loss coefficients
842-
sensitivity_selection (`float`): sensivity coef $c'_1$
843-
sensitivity_precision (`float`): sensitivity coef $c'_2$
844-
optim_lr (`float`): learning rate
845-
optim_weight_decay (`float`): weight decay
846-
scheduler_args (`SchedulerArgs | None`): optimization scheduler's arguments
847-
learn_logit_scale (`bool`): whether to learn the contrastive learning
848-
contrastive loss when using the default contrastive loss.
849-
use_normalized_constrastive (`bool`): whether to use the normalized cont
850-
loss by the precision coefs
851-
contrastive_loss (`ContrastiveLossType | None`): a contrastive loss
852-
function used for alignment. `learn_logit_scale` will not affect custom
853-
contrastive losses.
854-
precision_softmax_temp (`float`): temperature to use in softmax of
855-
precision
856-
scheduler: The scheduler to use for traning. If None is explicitely given,
857-
no scheduler will be used. Defaults to use OneCycleScheduler
858-
"""
859-
domain_mods = freeze_domain_modules(domain_mods)
860-
861-
gw_mod = GWModuleBayesian(
862-
domain_mods,
863-
workspace_dim,
864-
gw_encoders,
865-
gw_decoders,
866-
sensitivity_selection,
867-
sensitivity_precision,
868-
precision_softmax_temp,
869-
)
870-
871-
selection_mod = FixedSharedSelection()
872-
873-
contrastive_loss = ContrastiveLoss(
874-
torch.tensor([1]).log(), "mean", learn_logit_scale
875-
)
876-
877-
loss_mod = GWLossesBayesian(
878-
gw_mod,
879-
selection_mod,
880-
domain_mods,
881-
loss_coefs,
882-
contrastive_loss,
883-
use_normalized_constrastive,
884-
)
885-
886-
super().__init__(
887-
gw_mod,
888-
selection_mod,
889-
loss_mod,
890-
optim_lr,
891-
optim_weight_decay,
892-
scheduler_args,
893-
scheduler,
894-
)
895-
896-
897793
def pretrained_global_workspace(
898794
checkpoint_path: str | Path,
899795
domain_mods: Mapping[str, DomainModule],

0 commit comments

Comments
 (0)