Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update the fusion script and add the bayesian GW models #11

Merged
merged 29 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
b928cfc
Use systematic_coeffs_fusion shimmer branch
bdvllrs Apr 22, 2024
fd32859
Update shimmer fusion branch and add loss coefs
bdvllrs Apr 8, 2024
8b411da
Update shimmer
bdvllrs Apr 8, 2024
182720e
update lock
RolandBERTINJOHANNET Apr 11, 2024
f895230
use branch add_coeffs of shimmer
RolandBERTINJOHANNET Apr 11, 2024
125e3ef
new shimmer + changed run name
RolandBERTINJOHANNET Apr 11, 2024
b4b3146
fixing config for longer training + update shimmer
RolandBERTINJOHANNET Apr 11, 2024
1a3c58f
Add selection temperature coef
bdvllrs Apr 16, 2024
6f776cf
Use main shimmer branch now that fusion has been merged
bdvllrs Apr 18, 2024
194a79b
Update shimmer to correct branch
bdvllrs Apr 18, 2024
627617f
Use BroadcastLossCoefs for uncertainty model
bdvllrs Apr 16, 2024
32be18c
Update shimmer
bdvllrs Apr 16, 2024
8a2fec5
log variances on wandb during training
bdvllrs Apr 18, 2024
7c121a3
use branch feature/uncertainty_eq for shimmer
bdvllrs Apr 18, 2024
4f13768
uncetainty -> confidence
bdvllrs Apr 18, 2024
64dd0e3
Update logger to log precision
bdvllrs Apr 18, 2024
292d1c0
Rename broadcast to fuse
bdvllrs Apr 19, 2024
a0b75f7
update number of training steps
bdvllrs Apr 22, 2024
5a8fea9
Update shimmer
bdvllrs Apr 22, 2024
7d736cb
Add coef for unpaired loss
bdvllrs Apr 23, 2024
6f1dbfa
remove coefs for attr paired and unpaired
bdvllrs Apr 24, 2024
c080fae
Update shimmer
bdvllrs Apr 24, 2024
0ead7a0
Rename Confidence to Bayesian
bdvllrs Apr 24, 2024
9d9a711
Fix use of decode function on GlobalWorkspace instead of gw_mod
bdvllrs Apr 25, 2024
eae8236
Use gw_fusion when using fusion model
bdvllrs Apr 26, 2024
c05e7f0
Add script to evaluate bayesian reps
bdvllrs Apr 26, 2024
983a26a
Use shimmer main branch
bdvllrs Apr 26, 2024
1214d4d
Update shimmer
bdvllrs May 21, 2024
c520253
Rename GlobalWorkspace to GlobalWorkspace2Domains. GlobalWorkspace is…
bdvllrs May 21, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,18 @@
"CUDA_DEVICE_ORDER": "PCI_BUS_ID"
}
},
{
"type": "python",
"request": "launch",
"name": "Exp test common space fusion",
"program": "${workspaceFolder}/playground/exploration/test_common_space_fusion.py",
"justMyCode": false,
"env": {
"DEBUG": "1",
"CUDA_VISIBLE_DEVICES": "1",
"CUDA_DEVICE_ORDER": "PCI_BUS_ID"
}
},
{
"type": "python",
"request": "launch",
Expand Down
4 changes: 2 additions & 2 deletions config/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ training:
max_lr: 5e-3
weight_decay: 1e-5

max_steps: 15_000
max_steps: 200000

logging:
log_val_medias_every_n_epochs: 10
Expand All @@ -26,7 +26,7 @@ domain_modules:
color_blind: false

global_workspace:
has_uncertainty: false
bayesian_gw: false

domain_proportions:
- domains: ["v"]
Expand Down
2 changes: 2 additions & 0 deletions config/exp_test_common_space_fusion.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
exploration:
gw_checkpoint: "{default_root_dir}/shimmer-simple-shapes-t3p8nyq8/epoch=204.ckpt"
2 changes: 1 addition & 1 deletion config/exp_test_t.yaml
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
exploration:
gw_checkpoint: shimmer-simple-shapes-syhl3yoh/epoch=408.ckpt
gw_checkpoint: "{default_root_dir}/shimmer-simple-shapes-syhl3yoh/epoch=408.ckpt"
2 changes: 1 addition & 1 deletion config/train_gw.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ training:
max_lr: 5e-3
weight_decay: 1e-6

max_steps: 100_000
max_steps: 50000

logging:
log_val_medias_every_n_epochs: 10
Expand Down
12 changes: 5 additions & 7 deletions playground/exploration/contrastive_var.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

import torch
from shimmer import (
GlobalWorkspaceWithUncertainty,
GWLossesWithUncertainty,
GlobalWorkspaceBayesian,
GWLossesBayesian,
)

from simple_shapes_dataset import DEBUG_MODE, PROJECT_DIR
Expand Down Expand Up @@ -85,7 +85,7 @@ def main():

ckpt_path = config.default_root_dir / config.exploration.gw_checkpoint
migrate_model(ckpt_path, PROJECT_DIR / "migrations" / "gw")
domain_module = GlobalWorkspaceWithUncertainty.load_from_checkpoint(
domain_module = GlobalWorkspaceBayesian.load_from_checkpoint(
ckpt_path,
domain_mods=domain_description,
gw_encoders=gw_encoders,
Expand Down Expand Up @@ -122,7 +122,7 @@ def main():
"attr": attr2.reshape(batch_size * n_rep, -1),
}
)
gw_states_std = gw_mod.log_uncertainties
gw_states_std = gw_mod.precisions

actual_std_attr = (
gw_states_means["attr"].reshape(batch_size, n_rep, -1).std(dim=1).mean(dim=0)
Expand All @@ -142,9 +142,7 @@ def main():
print(f"Predicted std attr: {predicted_std_attr}")
print(f"Predicted std v: {predicted_std_v}")

contrastive_fn = cast(
GWLossesWithUncertainty, domain_module.loss_mod
).contrastive_fn
contrastive_fn = cast(GWLossesBayesian, domain_module.loss_mod).contrastive_fn
assert contrastive_fn is not None

norm1 = 1.0 + gw_states_std["attr"].exp() + gw_states_std["v_latents"].exp()
Expand Down
4 changes: 2 additions & 2 deletions playground/exploration/norm_cont_comp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Any

import torch
from shimmer import GlobalWorkspaceWithUncertainty, contrastive_loss
from shimmer import GlobalWorkspaceBayesian, contrastive_loss

from simple_shapes_dataset import DEBUG_MODE, PROJECT_DIR
from simple_shapes_dataset.ckpt_migrations import migrate_model
Expand Down Expand Up @@ -80,7 +80,7 @@ def main():

ckpt_path = config.default_root_dir / config.exploration.gw_checkpoint
migrate_model(ckpt_path, PROJECT_DIR / "migrations" / "gw")
domain_module = GlobalWorkspaceWithUncertainty.load_from_checkpoint(
domain_module = GlobalWorkspaceBayesian.load_from_checkpoint(
ckpt_path,
domain_mods=domain_description,
gw_encoders=gw_encoders,
Expand Down
107 changes: 107 additions & 0 deletions playground/exploration/test_common_space_fusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import logging
from collections.abc import Callable
from typing import Any, cast

from lightning.pytorch import Callback, Trainer, seed_everything
from lightning.pytorch.callbacks import (
RichProgressBar,
)
from shimmer.modules.global_workspace import (
GlobalWorkspaceBayesian,
GWPredictionsBase,
)
from torch import set_float32_matmul_precision

from simple_shapes_dataset import DEBUG_MODE, PROJECT_DIR
from simple_shapes_dataset.config import load_config
from simple_shapes_dataset.dataset import SimpleShapesDataModule
from simple_shapes_dataset.dataset.pre_process import (
color_blind_visual_domain,
nullify_attribute_rotation,
)
from simple_shapes_dataset.modules.domains import load_pretrained_domains


def main():
config = load_config(
PROJECT_DIR / "config",
load_files=["exp_test_common_space_fusion.yaml"],
debug_mode=DEBUG_MODE,
)

if config.exploration is None:
raise ValueError("Config value 'exploration' must be set")

seed_everything(config.seed, workers=True)

domain_proportion = {
frozenset(item.domains): item.proportion
for item in config.global_workspace.domain_proportions
}

additional_transforms: dict[str, list[Callable[[Any], Any]]] = {}
if config.domain_modules.attribute.nullify_rotation:
logging.info("Nullifying rotation in the attr domain.")
additional_transforms["attr"] = [nullify_attribute_rotation]
if config.domain_modules.visual.color_blind:
logging.info("v domain will be color blind.")
additional_transforms["v"] = [color_blind_visual_domain]

data_module = SimpleShapesDataModule(
config.dataset.path,
domain_proportion,
batch_size=config.training.batch_size,
num_workers=config.training.num_workers,
seed=config.seed,
ood_seed=config.ood_seed,
domain_args=config.global_workspace.domain_args,
additional_transforms=additional_transforms,
)

domain_modules, gw_encoders, gw_decoders = load_pretrained_domains(
config.default_root_dir,
config.global_workspace.domains,
config.global_workspace.latent_dim,
config.global_workspace.encoders.hidden_dim,
config.global_workspace.encoders.n_layers,
config.global_workspace.decoders.hidden_dim,
config.global_workspace.decoders.n_layers,
is_linear=config.global_workspace.linear_domains,
bias=config.global_workspace.linear_domains_use_bias,
)

module = GlobalWorkspaceBayesian.load_from_checkpoint(
config.exploration.gw_checkpoint,
domain_mods=domain_modules,
gw_encoders=gw_encoders,
gw_decoders=gw_decoders,
)

callbacks: list[Callback] = [RichProgressBar()]
set_float32_matmul_precision(config.training.float32_matmul_precision)

trainer = Trainer(
fast_dev_run=config.training.fast_dev_run,
max_steps=config.training.max_steps,
enable_progress_bar=config.training.enable_progress_bar,
default_root_dir=config.default_root_dir,
callbacks=callbacks,
precision=config.training.precision,
accelerator=config.training.accelerator,
devices=config.training.devices,
)

predictions = cast(
list[GWPredictionsBase],
trainer.predict(module, data_module, return_predictions=True),
)

for pred in predictions:
for k in range(pred["states"]["attr"].size(0)):
for domain in pred["states"]:
print(domain, pred["states"][domain][k])
print("--")


if __name__ == "__main__":
main()
4 changes: 2 additions & 2 deletions playground/exploration/var_norm_cont.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Any

from lightning.pytorch import Trainer
from shimmer.modules.global_workspace import GlobalWorkspaceWithUncertainty
from shimmer.modules.global_workspace import GlobalWorkspaceBayesian

from simple_shapes_dataset import DEBUG_MODE, PROJECT_DIR
from simple_shapes_dataset.ckpt_migrations import (
Expand Down Expand Up @@ -63,7 +63,7 @@ def main():

ckpt_path = config.exploration.gw_checkpoint
migrate_model(ckpt_path, PROJECT_DIR / "migrations" / "gw")
gw = GlobalWorkspaceWithUncertainty.load_from_checkpoint(
gw = GlobalWorkspaceBayesian.load_from_checkpoint(
ckpt_path,
domain_mods=domain_description,
gw_encoders=gw_encoders,
Expand Down
6 changes: 3 additions & 3 deletions playground/exploration/var_norm_cont_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Any

import torch
from shimmer.modules.global_workspace import GlobalWorkspaceWithUncertainty
from shimmer.modules.global_workspace import GlobalWorkspaceBayesian

from simple_shapes_dataset import DEBUG_MODE, PROJECT_DIR
from simple_shapes_dataset.ckpt_migrations import (
Expand Down Expand Up @@ -82,7 +82,7 @@ def main():

ckpt_path = config.default_root_dir / config.exploration.gw_checkpoint
migrate_model(ckpt_path, PROJECT_DIR / "migrations" / "gw")
domain_module = GlobalWorkspaceWithUncertainty.load_from_checkpoint(
domain_module = GlobalWorkspaceBayesian.load_from_checkpoint(
ckpt_path,
domain_mods=domain_description,
gw_encoders=gw_encoders,
Expand All @@ -103,7 +103,7 @@ def main():
v_test[:, :12] = v_paired[None, :12]
attr_test[:, :12] = attr_paired[None, :12]
gw_states_means = gw_mod.encode({"v_latents": v_test, "attr": attr_test})
gw_states_std = gw_mod.log_uncertainties
gw_states_std = gw_mod.precisions
v_gw_var = (0.5 * gw_states_std["v_latents"]).exp() # noqa: F841
attr_gw_var = (0.5 * gw_states_std["attr"]).exp() # noqa: F841
print(gw_states_means)
Expand Down
58 changes: 42 additions & 16 deletions playground/train_gw.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,17 @@
RichProgressBar,
)
from lightning.pytorch.loggers.wandb import WandbLogger
from shimmer import ContrastiveLossType, GlobalWorkspaceBase, LossCoefs, SaveMigrations
from shimmer import (
BroadcastLossCoefs,
ContrastiveLossType,
GlobalWorkspaceBase,
LossCoefs,
SaveMigrations,
)
from shimmer.modules.global_workspace import (
GlobalWorkspace,
GlobalWorkspaceFusion,
GlobalWorkspaceWithUncertainty,
GlobalWorkspace2Domains,
GlobalWorkspaceBayesian,
SchedulerArgs,
)
from torch import set_float32_matmul_precision
Expand Down Expand Up @@ -76,13 +82,6 @@ def main():
bias=config.global_workspace.linear_domains_use_bias,
)

loss_coefs: LossCoefs = {
"demi_cycles": config.global_workspace.loss_coefficients.demi_cycles,
"cycles": config.global_workspace.loss_coefficients.cycles,
"translations": config.global_workspace.loss_coefficients.translations,
"contrastives": config.global_workspace.loss_coefficients.contrastives,
}

contrastive_fn: ContrastiveLossType | None = None
if config.global_workspace.vsepp_contrastive_loss:
contrastive_fn = VSEPPContrastiveLoss(
Expand All @@ -93,13 +92,23 @@ def main():
)

module: GlobalWorkspaceBase
if config.global_workspace.has_uncertainty:
module = GlobalWorkspaceWithUncertainty(
gw_type: str
if config.global_workspace.bayesian_gw:
gw_type = "gw_bayesian"
loss_coefs_bayesian: BroadcastLossCoefs = {
"contrastives": config.global_workspace.loss_coefficients.contrastives,
"fused": config.global_workspace.loss_coefficients.fused,
"translations": config.global_workspace.loss_coefficients.translations,
"demi_cycles": config.global_workspace.loss_coefficients.demi_cycles,
"cycles": config.global_workspace.loss_coefficients.cycles,
}
module = GlobalWorkspaceBayesian(
domain_modules,
gw_encoders,
gw_decoders,
config.global_workspace.latent_dim,
loss_coefs,
loss_coefs_bayesian,
config.global_workspace.selection_temperature,
config.training.optim.lr,
config.training.optim.weight_decay,
scheduler_args=SchedulerArgs(
Expand All @@ -110,11 +119,21 @@ def main():
contrastive_loss=contrastive_fn,
)
elif config.global_workspace.use_fusion_model:
module = GlobalWorkspaceFusion(
gw_type = "gw_fusion"
loss_coefs_fusion: BroadcastLossCoefs = {
"contrastives": config.global_workspace.loss_coefficients.contrastives,
"fused": config.global_workspace.loss_coefficients.fused,
"translations": config.global_workspace.loss_coefficients.translations,
"demi_cycles": config.global_workspace.loss_coefficients.demi_cycles,
"cycles": config.global_workspace.loss_coefficients.cycles,
}
module = GlobalWorkspace(
domain_modules,
gw_encoders,
gw_decoders,
config.global_workspace.latent_dim,
loss_coefs_fusion,
config.global_workspace.selection_temperature,
config.training.optim.lr,
config.training.optim.weight_decay,
scheduler_args=SchedulerArgs(
Expand All @@ -125,7 +144,15 @@ def main():
contrastive_loss=contrastive_fn,
)
else:
module = GlobalWorkspace(
gw_type = "gw"
loss_coefs: LossCoefs = {
"demi_cycles": config.global_workspace.loss_coefficients.demi_cycles,
"cycles": config.global_workspace.loss_coefficients.cycles,
"translations": config.global_workspace.loss_coefficients.translations,
"contrastives": config.global_workspace.loss_coefficients.contrastives,
}

module = GlobalWorkspace2Domains(
domain_modules,
gw_encoders,
gw_decoders,
Expand Down Expand Up @@ -216,7 +243,6 @@ def main():

wandb_logger = None
if config.wandb.enabled:
gw_type = "gw_uncertainty" if config.global_workspace.has_uncertainty else "gw"
run_name = f"{gw_type}_z={config.global_workspace.latent_dim}"
wandb_logger = WandbLogger(
save_dir=config.wandb.save_dir,
Expand Down
Loading
Loading