From 5c932ac909a0036641443518381ae8761593f3c6 Mon Sep 17 00:00:00 2001 From: Janne Spijkervet Date: Wed, 11 Mar 2020 20:50:07 +0100 Subject: [PATCH] Added testing scripts --- config/config.yaml | 2 +- experiment.py | 3 +++ modules/__init__.py | 3 ++- modules/logistic_regression.py | 9 ++++++++- run_all.sh | 9 +++++++++ testing/logistic_regression.py | 5 ++++- utils/__init__.py | 3 ++- utils/filestorage.py | 9 +++++++++ 8 files changed, 38 insertions(+), 5 deletions(-) create mode 100755 run_all.sh create mode 100644 utils/filestorage.py diff --git a/config/config.yaml b/config/config.yaml index fd8d150..22b60f7 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -13,7 +13,7 @@ n_out: 64 temperature: 0.5 # reload options -model_path: "logs/182" # set to most directory containing `checkpoint_##.tar` +model_path: "logs/0" # set to the directory containing `checkpoint_##.tar` model_num: 40 # set to checkpoint number # logistic regression options diff --git a/experiment.py b/experiment.py index e39b755..5787cdf 100644 --- a/experiment.py +++ b/experiment.py @@ -7,12 +7,15 @@ from sacred.stflow import LogFileWriter from sacred.observers import FileStorageObserver, MongoObserver +# from utils import CustomFileStorageObserver + # custom config hook from utils.yaml_config_hook import yaml_config_hook ex = Experiment("SimCLR") + #### file output directory ex.observers.append(FileStorageObserver("./logs")) diff --git a/modules/__init__.py b/modules/__init__.py index 82bb216..9d4ea83 100644 --- a/modules/__init__.py +++ b/modules/__init__.py @@ -1,2 +1,3 @@ from .simclr import SimCLR -from .nt_xent import NT_Xent \ No newline at end of file +from .nt_xent import NT_Xent +from .logistic_regression import LogisticRegression \ No newline at end of file diff --git a/modules/logistic_regression.py b/modules/logistic_regression.py index 39ab5d9..a055f7d 100644 --- a/modules/logistic_regression.py +++ b/modules/logistic_regression.py @@ -1,4 +1,11 @@ import torch.nn as nn class LogisticRegression(nn.Module): - \ No newline at end of file + + def __init__(self, n_features, n_classes): + super(LogisticRegression, self).__init__() + + self.model = nn.Linear(n_features, n_classes) + + def forward(self, x): + return self.model(x) \ No newline at end of file diff --git a/run_all.sh b/run_all.sh new file mode 100755 index 0000000..a23d12b --- /dev/null +++ b/run_all.sh @@ -0,0 +1,9 @@ +# activate pip / conda environment first + +# Train SimCLR model +python main.py + +# Train linear model and run test +python -m testing.logistic_regression \ + with \ + model_path=./logs/0 \ No newline at end of file diff --git a/testing/logistic_regression.py b/testing/logistic_regression.py index 4fd0070..c5a6ad9 100644 --- a/testing/logistic_regression.py +++ b/testing/logistic_regression.py @@ -6,6 +6,7 @@ from model import load_model from utils import post_config_hook +from modules import LogisticRegression def train(args, loader, simclr_model, model, criterion, optimizer): loss_epoch = 0 @@ -77,7 +78,9 @@ def main(_run, _log): ## Logistic Regression - model = torch.nn.Sequential(torch.nn.Linear(simclr_model.n_features, 10)).to(args.device) + n_classes = 10 # stl-10 + model = LogisticRegression(simclr_model.n_features, n_classes) + model = model.to(args.device) optimizer = torch.optim.Adam(model.parameters(), lr=3e-4) criterion = torch.nn.CrossEntropyLoss() diff --git a/utils/__init__.py b/utils/__init__.py index bc1b803..e40f29f 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -1,2 +1,3 @@ from .masks import mask_correlated_samples -from .yaml_config_hook import post_config_hook \ No newline at end of file +from .yaml_config_hook import post_config_hook +from .filestorage import CustomFileStorageObserver \ No newline at end of file diff --git a/utils/filestorage.py b/utils/filestorage.py new file mode 100644 index 0000000..3b037ad --- /dev/null +++ b/utils/filestorage.py @@ -0,0 +1,9 @@ +import pathlib +from sacred.observers import FileStorageObserver + +class CustomFileStorageObserver(FileStorageObserver): + def started_event(self, ex_info, command, host_info, start_time, config, meta_info, _id): + if _id is None: + _id = "baseline" + (pathlib.Path(self.basedir) / _id).parent.mkdir(exist_ok=True, parents=True) + return super().started_event(ex_info, command, host_info, start_time, config, meta_info, _id) \ No newline at end of file