-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Fixes Termination Manager logging to report aggregated percentage of environments done due to each term. #3107
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
c7b9ebd
fc8c511
842cea0
7ef71cc
183b4aa
a0767c6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -60,10 +60,9 @@ def __init__(self, cfg: object, env: ManagerBasedRLEnv): | |
|
|
||
| # call the base class constructor (this will parse the terms config) | ||
| super().__init__(cfg, env) | ||
| self._term_name_to_term_idx = {name: i for i, name in enumerate(self._term_names)} | ||
| # prepare extra info to store individual termination term information | ||
| self._term_dones = dict() | ||
| for term_name in self._term_names: | ||
| self._term_dones[term_name] = torch.zeros(self.num_envs, device=self.device, dtype=torch.bool) | ||
| self._term_dones = torch.zeros((self.num_envs, len(self._term_names)), device=self.device, dtype=torch.bool) | ||
| # create buffer for managing termination per environment | ||
| self._truncated_buf = torch.zeros(self.num_envs, device=self.device, dtype=torch.bool) | ||
| self._terminated_buf = torch.zeros_like(self._truncated_buf) | ||
|
|
@@ -139,9 +138,10 @@ def reset(self, env_ids: Sequence[int] | None = None) -> dict[str, torch.Tensor] | |
| env_ids = slice(None) | ||
| # add to episode dict | ||
| extras = {} | ||
| for key in self._term_dones.keys(): | ||
| last_episode_done_stats = self._term_dones.float().mean(dim=0) | ||
| for i, key in enumerate(self._term_names): | ||
| # store information | ||
| extras["Episode_Termination/" + key] = torch.count_nonzero(self._term_dones[key][env_ids]).item() | ||
| extras["Episode_Termination/" + key] = last_episode_done_stats[i].item() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you want the ratio, isn't that simply? self._term_dones[key][env_ids].sum() / len(env_ids)
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for reveiwing! But it seems like if this approach is what I after, do it in one tensor operation seems quite nice, both speed wise and memory utility wise. |
||
| # reset all the reward terms | ||
| for term_cfg in self._class_term_cfgs: | ||
| term_cfg.func.reset(env_ids=env_ids) | ||
|
|
@@ -161,15 +161,18 @@ def compute(self) -> torch.Tensor: | |
| self._truncated_buf[:] = False | ||
| self._terminated_buf[:] = False | ||
| # iterate over all the termination terms | ||
| for name, term_cfg in zip(self._term_names, self._term_cfgs): | ||
| for i, term_cfg in enumerate(self._term_cfgs): | ||
| value = term_cfg.func(self._env, **term_cfg.params) | ||
| # store timeout signal separately | ||
| if term_cfg.time_out: | ||
| self._truncated_buf |= value | ||
| else: | ||
| self._terminated_buf |= value | ||
| # add to episode dones | ||
| self._term_dones[name][:] = value | ||
| rows = value.nonzero(as_tuple=True)[0] # indexing is cheaper than boolean advance indexing | ||
| if rows.numel() > 0: | ||
| self._term_dones[rows] = False | ||
| self._term_dones[rows, i] = True | ||
| # return combined termination signal | ||
| return self._truncated_buf | self._terminated_buf | ||
|
|
||
|
|
@@ -182,7 +185,7 @@ def get_term(self, name: str) -> torch.Tensor: | |
| Returns: | ||
| The corresponding termination term value. Shape is (num_envs,). | ||
| """ | ||
| return self._term_dones[name] | ||
| return self._term_dones[:, self._term_name_to_term_idx[name]] | ||
|
|
||
| def get_active_iterable_terms(self, env_idx: int) -> Sequence[tuple[str, Sequence[float]]]: | ||
| """Returns the active terms as iterable sequence of tuples. | ||
|
|
@@ -196,8 +199,8 @@ def get_active_iterable_terms(self, env_idx: int) -> Sequence[tuple[str, Sequenc | |
| The active terms. | ||
| """ | ||
| terms = [] | ||
| for key in self._term_dones.keys(): | ||
| terms.append((key, [self._term_dones[key][env_idx].float().cpu().item()])) | ||
| for i, key in enumerate(self._term_names): | ||
| terms.append((key, [self._term_dones[env_idx, i].float().cpu().item()])) | ||
| return terms | ||
|
|
||
| """ | ||
|
|
@@ -217,7 +220,7 @@ def set_term_cfg(self, term_name: str, cfg: TerminationTermCfg): | |
| if term_name not in self._term_names: | ||
| raise ValueError(f"Termination term '{term_name}' not found.") | ||
| # set the configuration | ||
| self._term_cfgs[self._term_names.index(term_name)] = cfg | ||
| self._term_cfgs[self._term_name_to_term_idx[term_name]] = cfg | ||
|
|
||
| def get_term_cfg(self, term_name: str) -> TerminationTermCfg: | ||
| """Gets the configuration for the specified term. | ||
|
|
@@ -234,7 +237,7 @@ def get_term_cfg(self, term_name: str) -> TerminationTermCfg: | |
| if term_name not in self._term_names: | ||
| raise ValueError(f"Termination term '{term_name}' not found.") | ||
| # return the configuration | ||
| return self._term_cfgs[self._term_names.index(term_name)] | ||
| return self._term_cfgs[self._term_name_to_term_idx[term_name]] | ||
|
|
||
| """ | ||
| Helper functions. | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason to change this as a dict of tensors to a single tensor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought it
last_episode_done_stats = self._term_dones.float().mean(dim=0)this operation is very nice and optimized, thats why I did it, but if you think dict is more clear I can revert back : ))