Skip to content

Commit 62f82ed

Browse files
add step_limit to dynamicqueryattention (#159)
This just adds a way to have a given attention mechanism stop before it has completed all steps. This makes sense because even though an attention mechanism is trained for n steps we may not always need to go the whole n steps. I'm using this right now to train it for classification on the average of all steps. I think it should make it to shimmer since we'll probably use this functionality quite a lot.
1 parent ca2ab8c commit 62f82ed

File tree

1 file changed

+15
-1
lines changed

1 file changed

+15
-1
lines changed

shimmer/modules/selection.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -225,9 +225,23 @@ def __init__(
225225
{domain: nn.Linear(domain_dim, head_size) for domain in domain_names}
226226
)
227227
self.n_steps = n_steps
228+
self.step_limit = n_steps # Default step limit is n_steps
228229
# Start with a random gw state
229230
self.register_buffer("initial_gw_state", torch.rand(domain_dim))
230231

232+
def set_step_limit(self, step_limit: int):
233+
"""
234+
Sets the step limit for the dynamic attention update loop.
235+
236+
Args:
237+
step_limit (`int`): Maximum number of steps to run the loop.
238+
"""
239+
if step_limit > self.n_steps:
240+
raise ValueError(
241+
f"Step limit cannot exceed the maximum n_steps ({self.n_steps})."
242+
)
243+
self.step_limit = step_limit
244+
231245
def fuse_weighted_encodings(
232246
self, encodings: LatentsDomainGroupT, attention_dict: dict[str, torch.Tensor]
233247
) -> torch.Tensor:
@@ -289,7 +303,7 @@ def forward(
289303

290304
if self.n_steps > 0:
291305
# Update the query based on the static attention scores
292-
for _ in range(self.n_steps):
306+
for _ in range(min(self.step_limit, self.n_steps)):
293307
# Apply the attention scores to the encodings
294308
summed_tensor = self.fuse_weighted_encodings(
295309
encodings_pre_fusion, attention_dict

0 commit comments

Comments
 (0)