-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
145 lines (123 loc) · 6.16 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import os
from argparse import ArgumentParser
import math
from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor, RichProgressBar
from lightning.pytorch.loggers import TensorBoardLogger
from dataset import create_train_dataloaders
from utils import DotDict
from utils.lightning import LightningCoatNet
slurm_state = True if 'SLURM_JOB_ID' in os.environ else False
if slurm_state:
version = 'slurm'+os.environ['SLURM_JOB_ID']
else:
from datetime import datetime
version = datetime.now().strftime('%d-%b-%y_%H-%M')
def parse_cmd():
parser = ArgumentParser('PyTorch CoatNET training script.')
parser.add_argument('--data_dir', type=str, required=True,
help='Directory containing ImageNet dataset.')
parser.add_argument('--class_dict', type=str, required=True,
help='Path to the pickle file containing the classes mapping dictionary.')
parser.add_argument('--config', type=str, required=True,
help='Path to the config toml file.')
parser.add_argument('--ckpt_dir', type=str, required=True,
help='Directory to save checkpoints.')
parser.add_argument('--log_dir', type=str, required=True,
help="The tensorboard logging directory.")
parser.add_argument('--image_size', type=int, default=224,
help='Image size for training.')
parser.add_argument('--seed', type=int, default=42,
help='Random seed.')
parser.add_argument('--num_workers', type=int, default=4,
help='Number of workers for dataloader.')
parser.add_argument('--batch_size', type=int, default=256,
help='Batch size for training.')
parser.add_argument('--learning_rate', type=float, default=1e-3,
help='Base learning rate.')
parser.add_argument('--learning_schedule', type=str, default='cosine',
help='Learning rate schedule.')
parser.add_argument('--ema_decay', type=float, default=None,
help='Exponential moving average decay.')
parser.add_argument('--weight_decay', type=float, default=1e-8,
help='Weight decay.')
parser.add_argument('--label_smoothing', type=float, default=0.1,
help='Label smoothing.')
parser.add_argument('--alpha', type=float, default=0.8,
help='Mixup alpha.')
parser.add_argument('--peak_lr', type=float, default=1e-3,
help='Peak learning rate.')
parser.add_argument('--min_lr', type=float, default=1e-5,
help='Minimum learning rate.')
parser.add_argument('--warmups', type=int, default=1e+4,
help='Number of warmup steps.')
parser.add_argument('--max_epochs', type=int, default=100,
help='Maximum number of epochs to train.')
parser.add_argument('--gpus', type=int, default=1,
help='Number of GPUs to use.')
parser.add_argument('--nodes', type=int, default=1,
help='Number of nodes on slurm')
return parser.parse_args()
def train(args):
seed_everything(args.seed)
# prepare dataset loaders
train_loader, val_loader, ds_len = create_train_dataloaders(args, val_frac=0.05, train=True)
# checkpoint callback
ckpt_callback = ModelCheckpoint(
dirpath=os.path.join(args.ckpt_dir, args.MODEL.name),
filename=args.MODEL.name+'-{epoch:02d}-{val_loss:.2f}',
save_top_k=2,
save_last=True,
verbose=False,
monitor='val_loss',
mode='min'
)
# tensorboard logger
tb_logger = TensorBoardLogger(
save_dir=args.log_dir,
name=args.MODEL.name,
version=version,
default_hp_metric=False
)
# learning rate monitor
lr_monitor = LearningRateMonitor(logging_interval='step')
# create the model and lightning wrapper
model = LightningCoatNet(image_size=args.MODEL.image_size, num_channels=args.MODEL.num_channels,
num_classes=args.MODEL.num_classes, lengths=args.MODEL.lengths,
depths=args.MODEL.depths, sizes=args.MODEL.sizes, blocks=args.MODEL.blocks,
mbconv_e=args.MODEL.mbconv_e, mbconv_se=args.MODEL.mbconv_se,
head_dim=args.MODEL.head_dim, mem_eff=args.MODEL.mem_eff,
tfmrel_e=args.MODEL.tfmrel_e, qkv_bias=args.MODEL.qkv_bias,
fc_e=args.MODEL.fc_e, stochastic_rate=args.MODEL.stochastic_rate,
learning_rate=args.learning_rate,
learning_schedule=args.learning_schedule,
ema_decay=args.ema_decay, weight_decay=args.weight_decay,
label_smoothing=args.label_smoothing, alpha=args.alpha,
peak_lr=args.peak_lr, warmup_steps=args.warmups, min_lr=args.min_lr,
max_steps=math.ceil(ds_len/args.batch_size)*args.max_epochs)
trainer = Trainer(
devices=args.gpus,
num_nodes=args.nodes,
accelerator='gpu' if args.gpus > 0 else 'cpu',
# precision=args.precision,
max_epochs=args.max_epochs,
callbacks=[ckpt_callback, lr_monitor, RichProgressBar()],
logger=tb_logger,
# progress_bar_refresh_rate=1 if slurm_state else 20,
strategy='ddp',
# num_sanity_val_steps=0 if slurm_state else 2,
# sync_batchnorm=True if args.gpus > 1 else False,
# resume_from_checkpoint=args.resume_from_checkpoint,
# deterministic=True,
# benchmark=True,
# profiler='simple',
# plugins='ddp_sharded'
gradient_clip_algorithm='value',
gradient_clip_val=1.0,
)
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)
if __name__ == '__main__':
conf = parse_cmd()
args = DotDict.from_toml(conf.config)
args.update(**conf.__dict__)
train(args)