Skip to content

Commit

Permalink
Added logistic regression
Browse files Browse the repository at this point in the history
  • Loading branch information
Spijkervet committed Mar 11, 2020
1 parent e36661c commit c38be08
Show file tree
Hide file tree
Showing 6 changed files with 149 additions and 5 deletions.
10 changes: 9 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,15 @@ Then, simply run:
python main.py
```

### Dependencies
### Testing
To test a trained model, make sure to set the `model_path` variable in the `config/config.yaml` to the log ID of the training (e.g. `logs/0`).
Set the `model_num` to the epoch number you want to load the checkpoints from (e.g. `40`).

```
python -m testing.logistic_regression
```

## Dependencies
```
torch
torchvision
Expand Down
17 changes: 15 additions & 2 deletions config/config.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,21 @@
# train options
batch_size: 256
workers: 16
start_epoch: 0
epochs: 40

# model options
resnet: "resnet18"
normalize: True
temperature: 0.5
n_out: 64
resnet: "resnet18"

# loss options
temperature: 0.5

# reload options
model_path: "logs/182" # set to most directory containing `checkpoint_##.tar`
model_num: 40 # set to checkpoint number

# logistic regression options
logistic_batch_size: 256
logistic_epochs: 100
7 changes: 6 additions & 1 deletion model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,15 @@
import torch
from modules import SimCLR

def load_model(args):
def load_model(args, reload_model=False):
model = SimCLR(args)

if reload_model:
model_fp = os.path.join(args.model_path, "checkpoint_{}.tar".format(args.model_num))
model.load_state_dict(torch.load(model_fp))
return model


def save_model(args, model, optimizer):
out = os.path.join(args.out_dir, "checkpoint_{}.tar".format(args.current_epoch))

Expand Down
4 changes: 4 additions & 0 deletions modules/logistic_regression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
import torch.nn as nn

class LogisticRegression(nn.Module):

113 changes: 113 additions & 0 deletions testing/logistic_regression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import torch
import torchvision
import argparse

from experiment import ex
from model import load_model
from utils import post_config_hook


def train(args, loader, simclr_model, model, criterion, optimizer):
loss_epoch = 0
accuracy_epoch = 0
for step, (x, y) in enumerate(loader):
optimizer.zero_grad()

x = x.to(args.device)
y = y.to(args.device)

# get encoding
with torch.no_grad():
h, z = simclr_model(x)
# h = 512
# z = 64

output = model(h)
loss = criterion(output, y)

predicted = output.argmax(1)
acc = (predicted == y).sum().item() / y.size(0)
accuracy_epoch += acc

loss.backward()
optimizer.step()

loss_epoch += loss.item()

return loss_epoch, accuracy_epoch

def test(args, loader, simclr_model, model, criterion, optimizer):
loss_epoch = 0
accuracy_epoch = 0
model.eval()
for step, (x, y) in enumerate(loader):
model.zero_grad()

x = x.to(args.device)
y = y.to(args.device)

# get encoding
with torch.no_grad():
h, z = simclr_model(x)
# h = 512
# z = 64

output = model(h)
loss = criterion(output, y)

predicted = output.argmax(1)
acc = (predicted == y).sum().item() / y.size(0)
accuracy_epoch += acc

loss_epoch += loss.item()

return loss_epoch, accuracy_epoch

@ex.automain
def main(_run, _log):
args = argparse.Namespace(**_run.config)
args = post_config_hook(args, _run)

args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

root = "./datasets"
simclr_model = load_model(args, reload_model=True)
simclr_model = simclr_model.to(args.device)
simclr_model.eval()


## Logistic Regression
model = torch.nn.Sequential(torch.nn.Linear(simclr_model.n_features, 10)).to(args.device)

optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
criterion = torch.nn.CrossEntropyLoss()

train_dataset = torchvision.datasets.STL10(
root, split="train", download=True, transform=torchvision.transforms.ToTensor()
)

test_dataset = torchvision.datasets.STL10(
root, split="test", download=True, transform=torchvision.transforms.ToTensor()
)

train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=args.logistic_batch_size,
drop_last=True,
num_workers=args.workers,
)

test_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=args.logistic_batch_size,
drop_last=True,
num_workers=args.workers,
)

for epoch in range(args.logistic_epochs):
loss_epoch, accuracy_epoch = train(args, train_loader, simclr_model, model, criterion, optimizer)
print(f"Epoch [{epoch}/{args.logistic_epochs}]\t Loss: {loss_epoch / len(train_loader)}\t Accuracy: {accuracy_epoch / len(train_loader)}")

# final testing
loss_epoch, accuracy_epoch = test(args, test_loader, simclr_model, model, criterion, optimizer)
print(f"[FINAL]\t Loss: {loss_epoch / len(test_loader)}\t Accuracy: {accuracy_epoch / len(test_loader)}")
3 changes: 2 additions & 1 deletion utils/yaml_config_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ def yaml_config_hook(config_file, ex):
config_dir, cf = d.popitem()
cf = os.path.join(os.path.dirname(config_file), config_dir, cf + ".yaml")
with open(cf) as f:
cfg.update(yaml.safe_load(f))
l = yaml.safe_load(f)
cfg.update(l)

if "defaults" in cfg.keys():
del cfg["defaults"]
Expand Down

0 comments on commit c38be08

Please sign in to comment.