Skip to content

Commit

Permalink
docs: fix docstring of selection related methods (#172)
Browse files Browse the repository at this point in the history
In particular `encode_and_fuse` and `forward` of Selection modules.
  • Loading branch information
bdvllrs authored Oct 11, 2024
1 parent b6eba5d commit 9b50160
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 6 deletions.
3 changes: 2 additions & 1 deletion shimmer/modules/global_workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,8 @@ def encode_and_fuse(
Args:
x (`LatentsDomainGroupsT`): the input domain representations.
selection_scores (`Mapping[str, torch.Tensor]`):
selection_module (`SelectionBase`): selection module to use to obtain
selection scores.
Returns:
`dict[frozenset[str], torch.Tensor]`: the GW representations.
Expand Down
4 changes: 2 additions & 2 deletions shimmer/modules/gw_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,8 +298,8 @@ def encode_and_fuse(
Args:
x (`LatentsDomainGroupT`): the input domain representations
selection_score (`Mapping[str, torch.Tensor]`): attention scores to
use to encode the reprensetation.
selection_module (`SelectionBase`): selection module to use to obtain
selection scores.
Returns:
`torch.Tensor`: The merged representation.
Expand Down
11 changes: 8 additions & 3 deletions shimmer/modules/selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ def forward(
Args:
domains (`LatentsDomainGroupT`): Group of unimodal latent representations.
encodings_pre_fusion (`LatentsDomainGroupT`): pre-fusion domain latent
representation.
Returns:
`dict[str, torch.Tensor]`: for each domain in the group, the fusion
Expand Down Expand Up @@ -75,7 +77,8 @@ def forward(
Args:
domains (`LatentsDomainGroupT`): input unimodal latent representations
gw_state (`torch.Tensor`): the previous GW state
encodings_pre_fusion (`LatentsDomainGroupT`): pre-fusion domain latent
representation.
Returns:
`dict[str, torch.Tensor]`: whether the domain is selected for each input
Expand Down Expand Up @@ -105,7 +108,8 @@ def forward(
Args:
domains (`LatentsDomainGroupT`): input unimodal latent representations
gw_state (`torch.Tensor`): the previous GW state
encodings_pre_fusion (`LatentsDomainGroupT`): pre-fusion domain latent
representation.
Returns:
`dict[str, torch.Tensor]`: whether the domain is selected for each input
Expand Down Expand Up @@ -281,7 +285,8 @@ def forward(
Args:
domains (`LatentsDomainGroupT`): Group of unimodal latent representations.
encodings (`LatentsDomainGroupT`): Group of pre-fusion encodings.
encodings_pre_fusion (`LatentsDomainGroupT`): pre-fusion domain latent
representation.
Returns:
`dict[str, torch.Tensor]`: the attention scores for each domain in the
Expand Down

0 comments on commit 9b50160

Please sign in to comment.