Skip to content

Commit

Permalink
Merge pull request #161 from geometric-intelligence/freeze_decoder
Browse files Browse the repository at this point in the history
Freeze decoder
  • Loading branch information
franciscoeacosta authored May 19, 2024
2 parents 1d739e8 + 3ad1f4a commit f72d481
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@
s_0 = [1,10,100,1000,10000]#,1000]
x_saliency = [0.5,0.8]#,0.8]
sigma_saliency = [0.05,0.1,0.15,0.2,0.5]#,0.5]
freeze_decoder = True
#integration
n_inte_step=[50,75,100]#,100] # 50

###-----TRAINING PARAMETERS-----###
load_pretrain=True
pretrain_path=os.path.join(os.getcwd(),"logs/rnn_isometry/20240418-180712/ckpt/model/checkpoint-step25000.pth")
num_steps_train=7500 # 10000
num_steps_train=10000#7500 # 10000
lr_decay_from=10000
steps_per_logging=20
steps_per_large_logging=500 # 500
Expand Down Expand Up @@ -80,4 +81,4 @@

###-----RAY TUNE PARAMETERS-----###
sweep_metric= "error_reencode"
num_samples = 1000
num_samples = 1000#1000
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import pickle

import default_config
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import numpy as np
import wandb
Expand Down Expand Up @@ -147,7 +146,7 @@ def _draw_heatmap(activations, title):
weight = activations[i]
vmin, vmax = weight.min() - 0.01, weight.max()

cmap = cm.get_cmap("jet", 1000)
cmap = plt.get_cmap("jet", 1000)
cmap.set_under("w")

ax.imshow(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ def __init__(self, rng, config: ml_collections.ConfigDict, device):
self.model_config = model.GridCellConfig(**config.model)
self.model = model.GridCell(self.model_config).to(device)

if config.model.freeze_decoder:
logging.info("==== freeze decoder ====")
for param in self.model.decoder.parameters():
param.requires_grad = False

# initialize dataset
logging.info("==== initialize dataset ====")
self.train_dataset = input_pipeline.TrainDataset(self.rng, config.data, self.model_config)
Expand All @@ -47,15 +52,15 @@ def __init__(self, rng, config: ml_collections.ConfigDict, device):
logging.info("==== initialize optimizer ====")
if config.train.optimizer_type == "adam":
self.optimizer = torch.optim.Adam(
self.model.parameters(), lr=config.train.lr
filter(lambda p: p.requires_grad, self.model.parameters()), lr=config.train.lr
)
elif config.train.optimizer_type == "adam_w":
self.optimizer = torch.optim.AdamW(
self.model.parameters(), lr=config.train.lr
filter(lambda p: p.requires_grad, self.model.parameters()), lr=config.train.lr
)
elif config.train.optimizer_type == "sgd":
self.optimizer = torch.optim.SGD(
self.model.parameters(), lr=config.train.lr, momentum=0.9
filter(lambda p: p.requires_grad, self.model.parameters()), lr=config.train.lr, momentum=0.9
)

if config.train.load_pretrain:
Expand All @@ -64,12 +69,13 @@ def __init__(self, rng, config: ml_collections.ConfigDict, device):
logging.info(f"Loading pretrain model from {ckpt_model_path}")
ckpt = torch.load(ckpt_model_path, map_location=device)
self.model.load_state_dict(ckpt["state_dict"])
logging.info("==== load pretrained optimizer ====")
self.optimizer.load_state_dict(ckpt["optimizer"])
# logging.info("==== load pretrained optimizer ====")
# self.optimizer.load_state_dict(ckpt["optimizer"])
self.starting_step = ckpt["step"]
else:
self.starting_step = 1


def train_and_evaluate(self):
logging.info("==== Experiment.train_and_evaluate() ===")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def main_sweep(sweep_name, s_0, sigma_saliency, x_saliency,plot=True):
"add_dx_0": default_config.add_dx_0,
"small_int": default_config.small_int,
# model parameters
"freeze_decoder": default_config.freeze_decoder,
"trans_type": default_config.trans_type,
"num_grid": default_config.num_grid,
"num_neurons": default_config.num_neurons,
Expand Down Expand Up @@ -183,6 +184,7 @@ def _convert_config(wandb_config):

# model parameter
config.model = _d(
freeze_decoder=wandb_config.freeze_decoder,
trans_type=wandb_config.trans_type,
rnn_step=wandb_config.rnn_step,
num_grid=wandb_config.num_grid,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

@dataclass
class GridCellConfig:
freeze_decoder: bool
trans_type: str
num_grid: int
num_neurons: int
Expand Down

0 comments on commit f72d481

Please sign in to comment.