Skip to content

Commit

Permalink
Update shimmer (#1)
Browse files Browse the repository at this point in the history
* Update to latest shimmer version
* Remove dependency to OmegaConf and use pydantic for config instead
* Update to latest auto_sbatch to remove OmegaConf dependency
* Add a checkpoint migration system to always have checkpoints that works.
* Add checkpoint migration to update to new shimmer version
  • Loading branch information
bdvllrs authored Feb 20, 2024
1 parent 2afda03 commit 9939bbf
Show file tree
Hide file tree
Showing 81 changed files with 2,138 additions and 1,992 deletions.
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
[flake8]
exclude = .git,__pycache__,images,config
ignore = E203,W503
extend-ignore = E203,E704,W503
max-line-length = 88
6 changes: 6 additions & 0 deletions .github/workflows/quality.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ jobs:
eval `ssh-agent -s`
ssh-add - <<< '${{ secrets.SHIMMER_DEPLOY }}'
poetry install --with dev
- name: isort
run: |
poetry run isort --profile=black --check-only --diff --line-length=88 simple_shapes_dataset playground
- uses: psf/black@stable
with:
src: "."
- name: Analysing the code with flake8
run: |
poetry run flake8 .
Expand Down
File renamed without changes.
7 changes: 3 additions & 4 deletions config/default/main.yaml → config/default.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
default_root_dir: ???
default_root_dir: ??? # add this in local.yaml

dataset:
path: ???
path: ??? # add this in local.yaml

seed: 0

Expand Down Expand Up @@ -29,12 +29,11 @@ global_workspace:
is_variational: false
var_contrastive_loss: false

sync_prop: 1.0
domain_proportions:
- domains: ["v"]
proportion: 1.0
- domains: ["attr"]
proportion: 1.0
- domains: ["v", "attr"]
proportion: ${global_workspace.sync_prop}
proportion: 1.0

2 changes: 2 additions & 0 deletions config/exp_test_t.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
exploration:
gw_checkpoint: shimmer-simple-shapes-syhl3yoh/epoch=408.ckpt
2 changes: 0 additions & 2 deletions config/exp_test_t/main.yaml

This file was deleted.

2 changes: 2 additions & 0 deletions config/exp_var_cont.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
exploration:
gw_checkpoint: shimmer-simple-shapes-qfe2thvs/epoch=200.ckpt
2 changes: 0 additions & 2 deletions config/exp_var_cont/main.yaml

This file was deleted.

50 changes: 50 additions & 0 deletions config/local.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
default_root_dir: "/home/bdevillers/projects/simple-shapes-dataset/checkpoints"

dataset:
path: "/shared/datasets/simple_shapes_dataset"

wandb:
enabled: true
save_dir: "/home/bdevillers/projects/simple-shapes-dataset/wandb"
project: "shimmer-simple-shapes"
entity: "vanrullen-lab"
reinit: true

training:
precision: "16-mixed"
float32_matmul_precision: "medium"

domain_modules:
text:
latent_filename: bert-base-uncased

ood_seed: 0

global_workspace:
domains:
# - checkpoint_path: shimmer-simple-shapes-t1nbecx3/epoch=201.ckpt
# domain_type: attr
# args:
# n_unpaired: ${global_workspace.domain_args.attr.n_unpaired}
- checkpoint_path: "."
domain_type: attr_legacy
# - checkpoint_path: shimmer-simple-shapes-z5ec9tmq/epoch=501.ckpt
# domain_type: v
- checkpoint_path: pretrained/vae_v_shimmer.ckpt
domain_type: v_latents

domain_args:
v_latents:
# presaved_path: shimmer-simple-shapes-z5ec9tmq_epoch=501_color_blind.npy
# presaved_path: shimmer-simple-shapes-z5ec9tmq_epoch=501.npy
presaved_path: vae_v_shimmer.npy
attr:
n_unpaired: 1


slurm:
script: "playground/train_v.py"
run_workdir: "."
python_env: "test"
command: "python test.py"

9 changes: 9 additions & 0 deletions config/save_v_latents.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
domain_checkpoint:
# checkpoint_path: shimmer-simple-shapes-z5ec9tmq/epoch=501.ckpt
checkpoint_path: pretrained/vae_v_shimmer.ckpt
domain_type: v

presaved_latents_path:
# v: shimmer-simple-shapes-z5ec9tmq_epoch=501.npy
v: vae_v_shimmer.npy

9 changes: 0 additions & 9 deletions config/save_v_latents/main.yaml

This file was deleted.

File renamed without changes.
1 change: 0 additions & 1 deletion config/train_gw/main.yaml → config/train_gw.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

training:
optim:
lr: 5e-3
Expand Down
File renamed without changes.
File renamed without changes.
4 changes: 4 additions & 0 deletions config/viz_vae_attr.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
visualization:
explore_vae:
wandb_name: tlexjzo7
checkpoint: shimmer-simple-shapes-tlexjzo7/epoch=196.ckpt
4 changes: 0 additions & 4 deletions config/viz_vae_attr/main.yaml

This file was deleted.

2 changes: 1 addition & 1 deletion config/viz_vae_gw/main.yaml → config/viz_vae_gw.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
visualization:
explore_gw:
# wandb_name: zdufxid0
checkpoint: ${default_root_dir}/shimmer-simple-shapes-zdufxid0/epoch=203.ckpt
checkpoint: shimmer-simple-shapes-zdufxid0/epoch=203.ckpt
domain: v_latents
4 changes: 4 additions & 0 deletions config/viz_vae_v.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
visualization:
explore_vae:
# wandb_name: z5ec9tmq
checkpoint: shimmer-simple-shapes-z5ec9tmq/epoch=501.ckpt
4 changes: 0 additions & 4 deletions config/viz_vae_v/main.yaml

This file was deleted.

4 changes: 0 additions & 4 deletions mypy.ini

This file was deleted.

67 changes: 31 additions & 36 deletions playground/exploration/contrastive_var.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,16 @@
from typing import Any, cast

import torch
from shimmer import load_structured_config
from shimmer.modules.global_workspace import VariationalGlobalWorkspace
from shimmer.modules.gw_module import VariationalGWModule
from shimmer.modules.losses import (VariationalGWLosses,
contrastive_loss_with_uncertainty)
from shimmer import VariationalGlobalWorkspace, VariationalGWLosses, VariationalGWModule

from simple_shapes_dataset import DEBUG_MODE, PROJECT_DIR
from simple_shapes_dataset.config.root import Config
from simple_shapes_dataset.ckpt_migrations import migrate_model, var_gw_migrations
from simple_shapes_dataset.config import load_config
from simple_shapes_dataset.dataset.data_module import SimpleShapesDataModule
from simple_shapes_dataset.dataset.pre_process import (color_blind_visual_domain,
nullify_attribute_rotation)
from simple_shapes_dataset.dataset.pre_process import (
color_blind_visual_domain,
nullify_attribute_rotation,
)
from simple_shapes_dataset.modules.domains.pretrained import load_pretrained_domains


Expand All @@ -35,13 +34,15 @@ def put_on_device(


def main():
config = load_structured_config(
config = load_config(
PROJECT_DIR / "config",
Config,
load_dirs=["exp_var_cont"],
load_files=["exp_var_cont.yaml"],
debug_mode=DEBUG_MODE,
)

if config.exploration is None:
raise ValueError("Exploration config should be set for this script")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

domain_proportion = {
Expand All @@ -67,20 +68,23 @@ def main():
additional_transforms=additional_transforms,
)

domain_description = load_pretrained_domains(
domain_description, interfaces = load_pretrained_domains(
config.default_root_dir,
config.global_workspace.domains,
config.global_workspace.latent_dim,
config.global_workspace.encoders.hidden_dim,
config.global_workspace.encoders.n_layers,
config.global_workspace.decoders.hidden_dim,
config.global_workspace.decoders.n_layers,
is_variational=True,
)

domain_module = cast(
VariationalGlobalWorkspace,
VariationalGlobalWorkspace.load_from_checkpoint(
config.exploration.gw_checkpoint,
domain_descriptions=domain_description,
),
ckpt_path = config.default_root_dir / config.exploration.gw_checkpoint
migrate_model(ckpt_path, var_gw_migrations)
domain_module = VariationalGlobalWorkspace.load_from_checkpoint(
ckpt_path,
domain_mods=domain_description,
gw_interfaces=interfaces,
)
domain_module.eval().freeze()
domain_module.to(device)
Expand All @@ -90,9 +94,7 @@ def main():
n_rep = 128
n_unpaired = 8

val_samples = put_on_device(
data_module.get_samples("val", batch_size), device
)
val_samples = put_on_device(data_module.get_samples("val", batch_size), device)
encoded_samples = domain_module.encode_domains(val_samples)[
frozenset(["v_latents", "attr"])
]
Expand All @@ -102,12 +104,7 @@ def main():
.expand((batch_size, n_rep, -1))
.clone()
)
attr1 = (
encoded_samples["attr"]
.unsqueeze(1)
.expand((batch_size, n_rep, -1))
.clone()
)
attr1 = encoded_samples["attr"].unsqueeze(1).expand((batch_size, n_rep, -1)).clone()
v_unpaired = torch.randn(batch_size, n_rep, n_unpaired).to(device)
attr_unpaired = torch.randn(batch_size, n_rep, n_unpaired).to(device)
v2 = v1[:]
Expand All @@ -122,10 +119,7 @@ def main():
)

actual_std_attr = (
gw_states_means["attr"]
.reshape(batch_size, n_rep, -1)
.std(dim=1)
.mean(dim=0)
gw_states_means["attr"].reshape(batch_size, n_rep, -1).std(dim=1).mean(dim=0)
)
actual_std_v = (
gw_states_means["v_latents"]
Expand All @@ -142,22 +136,23 @@ def main():
print(f"Predicted std attr: {predicted_std_attr}")
print(f"Predicted std v: {predicted_std_v}")

logit_scale = cast(VariationalGWLosses, domain_module.loss_mod).logit_scale
contrastive_fn = cast(
VariationalGWLosses, domain_module.loss_mod
).var_contrastive_fn
assert contrastive_fn is not None

cont_loss1 = contrastive_loss_with_uncertainty(
cont_loss1 = contrastive_fn(
gw_states_means["attr"],
gw_states_std["attr"],
gw_states_means["v_latents"],
gw_states_std["v_latents"],
logit_scale,
)

cont_loss2 = contrastive_loss_with_uncertainty(
cont_loss2 = contrastive_fn(
gw_states_means["attr"],
actual_std_attr,
gw_states_means["v_latents"],
actual_std_v,
logit_scale,
)

print(f"Contrastive loss 1: {cont_loss1}")
Expand Down
36 changes: 21 additions & 15 deletions playground/exploration/norm_cont_comp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,18 @@
from typing import Any, cast

import torch
from shimmer import load_structured_config
from shimmer import contrastive_loss
from shimmer.modules.global_workspace import GlobalWorkspace
from shimmer.modules.gw_module import VariationalGWModule
from shimmer.modules.losses import contrastive_loss

from simple_shapes_dataset import DEBUG_MODE, PROJECT_DIR
from simple_shapes_dataset.config.root import Config
from simple_shapes_dataset.ckpt_migrations import gw_migrations, migrate_model
from simple_shapes_dataset.config import load_config
from simple_shapes_dataset.dataset.data_module import SimpleShapesDataModule
from simple_shapes_dataset.dataset.pre_process import (color_blind_visual_domain,
nullify_attribute_rotation)
from simple_shapes_dataset.dataset.pre_process import (
color_blind_visual_domain,
nullify_attribute_rotation,
)
from simple_shapes_dataset.modules.domains.pretrained import load_pretrained_domains


Expand All @@ -34,13 +36,15 @@ def put_on_device(


def main():
config = load_structured_config(
config = load_config(
PROJECT_DIR / "config",
Config,
load_dirs=["exp_var_cont"],
load_files=["exp_var_cont.yaml"],
debug_mode=DEBUG_MODE,
)

if config.exploration is None:
raise ValueError("Exploration config should be set for this script")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

domain_proportion = {
Expand All @@ -66,20 +70,22 @@ def main():
additional_transforms=additional_transforms,
)

domain_description = load_pretrained_domains(
domain_description, interfaces = load_pretrained_domains(
config.default_root_dir,
config.global_workspace.domains,
config.global_workspace.latent_dim,
config.global_workspace.encoders.hidden_dim,
config.global_workspace.encoders.n_layers,
config.global_workspace.decoders.hidden_dim,
config.global_workspace.decoders.n_layers,
)

domain_module = cast(
GlobalWorkspace,
GlobalWorkspace.load_from_checkpoint(
config.exploration.gw_checkpoint,
domain_descriptions=domain_description,
),
ckpt_path = config.default_root_dir / config.exploration.gw_checkpoint
migrate_model(ckpt_path, gw_migrations)
domain_module = GlobalWorkspace.load_from_checkpoint(
ckpt_path,
domain_mods=domain_description,
gw_interfaces=interfaces,
)
domain_module.eval().freeze()
domain_module.to(device)
Expand Down
Loading

0 comments on commit 9939bbf

Please sign in to comment.