Skip to content

Commit

Permalink
fix: fix reward bugs in SMACv1 oxwhirl/smac#76, and add obs_agent_id
Browse files Browse the repository at this point in the history
  • Loading branch information
leoxhwang committed Jul 8, 2024
1 parent d182f79 commit cac5823
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
2 changes: 1 addition & 1 deletion smac_pettingzoo/env/smacv1.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,7 +793,7 @@ def reward_battle(self):
delta_enemy += prev_health - e_unit.health - e_unit.shield

if self.reward_only_positive:
reward = abs(delta_enemy + delta_deaths) # shield regeneration
reward = max(delta_enemy + delta_deaths, 0) # shield regeneration
else:
reward = delta_enemy + delta_deaths - delta_ally

Expand Down
18 changes: 15 additions & 3 deletions smac_pettingzoo/env/smacv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,7 @@ def step(self, actions: List[int]):
# Observe here so that we know if the episode is over.
self._obs = self._controller.observe()
except (protocol.ProtocolError, protocol.ConnectionError):
self.full_restart()
terminated = True
available_actions = []
for i in range(self.n_agents):
Expand Down Expand Up @@ -1446,6 +1447,7 @@ def get_obs_agent(self, agent_id, fully_observable=False):
enemy_feats = np.zeros(enemy_feats_dim, dtype=np.float32)
ally_feats = np.zeros(ally_feats_dim, dtype=np.float32)
own_feats = np.zeros(own_feats_dim, dtype=np.float32)
agent_id_feats = np.zeros(self.n_agents, dtype=np.float32)

if unit.health > 0 and self.obs_starcraft: # otherwise dead, return all zeros
x = unit.pos.x
Expand Down Expand Up @@ -1600,13 +1602,19 @@ def get_obs_agent(self, agent_id, fully_observable=False):
own_feats.flatten(),
)
)
if self.obs_agent_id:
agent_id_feats[agent_id] = 1.0
agent_obs = np.concatenate((agent_obs, agent_id_feats.flatten()))

if self.obs_timestep_number:
if self.obs_starcraft:
agent_obs = np.append(agent_obs, self._episode_steps / self.episode_limit)
else:
agent_obs = np.zeros(1, dtype=np.float32)
agent_obs[:] = self._episode_steps / self.episode_limit
if self.obs_agent_id:
agent_id_feats[agent_id] = 1.0
agent_obs = np.concatenate((agent_obs, agent_id_feats.flatten()))

if self.debug:
logging.debug("Obs Agent: {}".format(agent_id).center(60, "-"))
Expand Down Expand Up @@ -2050,17 +2058,21 @@ def get_obs_size(self):

all_feats = move_feats + enemy_feats + ally_feats + own_feats

agent_id_feats = 0
timestep_feats = 0
if self.obs_agent_id:
agent_id_feats = self.n_agents
all_feats += agent_id_feats
if self.obs_timestep_number:
timestep_feats = 1
all_feats += timestep_feats

return [
all_feats,
all_feats * self.stacked_frames if self.use_stacked_frames else all_feats,
[n_allies, n_ally_feats],
[n_enemies, n_enemy_feats],
[1, move_feats],
[1, own_feats + timestep_feats],
[1, own_feats + agent_id_feats + timestep_feats],
]

def get_state_size(self):
Expand Down Expand Up @@ -2095,7 +2107,7 @@ def get_state_size(self):
all_feats += timestep_feats

return [
all_feats,
all_feats * self.stacked_frames if self.use_stacked_frames else all_feats,
[n_allies, n_ally_feats],
[self.n_enemies, n_enemy_feats],
[1, move_feats],
Expand Down

0 comments on commit cac5823

Please sign in to comment.