Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion source/isaaclab/config/extension.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[package]

# Note: Semantic Versioning is used: https://semver.org/
version = "0.44.9"
version = "0.44.10"

# Description
title = "Isaac Lab framework for Robot Learning"
Expand Down
11 changes: 11 additions & 0 deletions source/isaaclab/docs/CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
Changelog
---------

0.44.10 (2025-08-06)
~~~~~~~~~~~~~~~~~~~

Fixed
^^^^^

* The old termination manager in :class:`~isaaclab.managers.TerminationManager` term_done logging logs the instantaneous
term done count at reset. This let to inaccurate aggregation of termination count, obscuring the what really happening
during the traing. Instead we log the episodic term done.


0.44.9 (2025-07-30)
~~~~~~~~~~~~~~~~~~~

Expand Down
27 changes: 15 additions & 12 deletions source/isaaclab/isaaclab/managers/termination_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,9 @@ def __init__(self, cfg: object, env: ManagerBasedRLEnv):

# call the base class constructor (this will parse the terms config)
super().__init__(cfg, env)
self._term_name_to_term_idx = {name: i for i, name in enumerate(self._term_names)}
# prepare extra info to store individual termination term information
self._term_dones = dict()
for term_name in self._term_names:
self._term_dones[term_name] = torch.zeros(self.num_envs, device=self.device, dtype=torch.bool)
self._term_dones = torch.zeros((self.num_envs, len(self._term_names)), device=self.device, dtype=torch.bool)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason to change this as a dict of tensors to a single tensor?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought it last_episode_done_stats = self._term_dones.float().mean(dim=0) this operation is very nice and optimized, thats why I did it, but if you think dict is more clear I can revert back : ))

# create buffer for managing termination per environment
self._truncated_buf = torch.zeros(self.num_envs, device=self.device, dtype=torch.bool)
self._terminated_buf = torch.zeros_like(self._truncated_buf)
Expand Down Expand Up @@ -139,9 +138,10 @@ def reset(self, env_ids: Sequence[int] | None = None) -> dict[str, torch.Tensor]
env_ids = slice(None)
# add to episode dict
extras = {}
for key in self._term_dones.keys():
last_episode_done_stats = self._term_dones.float().mean(dim=0)
for i, key in enumerate(self._term_names):
# store information
extras["Episode_Termination/" + key] = torch.count_nonzero(self._term_dones[key][env_ids]).item()
extras["Episode_Termination/" + key] = last_episode_done_stats[i].item()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want the ratio, isn't that simply?

self._term_dones[key][env_ids].sum() / len(env_ids)

Copy link
Collaborator Author

@ooctipus ooctipus Aug 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for reveiwing!
I guess, doing with env_ids will be viewing ratio of resetting environments. I thought maybe report stats of all environment can be a bit more nicer as user can verify from the graph that all terms sum up to 1. Of course you can do self._term_dones[key].sum() / self.env.num_envs as well

But it seems like if this approach is what I after, do it in one tensor operation seems quite nice, both speed wise and memory utility wise.

# reset all the reward terms
for term_cfg in self._class_term_cfgs:
term_cfg.func.reset(env_ids=env_ids)
Expand All @@ -161,15 +161,18 @@ def compute(self) -> torch.Tensor:
self._truncated_buf[:] = False
self._terminated_buf[:] = False
# iterate over all the termination terms
for name, term_cfg in zip(self._term_names, self._term_cfgs):
for i, term_cfg in enumerate(self._term_cfgs):
value = term_cfg.func(self._env, **term_cfg.params)
# store timeout signal separately
if term_cfg.time_out:
self._truncated_buf |= value
else:
self._terminated_buf |= value
# add to episode dones
self._term_dones[name][:] = value
rows = value.nonzero(as_tuple=True)[0] # indexing is cheaper than boolean advance indexing
if rows.numel() > 0:
self._term_dones[rows] = False
self._term_dones[rows, i] = True
# return combined termination signal
return self._truncated_buf | self._terminated_buf

Expand All @@ -182,7 +185,7 @@ def get_term(self, name: str) -> torch.Tensor:
Returns:
The corresponding termination term value. Shape is (num_envs,).
"""
return self._term_dones[name]
return self._term_dones[:, self._term_name_to_term_idx[name]]

def get_active_iterable_terms(self, env_idx: int) -> Sequence[tuple[str, Sequence[float]]]:
"""Returns the active terms as iterable sequence of tuples.
Expand All @@ -196,8 +199,8 @@ def get_active_iterable_terms(self, env_idx: int) -> Sequence[tuple[str, Sequenc
The active terms.
"""
terms = []
for key in self._term_dones.keys():
terms.append((key, [self._term_dones[key][env_idx].float().cpu().item()]))
for i, key in enumerate(self._term_names):
terms.append((key, [self._term_dones[env_idx, i].float().cpu().item()]))
return terms

"""
Expand All @@ -217,7 +220,7 @@ def set_term_cfg(self, term_name: str, cfg: TerminationTermCfg):
if term_name not in self._term_names:
raise ValueError(f"Termination term '{term_name}' not found.")
# set the configuration
self._term_cfgs[self._term_names.index(term_name)] = cfg
self._term_cfgs[self._term_name_to_term_idx[term_name]] = cfg

def get_term_cfg(self, term_name: str) -> TerminationTermCfg:
"""Gets the configuration for the specified term.
Expand All @@ -234,7 +237,7 @@ def get_term_cfg(self, term_name: str) -> TerminationTermCfg:
if term_name not in self._term_names:
raise ValueError(f"Termination term '{term_name}' not found.")
# return the configuration
return self._term_cfgs[self._term_names.index(term_name)]
return self._term_cfgs[self._term_name_to_term_idx[term_name]]

"""
Helper functions.
Expand Down
Loading