Skip to content

Commit

Permalink
Merge develop
Browse files Browse the repository at this point in the history
  • Loading branch information
YuriCat committed Apr 27, 2022
2 parents 704f501 + 4d6b3a3 commit 95d3551
Show file tree
Hide file tree
Showing 10 changed files with 317 additions and 171 deletions.
24 changes: 16 additions & 8 deletions handyrl/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,12 @@ def observe(self, env, player, show=False):


class RuleBasedAgent(RandomAgent):
def __init__(self, key=None):
self.key = None

def action(self, env, player, show=False):
if hasattr(env, 'rule_based_action'):
return env.rule_based_action(player)
return env.rule_based_action(player, key=self.key)
else:
return random.choice(env.legal_actions(player))

Expand All @@ -41,11 +44,12 @@ def print_outputs(env, prob, v):


class Agent:
def __init__(self, model, temperature=0.0):
def __init__(self, model, temperature=0.0, observation=True):
# model might be a neural net, or some planning algorithm such as game tree search
self.model = model
self.hidden = None
self.temperature = temperature
self.observation = observation

def reset(self, env, show=False):
self.hidden = self.model.init_hidden()
Expand All @@ -56,7 +60,8 @@ def plan(self, obs):
return outputs

def action(self, env, player, show=False):
outputs = self.plan(env.observation(player))
obs = env.observation(player)
outputs = self.plan(obs)
actions = env.legal_actions(player)
p = outputs['policy']
v = outputs.get('value', None)
Expand All @@ -74,11 +79,14 @@ def action(self, env, player, show=False):
return random.choices(np.arange(len(p)), weights=softmax(p / self.temperature))[0]

def observe(self, env, player, show=False):
outputs = self.plan(env.observation(player))
v = outputs.get('value', None)
if show:
print_outputs(env, None, v)
return v if v is not None else [0.0]
v = None
if self.observation:
obs = env.observation(player)
outputs = self.plan(obs)
v = outputs.get('value', None)
if show:
print_outputs(env, None, v)
return v


class EnsembleAgent(Agent):
Expand Down
93 changes: 28 additions & 65 deletions handyrl/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,109 +131,77 @@ def open_multiprocessing_connections(num_process, target, args_func):


class MultiProcessJobExecutor:
def __init__(self, func, send_generator, num_workers, postprocess=None, num_receivers=1):
def __init__(self, func, send_generator, num_workers, postprocess=None):
self.send_generator = send_generator
self.postprocess = postprocess
self.num_receivers = num_receivers
self.conns = []
self.waiting_conns = queue.Queue()
self.shutdown_flag = False
self.output_queue = queue.Queue(maxsize=8)
self.threads = []

for i in range(num_workers):
conn0, conn1 = mp.Pipe(duplex=True)
mp.Process(target=func, args=(conn1, i)).start()
mp.Process(target=func, args=(conn1, i), daemon=True).start()
conn1.close()
self.conns.append(conn0)
self.waiting_conns.put(conn0)

def shutdown(self):
self.shutdown_flag = True
for thread in self.threads:
thread.join()

def recv(self):
return self.output_queue.get()

def start(self):
self.threads.append(threading.Thread(target=self._sender))
for i in range(self.num_receivers):
self.threads.append(threading.Thread(target=self._receiver, args=(i,)))
for thread in self.threads:
thread.start()
threading.Thread(target=self._sender, daemon=True).start()
threading.Thread(target=self._receiver, daemon=True).start()

def _sender(self):
print('start sender')
while not self.shutdown_flag:
while True:
data = next(self.send_generator)
while not self.shutdown_flag:
try:
conn = self.waiting_conns.get(timeout=0.3)
conn.send(data)
break
except queue.Empty:
pass
conn = self.waiting_conns.get()
conn.send(data)
print('finished sender')

