From 903ac151e37d5556b9603b248c8ee19bec127e7c Mon Sep 17 00:00:00 2001 From: YuriCat Date: Tue, 25 Jan 2022 02:47:33 +0900 Subject: [PATCH 01/10] feature: abolish entry_server and worker_server accepts entry --- handyrl/worker.py | 35 ++++++++++------------------------- 1 file changed, 10 insertions(+), 25 deletions(-) diff --git a/handyrl/worker.py b/handyrl/worker.py index 58cd12f7..f791f4c4 100755 --- a/handyrl/worker.py +++ b/handyrl/worker.py @@ -191,8 +191,8 @@ def __init__(self, args): def run(self): # prepare listening connections - def entry_server(port): - print('started entry server %d' % port) + def worker_server(port): + print('started worker server %d' % port) conn_acceptor = accept_socket_connections(port=port, timeout=0.3) while not self.shutdown_flag: conn = next(conn_acceptor) @@ -202,33 +202,14 @@ def entry_server(port): args = copy.deepcopy(self.args) args['worker'] = worker_args conn.send(args) - conn.close() - print('finished entry server') - - def worker_server(port): - conn_acceptor = accept_socket_connections(port=port, timeout=0.3) - print('started worker server %d' % port) - while not self.shutdown_flag: # use super class's flag - conn = next(conn_acceptor) - if conn is not None: self.add_connection(conn) print('finished worker server') # use thread list of super class - self.threads.append(threading.Thread(target=entry_server, args=(9999,))) self.threads.append(threading.Thread(target=worker_server, args=(9998,))) - self.threads[-2].start() self.threads[-1].start() -def entry(worker_args): - conn = connect_socket_connection(worker_args['server_address'], 9999) - conn.send(worker_args) - args = conn.recv() - conn.close() - return args - - class RemoteWorkerCluster: def __init__(self, args): args['address'] = gethostname() @@ -238,15 +219,19 @@ def __init__(self, args): self.args = args def run(self): - args = entry(self.args) - print(args) - prepare_env(args['env']) - # open worker process = [] try: for i in range(self.args['num_gathers']): + # entry conn = connect_socket_connection(self.args['server_address'], 9998) + conn.send(self.args) + args = conn.recv() + + if i == 0: # call once at every machine + print(args) + prepare_env(args['env']) + p = mp.Process(target=gather_loop, args=(args, conn, i)) p.start() conn.close() From 41b80175d4f4c8fa8a6a03c7b251c4d12b5500ff Mon Sep 17 00:00:00 2001 From: YuriCat Date: Tue, 25 Jan 2022 05:56:25 +0900 Subject: [PATCH 02/10] experiment: use websocket instead of socket --- handyrl/connection.py | 70 +++++++++++++++++++++++++++++++++++++++++++ handyrl/train.py | 8 ++++- handyrl/worker.py | 50 +++++++++++-------------------- requirements.txt | 2 ++ 4 files changed, 97 insertions(+), 33 deletions(-) diff --git a/handyrl/connection.py b/handyrl/connection.py index 28805d00..5c5a6005 100755 --- a/handyrl/connection.py +++ b/handyrl/connection.py @@ -6,11 +6,15 @@ import struct import socket import pickle +import base64 import threading import queue import multiprocessing as mp import multiprocessing.connection as connection +from websocket import create_connection +from websocket_server import WebsocketServer + def send_recv(conn, sdata): conn.send(sdata) @@ -196,6 +200,28 @@ def _receiver(self, index): print('finished receiver %d' % index) +class WebsocketConnection: + def __init__(self, conn): + self.conn = conn + + def send(self, data): + message = base64.b64encode(pickle.dumps(data)) + self.conn.send(message) + + def recv(self): + message = self.conn.recv() + return pickle.loads(base64.b64decode(message)) + + def close(self): + self.conn.close() + + +def connect_websocket_connection(host, port): + host = socket.gethostbyname(host) + conn = create_connection('ws://%s:%d' % (host, port)) + return WebsocketConnection(conn) + + class QueueCommunicator: def __init__(self, conns=[]): self.input_queue = queue.Queue(maxsize=256) @@ -262,3 +288,47 @@ def _recv_thread(self): break except queue.Full: pass + + +class WebsocketCommunicator(WebsocketServer): + def __init__(self): + super().__init__(port=9998, host='127.0.0.1') + + self.input_queue = queue.Queue(maxsize=256) + self.output_queue = queue.Queue(maxsize=256) + self.shutdown_flag = False + + def run(self): + self.set_fn_new_client(self._new_client) + self.set_fn_message_received(self._message_received) + self.run_forever(threaded=True) + + def shutdown(self): + self.shutdown_flag = True + self.shutdown_gracefully() + + def recv(self): + return self.input_queue.get() + + def send(self, client, send_data): + self.output_queue.put((client, send_data)) + + @staticmethod + def _new_client(client, server): + print('New client {}:{} has joined.'.format(client['address'][0], client['address'][1])) + + @staticmethod + def _message_received(client, server, message): + while not server.shutdown_flag: + try: + server.input_queue.put((client, pickle.loads(base64.b64decode(message))), timeout=0.3) + break + except queue.Full: + pass + while not server.shutdown_flag: + try: + client, reply_message = server.output_queue.get(timeout=0.3) + break + except queue.Empty: + continue + server.send_message(client, base64.b64encode(pickle.dumps(reply_message))) diff --git a/handyrl/train.py b/handyrl/train.py index 79ae2afc..57ea5e10 100755 --- a/handyrl/train.py +++ b/handyrl/train.py @@ -555,7 +555,13 @@ def server(self): data = [data] send_data = [] - if req == 'args': + if req == 'entry': + for worker_args in data: + args = copy.deepcopy(self.args) + args['worker'] = worker_args + send_data.append(args) + + elif req == 'args': for _ in data: args = {'model_id': {}} diff --git a/handyrl/worker.py b/handyrl/worker.py index f791f4c4..bc17fb8e 100755 --- a/handyrl/worker.py +++ b/handyrl/worker.py @@ -3,6 +3,7 @@ # worker and gather +import base64 import random import threading import time @@ -14,9 +15,9 @@ import copy from .environment import prepare_env, make_env -from .connection import QueueCommunicator +from .connection import QueueCommunicator, WebsocketCommunicator from .connection import send_recv, open_multiprocessing_connections -from .connection import connect_socket_connection, accept_socket_connections +from .connection import connect_websocket_connection from .evaluation import Evaluator from .generation import Generator from .model import ModelWrapper, RandomModel @@ -161,6 +162,18 @@ def run(self): def gather_loop(args, conn, gaid): + if conn is None: + # entry + conn = connect_websocket_connection(args['server_address'], 9998) + + #conn.send(self.args) + conn.send(('entry', args)) + args = conn.recv() + print('entry finished') + + print(args) + prepare_env(args['env']) # TODO: call once + try: gather = Gather(args, conn, gaid) gather.run() @@ -184,30 +197,13 @@ def run(self): self.add_connection(conn0) -class WorkerServer(QueueCommunicator): +class WorkerServer(WebsocketCommunicator): def __init__(self, args): super().__init__() self.args = args def run(self): - # prepare listening connections - def worker_server(port): - print('started worker server %d' % port) - conn_acceptor = accept_socket_connections(port=port, timeout=0.3) - while not self.shutdown_flag: - conn = next(conn_acceptor) - if conn is not None: - worker_args = conn.recv() - print('accepted connection from %s!' % worker_args['address']) - args = copy.deepcopy(self.args) - args['worker'] = worker_args - conn.send(args) - self.add_connection(conn) - print('finished worker server') - - # use thread list of super class - self.threads.append(threading.Thread(target=worker_server, args=(9998,))) - self.threads[-1].start() + super().run() class RemoteWorkerCluster: @@ -223,18 +219,8 @@ def run(self): process = [] try: for i in range(self.args['num_gathers']): - # entry - conn = connect_socket_connection(self.args['server_address'], 9998) - conn.send(self.args) - args = conn.recv() - - if i == 0: # call once at every machine - print(args) - prepare_env(args['env']) - - p = mp.Process(target=gather_loop, args=(args, conn, i)) + p = mp.Process(target=gather_loop, args=(self.args, None, i)) p.start() - conn.close() process.append(p) while True: time.sleep(100) diff --git a/requirements.txt b/requirements.txt index 90443c01..5c353d48 100755 --- a/requirements.txt +++ b/requirements.txt @@ -3,3 +3,5 @@ numpy torch pytest psutil +websocket-server +websocket-client From 2d385bc698b196623f41431a6d8276c9134b8295 Mon Sep 17 00:00:00 2001 From: YuriCat Date: Wed, 26 Jan 2022 12:25:08 +0900 Subject: [PATCH 03/10] feature: entry in gather_loop --- handyrl/agent.py | 17 ++++------ handyrl/environment.py | 6 ++++ handyrl/envs/geister.py | 18 ++++------- handyrl/evaluation.py | 8 +++-- handyrl/generation.py | 45 ++++++++++++++------------ handyrl/model.py | 6 +++- handyrl/train.py | 11 +++---- handyrl/worker.py | 72 +++++++++++++++++------------------------ 8 files changed, 88 insertions(+), 95 deletions(-) diff --git a/handyrl/agent.py b/handyrl/agent.py index f26dfb87..4af84f5e 100755 --- a/handyrl/agent.py +++ b/handyrl/agent.py @@ -41,11 +41,10 @@ def print_outputs(env, prob, v): class Agent: - def __init__(self, model, observation=False, temperature=0.0): + def __init__(self, model, temperature=0.0): # model might be a neural net, or some planning algorithm such as game tree search self.model = model self.hidden = None - self.observation = observation self.temperature = temperature def reset(self, env, show=False): @@ -75,12 +74,10 @@ def action(self, env, player, show=False): return random.choices(np.arange(len(p)), weights=softmax(p / self.temperature))[0] def observe(self, env, player, show=False): - v = None - if self.observation: - outputs = self.plan(env.observation(player)) - v = outputs.get('value', None) - if show: - print_outputs(env, None, v) + outputs = self.plan(env.observation(player)) + v = outputs.get('value', None) + if show: + print_outputs(env, None, v) return v if v is not None else [0.0] @@ -103,5 +100,5 @@ def plan(self, obs): class SoftAgent(Agent): - def __init__(self, model, observation=False): - super().__init__(model, observation=observation, temperature=1.0) + def __init__(self, model): + super().__init__(model, temperature=1.0) diff --git a/handyrl/environment.py b/handyrl/environment.py index 9bca1713..02bb836c 100755 --- a/handyrl/environment.py +++ b/handyrl/environment.py @@ -77,6 +77,12 @@ def turn(self): def turns(self): return [self.turn()] + # + # Should be defined if players except turn player also observe game states + # + def observers(self): + return [] + # # Should be defined in all games # diff --git a/handyrl/envs/geister.py b/handyrl/envs/geister.py index 9a395659..a82bd6af 100755 --- a/handyrl/envs/geister.py +++ b/handyrl/envs/geister.py @@ -34,16 +34,10 @@ def __init__(self, input_dim, hidden_dim, kernel_size, bias): ) def init_hidden(self, input_size, batch_size): - if batch_size is None: # for inference - return tuple([ - np.zeros((self.hidden_dim, *input_size), dtype=np.float32), - np.zeros((self.hidden_dim, *input_size), dtype=np.float32) - ]) - else: # for training - return tuple([ - torch.zeros(*batch_size, self.hidden_dim, *input_size), - torch.zeros(*batch_size, self.hidden_dim, *input_size) - ]) + return tuple([ + torch.zeros(*batch_size, self.hidden_dim, *input_size), + torch.zeros(*batch_size, self.hidden_dim, *input_size) + ]) def forward(self, input_tensor, cur_state): h_cur, c_cur = cur_state @@ -150,7 +144,7 @@ def __init__(self): self.head_v = ScalarHead((filters * 2, 6, 6), 1, 1) self.head_r = ScalarHead((filters * 2, 6, 6), 1, 1) - def init_hidden(self, batch_size=None): + def init_hidden(self, batch_size=[]): return self.body.init_hidden(self.input_size[1:], batch_size) def forward(self, x, hidden): @@ -453,6 +447,8 @@ def legal(self, action): if self.turn_count < 0: layout = action - 4 * 6 * 6 return 0 <= layout < 70 + elif not 0 <= action < 4 * 6 * 6: + return False pos_from = self.action2from(action, self.color) pos_to = self.action2to(action, self.color) diff --git a/handyrl/evaluation.py b/handyrl/evaluation.py index a6d0ddf1..b3a7cb4a 100755 --- a/handyrl/evaluation.py +++ b/handyrl/evaluation.py @@ -88,11 +88,12 @@ def exec_match(env, agents, critic, show=False, game_args={}): if show and critic is not None: print('cv = ', critic.observe(env, None, show=False)[0]) turn_players = env.turns() + observers = env.observers() actions = {} for p, agent in agents.items(): if p in turn_players: actions[p] = agent.action(env, p, show=show) - else: + elif p in observers: agent.observe(env, p, show=show) if env.step(actions): return None @@ -117,12 +118,13 @@ def exec_network_match(env, network_agents, critic, show=False, game_args={}): if show and critic is not None: print('cv = ', critic.observe(env, None, show=False)[0]) turn_players = env.turns() + observers = env.observers() actions = {} for p, agent in network_agents.items(): if p in turn_players: action = agent.action(p) actions[p] = env.str2action(action, p) - else: + elif p in observers: agent.observe(p) if env.step(actions): return None @@ -161,7 +163,7 @@ def execute(self, models, args): if model is None: agents[p] = build_agent(opponent, self.env) else: - agents[p] = Agent(model, self.args['observation']) + agents[p] = Agent(model) outcome = exec_match(self.env, agents, None) if outcome is None: diff --git a/handyrl/generation.py b/handyrl/generation.py index 63b7e553..3ad3b76e 100755 --- a/handyrl/generation.py +++ b/handyrl/generation.py @@ -33,28 +33,31 @@ def generate(self, models, args): moment = {key: {p: None for p in self.env.players()} for key in moment_keys} turn_players = self.env.turns() + observers = self.env.observers() for player in self.env.players(): - if player in turn_players or self.args['observation']: - obs = self.env.observation(player) - model = models[player] - outputs = model.inference(obs, hidden[player]) - hidden[player] = outputs.get('hidden', None) - v = outputs.get('value', None) - - moment['observation'][player] = obs - moment['value'][player] = v - - if player in turn_players: - p_ = outputs['policy'] - legal_actions = self.env.legal_actions(player) - action_mask = np.ones_like(p_) * 1e32 - action_mask[legal_actions] = 0 - p = p_ - action_mask - action = random.choices(legal_actions, weights=softmax(p[legal_actions]))[0] - - moment['policy'][player] = p - moment['action_mask'][player] = action_mask - moment['action'][player] = action + if player not in turn_players + observers: + continue + + obs = self.env.observation(player) + model = models[player] + outputs = model.inference(obs, hidden[player]) + hidden[player] = outputs.get('hidden', None) + v = outputs.get('value', None) + + moment['observation'][player] = obs + moment['value'][player] = v + + if player in turn_players: + p_ = outputs['policy'] + legal_actions = self.env.legal_actions(player) + action_mask = np.ones_like(p_) * 1e32 + action_mask[legal_actions] = 0 + p = p_ - action_mask + action = random.choices(legal_actions, weights=softmax(p[legal_actions]))[0] + + moment['policy'][player] = p + moment['action_mask'][player] = action_mask + moment['action'][player] = action err = self.env.step(moment['action']) if err: diff --git a/handyrl/model.py b/handyrl/model.py index 621d703f..9eb7b94b 100755 --- a/handyrl/model.py +++ b/handyrl/model.py @@ -37,7 +37,11 @@ def __init__(self, model): def init_hidden(self, batch_size=None): if hasattr(self.model, 'init_hidden'): - return self.model.init_hidden(batch_size) + if batch_size is None: # for inference + hidden = self.model.init_hidden([]) + return map_r(hidden, lambda h: h.detach().numpy() if isinstance(h, torch.Tensor) else h) + else: # for training + return self.model.init_hidden(batch_size) return None def forward(self, *args, **kwargs): diff --git a/handyrl/train.py b/handyrl/train.py index 79ae2afc..0cee1013 100755 --- a/handyrl/train.py +++ b/handyrl/train.py @@ -26,7 +26,6 @@ from .model import to_torch, to_gpu, ModelWrapper from .losses import compute_target from .connection import MultiProcessJobExecutor -from .connection import accept_socket_connections from .worker import WorkerCluster, WorkerServer @@ -60,7 +59,7 @@ def replace_none(a, b): obs_zeros = map_r(moments[0]['observation'][moments[0]['turn'][0]], lambda o: np.zeros_like(o)) # template for padding p_zeros = np.zeros_like(moments[0]['policy'][moments[0]['turn'][0]]) # template for padding - # data that is chainge by training configuration + # data that is changed by training configuration if args['turn_based_training'] and not args['observation']: obs = [[m['observation'][m['turn'][0]]] for m in moments] p = np.array([[m['policy'][m['turn'][0]]] for m in moments]) @@ -154,7 +153,7 @@ def forward_prediction(model, hidden, batch, args): outputs_ = model(obs, hidden_) for k, o in outputs_.items(): if k == 'hidden': - next_hidden = outputs_['hidden'] + next_hidden = o else: outputs[k] = outputs.get(k, []) + [o] next_hidden = bimap_r(next_hidden, hidden, lambda nh, h: nh.view(h.size(0), -1, *h.size()[2:])) # (..., B, P or 1, ...) @@ -349,8 +348,8 @@ def shutdown(self): def train(self): if self.optimizer is None: # non-parametric model - print() - return + time.sleep(0.1) + return self.model batch_cnt, data_cnt, loss_sum = 0, 0, {} if self.gpu > 0: @@ -395,7 +394,7 @@ def run(self): if len(self.episodes) < self.args['minimum_episodes']: time.sleep(1) continue - if self.steps == 0: + if self.steps == 0 and self.optimizer is not None: self.batcher.run() print('started training') model = self.train() diff --git a/handyrl/worker.py b/handyrl/worker.py index 7048097c..a76d875b 100755 --- a/handyrl/worker.py +++ b/handyrl/worker.py @@ -84,8 +84,8 @@ def run(self): send_recv(self.conn, ('result', result)) -def make_worker_args(args, n_ga, gaid, base_wid, wid, conn): - return args, conn, base_wid + wid * n_ga + gaid +def make_worker_args(args, base_wid, wid, conn): + return args, conn, base_wid + wid def open_worker(args, conn, wid): @@ -94,25 +94,20 @@ def open_worker(args, conn, wid): class Gather(QueueCommunicator): - def __init__(self, args, conn, gaid): - print('started gather %d' % gaid) + def __init__(self, args, conn, gather_id, base_worker_id, num_workers): + print('started gather %d' % gather_id) super().__init__() - self.gather_id = gaid + self.gather_id = gather_id self.server_conn = conn self.args_queue = deque([]) self.data_map = {'model': {}} self.result_send_map = {} self.result_send_cnt = 0 - n_pro, n_ga = args['worker']['num_parallel'], args['worker']['num_gathers'] - - num_workers_per_gather = (n_pro // n_ga) + int(gaid < n_pro % n_ga) - base_wid = args['worker'].get('base_worker_id', 0) - worker_conns = open_multiprocessing_connections( - num_workers_per_gather, + num_workers, open_worker, - functools.partial(make_worker_args, args, n_ga, gaid, base_wid) + functools.partial(make_worker_args, args, base_worker_id) ) for conn in worker_conns: @@ -162,9 +157,25 @@ def run(self): self.result_send_cnt = 0 -def gather_loop(args, conn, gaid): +def gather_loop(args, conn, gather_id): + n_pro, n_ga = args['worker']['num_parallel'], args['worker']['num_gathers'] + n_pro_w = (n_pro // n_ga) + int(gather_id < n_pro % n_ga) + args['worker']['num_parallel_per_gather'] = n_pro_w + base_worker_id = 0 + + if conn is None: + # entry + conn = connect_socket_connection(args['worker']['server_address'], 9998) + conn.send(args['worker']) + args = conn.recv() + + if gather_id == 0: # call once at every machine + print(args) + prepare_env(args['env']) + base_worker_id = args['worker'].get('base_worker_id', 0) + try: - gather = Gather(args, conn, gaid) + gather = Gather(args, conn, gather_id, base_worker_id, n_pro_w) gather.run() finally: gather.shutdown() @@ -194,8 +205,8 @@ def __init__(self, args): def run(self): # prepare listening connections - def entry_server(port): - print('started entry server %d' % port) + def worker_server(port): + print('started worker server %d' % port) conn_acceptor = accept_socket_connections(port=port, timeout=0.3) while not self.shutdown_flag: conn = next(conn_acceptor) @@ -203,37 +214,18 @@ def entry_server(port): worker_args = conn.recv() print('accepted connection from %s!' % worker_args['address']) worker_args['base_worker_id'] = self.total_worker_count - self.total_worker_count += worker_args['num_parallel'] + self.total_worker_count += worker_args['num_parallel_per_gather'] args = copy.deepcopy(self.args) args['worker'] = worker_args conn.send(args) - conn.close() - print('finished entry server') - - def worker_server(port): - conn_acceptor = accept_socket_connections(port=port, timeout=0.3) - print('started worker server %d' % port) - while not self.shutdown_flag: # use super class's flag - conn = next(conn_acceptor) - if conn is not None: self.add_connection(conn) print('finished worker server') # use thread list of super class - self.threads.append(threading.Thread(target=entry_server, args=(9999,))) self.threads.append(threading.Thread(target=worker_server, args=(9998,))) - self.threads[-2].start() self.threads[-1].start() -def entry(worker_args): - conn = connect_socket_connection(worker_args['server_address'], 9999) - conn.send(worker_args) - args = conn.recv() - conn.close() - return args - - class RemoteWorkerCluster: def __init__(self, args): args['address'] = gethostname() @@ -243,18 +235,12 @@ def __init__(self, args): self.args = args def run(self): - args = entry(self.args) - print(args) - prepare_env(args['env']) - # open worker process = [] try: for i in range(self.args['num_gathers']): - conn = connect_socket_connection(self.args['server_address'], 9998) - p = mp.Process(target=gather_loop, args=(args, conn, i)) + p = mp.Process(target=gather_loop, args=({'worker': self.args}, None, i)) p.start() - conn.close() process.append(p) while True: time.sleep(100) From cb03657d58e19e351d868efe9a170f7728a9e2af Mon Sep 17 00:00:00 2001 From: YuriCat Date: Wed, 2 Feb 2022 15:02:07 +0900 Subject: [PATCH 04/10] Merge develop 2 --- handyrl/connection.py | 26 -------------------------- 1 file changed, 26 deletions(-) diff --git a/handyrl/connection.py b/handyrl/connection.py index 7dc1ef97..28805d00 100755 --- a/handyrl/connection.py +++ b/handyrl/connection.py @@ -6,15 +6,11 @@ import struct import socket import pickle -import base64 import threading import queue import multiprocessing as mp import multiprocessing.connection as connection -from websocket import create_connection -from websocket_server import WebsocketServer - def send_recv(conn, sdata): conn.send(sdata) @@ -200,28 +196,6 @@ def _receiver(self, index): print('finished receiver %d' % index) -class WebsocketConnection: - def __init__(self, conn): - self.conn = conn - - def send(self, data): - message = base64.b64encode(pickle.dumps(data)) - self.conn.send(message) - - def recv(self): - message = self.conn.recv() - return pickle.loads(base64.b64decode(message)) - - def close(self): - self.conn.close() - - -def connect_websocket_connection(host, port): - host = socket.gethostbyname(host) - conn = create_connection('ws://%s:%d' % (host, port)) - return WebsocketConnection(conn) - - class QueueCommunicator: def __init__(self, conns=[]): self.input_queue = queue.Queue(maxsize=256) From 2fc0756cc5583919f12fb537d11684650763c670 Mon Sep 17 00:00:00 2001 From: YuriCat Date: Wed, 2 Feb 2022 15:09:28 +0900 Subject: [PATCH 05/10] chore: remove unused imports --- handyrl/connection.py | 1 - handyrl/worker.py | 2 -- 2 files changed, 3 deletions(-) diff --git a/handyrl/connection.py b/handyrl/connection.py index 28805d00..e5a163af 100755 --- a/handyrl/connection.py +++ b/handyrl/connection.py @@ -2,7 +2,6 @@ # Licensed under The MIT License [see LICENSE for details] import io -import time import struct import socket import pickle diff --git a/handyrl/worker.py b/handyrl/worker.py index a4ef5b9c..3ef1e369 100755 --- a/handyrl/worker.py +++ b/handyrl/worker.py @@ -5,7 +5,6 @@ import base64 import random -import threading import time import functools from socket import gethostname @@ -17,7 +16,6 @@ from .environment import prepare_env, make_env from .connection import QueueCommunicator from .connection import send_recv, open_multiprocessing_connections -from .connection import connect_websocket_connection from .evaluation import Evaluator from .generation import Generator from .model import ModelWrapper, RandomModel From 211733d4a12f0139cc32c2bb69fd5b505e85071a Mon Sep 17 00:00:00 2001 From: YuriCat Date: Wed, 2 Feb 2022 15:13:16 +0900 Subject: [PATCH 06/10] feature: stop producing entry in server() --- handyrl/train.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/handyrl/train.py b/handyrl/train.py index 0027940c..07dbf81f 100755 --- a/handyrl/train.py +++ b/handyrl/train.py @@ -571,13 +571,7 @@ def server(self): data = [data] send_data = [] - if req == 'entry': - for worker_args in data: - args = copy.deepcopy(self.args) - args['worker'] = worker_args - send_data.append(args) - - elif req == 'args': + if req == 'args': for _ in data: args = {'model_id': {}} From 188057bf0c79c46cae197865a38444faa3f4e3c0 Mon Sep 17 00:00:00 2001 From: YuriCat Date: Thu, 3 Feb 2022 19:19:12 +0900 Subject: [PATCH 07/10] feature: enable to set server port --- handyrl/train.py | 2 ++ handyrl/worker.py | 6 ++++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/handyrl/train.py b/handyrl/train.py index 07dbf81f..69016ad6 100755 --- a/handyrl/train.py +++ b/handyrl/train.py @@ -423,7 +423,9 @@ class Learner: def __init__(self, args, net=None, remote=False): train_args = args['train_args'] env_args = args['env_args'] + worker_args = args['worker_args'] train_args['env'] = env_args + train_args['worker'] = worker_args args = train_args self.args = args diff --git a/handyrl/worker.py b/handyrl/worker.py index f881228c..e525648e 100755 --- a/handyrl/worker.py +++ b/handyrl/worker.py @@ -164,7 +164,8 @@ def gather_loop(args, conn, gather_id): if conn is None: # entry - conn = connect_websocket_connection(args['worker']['server_address'], 9998) + port = int(args['worker'].get('server_port', 9998)) + conn = connect_websocket_connection(args['worker']['server_address'], port) conn.send(('entry', args['worker'])) args = conn.recv() @@ -229,7 +230,8 @@ def connect_websocket_connection(host, port): class WorkerServer(WebsocketServer): def __init__(self, args): - super().__init__(port=9998, host='127.0.0.1') + port = int(args['worker'].get('server_port', 9998)) + super().__init__(port=port, host='0.0.0.0') self.input_queue = queue.Queue(maxsize=256) self.output_queue = queue.Queue(maxsize=256) self.shutdown_flag = False From 803d5be06ee7b462d663262117f5804b1384468a Mon Sep 17 00:00:00 2001 From: YuriCat Date: Fri, 4 Feb 2022 15:35:58 +0900 Subject: [PATCH 08/10] fix: revert using worker_args in train.py --- handyrl/train.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/handyrl/train.py b/handyrl/train.py index 69016ad6..07dbf81f 100755 --- a/handyrl/train.py +++ b/handyrl/train.py @@ -423,9 +423,7 @@ class Learner: def __init__(self, args, net=None, remote=False): train_args = args['train_args'] env_args = args['env_args'] - worker_args = args['worker_args'] train_args['env'] = env_args - train_args['worker'] = worker_args args = train_args self.args = args From 29a86b0f28e1ca6b35ca5da51b12bd3ce6ec9405 Mon Sep 17 00:00:00 2001 From: YuriCat Date: Fri, 4 Feb 2022 15:55:47 +0900 Subject: [PATCH 09/10] chore: add WebsocketConnection.dumps() and loads() --- handyrl/worker.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/handyrl/worker.py b/handyrl/worker.py index e525648e..7a2a55b3 100755 --- a/handyrl/worker.py +++ b/handyrl/worker.py @@ -210,13 +210,21 @@ class WebsocketConnection: def __init__(self, conn): self.conn = conn + @staticmethod + def dumps(data): + return base64.b64encode(pickle.dumps(data)) + + @staticmethod + def loads(message): + return pickle.loads(base64.b64decode(message)) + def send(self, data): - message = base64.b64encode(pickle.dumps(data)) + message = self.dumps(data) self.conn.send(message) def recv(self): message = self.conn.recv() - return pickle.loads(base64.b64decode(message)) + return self.loads(message) def close(self): self.conn.close() @@ -260,7 +268,7 @@ def _new_client(client, server): @staticmethod def _message_received(client, server, message): - message_ = pickle.loads(base64.b64decode(message)) + message_ = WebsocketConnection.loads(message) if message_[0] == 'entry': worker_args = message_[1] print('accepted connection from %s' % worker_args['address']) @@ -282,7 +290,7 @@ def _message_received(client, server, message): break except queue.Empty: continue - reply_message_ = base64.b64encode(pickle.dumps(reply_message)) + reply_message_ = WebsocketConnection.dumps(reply_message) server.send_message(client, reply_message_) From d4ec6f65ccd1157c75820b460adbfaa560053ab2 Mon Sep 17 00:00:00 2001 From: YuriCat Date: Thu, 7 Jul 2022 02:14:52 +0900 Subject: [PATCH 10/10] fix: update for new communication interface --- handyrl/worker.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/handyrl/worker.py b/handyrl/worker.py index 54bc5ce6..5c6a46ca 100755 --- a/handyrl/worker.py +++ b/handyrl/worker.py @@ -250,6 +250,9 @@ def __init__(self, args): self.args = args self.total_worker_count = 0 + def connection_count(self): + return len(self.clients) + def run(self): self.set_fn_new_client(self._new_client) self.set_fn_message_received(self._message_received) @@ -259,8 +262,8 @@ def shutdown(self): self.shutdown_flag = True self.shutdown_gracefully() - def recv(self): - return self.input_queue.get() + def recv(self, timeout=None): + return self.input_queue.get(timeout=timeout) def send(self, client, send_data): self.output_queue.put((client, send_data))