Skip to content

Commit

Permalink
Update shimmer to latest main (#10)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdvllrs authored Apr 8, 2024
1 parent b83f7a9 commit 8eecd0b
Show file tree
Hide file tree
Showing 12 changed files with 45 additions and 69 deletions.
1 change: 0 additions & 1 deletion config/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ domain_modules:

global_workspace:
has_uncertainty: false
cont_loss_with_uncertainty: false

domain_proportions:
- domains: ["v"]
Expand Down
21 changes: 8 additions & 13 deletions playground/exploration/contrastive_var.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from shimmer import (
GlobalWorkspaceWithUncertainty,
GWLossesWithUncertainty,
GWModuleWithUncertainty,
)

from simple_shapes_dataset import DEBUG_MODE, PROJECT_DIR
Expand Down Expand Up @@ -82,7 +81,6 @@ def main():
config.global_workspace.encoders.n_layers,
config.global_workspace.decoders.hidden_dim,
config.global_workspace.decoders.n_layers,
has_uncertainty=True,
)

ckpt_path = config.default_root_dir / config.exploration.gw_checkpoint
Expand All @@ -95,7 +93,7 @@ def main():
)
domain_module.eval().freeze()
domain_module.to(device)
gw_mod = cast(GWModuleWithUncertainty, domain_module.gw_mod)
gw_mod = domain_module.gw_mod

