-
Notifications
You must be signed in to change notification settings - Fork 0
/
agent.py
209 lines (173 loc) · 6.6 KB
/
agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
import argparse
import gym
import select
import sys
import threading
import traceback
import time
from dqn import DQN
from enums import EnvTypes
# number of episodes to train and test the agent for
TRAIN_EPISODES = 3000
TEST_EPISODES = 1000
# number of random actions taken for initialization
INIT_STEPS = 10000
atari_environments = {
'Breakout-v0': 6,
'MsPacman-v0': 8,
'Phoenix-v0': 8,
'SpaceInvaders-v0': 6
}
standard_environments = {
'CartPole-v0': [4, 2]
}
algorithms = {
'dqn': DQN
}
render = False
polling = True
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('env_name',
help="Name of the OpenAI Gym environment to run")
parser.add_argument('network_algorithm',
help="The algorithm to be trained")
parser.add_argument('--monitor', default=None,
help="Directory for monitor recording of training")
parser.add_argument('--initstep', default=INIT_STEPS,
help="Number of steps taken during initialization")
parser.add_argument('--save', default=None,
help="Save algoithm periodically to specified directory")
parser.add_argument('--restore', default=None,
help="Restores and resumes training of specified algorithm")
parser.add_argument('--predict', default=None,
help="Run prediction from specified save directory")
return parser.parse_args()
def initialize_training(env, network, iterations):
observation = env.reset()
for i in range(iterations):
# step environment with random action and fill replay memory
old_observation = observation
action = env.action_space.sample()
observation, reward, done, _ = env.step(action)
# add state transition to replay memory
network.process_observation(old_observation)
network.notify_state_transition(action, reward, done)
# reset the environment if done
if done:
observation = env.reset()
# show progress every thousand steps
step = i+1
if step % 1000 == 0:
print("Training initialization step {} completed".format(step))
def train_agent(env, network, save_dir):
print("Beginning training")
# train for NUM_EPISODES number of episodes
curr_episode = 0
training_iterations = 0
tot_reward = 0
observation = env.reset()
while curr_episode < TRAIN_EPISODES:
# render the environment
if render:
env.render()
# take an action and step the environment
action = network.training_predict(env, observation)
old_observation = observation
observation, reward, done, _ = env.step(action)
tot_reward += reward
# update network with state transition and train
network.notify_state_transition(action, reward, done)
network.batch_train(save_dir)
# reset the environment and start new episode if done
if done:
print("Episode {} completed; total reward is {}".format(curr_episode, tot_reward))
if render:
env.render()
observation = env.reset()
curr_episode += 1
tot_reward = 0
# display training iterations every 10 iterations
training_iterations += 1
if training_iterations % 100 == 0:
print("Agent training iteration {} completed".format(training_iterations))
def test_agent(env, network):
print("Beginning testing")
# test for TEST_EPISODES number of episodes
avg_ep_reward = 0
curr_episode = 0
tot_reward = 0
observation = env.reset()
while curr_episode < TEST_EPISODES:
if render:
env.render()
observation, reward, done, _ = env.step(network.testing_predict(observation))
tot_reward += reward
if done:
print("Episode {} completed; total reward is {}".format(curr_episode, tot_reward))
if render:
env.render()
observation = env.reset()
curr_episode += 1
avg_ep_reward += tot_reward
tot_reward = 0
avg_ep_reward /= TEST_EPISODES
print("Average total reward per episode is {}".format(avg_ep_reward))
def render_toggle():
global render
while(polling):
register = select.select([sys.stdin], [], [], 0.1)[0]
if len(register) and register[0].readline():
render = not render
def main():
global polling
# parse command line flag arguments
args = parse_arguments()
# currently only support certain environments
assert (args.env_name in atari_environments.keys() or
args.env_name in standard_environments.keys())
# add keyboard polling for render toggling
render_toggle_thread = threading.Thread(target=render_toggle)
render_toggle_thread.start()
try:
if args.env_name in atari_environments.keys():
env_type = EnvTypes.ATARI
state_dims = [105, 80, 1]
action_dims = atari_environments[args.env_name]
elif args.env_name in standard_environments.keys():
env_type = EnvTypes.STANDARD
state_dims = [standard_environments[args.env_name][0]]
action_dims = standard_environments[args.env_name][1]
# initialize network and prepare for training
network = algorithms[args.network_algorithm](env_type, state_dims, action_dims)
if args.predict is None:
initialize_training(gym.make(args.env_name), network, INIT_STEPS)
if args.restore is not None:
network.restore_algorithm(args.restore)
# begin training
train_env = gym.make(args.env_name)
if args.monitor is not None:
train_env.monitor.start(args.monitor+'/train')
train_agent(train_env, network, args.save)
if args.monitor is not None:
train_env.monitor.close()
else:
network.restore_algorithm(args.predict)
# evaluate agent
test_env = gym.make(args.env_name)
if args.monitor is not None:
test_env.monitor.start(args.monitor+'/test')
test_agent(test_env, network)
if args.monitor is not None:
test_env.monitor.close()
except KeyboardInterrupt:
print("\nInterrupt received. Terminating agent...")
except:
traceback.print_exc()
finally:
# stop the keyboard polling thread
polling = False
render_toggle_thread.join()
sys.exit()
if __name__ == '__main__':
main()