Skip to content

Commit

Permalink
Merge branch 'develop' into feature/output_whole_result
Browse files Browse the repository at this point in the history
  • Loading branch information
YuriCat committed Apr 20, 2022
2 parents 9180f1e + 981d7fe commit a526d0e
Show file tree
Hide file tree
Showing 8 changed files with 32 additions and 19 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
.DS_Store
*.pyc
*.pth
*.onnx
6 changes: 3 additions & 3 deletions handyrl/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,12 @@ def plan(self, obs):
outputs = {}
for i, model in enumerate(self.model):
o = model.inference(obs, self.hidden[i])
for k, v in o:
for k, v in o.items():
if k == 'hidden':
self.hidden[i] = v
else:
outputs[k] = outputs.get(k, []) + [o]
for k, vl in outputs:
outputs[k] = outputs.get(k, []) + [v]
for k, vl in outputs.items():
outputs[k] = np.mean(vl, axis=0)
return outputs

Expand Down
16 changes: 11 additions & 5 deletions handyrl/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@ def __init__(self, agent, env, conn):

def run(self):
while True:
command, args = self.conn.recv()
try:
command, args = self.conn.recv()
except ConnectionResetError:
break
if command == 'quit':
break
elif command == 'outcome':
Expand Down Expand Up @@ -137,7 +140,7 @@ def exec_network_match(env, network_agents, critic=None, show=False, game_args={
return outcome


def build_agent(raw, env):
def build_agent(raw, env=None):
if raw == 'random':
return RandomAgent()
elif raw == 'rulebase':
Expand Down Expand Up @@ -297,18 +300,21 @@ def _open_session(self):

self.ort_session = onnxruntime.InferenceSession(self.model_path, sess_options=opts)

def init_hidden(self):
def init_hidden(self, batch_size=None):
if self.ort_session is None:
self._open_session()
hidden_inputs = [y for y in self.ort_session.get_inputs() if y.name.startswith('hidden')]
if len(hidden_inputs) == 0:
return None

if batch_size is None:
batch_size = []
import numpy as np
type_map = {
'tensor(float)': np.float32,
'tensor(int64)': np.int64,
}
hidden_tensors = [np.zeros(y.shape[1:], dtype=type_map[y.type]) for y in hidden_inputs]
hidden_tensors = [np.zeros(batch_size + list(y.shape[1:]), dtype=type_map[y.type]) for y in hidden_inputs]
return hidden_tensors

def inference(self, x, hidden=None, batch_input=False):
Expand Down Expand Up @@ -414,7 +420,7 @@ def eval_client_main(args, argv):
host = argv[1] if len(argv) >= 2 else 'localhost'
conn = connect_socket_connection(host, network_match_port)
env_args = conn.recv()
except EOFError:
except ConnectionResetError:
break

model_path = argv[0] if len(argv) >= 1 else 'models/latest.pth'
Expand Down
2 changes: 1 addition & 1 deletion handyrl/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def __init__(self, model, x):
wrapped_model = ModelWrapper(model)
hidden = wrapped_model.init_hidden()
outputs = wrapped_model.inference(x, hidden)
self.output_dict = {key: np.zeros_like(value) for key, value in outputs.items()}
self.output_dict = {key: np.zeros_like(value) for key, value in outputs.items() if key != 'hidden'}

def inference(self, *args):
return self.output_dict
13 changes: 7 additions & 6 deletions handyrl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,9 @@ def replace_none(a, b):
if not args['turn_based_training']: # solo training
players = [random.choice(players)]

obs_zeros = map_r(moments[0]['observation'][moments[0]['turn'][0]], lambda o: np.zeros_like(o)) # template for padding
amask_zeros = np.zeros_like(moments[0]['action_mask'][moments[0]['turn'][0]]) # template for padding
# template for padding
obs_zeros = map_r(moments[0]['observation'][moments[0]['turn'][0]], lambda o: np.zeros_like(o))
amask_zeros = np.zeros_like(moments[0]['action_mask'][moments[0]['turn'][0]])

# data that is changed by training configuration
if args['turn_based_training'] and not args['observation']:
Expand Down Expand Up @@ -113,7 +114,8 @@ def replace_none(a, b):

return {
'observation': obs,
'selected_prob': prob, 'value': v,
'selected_prob': prob,
'value': v,
'action': act, 'outcome': oc,
'reward': rew, 'return': ret,
'episode_mask': emask,
Expand Down Expand Up @@ -336,8 +338,6 @@ def __init__(self, args, model):
self.trained_model = nn.DataParallel(self.wrapped_model)

def update(self):
if len(self.episodes) < self.args['minimum_episodes']:
return None, 0 # return None before training
self.update_flag = True
model, steps = self.update_queue.get()
return model, steps
Expand Down Expand Up @@ -522,7 +522,8 @@ def output_wp(name, results):
name_tag = ' (%s)' % name if name != '' else ''
print('win rate%s = %.3f (%.1f / %d)' % (name_tag, (mean + 1) / 2, (r + n) / 2, n))

if len(self.args.get('eval', {}).get('opponent', [])) <= 1:
keys = self.results_per_opponent[self.model_epoch]
if len(self.args.get('eval', {}).get('opponent', [])) <= 1 and len(keys) <= 1:
output_wp('', self.results[self.model_epoch])
else:
output_wp('total', self.results[self.model_epoch])
Expand Down
8 changes: 6 additions & 2 deletions handyrl/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,11 @@ def run(self):
p.terminate()


def worker_main(args):
def worker_main(args, argv):
# offline generation worker
worker = RemoteWorkerCluster(args=args['worker_args'])
worker_args = args['worker_args']
if len(argv) >= 1:
worker_args['num_parallel'] = int(argv[0])

worker = RemoteWorkerCluster(args=worker_args)
worker.run()
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
main(args)
elif mode == '--worker' or mode == '-w':
from handyrl.worker import worker_main as main
main(args)
main(args, sys.argv[2:])
elif mode == '--eval' or mode == '-e':
from handyrl.evaluation import eval_main as main
main(args, sys.argv[2:])
Expand Down
3 changes: 2 additions & 1 deletion scripts/make_onnx_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@

# Usage: python3 script/make_onnx_model.py MODEL_PATH

import re
import sys
import yaml
import torch
Expand All @@ -13,7 +14,7 @@


model_path = sys.argv[-1]
saved_model_path = model_path.rstrip('.pth') + '.onnx'
saved_model_path = re.sub('\.pth$', '.onnx', model_path)

with open('config.yaml') as f:
args = yaml.safe_load(f)
Expand Down

0 comments on commit a526d0e

Please sign in to comment.