diff --git a/docs/source/tasks/control/index.md b/docs/source/tasks/control/index.md index c2f15e8f5..2a7136610 100644 --- a/docs/source/tasks/control/index.md +++ b/docs/source/tasks/control/index.md @@ -45,4 +45,46 @@ Use the Cartpole robot to swing up a pole on a cart. **Success Conditions:** - No specific success conditions. The task is considered successful if the pole is upright for the whole episode. We can threshold the episode accumulated reward to determine success. +::: + +## MS-HopperHop-v1 +![dense-reward][reward-badge] + +:::{dropdown} Task Card +:icon: note +:color: primary + +**Task Description:** +Hopper robot stays upright and moves in positive x direction with hopping motion + + +**Supported Robots: Hopper** + +**Randomizations:** +- Hopper robot is randomly rotated [-pi, pi] radians about y axis. +- Hopper qpos are uniformly sampled within their allowed ranges + +**Success Conditions:** +- No specific success conditions. The task is considered successful if the pole is upright for the whole episode. We can threshold the episode accumulated reward to determine success. +::: + +## MS-HopperStand-v1 +![dense-reward][reward-badge] + +:::{dropdown} Task Card +:icon: note +:color: primary + +**Task Description:** +Hopper robot stands upright + + +**Supported Robots: Hopper** + +**Randomizations:** +- Hopper robot is randomly rotated [-pi, pi] radians about y axis. +- Hopper qpos are uniformly sampled within their allowed ranges + +**Success Conditions:** +- No specific success conditions. We can threshold the episode accumulated reward to determine success. ::: \ No newline at end of file diff --git a/mani_skill/envs/tasks/control/__init__.py b/mani_skill/envs/tasks/control/__init__.py index 920f2d3f6..59e766758 100644 --- a/mani_skill/envs/tasks/control/__init__.py +++ b/mani_skill/envs/tasks/control/__init__.py @@ -1 +1,2 @@ from .cartpole import CartpoleBalanceEnv, CartpoleSwingUpEnv +from .hopper import HopperHopEnv, HopperStandEnv diff --git a/mani_skill/envs/tasks/control/assets/hopper.xml b/mani_skill/envs/tasks/control/assets/hopper.xml new file mode 100644 index 000000000..0aa3be492 --- /dev/null +++ b/mani_skill/envs/tasks/control/assets/hopper.xml @@ -0,0 +1,70 @@ + + + + + + + + + + + + + + + + + + + + diff --git a/mani_skill/envs/tasks/control/cartpole.py b/mani_skill/envs/tasks/control/cartpole.py index dd5f265bd..71b0d1e27 100644 --- a/mani_skill/envs/tasks/control/cartpole.py +++ b/mani_skill/envs/tasks/control/cartpole.py @@ -3,7 +3,9 @@ from typing import Any, Dict, Union import numpy as np +import sapien import torch +from transforms3d.euler import euler2quat from mani_skill.agents.base_agent import BaseAgent from mani_skill.agents.controllers import * @@ -110,6 +112,17 @@ def _load_scene(self, options: dict): for a in actor_builders: a.build(a.name) + # background visual wall + self.wall = self.scene.create_actor_builder() + self.wall.add_box_visual( + half_size=(1e-3, 20, 10), + pose=sapien.Pose(p=[0, 1, 1], q=euler2quat(0, 0, np.pi / 2)), + material=sapien.render.RenderMaterial( + base_color=np.array([0.3, 0.3, 0.3, 1]) + ), + ) + self.wall.build_static(name="wall") + def evaluate(self): return dict() diff --git a/mani_skill/envs/tasks/control/hopper.py b/mani_skill/envs/tasks/control/hopper.py new file mode 100644 index 000000000..57cfcadd7 --- /dev/null +++ b/mani_skill/envs/tasks/control/hopper.py @@ -0,0 +1,248 @@ +"""Adapted from https://github.com/google-deepmind/dm_control/blob/main/dm_control/suite/hopper.py""" + +import os +from typing import Any, Dict, Union + +import numpy as np +import sapien +import torch + +from mani_skill.agents.base_agent import BaseAgent +from mani_skill.agents.controllers import * +from mani_skill.envs.sapien_env import BaseEnv +from mani_skill.envs.utils import randomization, rewards +from mani_skill.sensors.camera import CameraConfig +from mani_skill.utils import common, sapien_utils +from mani_skill.utils.geometry import rotation_conversions +from mani_skill.utils.registration import register_env +from mani_skill.utils.scene_builder.control.planar.scene_builder import ( + PlanarSceneBuilder, +) +from mani_skill.utils.structs.pose import Pose +from mani_skill.utils.structs.types import Array, SceneConfig, SimConfig + +MJCF_FILE = f"{os.path.join(os.path.dirname(__file__), 'assets/hopper.xml')}" + +# Minimal height of torso over foot above which stand reward is 1. +_STAND_HEIGHT = 0.6 + +# Hopping speed above which hop reward is 1. +_HOP_SPEED = 2 + + +class HopperRobot(BaseAgent): + uid = "hopper" + mjcf_path = MJCF_FILE + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @property + def _controller_configs(self): + # NOTE joints in [rootx,rooty,rooz] are for planar tracking, not control joints + max_delta = 2 # best by far + stiffness = 100 + damping = 10 # end best + pd_joint_delta_pos_body = PDJointPosControllerConfig( + ["hip", "knee", "waist"], + lower=-max_delta, + upper=max_delta, + damping=damping, + stiffness=stiffness, + use_delta=True, + ) + pd_joint_delta_pos_ankle = PDJointPosControllerConfig( + ["ankle"], + lower=-max_delta / 2.5, + upper=max_delta / 2.5, + damping=damping, + stiffness=stiffness, + use_delta=True, + ) + rest = PassiveControllerConfig( + [j.name for j in self.robot.active_joints if "root" in j.name], + damping=0, + friction=0, + ) + return deepcopy_dict( + dict( + pd_joint_delta_pos=dict( + body=pd_joint_delta_pos_body, + ankle=pd_joint_delta_pos_ankle, + rest=rest, + balance_passive_force=False, + ), + ) + ) + + def _load_articulation(self): + """ + Load the robot articulation + """ + loader = self.scene.create_mjcf_loader() + asset_path = str(self.mjcf_path) + + loader.name = self.uid + + self.robot = loader.parse(asset_path)[0][0].build() + assert self.robot is not None, f"Fail to load URDF/MJCF from {asset_path}" + self.robot_link_ids = [link.name for link in self.robot.get_links()] + + # cache robot mass for com computation + self.robot_links_mass = [link.mass[0].item() for link in self.robot.get_links()] + self.robot_mass = np.sum(self.robot_links_mass[3:]) + + # planar agent has root joints in range [-inf, inf], not ideal in obs space + def get_proprioception(self): + return dict( + # don't include xslider qpos, for x trans invariance + # x increases throughout successful episode + qpos=self.robot.get_qpos()[:, 1:], + qvel=self.robot.get_qvel(), + ) + + +class HopperEnv(BaseEnv): + agent: Union[HopperRobot] + + def __init__(self, *args, robot_uids=HopperRobot, **kwargs): + super().__init__(*args, robot_uids=robot_uids, **kwargs) + + @property + def _default_sim_config(self): + return SimConfig( + scene_cfg=SceneConfig( + solver_position_iterations=4, solver_velocity_iterations=1 + ), + sim_freq=100, + control_freq=25, + ) + + @property + def _default_sensor_configs(self): + return [ + # replicated from xml file + CameraConfig( + uid="cam0", + pose=sapien_utils.look_at(eye=[0, -2.8, 0.8], target=[0, 0, 0]), + width=128, + height=128, + fov=np.pi / 4, + near=0.01, + far=100, + mount=self.agent.robot.links_map["torso_dummy_1"], + ), + ] + + @property + def _default_human_render_camera_configs(self): + return [ + # replicated from xml file + CameraConfig( + uid="render_cam", + pose=sapien_utils.look_at(eye=[0, -2.8, 0.8], target=[0, 0, 0]), + width=512, + height=512, + fov=np.pi / 4, + near=0.01, + far=100, + mount=self.agent.robot.links_map["torso_dummy_1"], + ), + ] + + def _load_scene(self, options: dict): + loader = self.scene.create_mjcf_loader() + articulation_builders, actor_builders, sensor_configs = loader.parse(MJCF_FILE) + for a in actor_builders: + a.build(a.name) + + self.planar_scene = PlanarSceneBuilder(env=self) + self.planar_scene.build() + + def _initialize_episode(self, env_idx: torch.Tensor, options: Dict): + with torch.device(self.device): + b = len(env_idx) + # qpos sampled same as dm_control, but ensure no self intersection explicitly here + random_qpos = torch.rand(b, self.agent.robot.dof[0]) + q_lims = self.agent.robot.get_qlimits() + q_ranges = q_lims[..., 1] - q_lims[..., 0] + random_qpos *= q_ranges + random_qpos += q_lims[..., 0] + + # overwrite planar joint qpos - these are special for planar robots + # first two joints are dummy rootx and rootz + random_qpos[:, :2] = 0 + # y is axis of rotation of our planar robot (xz plane), so we randomize around it + random_qpos[:, 2] = torch.pi * (2 * torch.rand(b) - 1) # (-pi,pi) + self.agent.reset(random_qpos) + + @property # dm_control mjc function adapted for maniskill + def height(self): + """Returns relative height of the robot""" + return ( + self.agent.robot.links_map["torso"].pose.p[:, -1] + - self.agent.robot.links_map["foot_heel"].pose.p[:, -1] + ).view(-1, 1) + + @property # dm_control mjc function adapted for maniskill + def subtreelinvelx(self): + # """Returns linvel x component of robot""" + links = self.agent.robot.get_links()[3:] # skip first three dummy links + vels = torch.stack( + [link.get_linear_velocity() * link.mass[0].item() for link in links], dim=0 + ) # (num_links, b, 3) + com_vel = vels.sum(dim=0) / self.agent.robot_mass # (b, 3) + return com_vel[:, 0] + + # dm_control mjc function adapted for maniskill + def touch(self, link_name): + """Returns function of sensor force values""" + force_vec = self.agent.robot.get_net_contact_forces([link_name]) + force_mag = torch.linalg.norm(force_vec, dim=-1) + return torch.log1p(force_mag) + + # dm_control also includes foot pressures as state obs space + def _get_obs_state_dict(self, info: Dict): + return dict( + agent=self._get_obs_agent(), + toe_touch=self.touch("foot_toe"), + heel_touch=self.touch("foot_heel"), + ) + + +@register_env("MS-HopperStand-v1", max_episode_steps=600) +class HopperStandEnv(HopperEnv): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def compute_dense_reward(self, obs: Any, action: Array, info: Dict): + standing = rewards.tolerance(self.height, lower=_STAND_HEIGHT, upper=2.0) + return standing.view(-1) + + def compute_normalized_dense_reward(self, obs: Any, action: Array, info: Dict): + # this should be equal to compute_dense_reward / max possible reward + max_reward = 1.0 + return self.compute_dense_reward(obs=obs, action=action, info=info) / max_reward + + +@register_env("MS-HopperHop-v1", max_episode_steps=600) +class HopperHopEnv(HopperEnv): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def compute_dense_reward(self, obs: Any, action: Array, info: Dict): + standing = rewards.tolerance(self.height, lower=_STAND_HEIGHT, upper=2.0) + hopping = rewards.tolerance( + self.subtreelinvelx, + lower=_HOP_SPEED, + upper=float("inf"), + margin=_HOP_SPEED / 2, + value_at_margin=0.5, + sigmoid="linear", + ) + + return standing.view(-1) * hopping.view(-1) + + def compute_normalized_dense_reward(self, obs: Any, action: Array, info: Dict): + max_reward = 1.0 + return self.compute_dense_reward(obs=obs, action=action, info=info) / max_reward diff --git a/mani_skill/envs/utils/rewards/common.py b/mani_skill/envs/utils/rewards/common.py index 23f9ee127..26cd78d04 100644 --- a/mani_skill/envs/utils/rewards/common.py +++ b/mani_skill/envs/utils/rewards/common.py @@ -1,3 +1,4 @@ +import numpy as np import torch @@ -23,7 +24,7 @@ def tolerance( 'linear', 'hyperbolic', 'long_tail', 'cosine', 'tanh_squared'. value_at_margin: A float between 0 and 1 specifying the output value when the distance from `x` to the nearest bound is equal to `margin`. Ignored - if `margin == 0`. todo: not implemented yet + if `margin == 0`. Returns: A torch array with values between 0.0 and 1.0. @@ -31,7 +32,24 @@ def tolerance( Raises: ValueError: If `bounds[0] > bounds[1]`. ValueError: If `margin` is negative. + ValueError: If not 0 < `value_at_margin` < 1, + except for `linear`, `cosine` and `quadratic` sigmoids, which allow `value_at_margin` == 0. + ValueError: If `sigmoid` is of an unknown type. """ + + if sigmoid in ("cosine", "linear", "quadratic"): + if not 0 <= value_at_margin < 1: + raise ValueError( + "`value_at_margin` must be nonnegative and smaller than 1, " + "got {}.".format(value_at_margin) + ) + else: + if not 0 < value_at_margin < 1: + raise ValueError( + "`value_at_margin` must be strictly between 0 and 1, " + "got {}.".format(value_at_margin) + ) + if lower > upper: raise ValueError("Lower bound must be <= upper bound.") @@ -45,13 +63,25 @@ def tolerance( else: d = torch.where(x < lower, lower - x, x - upper) / margin if sigmoid == "gaussian": + scale = np.sqrt(-2 * np.log(value_at_margin)) value = torch.where( - in_bounds, torch.tensor(1.0), torch.exp(-0.5 * (d**2)) + in_bounds, torch.tensor(1.0), torch.exp(-0.5 * (d * scale) ** 2) ) elif sigmoid == "hyperbolic": - value = torch.where(in_bounds, torch.tensor(1.0), 1 / (1 + torch.exp(d))) + scale = np.arccosh(1 / value_at_margin) + value = torch.where( + in_bounds, torch.tensor(1.0), 1 / (1 + torch.exp(d * scale)) + ) elif sigmoid == "quadratic": - value = torch.where(in_bounds, torch.tensor(1.0), 1 - d**2) + scale = np.sqrt(1 - value_at_margin) + scaled_d = d * scale + x = torch.where(scaled_d.abs() < 1, 1 - scaled_d**2, torch.tensor(0.0)) + value = torch.where(in_bounds, torch.tensor(1.0), x) + elif sigmoid == "linear": + scale = 1 - value_at_margin + scaled_d = d * scale + x = torch.where(scaled_d.abs() < 1, 1 - scaled_d, torch.tensor(0.0)) + value = torch.where(in_bounds, torch.tensor(1.0), x) else: raise ValueError(f"Unknown sigmoid type {sigmoid!r}.") diff --git a/mani_skill/utils/building/ground.py b/mani_skill/utils/building/ground.py index 550a6597d..f1543c841 100644 --- a/mani_skill/utils/building/ground.py +++ b/mani_skill/utils/building/ground.py @@ -17,6 +17,8 @@ def build_ground( scene: ManiSkillScene, floor_width: int = 100, + floor_length: int = None, + xy_origin: tuple = (0, 0), altitude=0, name="ground", ): @@ -35,14 +37,17 @@ def build_ground( actor = ground.build_static(name=name) # generate a grid of right triangles that form 1x1 meter squares centered at (0, 0, 0) - num_verts = (floor_width + 1) ** 2 + floor_length = floor_width if floor_length is None else floor_length + num_verts = (floor_width + 1) * (floor_length + 1) vertices = np.zeros((num_verts, 3)) floor_half_width = floor_width / 2 - ranges = np.arange(start=-floor_half_width, stop=floor_half_width + 1) - xx, yy = np.meshgrid(ranges, ranges) + floor_half_length = floor_length / 2 + xrange = np.arange(start=-floor_half_width, stop=floor_half_width + 1) + yrange = np.arange(start=-floor_half_length, stop=floor_half_length + 1) + xx, yy = np.meshgrid(xrange, yrange) xys = np.stack((yy, xx), axis=2).reshape(-1, 2) - vertices[:, 0] = xys[:, 0] - vertices[:, 1] = xys[:, 1] + vertices[:, 0] = xys[:, 0] + xy_origin[0] + vertices[:, 1] = xys[:, 1] + xy_origin[1] vertices[:, 2] = altitude normals = np.zeros((len(vertices), 3)) normals[:, 2] = 1 @@ -61,7 +66,7 @@ def build_ground( # TODO: This is fast but still two for loops which is a little annoying triangles = [] - for i in range(floor_width): + for i in range(floor_length): triangles.append( np.stack( [ @@ -72,7 +77,7 @@ def build_ground( axis=1, ) ) - for i in range(floor_width): + for i in range(floor_length): triangles.append( np.stack( [ diff --git a/mani_skill/utils/scene_builder/control/planar/scene_builder.py b/mani_skill/utils/scene_builder/control/planar/scene_builder.py new file mode 100644 index 000000000..9827cd498 --- /dev/null +++ b/mani_skill/utils/scene_builder/control/planar/scene_builder.py @@ -0,0 +1,34 @@ +from typing import List + +import numpy as np +import sapien +import sapien.render +import torch +from transforms3d.euler import euler2quat + +from mani_skill.utils.building.ground import build_ground +from mani_skill.utils.scene_builder import SceneBuilder + + +class PlanarSceneBuilder(SceneBuilder): + def build(self, build_config_idxs: List[int] = None): + # ground - a strip with length along +x + self.ground = build_ground( + self.scene, + floor_width=2, + floor_length=100, + altitude=0, + xy_origin=(50 - 2, 0), + ) + + # background visual wall + self.wall = self.scene.create_actor_builder() + self.wall.add_box_visual( + half_size=(1e-3, 65, 10), + pose=sapien.Pose(p=[(50 - 2), 2, 0], q=euler2quat(0, 0, np.pi / 2)), + material=sapien.render.RenderMaterial( + base_color=np.array([0.3, 0.3, 0.3, 1]) + ), + ) + self.wall.build_static(name="wall") + self.scene_objects: List[sapien.Entity] = [self.ground, self.wall]