-
Notifications
You must be signed in to change notification settings - Fork 42
/
train.py
65 lines (51 loc) · 1.94 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import numpy as np
from utilities.trainer import *
import torch
from arguments import *
import os
from utilities.util import *
from utilities.logger import Logger
import argparse
parser = argparse.ArgumentParser(description='Test rl agent.')
parser.add_argument('--save-path', type=str, nargs='?', default='./', help='Please input the directory of saving model.')
argv = parser.parse_args()
if argv.save_path[-1] is '/':
save_path = argv.save_path
else:
save_path = argv.save_path+'/'
# create save folders
if 'model_save' not in os.listdir(save_path):
os.mkdir(save_path+'model_save')
if 'tensorboard' not in os.listdir(save_path):
os.mkdir(save_path+'tensorboard')
if log_name not in os.listdir(save_path+'model_save/'):
os.mkdir(save_path+'model_save/'+log_name)
if log_name not in os.listdir(save_path+'tensorboard/'):
os.mkdir(save_path+'tensorboard/'+log_name)
else:
path = save_path+'tensorboard/'+log_name
for f in os.listdir(path):
file_path = os.path.join(path,f)
if os.path.isfile(file_path):
os.remove(file_path)
logger = Logger(save_path+'tensorboard/' + log_name)
model = Model[model_name]
strategy = Strategy[model_name]
print ( '{}\n'.format(args) )
if strategy == 'pg':
train = PGTrainer(args, model, env(), logger, args.online)
elif strategy == 'q':
raise NotImplementedError('This needs to be implemented.')
else:
raise RuntimeError('Please input the correct strategy, e.g. pg or q.')
stat = dict()
for i in range(args.train_episodes_num):
train.run(stat)
train.logging(stat)
if i%args.save_model_freq == args.save_model_freq-1:
train.print_info(stat)
torch.save({'model_state_dict': train.behaviour_net.state_dict()}, save_path+'model_save/'+log_name+'/model.pt')
print ('The model is saved!\n')
with open(save_path+'model_save/'+log_name +'/log.txt', 'w+') as file:
file.write(str(args)+'\n')
file.write(str(i))