Skip to content

Commit

Permalink
Module attention base (#84)
Browse files Browse the repository at this point in the history
* changed query attention module to more general use

* change name

* rechanged shapesclassifier to match state dict

* change to old classifier

* added accuracy log

* changed shape classifier to the new classifier

* changed inputs fixed query attention in pytest

* change corruption per batch to ebery instance

* changes for pr

* deleted print and tuple

* changed to _attention dict
  • Loading branch information
larascipio authored May 21, 2024
1 parent d068372 commit d862b5d
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 109 deletions.
73 changes: 32 additions & 41 deletions shimmer/modules/attention_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,18 @@
import torch
from lightning.pytorch import LightningModule
from lightning.pytorch.utilities.types import OptimizerLRSchedulerConfig
from torch import Tensor, nn
from torch import Tensor
from torch.optim.lr_scheduler import OneCycleLR

from shimmer.modules.global_workspace import GlobalWorkspaceBase, SchedulerArgs
from shimmer.modules.gw import GWModuleBase
from shimmer.modules.global_workspace import (
GlobalWorkspaceBase,
GWModuleBase,
SchedulerArgs,
)
from shimmer.modules.losses import GWLossesBase
from shimmer.modules.selection import DynamicQueryAttention, SelectionBase
from shimmer.modules.selection import (
SelectionBase,
)
from shimmer.types import (
LatentsDomainGroupsDT,
LatentsDomainGroupsT,
Expand All @@ -20,45 +25,22 @@
)


class ShapesClassifier(nn.Sequential):
def __init__(self, input_dim, output_dim):
layers = [
nn.Linear(input_dim, 256),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.Dropout(p=0.5),
nn.Linear(256, 128),
nn.BatchNorm1d(128),
nn.ReLU(),
nn.Dropout(p=0.5),
nn.Linear(128, 64),
nn.BatchNorm1d(64),
nn.ReLU(),
nn.Dropout(p=0.5),
nn.Linear(64, 32),
nn.BatchNorm1d(32),
nn.ReLU(),
nn.Dropout(p=0.5),
nn.Linear(32, output_dim),
]
super().__init__(*layers)


class DynamicAttention(LightningModule):
class AttentionBase(LightningModule):
"""
Attention Lightning Module.
This is a wrapper around the DynamicQueryAttention module.
It is used to train the Dynamic Query Attention mechanism.
This is a wrapper around the different attention modules.
It is used to train an attention/selection mechanism.
"""

def __init__(
self,
gw: GlobalWorkspaceBase[GWModuleBase, SelectionBase, GWLossesBase],
domain_dim: int,
head_size: int,
attention: SelectionBase,
domain_names: Sequence[str],
criterion: Callable[[torch.Tensor, RawDomainGroupT], torch.Tensor],
criterion: Callable[
[torch.Tensor, RawDomainGroupT], tuple[torch.Tensor, torch.Tensor]
],
optim_lr: float = 1e-3,
optim_weight_decay: float = 0.0,
scheduler_args: SchedulerArgs | None = None,
Expand All @@ -67,12 +49,13 @@ def __init__(
self.save_hyperparameters(
ignore=[
"gw",
"attention",
"criterion",
]
)

self.gw = gw
self.attention = DynamicQueryAttention(head_size, domain_dim, domain_names)
self.attention = attention
self.domain_names = domain_names
self.criterion = criterion
self.optim_lr = optim_lr
Expand Down Expand Up @@ -119,7 +102,6 @@ def apply_corruption(
self,
batch: LatentsDomainGroupsT,
corruption_vector: torch.Tensor | None = None,
corrupted_domain: str | None = None,
) -> LatentsDomainGroupsDT:
"""
Apply corruption to the batch.
Expand All @@ -132,12 +114,11 @@ def apply_corruption(
Returns:
A batch where one of the latent domains is corrupted.
"""
if corrupted_domain is None:
# Specify which domain will be corrupted
corrupted_domain = random.choice(list(self.domain_names))

matched_data_dict: LatentsDomainGroupsDT = {}
for domain_names, domains in batch.items():
# Randomly select a domain to be corrupted for this instance
corrupted_domain = random.choice(list(self.domain_names))
for domain_name, domain in domains.items():
if domain_names != self.domain_names or domain_name != corrupted_domain:
matched_data_dict.setdefault(domain_names, {})[domain_name] = domain
Expand All @@ -160,18 +141,28 @@ def generic_step(self, batch: RawDomainGroupsT, mode: str) -> Tensor:
prefusion_encodings = self.gw.encode(corrupted_batch)
attention_scores = self.forward(corrupted_batch, prefusion_encodings)
merged_gw_representation = self.gw.fuse(prefusion_encodings, attention_scores)

losses = []
accuracies = []

for domain_names, domains in merged_gw_representation.items():
losses.append(self.criterion(domains, batch[domain_names]))
loss, accuracy = self.criterion(domains, batch[domain_names])
losses.append(loss)
accuracies.append(accuracy)
domain_names_str = ",".join(domain_names)
self.log(
f"{mode}/{domain_names_str}_loss",
losses[-1],
batch_size=domains.size(0),
)
self.log(
f"{mode}/{domain_names_str}_accuracy",
accuracies[-1],
batch_size=domains.size(0),
)
loss = torch.stack(losses).mean()
print(f"loss: {loss}")
self.log(f"{mode}/loss", loss, on_step=True, on_epoch=True)
self.log(f"{mode}/accuracy", torch.stack(accuracies).mean())

return loss

Expand Down
113 changes: 47 additions & 66 deletions shimmer/modules/selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,75 +119,85 @@ def forward(
return selection


def _calculate_attention_dict(
domains: LatentsDomainGroupT,
keys: dict[str, torch.Tensor],
query: torch.Tensor,
) -> dict[str, torch.Tensor]:
"""
Args:
domains (`LatentsDomainGroupT`): Group of unimodal latent representations.
keys (`dict[str, torch.Tensor]`): The keys for each domain.
query (`torch.Tensor`): The query tensor.
Returns:
`dict[str, torch.Tensor]`: The attention scores for each domain.
"""
dot_products = {
domain: torch.bmm(key.unsqueeze(1), query.unsqueeze(2)).squeeze()
for domain, key in keys.items()
}

dot_products_tensor = torch.stack(list(dot_products.values()), dim=1)

attention_scores = torch.softmax(dot_products_tensor, dim=1)

attention_dict = {
domain: attention_scores[:, i] for i, domain in enumerate(domains)
}
return attention_dict


class KQFixedQSelection(SelectionBase):
"""
Key-Query attention with a fixed gw vector.
"""

def __init__(self, domain_dim: int, head_size: int, domains: Iterable[str]):
def __init__(self, head_size: int, domain_dim: int, domain_names: Iterable[str]):
"""
Args:
domain_dim (`int`) : dimension of the input dims
(assumed to be the same for now)
head_size (`int`) : dimension of the key and query vectors.
domains (`Iterable[str]`) : list of input domains
domain_dim (`int`) : dimension of the input dims (assumed to be the same
for now)
domain_names (`Iterable[str]`) : list of input domains
"""
super().__init__()
self.head_size = head_size
self.query_layer = nn.Linear(domain_dim, head_size)
self.key_layers = nn.ModuleDict(
{domain: nn.Linear(domain_dim, head_size) for domain in domains}
{domain: nn.Linear(domain_dim, head_size) for domain in domain_names}
)
self.gw_state: torch.Tensor | None = None

def update_gw_state(self, gw_state: torch.Tensor) -> None:
"""
Set the internal copy of the fixed gw state. You're meant to only call this once
Args:
gw_state (`torch.Tensor`): the previous GW state
"""
self.gw_state = gw_state
# Start with a random gw state
self.register_buffer("initial_gw_state", torch.rand(domain_dim))

def forward(
self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT
) -> dict[str, torch.Tensor]:
"""
Compute keys and queries, match them with dot product and softmax.
Does this twice, once with the static query and once with a dynamic query.
Args:
domains (`LatentsDomainGroupT`): Group of unimodal latent representations.
encodings (`LatentsDomainGroupT`): Group of pre-fusion encodings.
Returns:
`dict[str, torch.Tensor]`: for each domain in the group, the fusion
coefficient for each item in the batch.
`dict[str, torch.Tensor]`: the attention scores for each domain in the
group.
"""

if self.gw_state is None:
raise ValueError("GW state has not been initialized.")

keys = {
domain: self.key_layers[domain](encoding)
for domain, encoding in domains.items()
}

device = group_device(domains)
query = self.query_layer(self.gw_state.to(device))

dot_products = {
domain: torch.bmm(key.unsqueeze(1), query.unsqueeze(2)).squeeze()
for domain, key in keys.items()
}

dot_products_tensor = torch.stack(list(dot_products.values()), dim=1)

attention_scores = torch.softmax(dot_products_tensor, dim=1)
batch_size = group_batch_size(domains)

attention_dict = {
domain: attention_scores[:, i : i + 1] for i, domain in enumerate(keys)
}
# Retrieve random query
query = self.query_layer(self.initial_gw_state.expand(batch_size, -1))

return attention_dict
# Calculate the attention scores
return _calculate_attention_dict(domains, keys, query)


class RandomSelection(SelectionBase):
Expand Down Expand Up @@ -262,35 +272,6 @@ def __init__(self, head_size: int, domain_dim: int, domain_names: Iterable[str])
# Start with a random gw state
self.register_buffer("initial_gw_state", torch.rand(domain_dim))

def calculate_attention_dict(
self,
domains: LatentsDomainGroupT,
keys: dict[str, torch.Tensor],
query: torch.Tensor,
) -> dict[str, torch.Tensor]:
"""
Args:
domains (`LatentsDomainGroupT`): Group of unimodal latent representations.
keys (`dict[str, torch.Tensor]`): The keys for each domain.
query (`torch.Tensor`): The query tensor.
Returns:
`dict[str, torch.Tensor]`: The attention scores for each domain.
"""
dot_products = {
domain: torch.bmm(key.unsqueeze(1), query.unsqueeze(2)).squeeze()
for domain, key in keys.items()
}

dot_products_tensor = torch.stack(list(dot_products.values()), dim=1)

attention_scores = torch.softmax(dot_products_tensor, dim=1)

attention_dict = {
domain: attention_scores[:, i] for i, domain in enumerate(domains)
}
return attention_dict

def fuse_weighted_encodings(
self, encodings: LatentsDomainGroupT, attention_dict: dict[str, torch.Tensor]
) -> torch.Tensor:
Expand Down Expand Up @@ -348,7 +329,7 @@ def forward(
query = self.query_layer(self.initial_gw_state.expand(batch_size, -1))

# Calculate the attention scores
static_attention_dict = self.calculate_attention_dict(domains, keys, query)
static_attention_dict = _calculate_attention_dict(domains, keys, query)

# Apply the attention scores to the encodings
summed_tensor = self.fuse_weighted_encodings(
Expand All @@ -359,6 +340,6 @@ def forward(
query = self.query_layer(summed_tensor)

# Calculate the attention scores again
dynamic_attention_dict = self.calculate_attention_dict(domains, keys, query)
dynamic_attention_dict = _calculate_attention_dict(domains, keys, query)

return dynamic_attention_dict
4 changes: 2 additions & 2 deletions tests/test_kq_onepass_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def test_single_domain():
batch_size = 2056
domains = ["v_latents"]

attention = KQFixedQSelection(domain_dim, head_size, domains)
attention = KQFixedQSelection(head_size, domain_dim, domains)
gw_state = torch.rand(batch_size, domain_dim)
attention.update_gw_state(gw_state)

Expand All @@ -28,7 +28,7 @@ def test_multiple_domains_sumis1():
head_size = 5
batch_size = 2056
domains = ["v_latents", "attr"]
attention = KQFixedQSelection(domain_dim, head_size, domains)
attention = KQFixedQSelection(head_size, domain_dim, domains)
gw_state = torch.rand(batch_size, domain_dim)
attention.update_gw_state(gw_state)

Expand Down

0 comments on commit d862b5d

Please sign in to comment.