-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathconfigs.py
130 lines (97 loc) · 3.92 KB
/
configs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import argparse
from datetime import datetime
from pathlib import Path
import pprint
import torch
project_dir = Path(__file__).resolve().parent
datasets_dir = project_dir.joinpath('datasets/')
# Where to save checkpoint and log images
result_dir = project_dir.joinpath('results/')
if not result_dir.exists():
result_dir.mkdir()
def get_optimizer(optimizer_name='Adam'):
"""Get optimizer by name"""
# optimizer_name = optimizer_name.capitalize()
return getattr(torch.optim, optimizer_name)
def str2bool(arg):
"""String to boolean"""
if arg.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif arg.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
class BaseConfig(object):
def __init__(self):
"""Base Configuration Class"""
self.parse_base()
def parse_base(self):
"""Base configurations for all models"""
self.parser = argparse.ArgumentParser()
#================ Mode ==============#
self.parser.add_argument('--mode', type=str, default='train')
#================ Train ==============#
self.parser.add_argument('--batch_size', type=int, default=16)
self.parser.add_argument('--n_epochs', type=int, default=200)
self.parser.add_argument('--optimizer', type=str, default='RMSprop')
self.parser.add_argument('-dataset', type=str, default='CIFAR10')
#=============== Misc ===============#
self.parser.add_argument('--log_interval', type=int, default=100)
self.parser.add_argument('--save_interval', type=int, default=10)
def parse(self):
"""Update configuration with extra arguments (To be inherited)"""
pass
def initialize(self, parse=True, **optional_kwargs):
"""Set kwargs as class attributes with setattr"""
# Update parser
self.parse()
# Parse arguments
if parse:
kwargs = self.parser.parse_args()
else:
kwargs = self.parser.parse_known_args()[0]
# namedtuple => dictionary
kwargs = vars(kwargs)
kwargs.update(optional_kwargs)
if kwargs is not None:
for key, value in kwargs.items():
if key == 'optimizer':
value = get_optimizer(value)
setattr(self, key, value)
self.isTrain = self.mode == 'train'
# Dataset
# ex) ./datasets/Mnist/
self.dataset_dir = datasets_dir.joinpath(self.dataset)
# Save / Log
# ex) ./results/vae/
self.model_dir = result_dir
if not self.model_dir.exists():
self.model_dir.mkdir()
if self.mode == 'train':
# ex) ./results/vae/2017-12-10_10:09:08/
time_now = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
self.ckpt_dir = self.model_dir.joinpath(time_now)
self.ckpt_dir.mkdir()
file_path = self.ckpt_dir.joinpath('config.txt')
with open(file_path, 'w') as f:
f.write('------------ Configurations -------------\n')
for k, v in sorted(self.__dict__.items()):
f.write('%s: %s\n' % (str(k), str(v)))
f.write('----------------- End -------------------\n')
# Previous ckpt to load (optional for evaluation)
if self.mode == 'test':
assert self.load_ckpt_time
self.ckpt_dir = self.model_dir.joinpath(self.load_ckpt_time)
return self
def __repr__(self):
"""Pretty-print configurations in alphabetical order"""
config_str = 'Configurations\n'
config_str += pprint.pformat(self.__dict__)
return config_str
def get_config(parse=True):
"""Get configuration class in single step"""
return BaseConfig().initialize(parse=parse)
if __name__ == '__main__':
config = get_config()
import ipdb
ipdb.set_trace()