From c38be08d92973d5c8b0ccda2e77fa089e6e49899 Mon Sep 17 00:00:00 2001 From: Janne Spijkervet Date: Wed, 11 Mar 2020 20:30:34 +0100 Subject: [PATCH] Added logistic regression --- README.md | 10 ++- config/config.yaml | 17 ++++- model.py | 7 +- modules/logistic_regression.py | 4 ++ testing/logistic_regression.py | 113 +++++++++++++++++++++++++++++++++ utils/yaml_config_hook.py | 3 +- 6 files changed, 149 insertions(+), 5 deletions(-) create mode 100644 modules/logistic_regression.py create mode 100644 testing/logistic_regression.py diff --git a/README.md b/README.md index e7d2999..ae54d5a 100644 --- a/README.md +++ b/README.md @@ -25,7 +25,15 @@ Then, simply run: python main.py ``` -### Dependencies +### Testing +To test a trained model, make sure to set the `model_path` variable in the `config/config.yaml` to the log ID of the training (e.g. `logs/0`). +Set the `model_num` to the epoch number you want to load the checkpoints from (e.g. `40`). + +``` +python -m testing.logistic_regression +``` + +## Dependencies ``` torch torchvision diff --git a/config/config.yaml b/config/config.yaml index bb0360c..fd8d150 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -1,8 +1,21 @@ +# train options batch_size: 256 workers: 16 start_epoch: 0 epochs: 40 + +# model options +resnet: "resnet18" normalize: True -temperature: 0.5 n_out: 64 -resnet: "resnet18" \ No newline at end of file + +# loss options +temperature: 0.5 + +# reload options +model_path: "logs/182" # set to most directory containing `checkpoint_##.tar` +model_num: 40 # set to checkpoint number + +# logistic regression options +logistic_batch_size: 256 +logistic_epochs: 100 \ No newline at end of file diff --git a/model.py b/model.py index df76418..9e830c9 100644 --- a/model.py +++ b/model.py @@ -2,10 +2,15 @@ import torch from modules import SimCLR -def load_model(args): +def load_model(args, reload_model=False): model = SimCLR(args) + + if reload_model: + model_fp = os.path.join(args.model_path, "checkpoint_{}.tar".format(args.model_num)) + model.load_state_dict(torch.load(model_fp)) return model + def save_model(args, model, optimizer): out = os.path.join(args.out_dir, "checkpoint_{}.tar".format(args.current_epoch)) diff --git a/modules/logistic_regression.py b/modules/logistic_regression.py new file mode 100644 index 0000000..39ab5d9 --- /dev/null +++ b/modules/logistic_regression.py @@ -0,0 +1,4 @@ +import torch.nn as nn + +class LogisticRegression(nn.Module): + \ No newline at end of file diff --git a/testing/logistic_regression.py b/testing/logistic_regression.py new file mode 100644 index 0000000..4fd0070 --- /dev/null +++ b/testing/logistic_regression.py @@ -0,0 +1,113 @@ +import torch +import torchvision +import argparse + +from experiment import ex +from model import load_model +from utils import post_config_hook + + +def train(args, loader, simclr_model, model, criterion, optimizer): + loss_epoch = 0 + accuracy_epoch = 0 + for step, (x, y) in enumerate(loader): + optimizer.zero_grad() + + x = x.to(args.device) + y = y.to(args.device) + + # get encoding + with torch.no_grad(): + h, z = simclr_model(x) + # h = 512 + # z = 64 + + output = model(h) + loss = criterion(output, y) + + predicted = output.argmax(1) + acc = (predicted == y).sum().item() / y.size(0) + accuracy_epoch += acc + + loss.backward() + optimizer.step() + + loss_epoch += loss.item() + + return loss_epoch, accuracy_epoch + +def test(args, loader, simclr_model, model, criterion, optimizer): + loss_epoch = 0 + accuracy_epoch = 0 + model.eval() + for step, (x, y) in enumerate(loader): + model.zero_grad() + + x = x.to(args.device) + y = y.to(args.device) + + # get encoding + with torch.no_grad(): + h, z = simclr_model(x) + # h = 512 + # z = 64 + + output = model(h) + loss = criterion(output, y) + + predicted = output.argmax(1) + acc = (predicted == y).sum().item() / y.size(0) + accuracy_epoch += acc + + loss_epoch += loss.item() + + return loss_epoch, accuracy_epoch + +@ex.automain +def main(_run, _log): + args = argparse.Namespace(**_run.config) + args = post_config_hook(args, _run) + + args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + root = "./datasets" + simclr_model = load_model(args, reload_model=True) + simclr_model = simclr_model.to(args.device) + simclr_model.eval() + + + ## Logistic Regression + model = torch.nn.Sequential(torch.nn.Linear(simclr_model.n_features, 10)).to(args.device) + + optimizer = torch.optim.Adam(model.parameters(), lr=3e-4) + criterion = torch.nn.CrossEntropyLoss() + + train_dataset = torchvision.datasets.STL10( + root, split="train", download=True, transform=torchvision.transforms.ToTensor() + ) + + test_dataset = torchvision.datasets.STL10( + root, split="test", download=True, transform=torchvision.transforms.ToTensor() + ) + + train_loader = torch.utils.data.DataLoader( + train_dataset, + batch_size=args.logistic_batch_size, + drop_last=True, + num_workers=args.workers, + ) + + test_loader = torch.utils.data.DataLoader( + test_dataset, + batch_size=args.logistic_batch_size, + drop_last=True, + num_workers=args.workers, + ) + + for epoch in range(args.logistic_epochs): + loss_epoch, accuracy_epoch = train(args, train_loader, simclr_model, model, criterion, optimizer) + print(f"Epoch [{epoch}/{args.logistic_epochs}]\t Loss: {loss_epoch / len(train_loader)}\t Accuracy: {accuracy_epoch / len(train_loader)}") + + # final testing + loss_epoch, accuracy_epoch = test(args, test_loader, simclr_model, model, criterion, optimizer) + print(f"[FINAL]\t Loss: {loss_epoch / len(test_loader)}\t Accuracy: {accuracy_epoch / len(test_loader)}") diff --git a/utils/yaml_config_hook.py b/utils/yaml_config_hook.py index ba5f0f8..18547e8 100644 --- a/utils/yaml_config_hook.py +++ b/utils/yaml_config_hook.py @@ -17,7 +17,8 @@ def yaml_config_hook(config_file, ex): config_dir, cf = d.popitem() cf = os.path.join(os.path.dirname(config_file), config_dir, cf + ".yaml") with open(cf) as f: - cfg.update(yaml.safe_load(f)) + l = yaml.safe_load(f) + cfg.update(l) if "defaults" in cfg.keys(): del cfg["defaults"]