-
Notifications
You must be signed in to change notification settings - Fork 0
/
density_train.py
120 lines (102 loc) · 3.78 KB
/
density_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
import torch
import pytorch_lightning as pl
from argparse import ArgumentParser
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint, TQDMProgressBar
from dataset import EMBEDMammoDataModule
from downstream_model import MammoNet
def main(hparams):
# torch.set_float32_matmul_precision('medium')
torch.set_float32_matmul_precision("high")
# sets seeds for numpy, torch, python.random and PYTHONHASHSEED.
pl.seed_everything(hparams.seed, workers=True)
if hparams.dataset == "embed":
data = EMBEDMammoDataModule(
target="density",
csv_file=hparams.csv_file,
image_size=(512, 384),
batch_alpha=hparams.batch_alpha,
batch_size=hparams.batch_size,
num_workers=hparams.num_workers,
)
else:
print("Unknown dataset. Exiting.")
return
# model
model = MammoNet(
backbone=hparams.model, learning_rate=hparams.learning_rate, num_classes=4
)
# Create output directory
output_dir = os.path.join(hparams.output_root, hparams.output_name)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
print("")
print("=============================================================")
print("TRAINING...")
print("=============================================================")
print("")
wandb_logger = WandbLogger(save_dir=hparams.output_root, project="mammo-stuff")
wandb_logger.watch(model, log="all", log_freq=100)
# train
trainer = pl.Trainer(
val_check_interval=1000,
max_epochs=hparams.epochs,
accelerator="auto",
devices=hparams.num_devices,
precision="16-mixed",
num_sanity_val_steps=0,
logger=[
TensorBoardLogger(hparams.output_root, name=hparams.output_name),
wandb_logger,
],
callbacks=[
ModelCheckpoint(filename="last.ckpt"),
ModelCheckpoint(monitor="val_auc", mode="max"),
TQDMProgressBar(refresh_rate=10),
],
)
trainer.fit(
model=model,
datamodule=data,
)
print("")
print("=============================================================")
print("VALIDATION...")
print("=============================================================")
print("")
trainer.validate(
model=model,
datamodule=data,
ckpt_path=trainer.checkpoint_callback.best_model_path,
)
print("")
print("=============================================================")
print("TESTING...")
print("=============================================================")
print("")
trainer.test(
model=model,
datamodule=data,
ckpt_path=trainer.checkpoint_callback.best_model_path,
)
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--epochs", type=int, default=50)
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--batch_alpha", type=float, default=1.0)
parser.add_argument("--learning_rate", type=float, default=0.0001)
parser.add_argument("--num_workers", type=int, default=6)
parser.add_argument("--num_devices", type=int, default=1)
parser.add_argument("--model", type=str, default="resnet18")
parser.add_argument("--dataset", type=str, default="embed")
parser.add_argument(
"--csv_file",
type=str,
default="/vol/biomedic3/data/EMBED/tables/mammo-net-csv/embed-non-negative.csv",
)
parser.add_argument("--output_root", type=str, default="output")
parser.add_argument("--output_name", type=str, default="density-balanced")
parser.add_argument("--seed", type=int, default=33)
args = parser.parse_args()
main(args)