diff --git a/shimmer/modules/global_workspace.py b/shimmer/modules/global_workspace.py index 24d18fab..01509dd0 100644 --- a/shimmer/modules/global_workspace.py +++ b/shimmer/modules/global_workspace.py @@ -296,7 +296,8 @@ def encode_and_fuse( Args: x (`LatentsDomainGroupsT`): the input domain representations. - selection_scores (`Mapping[str, torch.Tensor]`): + selection_module (`SelectionBase`): selection module to use to obtain + selection scores. Returns: `dict[frozenset[str], torch.Tensor]`: the GW representations. diff --git a/shimmer/modules/gw_module.py b/shimmer/modules/gw_module.py index 90b9abcc..a7c039bb 100644 --- a/shimmer/modules/gw_module.py +++ b/shimmer/modules/gw_module.py @@ -298,8 +298,8 @@ def encode_and_fuse( Args: x (`LatentsDomainGroupT`): the input domain representations - selection_score (`Mapping[str, torch.Tensor]`): attention scores to - use to encode the reprensetation. + selection_module (`SelectionBase`): selection module to use to obtain + selection scores. Returns: `torch.Tensor`: The merged representation. diff --git a/shimmer/modules/selection.py b/shimmer/modules/selection.py index 9ab6caac..ac03bdd1 100644 --- a/shimmer/modules/selection.py +++ b/shimmer/modules/selection.py @@ -38,6 +38,8 @@ def forward( Args: domains (`LatentsDomainGroupT`): Group of unimodal latent representations. + encodings_pre_fusion (`LatentsDomainGroupT`): pre-fusion domain latent + representation. Returns: `dict[str, torch.Tensor]`: for each domain in the group, the fusion @@ -75,7 +77,8 @@ def forward( Args: domains (`LatentsDomainGroupT`): input unimodal latent representations - gw_state (`torch.Tensor`): the previous GW state + encodings_pre_fusion (`LatentsDomainGroupT`): pre-fusion domain latent + representation. Returns: `dict[str, torch.Tensor]`: whether the domain is selected for each input @@ -105,7 +108,8 @@ def forward( Args: domains (`LatentsDomainGroupT`): input unimodal latent representations - gw_state (`torch.Tensor`): the previous GW state + encodings_pre_fusion (`LatentsDomainGroupT`): pre-fusion domain latent + representation. Returns: `dict[str, torch.Tensor]`: whether the domain is selected for each input @@ -281,7 +285,8 @@ def forward( Args: domains (`LatentsDomainGroupT`): Group of unimodal latent representations. - encodings (`LatentsDomainGroupT`): Group of pre-fusion encodings. + encodings_pre_fusion (`LatentsDomainGroupT`): pre-fusion domain latent + representation. Returns: `dict[str, torch.Tensor]`: the attention scores for each domain in the