From 02fcf7c719d4a1fa2766da814ad3a916cd91b397 Mon Sep 17 00:00:00 2001 From: YuriCat Date: Fri, 4 Mar 2022 16:09:14 +0900 Subject: [PATCH 01/21] experiment: set all threads daemon=True and remove shutdown flags --- handyrl/connection.py | 52 +++++++++++++------------------------------ handyrl/train.py | 47 ++++++++++---------------------------- handyrl/worker.py | 18 +++++---------- 3 files changed, 33 insertions(+), 84 deletions(-) diff --git a/handyrl/connection.py b/handyrl/connection.py index 3550f153..fe424018 100755 --- a/handyrl/connection.py +++ b/handyrl/connection.py @@ -131,15 +131,12 @@ def open_multiprocessing_connections(num_process, target, args_func): class MultiProcessJobExecutor: - def __init__(self, func, send_generator, num_workers, postprocess=None, num_receivers=1): + def __init__(self, func, send_generator, num_workers, postprocess=None): self.send_generator = send_generator self.postprocess = postprocess - self.num_receivers = num_receivers self.conns = [] self.waiting_conns = queue.Queue() - self.shutdown_flag = False self.output_queue = queue.Queue(maxsize=8) - self.threads = [] for i in range(num_workers): conn0, conn1 = mp.Pipe(duplex=True) @@ -148,26 +145,18 @@ def __init__(self, func, send_generator, num_workers, postprocess=None, num_rece self.conns.append(conn0) self.waiting_conns.put(conn0) - def shutdown(self): - self.shutdown_flag = True - for thread in self.threads: - thread.join() - def recv(self): return self.output_queue.get() def start(self): - self.threads.append(threading.Thread(target=self._sender)) - for i in range(self.num_receivers): - self.threads.append(threading.Thread(target=self._receiver, args=(i,))) - for thread in self.threads: - thread.start() + threading.Thread(target=self._sender, daemon=True).start() + threading.Thread(target=self._receiver, daemon=True).start() def _sender(self): print('start sender') - while not self.shutdown_flag: + while True: data = next(self.send_generator) - while not self.shutdown_flag: + while True: try: conn = self.waiting_conns.get(timeout=0.3) conn.send(data) @@ -176,17 +165,16 @@ def _sender(self): pass print('finished sender') - def _receiver(self, index): - print('start receiver %d' % index) - conns = [conn for i, conn in enumerate(self.conns) if i % self.num_receivers == index] - while not self.shutdown_flag: - tmp_conns = connection.wait(conns) + def _receiver(self): + print('start receiver') + while True: + tmp_conns = connection.wait(self.conns) for conn in tmp_conns: data = conn.recv() self.waiting_conns.put(conn) if self.postprocess is not None: data = self.postprocess(data) - while not self.shutdown_flag: + while True: try: self.output_queue.put(data, timeout=0.3) break @@ -202,18 +190,8 @@ def __init__(self, conns=[]): self.conns = [] for conn in conns: self.add_connection(conn) - self.shutdown_flag = False - self.threads = [ - threading.Thread(target=self._send_thread), - threading.Thread(target=self._recv_thread), - ] - for thread in self.threads: - thread.start() - - def shutdown(self): - self.shutdown_flag = True - for thread in self.threads: - thread.join() + threading.Thread(target=self._send_thread, daemon=True).start() + threading.Thread(target=self._recv_thread, daemon=True).start() def recv(self): return self.input_queue.get() @@ -229,7 +207,7 @@ def disconnect(self, conn): self.conns.remove(conn) def _send_thread(self): - while not self.shutdown_flag: + while True: try: conn, send_data = self.output_queue.get(timeout=0.3) except queue.Empty: @@ -242,7 +220,7 @@ def _send_thread(self): self.disconnect(conn) def _recv_thread(self): - while not self.shutdown_flag: + while True: conns = connection.wait(self.conns, timeout=0.3) for conn in conns: try: @@ -253,7 +231,7 @@ def _recv_thread(self): except EOFError: self.disconnect(conn) continue - while not self.shutdown_flag: + while True: try: self.input_queue.put((conn, recv_data), timeout=0.3) break diff --git a/handyrl/train.py b/handyrl/train.py index 70451eb1..28841f48 100755 --- a/handyrl/train.py +++ b/handyrl/train.py @@ -265,9 +265,7 @@ class Batcher: def __init__(self, args, episodes): self.args = args self.episodes = episodes - self.shutdown_flag = False - - self.executor = MultiProcessJobExecutor(self._worker, self._selector(), self.args['num_batchers'], num_receivers=2) + self.executor = MultiProcessJobExecutor(self._worker, self._selector(), self.args['num_batchers']) def _selector(self): while True: @@ -275,7 +273,7 @@ def _selector(self): def _worker(self, conn, bid): print('started batcher %d' % bid) - while not self.shutdown_flag: + while True: episodes = conn.recv() batch = make_batch(episodes, self.args) conn.send(batch) @@ -308,10 +306,6 @@ def select_episode(self): def batch(self): return self.executor.recv() - def shutdown(self): - self.shutdown_flag = True - self.executor.shutdown() - class Trainer: def __init__(self, args, model): @@ -328,7 +322,6 @@ def __init__(self, args, model): self.batcher = Batcher(self.args, self.episodes) self.update_flag = False self.update_queue = queue.Queue(maxsize=1) - self.shutdown_flag = False self.wrapped_model = ModelWrapper(self.model) self.trained_model = self.wrapped_model @@ -340,10 +333,6 @@ def update(self): model, steps = self.update_queue.get() return model, steps - def shutdown(self): - self.shutdown_flag = True - self.batcher.shutdown() - def train(self): if self.optimizer is None: # non-parametric model time.sleep(0.1) @@ -354,7 +343,7 @@ def train(self): self.trained_model.cuda() self.trained_model.train() - while data_cnt == 0 or not (self.update_flag or self.shutdown_flag): + while data_cnt == 0 or not self.update_flag: batch = self.batcher.batch() batch_size = batch['value'].size(0) player_count = batch['value'].size(2) @@ -388,12 +377,12 @@ def train(self): def run(self): print('waiting training') - while not self.shutdown_flag and len(self.episodes) < self.args['minimum_episodes']: + while len(self.episodes) < self.args['minimum_episodes']: time.sleep(1) - if not self.shutdown_flag and self.optimizer is not None: + if self.optimizer is not None: self.batcher.run() print('started training') - while not self.shutdown_flag: + while True: model = self.train() self.update_flag = False self.update_queue.put((model, self.steps)) @@ -413,7 +402,6 @@ def __init__(self, args, net=None, remote=False): self.env = make_env(env_args) eval_modify_rate = (args['update_episodes'] ** 0.85) / args['update_episodes'] self.eval_rate = max(args['eval_rate'], eval_modify_rate) - self.shutdown_flag = False self.flags = set() # trained datum @@ -438,12 +426,6 @@ def __init__(self, args, net=None, remote=False): # thread connection self.trainer = Trainer(args, self.model) - def shutdown(self): - self.shutdown_flag = True - self.trainer.shutdown() - self.worker.shutdown() - self.thread.join() - def model_path(self, model_id): return os.path.join('models', str(model_id) + '.pth') @@ -549,7 +531,7 @@ def server(self): while self.model_epoch < self.args['epochs'] or self.args['epochs'] < 0: # no update call before storing minimum number of episodes + 1 age next_update_episodes = prev_update_episodes + self.args['update_episodes'] - while not self.shutdown_flag and self.num_returned_episodes < next_update_episodes: + while self.num_returned_episodes < next_update_episodes: conn, (req, data) = self.worker.recv() multi_req = isinstance(data, list) if not multi_req: @@ -618,16 +600,11 @@ def server(self): print('finished server') def run(self): - try: - # open training thread - self.thread = threading.Thread(target=self.trainer.run) - self.thread.start() - # open generator, evaluator - self.worker.run() - self.server() - - finally: - self.shutdown() + # open training thread + threading.Thread(target=self.trainer.run, daemon=True).start() + # open generator, evaluator + self.worker.run() + self.server() def train_main(args): diff --git a/handyrl/worker.py b/handyrl/worker.py index 7fe28bc7..0f24222b 100755 --- a/handyrl/worker.py +++ b/handyrl/worker.py @@ -163,11 +163,8 @@ def run(self): def gather_loop(args, conn, gaid): - try: - gather = Gather(args, conn, gaid) - gather.run() - finally: - gather.shutdown() + gather = Gather(args, conn, gaid) + gather.run() class WorkerCluster(QueueCommunicator): @@ -197,7 +194,7 @@ def run(self): def entry_server(port): print('started entry server %d' % port) conn_acceptor = accept_socket_connections(port=port, timeout=0.3) - while not self.shutdown_flag: + while True: conn = next(conn_acceptor) if conn is not None: worker_args = conn.recv() @@ -213,17 +210,14 @@ def entry_server(port): def worker_server(port): print('started worker server %d' % port) conn_acceptor = accept_socket_connections(port=port, timeout=0.3) - while not self.shutdown_flag: + while True: conn = next(conn_acceptor) if conn is not None: self.add_connection(conn) print('finished worker server') - # use thread list of super class - self.threads.append(threading.Thread(target=entry_server, args=(9999,))) - self.threads.append(threading.Thread(target=worker_server, args=(9998,))) - self.threads[-2].start() - self.threads[-1].start() + threading.Thread(target=entry_server, args=(9999,), daemon=True).start() + threading.Thread(target=worker_server, args=(9998,), daemon=True).start() def entry(worker_args): From fc8d3afc6bd464be7c36e15cabfffb1936260b49 Mon Sep 17 00:00:00 2001 From: YuriCat Date: Fri, 4 Mar 2022 20:50:50 +0900 Subject: [PATCH 02/21] fix: finished receiver message --- handyrl/connection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/handyrl/connection.py b/handyrl/connection.py index fe424018..33b83287 100755 --- a/handyrl/connection.py +++ b/handyrl/connection.py @@ -180,7 +180,7 @@ def _receiver(self): break except queue.Full: pass - print('finished receiver %d' % index) + print('finished receiver') class QueueCommunicator: From 22322bc0d3eb921eb7c87540e436b99562bb3aeb Mon Sep 17 00:00:00 2001 From: YuriCat Date: Tue, 8 Mar 2022 21:59:42 +0900 Subject: [PATCH 03/21] feature: blocking style connections --- handyrl/connection.py | 32 +++++++------------------------- 1 file changed, 7 insertions(+), 25 deletions(-) diff --git a/handyrl/connection.py b/handyrl/connection.py index 33b83287..713d60fb 100755 --- a/handyrl/connection.py +++ b/handyrl/connection.py @@ -156,30 +156,20 @@ def _sender(self): print('start sender') while True: data = next(self.send_generator) - while True: - try: - conn = self.waiting_conns.get(timeout=0.3) - conn.send(data) - break - except queue.Empty: - pass + conn = self.waiting_conns.get() + conn.send(data) print('finished sender') def _receiver(self): print('start receiver') while True: - tmp_conns = connection.wait(self.conns) - for conn in tmp_conns: + conns = connection.wait(self.conns) + for conn in conns: data = conn.recv() self.waiting_conns.put(conn) if self.postprocess is not None: data = self.postprocess(data) - while True: - try: - self.output_queue.put(data, timeout=0.3) - break - except queue.Full: - pass + self.output_queue.put(data) print('finished receiver') @@ -208,10 +198,7 @@ def disconnect(self, conn): def _send_thread(self): while True: - try: - conn, send_data = self.output_queue.get(timeout=0.3) - except queue.Empty: - continue + conn, send_data = self.output_queue.get() try: conn.send(send_data) except ConnectionResetError: @@ -231,9 +218,4 @@ def _recv_thread(self): except EOFError: self.disconnect(conn) continue - while True: - try: - self.input_queue.put((conn, recv_data), timeout=0.3) - break - except queue.Full: - pass + self.input_queue.put((conn, recv_data)) From 6da14b1d0b2ff72bbc7d6b8932e31dd966a36345 Mon Sep 17 00:00:00 2001 From: YuriCat Date: Wed, 16 Mar 2022 20:52:24 +0900 Subject: [PATCH 04/21] feature: use same buffer_length for buffering system in Gather --- handyrl/worker.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/handyrl/worker.py b/handyrl/worker.py index f08028db..7e60e8b6 100755 --- a/handyrl/worker.py +++ b/handyrl/worker.py @@ -118,8 +118,7 @@ def __init__(self, args, conn, gaid): for conn in worker_conns: self.add_connection(conn) - self.args_buf_len = 1 + len(worker_conns) // 4 - self.result_buf_len = 1 + len(worker_conns) // 4 + self.buffer_length = 1 + len(worker_conns) // 4 def __del__(self): print('finished gather %d' % self.gather_id) @@ -131,7 +130,7 @@ def run(self): # When requested arguments, return buffered outputs if len(self.args_queue) == 0: # get multiple arguments from server and store them - self.server_conn.send((command, [None] * self.args_buf_len)) + self.server_conn.send((command, [None] * self.buffer_length)) self.args_queue += self.server_conn.recv() next_args = self.args_queue.popleft() @@ -153,7 +152,7 @@ def run(self): self.result_send_map[command].append(args) self.result_send_cnt += 1 - if self.result_send_cnt >= self.result_buf_len: + if self.result_send_cnt >= self.buffer_length: # send datum to server after buffering certain number of datum for command, args_list in self.result_send_map.items(): self.server_conn.send((command, args_list)) From 0d0a8b8a300e5e5bf2d5ba75c58572c167422183 Mon Sep 17 00:00:00 2001 From: YuriCat Date: Thu, 31 Mar 2022 23:19:55 +0900 Subject: [PATCH 05/21] feature: observation switch in agent class --- handyrl/agent.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/handyrl/agent.py b/handyrl/agent.py index 6b6f6bd9..74b2de48 100755 --- a/handyrl/agent.py +++ b/handyrl/agent.py @@ -41,11 +41,12 @@ def print_outputs(env, prob, v): class Agent: - def __init__(self, model, temperature=0.0): + def __init__(self, model, temperature=0.0, observation=True): # model might be a neural net, or some planning algorithm such as game tree search self.model = model self.hidden = None self.temperature = temperature + self.observation = observation def reset(self, env, show=False): self.hidden = self.model.init_hidden() @@ -74,11 +75,13 @@ def action(self, env, player, show=False): return random.choices(np.arange(len(p)), weights=softmax(p / self.temperature))[0] def observe(self, env, player, show=False): - outputs = self.plan(env.observation(player)) - v = outputs.get('value', None) - if show: - print_outputs(env, None, v) - return v if v is not None else [0.0] + v = None + if self.observation: + outputs = self.plan(env.observation(player)) + v = outputs.get('value', None) + if show: + print_outputs(env, None, v) + return v class EnsembleAgent(Agent): From a25e4364cbb225cc657f0bf7dbd756c18e64f72b Mon Sep 17 00:00:00 2001 From: YuriCat Date: Thu, 31 Mar 2022 23:28:41 +0900 Subject: [PATCH 06/21] feature: divide obs and plan line in agent class --- handyrl/agent.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/handyrl/agent.py b/handyrl/agent.py index 6b6f6bd9..cbb5c961 100755 --- a/handyrl/agent.py +++ b/handyrl/agent.py @@ -56,7 +56,8 @@ def plan(self, obs): return outputs def action(self, env, player, show=False): - outputs = self.plan(env.observation(player)) + obs = env.observation(player) + outputs = self.plan(obs) actions = env.legal_actions(player) p = outputs['policy'] v = outputs.get('value', None) @@ -74,7 +75,8 @@ def action(self, env, player, show=False): return random.choices(np.arange(len(p)), weights=softmax(p / self.temperature))[0] def observe(self, env, player, show=False): - outputs = self.plan(env.observation(player)) + obs = env.observation(player) + outputs = self.plan(obs) v = outputs.get('value', None) if show: print_outputs(env, None, v) From 2aefe4d306266bf4025bc5e4d3993c9ad5543191 Mon Sep 17 00:00:00 2001 From: YuriCat Date: Sat, 2 Apr 2022 22:42:51 +0900 Subject: [PATCH 07/21] feature: apply train_args.observation flag in Generator --- handyrl/generation.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/handyrl/generation.py b/handyrl/generation.py index e03e857b..8bca1c98 100755 --- a/handyrl/generation.py +++ b/handyrl/generation.py @@ -37,6 +37,8 @@ def generate(self, models, args): for player in self.env.players(): if player not in turn_players + observers: continue + if player not in turn_players and player in args['player'] and not self.args['observation']: + continue obs = self.env.observation(player) model = models[player] From 50a221833c76d01796a4e9806cb0c97c5bb24205 Mon Sep 17 00:00:00 2001 From: YuriCat Date: Mon, 4 Apr 2022 17:58:03 +0900 Subject: [PATCH 08/21] feature: stop workers if reveived args is None --- handyrl/connection.py | 4 +- handyrl/train.py | 88 ++++++++++++++++++++++++------------------- handyrl/worker.py | 11 +++++- 3 files changed, 61 insertions(+), 42 deletions(-) diff --git a/handyrl/connection.py b/handyrl/connection.py index 3550f153..68e2159b 100755 --- a/handyrl/connection.py +++ b/handyrl/connection.py @@ -215,8 +215,8 @@ def shutdown(self): for thread in self.threads: thread.join() - def recv(self): - return self.input_queue.get() + def recv(self, timeout=None): + return self.input_queue.get(timeout=timeout) def send(self, conn, send_data): self.output_queue.put((conn, send_data)) diff --git a/handyrl/train.py b/handyrl/train.py index 971a0890..d1636da4 100755 --- a/handyrl/train.py +++ b/handyrl/train.py @@ -549,17 +549,24 @@ def server(self): # returns as list if getting multiple requests as list print('started server') prev_update_episodes = self.args['minimum_episodes'] - while self.model_epoch < self.args['epochs'] or self.args['epochs'] < 0: - # no update call before storing minimum number of episodes + 1 age - next_update_episodes = prev_update_episodes + self.args['update_episodes'] - while not self.shutdown_flag and self.num_returned_episodes < next_update_episodes: - conn, (req, data) = self.worker.recv() - multi_req = isinstance(data, list) - if not multi_req: - data = [data] - send_data = [] - - if req == 'args': + # no update call before storing minimum number of episodes + 1 epoch + next_update_episodes = prev_update_episodes + self.args['update_episodes'] + + while len(self.worker.conns) > 0: + try: + conn, (req, data) = self.worker.recv(timeout=0.3) + except queue.Empty: + continue + + multi_req = isinstance(data, list) + if not multi_req: + data = [data] + send_data = [] + + if req == 'args': + if self.shutdown_flag: + send_data = [None] * len(data) + else: for _ in data: args = {'model_id': {}} @@ -591,33 +598,38 @@ def server(self): send_data.append(args) - elif req == 'episode': - # report generated episodes - self.feed_episodes(data) - send_data = [None] * len(data) - - elif req == 'result': - # report evaluation results - self.feed_results(data) - send_data = [None] * len(data) - - elif req == 'model': - for model_id in data: - model = self.model - if model_id != self.model_epoch and model_id > 0: - try: - model = copy.deepcopy(self.model) - model.load_state_dict(torch.load(self.model_path(model_id)), strict=False) - except: - # return latest model if failed to load specified model - pass - send_data.append(pickle.dumps(model)) - - if not multi_req and len(send_data) == 1: - send_data = send_data[0] - self.worker.send(conn, send_data) - prev_update_episodes = next_update_episodes - self.update() + elif req == 'episode': + # report generated episodes + self.feed_episodes(data) + send_data = [None] * len(data) + + elif req == 'result': + # report evaluation results + self.feed_results(data) + send_data = [None] * len(data) + + elif req == 'model': + for model_id in data: + model = self.model + if model_id != self.model_epoch and model_id > 0: + try: + model = copy.deepcopy(self.model) + model.load_state_dict(torch.load(self.model_path(model_id)), strict=False) + except: + # return latest model if failed to load specified model + pass + send_data.append(pickle.dumps(model)) + + if not multi_req and len(send_data) == 1: + send_data = send_data[0] + self.worker.send(conn, send_data) + + if self.num_returned_episodes >= next_update_episodes: + prev_update_episodes = next_update_episodes + next_update_episodes = prev_update_episodes + self.args['update_episodes'] + self.update() + if self.args['epochs'] >= 0 and self.model_epoch >= self.args['epochs']: + self.shutdown_flag = True print('finished server') def run(self): diff --git a/handyrl/worker.py b/handyrl/worker.py index 112494fa..cb036925 100755 --- a/handyrl/worker.py +++ b/handyrl/worker.py @@ -12,6 +12,7 @@ import multiprocessing as mp import pickle import copy +import queue from .environment import prepare_env, make_env from .connection import QueueCommunicator @@ -65,6 +66,8 @@ def _gather_models(self, model_ids): def run(self): while True: args = send_recv(self.conn, ('args', None)) + if args is None: + break role = args['role'] models = {} @@ -125,8 +128,12 @@ def __del__(self): print('finished gather %d' % self.gather_id) def run(self): - while True: - conn, (command, args) = self.recv() + while len(self.conns) > 0: + try: + conn, (command, args) = self.recv(timeout=0.3) + except queue.Empty: + continue + if command == 'args': # When requested arguments, return buffered outputs if len(self.args_queue) == 0: From 2c324f4284998aabf2831acf94a393c859029f16 Mon Sep 17 00:00:00 2001 From: YuriCat Date: Mon, 4 Apr 2022 18:06:16 +0900 Subject: [PATCH 09/21] feature: daemonic batcher processes --- handyrl/connection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/handyrl/connection.py b/handyrl/connection.py index c5e57fab..19b0049a 100755 --- a/handyrl/connection.py +++ b/handyrl/connection.py @@ -140,7 +140,7 @@ def __init__(self, func, send_generator, num_workers, postprocess=None): for i in range(num_workers): conn0, conn1 = mp.Pipe(duplex=True) - mp.Process(target=func, args=(conn1, i)).start() + mp.Process(target=func, args=(conn1, i), daemon=True).start() conn1.close() self.conns.append(conn0) self.waiting_conns.put(conn0) From a659dd86a70551b637da513c67a5f1e2f0eae545 Mon Sep 17 00:00:00 2001 From: YuriCat Date: Wed, 6 Apr 2022 22:15:46 +0900 Subject: [PATCH 10/21] fix: continue server before waiting workers --- handyrl/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/handyrl/train.py b/handyrl/train.py index d285b1fd..8bd59e16 100755 --- a/handyrl/train.py +++ b/handyrl/train.py @@ -535,7 +535,7 @@ def server(self): # no update call before storing minimum number of episodes + 1 epoch next_update_episodes = prev_update_episodes + self.args['update_episodes'] - while len(self.worker.conns) > 0: + while len(self.worker.conns) > 0 or not self.shutdown_flag: try: conn, (req, data) = self.worker.recv(timeout=0.3) except queue.Empty: From 907dacf3aa06190dd7aa4c8dedd72acaa9fa5332 Mon Sep 17 00:00:00 2001 From: YuriCat Date: Mon, 11 Apr 2022 00:08:29 +0900 Subject: [PATCH 11/21] feature: set timeout=None for connection acceptors --- handyrl/worker.py | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/handyrl/worker.py b/handyrl/worker.py index 75661f72..759af4d5 100755 --- a/handyrl/worker.py +++ b/handyrl/worker.py @@ -193,27 +193,25 @@ def run(self): # prepare listening connections def entry_server(port): print('started entry server %d' % port) - conn_acceptor = accept_socket_connections(port=port, timeout=0.3) + conn_acceptor = accept_socket_connections(port=port) while True: conn = next(conn_acceptor) - if conn is not None: - worker_args = conn.recv() - print('accepted connection from %s!' % worker_args['address']) - worker_args['base_worker_id'] = self.total_worker_count - self.total_worker_count += worker_args['num_parallel'] - args = copy.deepcopy(self.args) - args['worker'] = worker_args - conn.send(args) - conn.close() + worker_args = conn.recv() + print('accepted connection from %s!' % worker_args['address']) + worker_args['base_worker_id'] = self.total_worker_count + self.total_worker_count += worker_args['num_parallel'] + args = copy.deepcopy(self.args) + args['worker'] = worker_args + conn.send(args) + conn.close() print('finished entry server') def worker_server(port): print('started worker server %d' % port) - conn_acceptor = accept_socket_connections(port=port, timeout=0.3) + conn_acceptor = accept_socket_connections(port=port) while True: conn = next(conn_acceptor) - if conn is not None: - self.add_connection(conn) + self.add_connection(conn) print('finished worker server') threading.Thread(target=entry_server, args=(9999,), daemon=True).start() From ca77c2fbdfe71976805d66c61ce450685a7b127a Mon Sep 17 00:00:00 2001 From: YuriCat Date: Mon, 11 Apr 2022 19:32:29 +0900 Subject: [PATCH 12/21] feature: use set instead of list for connections in QueueCommunicator --- handyrl/connection.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/handyrl/connection.py b/handyrl/connection.py index 3550f153..aa005ecd 100755 --- a/handyrl/connection.py +++ b/handyrl/connection.py @@ -199,7 +199,7 @@ class QueueCommunicator: def __init__(self, conns=[]): self.input_queue = queue.Queue(maxsize=256) self.output_queue = queue.Queue(maxsize=256) - self.conns = [] + self.conns = set() for conn in conns: self.add_connection(conn) self.shutdown_flag = False @@ -222,11 +222,11 @@ def send(self, conn, send_data): self.output_queue.put((conn, send_data)) def add_connection(self, conn): - self.conns.append(conn) + self.conns.add(conn) def disconnect(self, conn): print('disconnected') - self.conns.remove(conn) + self.conns.discard(conn) def _send_thread(self): while not self.shutdown_flag: From 21386da909f9238de5b7a159097b5ecccd315ccb Mon Sep 17 00:00:00 2001 From: YuriCat Date: Fri, 15 Apr 2022 19:24:45 +0900 Subject: [PATCH 13/21] feature: set model=None in load_model() --- handyrl/evaluation.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/handyrl/evaluation.py b/handyrl/evaluation.py index ad770e30..305bd380 100755 --- a/handyrl/evaluation.py +++ b/handyrl/evaluation.py @@ -351,10 +351,11 @@ def insert_input(y): return outputs -def load_model(model_path, model): +def load_model(model_path, model=None): if model_path.endswith('.onnx'): model = OnnxModel(model_path) return model + assert model is not None import torch from .model import ModelWrapper model.load_state_dict(torch.load(model_path)) From 58034942bed73b9f202aeeecbcba8da4fdcd1775 Mon Sep 17 00:00:00 2001 From: YuriCat Date: Thu, 21 Apr 2022 22:04:20 +0900 Subject: [PATCH 14/21] feature: connection_count() interface --- handyrl/connection.py | 3 +++ handyrl/train.py | 2 +- handyrl/worker.py | 2 +- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/handyrl/connection.py b/handyrl/connection.py index 19b0049a..52612997 100755 --- a/handyrl/connection.py +++ b/handyrl/connection.py @@ -183,6 +183,9 @@ def __init__(self, conns=[]): threading.Thread(target=self._send_thread, daemon=True).start() threading.Thread(target=self._recv_thread, daemon=True).start() + def connection_count(self): + return len(self.conns) + def recv(self, timeout=None): return self.input_queue.get(timeout=timeout) diff --git a/handyrl/train.py b/handyrl/train.py index 8bd59e16..05fb2b9d 100755 --- a/handyrl/train.py +++ b/handyrl/train.py @@ -535,7 +535,7 @@ def server(self): # no update call before storing minimum number of episodes + 1 epoch next_update_episodes = prev_update_episodes + self.args['update_episodes'] - while len(self.worker.conns) > 0 or not self.shutdown_flag: + while self.worker.connection_count() > 0 or not self.shutdown_flag: try: conn, (req, data) = self.worker.recv(timeout=0.3) except queue.Empty: diff --git a/handyrl/worker.py b/handyrl/worker.py index 6599c0ee..a34d5d4c 100755 --- a/handyrl/worker.py +++ b/handyrl/worker.py @@ -128,7 +128,7 @@ def __del__(self): print('finished gather %d' % self.gather_id) def run(self): - while len(self.conns) > 0: + while self.connection_count() > 0: try: conn, (command, args) = self.recv(timeout=0.3) except queue.Empty: From 068e823fee7689af8ae17515cc35e9ee2a86db6f Mon Sep 17 00:00:00 2001 From: ikki407 Date: Fri, 25 Mar 2022 15:05:15 +0900 Subject: [PATCH 15/21] feature: add losses plot and stats plot --- scripts/loss_plot.py | 91 ++++++++++++++++++++++++++++++++++++++++ scripts/stats_plot.py | 91 ++++++++++++++++++++++++++++++++++++++++ scripts/win_rate_plot.py | 9 ++-- 3 files changed, 186 insertions(+), 5 deletions(-) create mode 100644 scripts/loss_plot.py create mode 100644 scripts/stats_plot.py diff --git a/scripts/loss_plot.py b/scripts/loss_plot.py new file mode 100644 index 00000000..aa6db5de --- /dev/null +++ b/scripts/loss_plot.py @@ -0,0 +1,91 @@ + +# Usage: python3 scripts/loss_plot.py train-log-002.txt title + +# You should not include figures generated by this script in your academic paper, because +# 1. This version of HandyRL doesn't display all the results of the matches. +# 2. Smoothing method in this script is not a simple moving average. + +import sys +import numpy as np +import matplotlib.pyplot as plt +import seaborn as sns + + +n = 10 + + +def kernel(n): + a = np.array(list(range(1, 1 + (n+1)//2)) + list(range(1 + n//2, 1, -1))) + return a / a.sum() + + +def get_loss_list(path): + epoch_data_list = [] + epoch_list = [0] + step_list = [0] + game_list = [0] + + f = open(path) + lines = f.readlines() + prev_line = '' + + for line in lines: + if line.startswith('updated'): + epoch_list.append(len(epoch_list)) + step_list.append(int(line.split('(')[1].rstrip().rstrip(')'))) + if line.startswith('loss'): + elms = line.split() + epoch_data_list.append({}) + for e in elms[2:]: + name, loss = e.split(':') + loss = float(loss) + epoch_data_list[-1][name] = loss + if line.startswith('epoch '): + print(line, len(epoch_list)) + if ' ' in prev_line: + game = int(prev_line.split()[-1]) + game_list.append(game) + + prev_line = line + + game_list = game_list[:len(epoch_data_list)] + + if (len(epoch_list) - n//2 + 1 - n//2) - n <= 0: + raise Exception('Range to average `n` is too large, set small nubmer') + clipped_epoch_list = epoch_list[n//2:-n//2+1] + clipped_step_list = step_list[n//2:-n//2+1] + clipped_game_list = game_list[n//2:-n//2+1] + + kn = kernel(n) + start_epoch = {} + averaged_loss_lists = {} + for name in epoch_data_list[0].keys(): + data = [d[name] for d in epoch_data_list] + averaged_loss_lists[name] = np.convolve(data, kn, mode='valid') + start_epoch = 0 + + return clipped_epoch_list, clipped_step_list, clipped_game_list, averaged_loss_lists, start_epoch + + +# Plot +flatui = ["#9b59b6", "#95a5a6", "#34495e", "#3498db", "#e74c3c", "#2ecc71", "#b22222"] +sns.set_palette(sns.color_palette(flatui, 24)) + +clipped_epoch_list, clipped_step_list, clipped_game_list, averaged_loss_lists, start_epoch = get_loss_list(sys.argv[1]) + +fig_num = len(averaged_loss_lists) +fig = plt.figure() + +for i, (k, v) in enumerate(averaged_loss_lists.items()): + ax = fig.add_subplot((fig_num - 1) // 2 + 1, 2, i + 1) + start = start_epoch + ax.plot(clipped_game_list[start:], v[start:], label=k) + + ax.set_xlabel('Games') + ax.set_ylabel(k) + ax.set_title(f"loss: {k}") + ax.grid() + +fig.suptitle(sys.argv[2]) +fig.tight_layout() +plt.show() diff --git a/scripts/stats_plot.py b/scripts/stats_plot.py new file mode 100644 index 00000000..6696e0c2 --- /dev/null +++ b/scripts/stats_plot.py @@ -0,0 +1,91 @@ + +# Usage: python3 scripts/stats_plot.py train-log-002.txt title + +# You should not include figures generated by this script in your academic paper, because +# 1. This version of HandyRL doesn't display all the results of the matches. +# 2. Smoothing method in this script is not a simple moving average. + +import sys +import numpy as np +import matplotlib.pyplot as plt +import seaborn as sns + + +n = 10 + + +def kernel(n): + a = np.array(list(range(1, 1 + (n+1)//2)) + list(range(1 + n//2, 1, -1))) + return a / a.sum() + + +def get_stats_list(path): + epoch_data_list = [] + epoch_list = [0] + step_list = [0] + game_list = [0] + + f = open(path) + lines = f.readlines() + prev_line = '' + + for line in lines: + if line.startswith('updated'): + epoch_list.append(len(epoch_list)) + step_list.append(int(line.split('(')[1].rstrip().rstrip(')'))) + if line.startswith('generation stats'): + elms = line.split('=') + epoch_data_list.append({}) + for e in range(0, len(elms), 2): + name, stats = elms[e], elms[e+1].split('+-') + stats_mean, stats_std = float(stats[0]), float(stats[1]) + epoch_data_list[-1][name] = stats_mean + if line.startswith('epoch '): + print(line, len(epoch_list)) + if ' ' in prev_line: + game = int(prev_line.split()[-1]) + game_list.append(game) + + prev_line = line + + game_list = game_list[:len(epoch_data_list)] + + if (len(epoch_list) - n//2 + 1 - n//2) - n <= 0: + raise Exception('Range to average `n` is too large, set small nubmer') + clipped_epoch_list = epoch_list[n//2:-n//2+1] + clipped_step_list = step_list[n//2:-n//2+1] + clipped_game_list = game_list[n//2:-n//2+1] + + kn = kernel(n) + start_epoch = {} + averaged_stats_lists = {} + for name in epoch_data_list[0].keys(): + data = [d[name] for d in epoch_data_list] + averaged_stats_lists[name] = np.convolve(data, kn, mode='valid') + start_epoch = 0 + + return clipped_epoch_list, clipped_step_list, clipped_game_list, averaged_stats_lists, start_epoch + + +# Plot +flatui = ["#9b59b6", "#95a5a6", "#34495e", "#3498db", "#e74c3c", "#2ecc71", "#b22222"] +sns.set_palette(sns.color_palette(flatui, 24)) + +clipped_epoch_list, clipped_step_list, clipped_game_list, averaged_stats_lists, start_epoch = get_stats_list(sys.argv[1]) + +fig_num = len(averaged_stats_lists) +fig = plt.figure() + +for i, (k, v) in enumerate(averaged_stats_lists.items()): + ax = fig.add_subplot((fig_num - 1) // 2 + 1, 2, i + 1) + start = start_epoch + ax.plot(clipped_game_list[start:], v[start:], label=k) + + ax.set_xlabel('Games') + ax.set_ylabel(k) + ax.set_title(f"stats: {k}") + ax.grid() + +fig.suptitle(sys.argv[2]) +fig.tight_layout() +plt.show() diff --git a/scripts/win_rate_plot.py b/scripts/win_rate_plot.py index 7d8e67e3..cf18878e 100644 --- a/scripts/win_rate_plot.py +++ b/scripts/win_rate_plot.py @@ -7,10 +7,13 @@ import sys import numpy as np +import matplotlib.pyplot as plt +import seaborn as sns n = 15 + def kernel(n): a = np.array(list(range(1, 1 + (n+1)//2)) + list(range(1 + n//2,1,-1))) return a / a.sum() @@ -34,7 +37,7 @@ def get_wp_list(path): step_list.append(int(line.split('(')[1].rstrip().rstrip(')'))) if line.startswith('win rate'): elms = line.split() - opponent = elms[2].lstrip('(').rstrip(')') + opponent = elms[2].lstrip('(').rstrip(')') if elms[2] != '=' else 'total' games = int(elms[-1].lstrip('(').rstrip(')')) wp = float(elms[-4]) if games > 0 else 0.0 epoch_data_list[-1][opponent] = {'w': games * wp, 'n': games} @@ -69,8 +72,6 @@ def get_wp_list(path): return clipped_epoch_list, clipped_step_list, clipped_game_list, averaged_wp_lists, start_epoch -import matplotlib.pyplot as plt -import seaborn as sns flatui = ["#9b59b6", "#95a5a6", "#34495e", "#3498db", "#e74c3c", "#2ecc71", "#b22222"] sns.set_palette(sns.color_palette(flatui, 24)) @@ -84,8 +85,6 @@ def get_wp_list(path): last_win_rate = {} for opponent in opponents: - if opponent == 'total': - continue wp_list = averaged_wp_lists[opponent] start = start_epoch[opponent] # ax.plot(clipped_epoch_list[start:], wp_list[start:], label=opponent) From e017a98d53d46432fa83df62e6e321898a325439 Mon Sep 17 00:00:00 2001 From: YuriCat Date: Sat, 23 Apr 2022 14:24:59 +0900 Subject: [PATCH 16/21] feature: rulebase agents with keys --- handyrl/agent.py | 5 ++++- handyrl/envs/kaggle/hungry_geese.py | 2 +- handyrl/evaluation.py | 5 +++-- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/handyrl/agent.py b/handyrl/agent.py index cbb5c961..71a2dde6 100755 --- a/handyrl/agent.py +++ b/handyrl/agent.py @@ -23,9 +23,12 @@ def observe(self, env, player, show=False): class RuleBasedAgent(RandomAgent): + def __init__(self, key=None): + self.key = None + def action(self, env, player, show=False): if hasattr(env, 'rule_based_action'): - return env.rule_based_action(player) + return env.rule_based_action(player, key=self.key) else: return random.choice(env.legal_actions(player)) diff --git a/handyrl/envs/kaggle/hungry_geese.py b/handyrl/envs/kaggle/hungry_geese.py index 0a663adc..52dd00c5 100644 --- a/handyrl/envs/kaggle/hungry_geese.py +++ b/handyrl/envs/kaggle/hungry_geese.py @@ -186,7 +186,7 @@ def legal_actions(self, player): def players(self): return list(range(self.NUM_AGENTS)) - def rule_based_action(self, player): + def rule_based_action(self, player, key=None): from kaggle_environments.envs.hungry_geese.hungry_geese import Observation, Configuration, Action, GreedyAgent action_map = {'N': Action.NORTH, 'S': Action.SOUTH, 'W': Action.WEST, 'E': Action.EAST} diff --git a/handyrl/evaluation.py b/handyrl/evaluation.py index ad770e30..720ea5f9 100755 --- a/handyrl/evaluation.py +++ b/handyrl/evaluation.py @@ -143,8 +143,9 @@ def exec_network_match(env, network_agents, critic=None, show=False, game_args={ def build_agent(raw, env=None): if raw == 'random': return RandomAgent() - elif raw == 'rulebase': - return RuleBasedAgent() + elif raw.startswith('rulebase'): + key = raw.split('-')[1] if '-' in raw else None + return RuleBasedAgent(key) return None From 834a8fa41000091c94e384cb13293510ae038102 Mon Sep 17 00:00:00 2001 From: YuriCat Date: Wed, 27 Apr 2022 18:40:43 +0900 Subject: [PATCH 17/21] fix: set key for rule-based agents --- handyrl/agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/handyrl/agent.py b/handyrl/agent.py index 86d2c08e..44f503ab 100755 --- a/handyrl/agent.py +++ b/handyrl/agent.py @@ -24,7 +24,7 @@ def observe(self, env, player, show=False): class RuleBasedAgent(RandomAgent): def __init__(self, key=None): - self.key = None + self.key = key def action(self, env, player, show=False): if hasattr(env, 'rule_based_action'): From 4160315be47caa78845b7c7d685489f61cd9b239 Mon Sep 17 00:00:00 2001 From: YuriCat Date: Thu, 12 May 2022 20:12:12 +0900 Subject: [PATCH 18/21] feature: opponent selction by : --- handyrl/evaluation.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/handyrl/evaluation.py b/handyrl/evaluation.py index 248f5b6c..18a39b21 100755 --- a/handyrl/evaluation.py +++ b/handyrl/evaluation.py @@ -378,14 +378,18 @@ def eval_main(args, argv): prepare_env(env_args) env = make_env(env_args) - model_path = argv[0] if len(argv) >= 1 else 'models/latest.pth' + model_paths = argv[0].split(':') if len(argv) >= 1 else ['models/latest.pth'] num_games = int(argv[1]) if len(argv) >= 2 else 100 num_process = int(argv[2]) if len(argv) >= 3 else 1 - agent1 = build_agent(model_path, env) - if agent1 is None: - model = load_model(model_path, env.net()) - agent1 = Agent(model) + def resolve_agent(model_path): + agent = build_agent(model_path, env) + if agent is None: + model = load_model(model_path, env.net()) + agent = Agent(model) + return agent + + main_agent = resolve_agent(model_paths[0]) critic = None print('%d process, %d games' % (num_process, num_games)) @@ -393,7 +397,8 @@ def eval_main(args, argv): seed = random.randrange(1e8) print('seed = %d' % seed) - agents = [agent1] + [RandomAgent() for _ in range(len(env.players()) - 1)] + opponent = model_paths[1] if len(model_paths) > 1 else 'random' + agents = [main_agent] + [resolve_agent(opponent) for _ in range(len(env.players()) - 1)] evaluate_mp(env, agents, critic, env_args, {'default': {}}, num_process, num_games, seed) From 30ca00cd79c74c35c02d6afe7b2a37ddcdb360cd Mon Sep 17 00:00:00 2001 From: YuriCat Date: Thu, 26 May 2022 23:47:14 +0900 Subject: [PATCH 19/21] feature: divide ep count variable --- handyrl/train.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/handyrl/train.py b/handyrl/train.py index 05fb2b9d..75c0a31b 100755 --- a/handyrl/train.py +++ b/handyrl/train.py @@ -286,8 +286,9 @@ def run(self): def select_episode(self): while True: - ep_idx = random.randrange(min(len(self.episodes), self.args['maximum_episodes'])) - accept_rate = 1 - (len(self.episodes) - 1 - ep_idx) / self.args['maximum_episodes'] + ep_count = min(len(self.episodes), self.args['maximum_episodes']) + ep_idx = random.randrange(ep_count) + accept_rate = 1 - (ep_count - 1 - ep_idx) / self.args['maximum_episodes'] if random.random() < accept_rate: break ep = self.episodes[ep_idx] From bcc4d7af846280cbb04369d542a353d164d11c42 Mon Sep 17 00:00:00 2001 From: YuriCat Date: Sat, 28 May 2022 07:36:58 +0900 Subject: [PATCH 20/21] fix: stop calling view_transition() in reset phase --- handyrl/evaluation.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/handyrl/evaluation.py b/handyrl/evaluation.py index 248f5b6c..8783ce86 100755 --- a/handyrl/evaluation.py +++ b/handyrl/evaluation.py @@ -58,7 +58,8 @@ def run(self): reset = args[1] if reset: self.agent.reset(self.env, show=True) - view_transition(self.env) + else: + view_transition(self.env) self.conn.send(ret) From 1f8aa05d7327113cdb74d1889dc98a5222d474a4 Mon Sep 17 00:00:00 2001 From: YuriCat Date: Mon, 13 Jun 2022 23:38:11 +0900 Subject: [PATCH 21/21] chore: add kwargs to random model --- handyrl/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/handyrl/model.py b/handyrl/model.py index b4dd2a7a..d59bde82 100755 --- a/handyrl/model.py +++ b/handyrl/model.py @@ -70,5 +70,5 @@ def __init__(self, model, x): outputs = wrapped_model.inference(x, hidden) self.output_dict = {key: np.zeros_like(value) for key, value in outputs.items() if key != 'hidden'} - def inference(self, *args): + def inference(self, *args, **kwargs): return self.output_dict