Skip to content

Commit

Permalink
Update to refactored mujoco renderer in Gymnasium v0.27.0 (#60)
Browse files Browse the repository at this point in the history
* update mujoco renderer gymnasium v27

* remove rgb_array_list from metadata

* skip render check pytest

* pre-commit black
  • Loading branch information
rodrigodelazcano authored Dec 17, 2022
1 parent 3436709 commit a2e17df
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 62 deletions.
29 changes: 21 additions & 8 deletions gymnasium_robotics/envs/fetch_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@
from gymnasium_robotics.envs.robot_env import MujocoPyRobotEnv, MujocoRobotEnv
from gymnasium_robotics.utils import rotations

DEFAULT_CAMERA_CONFIG = {
"distance": 2.5,
"azimuth": 132.0,
"elevation": -14.0,
"lookat": np.array([1.3, 0.75, 0.55]),
}


def goal_distance(goal_a, goal_b):
assert goal_a.shape == goal_b.shape
Expand Down Expand Up @@ -143,14 +150,6 @@ def get_gripper_xpos(self):

raise NotImplementedError

def _viewer_setup(self):
lookat = self.get_gripper_xpos()
for idx, value in enumerate(lookat):
self.viewer.cam.lookat[idx] = value
self.viewer.cam.distance = 2.5
self.viewer.cam.azimuth = 132.0
self.viewer.cam.elevation = -14.0

def _sample_goal(self):
if self.has_object:
goal = self.initial_gripper_xpos[:3] + self.np_random.uniform(
Expand Down Expand Up @@ -238,6 +237,17 @@ def _render_callback(self):
self.sim.model.site_pos[site_id] = self.goal - sites_offset[0]
self.sim.forward()

def _viewer_setup(self):
lookat = self.get_gripper_xpos()
for idx, value in enumerate(lookat):
self.viewer.cam.lookat[idx] = value
assert self.viewer is not None
for key, value in DEFAULT_CAMERA_CONFIG.items():
if isinstance(value, np.ndarray):
getattr(self.viewer.cam, key)[:] = value
else:
setattr(self.viewer.cam, key, value)

def _reset_sim(self):
self.sim.set_state(self.initial_state)

Expand Down Expand Up @@ -279,6 +289,9 @@ def _env_setup(self, initial_qpos):


class MujocoFetchEnv(get_base_fetch_env(MujocoRobotEnv)):
def __init__(self, default_camera_config: dict = DEFAULT_CAMERA_CONFIG, **kwargs):
super().__init__(default_camera_config=default_camera_config, **kwargs)

def _step_callback(self):
if self.block_gripper:
self._utils.set_joint_qpos(
Expand Down
34 changes: 20 additions & 14 deletions gymnasium_robotics/envs/hand_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,13 @@

from gymnasium_robotics.envs.robot_env import MujocoPyRobotEnv, MujocoRobotEnv

DEFAULT_CAMERA_CONFIG = {
"distance": 0.5,
"azimuth": 55.0,
"elevation": -25.0,
"lookat": np.array([1, 0.96, 0.14]),
}


def get_base_hand_env(
RobotEnvClass: Union[MujocoPyRobotEnv, MujocoRobotEnv]
Expand All @@ -21,24 +28,19 @@ def __init__(self, relative_control, **kwargs):

# RobotEnv methods
# ----------------------------
def _get_palm_xpos(self):
raise NotImplementedError

def _set_action(self, action):
assert action.shape == (20,)

def _viewer_setup(self):
lookat = self._get_palm_xpos()
for idx, value in enumerate(lookat):
self.viewer.cam.lookat[idx] = value
self.viewer.cam.distance = 0.5
self.viewer.cam.azimuth = 55.0
self.viewer.cam.elevation = -25.0

return BaseHandEnv


class MujocoHandEnv(get_base_hand_env(MujocoRobotEnv)):
def __init__(
self, default_camera_config: dict = DEFAULT_CAMERA_CONFIG, **kwargs
) -> None:
super().__init__(default_camera_config=default_camera_config, **kwargs)

def _set_action(self, action):
super()._set_action(action)
ctrlrange = self.model.actuator_ctrlrange
Expand All @@ -60,10 +62,6 @@ def _set_action(self, action):
self.data.ctrl[:] = actuation_center + action * actuation_range
self.data.ctrl[:] = np.clip(self.data.ctrl, ctrlrange[:, 0], ctrlrange[:, 1])

def _get_palm_xpos(self):
body_id = self._model_names.body_name2id["robot0:palm"]
return self.data.xpos[body_id]


class MujocoPyHandEnv(get_base_hand_env(MujocoPyRobotEnv)):
def _set_action(self, action):
Expand Down Expand Up @@ -92,3 +90,11 @@ def _set_action(self, action):
def _get_palm_xpos(self):
body_id = self.sim.model.body_name2id("robot0:palm")
return self.sim.data.body_xpos[body_id]

def _viewer_setup(self):
lookat = self._get_palm_xpos()
for idx, value in enumerate(lookat):
self.viewer.cam.lookat[idx] = value
self.viewer.cam.distance = 0.5
self.viewer.cam.azimuth = 55.0
self.viewer.cam.elevation = -25.0
61 changes: 22 additions & 39 deletions gymnasium_robotics/envs/robot_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import os
from typing import Optional, Union

import gymnasium as gym
import numpy as np
from gymnasium import error, logger, spaces

Expand Down Expand Up @@ -36,7 +35,6 @@ class BaseRobotEnv(GoalEnv):
"render_modes": [
"human",
"rgb_array",
"rgb_array_list",
],
"render_fps": 25,
}
Expand Down Expand Up @@ -202,12 +200,6 @@ def _env_setup(self, initial_qpos):
"""
pass

def _viewer_setup(self):
"""Initial configuration of the viewer. Can be used to set the camera position,
for example.
"""
pass

def _render_callback(self):
"""A custom callback that is called before rendering. Can be used
to implement custom visualizations.
Expand All @@ -222,7 +214,7 @@ def _step_callback(self):


class MujocoRobotEnv(BaseRobotEnv):
def __init__(self, **kwargs):
def __init__(self, default_camera_config: Optional[dict] = None, **kwargs):
if MUJOCO_IMPORT_ERROR is not None:
raise error.DependencyNotInstalled(
f"{MUJOCO_IMPORT_ERROR}. (HINT: you need to install mujoco)"
Expand All @@ -233,6 +225,12 @@ def __init__(self, **kwargs):

super().__init__(**kwargs)

from gymnasium.envs.mujoco.mujoco_rendering import MujocoRenderer

self.mujoco_renderer = MujocoRenderer(
self.model, self.data, default_camera_config
)

def _initialize_simulation(self):
self.model = self._mujoco.MjModel.from_xml_path(self.fullpath)
self.data = self._mujoco.MjData(self.model)
Expand All @@ -258,35 +256,11 @@ def _reset_sim(self):

def render(self):
self._render_callback()
if self.render_mode == "rgb_array":
self._get_viewer(self.render_mode).render()
data = self._get_viewer(self.render_mode).read_pixels(depth=False)
# original image is upside-down, so flip it
return data[::-1, :, :]
elif self.render_mode == "human":
self._get_viewer(self.render_mode).render()

def _get_viewer(
self, mode
) -> Union["gym.envs.mujoco.Viewer", "gym.envs.mujoco.RenderContextOffscreen"]:
self.viewer = self._viewers.get(mode)
if self.viewer is None:
if mode == "human":
from gymnasium.envs.mujoco.mujoco_rendering import Viewer

self.viewer = Viewer(self.model, self.data)
elif mode in {
"rgb_array",
"rgb_array_list",
}:
from gymnasium.envs.mujoco.mujoco_rendering import (
RenderContextOffscreen,
)
return self.mujoco_renderer.render(self.render_mode)

self.viewer = RenderContextOffscreen(model=self.model, data=self.data)
self._viewer_setup()
self._viewers[mode] = self.viewer
return self.viewer
def close(self):
if self.mujoco_renderer is not None:
self.mujoco_renderer.close()

@property
def dt(self):
Expand Down Expand Up @@ -334,7 +308,6 @@ def render(self):
self._render_callback()
if self.render_mode in {
"rgb_array",
"rgb_array_list",
}:
self._get_viewer(self.render_mode).render(width, height)
# window size used for old mujoco-py:
Expand All @@ -346,6 +319,11 @@ def render(self):
elif self.render_mode == "human":
self._get_viewer(self.render_mode).render()

def close(self):
if self.viewer is not None:
self.viewer = None
self._viewers = {}

def _get_viewer(
self, mode
) -> Union["mujoco_py.MjViewer", "mujoco_py.MjRenderContextOffscreen"]:
Expand All @@ -356,7 +334,6 @@ def _get_viewer(

elif mode in {
"rgb_array",
"rgb_array_list",
}:
self.viewer = self._mujoco_py.MjRenderContextOffscreen(self.sim, -1)
self._viewer_setup()
Expand All @@ -369,3 +346,9 @@ def dt(self):

def _mujoco_step(self, action):
self.sim.step()

def _viewer_setup(self):
"""Initial configuration of the viewer. Can be used to set the camera position,
for example.
"""
pass
4 changes: 3 additions & 1 deletion tests/test_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,12 @@
def test_env(spec):
# Capture warnings
env = spec.make(disable_env_checker=True).unwrapped

warnings.simplefilter("always")
# Test if env adheres to Gym API
with warnings.catch_warnings(record=True) as w:
check_env(env)
check_env(env, skip_render_check=True)
env.close()
for warning in w:
if warning.message.args[0] not in CHECK_ENV_IGNORE_WARNINGS:
raise Error(f"Unexpected warning: {warning.message}")
Expand Down

0 comments on commit a2e17df

Please sign in to comment.