diff --git a/shimmer/modules/global_workspace.py b/shimmer/modules/global_workspace.py index 02a37b6..9ca55fe 100644 --- a/shimmer/modules/global_workspace.py +++ b/shimmer/modules/global_workspace.py @@ -657,7 +657,7 @@ def __init__( gw_encoders: Mapping[str, Module], gw_decoders: Mapping[str, Module], workspace_dim: int, - loss_coefs: LossCoefs, + loss_coefs: LossCoefs | Mapping[str, float], optim_lr: float = 1e-3, optim_weight_decay: float = 0.0, scheduler_args: SchedulerArgs | None = None, @@ -682,7 +682,7 @@ def __init__( name to a `torch.nn.Module` class which role is to decode a GW representation into a unimodal latent representations. workspace_dim (`int`): dimension of the GW. - loss_coefs (`LossCoefs`): loss coefficients + loss_coefs (`LossCoefs | Mapping[str, float]`): loss coefficients optim_lr (`float`): learning rate optim_weight_decay (`float`): weight decay scheduler_args (`SchedulerArgs | None`): optimization scheduler's arguments @@ -734,7 +734,7 @@ def __init__( gw_encoders: Mapping[str, Module], gw_decoders: Mapping[str, Module], workspace_dim: int, - loss_coefs: BroadcastLossCoefs, + loss_coefs: BroadcastLossCoefs | Mapping[str, float], selection_temperature: float = 0.2, optim_lr: float = 1e-3, optim_weight_decay: float = 0.0, @@ -760,7 +760,8 @@ def __init__( name to a `torch.nn.Module` class which role is to decode a GW representation into a unimodal latent representations. workspace_dim (`int`): dimension of the GW. - loss_coefs (`BroadcastLossCoefs`): loss coefs for the losses. + loss_coefs (`BroadcastLossCoefs | Mapping[str, float]`): loss coefs for the + losses. selection_temperature (`float`): temperature value for the RandomSelection module. optim_lr (`float`): learning rate @@ -808,7 +809,7 @@ def pretrained_global_workspace( gw_encoders: Mapping[str, Module], gw_decoders: Mapping[str, Module], workspace_dim: int, - loss_coefs: LossCoefs, + loss_coefs: LossCoefs | Mapping[str, float], contrastive_fn: ContrastiveLossType, scheduler: LRScheduler | None @@ -831,7 +832,7 @@ def pretrained_global_workspace( name to a `torch.nn.Module` class which role is to decode a GW representation into a unimodal latent representations. workspace_dim (`int`): dimension of the GW. - loss_coefs (`LossCoefs`): loss coefficients + loss_coefs (`LossCoefs | Mapping[str, float]`): loss coefficients contrastive_loss (`ContrastiveLossType`): a contrastive loss function used for alignment. `learn_logit_scale` will not affect custom contrastive losses.