From cac5823e919788f1c08e1136f62cf5e0ddf4f15c Mon Sep 17 00:00:00 2001 From: leoxhwang <1134086740@qq.com> Date: Mon, 8 Jul 2024 15:51:30 +0800 Subject: [PATCH] fix: fix reward bugs in SMACv1 https://github.com/oxwhirl/smac/pull/76, and add obs_agent_id --- smac_pettingzoo/env/smacv1.py | 2 +- smac_pettingzoo/env/smacv2.py | 18 +++++++++++++++--- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/smac_pettingzoo/env/smacv1.py b/smac_pettingzoo/env/smacv1.py index 93c3a1c..dfccce7 100644 --- a/smac_pettingzoo/env/smacv1.py +++ b/smac_pettingzoo/env/smacv1.py @@ -793,7 +793,7 @@ def reward_battle(self): delta_enemy += prev_health - e_unit.health - e_unit.shield if self.reward_only_positive: - reward = abs(delta_enemy + delta_deaths) # shield regeneration + reward = max(delta_enemy + delta_deaths, 0) # shield regeneration else: reward = delta_enemy + delta_deaths - delta_ally diff --git a/smac_pettingzoo/env/smacv2.py b/smac_pettingzoo/env/smacv2.py index 1cb0517..59cb863 100644 --- a/smac_pettingzoo/env/smacv2.py +++ b/smac_pettingzoo/env/smacv2.py @@ -623,6 +623,7 @@ def step(self, actions: List[int]): # Observe here so that we know if the episode is over. self._obs = self._controller.observe() except (protocol.ProtocolError, protocol.ConnectionError): + self.full_restart() terminated = True available_actions = [] for i in range(self.n_agents): @@ -1446,6 +1447,7 @@ def get_obs_agent(self, agent_id, fully_observable=False): enemy_feats = np.zeros(enemy_feats_dim, dtype=np.float32) ally_feats = np.zeros(ally_feats_dim, dtype=np.float32) own_feats = np.zeros(own_feats_dim, dtype=np.float32) + agent_id_feats = np.zeros(self.n_agents, dtype=np.float32) if unit.health > 0 and self.obs_starcraft: # otherwise dead, return all zeros x = unit.pos.x @@ -1600,6 +1602,9 @@ def get_obs_agent(self, agent_id, fully_observable=False): own_feats.flatten(), ) ) + if self.obs_agent_id: + agent_id_feats[agent_id] = 1.0 + agent_obs = np.concatenate((agent_obs, agent_id_feats.flatten())) if self.obs_timestep_number: if self.obs_starcraft: @@ -1607,6 +1612,9 @@ def get_obs_agent(self, agent_id, fully_observable=False): else: agent_obs = np.zeros(1, dtype=np.float32) agent_obs[:] = self._episode_steps / self.episode_limit + if self.obs_agent_id: + agent_id_feats[agent_id] = 1.0 + agent_obs = np.concatenate((agent_obs, agent_id_feats.flatten())) if self.debug: logging.debug("Obs Agent: {}".format(agent_id).center(60, "-")) @@ -2050,17 +2058,21 @@ def get_obs_size(self): all_feats = move_feats + enemy_feats + ally_feats + own_feats + agent_id_feats = 0 timestep_feats = 0 + if self.obs_agent_id: + agent_id_feats = self.n_agents + all_feats += agent_id_feats if self.obs_timestep_number: timestep_feats = 1 all_feats += timestep_feats return [ - all_feats, + all_feats * self.stacked_frames if self.use_stacked_frames else all_feats, [n_allies, n_ally_feats], [n_enemies, n_enemy_feats], [1, move_feats], - [1, own_feats + timestep_feats], + [1, own_feats + agent_id_feats + timestep_feats], ] def get_state_size(self): @@ -2095,7 +2107,7 @@ def get_state_size(self): all_feats += timestep_feats return [ - all_feats, + all_feats * self.stacked_frames if self.use_stacked_frames else all_feats, [n_allies, n_ally_feats], [self.n_enemies, n_enemy_feats], [1, move_feats],