batch_size = 128
n_rep = 128
Expand All @@ -118,12 +116,13 @@ def main():
attr2 = attr1[:]
v2[:, :, -n_unpaired:] = v_unpaired
attr2[:, :, -n_unpaired:] = attr_unpaired
gw_states_means, gw_states_std = gw_mod.encoded_distribution(
gw_states_means = gw_mod.encode(
{
"v_latents": v2.reshape(batch_size * n_rep, -1),
"attr": attr2.reshape(batch_size * n_rep, -1),
}
)
gw_states_std = gw_mod.log_uncertainties

actual_std_attr = (
gw_states_means["attr"].reshape(batch_size, n_rep, -1).std(dim=1).mean(dim=0)
Expand All @@ -145,21 +144,17 @@ def main():

contrastive_fn = cast(
GWLossesWithUncertainty, domain_module.loss_mod
).cont_fn_with_uncertainty
).contrastive_fn
assert contrastive_fn is not None

norm1 = 1.0 + gw_states_std["attr"].exp() + gw_states_std["v_latents"].exp()
cont_loss1 = contrastive_fn(
gw_states_means["attr"],
gw_states_std["attr"],
gw_states_means["v_latents"],
gw_states_std["v_latents"],
gw_states_means["attr"] / norm1, gw_states_means["v_latents"] / norm1
)

norm2 = 1.0 + actual_std_attr.exp() + actual_std_v.exp()
cont_loss2 = contrastive_fn(
gw_states_means["attr"],
actual_std_attr,
gw_states_means["v_latents"],
actual_std_v,
gw_states_means["attr"] / norm2, gw_states_means["v_latents"] / norm2
)

print(f"Contrastive loss 1: {cont_loss1}")
Expand Down
18 changes: 6 additions & 12 deletions playground/exploration/norm_cont_comp.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import logging
from collections.abc import Callable, Mapping
from typing import Any, cast
from typing import Any

import torch
from shimmer import contrastive_loss
from shimmer.modules.global_workspace import GlobalWorkspace
from shimmer.modules.gw_module import GWModuleWithUncertainty
from shimmer import GlobalWorkspaceWithUncertainty, contrastive_loss

from simple_shapes_dataset import DEBUG_MODE, PROJECT_DIR
from simple_shapes_dataset.ckpt_migrations import migrate_model
Expand Down Expand Up @@ -82,15 +80,15 @@ def main():

ckpt_path = config.default_root_dir / config.exploration.gw_checkpoint
migrate_model(ckpt_path, PROJECT_DIR / "migrations" / "gw")
domain_module = GlobalWorkspace.load_from_checkpoint(
domain_module = GlobalWorkspaceWithUncertainty.load_from_checkpoint(
ckpt_path,
domain_mods=domain_description,
gw_encoders=gw_encoders,
gw_decoders=gw_decoders,
)
domain_module.eval().freeze()
domain_module.to(device)
gw_mod = cast(GWModuleWithUncertainty, domain_module.gw_mod)
gw_mod = domain_module.gw_mod

val_samples = put_on_device(data_module.get_samples("val", 2048), device)
encoded_samples = domain_module.encode_domains(val_samples)[
Expand All @@ -104,12 +102,8 @@ def main():
attr2 = attr1[:]
v2[:, -1] = v_unpaired[:, 0]
attr2[:, -1] = attr_unpaired[:, 0]
gw_states_means1, gw_states_std1 = gw_mod.encoded_distribution(
{"v_latents": v1, "attr": attr1}
)
gw_states_means2, gw_states_std2 = gw_mod.encoded_distribution(
{"v_latents": v2, "attr": attr2}
)
gw_states_means1 = gw_mod.encode({"v_latents": v1, "attr": attr1})
gw_states_means2 = gw_mod.encode({"v_latents": v2, "attr": attr2})
lossv1 = contrastive_loss(
gw_states_means1["v_latents"],
gw_states_means2["v_latents"],
Expand Down
1 change: 0 additions & 1 deletion playground/exploration/var_norm_cont.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def main():
config.global_workspace.encoders.n_layers,
config.global_workspace.decoders.hidden_dim,
config.global_workspace.decoders.n_layers,
has_uncertainty=True,
)

ckpt_path = config.exploration.gw_checkpoint
Expand Down
15 changes: 6 additions & 9 deletions playground/exploration/var_norm_cont_loss.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import logging
from collections.abc import Callable, Mapping
from typing import Any, cast
from typing import Any

import torch
from shimmer.modules.global_workspace import GlobalWorkspaceWithUncertainty
from shimmer.modules.gw_module import GWModuleWithUncertainty

from simple_shapes_dataset import DEBUG_MODE, PROJECT_DIR
from simple_shapes_dataset.ckpt_migrations import (
Expand Down Expand Up @@ -79,7 +78,6 @@ def main():
config.global_workspace.encoders.n_layers,
config.global_workspace.decoders.hidden_dim,
config.global_workspace.decoders.n_layers,
has_uncertainty=True,
)

ckpt_path = config.default_root_dir / config.exploration.gw_checkpoint
Expand All @@ -92,7 +90,7 @@ def main():
)
domain_module.eval().freeze()
domain_module.to(device)
gw_mod = cast(GWModuleWithUncertainty, domain_module.gw_mod)
gw_mod = domain_module.gw_mod

val_samples = put_on_device(data_module.get_samples("val", 1), device)
encoded_samples = domain_module.encode_domains(val_samples)[
Expand All @@ -104,11 +102,10 @@ def main():
attr_test = torch.randn(64, 13).to(device)
v_test[:, :12] = v_paired[None, :12]
attr_test[:, :12] = attr_paired[None, :12]
gw_states_means, gw_stats_std = gw_mod.encoded_distribution(
{"v_latents": v_test, "attr": attr_test}
)
v_gw_var = (0.5 * gw_stats_std["v_latents"]).exp() # noqa: F841
attr_gw_var = (0.5 * gw_stats_std["attr"]).exp() # noqa: F841
gw_states_means = gw_mod.encode({"v_latents": v_test, "attr": attr_test})
gw_states_std = gw_mod.log_uncertainties
v_gw_var = (0.5 * gw_states_std["v_latents"]).exp() # noqa: F841
attr_gw_var = (0.5 * gw_states_std["attr"]).exp() # noqa: F841
print(gw_states_means)


Expand Down
2 changes: 0 additions & 2 deletions playground/train_gw.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ def main():
config.global_workspace.encoders.n_layers,
config.global_workspace.decoders.hidden_dim,
config.global_workspace.decoders.n_layers,
has_uncertainty=config.global_workspace.has_uncertainty,
is_linear=config.global_workspace.linear_domains,
bias=config.global_workspace.linear_domains_use_bias,
)
Expand Down Expand Up @@ -101,7 +100,6 @@ def main():
gw_decoders,
config.global_workspace.latent_dim,
loss_coefs,
config.global_workspace.cont_loss_with_uncertainty,
config.training.optim.lr,
config.training.optim.weight_decay,
scheduler_args=SchedulerArgs(
Expand Down
1 change: 0 additions & 1 deletion playground/visualizations/explore_vae_gw.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,6 @@ def main() -> None:
config.global_workspace.encoders.n_layers,
config.global_workspace.decoders.hidden_dim,
config.global_workspace.decoders.n_layers,
has_uncertainty=True,
)

ckpt_path = config.default_root_dir / config.visualization.explore_gw.checkpoint
Expand Down
10 changes: 4 additions & 6 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ torchdata = "^0.6.1"
pillow = "^9.5.0"
numpy = "^1.25"
torch = "^2.0.1"
shimmer = {git = "[email protected]:bdvllrs/shimmer.git", rev = "b4fbb8d18f5102196d9a806fb0a894c64418df66"}
shimmer = {git = "[email protected]:bdvllrs/shimmer.git", rev = "main"}
wandb = "^0.15.4"
lightning = "^2.1.0"
pydantic = "^2.6.0"
Expand Down
24 changes: 19 additions & 5 deletions simple_shapes_dataset/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@
from matplotlib import gridspec
from matplotlib.figure import Figure
from PIL import Image
from shimmer import batch_cycles, batch_demi_cycles, batch_translations
from shimmer import (
SingleDomainSelection,
batch_cycles,
batch_demi_cycles,
batch_translations,
)
from shimmer.modules.global_workspace import GlobalWorkspaceBase
from torchvision.utils import make_grid

Expand Down Expand Up @@ -381,15 +386,24 @@ def on_callback(
f"ref_{'-'.join(domain_names)}_{domain_name}",
)

latents = pl_module.encode_domains(samples)
latent_groups = pl_module.encode_domains(samples)

selection_mod = SingleDomainSelection()

with torch.no_grad():
pl_module.eval()
prediction_demi_cycles = batch_demi_cycles(pl_module.gw_mod, latents)
prediction_demi_cycles = batch_demi_cycles(
pl_module.gw_mod, selection_mod, latent_groups
)
prediction_cycles = batch_cycles(
pl_module.gw_mod, latents, pl_module.domain_mods.keys()
pl_module.gw_mod,
selection_mod,
latent_groups,
pl_module.domain_mods.keys(),
)
prediction_translations = batch_translations(
pl_module.gw_mod, selection_mod, latent_groups
)
prediction_translations = batch_translations(pl_module.gw_mod, latents)
pl_module.train()

for logger in loggers:
Expand Down
18 changes: 1 addition & 17 deletions simple_shapes_dataset/modules/domains/pretrained.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,7 @@
from collections.abc import Sequence
from pathlib import Path

from shimmer import (
DomainModule,
GWDecoder,
GWEncoder,
GWEncoderLinear,
GWEncoderWithUncertainty,
)
from shimmer import DomainModule, GWDecoder, GWEncoder, GWEncoderLinear
from torch.nn import Linear, Module

from simple_shapes_dataset import PROJECT_DIR
Expand Down Expand Up @@ -82,7 +76,6 @@ def load_pretrained_domain(
encoder_n_layers: int,
decoder_hidden_dim: int,
decoder_n_layers: int,
has_uncertainty: bool = False,
is_linear: bool = False,
bias: bool = False,
) -> tuple[DomainModule, Module, Module]:
Expand All @@ -93,13 +86,6 @@ def load_pretrained_domain(
if is_linear:
gw_encoder = GWEncoderLinear(module.latent_dim, workspace_dim, bias=bias)
gw_decoder = Linear(workspace_dim, module.latent_dim, bias=bias)
elif has_uncertainty:
gw_encoder = GWEncoderWithUncertainty(
module.latent_dim, encoder_hidden_dim, workspace_dim, encoder_n_layers
)
gw_decoder = GWDecoder(
workspace_dim, decoder_hidden_dim, module.latent_dim, decoder_n_layers
)
else:
gw_encoder = GWEncoder(
module.latent_dim, encoder_hidden_dim, workspace_dim, encoder_n_layers
Expand All @@ -119,7 +105,6 @@ def load_pretrained_domains(
encoders_n_layers: int,
decoders_hidden_dim: int,
decoders_n_layers: int,
has_uncertainty: bool = False,
is_linear: bool = False,
bias: bool = False,
) -> tuple[dict[str, DomainModule], dict[str, Module], dict[str, Module]]:
Expand All @@ -137,7 +122,6 @@ def load_pretrained_domains(
encoders_n_layers,
decoders_hidden_dim,
decoders_n_layers,
has_uncertainty,
is_linear,
bias,
)
Expand Down
1 change: 0 additions & 1 deletion simple_shapes_dataset/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,6 @@ class GlobalWorkspace(BaseModel):
latent_dim: int = 12
has_uncertainty: bool = False
use_fusion_model: bool = False
cont_loss_with_uncertainty: bool = False
learn_logit_scale: bool = False
vsepp_contrastive_loss: bool = False
vsepp_margin: float = 0.2
Expand Down

0 comments on commit 8eecd0b

Please sign in to comment.