Skip to content

Commit

Permalink
Merge pull request #348 from DeNA/develop
Browse files Browse the repository at this point in the history
(2023/11) merge develop into master
  • Loading branch information
YuriCat authored Nov 28, 2023
2 parents 6bdcd87 + 403ae35 commit 83379bd
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 23 deletions.
3 changes: 1 addition & 2 deletions .github/workflows/action.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
strategy:
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
python-version: [3.7, 3.8, 3.9]
python-version: ['3.8', '3.9', '3.10']
steps:
- name: Checkout
uses: actions/checkout@v2
Expand All @@ -24,7 +24,6 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install -r handyrl/envs/kaggle/requirements.txt
- name: pytest
run: |
python -m pytest tests
19 changes: 10 additions & 9 deletions handyrl/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def exec_match(env, agents, critic=None, show=False, game_args={}):
outcome = env.outcome()
if show:
print('final outcome = %s' % outcome)
return outcome
return {'result': outcome}


def exec_network_match(env, network_agents, critic=None, show=False, game_args={}):
Expand Down Expand Up @@ -138,7 +138,7 @@ def exec_network_match(env, network_agents, critic=None, show=False, game_args={
outcome = env.outcome()
for p, agent in network_agents.items():
agent.outcome(outcome[p])
return outcome
return {'result': outcome}


def build_agent(raw, env=None):
Expand Down Expand Up @@ -170,11 +170,11 @@ def execute(self, models, args):
else:
agents[p] = Agent(model)

outcome = exec_match(self.env, agents)
if outcome is None:
results = exec_match(self.env, agents)
if results is None:
print('None episode in evaluation!')
return None
return {'args': args, 'result': outcome, 'opponent': opponent}
return {'args': args, 'opponent': opponent, **results}


def wp_func(results):
Expand All @@ -196,10 +196,10 @@ def eval_process_mp_child(agents, critic, env_args, index, in_queue, out_queue,
print('*** Game %d ***' % g)
agent_map = {env.players()[p]: agents[ai] for p, ai in enumerate(agent_ids)}
if isinstance(list(agent_map.values())[0], NetworkAgent):
outcome = exec_network_match(env, agent_map, critic, show=show, game_args=game_args)
results = exec_network_match(env, agent_map, critic, show=show, game_args=game_args)
else:
outcome = exec_match(env, agent_map, critic, show=show, game_args=game_args)
out_queue.put((pat_idx, agent_ids, outcome))
results = exec_match(env, agent_map, critic, show=show, game_args=game_args)
out_queue.put((pat_idx, agent_ids, results))
out_queue.put(None)


Expand Down Expand Up @@ -246,7 +246,8 @@ def evaluate_mp(env, agents, critic, env_args, args_patterns, num_process, num_g
if ret is None:
finished_cnt += 1
continue
pat_idx, agent_ids, outcome = ret
pat_idx, agent_ids, results = ret
outcome = results.get('result')
if outcome is not None:
for idx, p in enumerate(env.players()):
agent_id = agent_ids[idx]
Expand Down
21 changes: 10 additions & 11 deletions handyrl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,24 +63,23 @@ def replace_none(a, b):

# data that is changed by training configuration
if args['turn_based_training'] and not args['observation']:
obs = [[m['observation'][m['turn'][0]]] for m in moments]
prob = np.array([[[m['selected_prob'][m['turn'][0]]]] for m in moments])
act = np.array([[m['action'][m['turn'][0]]] for m in moments], dtype=np.int64)[..., np.newaxis]
amask = np.array([[m['action_mask'][m['turn'][0]]] for m in moments])
players_list = [[m['turn'][0]] for m in moments]
else:
obs = [[replace_none(m['observation'][player], obs_zeros) for player in players] for m in moments]
prob = np.array([[[replace_none(m['selected_prob'][player], 1.0)] for player in players] for m in moments])
act = np.array([[replace_none(m['action'][player], 0) for player in players] for m in moments], dtype=np.int64)[..., np.newaxis]
amask = np.array([[replace_none(m['action_mask'][player], amask_zeros + 1e32) for player in players] for m in moments])
players_list = [players for m in moments]

obs = [[replace_none(m['observation'][player], obs_zeros) for player in players_] for m, players_ in zip(moments, players_list)]
prob = np.array([[[replace_none(m['selected_prob'][player], 1.0)] for player in players_] for m, players_ in zip(moments, players_list)])
act = np.array([[replace_none(m['action'][player], 0) for player in players_] for m, players_ in zip(moments, players_list)], dtype=np.int64)[..., np.newaxis]
amask = np.array([[replace_none(m['action_mask'][player], amask_zeros + 1e32) for player in players_] for m, players_ in zip(moments, players_list)])

# reshape observation
obs = rotate(rotate(obs)) # (T, P, ..., ...) -> (P, ..., T, ...) -> (..., T, P, ...)
obs = bimap_r(obs_zeros, obs, lambda _, o: np.array(o))

# datum that is not changed by training configuration
v = np.array([[replace_none(m['value'][player], [0]) for player in players] for m in moments], dtype=np.float32).reshape(len(moments), len(players), -1)
rew = np.array([[replace_none(m['reward'][player], [0]) for player in players] for m in moments], dtype=np.float32).reshape(len(moments), len(players), -1)
ret = np.array([[replace_none(m['return'][player], [0]) for player in players] for m in moments], dtype=np.float32).reshape(len(moments), len(players), -1)
v = np.array([[replace_none(m['value'][player], 0) for player in players] for m in moments], dtype=np.float32).reshape(len(moments), len(players), -1)
rew = np.array([[replace_none(m['reward'][player], 0) for player in players] for m in moments], dtype=np.float32).reshape(len(moments), len(players), -1)
ret = np.array([[replace_none(m['return'][player], 0) for player in players] for m in moments], dtype=np.float32).reshape(len(moments), len(players), -1)
oc = np.array([ep['outcome'][player] for player in players], dtype=np.float32).reshape(1, len(players), -1)

emask = np.ones((len(moments), 1, 1), dtype=np.float32) # episode mask
Expand Down
2 changes: 1 addition & 1 deletion tests/test_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
'tictactoe',
'geister',
'parallel_tictactoe',
'kaggle.hungry_geese',
# 'kaggle.hungry_geese',
]


Expand Down

0 comments on commit 83379bd

Please sign in to comment.