Skip to content

Commit f498245

Browse files
ooctipuskellyguo11
andauthored
Separates per-step termination and last-episode termination bookkeeping (#3745)
# Description This PR fixes the issue where get_done_term returned last episode value rather than current step value. This PR realizes values used for get_term should be different from that used for logging, and mixed useage leads to non-intuitive behavior. using per-step value for logging leads to overcounting and undercounting reported in #2977 using last-episode value for get_term leads to misalignment with expectation reported in #3720 Fixes #2977 #3720 --- The logging behavior remains *mostly* the same as #3107, and and also got rid of the weird overwriting behavior(yay). I get exactly the same termination curve as #3107 when run on `Isaac-Velocity-Rough-Anymal-C-v0` Here is a benchmark summary with 1000 steps running `Isaac-Velocity-Rough-Anymal-C-v0 ` with 4096 envs Before #3107: `| termination.compute | 0.229 ms|` `| termination.reset | 0.007 ms|` PR #3107: `| termination.compute | 0.274 ms|` `| termination.reset | 0.004 ms|` This PR: `| termination.compute | 0.258 ms|` `| termination.reset | 0.004 ms|` We actually see improvement, this is due to the fact that expensive maintenance of last_episode_value is only computed once per compute(#3107 computes last_episode_value for every term) ## Type of change - Bug fix (non-breaking change which fixes an issue) ## Checklist - [x] I have read and understood the [contribution guidelines](https://isaac-sim.github.io/IsaacLab/main/source/refs/contributing.html) - [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with `./isaaclab.sh --format` - [ ] I have made corresponding changes to the documentation - [x] My changes generate no new warnings - [ ] I have added tests that prove my fix is effective or that my feature works - [x] I have updated the changelog and the corresponding version in the extension's `config/extension.toml` file - [x] I have added my name to the `CONTRIBUTORS.md` or my name already exists there --------- Signed-off-by: Kelly Guo <[email protected]> Co-authored-by: Kelly Guo <[email protected]>
1 parent c8b6d22 commit f498245

File tree

4 files changed

+165
-8
lines changed

4 files changed

+165
-8
lines changed

source/isaaclab/config/extension.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[package]
22

33
# Note: Semantic Versioning is used: https://semver.org/
4-
version = "0.47.8"
4+
version = "0.47.9"
55

66
# Description
77
title = "Isaac Lab framework for Robot Learning"

source/isaaclab/docs/CHANGELOG.rst

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,18 @@
11
Changelog
22
---------
33

4+
0.47.9 (2025-11-05)
5+
~~~~~~~~~~~~~~~~~~~
6+
7+
Changed
8+
^^^^^^^
9+
10+
* Fixed termination term bookkeeping in :class:`~isaaclab.managers.TerminationManager`:
11+
per-step termination and last-episode termination bookkeeping are now separated.
12+
last-episode dones are now updated once per step from all term outputs, avoiding per-term overwrites
13+
and ensuring Episode_Termination metrics reflect the actual triggering terms.
14+
15+
416
0.47.8 (2025-11-06)
517
~~~~~~~~~~~~~~~~~~~
618

source/isaaclab/isaaclab/managers/termination_manager.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ def __init__(self, cfg: object, env: ManagerBasedRLEnv):
6363
self._term_name_to_term_idx = {name: i for i, name in enumerate(self._term_names)}
6464
# prepare extra info to store individual termination term information
6565
self._term_dones = torch.zeros((self.num_envs, len(self._term_names)), device=self.device, dtype=torch.bool)
66+
# prepare extra info to store last episode done per termination term information
67+
self._last_episode_dones = torch.zeros_like(self._term_dones)
6668
# create buffer for managing termination per environment
6769
self._truncated_buf = torch.zeros(self.num_envs, device=self.device, dtype=torch.bool)
6870
self._terminated_buf = torch.zeros_like(self._truncated_buf)
@@ -138,7 +140,7 @@ def reset(self, env_ids: Sequence[int] | None = None) -> dict[str, torch.Tensor]
138140
env_ids = slice(None)
139141
# add to episode dict
140142
extras = {}
141-
last_episode_done_stats = self._term_dones.float().mean(dim=0)
143+
last_episode_done_stats = self._last_episode_dones.float().mean(dim=0)
142144
for i, key in enumerate(self._term_names):
143145
# store information
144146
extras["Episode_Termination/" + key] = last_episode_done_stats[i].item()
@@ -169,15 +171,17 @@ def compute(self) -> torch.Tensor:
169171
else:
170172
self._terminated_buf |= value
171173
# add to episode dones
172-
rows = value.nonzero(as_tuple=True)[0] # indexing is cheaper than boolean advance indexing
173-
if rows.numel() > 0:
174-
self._term_dones[rows] = False
175-
self._term_dones[rows, i] = True
174+
self._term_dones[:, i] = value
175+
# update last-episode dones once per compute: for any env where a term fired,
176+
# reflect exactly which term(s) fired this step and clear others
177+
rows = self._term_dones.any(dim=1).nonzero(as_tuple=True)[0]
178+
if rows.numel() > 0:
179+
self._last_episode_dones[rows] = self._term_dones[rows]
176180
# return combined termination signal
177181
return self._truncated_buf | self._terminated_buf
178182

179183
def get_term(self, name: str) -> torch.Tensor:
180-
"""Returns the termination term with the specified name.
184+
"""Returns the termination term value at current step with the specified name.
181185
182186
Args:
183187
name: The name of the termination term.
@@ -190,7 +194,8 @@ def get_term(self, name: str) -> torch.Tensor:
190194
def get_active_iterable_terms(self, env_idx: int) -> Sequence[tuple[str, Sequence[float]]]:
191195
"""Returns the active terms as iterable sequence of tuples.
192196
193-
The first element of the tuple is the name of the term and the second element is the raw value(s) of the term.
197+
The first element of the tuple is the name of the term and the second element is the raw value(s) of the term
198+
recorded at current step.
194199
195200
Args:
196201
env_idx: The specific environment to pull the active terms from.
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
# Copyright (c) 2022-2025, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md).
2+
# All rights reserved.
3+
#
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
6+
"""Launch Isaac Sim Simulator first."""
7+
8+
from isaaclab.app import AppLauncher
9+
10+
# launch omniverse app
11+
simulation_app = AppLauncher(headless=True).app
12+
13+
"""Rest everything follows."""
14+
15+
import torch
16+
17+
import pytest
18+
19+
from isaaclab.managers import TerminationManager, TerminationTermCfg
20+
from isaaclab.sim import SimulationContext
21+
22+
23+
class DummyEnv:
24+
"""Minimal mutable env stub for the termination manager tests."""
25+
26+
def __init__(self, num_envs: int, device: str, sim: SimulationContext):
27+
self.num_envs = num_envs
28+
self.device = device
29+
self.sim = sim
30+
self.counter = 0 # mutable step counter used by test terms
31+
32+
33+
def fail_every_5_steps(env) -> torch.Tensor:
34+
"""Returns True for all envs when counter is a positive multiple of 5."""
35+
cond = env.counter > 0 and (env.counter % 5 == 0)
36+
return torch.full((env.num_envs,), cond, dtype=torch.bool, device=env.device)
37+
38+
39+
def fail_every_10_steps(env) -> torch.Tensor:
40+
"""Returns True for all envs when counter is a positive multiple of 10."""
41+
cond = env.counter > 0 and (env.counter % 10 == 0)
42+
return torch.full((env.num_envs,), cond, dtype=torch.bool, device=env.device)
43+
44+
45+
def fail_every_3_steps(env) -> torch.Tensor:
46+
"""Returns True for all envs when counter is a positive multiple of 3."""
47+
cond = env.counter > 0 and (env.counter % 3 == 0)
48+
return torch.full((env.num_envs,), cond, dtype=torch.bool, device=env.device)
49+
50+
51+
@pytest.fixture
52+
def env():
53+
sim = SimulationContext()
54+
return DummyEnv(num_envs=20, device="cpu", sim=sim)
55+
56+
57+
def test_initial_state_and_shapes(env):
58+
cfg = {
59+
"term_5": TerminationTermCfg(func=fail_every_5_steps),
60+
"term_10": TerminationTermCfg(func=fail_every_10_steps),
61+
}
62+
tm = TerminationManager(cfg, env)
63+
64+
# Active term names
65+
assert tm.active_terms == ["term_5", "term_10"]
66+
67+
# Internal buffers have expected shapes and start as all False
68+
assert tm._term_dones.shape == (env.num_envs, 2)
69+
assert tm._last_episode_dones.shape == (env.num_envs, 2)
70+
assert tm.dones.shape == (env.num_envs,)
71+
assert tm.time_outs.shape == (env.num_envs,)
72+
assert tm.terminated.shape == (env.num_envs,)
73+
assert torch.all(~tm._term_dones) and torch.all(~tm._last_episode_dones)
74+
75+
76+
def test_term_transitions_and_persistence(env):
77+
"""Concise transitions: single fire, persist, switch, both, persist.
78+
79+
Uses 3-step and 5-step terms and verifies current-step values and last-episode persistence.
80+
"""
81+
cfg = {
82+
"term_3": TerminationTermCfg(func=fail_every_3_steps, time_out=False),
83+
"term_5": TerminationTermCfg(func=fail_every_5_steps, time_out=False),
84+
}
85+
tm = TerminationManager(cfg, env)
86+
87+
# step 3: only term_3 -> last_episode [True, False]
88+
env.counter = 3
89+
out = tm.compute()
90+
assert torch.all(tm.get_term("term_3")) and torch.all(~tm.get_term("term_5"))
91+
assert torch.all(out)
92+
assert torch.all(tm._last_episode_dones[:, 0]) and torch.all(~tm._last_episode_dones[:, 1])
93+
94+
# step 4: none -> last_episode persists [True, False]
95+
env.counter = 4
96+
out = tm.compute()
97+
assert torch.all(~out)
98+
assert torch.all(~tm.get_term("term_3")) and torch.all(~tm.get_term("term_5"))
99+
assert torch.all(tm._last_episode_dones[:, 0]) and torch.all(~tm._last_episode_dones[:, 1])
100+
101+
# step 5: only term_5 -> last_episode [False, True]
102+
env.counter = 5
103+
out = tm.compute()
104+
assert torch.all(~tm.get_term("term_3")) and torch.all(tm.get_term("term_5"))
105+
assert torch.all(out)
106+
assert torch.all(~tm._last_episode_dones[:, 0]) and torch.all(tm._last_episode_dones[:, 1])
107+
108+
# step 15: both -> last_episode [True, True]
109+
env.counter = 15
110+
out = tm.compute()
111+
assert torch.all(tm.get_term("term_3")) and torch.all(tm.get_term("term_5"))
112+
assert torch.all(out)
113+
assert torch.all(tm._last_episode_dones[:, 0]) and torch.all(tm._last_episode_dones[:, 1])
114+
115+
# step 16: none -> persist [True, True]
116+
env.counter = 16
117+
out = tm.compute()
118+
assert torch.all(~out)
119+
assert torch.all(~tm.get_term("term_3")) and torch.all(~tm.get_term("term_5"))
120+
assert torch.all(tm._last_episode_dones[:, 0]) and torch.all(tm._last_episode_dones[:, 1])
121+
122+
123+
def test_time_out_vs_terminated_split(env):
124+
cfg = {
125+
"term_5": TerminationTermCfg(func=fail_every_5_steps, time_out=False), # terminated
126+
"term_10": TerminationTermCfg(func=fail_every_10_steps, time_out=True), # timeout
127+
}
128+
tm = TerminationManager(cfg, env)
129+
130+
# Step 5: terminated fires, not timeout
131+
env.counter = 5
132+
out = tm.compute()
133+
assert torch.all(out)
134+
assert torch.all(tm.terminated) and torch.all(~tm.time_outs)
135+
136+
# Step 10: both fire; timeout and terminated both True
137+
env.counter = 10
138+
out = tm.compute()
139+
assert torch.all(out)
140+
assert torch.all(tm.terminated) and torch.all(tm.time_outs)

0 commit comments

Comments
 (0)