From d862b5d8d42f0e76e15508e540a6293190b6a545 Mon Sep 17 00:00:00 2001
From: larascipio <57267136+larascipio@users.noreply.github.com>
Date: Tue, 21 May 2024 14:47:50 +0200
Subject: [PATCH] Module attention base (#84)

* changed query attention module to more general use

* change name

* rechanged shapesclassifier to match state dict

* change to old classifier

* added accuracy log

* changed shape classifier to the new classifier

* changed inputs fixed query attention in pytest

* change corruption per batch to ebery instance

* changes for pr

* deleted print and tuple

* changed to _attention dict
---
 shimmer/modules/attention_module.py |  73 ++++++++----------
 shimmer/modules/selection.py        | 113 ++++++++++++----------------
 tests/test_kq_onepass_attention.py  |   4 +-
 3 files changed, 81 insertions(+), 109 deletions(-)

diff --git a/shimmer/modules/attention_module.py b/shimmer/modules/attention_module.py
index 1ed6ca32..024a37fe 100644
--- a/shimmer/modules/attention_module.py
+++ b/shimmer/modules/attention_module.py
@@ -5,13 +5,18 @@
 import torch
 from lightning.pytorch import LightningModule
 from lightning.pytorch.utilities.types import OptimizerLRSchedulerConfig
-from torch import Tensor, nn
+from torch import Tensor
 from torch.optim.lr_scheduler import OneCycleLR
 
-from shimmer.modules.global_workspace import GlobalWorkspaceBase, SchedulerArgs
-from shimmer.modules.gw import GWModuleBase
+from shimmer.modules.global_workspace import (
+    GlobalWorkspaceBase,
+    GWModuleBase,
+    SchedulerArgs,
+)
 from shimmer.modules.losses import GWLossesBase
-from shimmer.modules.selection import DynamicQueryAttention, SelectionBase
+from shimmer.modules.selection import (
+    SelectionBase,
+)
 from shimmer.types import (
     LatentsDomainGroupsDT,
     LatentsDomainGroupsT,
@@ -20,45 +25,22 @@
 )
 
 
-class ShapesClassifier(nn.Sequential):
-    def __init__(self, input_dim, output_dim):
-        layers = [
-            nn.Linear(input_dim, 256),
-            nn.BatchNorm1d(256),
-            nn.ReLU(),
-            nn.Dropout(p=0.5),
-            nn.Linear(256, 128),
-            nn.BatchNorm1d(128),
-            nn.ReLU(),
-            nn.Dropout(p=0.5),
-            nn.Linear(128, 64),
-            nn.BatchNorm1d(64),
-            nn.ReLU(),
-            nn.Dropout(p=0.5),
-            nn.Linear(64, 32),
-            nn.BatchNorm1d(32),
-            nn.ReLU(),
-            nn.Dropout(p=0.5),
-            nn.Linear(32, output_dim),
-        ]
-        super().__init__(*layers)
-
-
-class DynamicAttention(LightningModule):
+class AttentionBase(LightningModule):
     """
     Attention Lightning Module.
 
-    This is a wrapper around the DynamicQueryAttention module.
-    It is used to train the Dynamic Query Attention mechanism.
+    This is a wrapper around the different attention modules.
+    It is used to train an attention/selection mechanism.
     """
 
     def __init__(
         self,
         gw: GlobalWorkspaceBase[GWModuleBase, SelectionBase, GWLossesBase],
-        domain_dim: int,
-        head_size: int,
+        attention: SelectionBase,
         domain_names: Sequence[str],
-        criterion: Callable[[torch.Tensor, RawDomainGroupT], torch.Tensor],
+        criterion: Callable[
+            [torch.Tensor, RawDomainGroupT], tuple[torch.Tensor, torch.Tensor]
+        ],
         optim_lr: float = 1e-3,
         optim_weight_decay: float = 0.0,
         scheduler_args: SchedulerArgs | None = None,
@@ -67,12 +49,13 @@ def __init__(
         self.save_hyperparameters(
             ignore=[
                 "gw",
+                "attention",
                 "criterion",
             ]
         )
 
         self.gw = gw
-        self.attention = DynamicQueryAttention(head_size, domain_dim, domain_names)
+        self.attention = attention
         self.domain_names = domain_names
         self.criterion = criterion
         self.optim_lr = optim_lr
@@ -119,7 +102,6 @@ def apply_corruption(
         self,
         batch: LatentsDomainGroupsT,
         corruption_vector: torch.Tensor | None = None,
-        corrupted_domain: str | None = None,
     ) -> LatentsDomainGroupsDT:
         """
         Apply corruption to the batch.
@@ -132,12 +114,11 @@ def apply_corruption(
         Returns:
             A batch where one of the latent domains is corrupted.
         """
-        if corrupted_domain is None:
-            # Specify which domain will be corrupted
-            corrupted_domain = random.choice(list(self.domain_names))
 
         matched_data_dict: LatentsDomainGroupsDT = {}
         for domain_names, domains in batch.items():
+            # Randomly select a domain to be corrupted for this instance
+            corrupted_domain = random.choice(list(self.domain_names))
             for domain_name, domain in domains.items():
                 if domain_names != self.domain_names or domain_name != corrupted_domain:
                     matched_data_dict.setdefault(domain_names, {})[domain_name] = domain
@@ -160,18 +141,28 @@ def generic_step(self, batch: RawDomainGroupsT, mode: str) -> Tensor:
         prefusion_encodings = self.gw.encode(corrupted_batch)
         attention_scores = self.forward(corrupted_batch, prefusion_encodings)
         merged_gw_representation = self.gw.fuse(prefusion_encodings, attention_scores)
+
         losses = []
+        accuracies = []
+
         for domain_names, domains in merged_gw_representation.items():
-            losses.append(self.criterion(domains, batch[domain_names]))
+            loss, accuracy = self.criterion(domains, batch[domain_names])
+            losses.append(loss)
+            accuracies.append(accuracy)
             domain_names_str = ",".join(domain_names)
             self.log(
                 f"{mode}/{domain_names_str}_loss",
                 losses[-1],
                 batch_size=domains.size(0),
             )
+            self.log(
+                f"{mode}/{domain_names_str}_accuracy",
+                accuracies[-1],
+                batch_size=domains.size(0),
+            )
         loss = torch.stack(losses).mean()
-        print(f"loss: {loss}")
         self.log(f"{mode}/loss", loss, on_step=True, on_epoch=True)
+        self.log(f"{mode}/accuracy", torch.stack(accuracies).mean())
 
         return loss
 
diff --git a/shimmer/modules/selection.py b/shimmer/modules/selection.py
index aad07b84..3bf64d87 100644
--- a/shimmer/modules/selection.py
+++ b/shimmer/modules/selection.py
@@ -119,75 +119,85 @@ def forward(
         return selection
 
 
+def _calculate_attention_dict(
+    domains: LatentsDomainGroupT,
+    keys: dict[str, torch.Tensor],
+    query: torch.Tensor,
+) -> dict[str, torch.Tensor]:
+    """
+    Args:
+        domains (`LatentsDomainGroupT`): Group of unimodal latent representations.
+        keys (`dict[str, torch.Tensor]`): The keys for each domain.
+        query (`torch.Tensor`): The query tensor.
+
+    Returns:
+        `dict[str, torch.Tensor]`: The attention scores for each domain.
+    """
+    dot_products = {
+        domain: torch.bmm(key.unsqueeze(1), query.unsqueeze(2)).squeeze()
+        for domain, key in keys.items()
+    }
+
+    dot_products_tensor = torch.stack(list(dot_products.values()), dim=1)
+
+    attention_scores = torch.softmax(dot_products_tensor, dim=1)
+
+    attention_dict = {
+        domain: attention_scores[:, i] for i, domain in enumerate(domains)
+    }
+    return attention_dict
+
+
 class KQFixedQSelection(SelectionBase):
     """
     Key-Query attention with a fixed gw vector.
     """
 
-    def __init__(self, domain_dim: int, head_size: int, domains: Iterable[str]):
+    def __init__(self, head_size: int, domain_dim: int, domain_names: Iterable[str]):
         """
         Args:
-            domain_dim (`int`) : dimension of the input dims
-                (assumed to be the same for now)
             head_size (`int`) : dimension of the key and query vectors.
-            domains (`Iterable[str]`) : list of input domains
+            domain_dim (`int`) : dimension of the input dims (assumed to be the same
+                for now)
+            domain_names  (`Iterable[str]`) : list of input domains
         """
         super().__init__()
         self.head_size = head_size
         self.query_layer = nn.Linear(domain_dim, head_size)
         self.key_layers = nn.ModuleDict(
-            {domain: nn.Linear(domain_dim, head_size) for domain in domains}
+            {domain: nn.Linear(domain_dim, head_size) for domain in domain_names}
         )
-        self.gw_state: torch.Tensor | None = None
-
-    def update_gw_state(self, gw_state: torch.Tensor) -> None:
-        """
-        Set the internal copy of the fixed gw state. You're meant to only call this once
-
-        Args:
-            gw_state (`torch.Tensor`): the previous GW state
-        """
-        self.gw_state = gw_state
+        # Start with a random gw state
+        self.register_buffer("initial_gw_state", torch.rand(domain_dim))
 
     def forward(
         self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT
     ) -> dict[str, torch.Tensor]:
         """
         Compute keys and queries, match them with dot product and softmax.
+        Does this twice, once with the static query and once with a dynamic query.
 
         Args:
             domains (`LatentsDomainGroupT`): Group of unimodal latent representations.
+            encodings (`LatentsDomainGroupT`): Group of pre-fusion encodings.
 
         Returns:
-            `dict[str, torch.Tensor]`: for each domain in the group, the fusion
-            coefficient for each item in the batch.
+            `dict[str, torch.Tensor]`: the attention scores for each domain in the
+            group.
         """
 
-        if self.gw_state is None:
-            raise ValueError("GW state has not been initialized.")
-
         keys = {
             domain: self.key_layers[domain](encoding)
             for domain, encoding in domains.items()
         }
 
-        device = group_device(domains)
-        query = self.query_layer(self.gw_state.to(device))
-
-        dot_products = {
-            domain: torch.bmm(key.unsqueeze(1), query.unsqueeze(2)).squeeze()
-            for domain, key in keys.items()
-        }
-
-        dot_products_tensor = torch.stack(list(dot_products.values()), dim=1)
-
-        attention_scores = torch.softmax(dot_products_tensor, dim=1)
+        batch_size = group_batch_size(domains)
 
-        attention_dict = {
-            domain: attention_scores[:, i : i + 1] for i, domain in enumerate(keys)
-        }
+        # Retrieve random query
+        query = self.query_layer(self.initial_gw_state.expand(batch_size, -1))
 
-        return attention_dict
+        # Calculate the attention scores
+        return _calculate_attention_dict(domains, keys, query)
 
 
 class RandomSelection(SelectionBase):
@@ -262,35 +272,6 @@ def __init__(self, head_size: int, domain_dim: int, domain_names: Iterable[str])
         # Start with a random gw state
         self.register_buffer("initial_gw_state", torch.rand(domain_dim))
 
-    def calculate_attention_dict(
-        self,
-        domains: LatentsDomainGroupT,
-        keys: dict[str, torch.Tensor],
-        query: torch.Tensor,
-    ) -> dict[str, torch.Tensor]:
-        """
-        Args:
-            domains (`LatentsDomainGroupT`): Group of unimodal latent representations.
-            keys (`dict[str, torch.Tensor]`): The keys for each domain.
-            query (`torch.Tensor`): The query tensor.
-
-        Returns:
-            `dict[str, torch.Tensor]`: The attention scores for each domain.
-        """
-        dot_products = {
-            domain: torch.bmm(key.unsqueeze(1), query.unsqueeze(2)).squeeze()
-            for domain, key in keys.items()
-        }
-
-        dot_products_tensor = torch.stack(list(dot_products.values()), dim=1)
-
-        attention_scores = torch.softmax(dot_products_tensor, dim=1)
-
-        attention_dict = {
-            domain: attention_scores[:, i] for i, domain in enumerate(domains)
-        }
-        return attention_dict
-
     def fuse_weighted_encodings(
         self, encodings: LatentsDomainGroupT, attention_dict: dict[str, torch.Tensor]
     ) -> torch.Tensor:
@@ -348,7 +329,7 @@ def forward(
         query = self.query_layer(self.initial_gw_state.expand(batch_size, -1))
 
         # Calculate the attention scores
-        static_attention_dict = self.calculate_attention_dict(domains, keys, query)
+        static_attention_dict = _calculate_attention_dict(domains, keys, query)
 
         # Apply the attention scores to the encodings
         summed_tensor = self.fuse_weighted_encodings(
@@ -359,6 +340,6 @@ def forward(
         query = self.query_layer(summed_tensor)
 
         # Calculate the attention scores again
-        dynamic_attention_dict = self.calculate_attention_dict(domains, keys, query)
+        dynamic_attention_dict = _calculate_attention_dict(domains, keys, query)
 
         return dynamic_attention_dict
diff --git a/tests/test_kq_onepass_attention.py b/tests/test_kq_onepass_attention.py
index c049f94e..ed5d14b9 100644
--- a/tests/test_kq_onepass_attention.py
+++ b/tests/test_kq_onepass_attention.py
@@ -9,7 +9,7 @@ def test_single_domain():
     batch_size = 2056
     domains = ["v_latents"]
 
-    attention = KQFixedQSelection(domain_dim, head_size, domains)
+    attention = KQFixedQSelection(head_size, domain_dim, domains)
     gw_state = torch.rand(batch_size, domain_dim)
     attention.update_gw_state(gw_state)
 
@@ -28,7 +28,7 @@ def test_multiple_domains_sumis1():
     head_size = 5
     batch_size = 2056
     domains = ["v_latents", "attr"]
-    attention = KQFixedQSelection(domain_dim, head_size, domains)
+    attention = KQFixedQSelection(head_size, domain_dim, domains)
     gw_state = torch.rand(batch_size, domain_dim)
     attention.update_gw_state(gw_state)