Skip to content

Commit

Permalink
add step_limit to dynamicqueryattention (#159)
Browse files Browse the repository at this point in the history
This just adds a way to have a given attention mechanism stop before it
has completed all steps.
This makes sense because even though an attention mechanism is trained
for n steps we may not always need to go the whole n steps.
I'm using this right now to train it for classification on the average
of all steps. I think it should make it to shimmer since we'll probably
use this functionality quite a lot.
  • Loading branch information
RolandBERTINJOHANNET authored Sep 27, 2024
1 parent ca2ab8c commit 62f82ed
Showing 1 changed file with 15 additions and 1 deletion.
16 changes: 15 additions & 1 deletion shimmer/modules/selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,9 +225,23 @@ def __init__(
{domain: nn.Linear(domain_dim, head_size) for domain in domain_names}
)
self.n_steps = n_steps
self.step_limit = n_steps # Default step limit is n_steps
# Start with a random gw state
self.register_buffer("initial_gw_state", torch.rand(domain_dim))

def set_step_limit(self, step_limit: int):
"""
Sets the step limit for the dynamic attention update loop.
Args:
step_limit (`int`): Maximum number of steps to run the loop.
"""
if step_limit > self.n_steps:
raise ValueError(
f"Step limit cannot exceed the maximum n_steps ({self.n_steps})."
)
self.step_limit = step_limit

def fuse_weighted_encodings(
self, encodings: LatentsDomainGroupT, attention_dict: dict[str, torch.Tensor]
) -> torch.Tensor:
Expand Down Expand Up @@ -289,7 +303,7 @@ def forward(

if self.n_steps > 0:
# Update the query based on the static attention scores
for _ in range(self.n_steps):
for _ in range(min(self.step_limit, self.n_steps)):
# Apply the attention scores to the encodings
summed_tensor = self.fuse_weighted_encodings(
encodings_pre_fusion, attention_dict
Expand Down

0 comments on commit 62f82ed

Please sign in to comment.