Skip to content

Commit 2a336db

Browse files
committed
fix the per-step termination log on reset to per-episode termination log
1 parent 19b24c7 commit 2a336db

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

source/isaaclab/isaaclab/managers/termination_manager.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,7 @@ def __init__(self, cfg: object, env: ManagerBasedRLEnv):
6161
# call the base class constructor (this will parse the terms config)
6262
super().__init__(cfg, env)
6363
# prepare extra info to store individual termination term information
64-
self._term_dones = dict()
65-
for term_name in self._term_names:
66-
self._term_dones[term_name] = torch.zeros(self.num_envs, device=self.device, dtype=torch.bool)
64+
self._term_dones = torch.zeros((self.num_envs, len(self._term_names)), device=self.device, dtype=torch.bool)
6765
# create buffer for managing termination per environment
6866
self._truncated_buf = torch.zeros(self.num_envs, device=self.device, dtype=torch.bool)
6967
self._terminated_buf = torch.zeros_like(self._truncated_buf)
@@ -139,9 +137,10 @@ def reset(self, env_ids: Sequence[int] | None = None) -> dict[str, torch.Tensor]
139137
env_ids = slice(None)
140138
# add to episode dict
141139
extras = {}
142-
for key in self._term_dones.keys():
140+
last_episode_done_stats = self._term_dones.float().mean(dim=0)
141+
for i, key in enumerate(self._term_names):
143142
# store information
144-
extras["Episode_Termination/" + key] = torch.count_nonzero(self._term_dones[key][env_ids]).item()
143+
extras["Episode_Termination/" + key] = last_episode_done_stats[i].item()
145144
# reset all the reward terms
146145
for term_cfg in self._class_term_cfgs:
147146
term_cfg.func.reset(env_ids=env_ids)
@@ -161,15 +160,16 @@ def compute(self) -> torch.Tensor:
161160
self._truncated_buf[:] = False
162161
self._terminated_buf[:] = False
163162
# iterate over all the termination terms
164-
for name, term_cfg in zip(self._term_names, self._term_cfgs):
163+
for i, term_cfg in enumerate(self._term_cfgs):
165164
value = term_cfg.func(self._env, **term_cfg.params)
166165
# store timeout signal separately
167166
if term_cfg.time_out:
168167
self._truncated_buf |= value
169168
else:
170169
self._terminated_buf |= value
171170
# add to episode dones
172-
self._term_dones[name][:] = value
171+
self._term_dones[value] = False
172+
self._term_dones[value, i] = True
173173
# return combined termination signal
174174
return self._truncated_buf | self._terminated_buf
175175

@@ -182,7 +182,7 @@ def get_term(self, name: str) -> torch.Tensor:
182182
Returns:
183183
The corresponding termination term value. Shape is (num_envs,).
184184
"""
185-
return self._term_dones[name]
185+
return self._term_dones[name, self._term_names.index(name)]
186186

187187
def get_active_iterable_terms(self, env_idx: int) -> Sequence[tuple[str, Sequence[float]]]:
188188
"""Returns the active terms as iterable sequence of tuples.
@@ -196,8 +196,8 @@ def get_active_iterable_terms(self, env_idx: int) -> Sequence[tuple[str, Sequenc
196196
The active terms.
197197
"""
198198
terms = []
199-
for key in self._term_dones.keys():
200-
terms.append((key, [self._term_dones[key][env_idx].float().cpu().item()]))
199+
for i, key in enumerate(self._term_names):
200+
terms.append((key, [self._term_dones[env_idx, i].float().cpu().item()]))
201201
return terms
202202

203203
"""

0 commit comments

Comments
 (0)