Skip to content

Commit

Permalink
Added testing scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
Spijkervet committed Mar 11, 2020
1 parent c38be08 commit 5c932ac
Show file tree
Hide file tree
Showing 8 changed files with 38 additions and 5 deletions.
2 changes: 1 addition & 1 deletion config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ n_out: 64
temperature: 0.5

# reload options
model_path: "logs/182" # set to most directory containing `checkpoint_##.tar`
model_path: "logs/0" # set to the directory containing `checkpoint_##.tar`
model_num: 40 # set to checkpoint number

# logistic regression options
Expand Down
3 changes: 3 additions & 0 deletions experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@
from sacred.stflow import LogFileWriter
from sacred.observers import FileStorageObserver, MongoObserver

# from utils import CustomFileStorageObserver

# custom config hook
from utils.yaml_config_hook import yaml_config_hook


ex = Experiment("SimCLR")


#### file output directory
ex.observers.append(FileStorageObserver("./logs"))

Expand Down
3 changes: 2 additions & 1 deletion modules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .simclr import SimCLR
from .nt_xent import NT_Xent
from .nt_xent import NT_Xent
from .logistic_regression import LogisticRegression
9 changes: 8 additions & 1 deletion modules/logistic_regression.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
import torch.nn as nn

class LogisticRegression(nn.Module):


def __init__(self, n_features, n_classes):
super(LogisticRegression, self).__init__()

self.model = nn.Linear(n_features, n_classes)

def forward(self, x):
return self.model(x)
9 changes: 9 additions & 0 deletions run_all.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# activate pip / conda environment first

# Train SimCLR model
python main.py

# Train linear model and run test
python -m testing.logistic_regression \
with \
model_path=./logs/0
5 changes: 4 additions & 1 deletion testing/logistic_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from model import load_model
from utils import post_config_hook

from modules import LogisticRegression

def train(args, loader, simclr_model, model, criterion, optimizer):
loss_epoch = 0
Expand Down Expand Up @@ -77,7 +78,9 @@ def main(_run, _log):


## Logistic Regression
model = torch.nn.Sequential(torch.nn.Linear(simclr_model.n_features, 10)).to(args.device)
n_classes = 10 # stl-10
model = LogisticRegression(simclr_model.n_features, n_classes)
model = model.to(args.device)

optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
criterion = torch.nn.CrossEntropyLoss()
Expand Down
3 changes: 2 additions & 1 deletion utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .masks import mask_correlated_samples
from .yaml_config_hook import post_config_hook
from .yaml_config_hook import post_config_hook
from .filestorage import CustomFileStorageObserver
9 changes: 9 additions & 0 deletions utils/filestorage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import pathlib
from sacred.observers import FileStorageObserver

class CustomFileStorageObserver(FileStorageObserver):
def started_event(self, ex_info, command, host_info, start_time, config, meta_info, _id):
if _id is None:
_id = "baseline"
(pathlib.Path(self.basedir) / _id).parent.mkdir(exist_ok=True, parents=True)
return super().started_event(ex_info, command, host_info, start_time, config, meta_info, _id)

0 comments on commit 5c932ac

Please sign in to comment.