diff --git a/handyrl/agent.py b/handyrl/agent.py index 6b6f6bd9..44f503ab 100755 --- a/handyrl/agent.py +++ b/handyrl/agent.py @@ -23,9 +23,12 @@ def observe(self, env, player, show=False): class RuleBasedAgent(RandomAgent): + def __init__(self, key=None): + self.key = key + def action(self, env, player, show=False): if hasattr(env, 'rule_based_action'): - return env.rule_based_action(player) + return env.rule_based_action(player, key=self.key) else: return random.choice(env.legal_actions(player)) @@ -41,11 +44,12 @@ def print_outputs(env, prob, v): class Agent: - def __init__(self, model, temperature=0.0): + def __init__(self, model, temperature=0.0, observation=True): # model might be a neural net, or some planning algorithm such as game tree search self.model = model self.hidden = None self.temperature = temperature + self.observation = observation def reset(self, env, show=False): self.hidden = self.model.init_hidden() @@ -56,7 +60,8 @@ def plan(self, obs): return outputs def action(self, env, player, show=False): - outputs = self.plan(env.observation(player)) + obs = env.observation(player) + outputs = self.plan(obs) actions = env.legal_actions(player) p = outputs['policy'] v = outputs.get('value', None) @@ -74,11 +79,14 @@ def action(self, env, player, show=False): return random.choices(np.arange(len(p)), weights=softmax(p / self.temperature))[0] def observe(self, env, player, show=False): - outputs = self.plan(env.observation(player)) - v = outputs.get('value', None) - if show: - print_outputs(env, None, v) - return v if v is not None else [0.0] + v = None + if self.observation: + obs = env.observation(player) + outputs = self.plan(obs) + v = outputs.get('value', None) + if show: + print_outputs(env, None, v) + return v class EnsembleAgent(Agent): diff --git a/handyrl/connection.py b/handyrl/connection.py index 3550f153..49a176ee 100755 --- a/handyrl/connection.py +++ b/handyrl/connection.py @@ -131,109 +131,77 @@ def open_multiprocessing_connections(num_process, target, args_func): class MultiProcessJobExecutor: - def __init__(self, func, send_generator, num_workers, postprocess=None, num_receivers=1): + def __init__(self, func, send_generator, num_workers, postprocess=None): self.send_generator = send_generator self.postprocess = postprocess - self.num_receivers = num_receivers self.conns = [] self.waiting_conns = queue.Queue() - self.shutdown_flag = False self.output_queue = queue.Queue(maxsize=8) - self.threads = [] for i in range(num_workers): conn0, conn1 = mp.Pipe(duplex=True) - mp.Process(target=func, args=(conn1, i)).start() + mp.Process(target=func, args=(conn1, i), daemon=True).start() conn1.close() self.conns.append(conn0) self.waiting_conns.put(conn0) - def shutdown(self): - self.shutdown_flag = True - for thread in self.threads: - thread.join() - def recv(self): return self.output_queue.get() def start(self): - self.threads.append(threading.Thread(target=self._sender)) - for i in range(self.num_receivers): - self.threads.append(threading.Thread(target=self._receiver, args=(i,))) - for thread in self.threads: - thread.start() + threading.Thread(target=self._sender, daemon=True).start() + threading.Thread(target=self._receiver, daemon=True).start() def _sender(self): print('start sender') - while not self.shutdown_flag: + while True: data = next(self.send_generator) - while not self.shutdown_flag: - try: - conn = self.waiting_conns.get(timeout=0.3) - conn.send(data) - break - except queue.Empty: - pass + conn = self.waiting_conns.get() + conn.send(data) print('finished sender') - def _receiver(self, index): - print('start receiver %d' % index) - conns = [conn for i, conn in enumerate(self.conns) if i % self.num_receivers == index] - while not self.shutdown_flag: - tmp_conns = connection.wait(conns) - for conn in tmp_conns: + def _receiver(self): + print('start receiver') + while True: + conns = connection.wait(self.conns) + for conn in conns: data = conn.recv() self.waiting_conns.put(conn) if self.postprocess is not None: data = self.postprocess(data) - while not self.shutdown_flag: - try: - self.output_queue.put(data, timeout=0.3) - break - except queue.Full: - pass - print('finished receiver %d' % index) + self.output_queue.put(data) + print('finished receiver') class QueueCommunicator: def __init__(self, conns=[]): self.input_queue = queue.Queue(maxsize=256) self.output_queue = queue.Queue(maxsize=256) - self.conns = [] + self.conns = set() for conn in conns: self.add_connection(conn) - self.shutdown_flag = False - self.threads = [ - threading.Thread(target=self._send_thread), - threading.Thread(target=self._recv_thread), - ] - for thread in self.threads: - thread.start() - - def shutdown(self): - self.shutdown_flag = True - for thread in self.threads: - thread.join() + threading.Thread(target=self._send_thread, daemon=True).start() + threading.Thread(target=self._recv_thread, daemon=True).start() - def recv(self): - return self.input_queue.get() + def connection_count(self): + return len(self.conns) + + def recv(self, timeout=None): + return self.input_queue.get(timeout=timeout) def send(self, conn, send_data): self.output_queue.put((conn, send_data)) def add_connection(self, conn): - self.conns.append(conn) + self.conns.add(conn) def disconnect(self, conn): print('disconnected') - self.conns.remove(conn) + self.conns.discard(conn) def _send_thread(self): - while not self.shutdown_flag: - try: - conn, send_data = self.output_queue.get(timeout=0.3) - except queue.Empty: - continue + while True: + conn, send_data = self.output_queue.get() try: conn.send(send_data) except ConnectionResetError: @@ -242,7 +210,7 @@ def _send_thread(self): self.disconnect(conn) def _recv_thread(self): - while not self.shutdown_flag: + while True: conns = connection.wait(self.conns, timeout=0.3) for conn in conns: try: @@ -253,9 +221,4 @@ def _recv_thread(self): except EOFError: self.disconnect(conn) continue - while not self.shutdown_flag: - try: - self.input_queue.put((conn, recv_data), timeout=0.3) - break - except queue.Full: - pass + self.input_queue.put((conn, recv_data)) diff --git a/handyrl/envs/kaggle/hungry_geese.py b/handyrl/envs/kaggle/hungry_geese.py index 0a663adc..52dd00c5 100644 --- a/handyrl/envs/kaggle/hungry_geese.py +++ b/handyrl/envs/kaggle/hungry_geese.py @@ -186,7 +186,7 @@ def legal_actions(self, player): def players(self): return list(range(self.NUM_AGENTS)) - def rule_based_action(self, player): + def rule_based_action(self, player, key=None): from kaggle_environments.envs.hungry_geese.hungry_geese import Observation, Configuration, Action, GreedyAgent action_map = {'N': Action.NORTH, 'S': Action.SOUTH, 'W': Action.WEST, 'E': Action.EAST} diff --git a/handyrl/evaluation.py b/handyrl/evaluation.py index ad770e30..2d391bce 100755 --- a/handyrl/evaluation.py +++ b/handyrl/evaluation.py @@ -58,7 +58,8 @@ def run(self): reset = args[1] if reset: self.agent.reset(self.env, show=True) - view_transition(self.env) + else: + view_transition(self.env) self.conn.send(ret) @@ -143,8 +144,9 @@ def exec_network_match(env, network_agents, critic=None, show=False, game_args={ def build_agent(raw, env=None): if raw == 'random': return RandomAgent() - elif raw == 'rulebase': - return RuleBasedAgent() + elif raw.startswith('rulebase'): + key = raw.split('-')[1] if '-' in raw else None + return RuleBasedAgent(key) return None @@ -351,10 +353,11 @@ def insert_input(y): return outputs -def load_model(model_path, model): +def load_model(model_path, model=None): if model_path.endswith('.onnx'): model = OnnxModel(model_path) return model + assert model is not None import torch from .model import ModelWrapper model.load_state_dict(torch.load(model_path)) @@ -376,14 +379,18 @@ def eval_main(args, argv): prepare_env(env_args) env = make_env(env_args) - model_path = argv[0] if len(argv) >= 1 else 'models/latest.pth' + model_paths = argv[0].split(':') if len(argv) >= 1 else ['models/latest.pth'] num_games = int(argv[1]) if len(argv) >= 2 else 100 num_process = int(argv[2]) if len(argv) >= 3 else 1 - agent1 = build_agent(model_path, env) - if agent1 is None: - model = load_model(model_path, env.net()) - agent1 = Agent(model) + def resolve_agent(model_path): + agent = build_agent(model_path, env) + if agent is None: + model = load_model(model_path, env.net()) + agent = Agent(model) + return agent + + main_agent = resolve_agent(model_paths[0]) critic = None print('%d process, %d games' % (num_process, num_games)) @@ -391,7 +398,8 @@ def eval_main(args, argv): seed = random.randrange(1e8) print('seed = %d' % seed) - agents = [agent1] + [RandomAgent() for _ in range(len(env.players()) - 1)] + opponent = model_paths[1] if len(model_paths) > 1 else 'random' + agents = [main_agent] + [resolve_agent(opponent) for _ in range(len(env.players()) - 1)] evaluate_mp(env, agents, critic, env_args, {'default': {}}, num_process, num_games, seed) diff --git a/handyrl/generation.py b/handyrl/generation.py index e03e857b..8bca1c98 100755 --- a/handyrl/generation.py +++ b/handyrl/generation.py @@ -37,6 +37,8 @@ def generate(self, models, args): for player in self.env.players(): if player not in turn_players + observers: continue + if player not in turn_players and player in args['player'] and not self.args['observation']: + continue obs = self.env.observation(player) model = models[player] diff --git a/handyrl/model.py b/handyrl/model.py index b4dd2a7a..d59bde82 100755 --- a/handyrl/model.py +++ b/handyrl/model.py @@ -70,5 +70,5 @@ def __init__(self, model, x): outputs = wrapped_model.inference(x, hidden) self.output_dict = {key: np.zeros_like(value) for key, value in outputs.items() if key != 'hidden'} - def inference(self, *args): + def inference(self, *args, **kwargs): return self.output_dict diff --git a/handyrl/train.py b/handyrl/train.py index 24135e59..59c9b82c 100755 --- a/handyrl/train.py +++ b/handyrl/train.py @@ -267,9 +267,7 @@ class Batcher: def __init__(self, args, episodes): self.args = args self.episodes = episodes - self.shutdown_flag = False - - self.executor = MultiProcessJobExecutor(self._worker, self._selector(), self.args['num_batchers'], num_receivers=2) + self.executor = MultiProcessJobExecutor(self._worker, self._selector(), self.args['num_batchers']) def _selector(self): while True: @@ -277,7 +275,7 @@ def _selector(self): def _worker(self, conn, bid): print('started batcher %d' % bid) - while not self.shutdown_flag: + while True: episodes = conn.recv() batch = make_batch(episodes, self.args) conn.send(batch) @@ -288,8 +286,9 @@ def run(self): def select_episode(self): while True: - ep_idx = random.randrange(min(len(self.episodes), self.args['maximum_episodes'])) - accept_rate = 1 - (len(self.episodes) - 1 - ep_idx) / self.args['maximum_episodes'] + ep_count = min(len(self.episodes), self.args['maximum_episodes']) + ep_idx = random.randrange(ep_count) + accept_rate = 1 - (ep_count - 1 - ep_idx) / self.args['maximum_episodes'] if random.random() < accept_rate: break ep = self.episodes[ep_idx] @@ -310,10 +309,6 @@ def select_episode(self): def batch(self): return self.executor.recv() - def shutdown(self): - self.shutdown_flag = True - self.executor.shutdown() - class Trainer: def __init__(self, args, model): @@ -330,7 +325,6 @@ def __init__(self, args, model): self.batcher = Batcher(self.args, self.episodes) self.update_flag = False self.update_queue = queue.Queue(maxsize=1) - self.shutdown_flag = False self.wrapped_model = ModelWrapper(self.model) self.trained_model = self.wrapped_model @@ -342,10 +336,6 @@ def update(self): model, steps = self.update_queue.get() return model, steps - def shutdown(self): - self.shutdown_flag = True - self.batcher.shutdown() - def train(self): if self.optimizer is None: # non-parametric model time.sleep(0.1) @@ -356,7 +346,7 @@ def train(self): self.trained_model.cuda() self.trained_model.train() - while data_cnt == 0 or not (self.update_flag or self.shutdown_flag): + while data_cnt == 0 or not self.update_flag: batch = self.batcher.batch() batch_size = batch['value'].size(0) player_count = batch['value'].size(2) @@ -390,12 +380,12 @@ def train(self): def run(self): print('waiting training') - while not self.shutdown_flag and len(self.episodes) < self.args['minimum_episodes']: + while len(self.episodes) < self.args['minimum_episodes']: time.sleep(1) - if not self.shutdown_flag and self.optimizer is not None: + if self.optimizer is not None: self.batcher.run() print('started training') - while not self.shutdown_flag: + while True: model = self.train() self.update_flag = False self.update_queue.put((model, self.steps)) @@ -440,12 +430,6 @@ def __init__(self, args, net=None, remote=False): # thread connection self.trainer = Trainer(args, self.model) - def shutdown(self): - self.shutdown_flag = True - self.trainer.shutdown() - self.worker.shutdown() - self.thread.join() - def model_path(self, model_id): return os.path.join('models', str(model_id) + '.pth') @@ -551,17 +535,24 @@ def server(self): # returns as list if getting multiple requests as list print('started server') prev_update_episodes = self.args['minimum_episodes'] - while self.model_epoch < self.args['epochs'] or self.args['epochs'] < 0: - # no update call before storing minimum number of episodes + 1 age - next_update_episodes = prev_update_episodes + self.args['update_episodes'] - while not self.shutdown_flag and self.num_returned_episodes < next_update_episodes: - conn, (req, data) = self.worker.recv() - multi_req = isinstance(data, list) - if not multi_req: - data = [data] - send_data = [] - - if req == 'args': + # no update call before storing minimum number of episodes + 1 epoch + next_update_episodes = prev_update_episodes + self.args['update_episodes'] + + while self.worker.connection_count() > 0 or not self.shutdown_flag: + try: + conn, (req, data) = self.worker.recv(timeout=0.3) + except queue.Empty: + continue + + multi_req = isinstance(data, list) + if not multi_req: + data = [data] + send_data = [] + + if req == 'args': + if self.shutdown_flag: + send_data = [None] * len(data) + else: for _ in data: args = {'model_id': {}} @@ -593,46 +584,46 @@ def server(self): send_data.append(args) - elif req == 'episode': - # report generated episodes - self.feed_episodes(data) - send_data = [None] * len(data) - - elif req == 'result': - # report evaluation results - self.feed_results(data) - send_data = [None] * len(data) - - elif req == 'model': - for model_id in data: - model = self.model - if model_id != self.model_epoch and model_id > 0: - try: - model = copy.deepcopy(self.model) - model.load_state_dict(torch.load(self.model_path(model_id)), strict=False) - except: - # return latest model if failed to load specified model - pass - send_data.append(pickle.dumps(model)) - - if not multi_req and len(send_data) == 1: - send_data = send_data[0] - self.worker.send(conn, send_data) - prev_update_episodes = next_update_episodes - self.update() + elif req == 'episode': + # report generated episodes + self.feed_episodes(data) + send_data = [None] * len(data) + + elif req == 'result': + # report evaluation results + self.feed_results(data) + send_data = [None] * len(data) + + elif req == 'model': + for model_id in data: + model = self.model + if model_id != self.model_epoch and model_id > 0: + try: + model = copy.deepcopy(self.model) + model.load_state_dict(torch.load(self.model_path(model_id)), strict=False) + except: + # return latest model if failed to load specified model + pass + send_data.append(pickle.dumps(model)) + + if not multi_req and len(send_data) == 1: + send_data = send_data[0] + self.worker.send(conn, send_data) + + if self.num_returned_episodes >= next_update_episodes: + prev_update_episodes = next_update_episodes + next_update_episodes = prev_update_episodes + self.args['update_episodes'] + self.update() + if self.args['epochs'] >= 0 and self.model_epoch >= self.args['epochs']: + self.shutdown_flag = True print('finished server') def run(self): - try: - # open training thread - self.thread = threading.Thread(target=self.trainer.run) - self.thread.start() - # open generator, evaluator - self.worker.run() - self.server() - - finally: - self.shutdown() + # open training thread + threading.Thread(target=self.trainer.run, daemon=True).start() + # open generator, evaluator + self.worker.run() + self.server() def train_main(args): diff --git a/handyrl/worker.py b/handyrl/worker.py index 112494fa..0cf47b63 100755 --- a/handyrl/worker.py +++ b/handyrl/worker.py @@ -12,6 +12,7 @@ import multiprocessing as mp import pickle import copy +import queue from .environment import prepare_env, make_env from .connection import QueueCommunicator @@ -65,6 +66,8 @@ def _gather_models(self, model_ids): def run(self): while True: args = send_recv(self.conn, ('args', None)) + if args is None: + break role = args['role'] models = {} @@ -118,20 +121,23 @@ def __init__(self, args, conn, gaid): for conn in worker_conns: self.add_connection(conn) - self.args_buf_len = 1 + len(worker_conns) // 4 - self.result_buf_len = 1 + len(worker_conns) // 4 + self.buffer_length = 1 + len(worker_conns) // 4 def __del__(self): print('finished gather %d' % self.gather_id) def run(self): - while True: - conn, (command, args) = self.recv() + while self.connection_count() > 0: + try: + conn, (command, args) = self.recv(timeout=0.3) + except queue.Empty: + continue + if command == 'args': # When requested arguments, return buffered outputs if len(self.args_queue) == 0: # get multiple arguments from server and store them - self.server_conn.send((command, [None] * self.args_buf_len)) + self.server_conn.send((command, [None] * self.buffer_length)) self.args_queue += self.server_conn.recv() next_args = self.args_queue.popleft() @@ -153,7 +159,7 @@ def run(self): self.result_send_map[command].append(args) self.result_send_cnt += 1 - if self.result_send_cnt >= self.result_buf_len: + if self.result_send_cnt >= self.buffer_length: # send datum to server after buffering certain number of datum for command, args_list in self.result_send_map.items(): self.server_conn.send((command, args_list)) @@ -163,11 +169,8 @@ def run(self): def gather_loop(args, conn, gaid): - try: - gather = Gather(args, conn, gaid) - gather.run() - finally: - gather.shutdown() + gather = Gather(args, conn, gaid) + gather.run() class WorkerCluster(QueueCommunicator): @@ -196,34 +199,29 @@ def run(self): # prepare listening connections def entry_server(port): print('started entry server %d' % port) - conn_acceptor = accept_socket_connections(port=port, timeout=0.3) - while not self.shutdown_flag: + conn_acceptor = accept_socket_connections(port=port) + while True: conn = next(conn_acceptor) - if conn is not None: - worker_args = conn.recv() - print('accepted connection from %s!' % worker_args['address']) - worker_args['base_worker_id'] = self.total_worker_count - self.total_worker_count += worker_args['num_parallel'] - args = copy.deepcopy(self.args) - args['worker'] = worker_args - conn.send(args) - conn.close() + worker_args = conn.recv() + print('accepted connection from %s!' % worker_args['address']) + worker_args['base_worker_id'] = self.total_worker_count + self.total_worker_count += worker_args['num_parallel'] + args = copy.deepcopy(self.args) + args['worker'] = worker_args + conn.send(args) + conn.close() print('finished entry server') def worker_server(port): print('started worker server %d' % port) - conn_acceptor = accept_socket_connections(port=port, timeout=0.3) - while not self.shutdown_flag: + conn_acceptor = accept_socket_connections(port=port) + while True: conn = next(conn_acceptor) - if conn is not None: - self.add_connection(conn) + self.add_connection(conn) print('finished worker server') - # use thread list of super class - self.threads.append(threading.Thread(target=entry_server, args=(9999,))) - self.threads.append(threading.Thread(target=worker_server, args=(9998,))) - self.threads[-2].start() - self.threads[-1].start() + threading.Thread(target=entry_server, args=(9999,), daemon=True).start() + threading.Thread(target=worker_server, args=(9998,), daemon=True).start() def entry(worker_args): diff --git a/scripts/loss_plot.py b/scripts/loss_plot.py new file mode 100644 index 00000000..aa6db5de --- /dev/null +++ b/scripts/loss_plot.py @@ -0,0 +1,91 @@ + +# Usage: python3 scripts/loss_plot.py train-log-002.txt title + +# You should not include figures generated by this script in your academic paper, because +# 1. This version of HandyRL doesn't display all the results of the matches. +# 2. Smoothing method in this script is not a simple moving average. + +import sys +import numpy as np +import matplotlib.pyplot as plt +import seaborn as sns + + +n = 10 + + +def kernel(n): + a = np.array(list(range(1, 1 + (n+1)//2)) + list(range(1 + n//2, 1, -1))) + return a / a.sum() + + +def get_loss_list(path): + epoch_data_list = [] + epoch_list = [0] + step_list = [0] + game_list = [0] + + f = open(path) + lines = f.readlines() + prev_line = '' + + for line in lines: + if line.startswith('updated'): + epoch_list.append(len(epoch_list)) + step_list.append(int(line.split('(')[1].rstrip().rstrip(')'))) + if line.startswith('loss'): + elms = line.split() + epoch_data_list.append({}) + for e in elms[2:]: + name, loss = e.split(':') + loss = float(loss) + epoch_data_list[-1][name] = loss + if line.startswith('epoch '): + print(line, len(epoch_list)) + if ' ' in prev_line: + game = int(prev_line.split()[-1]) + game_list.append(game) + + prev_line = line + + game_list = game_list[:len(epoch_data_list)] + + if (len(epoch_list) - n//2 + 1 - n//2) - n <= 0: + raise Exception('Range to average `n` is too large, set small nubmer') + clipped_epoch_list = epoch_list[n//2:-n//2+1] + clipped_step_list = step_list[n//2:-n//2+1] + clipped_game_list = game_list[n//2:-n//2+1] + + kn = kernel(n) + start_epoch = {} + averaged_loss_lists = {} + for name in epoch_data_list[0].keys(): + data = [d[name] for d in epoch_data_list] + averaged_loss_lists[name] = np.convolve(data, kn, mode='valid') + start_epoch = 0 + + return clipped_epoch_list, clipped_step_list, clipped_game_list, averaged_loss_lists, start_epoch + + +# Plot +flatui = ["#9b59b6", "#95a5a6", "#34495e", "#3498db", "#e74c3c", "#2ecc71", "#b22222"] +sns.set_palette(sns.color_palette(flatui, 24)) + +clipped_epoch_list, clipped_step_list, clipped_game_list, averaged_loss_lists, start_epoch = get_loss_list(sys.argv[1]) + +fig_num = len(averaged_loss_lists) +fig = plt.figure() + +for i, (k, v) in enumerate(averaged_loss_lists.items()): + ax = fig.add_subplot((fig_num - 1) // 2 + 1, 2, i + 1) + start = start_epoch + ax.plot(clipped_game_list[start:], v[start:], label=k) + + ax.set_xlabel('Games') + ax.set_ylabel(k) + ax.set_title(f"loss: {k}") + ax.grid() + +fig.suptitle(sys.argv[2]) +fig.tight_layout() +plt.show() diff --git a/scripts/stats_plot.py b/scripts/stats_plot.py new file mode 100644 index 00000000..6696e0c2 --- /dev/null +++ b/scripts/stats_plot.py @@ -0,0 +1,91 @@ + +# Usage: python3 scripts/stats_plot.py train-log-002.txt title + +# You should not include figures generated by this script in your academic paper, because +# 1. This version of HandyRL doesn't display all the results of the matches. +# 2. Smoothing method in this script is not a simple moving average. + +import sys +import numpy as np +import matplotlib.pyplot as plt +import seaborn as sns + + +n = 10 + + +def kernel(n): + a = np.array(list(range(1, 1 + (n+1)//2)) + list(range(1 + n//2, 1, -1))) + return a / a.sum() + + +def get_stats_list(path): + epoch_data_list = [] + epoch_list = [0] + step_list = [0] + game_list = [0] + + f = open(path) + lines = f.readlines() + prev_line = '' + + for line in lines: + if line.startswith('updated'): + epoch_list.append(len(epoch_list)) + step_list.append(int(line.split('(')[1].rstrip().rstrip(')'))) + if line.startswith('generation stats'): + elms = line.split('=') + epoch_data_list.append({}) + for e in range(0, len(elms), 2): + name, stats = elms[e], elms[e+1].split('+-') + stats_mean, stats_std = float(stats[0]), float(stats[1]) + epoch_data_list[-1][name] = stats_mean + if line.startswith('epoch '): + print(line, len(epoch_list)) + if ' ' in prev_line: + game = int(prev_line.split()[-1]) + game_list.append(game) + + prev_line = line + + game_list = game_list[:len(epoch_data_list)] + + if (len(epoch_list) - n//2 + 1 - n//2) - n <= 0: + raise Exception('Range to average `n` is too large, set small nubmer') + clipped_epoch_list = epoch_list[n//2:-n//2+1] + clipped_step_list = step_list[n//2:-n//2+1] + clipped_game_list = game_list[n//2:-n//2+1] + + kn = kernel(n) + start_epoch = {} + averaged_stats_lists = {} + for name in epoch_data_list[0].keys(): + data = [d[name] for d in epoch_data_list] + averaged_stats_lists[name] = np.convolve(data, kn, mode='valid') + start_epoch = 0 + + return clipped_epoch_list, clipped_step_list, clipped_game_list, averaged_stats_lists, start_epoch + + +# Plot +flatui = ["#9b59b6", "#95a5a6", "#34495e", "#3498db", "#e74c3c", "#2ecc71", "#b22222"] +sns.set_palette(sns.color_palette(flatui, 24)) + +clipped_epoch_list, clipped_step_list, clipped_game_list, averaged_stats_lists, start_epoch = get_stats_list(sys.argv[1]) + +fig_num = len(averaged_stats_lists) +fig = plt.figure() + +for i, (k, v) in enumerate(averaged_stats_lists.items()): + ax = fig.add_subplot((fig_num - 1) // 2 + 1, 2, i + 1) + start = start_epoch + ax.plot(clipped_game_list[start:], v[start:], label=k) + + ax.set_xlabel('Games') + ax.set_ylabel(k) + ax.set_title(f"stats: {k}") + ax.grid() + +fig.suptitle(sys.argv[2]) +fig.tight_layout() +plt.show() diff --git a/scripts/win_rate_plot.py b/scripts/win_rate_plot.py index 7d8e67e3..cf18878e 100644 --- a/scripts/win_rate_plot.py +++ b/scripts/win_rate_plot.py @@ -7,10 +7,13 @@ import sys import numpy as np +import matplotlib.pyplot as plt +import seaborn as sns n = 15 + def kernel(n): a = np.array(list(range(1, 1 + (n+1)//2)) + list(range(1 + n//2,1,-1))) return a / a.sum() @@ -34,7 +37,7 @@ def get_wp_list(path): step_list.append(int(line.split('(')[1].rstrip().rstrip(')'))) if line.startswith('win rate'): elms = line.split() - opponent = elms[2].lstrip('(').rstrip(')') + opponent = elms[2].lstrip('(').rstrip(')') if elms[2] != '=' else 'total' games = int(elms[-1].lstrip('(').rstrip(')')) wp = float(elms[-4]) if games > 0 else 0.0 epoch_data_list[-1][opponent] = {'w': games * wp, 'n': games} @@ -69,8 +72,6 @@ def get_wp_list(path): return clipped_epoch_list, clipped_step_list, clipped_game_list, averaged_wp_lists, start_epoch -import matplotlib.pyplot as plt -import seaborn as sns flatui = ["#9b59b6", "#95a5a6", "#34495e", "#3498db", "#e74c3c", "#2ecc71", "#b22222"] sns.set_palette(sns.color_palette(flatui, 24)) @@ -84,8 +85,6 @@ def get_wp_list(path): last_win_rate = {} for opponent in opponents: - if opponent == 'total': - continue wp_list = averaged_wp_lists[opponent] start = start_epoch[opponent] # ax.plot(clipped_epoch_list[start:], wp_list[start:], label=opponent)