Skip to content

Commit

Permalink
Add a selection module for fixed shared weights (#69)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdvllrs authored May 15, 2024
1 parent 5a516f9 commit 98a64a1
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 4 deletions.
7 changes: 3 additions & 4 deletions shimmer/modules/global_workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
LossCoefs,
)
from shimmer.modules.selection import (
FixedSharedSelection,
RandomSelection,
SelectionBase,
SingleDomainSelection,
Expand Down Expand Up @@ -631,7 +632,7 @@ def __init__(


class GlobalWorkspaceBayesian(
GlobalWorkspaceBase[GWModuleBayesian, RandomSelection, GWLossesBayesian]
GlobalWorkspaceBase[GWModuleBayesian, FixedSharedSelection, GWLossesBayesian]
):
"""
A simple 2-domains max GlobalWorkspaceBase with a Bayesian base uncertainty
Expand All @@ -648,7 +649,6 @@ def __init__(
gw_decoders: Mapping[str, Module],
workspace_dim: int,
loss_coefs: BroadcastLossCoefs,
selection_temperature: float = 0.2,
sensitivity_selection: float = 1,
sensitivity_precision: float = 1,
optim_lr: float = 1e-3,
Expand All @@ -672,7 +672,6 @@ def __init__(
GW representation into a unimodal latent representations.
workspace_dim (`int`): dimension of the GW.
loss_coefs (`LossCoefs`): loss coefficients
selection_temperature (`float`): temperature for `RandomSelection`
sensitivity_selection (`float`): sensivity coef $c'_1$
sensitivity_precision (`float`): sensitivity coef $c'_2$
optim_lr (`float`): learning rate
Expand All @@ -695,7 +694,7 @@ def __init__(
sensitivity_precision,
)

selection_mod = RandomSelection(selection_temperature)
selection_mod = FixedSharedSelection()

contrastive_loss = ContrastiveLoss(
torch.tensor([1]).log(), "mean", learn_logit_scale
Expand Down
30 changes: 30 additions & 0 deletions shimmer/modules/selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,36 @@ def forward(
return selection


class FixedSharedSelection(SelectionBase):
"""
This selection mechanism is deterministic and always shares the weights equally
between domains.
For example, if 2 domains, it gives 0.5 for each; 3 domains, 1/3 for each...
"""

def forward(
self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT
) -> dict[str, torch.Tensor]:
"""
Forward pass of the module.
Args:
domains (`LatentsDomainGroupT`): input unimodal latent representations
gw_state (`torch.Tensor`): the previous GW state
Returns:
`dict[str, torch.Tensor]`: whether the domain is selected for each input
in the batch.
"""
selection: dict[str, torch.Tensor] = {}
bs = group_batch_size(domains)
coef = torch.full((bs,), 1.0 / len(domains), device=group_device(domains))
for domain in domains:
selection[domain] = coef.clone()
return selection


class KQFixedQSelection(SelectionBase):
"""
Key-Query attention with a fixed gw vector.
Expand Down

0 comments on commit 98a64a1

Please sign in to comment.