@@ -61,9 +61,7 @@ def __init__(self, cfg: object, env: ManagerBasedRLEnv):
6161 # call the base class constructor (this will parse the terms config)
6262 super ().__init__ (cfg , env )
6363 # prepare extra info to store individual termination term information
64- self ._term_dones = dict ()
65- for term_name in self ._term_names :
66- self ._term_dones [term_name ] = torch .zeros (self .num_envs , device = self .device , dtype = torch .bool )
64+ self ._term_dones = torch .zeros ((self .num_envs , len (self ._term_names )), device = self .device , dtype = torch .bool )
6765 # create buffer for managing termination per environment
6866 self ._truncated_buf = torch .zeros (self .num_envs , device = self .device , dtype = torch .bool )
6967 self ._terminated_buf = torch .zeros_like (self ._truncated_buf )
@@ -139,9 +137,10 @@ def reset(self, env_ids: Sequence[int] | None = None) -> dict[str, torch.Tensor]
139137 env_ids = slice (None )
140138 # add to episode dict
141139 extras = {}
142- for key in self ._term_dones .keys ():
140+ last_episode_done_stats = self ._term_dones .float ().mean (dim = 0 )
141+ for i , key in enumerate (self ._term_names ):
143142 # store information
144- extras ["Episode_Termination/" + key ] = torch . count_nonzero ( self . _term_dones [ key ][ env_ids ]) .item ()
143+ extras ["Episode_Termination/" + key ] = last_episode_done_stats [ i ] .item ()
145144 # reset all the reward terms
146145 for term_cfg in self ._class_term_cfgs :
147146 term_cfg .func .reset (env_ids = env_ids )
@@ -161,15 +160,16 @@ def compute(self) -> torch.Tensor:
161160 self ._truncated_buf [:] = False
162161 self ._terminated_buf [:] = False
163162 # iterate over all the termination terms
164- for name , term_cfg in zip ( self . _term_names , self ._term_cfgs ):
163+ for i , term_cfg in enumerate ( self ._term_cfgs ):
165164 value = term_cfg .func (self ._env , ** term_cfg .params )
166165 # store timeout signal separately
167166 if term_cfg .time_out :
168167 self ._truncated_buf |= value
169168 else :
170169 self ._terminated_buf |= value
171170 # add to episode dones
172- self ._term_dones [name ][:] = value
171+ self ._term_dones [value ] = False
172+ self ._term_dones [value , i ] = True
173173 # return combined termination signal
174174 return self ._truncated_buf | self ._terminated_buf
175175
@@ -182,7 +182,7 @@ def get_term(self, name: str) -> torch.Tensor:
182182 Returns:
183183 The corresponding termination term value. Shape is (num_envs,).
184184 """
185- return self ._term_dones [name ]
185+ return self ._term_dones [name , self . _term_names . index ( name ) ]
186186
187187 def get_active_iterable_terms (self , env_idx : int ) -> Sequence [tuple [str , Sequence [float ]]]:
188188 """Returns the active terms as iterable sequence of tuples.
@@ -196,8 +196,8 @@ def get_active_iterable_terms(self, env_idx: int) -> Sequence[tuple[str, Sequenc
196196 The active terms.
197197 """
198198 terms = []
199- for key in self ._term_dones . keys ( ):
200- terms .append ((key , [self ._term_dones [key ][ env_idx ].float ().cpu ().item ()]))
199+ for i , key in enumerate ( self ._term_names ):
200+ terms .append ((key , [self ._term_dones [env_idx , i ].float ().cpu ().item ()]))
201201 return terms
202202
203203 """
0 commit comments