Skip to content

Commit

Permalink
Fixed MA env reporting Including SC2 (#224)
Browse files Browse the repository at this point in the history
* fixed seed
* fixed reporting
---------
  • Loading branch information
Denys88 authored Feb 16, 2023
1 parent 726da21 commit 537a899
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 9 deletions.
1 change: 0 additions & 1 deletion rl_games/algos_torch/a2c_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def __init__(self, base_name, params):
'config' : self.central_value_config,
'writter' : self.writer,
'max_epochs' : self.max_epochs,
'max_frames' : self.max_frames,
'multi_gpu' : self.multi_gpu,
'zero_rnn_on_done' : self.zero_rnn_on_done
}
Expand Down
7 changes: 4 additions & 3 deletions rl_games/common/a2c_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,8 +685,8 @@ def play_steps(self):
self.current_rewards += rewards
self.current_lengths += 1
all_done_indices = self.dones.nonzero(as_tuple=False)
env_done_indices = self.dones.view(self.num_actors, self.num_agents).all(dim=1).nonzero(as_tuple=False)

env_done_indices = all_done_indices[::self.num_agents]
self.game_rewards.update(self.current_rewards[env_done_indices])
self.game_lengths.update(self.current_lengths[env_done_indices])
self.algo_observer.process_infos(infos, env_done_indices)
Expand Down Expand Up @@ -755,7 +755,8 @@ def play_steps_rnn(self):
self.current_rewards += rewards
self.current_lengths += 1
all_done_indices = self.dones.nonzero(as_tuple=False)
env_done_indices = self.dones.view(self.num_actors, self.num_agents).all(dim=1).nonzero(as_tuple=False)
env_done_indices = all_done_indices[::self.num_agents]

if len(all_done_indices) > 0:
if self.zero_rnn_on_done:
for s in self.rnn_states:
Expand Down
4 changes: 3 additions & 1 deletion rl_games/configs/smac/3m_torch.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,13 @@ params:
lr_schedule: None
kl_threshold: 0.05
normalize_input: True
#normalize_value: True
use_action_masks: True
ignore_dead_batches : False

env_config:
name: 3m
frames: 1
transpose: False
random_invalid_step: False
random_invalid_step: False
obs_last_action: True
8 changes: 4 additions & 4 deletions rl_games/envs/smac_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
class SMACEnv(gym.Env):
def __init__(self, name="3m", **kwargs):
gym.Env.__init__(self)
self.seed = kwargs.pop('seed', None)
self._seed = kwargs.pop('seed', None)
self.reward_sparse = kwargs.get('reward_sparse', False)
self.use_central_value = kwargs.pop('central_value', False)
self.concat_infos = True
self.random_invalid_step = kwargs.pop('random_invalid_step', False)
self.replay_save_freq = kwargs.pop('replay_save_freq', 10000)
self.apply_agent_ids = kwargs.pop('apply_agent_ids', False)
self.env = StarCraft2Env(map_name=name, seed=self.seed, **kwargs)
self.apply_agent_ids = kwargs.pop('apply_agent_ids', True)
self.env = StarCraft2Env(map_name=name, seed=self._seed, **kwargs)
self.env_info = self.env.get_env_info()

self._game_num = 0
Expand Down Expand Up @@ -101,6 +101,6 @@ def get_action_mask(self):
def has_action_mask(self):
return not self.random_invalid_step

def seed(self, val):
def seed(self, _):
pass

0 comments on commit 537a899

Please sign in to comment.