diff --git a/shimmer/modules/attention_module.py b/shimmer/modules/attention_module.py index 1ed6ca32..024a37fe 100644 --- a/shimmer/modules/attention_module.py +++ b/shimmer/modules/attention_module.py @@ -5,13 +5,18 @@ import torch from lightning.pytorch import LightningModule from lightning.pytorch.utilities.types import OptimizerLRSchedulerConfig -from torch import Tensor, nn +from torch import Tensor from torch.optim.lr_scheduler import OneCycleLR -from shimmer.modules.global_workspace import GlobalWorkspaceBase, SchedulerArgs -from shimmer.modules.gw import GWModuleBase +from shimmer.modules.global_workspace import ( + GlobalWorkspaceBase, + GWModuleBase, + SchedulerArgs, +) from shimmer.modules.losses import GWLossesBase -from shimmer.modules.selection import DynamicQueryAttention, SelectionBase +from shimmer.modules.selection import ( + SelectionBase, +) from shimmer.types import ( LatentsDomainGroupsDT, LatentsDomainGroupsT, @@ -20,45 +25,22 @@ ) -class ShapesClassifier(nn.Sequential): - def __init__(self, input_dim, output_dim): - layers = [ - nn.Linear(input_dim, 256), - nn.BatchNorm1d(256), - nn.ReLU(), - nn.Dropout(p=0.5), - nn.Linear(256, 128), - nn.BatchNorm1d(128), - nn.ReLU(), - nn.Dropout(p=0.5), - nn.Linear(128, 64), - nn.BatchNorm1d(64), - nn.ReLU(), - nn.Dropout(p=0.5), - nn.Linear(64, 32), - nn.BatchNorm1d(32), - nn.ReLU(), - nn.Dropout(p=0.5), - nn.Linear(32, output_dim), - ] - super().__init__(*layers) - - -class DynamicAttention(LightningModule): +class AttentionBase(LightningModule): """ Attention Lightning Module. - This is a wrapper around the DynamicQueryAttention module. - It is used to train the Dynamic Query Attention mechanism. + This is a wrapper around the different attention modules. + It is used to train an attention/selection mechanism. """ def __init__( self, gw: GlobalWorkspaceBase[GWModuleBase, SelectionBase, GWLossesBase], - domain_dim: int, - head_size: int, + attention: SelectionBase, domain_names: Sequence[str], - criterion: Callable[[torch.Tensor, RawDomainGroupT], torch.Tensor], + criterion: Callable[ + [torch.Tensor, RawDomainGroupT], tuple[torch.Tensor, torch.Tensor] + ], optim_lr: float = 1e-3, optim_weight_decay: float = 0.0, scheduler_args: SchedulerArgs | None = None, @@ -67,12 +49,13 @@ def __init__( self.save_hyperparameters( ignore=[ "gw", + "attention", "criterion", ] ) self.gw = gw - self.attention = DynamicQueryAttention(head_size, domain_dim, domain_names) + self.attention = attention self.domain_names = domain_names self.criterion = criterion self.optim_lr = optim_lr @@ -119,7 +102,6 @@ def apply_corruption( self, batch: LatentsDomainGroupsT, corruption_vector: torch.Tensor | None = None, - corrupted_domain: str | None = None, ) -> LatentsDomainGroupsDT: """ Apply corruption to the batch. @@ -132,12 +114,11 @@ def apply_corruption( Returns: A batch where one of the latent domains is corrupted. """ - if corrupted_domain is None: - # Specify which domain will be corrupted - corrupted_domain = random.choice(list(self.domain_names)) matched_data_dict: LatentsDomainGroupsDT = {} for domain_names, domains in batch.items(): + # Randomly select a domain to be corrupted for this instance + corrupted_domain = random.choice(list(self.domain_names)) for domain_name, domain in domains.items(): if domain_names != self.domain_names or domain_name != corrupted_domain: matched_data_dict.setdefault(domain_names, {})[domain_name] = domain @@ -160,18 +141,28 @@ def generic_step(self, batch: RawDomainGroupsT, mode: str) -> Tensor: prefusion_encodings = self.gw.encode(corrupted_batch) attention_scores = self.forward(corrupted_batch, prefusion_encodings) merged_gw_representation = self.gw.fuse(prefusion_encodings, attention_scores) + losses = [] + accuracies = [] + for domain_names, domains in merged_gw_representation.items(): - losses.append(self.criterion(domains, batch[domain_names])) + loss, accuracy = self.criterion(domains, batch[domain_names]) + losses.append(loss) + accuracies.append(accuracy) domain_names_str = ",".join(domain_names) self.log( f"{mode}/{domain_names_str}_loss", losses[-1], batch_size=domains.size(0), ) + self.log( + f"{mode}/{domain_names_str}_accuracy", + accuracies[-1], + batch_size=domains.size(0), + ) loss = torch.stack(losses).mean() - print(f"loss: {loss}") self.log(f"{mode}/loss", loss, on_step=True, on_epoch=True) + self.log(f"{mode}/accuracy", torch.stack(accuracies).mean()) return loss diff --git a/shimmer/modules/selection.py b/shimmer/modules/selection.py index aad07b84..3bf64d87 100644 --- a/shimmer/modules/selection.py +++ b/shimmer/modules/selection.py @@ -119,75 +119,85 @@ def forward( return selection +def _calculate_attention_dict( + domains: LatentsDomainGroupT, + keys: dict[str, torch.Tensor], + query: torch.Tensor, +) -> dict[str, torch.Tensor]: + """ + Args: + domains (`LatentsDomainGroupT`): Group of unimodal latent representations. + keys (`dict[str, torch.Tensor]`): The keys for each domain. + query (`torch.Tensor`): The query tensor. + + Returns: + `dict[str, torch.Tensor]`: The attention scores for each domain. + """ + dot_products = { + domain: torch.bmm(key.unsqueeze(1), query.unsqueeze(2)).squeeze() + for domain, key in keys.items() + } + + dot_products_tensor = torch.stack(list(dot_products.values()), dim=1) + + attention_scores = torch.softmax(dot_products_tensor, dim=1) + + attention_dict = { + domain: attention_scores[:, i] for i, domain in enumerate(domains) + } + return attention_dict + + class KQFixedQSelection(SelectionBase): """ Key-Query attention with a fixed gw vector. """ - def __init__(self, domain_dim: int, head_size: int, domains: Iterable[str]): + def __init__(self, head_size: int, domain_dim: int, domain_names: Iterable[str]): """ Args: - domain_dim (`int`) : dimension of the input dims - (assumed to be the same for now) head_size (`int`) : dimension of the key and query vectors. - domains (`Iterable[str]`) : list of input domains + domain_dim (`int`) : dimension of the input dims (assumed to be the same + for now) + domain_names (`Iterable[str]`) : list of input domains """ super().__init__() self.head_size = head_size self.query_layer = nn.Linear(domain_dim, head_size) self.key_layers = nn.ModuleDict( - {domain: nn.Linear(domain_dim, head_size) for domain in domains} + {domain: nn.Linear(domain_dim, head_size) for domain in domain_names} ) - self.gw_state: torch.Tensor | None = None - - def update_gw_state(self, gw_state: torch.Tensor) -> None: - """ - Set the internal copy of the fixed gw state. You're meant to only call this once - - Args: - gw_state (`torch.Tensor`): the previous GW state - """ - self.gw_state = gw_state + # Start with a random gw state + self.register_buffer("initial_gw_state", torch.rand(domain_dim)) def forward( self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT ) -> dict[str, torch.Tensor]: """ Compute keys and queries, match them with dot product and softmax. + Does this twice, once with the static query and once with a dynamic query. Args: domains (`LatentsDomainGroupT`): Group of unimodal latent representations. + encodings (`LatentsDomainGroupT`): Group of pre-fusion encodings. Returns: - `dict[str, torch.Tensor]`: for each domain in the group, the fusion - coefficient for each item in the batch. + `dict[str, torch.Tensor]`: the attention scores for each domain in the + group. """ - if self.gw_state is None: - raise ValueError("GW state has not been initialized.") - keys = { domain: self.key_layers[domain](encoding) for domain, encoding in domains.items() } - device = group_device(domains) - query = self.query_layer(self.gw_state.to(device)) - - dot_products = { - domain: torch.bmm(key.unsqueeze(1), query.unsqueeze(2)).squeeze() - for domain, key in keys.items() - } - - dot_products_tensor = torch.stack(list(dot_products.values()), dim=1) - - attention_scores = torch.softmax(dot_products_tensor, dim=1) + batch_size = group_batch_size(domains) - attention_dict = { - domain: attention_scores[:, i : i + 1] for i, domain in enumerate(keys) - } + # Retrieve random query + query = self.query_layer(self.initial_gw_state.expand(batch_size, -1)) - return attention_dict + # Calculate the attention scores + return _calculate_attention_dict(domains, keys, query) class RandomSelection(SelectionBase): @@ -262,35 +272,6 @@ def __init__(self, head_size: int, domain_dim: int, domain_names: Iterable[str]) # Start with a random gw state self.register_buffer("initial_gw_state", torch.rand(domain_dim)) - def calculate_attention_dict( - self, - domains: LatentsDomainGroupT, - keys: dict[str, torch.Tensor], - query: torch.Tensor, - ) -> dict[str, torch.Tensor]: - """ - Args: - domains (`LatentsDomainGroupT`): Group of unimodal latent representations. - keys (`dict[str, torch.Tensor]`): The keys for each domain. - query (`torch.Tensor`): The query tensor. - - Returns: - `dict[str, torch.Tensor]`: The attention scores for each domain. - """ - dot_products = { - domain: torch.bmm(key.unsqueeze(1), query.unsqueeze(2)).squeeze() - for domain, key in keys.items() - } - - dot_products_tensor = torch.stack(list(dot_products.values()), dim=1) - - attention_scores = torch.softmax(dot_products_tensor, dim=1) - - attention_dict = { - domain: attention_scores[:, i] for i, domain in enumerate(domains) - } - return attention_dict - def fuse_weighted_encodings( self, encodings: LatentsDomainGroupT, attention_dict: dict[str, torch.Tensor] ) -> torch.Tensor: @@ -348,7 +329,7 @@ def forward( query = self.query_layer(self.initial_gw_state.expand(batch_size, -1)) # Calculate the attention scores - static_attention_dict = self.calculate_attention_dict(domains, keys, query) + static_attention_dict = _calculate_attention_dict(domains, keys, query) # Apply the attention scores to the encodings summed_tensor = self.fuse_weighted_encodings( @@ -359,6 +340,6 @@ def forward( query = self.query_layer(summed_tensor) # Calculate the attention scores again - dynamic_attention_dict = self.calculate_attention_dict(domains, keys, query) + dynamic_attention_dict = _calculate_attention_dict(domains, keys, query) return dynamic_attention_dict diff --git a/tests/test_kq_onepass_attention.py b/tests/test_kq_onepass_attention.py index c049f94e..ed5d14b9 100644 --- a/tests/test_kq_onepass_attention.py +++ b/tests/test_kq_onepass_attention.py @@ -9,7 +9,7 @@ def test_single_domain(): batch_size = 2056 domains = ["v_latents"] - attention = KQFixedQSelection(domain_dim, head_size, domains) + attention = KQFixedQSelection(head_size, domain_dim, domains) gw_state = torch.rand(batch_size, domain_dim) attention.update_gw_state(gw_state) @@ -28,7 +28,7 @@ def test_multiple_domains_sumis1(): head_size = 5 batch_size = 2056 domains = ["v_latents", "attr"] - attention = KQFixedQSelection(domain_dim, head_size, domains) + attention = KQFixedQSelection(head_size, domain_dim, domains) gw_state = torch.rand(batch_size, domain_dim) attention.update_gw_state(gw_state)