Skip to content

Commit

Permalink
Merge pull request #277 from maxspahn/fix-realtime-mujoco
Browse files Browse the repository at this point in the history
Fix realtime mujoco
  • Loading branch information
maxspahn authored Jan 16, 2025
2 parents d5c9879 + cc86ae7 commit 4a1f966
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 31 deletions.
4 changes: 1 addition & 3 deletions examples/mujoco_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,16 +112,14 @@ def run_generic_mujoco(
t = 0.0
history = []
for i in range(n_steps):
t0 = time.perf_counter()
action = action_mag * np.cos(i/20)
action = action_mag * np.cos(env.t)
action[-1] = 0.02
ob, _, terminated, _, info = env.step(action)
#print(ob['robot_0'])
history.append(ob)
if terminated:
print(info)
break
t1 = time.perf_counter()

env.close()
return history
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "urdfenvs"
version = "0.10.1"
version = "0.10.2"
description = "Simple simulation environment for robots, based on the urdf files."
authors = ["Max Spahn <[email protected]>"]
maintainers = [
Expand Down
55 changes: 29 additions & 26 deletions urdfenvs/generic_mujoco/generic_mujoco_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ class GenericMujocoEnv(Env):
"rgb_array",
"depth_array",
],
"render_fps": 20,
}
_t: float
_number_movable_obstacles: int
Expand All @@ -53,7 +52,8 @@ def __init__(
goals: List[SubGoal],
sensors: Optional[List[Sensor]] = None,
render: Optional[Union[str, bool]] = None,
frame_skip: int = 5,
dt: float = 0.01,
n_sub_steps: int = 1,
width: int = DEFAULT_SIZE,
height: int = DEFAULT_SIZE,
camera_id: Optional[int] = None,
Expand Down Expand Up @@ -83,11 +83,13 @@ def __init__(
self._enforce_real_time = True
else:
self._enforce_real_time = False
self._sleep_offset = 0.0

self.width = width
self.height = height

self.frame_skip = frame_skip
self._dt = dt
self._n_sub_steps = n_sub_steps
self._initialize_simulation()

for sensor in self._sensors:
Expand All @@ -98,11 +100,6 @@ def __init__(
"rgb_array",
"depth_array",
], self.metadata["render_modes"]
if "render_fps" in self.metadata:
assert (
int(np.round(1.0 / self.dt)) == self.metadata["render_fps"]
), f'Expected value: {int(np.round(1.0 / self.dt))}, Actual value: {self.metadata["render_fps"]}'

self.observation_space = self.get_observation_space()
self.action_space = self.get_action_space()

Expand All @@ -117,6 +114,7 @@ def __init__(
height=DEFAULT_SIZE,
width=DEFAULT_SIZE,
)
self._end_last_step_time = time.time()

def render(self):
return self.mujoco_renderer.render(self.render_mode)
Expand Down Expand Up @@ -190,11 +188,12 @@ def _initialize_simulation(
self.model.body_pos[0] = [0, 1, 1]
self.model.vis.global_.offwidth = self.width
self.model.vis.global_.offheight = self.height
self.model.opt.timestep = self._dt / self._n_sub_steps
self.data = mujoco.MjData(self.model)

@property
def dt(self) -> float:
return self.model.opt.timestep * self.frame_skip
return self._dt

@property
def t(self) -> float:
Expand All @@ -221,7 +220,8 @@ def update_goals_position(self):
self.data.site_xpos[i] = goal.position(t=self.t)

def step(self, action: np.ndarray):
step_start = time.perf_counter()
target_end_step_time = self._end_last_step_time + self.dt

self._t += self.dt
truncated = False
info = {}
Expand All @@ -230,7 +230,7 @@ def step(self, action: np.ndarray):
self._done = True
info = {"action_limits": f"{action} not in {self.action_space}"}

self.do_simulation(action, self.frame_skip)
self.do_simulation(action)
for contact in self.data.contact:
body1 = self.model.geom(contact.geom1).name
body2 = self.model.geom(contact.geom2).name
Expand All @@ -256,15 +256,20 @@ def step(self, action: np.ndarray):
except WrongObservationError as e:
self._done = True
info = {"observation_limits": str(e)}
step_end = time.perf_counter()
step_time = step_end - step_start
if self._enforce_real_time:
sleep_time = max(0.0, self.dt - step_time)
time_before_sleep = time.time()
sleep_time = target_end_step_time - time_before_sleep - self._sleep_offset
if self._enforce_real_time and sleep_time > 0:
time.sleep(sleep_time)
step_final_end = time.perf_counter()
total_step_time = step_final_end - step_start
real_time_factor = self.dt / total_step_time
logging.info(f"Real time factor {real_time_factor}")
time_after_sleep = time.time()
# Compute the real-time factor (RTF)
real_time_step = time_after_sleep - self._end_last_step_time
rtf = self.dt / (real_time_step)
if real_time_step < self.dt:
self._sleep_offset -= 0.0001
else:
self._sleep_offset += 0.0001
logging.info(f"Real time factor {rtf:.4f}")
self._end_last_step_time = time.time()
return (
ob,
reward,
Expand Down Expand Up @@ -297,6 +302,7 @@ def reset(
qvel = np.zeros(self.nv)
self.set_state(qpos, qvel)
self._t = 0.0
self._end_last_step_time = time.time()
return self._get_obs(), {}

def set_state(self, qpos, qvel):
Expand All @@ -319,21 +325,18 @@ def nq(self) -> int:
def nv(self) -> int:
return self.model.nv - 6 * self._number_movable_obstacles

def do_simulation(self, ctrl, n_frames) -> None:
"""
Step the simulation n number of frames and applying a control action.
"""
def do_simulation(self, ctrl) -> None:
# Check control input is contained in the action space
if np.array(ctrl).shape != (self.model.nu,):
raise ValueError(
f"Action dimension mismatch. Expected {(self.model.nu,)}, found {np.array(ctrl).shape}"
)
self._step_mujoco_simulation(ctrl, n_frames)
self._step_mujoco_simulation(ctrl)

def _step_mujoco_simulation(self, ctrl, n_frames):
def _step_mujoco_simulation(self, ctrl):
self.data.ctrl[:] = ctrl

mujoco.mj_step(self.model, self.data, nstep=n_frames)
mujoco.mj_step(self.model, self.data, nstep=self._n_sub_steps)

# As of MuJoCo 2.0, force-related quantities like cacc are not computed
# unless there's a force sensor in the model.
Expand Down
1 change: 0 additions & 1 deletion urdfenvs/sensors/mujoco_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def get_obstacle_pose(
if movable:

free_joint_data = self._data.jnt(f"freejoint_{obst_name}").qpos
print(free_joint_data)
return free_joint_data[0:3].tolist(), free_joint_data[3:].tolist()
pos = self._data.body(obst_name).xpos
ori = self._data.body(obst_name).xquat
Expand Down

0 comments on commit 4a1f966

Please sign in to comment.