def _receiver(self, index):
print('start receiver %d' % index)
conns = [conn for i, conn in enumerate(self.conns) if i % self.num_receivers == index]
while not self.shutdown_flag:
tmp_conns = connection.wait(conns)
for conn in tmp_conns:
def _receiver(self):
print('start receiver')
while True:
conns = connection.wait(self.conns)
for conn in conns:
data = conn.recv()
self.waiting_conns.put(conn)
if self.postprocess is not None:
data = self.postprocess(data)
while not self.shutdown_flag:
try:
self.output_queue.put(data, timeout=0.3)
break
except queue.Full:
pass
print('finished receiver %d' % index)
self.output_queue.put(data)
print('finished receiver')


class QueueCommunicator:
def __init__(self, conns=[]):
self.input_queue = queue.Queue(maxsize=256)
self.output_queue = queue.Queue(maxsize=256)
self.conns = []
self.conns = set()
for conn in conns:
self.add_connection(conn)
self.shutdown_flag = False
self.threads = [
threading.Thread(target=self._send_thread),
threading.Thread(target=self._recv_thread),
]
for thread in self.threads:
thread.start()

def shutdown(self):
self.shutdown_flag = True
for thread in self.threads:
thread.join()
threading.Thread(target=self._send_thread, daemon=True).start()
threading.Thread(target=self._recv_thread, daemon=True).start()

def recv(self):
return self.input_queue.get()
def connection_count(self):
return len(self.conns)

def recv(self, timeout=None):
return self.input_queue.get(timeout=timeout)

def send(self, conn, send_data):
self.output_queue.put((conn, send_data))

def add_connection(self, conn):
self.conns.append(conn)
self.conns.add(conn)

def disconnect(self, conn):
print('disconnected')
self.conns.remove(conn)
self.conns.discard(conn)

def _send_thread(self):
while not self.shutdown_flag:
try:
conn, send_data = self.output_queue.get(timeout=0.3)
except queue.Empty:
continue
while True:
conn, send_data = self.output_queue.get()
try:
conn.send(send_data)
except ConnectionResetError:
Expand All @@ -242,7 +210,7 @@ def _send_thread(self):
self.disconnect(conn)

def _recv_thread(self):
while not self.shutdown_flag:
while True:
conns = connection.wait(self.conns, timeout=0.3)
for conn in conns:
try:
Expand All @@ -253,9 +221,4 @@ def _recv_thread(self):
except EOFError:
self.disconnect(conn)
continue
while not self.shutdown_flag:
try:
self.input_queue.put((conn, recv_data), timeout=0.3)
break
except queue.Full:
pass
self.input_queue.put((conn, recv_data))
2 changes: 1 addition & 1 deletion handyrl/envs/kaggle/hungry_geese.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def legal_actions(self, player):
def players(self):
return list(range(self.NUM_AGENTS))

def rule_based_action(self, player):
def rule_based_action(self, player, key=None):
from kaggle_environments.envs.hungry_geese.hungry_geese import Observation, Configuration, Action, GreedyAgent
action_map = {'N': Action.NORTH, 'S': Action.SOUTH, 'W': Action.WEST, 'E': Action.EAST}

Expand Down
8 changes: 5 additions & 3 deletions handyrl/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,9 @@ def exec_network_match(env, network_agents, critic=None, show=False, game_args={
def build_agent(raw, env=None):
if raw == 'random':
return RandomAgent()
elif raw == 'rulebase':
return RuleBasedAgent()
elif raw.startswith('rulebase'):
key = raw.split('-')[1] if '-' in raw else None
return RuleBasedAgent(key)
return None


Expand Down Expand Up @@ -351,10 +352,11 @@ def insert_input(y):
return outputs


def load_model(model_path, model):
def load_model(model_path, model=None):
if model_path.endswith('.onnx'):
model = OnnxModel(model_path)
return model
assert model is not None
import torch
from .model import ModelWrapper
model.load_state_dict(torch.load(model_path))
Expand Down
2 changes: 2 additions & 0 deletions handyrl/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ def generate(self, models, args):
for player in self.env.players():
if player not in turn_players + observers:
continue
if player not in turn_players and player in args['player'] and not self.args['observation']:
continue

obs = self.env.observation(player)
model = models[player]
Expand Down
Loading

0 comments on commit 95d3551

Please sign in to comment.