-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
59 lines (46 loc) · 1.67 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import numpy as np
import time
from common import get_args,experiment_setup
if __name__=='__main__':
args = get_args()
env, agent, buffer, learner, tester = experiment_setup(args)
args.logger.summary_init(agent.graph, agent.sess)
# Progress info
args.logger.add_item('Epoch')
args.logger.add_item('Cycle')
args.logger.add_item('Episodes@green')
args.logger.add_item('Timesteps')
args.logger.add_item('TimeCost(sec)/train')
args.logger.add_item('TimeCost(sec)/test')
# Algorithm info
for key in agent.train_info.keys():
args.logger.add_item(key, 'scalar')
for key in learner.learner_info:
args.logger.add_item(key, 'scalar')
# Test info
for key in agent.step_info.keys():
args.logger.add_item(key, 'scalar')
for key in env.env_info.keys():
args.logger.add_item(key, 'scalar')
for key in tester.info:
args.logger.add_item(key, 'scalar')
args.logger.summary_setup()
episodes_cnt = 0
for epoch in range(args.epochs):
for cycle in range(args.cycles):
args.logger.tabular_clear()
args.logger.summary_clear()
start_time = time.time()
learner.learn(args, env, agent, buffer)
args.logger.add_record('TimeCost(sec)/train', time.time()-start_time)
start_time = time.time()
tester.cycle_summary()
args.logger.add_record('TimeCost(sec)/test', time.time()-start_time)
args.logger.add_record('Epoch', str(epoch)+'/'+str(args.epochs))
args.logger.add_record('Cycle', str(cycle)+'/'+str(args.cycles))
args.logger.add_record('Episodes', buffer.counter)
args.logger.add_record('Timesteps', learner.steps_counter)
args.logger.tabular_show(args.tag)
args.logger.summary_show(buffer.counter)
tester.epoch_summary()
tester.final_summary()