Skip to content

Commit

Permalink
Update shimmer
Browse files Browse the repository at this point in the history
  • Loading branch information
bdvllrs committed Dec 12, 2023
1 parent c2557ca commit 4471b7f
Show file tree
Hide file tree
Showing 8 changed files with 887 additions and 1,059 deletions.
6 changes: 3 additions & 3 deletions playground/exploration/contrastive_var.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch
from shimmer import load_structured_config
from shimmer.modules.global_workspace import VariationalGlobalWorkspace
from shimmer.modules.global_workspace import GlobalWorkspace
from shimmer.modules.gw_module import VariationalGWModule
from shimmer.modules.losses import (VariationalGWLosses,
contrastive_loss_with_uncertainty)
Expand Down Expand Up @@ -76,8 +76,8 @@ def main():
)

domain_module = cast(
VariationalGlobalWorkspace,
VariationalGlobalWorkspace.load_from_checkpoint(
GlobalWorkspace,
GlobalWorkspace.load_from_checkpoint(
config.exploration.gw_checkpoint,
domain_descriptions=domain_description,
),
Expand Down
6 changes: 3 additions & 3 deletions playground/exploration/var_norm_cont.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from lightning.pytorch import Trainer
from shimmer import load_structured_config
from shimmer.modules.global_workspace import VariationalGlobalWorkspace
from shimmer.modules.global_workspace import GlobalWorkspace

from simple_shapes_dataset import DEBUG_MODE, PROJECT_DIR
from simple_shapes_dataset.config.root import Config
Expand Down Expand Up @@ -54,8 +54,8 @@ def main():
)

gw = cast(
VariationalGlobalWorkspace,
VariationalGlobalWorkspace.load_from_checkpoint(
GlobalWorkspace,
GlobalWorkspace.load_from_checkpoint(
config.exploration.gw_checkpoint,
domain_descriptions=domain_description,
),
Expand Down
6 changes: 3 additions & 3 deletions playground/exploration/var_norm_cont_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch
from shimmer import load_structured_config
from shimmer.modules.global_workspace import VariationalGlobalWorkspace
from shimmer.modules.global_workspace import GlobalWorkspace
from shimmer.modules.gw_module import VariationalGWModule

from simple_shapes_dataset import DEBUG_MODE, PROJECT_DIR
Expand Down Expand Up @@ -74,8 +74,8 @@ def main():
)

domain_module = cast(
VariationalGlobalWorkspace,
VariationalGlobalWorkspace.load_from_checkpoint(
GlobalWorkspace,
GlobalWorkspace.load_from_checkpoint(
config.exploration.gw_checkpoint,
domain_descriptions=domain_description,
),
Expand Down
8 changes: 4 additions & 4 deletions playground/train_gw.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from lightning.pytorch.loggers.wandb import WandbLogger
from omegaconf import OmegaConf
from shimmer import load_structured_config
from shimmer.modules.global_workspace import (DeterministicGlobalWorkspace,
SchedulerArgs, VariationalGlobalWorkspace)
from shimmer.modules.global_workspace import (SchedulerArgs, global_workspace,
variational_global_workspace)
from torch import set_float32_matmul_precision

from simple_shapes_dataset import DEBUG_MODE, PROJECT_DIR
Expand Down Expand Up @@ -85,7 +85,7 @@ def main():
# )

if config.global_workspace.is_variational:
module = VariationalGlobalWorkspace(
module = variational_global_workspace(
domain_modules,
config.global_workspace.latent_dim,
loss_coefs,
Expand All @@ -98,7 +98,7 @@ def main():
),
)
else:
module = DeterministicGlobalWorkspace(
module = global_workspace(
domain_modules,
config.global_workspace.latent_dim,
loss_coefs,
Expand Down
21 changes: 9 additions & 12 deletions playground/visualizations/explore_vae_gw.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,19 @@
import numpy as np
import torch
import torchvision.transforms.functional as F
from matplotlib.figure import Figure
from matplotlib.gridspec import GridSpec
from PIL.Image import Image
from shimmer.config import load_structured_config
from shimmer.modules.global_workspace import VariationalGlobalWorkspace
from shimmer.modules.global_workspace import GlobalWorkspace
from torchvision.utils import make_grid

import wandb
from simple_shapes_dataset import DEBUG_MODE, PROJECT_DIR
from simple_shapes_dataset.config.root import Config
from simple_shapes_dataset.logging import attribute_image_grid, get_pil_image
from simple_shapes_dataset.modules.domains.pretrained import (
load_pretrained_domains,
)
from simple_shapes_dataset.modules.domains.visual import (
VisualLatentDomainModule,
)
from simple_shapes_dataset.modules.domains.pretrained import load_pretrained_domains
from simple_shapes_dataset.modules.domains.visual import VisualLatentDomainModule

matplotlib.use("Agg")

Expand All @@ -36,7 +33,7 @@ def image_grid_from_v_tensor(


def dim_exploration_figure(
module: VariationalGlobalWorkspace,
module: GlobalWorkspace,
z_size: int,
device: torch.device,
domain: str,
Expand All @@ -46,13 +43,13 @@ def dim_exploration_figure(
image_size: int = 32,
plot_dims: Sequence[int] | None = None,
fig_dim: int = 5,
) -> plt.Figure:
) -> Figure:
possible_dims = plot_dims or np.arange(z_size)

fig_size = (len(possible_dims) - 1) * fig_dim

fig = cast(
plt.Figure,
Figure,
plt.figure(
constrained_layout=True, figsize=(fig_size, fig_size), dpi=1
),
Expand Down Expand Up @@ -161,8 +158,8 @@ def main() -> None:
)

domain_module = cast(
VariationalGlobalWorkspace,
VariationalGlobalWorkspace.load_from_checkpoint(
GlobalWorkspace,
GlobalWorkspace.load_from_checkpoint(
config.visualization.explore_gw.checkpoint,
domain_description=domain_description,
),
Expand Down
Loading

0 comments on commit 4471b7f

Please sign in to comment.