Skip to content

Commit

Permalink
fix: also allow the GlobalWorkspace to take flexible coefs
Browse files Browse the repository at this point in the history
  • Loading branch information
bdvllrs committed Oct 15, 2024
1 parent 74398e3 commit fa57de9
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions shimmer/modules/global_workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,7 @@ def __init__(
gw_encoders: Mapping[str, Module],
gw_decoders: Mapping[str, Module],
workspace_dim: int,
loss_coefs: LossCoefs,
loss_coefs: LossCoefs | Mapping[str, float],
optim_lr: float = 1e-3,
optim_weight_decay: float = 0.0,
scheduler_args: SchedulerArgs | None = None,
Expand All @@ -682,7 +682,7 @@ def __init__(
name to a `torch.nn.Module` class which role is to decode a
GW representation into a unimodal latent representations.
workspace_dim (`int`): dimension of the GW.
loss_coefs (`LossCoefs`): loss coefficients
loss_coefs (`LossCoefs | Mapping[str, float]`): loss coefficients
optim_lr (`float`): learning rate
optim_weight_decay (`float`): weight decay
scheduler_args (`SchedulerArgs | None`): optimization scheduler's arguments
Expand Down Expand Up @@ -734,7 +734,7 @@ def __init__(
gw_encoders: Mapping[str, Module],
gw_decoders: Mapping[str, Module],
workspace_dim: int,
loss_coefs: BroadcastLossCoefs,
loss_coefs: BroadcastLossCoefs | Mapping[str, float],
selection_temperature: float = 0.2,
optim_lr: float = 1e-3,
optim_weight_decay: float = 0.0,
Expand All @@ -760,7 +760,8 @@ def __init__(
name to a `torch.nn.Module` class which role is to decode a
GW representation into a unimodal latent representations.
workspace_dim (`int`): dimension of the GW.
loss_coefs (`BroadcastLossCoefs`): loss coefs for the losses.
loss_coefs (`BroadcastLossCoefs | Mapping[str, float]`): loss coefs for the
losses.
selection_temperature (`float`): temperature value for the RandomSelection
module.
optim_lr (`float`): learning rate
Expand Down Expand Up @@ -808,7 +809,7 @@ def pretrained_global_workspace(
gw_encoders: Mapping[str, Module],
gw_decoders: Mapping[str, Module],
workspace_dim: int,
loss_coefs: LossCoefs,
loss_coefs: LossCoefs | Mapping[str, float],
contrastive_fn: ContrastiveLossType,
scheduler: LRScheduler
| None
Expand All @@ -831,7 +832,7 @@ def pretrained_global_workspace(
name to a `torch.nn.Module` class which role is to decode a
GW representation into a unimodal latent representations.
workspace_dim (`int`): dimension of the GW.
loss_coefs (`LossCoefs`): loss coefficients
loss_coefs (`LossCoefs | Mapping[str, float]`): loss coefficients
contrastive_loss (`ContrastiveLossType`): a contrastive loss
function used for alignment. `learn_logit_scale` will not affect custom
contrastive losses.
Expand Down

0 comments on commit fa57de9

Please sign in to comment.