diff --git a/shimmer/__init__.py b/shimmer/__init__.py index 53862080..b97d4cfc 100644 --- a/shimmer/__init__.py +++ b/shimmer/__init__.py @@ -15,7 +15,6 @@ GlobalWorkspace2Domains, GlobalWorkspaceBase, SchedulerArgs, - batch_broadcasts, batch_cycles, batch_demi_cycles, batch_translations, @@ -33,14 +32,12 @@ translation, ) from shimmer.modules.losses import ( - BroadcastLossCoefs, GWLosses2Domains, GWLossesBase, LossCoefs, combine_loss, ) from shimmer.modules.selection import ( - RandomSelection, SelectionBase, SingleDomainSelection, ) diff --git a/shimmer/modules/__init__.py b/shimmer/modules/__init__.py index cd5957e2..0e553096 100644 --- a/shimmer/modules/__init__.py +++ b/shimmer/modules/__init__.py @@ -9,7 +9,6 @@ GlobalWorkspace2Domains, GlobalWorkspaceBase, SchedulerArgs, - batch_broadcasts, batch_cycles, batch_demi_cycles, batch_translations, @@ -27,14 +26,12 @@ translation, ) from shimmer.modules.losses import ( - BroadcastLossCoefs, GWLosses2Domains, GWLossesBase, LossCoefs, combine_loss, ) from shimmer.modules.selection import ( - RandomSelection, SelectionBase, SingleDomainSelection, ) diff --git a/shimmer/modules/attention_module.py b/shimmer/modules/attention_module.py deleted file mode 100644 index 157645ff..00000000 --- a/shimmer/modules/attention_module.py +++ /dev/null @@ -1,363 +0,0 @@ -import random -from collections.abc import Callable, Mapping, Sequence -from typing import Any - -import torch -from lightning.pytorch import LightningModule -from lightning.pytorch.utilities.types import OptimizerLRSchedulerConfig -from torch import Tensor -from torch.optim.lr_scheduler import OneCycleLR - -from shimmer.modules.global_workspace import ( - GlobalWorkspaceBase, - GWModuleBase, - SchedulerArgs, -) -from shimmer.modules.losses import GWLossesBase -from shimmer.modules.selection import ( - SelectionBase, -) -from shimmer.types import ( - LatentsDomainGroupsDT, - LatentsDomainGroupsT, - RawDomainGroupsT, - RawDomainGroupT, -) -from shimmer.utils import group_device, groups_batch_size, groups_device - - -class AttentionBase(LightningModule): - """ - Attention Lightning Module. - - This is a wrapper around the different attention modules. - It is used to train an attention/selection mechanism. - """ - - def __init__( - self, - gw: GlobalWorkspaceBase[GWModuleBase, SelectionBase, GWLossesBase], - attention: SelectionBase, - domain_names: Sequence[str], - criterion: Callable[ - [torch.Tensor, RawDomainGroupT], tuple[torch.Tensor, torch.Tensor] - ], - domain_dim: int, - fixed_corruption_vector: torch.Tensor | None = None, - corruption_scaling: list[float] | None = None, - corrupt_single_side: str | None = None, - corrupt_sides: bool = False, - two_sided_corruption: dict[str, float] | None = None, - optim_lr: float = 1e-3, - optim_weight_decay: float = 0.0, - scheduler_args: SchedulerArgs | None = None, - ): - super().__init__() - self.save_hyperparameters( - ignore=[ - "gw", - "attention", - "criterion", - ] - ) - - self.gw = gw - self.attention = attention - self.domain_names = frozenset(domain_names) - self.list_domain_names = list(domain_names) - self.criterion = criterion - self.domain_dim = domain_dim - self.fixed_corruption_vector = fixed_corruption_vector - self.corruption_scaling = corruption_scaling - self.corrupt_single_side = corrupt_single_side - self.corrupt_sides = corrupt_sides - self.test_sides_corruption = two_sided_corruption - self.optim_lr = optim_lr - self.optim_weight_decay = optim_weight_decay - self.scheduler_args = SchedulerArgs(max_lr=optim_lr, total_steps=1) - if scheduler_args is not None: - self.scheduler_args.update(scheduler_args) - - def configure_optimizers(self) -> OptimizerLRSchedulerConfig: - """ - Configure models optimizers. - - Here we use `AdamW` for the optimizer and `OneCycleLR` for the learning-rate - scheduler. - """ - - optimizer = torch.optim.AdamW( - self.parameters(), - lr=self.optim_lr, - weight_decay=self.optim_weight_decay, - ) - - lr_scheduler = OneCycleLR(optimizer, **self.scheduler_args) - - return { - "optimizer": optimizer, - "lr_scheduler": { - "scheduler": lr_scheduler, - "interval": "step", - }, - } - - def forward( - self, - corrupted_batch: LatentsDomainGroupsT, - prefusion_encodings: LatentsDomainGroupsT, - ) -> LatentsDomainGroupsDT: - """ - Forward pass of the model. - - Args: - corrupted_batch: The input to the model. - prefusion_encodings: The pre-fusion encodings. - - Returns: - The attention scores. - """ - return { - domains: self.attention(latents, prefusion_encodings[domains]) - for domains, latents in corrupted_batch.items() - } - - def apply_one_sided_corruption( - self, - batch: LatentsDomainGroupsT, - ) -> LatentsDomainGroupsDT: - """ - Apply corruption to each tensor of the matched data - by use of masking. Only for two domains. - - Args: - batch: A batch of latent domains. - Returns: - A batch where either one (of the domains) of each tensor is corrupted. - """ - matched_data_dict: LatentsDomainGroupsDT = {} - - # Make a copy of the batch - for domain_names, domains in batch.items(): - for domain_name, domain in domains.items(): - matched_data_dict.setdefault(domain_names, {})[domain_name] = domain - continue - device = group_device(domains) - batch_size = groups_batch_size(batch) - n_domains = len(self.domain_names) - - # Check if the side that should be corrupted is given - if self.corrupt_single_side is not None: - corrupted_domain_index = self.list_domain_names.index( - self.corrupt_single_side - ) - masked_domains = torch.zeros(batch_size, n_domains, dtype=torch.bool) - masked_domains[:, corrupted_domain_index] = True - else: - selected_domains = torch.randint(0, n_domains, (batch_size,), device=device) - masked_domains = torch.nn.functional.one_hot( - selected_domains, n_domains - ).to(device, torch.bool) - - # Check if corruption is fixed or variable - if self.fixed_corruption_vector is not None: - corruption_vector = self.fixed_corruption_vector.expand( - batch_size, self.domain_dim - ) - else: - corruption_vector = torch.randn( - (batch_size, self.domain_dim), device=device - ) - - # Normalize the corruption vector - corruption_vector = ( - corruption_vector - corruption_vector.mean(dim=1, keepdim=True) - ) / corruption_vector.std(dim=1, keepdim=True) - - # Choose randomly from corruption scaling - amount_corruption = ( - random.choice(self.corruption_scaling) if self.corruption_scaling else 1.0 - ) - - # Scale the corruption vector based on the amount of corruption - scaled_corruption_vector = (corruption_vector * 5) * amount_corruption - for _, (domain_names, domains) in enumerate(matched_data_dict.items()): - if domain_names == self.domain_names: - for domain_name, domain in domains.items(): - if domain_name == self.list_domain_names[0]: - domain[masked_domains[:, 0]] += scaled_corruption_vector[ - masked_domains[:, 0] - ] - if domain_name == self.list_domain_names[1]: - domain[~masked_domains[:, 0]] += scaled_corruption_vector[ - ~masked_domains[:, 0] - ] - return matched_data_dict - - def apply_two_sided_corruption( - self, - batch: LatentsDomainGroupsT, - ) -> LatentsDomainGroupsDT: - """ - Apply corruption to each tensor of the matched data (two-sided corruption) - Only for two domains. - - Args: - batch: A batch of latent domains. - Returns: - A batch where either both sides of the domains are corrupted. - """ - matched_data_dict: LatentsDomainGroupsDT = {} - - # Make a copy of the batch - for domain_names, domains in batch.items(): - for domain_name, domain in domains.items(): - matched_data_dict.setdefault(domain_names, {})[domain_name] = domain - continue - device = groups_device(batch) - batch_size = groups_batch_size(batch) - n_domains = len(self.domain_names) - - corruption_matrices = {} - - # Check if a fixed or variable corruption vector should be used - for domain_idx in range(n_domains): - if self.fixed_corruption_vector is not None: - corruption_matrix = self.fixed_corruption_vector.expand( - batch_size, self.domain_dim - ).to(device) - else: - corruption_matrix = torch.randn( - (batch_size, self.domain_dim), device=device - ) - - # Normalize the corruption matrices - normalized_corruption_matrix = ( - corruption_matrix - corruption_matrix.mean(dim=1, keepdim=True) - ) / corruption_matrix.std(dim=1, keepdim=True) - - # Get the scaled corruption vector - if self.test_sides_corruption is not None: - scaled_corruption_matrix = ( - normalized_corruption_matrix * 5 - ) * self.test_sides_corruption[self.list_domain_names[domain_idx]] - else: - amount_corruption = ( - random.choice(self.corruption_scaling) - if self.corruption_scaling - else 1.0 - ) - scaled_corruption_matrix = ( - normalized_corruption_matrix * 5 - ) * amount_corruption - corruption_matrices[self.list_domain_names[domain_idx]] = ( - scaled_corruption_matrix - ) - - for domain_names, domains in matched_data_dict.items(): - if domain_names == self.domain_names: - for domain_name, domain in domains.items(): - if domain_name in corruption_matrices: - domain += corruption_matrices[domain_name] - return matched_data_dict - - def calculate_mean_attention( - self, - attention_scores: dict[frozenset[str], dict[str, Tensor]], - ) -> dict: - """ - Calculate the mean attention scores for each domain. - - Args: - attention_scores: The attention scores for each domain. - - Returns: - The mean attention scores for each domain. - """ - # Initialize variables to accumulate mean scores - mean_attention_dict = {} - - # Iterate through attention_dicts - for _, scores in attention_scores.items(): - # Check if more than 1 domains are present - if len(scores) > 1: - for key, values in scores.items(): - # Accumulate mean scores for each key - mean_score = values.mean().item() - mean_attention_dict[key] = mean_score - return mean_attention_dict - - def generic_step(self, batch: RawDomainGroupsT, mode: str) -> Tensor: - """ - Generic step used by lightning, used for training, validation and testing. - - Args: - batch: A batch of latent domains. - mode: The mode in which the model is currently in. - - Returns: - The loss of the model. - """ - latent_domains = self.gw.encode_domains(batch) - if self.corrupt_sides is True: - corrupted_batch = self.apply_two_sided_corruption(latent_domains) - else: - corrupted_batch = self.apply_one_sided_corruption(latent_domains) - prefusion_encodings = self.gw.encode(corrupted_batch) - attention_scores = self.forward(corrupted_batch, prefusion_encodings) - merged_gw_representation = self.gw.fuse(prefusion_encodings, attention_scores) - losses = [] - accuracies = [] - - for domain_names, domains in merged_gw_representation.items(): - loss, accuracy = self.criterion(domains, batch[domain_names]) - losses.append(loss) - accuracies.append(accuracy) - domain_names_str = ",".join(domain_names) - self.log( - f"{mode}/{domain_names_str}_loss", - losses[-1], - batch_size=domains.size(0), - ) - self.log( - f"{mode}/{domain_names_str}_accuracy", - accuracies[-1], - batch_size=domains.size(0), - ) - mean_attention_scores = self.calculate_mean_attention(attention_scores) - for domain_name, score in mean_attention_scores.items(): - self.log(f"{mode}/{domain_name}_mean_attention_score", score) - - loss = torch.stack(losses).mean() - self.log(f"{mode}/loss", loss, on_step=True, on_epoch=True) - self.log(f"{mode}/accuracy", torch.stack(accuracies).mean()) - - return loss - - def training_step( - self, batch: RawDomainGroupsT, batch_idx: int - ) -> Tensor | Mapping[str, Any] | None: # type: ignore - return self.generic_step(batch, "train") - - def validation_step( # type: ignore - self, data: RawDomainGroupT, batch_idx: int, dataloader_idx: int = 0 - ) -> torch.Tensor: - """Validation step used by lightning""" - - batch = {frozenset(data.keys()): data} - for domain in data: - batch[frozenset([domain])] = {domain: data[domain]} - if dataloader_idx == 0: - return self.generic_step(batch, mode="val") - return self.generic_step(batch, mode="val/ood") - - def test_step( # type: ignore - self, data: RawDomainGroupT, batch_idx: int, dataloader_idx: int = 0 - ) -> torch.Tensor: - """Test step used by lightning""" - batch = {frozenset(data.keys()): data} - for domain in data: - batch[frozenset([domain])] = {domain: data[domain]} - if dataloader_idx == 0: - return self.generic_step(batch, mode="test") - return self.generic_step(batch, mode="test/ood") diff --git a/shimmer/modules/global_workspace.py b/shimmer/modules/global_workspace.py index 9ca55fe7..7cf73fb2 100644 --- a/shimmer/modules/global_workspace.py +++ b/shimmer/modules/global_workspace.py @@ -17,19 +17,15 @@ GWModule, GWModuleBase, GWModulePrediction, - broadcast_cycles, cycle, translation, ) from shimmer.modules.losses import ( - BroadcastLossCoefs, - GWLosses, GWLosses2Domains, GWLossesBase, LossCoefs, ) from shimmer.modules.selection import ( - RandomSelection, SelectionBase, SingleDomainSelection, ) @@ -177,35 +173,6 @@ def batch_translations( return predictions -def batch_broadcasts( - gw_mod: GWModuleBase, - selection_mod: SelectionBase, - latent_domains: LatentsDomainGroupsT, -) -> tuple[ - dict[frozenset[str], dict[str, torch.Tensor]], - dict[frozenset[str], dict[str, torch.Tensor]], -]: - """ - Computes all possible broadcast of a batch for each group of domains. - - Args: - gw_mod (`GWModuleBase`): the GWModuleBase - selection_mod (`SelectionBase`): selection module - latent_domains (`LatentsT`): the batch of groups of domains - - Returns: - `tuple[dict[frozenset[str], dict[str, torch.Tensor]], - dict[frozenset[str], dict[str, torch.Tensor]], ]`: broadcast predictions - for each domain.""" - predictions: dict[frozenset[str], dict[str, torch.Tensor]] = {} - cycles: dict[frozenset[str], dict[str, torch.Tensor]] = {} - for domains, latents in latent_domains.items(): - pred_broadcast, pred_cycles = broadcast_cycles(gw_mod, selection_mod, latents) - predictions[domains] = pred_broadcast - cycles[domains] = pred_cycles - return predictions, cycles - - class OneCycleSchedulerSentinel(Enum): """ Used for backward-compatibility issues to use One-Cycle Scheduler by default @@ -721,88 +688,6 @@ def __init__( ) -class GlobalWorkspace(GlobalWorkspaceBase[GWModule, RandomSelection, GWLosses]): - """The 2-domain fusion (with broadcast loss) flavor of GlobalWorkspaceBase. - - This is used to simplify a Global Workspace instanciation and only overrides the - `__init__` method. - """ - - def __init__( - self, - domain_mods: Mapping[str, DomainModule], - gw_encoders: Mapping[str, Module], - gw_decoders: Mapping[str, Module], - workspace_dim: int, - loss_coefs: BroadcastLossCoefs | Mapping[str, float], - selection_temperature: float = 0.2, - optim_lr: float = 1e-3, - optim_weight_decay: float = 0.0, - scheduler_args: SchedulerArgs | None = None, - learn_logit_scale: bool = False, - contrastive_loss: ContrastiveLossType | None = None, - scheduler: Callable[[Optimizer], LRScheduler] - | None - | OneCycleSchedulerSentinel = OneCycleSchedulerSentinel.DEFAULT, - fusion_activation_fn: Callable[[torch.Tensor], torch.Tensor] = torch.tanh, - ) -> None: - """ - Initializes a Global Workspace - - Args: - domain_mods (`Mapping[str, DomainModule]`): mapping of the domains - connected to the GW. Keys are domain names, values are the - `DomainModule`. - gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain - name to a `torch.nn.Module` class which role is to encode a - unimodal latent representations into a GW representation (pre fusion). - gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain - name to a `torch.nn.Module` class which role is to decode a - GW representation into a unimodal latent representations. - workspace_dim (`int`): dimension of the GW. - loss_coefs (`BroadcastLossCoefs | Mapping[str, float]`): loss coefs for the - losses. - selection_temperature (`float`): temperature value for the RandomSelection - module. - optim_lr (`float`): learning rate - optim_weight_decay (`float`): weight decay - scheduler_args (`SchedulerArgs | None`): optimization scheduler's arguments - learn_logit_scale (`bool`): whether to learn the contrastive learning - contrastive loss when using the default contrastive loss. - contrastive_loss (`ContrastiveLossType | None`): a contrastive loss - function used for alignment. `learn_logit_scale` will not affect custom - contrastive losses. - scheduler: The scheduler to use for traning. If None is explicitely given, - no scheduler will be used. Defaults to use OneCycleScheduler - fusion_activation_fn (`Callable[[torch.Tensor], torch.Tensor]`): activation - function to fuse the domains. - """ - domain_mods = freeze_domain_modules(domain_mods) - gw_mod = GWModule( - domain_mods, workspace_dim, gw_encoders, gw_decoders, fusion_activation_fn - ) - - if contrastive_loss is None: - contrastive_loss = ContrastiveLoss( - torch.tensor([1 / 0.07]).log(), "mean", learn_logit_scale - ) - - selection_mod = RandomSelection(selection_temperature) - loss_mod = GWLosses( - gw_mod, selection_mod, domain_mods, loss_coefs, contrastive_loss - ) - - super().__init__( - gw_mod, - selection_mod, - loss_mod, - optim_lr, - optim_weight_decay, - scheduler_args, - scheduler, - ) - - def pretrained_global_workspace( checkpoint_path: str | Path, domain_mods: Mapping[str, DomainModule], diff --git a/shimmer/modules/losses.py b/shimmer/modules/losses.py index e47c0293..c5fa64ab 100644 --- a/shimmer/modules/losses.py +++ b/shimmer/modules/losses.py @@ -1,6 +1,5 @@ from abc import ABC, abstractmethod -from collections.abc import Generator, Mapping -from itertools import product +from collections.abc import Mapping from typing import TypedDict import torch @@ -304,34 +303,9 @@ class LossCoefs(TypedDict, total=False): """Contrastive loss coefficient.""" -class BroadcastLossCoefs(TypedDict, total=False): - """ - Dict of loss coefficients used in the GWLossesFusion. - - If one is not provided, the coefficient is assumed to be 0 and will not be logged. - If the loss is excplicitely set to 0, it will be logged, but not take part in - the total loss. - """ - - contrastives: float - """Contrastive loss coefficient.""" - - fused: float - """fused loss coefficient (encode multiple domains and decode to one of them).""" - - demi_cycles: float - """demi_cycles loss coefficient. Demi-cycles are always one-to-one""" - - cycles: float - """cycles loss coefficient. Cycles can be many-to-one""" - - translations: float - """translation loss coefficient. Translation, like cycles, can be many-to-one.""" - - def combine_loss( metrics: dict[str, torch.Tensor], - coefs: Mapping[str, float] | LossCoefs | BroadcastLossCoefs, + coefs: Mapping[str, float] | LossCoefs, ) -> torch.Tensor: """ Combines the metrics according to the ones selected in coefs @@ -496,276 +470,3 @@ def step( metrics.update(self.contrastive_loss(domain_latents)) return LossOutput(combine_loss(metrics, self.loss_coefs), metrics) - - -def generate_partitions(n: int) -> Generator[tuple[int, ...], None, None]: - """ - Generates all possible partitions of zeros and ones for `n` elements, - excluding the all-zeros partition. - - Args: - n (`int`): The number of modalities to generate partitions for. - - Yields: - `tuple[int, ...]`: A partition of zeros and ones, excluding the - all-zeros partition. - """ - for perm in product([0, 1], repeat=n): - if any(perm): - yield perm - - -def broadcast_loss( - gw_mod: GWModuleBase, - selection_mod: SelectionBase, - domain_mods: Mapping[str, DomainModule], - latent_domains: LatentsDomainGroupsT, - raw_data: RawDomainGroupsT, -) -> dict[str, torch.Tensor]: - """ - Computes broadcast loss including demi-cycle, cycle, and translation losses. - - This return multiple metrics: - * `demi_cycles` - * `cycles` - * `translations` - * `fused` - * `from_{start_group}_to_{domain}_loss` where `{start_group}` is of the form - "{domain1,domain2,domainN}" sorted in alphabetical order - (e.g. "from_{t,v}_to_t_loss"). - * `from_{start_group}_to_{domain}_{metric}` with - additional metrics provided by the domain_mod's - `compute_broadcast_loss` output - * `from_{start_group}_through_{target_group}_to_{domain}_case_{case_group}_loss` - where `{start_group}`, `{target_group}` and `{case_group}` is of the form - "{domain1,domain2,domainN}" sorted in alphabetical order - (e.g. "from_{t}_through_{v}_to_t_case_{t,v}_loss"). `{start_group}` represents the input - domains, `{target_group}` the target domains used for the cycle and - `{case_group}` all available domains participating to the loss. - * `from_{start_group}_through_{target_group}_to_{domain}_case_{case_group}_{metric}` - additional metrics provided by the domain_mod's `compute_broadcast_loss` - output - - Args: - gw_mod (`shimmer.modules.gw_module.GWModuleBase`): The GWModule to use - selection_mod (`shimmer.modules.selection.SelectionBase`): Selection mod to use - domain_mods (`Mapping[str, DomainModule]`): the domain modules - latent_domains: The latent domain representations. - raw_data (`RawDomainGroupsT`): raw input data - - Returns: - A dictionary with the total loss and additional metrics. - """ # noqa: E501 - losses: dict[str, torch.Tensor] = {} - metrics: dict[str, torch.Tensor] = {} - - demi_cycle_losses: list[str] = [] - cycle_losses: list[str] = [] - translation_losses: list[str] = [] - fused_losses: list[str] = [] - - for group_domains, latents in latent_domains.items(): - encoded_latents = gw_mod.encode(latents) - partitions = generate_partitions(len(group_domains)) - group_name = "{" + ",".join(sorted(group_domains)) + "}" - - for partition in partitions: - selected_latents = { - domain: latents[domain] - for domain, present in zip(latents, partition, strict=True) - if present - } - selected_encoded_latents = { - domain: encoded_latents[domain] for domain in selected_latents - } - selected_group_label = "{" + ",".join(sorted(selected_latents)) + "}" - - selection_scores = selection_mod(selected_latents, selected_encoded_latents) - fused_latents = gw_mod.fuse(selected_encoded_latents, selection_scores) - decoded_latents = gw_mod.decode(fused_latents) - - num_active_domains = sum(partition) - num_total_domains = len(decoded_latents) - - for domain, pred in decoded_latents.items(): - if domain not in group_domains: # if we don't have ground truth - continue - ground_truth = latents[domain] - - if num_active_domains == 1 and domain in selected_latents: - loss_fn = domain_mods[domain].compute_dcy_loss - elif domain not in selected_latents: - loss_fn = domain_mods[domain].compute_tr_loss - else: - loss_fn = domain_mods[domain].compute_fused_loss - - loss_output = loss_fn( - pred, ground_truth, raw_data[group_domains][domain] - ) - if loss_output is None: - continue - - loss_label = f"from_{selected_group_label}_to_{domain}" - losses[loss_label + "_loss"] = loss_output.loss - metrics.update( - {f"{loss_label}_{k}": v for k, v in loss_output.metrics.items()} - ) - - if num_active_domains == 1 and domain in selected_latents: - demi_cycle_losses.append(loss_label + "_loss") - elif domain not in selected_latents: - translation_losses.append(loss_label + "_loss") - else: # fused loss - fused_losses.append(loss_label + "_loss") - - if num_active_domains < num_total_domains: - inverse_selected_latents = { - domain: decoded_latents[domain] - for domain in decoded_latents - if domain not in selected_latents - } - - inverse_selected_group_label = ( - "{" + ",".join(sorted(inverse_selected_latents)) + "}" - ) - - re_encoded_latents = gw_mod.encode(inverse_selected_latents) - re_selection_scores = selection_mod( - inverse_selected_latents, re_encoded_latents - ) - re_fused_latents = gw_mod.fuse(re_encoded_latents, re_selection_scores) - re_decoded_latents = gw_mod.decode( - re_fused_latents, domains=selected_latents.keys() - ) - - for domain in selected_latents: - re_ground_truth = latents[domain] - re_loss_output = domain_mods[domain].compute_cy_loss( - re_decoded_latents[domain], - re_ground_truth, - raw_data[group_domains][domain], - ) - if re_loss_output is None: - continue - loss_label = ( - f"from_{selected_group_label}_" - f"through_{inverse_selected_group_label}_to_{domain}_" - f"case_{group_name}" - ) - losses[loss_label + "_loss"] = re_loss_output.loss - metrics.update( - { - f"{loss_label}_{k}": v - for k, v in re_loss_output.metrics.items() - } - ) - cycle_losses.append(loss_label + "_loss") - - if demi_cycle_losses: - metrics["demi_cycles"] = torch.mean( - torch.stack([losses[loss_name] for loss_name in demi_cycle_losses]) - ) - if cycle_losses: - metrics["cycles"] = torch.mean( - torch.stack([losses[loss_name] for loss_name in cycle_losses]) - ) - if translation_losses: - metrics["translations"] = torch.mean( - torch.stack([losses[loss_name] for loss_name in translation_losses]) - ) - if fused_losses: - metrics["fused"] = torch.mean( - torch.stack([losses[loss_name] for loss_name in fused_losses]) - ) - - metrics.update(losses) - return metrics - - -class GWLosses(GWLossesBase): - """ - Implementation of `GWLossesBase` for fusion-based models. - """ - - def __init__( - self, - gw_mod: GWModule, - selection_mod: SelectionBase, - domain_mods: dict[str, DomainModule], - loss_coefs: BroadcastLossCoefs | Mapping[str, float], - contrastive_fn: ContrastiveLossType, - ): - """ - Initializes the loss computation module for a Global Workspace Fusion model. - - Args: - gw_mod: The GWModule for the global workspace. - selection_mod: The selection mechanism for the model. - domain_mods: A mapping of domain names to their respective DomainModule. - loss_coefs (`BroadcastLossCoefs`): coefs for the losses - contrastive_fn: The function used for computing contrastive loss. - """ - super().__init__() - self.gw_mod = gw_mod - self.selection_mod = selection_mod - self.domain_mods = domain_mods - self.loss_coefs = loss_coefs - self.contrastive_fn = contrastive_fn - - def contrastive_loss( - self, latent_domains: LatentsDomainGroupsT - ) -> dict[str, torch.Tensor]: - """ - Computes the contrastive loss for the given latent domains. - - Args: - latent_domains: The latent domain representations. - - Returns: - A dictionary of contrastive loss metrics. - """ - - return contrastive_loss(self.gw_mod, latent_domains, self.contrastive_fn) - - def broadcast_loss( - self, latent_domains: LatentsDomainGroupsT, raw_data: RawDomainGroupsT - ) -> dict[str, torch.Tensor]: - return broadcast_loss( - self.gw_mod, self.selection_mod, self.domain_mods, latent_domains, raw_data - ) - - def step( - self, - raw_data: RawDomainGroupsT, - domain_latents: LatentsDomainGroupsT, - mode: ModelModeT, - ) -> LossOutput: - """ - Performs a step of loss computation. - - Args: - raw_data (`RawDomainGroupsT`): raw input data - domain_latents: Latent representations for all domains. - mode: The mode in which the model is currently operating. - - Returns: - A LossOutput object containing the loss and metrics for this step. - """ - - metrics: dict[str, torch.Tensor] = {} - - metrics.update(self.contrastive_loss(domain_latents)) - metrics.update(self.broadcast_loss(domain_latents, raw_data)) - - loss = combine_loss(metrics, self.loss_coefs) - - metrics["broadcast_loss"] = torch.stack( - [ - metrics[name] - for name, coef in self.loss_coefs.items() - if isinstance(coef, float) and coef > 0 and name != "contrastives" - ], - dim=0, - ).mean() - - return LossOutput(loss, metrics) diff --git a/shimmer/modules/selection.py b/shimmer/modules/selection.py index ac03bdd1..0ef17288 100644 --- a/shimmer/modules/selection.py +++ b/shimmer/modules/selection.py @@ -1,8 +1,6 @@ from abc import ABC, abstractmethod -from collections.abc import Iterable import torch -import torch.nn as nn from shimmer.types import LatentsDomainGroupT from shimmer.utils import group_batch_size, group_device @@ -121,203 +119,3 @@ def forward( for domain in domains: selection[domain] = coef.clone() return selection - - -def _calculate_attention_dict( - domains: LatentsDomainGroupT, - keys: dict[str, torch.Tensor], - query: torch.Tensor, -) -> dict[str, torch.Tensor]: - """ - Args: - domains (`LatentsDomainGroupT`): Group of unimodal latent representations. - keys (`dict[str, torch.Tensor]`): The keys for each domain. - query (`torch.Tensor`): The query tensor. - - Returns: - `dict[str, torch.Tensor]`: The attention scores for each domain. - """ - dot_products = { - domain: torch.bmm(key.unsqueeze(1), query.unsqueeze(2)).squeeze() - for domain, key in keys.items() - } - - dot_products_tensor = torch.stack(list(dot_products.values()), dim=1) - - attention_scores = torch.softmax(dot_products_tensor, dim=1) - - attention_dict = { - domain: attention_scores[:, i] for i, domain in enumerate(domains) - } - return attention_dict - - -class RandomSelection(SelectionBase): - """ - Modified random attention to only utilize uniform-softmax scores across modalities. - This version omits the binary scaling factors and focuses on generating attention - coefficients using a uniform distribution followed by a domain-wise softmax. - """ - - def __init__(self, temperature: float): - """ - Args: - temperature (`float`): Temperature of the softmax applied to uniform - scaling factors. - """ - super().__init__() - self.temperature = temperature - - def forward( - self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT - ) -> dict[str, torch.Tensor]: - """ - Generate uniform-then-domain-wise-softmaxed samples for each domain. - - Args: - domains (`LatentsDomainGroupT`): Group of unimodal latent representations. - This is not used in the function directly but determines the structure - of the returned attention coefficients. - - Returns: - `dict[str, torch.Tensor]`: For each domain in the group, the fusion - coefficient for each item in the batch, based solely on - uniform-softmax scores. - """ - num_domains = len(domains) - batch_size = group_batch_size(domains) - device = group_device(domains) - - # Generate uniform scores - uniform_scores = torch.rand(batch_size, num_domains, device=device) - - # Apply softmax across domains with temperature scaling - softmax_scores = torch.softmax(uniform_scores / self.temperature, dim=1) - # Create attention dictionary for each domain - attention_dict = { - domain: softmax_scores[:, i] for i, domain in enumerate(domains) - } - - return attention_dict - - -class DynamicQueryAttention(SelectionBase): - """ - Key-Query attention with a dynamic gw vector. - The query is updated based on the scaled gw vector. - """ - - def __init__( - self, - head_size: int, - domain_dim: int, - domain_names: Iterable[str], - n_steps: int = 1, - ): - """ - Args: - head_size (`int`) : dimension of the key and query vectors. - domain_dim (`int`) : dimension of the input dims (assumed to be the same - for now) - domain_names (`Iterable[str]`) : list of input domains - n_steps (`int`) : number of steps to update the query vector - """ - super().__init__() - self.head_size = head_size - self.query_layer = nn.Linear(domain_dim, head_size) - self.key_layers = nn.ModuleDict( - {domain: nn.Linear(domain_dim, head_size) for domain in domain_names} - ) - self.n_steps = n_steps - self.step_limit = n_steps # Default step limit is n_steps - # Start with a random gw state - self.register_buffer("initial_gw_state", torch.rand(domain_dim)) - - def set_step_limit(self, step_limit: int): - """ - Sets the step limit for the dynamic attention update loop. - - Args: - step_limit (`int`): Maximum number of steps to run the loop. - """ - if step_limit > self.n_steps: - raise ValueError( - f"Step limit cannot exceed the maximum n_steps ({self.n_steps})." - ) - self.step_limit = step_limit - - def fuse_weighted_encodings( - self, encodings: LatentsDomainGroupT, attention_dict: dict[str, torch.Tensor] - ) -> torch.Tensor: - """ - Fuse the weighted encodings using the attention scores. - - Args: - encodings (`LatentsDomainGroupT`): Unimodal latent representation - attention_dict (`dict[str, torch.Tensor]`): The attention scores for each - domain in the group. - - Returns: - `torch.Tensor`: The fused tensor. - """ - # Apply attention scores to the encodings - weighted_encodings = {} - for key in attention_dict: - if key in encodings: - # Perform element-wise multiplication - weighted_encodings[key] = ( - attention_dict[key].unsqueeze(1) * encodings[key] - ) - - # Stack the tensors along a new dimension (dimension 0) - stacked_tensors = torch.stack(list(weighted_encodings.values())) - - # Apply fusion by summing along the newly created dimension - summed_tensor = torch.sum(stacked_tensors, dim=0) - return summed_tensor - - def forward( - self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT - ) -> dict[str, torch.Tensor]: - """ - Compute keys and queries, match them with dot product and softmax. - Does this twice, once with the static query and once with a dynamic query. - - Args: - domains (`LatentsDomainGroupT`): Group of unimodal latent representations. - encodings_pre_fusion (`LatentsDomainGroupT`): pre-fusion domain latent - representation. - - Returns: - `dict[str, torch.Tensor]`: the attention scores for each domain in the - group. - """ - - keys = { - domain: self.key_layers[domain](encoding) - for domain, encoding in domains.items() - } - - batch_size = group_batch_size(domains) - - # Retrieve random query - query = self.query_layer(self.initial_gw_state.expand(batch_size, -1)) - - # Calculate the attention scores - attention_dict = _calculate_attention_dict(domains, keys, query) - - if self.n_steps > 0: - # Update the query based on the static attention scores - for _ in range(min(self.step_limit, self.n_steps)): - # Apply the attention scores to the encodings - summed_tensor = self.fuse_weighted_encodings( - encodings_pre_fusion, attention_dict - ) - - # Retrieve query (now it is dependent on the new gw state) - query = self.query_layer(summed_tensor) - - # Calculate the attention scores again - attention_dict = _calculate_attention_dict(domains, keys, query) - - return attention_dict diff --git a/tests/save_model.py b/tests/save_model.py index e8517e6b..7c9f25d1 100644 --- a/tests/save_model.py +++ b/tests/save_model.py @@ -4,7 +4,6 @@ from utils import DummyDomainModule from shimmer import GlobalWorkspace2Domains, GWDecoder, GWEncoder -from shimmer.modules.global_workspace import GlobalWorkspace here = Path(__file__).parent @@ -54,19 +53,11 @@ def save_gw_ckpt(): workspace_dim=16, loss_coefs={}, ) - gw = GlobalWorkspace( - domains, - gw_encoders, - gw_decoders, - workspace_dim=16, - loss_coefs={}, - ) torch.save( {"state_dict": gw_2_domains.state_dict()}, here / "data" / "old_gw_2_domains.ckpt", ) - torch.save({"state_dict": gw.state_dict()}, here / "data" / "old_gw.ckpt") if __name__ == "__main__": diff --git a/tests/test_broadcast.py b/tests/test_broadcast.py deleted file mode 100644 index bde29e8d..00000000 --- a/tests/test_broadcast.py +++ /dev/null @@ -1,82 +0,0 @@ -from typing import Any - -import torch -from torch import nn - -from shimmer.modules.domain import DomainModule, LossOutput -from shimmer.modules.global_workspace import GlobalWorkspace -from shimmer.modules.losses import BroadcastLossCoefs - - -class DummyDomainModule(DomainModule): - def __init__(self, latent_dim: int): - super().__init__(latent_dim) - self.encoder = nn.Linear(latent_dim, latent_dim) # Simplified encoder - self.decoder = nn.Linear(latent_dim, latent_dim) # Simplified decoder - - def encode(self, x: torch.Tensor) -> torch.Tensor: - return self.encoder(x) # Simple forward pass through encoder - - def decode(self, z: torch.Tensor) -> torch.Tensor: - return self.decoder(z) # Simple forward pass through decoder - - def compute_loss( - self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any - ) -> LossOutput: - loss = torch.mean((pred - target) ** 2) # Simple MSE loss - return LossOutput(loss=loss) # Constructing LossOutput with the loss - - -def test_broadcast_loss(): - domain_mods: dict[str, DomainModule] = { - "domain1": DummyDomainModule(latent_dim=10), - "domain2": DummyDomainModule(latent_dim=10), - } - gw_encoders = {"domain1": nn.Linear(10, 10), "domain2": nn.Linear(10, 10)} - gw_decoders = {"domain1": nn.Linear(10, 10), "domain2": nn.Linear(10, 10)} - workspace_dim = 10 - loss_coefs: BroadcastLossCoefs = { - "fused": 1.0, - "cycles": 1.0, - "demi_cycles": 1.0, - "translations": 1.0, - "contrastives": 0.1, - } - - gw_fusion = GlobalWorkspace( - domain_mods, - gw_encoders, - gw_decoders, - workspace_dim, - loss_coefs, - selection_temperature=0.2, - optim_lr=1e-3, - optim_weight_decay=0.0, - scheduler_args=None, # Simplified for testing - learn_logit_scale=False, - ) - - # Adjusting the dummy data to fit the expected input structure for broadcast_loss - # Now using a frozenset for the keys to match LatentsDomainGroupsT - latent_domains = { - frozenset(["domain1", "domain2"]): { - "domain1": torch.rand(5, 10), # Batch size of 5, feature dimension of 10 - "domain2": torch.rand(5, 10), - } - } - - # Test broadcast_loss with the corrected structure - output = gw_fusion.loss_mod.broadcast_loss(latent_domains, latent_domains) - - er_msg = "Demi-cycle, cycle, fused and translation metrics should be in the output." - assert all( - metric in output - for metric in ["demi_cycles", "cycles", "translations", "fused"] - ), er_msg - - er_msg = "Losses should be scalar tensors or 1D tensor with size equal to one." - assert all( - (loss.dim() == 0 or (loss.dim() == 1 and loss.size(0) == 1)) - for key, loss in output.items() - if key.endswith("_loss") - ), er_msg diff --git a/tests/test_ckpt_migrations.py b/tests/test_ckpt_migrations.py index 09cb98fc..1eb233ca 100644 --- a/tests/test_ckpt_migrations.py +++ b/tests/test_ckpt_migrations.py @@ -6,7 +6,6 @@ from utils import DummyDomainModule from shimmer import GlobalWorkspace2Domains, GWDecoder, GWEncoder -from shimmer.modules.global_workspace import GlobalWorkspace from shimmer.utils import MIGRATION_DIR here = Path(__file__).parent @@ -70,63 +69,3 @@ def test_ckpt_migration_2_domains(): ) gw.load_state_dict(new_ckpt["state_dict"]) - - -def test_ckpt_migration_gw(): - old_ckpt_path = here / "data" / "old_gw.ckpt" - old_ckpt = torch.load(old_ckpt_path, weights_only=True) - new_ckpt, done_migrations = migrate_from_folder(old_ckpt, MIGRATION_DIR) - - old_keys = set(old_ckpt["state_dict"].keys()) - new_keys = set(new_ckpt["state_dict"].keys()) - print(f"Removed keys: {old_keys - new_keys}") - print(f"New keys: {new_keys - old_keys}") - - print("Done migrations:", ", ".join(map(lambda x: x.name, done_migrations))) - - domains = { - "v": DummyDomainModule(latent_dim=32), - "t": DummyDomainModule(latent_dim=32), - } - - workspace_dim = 16 - - gw_encoders = { - "v": GWEncoder( - domains["v"].latent_dim, - hidden_dim=64, - out_dim=workspace_dim, - n_layers=1, - ), - "t": GWEncoder( - domains["t"].latent_dim, - hidden_dim=64, - out_dim=workspace_dim, - n_layers=1, - ), - } - - gw_decoders = { - "v": GWDecoder( - workspace_dim, - hidden_dim=64, - out_dim=domains["v"].latent_dim, - n_layers=1, - ), - "t": GWDecoder( - workspace_dim, - hidden_dim=64, - out_dim=domains["t"].latent_dim, - n_layers=1, - ), - } - - gw = GlobalWorkspace( - domains, - gw_encoders, - gw_decoders, - workspace_dim=16, - loss_coefs={}, - ) - - gw.load_state_dict(new_ckpt["state_dict"]) diff --git a/tests/test_kq_onepass_attention.py b/tests/test_kq_onepass_attention.py deleted file mode 100644 index 74b8b8e6..00000000 --- a/tests/test_kq_onepass_attention.py +++ /dev/null @@ -1,55 +0,0 @@ -import torch - -from shimmer.modules.selection import DynamicQueryAttention - - -def test_single_domain(): - domain_dim = 12 - head_size = 6 - batch_size = 2056 - domains = ["v_latents"] - - attention = DynamicQueryAttention(head_size, domain_dim, domains, n_steps=0) - gw_state = torch.rand(batch_size, domain_dim) - attention.update_gw_state(gw_state) - - single_domain_input = {"v_latents": torch.rand(batch_size, domain_dim)} - encodings_pre_fusion = {"v_latents": torch.rand(batch_size, domain_dim)} - attention_scores = attention(single_domain_input, encodings_pre_fusion) - - expected_scores = torch.ones(batch_size, 1) - assert torch.allclose( - attention_scores["v_latents"], expected_scores - ), "Attention scores for single domain should be all 1s" - - -def test_multiple_domains_sumis1(): - domain_dim = 12 - head_size = 5 - batch_size = 2056 - domains = ["v_latents", "attr"] - attention = DynamicQueryAttention(head_size, domain_dim, domains, n_steps=0) - gw_state = torch.rand(batch_size, domain_dim) - attention.update_gw_state(gw_state) - - multiple_domain_input = { - "v_latents": torch.rand(batch_size, domain_dim), - "attr": torch.rand(batch_size, domain_dim), - } - encodings_pre_fusion = { - "v_latents": torch.rand(batch_size, domain_dim), - "attr": torch.rand(batch_size, domain_dim), - } - - attention_scores = attention(multiple_domain_input, encodings_pre_fusion) - - scores_sum = sum( - attention_scores[domain].squeeze() for domain in multiple_domain_input - ) - assert isinstance(scores_sum, torch.Tensor) - - expected_sum = torch.ones(batch_size) - - assert torch.allclose( - scores_sum, expected_sum - ), "Sum of attention scores across domains should be 1" diff --git a/tests/test_query_attention.py b/tests/test_query_attention.py deleted file mode 100644 index 61d99bd2..00000000 --- a/tests/test_query_attention.py +++ /dev/null @@ -1,118 +0,0 @@ -import torch - -from shimmer.modules.selection import DynamicQueryAttention - - -def test_non_random(): - x = { - frozenset(["vision"]): { - "vision": torch.Tensor([[1.0, 0.0, 0.3], [1.0, 0.0, 0.3]]), - }, - frozenset(["language"]): { - "language": torch.Tensor([[1.0, 0.2, 0.9], [1.0, 0.2, 0.9]]), - }, - frozenset(["vision", "language"]): { - "vision": torch.Tensor([[1.0, 0.0, 0.3], [1.0, 0.0, 0.3]]), - "language": torch.Tensor([[1.0, 0.2, 0.9], [1.0, 0.2, 0.9]]), - }, - } - - y = { - frozenset(["vision"]): { - "vision": torch.Tensor([[1.0, 0.0, 0.3], [1.0, 0.0, 0.3]]), - }, - frozenset(["language"]): { - "language": torch.Tensor([[1.0, 0.2, 0.9], [1.0, 0.2, 0.9]]), - }, - frozenset(["vision", "language"]): { - "vision": torch.Tensor([[1.0, 0.0, 0.3], [1.0, 0.0, 0.3]]), - "language": torch.Tensor([[1.0, 0.2, 0.9], [1.0, 0.2, 0.9]]), - }, - } - domain_dim = 3 - head_size = 5 - domain_names = ["vision", "language"] - batch_size = 2 - attention = DynamicQueryAttention(head_size, domain_dim, domain_names) - - attention_scores = { - domains: attention(latents, y[domains]) for domains, latents in x.items() - } - assert attention_scores[frozenset(["vision"])]["vision"].shape == torch.Size( - [batch_size] - ) - assert attention_scores[frozenset(["vision", "language"])][ - "vision" - ].shape == torch.Size([batch_size]) - - -def test_single_domain(): - batch_size = 2056 - domain_dim = 12 - head_size = 6 - domain_names = ["v_latents"] - attention = DynamicQueryAttention(head_size, domain_dim, domain_names) - - single_domain_input = {"v_latents": torch.rand(batch_size, domain_dim)} - prefusion_encodings = {"v_latents": torch.rand(batch_size, domain_dim)} - - attention_scores = attention(single_domain_input, prefusion_encodings) - - expected_scores = torch.ones(batch_size, 1) - assert torch.allclose( - attention_scores["v_latents"], expected_scores - ), "Attention scores for single domain should be all 1s" - - -def test_multiple_domains_sumis1(): - domain_dim = 12 - head_size = 5 - batch_size = 2056 - domain_names = ["v_latents", "attr"] - - attention = DynamicQueryAttention(head_size, domain_dim, domain_names) - - multiple_domain_input = { - "v_latents": torch.rand(batch_size, domain_dim), - "attr": torch.rand(batch_size, domain_dim), - } - prefusion_encodings = { - "v_latents": torch.rand(batch_size, domain_dim), - "attr": torch.rand(batch_size, domain_dim), - } - attention_scores = attention(multiple_domain_input, prefusion_encodings) - - scores_sum = sum( - attention_scores[domain].squeeze() for domain in multiple_domain_input - ) - assert isinstance(scores_sum, torch.Tensor) - - expected_sum = torch.ones(batch_size) - - assert torch.allclose( - scores_sum, expected_sum - ), "Sum of attention scores across domains should be 1" - - -def test_attention_backward(): - domain_dim = 12 - head_size = 6 - batch_size = 2056 - domain_names = ["v_latents", "attr"] - - attention = DynamicQueryAttention(head_size, domain_dim, domain_names) - - domains = { - "v_latents": torch.rand(batch_size, domain_dim, requires_grad=True), - "attr": torch.rand(batch_size, domain_dim, requires_grad=True), - } - prefusion_encodings = { - "v_latents": torch.rand(batch_size, domain_dim, requires_grad=True), - "attr": torch.rand(batch_size, domain_dim, requires_grad=True), - } - - attention_scores = attention(domains, prefusion_encodings) - - loss = sum(score.mean() for score in attention_scores.values()) - - assert isinstance(loss, torch.Tensor) diff --git a/tests/test_random_attention.py b/tests/test_random_attention.py deleted file mode 100644 index 5422fb73..00000000 --- a/tests/test_random_attention.py +++ /dev/null @@ -1,73 +0,0 @@ -import torch - -from shimmer.modules.selection import RandomSelection - - -def test_multiple_domains(): - temperature = 1.0 - domain_dim = 12 - batch_size = 2056 - - selection = RandomSelection(temperature) - multiple_domain_input = { - "v_latents": torch.rand(batch_size, domain_dim), - "attr": torch.rand(batch_size, domain_dim), - } - - prefusion_encodings = { - "v_latents": torch.rand(batch_size, domain_dim), - "attr": torch.rand(batch_size, domain_dim), - } - - selection_scores = selection(multiple_domain_input, prefusion_encodings) - - # Ensure the sum of attention scores across domains equals 1 - scores_sum = sum( - selection_scores[domain].squeeze() for domain in multiple_domain_input - ) - assert isinstance(scores_sum, torch.Tensor) - - expected_sum = torch.ones(batch_size) - - assert torch.allclose( - scores_sum, expected_sum - ), "Sum of selection scores across domains should be 1" - - -def test_three_domains(): - temperature = 1.0 - domain_dim = 12 - batch_size = 2056 - - selection = RandomSelection(temperature) - three_domain_input = { - "v_latents": torch.rand(batch_size, domain_dim), - "attr": torch.rand(batch_size, domain_dim), - "audio": torch.rand(batch_size, domain_dim), - } - - prefusion_encodings = { - "v_latents": torch.rand(batch_size, domain_dim), - "attr": torch.rand(batch_size, domain_dim), - "audio": torch.rand(batch_size, domain_dim), - } - - selection_scores = selection(three_domain_input, prefusion_encodings) - - # Ensure that the shape of the selection scores matches the input domains - for domain in three_domain_input: - assert selection_scores[domain].shape == ( - batch_size, - ), f"Scores shape mismatch for {domain}" - - # Ensure the sum of attention scores across domains equals 1 - scores_sum = sum( - selection_scores[domain].squeeze() for domain in three_domain_input - ) - assert isinstance(scores_sum, torch.Tensor) - - expected_sum = torch.ones(batch_size) - - assert torch.allclose( - scores_sum, expected_sum - ), "Sum of selection scores across three domains should be 1"