-
Notifications
You must be signed in to change notification settings - Fork 77
/
main.py
158 lines (132 loc) · 5.22 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import time
import importlib
import threading
from absl import app
from absl import flags
from pysc2 import maps
from pysc2.env import available_actions_printer
from pysc2.env import sc2_env
from pysc2.lib import stopwatch
import tensorflow as tf
from run_loop import run_loop
COUNTER = 0
LOCK = threading.Lock()
FLAGS = flags.FLAGS
flags.DEFINE_bool("training", True, "Whether to train agents.")
flags.DEFINE_bool("continuation", False, "Continuously training.")
flags.DEFINE_float("learning_rate", 5e-4, "Learning rate for training.")
flags.DEFINE_float("discount", 0.99, "Discount rate for future rewards.")
flags.DEFINE_integer("max_steps", int(1e5), "Total steps for training.")
flags.DEFINE_integer("snapshot_step", int(1e3), "Step for snapshot.")
flags.DEFINE_string("snapshot_path", "./snapshot/", "Path for snapshot.")
flags.DEFINE_string("log_path", "./log/", "Path for log.")
flags.DEFINE_string("device", "0", "Device for training.")
flags.DEFINE_string("map", "MoveToBeacon", "Name of a map to use.")
flags.DEFINE_bool("render", True, "Whether to render with pygame.")
flags.DEFINE_integer("screen_resolution", 64, "Resolution for screen feature layers.")
flags.DEFINE_integer("minimap_resolution", 64, "Resolution for minimap feature layers.")
flags.DEFINE_integer("step_mul", 8, "Game steps per agent step.")
flags.DEFINE_string("agent", "agents.a3c_agent.A3CAgent", "Which agent to run.")
flags.DEFINE_string("net", "fcn", "atari or fcn.")
flags.DEFINE_enum("agent_race", None, sc2_env.races.keys(), "Agent's race.")
flags.DEFINE_enum("bot_race", None, sc2_env.races.keys(), "Bot's race.")
flags.DEFINE_enum("difficulty", None, sc2_env.difficulties.keys(), "Bot's strength.")
flags.DEFINE_integer("max_agent_steps", 60, "Total agent steps.")
flags.DEFINE_bool("profile", False, "Whether to turn on code profiling.")
flags.DEFINE_bool("trace", False, "Whether to trace the code execution.")
flags.DEFINE_integer("parallel", 16, "How many instances to run in parallel.")
flags.DEFINE_bool("save_replay", False, "Whether to save a replay at the end.")
FLAGS(sys.argv)
if FLAGS.training:
PARALLEL = FLAGS.parallel
MAX_AGENT_STEPS = FLAGS.max_agent_steps
DEVICE = ['/gpu:'+dev for dev in FLAGS.device.split(',')]
else:
PARALLEL = 1
MAX_AGENT_STEPS = 1e5
DEVICE = ['/cpu:0']
LOG = FLAGS.log_path+FLAGS.map+'/'+FLAGS.net
SNAPSHOT = FLAGS.snapshot_path+FLAGS.map+'/'+FLAGS.net
if not os.path.exists(LOG):
os.makedirs(LOG)
if not os.path.exists(SNAPSHOT):
os.makedirs(SNAPSHOT)
def run_thread(agent, map_name, visualize):
with sc2_env.SC2Env(
map_name=map_name,
agent_race=FLAGS.agent_race,
bot_race=FLAGS.bot_race,
difficulty=FLAGS.difficulty,
step_mul=FLAGS.step_mul,
screen_size_px=(FLAGS.screen_resolution, FLAGS.screen_resolution),
minimap_size_px=(FLAGS.minimap_resolution, FLAGS.minimap_resolution),
visualize=visualize) as env:
env = available_actions_printer.AvailableActionsPrinter(env)
# Only for a single player!
replay_buffer = []
for recorder, is_done in run_loop([agent], env, MAX_AGENT_STEPS):
if FLAGS.training:
replay_buffer.append(recorder)
if is_done:
counter = 0
with LOCK:
global COUNTER
COUNTER += 1
counter = COUNTER
# Learning rate schedule
learning_rate = FLAGS.learning_rate * (1 - 0.9 * counter / FLAGS.max_steps)
agent.update(replay_buffer, FLAGS.discount, learning_rate, counter)
replay_buffer = []
if counter % FLAGS.snapshot_step == 1:
agent.save_model(SNAPSHOT, counter)
if counter >= FLAGS.max_steps:
break
elif is_done:
obs = recorder[-1].observation
score = obs["score_cumulative"][0]
print('Your score is '+str(score)+'!')
if FLAGS.save_replay:
env.save_replay(agent.name)
def _main(unused_argv):
"""Run agents"""
stopwatch.sw.enabled = FLAGS.profile or FLAGS.trace
stopwatch.sw.trace = FLAGS.trace
maps.get(FLAGS.map) # Assert the map exists.
# Setup agents
agent_module, agent_name = FLAGS.agent.rsplit(".", 1)
agent_cls = getattr(importlib.import_module(agent_module), agent_name)
agents = []
for i in range(PARALLEL):
agent = agent_cls(FLAGS.training, FLAGS.minimap_resolution, FLAGS.screen_resolution)
agent.build_model(i > 0, DEVICE[i % len(DEVICE)], FLAGS.net)
agents.append(agent)
config = tf.ConfigProto(allow_soft_placement=True)
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
summary_writer = tf.summary.FileWriter(LOG)
for i in range(PARALLEL):
agents[i].setup(sess, summary_writer)
agent.initialize()
if not FLAGS.training or FLAGS.continuation:
global COUNTER
COUNTER = agent.load_model(SNAPSHOT)
# Run threads
threads = []
for i in range(PARALLEL - 1):
t = threading.Thread(target=run_thread, args=(agents[i], FLAGS.map, False))
threads.append(t)
t.daemon = True
t.start()
time.sleep(5)
run_thread(agents[-1], FLAGS.map, FLAGS.render)
for t in threads:
t.join()
if FLAGS.profile:
print(stopwatch.sw)
if __name__ == "__main__":
app.run(_main)