Skip to content

Commit

Permalink
change min_frequency to global step counter
Browse files Browse the repository at this point in the history
  • Loading branch information
kellyguo11 committed Jul 30, 2024
1 parent 337646f commit 92e12dc
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,8 @@ def _reset_idx(self, env_ids: Sequence[int]):
# apply events such as randomization for environments that need a reset
if self.cfg.events:
if "reset" in self.event_manager.available_modes:
self.event_manager.apply(env_ids=env_ids, mode="reset", dt=self.step_dt)
env_step_count = self._sim_step_counter // self.cfg.decimation
self.event_manager.apply(env_ids=env_ids, mode="reset", global_env_step_count=env_step_count)

# reset noise models
if self.cfg.action_noise_model:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,8 @@ def _reset_idx(self, env_ids: Sequence[int]):
self.scene.reset(env_ids)
# apply events such as randomizations for environments that need a reset
if "reset" in self.event_manager.available_modes:
self.event_manager.apply(env_ids=env_ids, mode="reset", dt=self.step_dt)
env_step_count = self._sim_step_counter // self.cfg.decimation
self.event_manager.apply(env_ids=env_ids, mode="reset", global_env_step_count=env_step_count)

# iterate over all managers and reset them
# this returns a dictionary of information which is stored in the extras
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,8 @@ def _reset_idx(self, env_ids: Sequence[int]):
self.scene.reset(env_ids)
# apply events such as randomizations for environments that need a reset
if "reset" in self.event_manager.available_modes:
self.event_manager.apply(env_ids=env_ids, mode="reset", dt=self.step_dt)
env_step_count = self._sim_step_counter // self.cfg.decimation
self.event_manager.apply(env_ids=env_ids, mode="reset", global_env_step_count=env_step_count)

# iterate over all managers and reset them
# this returns a dictionary of information which is stored in the extras
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,13 @@ def reset(self, env_ids: Sequence[int] | None = None) -> dict[str, float]:
# nothing to log here
return {}

def apply(self, mode: str, env_ids: Sequence[int] | None = None, dt: float | None = None):
def apply(
self,
mode: str,
env_ids: Sequence[int] | None = None,
dt: float | None = None,
global_env_step_count: int | None = None,
):
"""Calls each event term in the specified mode.
Note:
Expand All @@ -132,6 +138,8 @@ def apply(self, mode: str, env_ids: Sequence[int] | None = None, dt: float | Non
Defaults to None, in which case the event is applied to all environments.
dt: The time step of the environment. This is only used for the "interval" mode.
Defaults to None to simplify the call for other modes.
global_env_step_count: The environment step count of the task. This is only used for the "reset" mode.
Defaults to None to simplify the call for other modes.
Raises:
ValueError: If the mode is ``"interval"`` and the time step is not provided.
Expand Down Expand Up @@ -173,17 +181,18 @@ def apply(self, mode: str, env_ids: Sequence[int] | None = None, dt: float | Non
time_left[env_ids] = torch.rand(len(env_ids), device=self.device) * (upper - lower) + lower
# check for minimum frequency for reset
elif mode == "reset":
if dt is None:
if global_env_step_count is None:
raise ValueError(
f"Event mode '{mode}' requires the time step of the environment"
f"Event mode '{mode}' requires the step count of the environment"
" to be passed to the event manager."
)
self._reset_mode_time_until_next_reset[index] -= dt

if env_ids is not None and len(env_ids) > 0:
time_left = self._reset_mode_time_until_next_reset[index]
env_ids = env_ids[time_left[env_ids] <= 0.0]
last_reset_step = self._reset_mode_last_reset_step_count[index]
steps_since_last_reset = global_env_step_count - last_reset_step
env_ids = env_ids[steps_since_last_reset[env_ids] >= term_cfg.min_step_count_between_reset]
if len(env_ids) > 0:
time_left[env_ids] = term_cfg.min_frequency
last_reset_step[env_ids] = global_env_step_count
else:
# no need to call func to sample
continue
Expand Down Expand Up @@ -250,8 +259,8 @@ def _prepare_terms(self):
self._interval_mode_time_left: list[torch.Tensor] = list()
# global timer for "interval" mode for global properties
self._interval_mode_time_global: list[torch.Tensor] = list()
# buffer to store the time until next reset for each environment for "reset" mode with minimum frequency
self._reset_mode_time_until_next_reset: list[torch.Tensor] = list()
# buffer to store the step count when reset was last performed for each environment for "reset" mode
self._reset_mode_last_reset_step_count: list[torch.Tensor] = list()

# check if config is dict already
if isinstance(self.cfg, dict):
Expand Down Expand Up @@ -304,5 +313,5 @@ def _prepare_terms(self):
self._interval_mode_time_left.append(time_left)

elif term_cfg.mode == "reset":
time_left = torch.zeros(self.num_envs, device=self.device)
self._reset_mode_time_until_next_reset.append(time_left)
step_count = torch.zeros(self.num_envs, device=self.device, dtype=torch.int32)
self._reset_mode_last_reset_step_count.append(step_count)
Original file line number Diff line number Diff line change
Expand Up @@ -209,11 +209,11 @@ class EventTermCfg(ManagerTermBaseCfg):
This is only used if the mode is ``"interval"``.
"""

min_frequency: float = 0.0
"""The minimum time in seconds between when term is applied.
min_step_count_between_reset: float = 0.0
"""The minimum number of environment steps between when term is applied.
When mode is "reset", the term will not be applied on the next reset unless
the time since the last application of the term has exceeded this.
the number of steps since the last application of the term has exceeded this.
Note:
This is only used if the mode is ``"reset"``.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class EventCfg:
robot_physics_material = EventTerm(
func=mdp.randomize_rigid_body_material,
mode="reset",
min_frequency=36.0,
min_step_count_between_reset=720,
params={
"asset_cfg": SceneEntityCfg("robot"),
"static_friction_range": (0.7, 1.3),
Expand All @@ -40,7 +40,7 @@ class EventCfg:
)
robot_joint_stiffness_and_damping = EventTerm(
func=mdp.randomize_actuator_gains,
min_frequency=36.0,
min_step_count_between_reset=720,
mode="reset",
params={
"asset_cfg": SceneEntityCfg("robot", joint_names=".*"),
Expand All @@ -52,7 +52,7 @@ class EventCfg:
)
robot_joint_limits = EventTerm(
func=mdp.randomize_joint_parameters,
min_frequency=36.0,
min_step_count_between_reset=720,
mode="reset",
params={
"asset_cfg": SceneEntityCfg("robot", joint_names=".*"),
Expand All @@ -64,7 +64,7 @@ class EventCfg:
)
robot_tendon_properties = EventTerm(
func=mdp.randomize_fixed_tendon_parameters,
min_frequency=36.0,
min_step_count_between_reset=720,
mode="reset",
params={
"asset_cfg": SceneEntityCfg("robot", fixed_tendon_names=".*"),
Expand All @@ -78,7 +78,7 @@ class EventCfg:
# -- object
object_physics_material = EventTerm(
func=mdp.randomize_rigid_body_material,
min_frequency=36.0,
min_step_count_between_reset=720,
mode="reset",
params={
"asset_cfg": SceneEntityCfg("object"),
Expand All @@ -90,7 +90,7 @@ class EventCfg:
)
object_scale_mass = EventTerm(
func=mdp.randomize_rigid_body_mass,
min_frequency=36.0,
min_step_count_between_reset=720,
mode="reset",
params={
"asset_cfg": SceneEntityCfg("object"),
Expand Down

0 comments on commit 92e12dc

Please sign in to comment.