diff --git a/rl_games/algos_torch/a2c_discrete.py b/rl_games/algos_torch/a2c_discrete.py index 9bfda767..4f911437 100644 --- a/rl_games/algos_torch/a2c_discrete.py +++ b/rl_games/algos_torch/a2c_discrete.py @@ -49,7 +49,6 @@ def __init__(self, base_name, params): 'config' : self.central_value_config, 'writter' : self.writer, 'max_epochs' : self.max_epochs, - 'max_frames' : self.max_frames, 'multi_gpu' : self.multi_gpu, 'zero_rnn_on_done' : self.zero_rnn_on_done } diff --git a/rl_games/common/a2c_common.py b/rl_games/common/a2c_common.py index 93900668..a085abf1 100644 --- a/rl_games/common/a2c_common.py +++ b/rl_games/common/a2c_common.py @@ -685,8 +685,8 @@ def play_steps(self): self.current_rewards += rewards self.current_lengths += 1 all_done_indices = self.dones.nonzero(as_tuple=False) - env_done_indices = self.dones.view(self.num_actors, self.num_agents).all(dim=1).nonzero(as_tuple=False) - + env_done_indices = all_done_indices[::self.num_agents] + self.game_rewards.update(self.current_rewards[env_done_indices]) self.game_lengths.update(self.current_lengths[env_done_indices]) self.algo_observer.process_infos(infos, env_done_indices) @@ -755,7 +755,8 @@ def play_steps_rnn(self): self.current_rewards += rewards self.current_lengths += 1 all_done_indices = self.dones.nonzero(as_tuple=False) - env_done_indices = self.dones.view(self.num_actors, self.num_agents).all(dim=1).nonzero(as_tuple=False) + env_done_indices = all_done_indices[::self.num_agents] + if len(all_done_indices) > 0: if self.zero_rnn_on_done: for s in self.rnn_states: diff --git a/rl_games/configs/smac/3m_torch.yaml b/rl_games/configs/smac/3m_torch.yaml index 28b7b800..10bf9f2e 100644 --- a/rl_games/configs/smac/3m_torch.yaml +++ b/rl_games/configs/smac/3m_torch.yaml @@ -42,6 +42,7 @@ params: lr_schedule: None kl_threshold: 0.05 normalize_input: True + #normalize_value: True use_action_masks: True ignore_dead_batches : False @@ -49,4 +50,5 @@ params: name: 3m frames: 1 transpose: False - random_invalid_step: False \ No newline at end of file + random_invalid_step: False + obs_last_action: True \ No newline at end of file diff --git a/rl_games/envs/smac_env.py b/rl_games/envs/smac_env.py index c36f4ca6..a4039e55 100644 --- a/rl_games/envs/smac_env.py +++ b/rl_games/envs/smac_env.py @@ -6,14 +6,14 @@ class SMACEnv(gym.Env): def __init__(self, name="3m", **kwargs): gym.Env.__init__(self) - self.seed = kwargs.pop('seed', None) + self._seed = kwargs.pop('seed', None) self.reward_sparse = kwargs.get('reward_sparse', False) self.use_central_value = kwargs.pop('central_value', False) self.concat_infos = True self.random_invalid_step = kwargs.pop('random_invalid_step', False) self.replay_save_freq = kwargs.pop('replay_save_freq', 10000) - self.apply_agent_ids = kwargs.pop('apply_agent_ids', False) - self.env = StarCraft2Env(map_name=name, seed=self.seed, **kwargs) + self.apply_agent_ids = kwargs.pop('apply_agent_ids', True) + self.env = StarCraft2Env(map_name=name, seed=self._seed, **kwargs) self.env_info = self.env.get_env_info() self._game_num = 0 @@ -101,6 +101,6 @@ def get_action_mask(self): def has_action_mask(self): return not self.random_invalid_step - def seed(self, val): + def seed(self, _): pass