Skip to content

Commit

Permalink
Pick clutter task (GPU) (#254)
Browse files Browse the repository at this point in the history
* init

* Update pick_clutter_ycb.py

* fix bug with reset mask

* support more indexing options for Pose, actor_views and articulation_views, pick clutter task work, bug fix with pd_ee_pose controller

* return in reset info whether we reconfigure, necessary to know in order to determine when to delete trajectory buffer in record wrapper

* Update actor.py

* Update sapien_env.py

* Update pick_clutter_ycb.py

* docs

* Update pick_clutter_ycb.py
  • Loading branch information
StoneT2000 authored Apr 2, 2024
1 parent 1752391 commit b309aef
Show file tree
Hide file tree
Showing 16 changed files with 296 additions and 48 deletions.
3 changes: 2 additions & 1 deletion docs/source/user_guide/tutorials/custom_tasks_advanced.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,11 @@ class GPUMemoryConfig:
temp_buffer_capacity: int = 2**24
"""Increase this if you get 'PxgPinnedHostLinearMemoryAllocator: overflowing initial allocation size, increase capacity to at least %.' """
max_rigid_contact_count: int = 2**19
"""Increase this if you get 'Contact buffer overflow detected'"""
max_rigid_patch_count: int = (
2**18
) # 81920 is SAPIEN default but most tasks work with 2**18
"""Increase this if you get 'Patch buffer overflow detected'"""
heap_capacity: int = 2**26
found_lost_pairs_capacity: int = (
2**25
Expand All @@ -144,7 +146,6 @@ class GPUMemoryConfig:
def dict(self):
return {k: v for k, v in asdict(self).items()}
@dataclass
class SceneConfig:
gravity: np.ndarray = field(default_factory=lambda: np.array([0, 0, -9.81]))
Expand Down
7 changes: 6 additions & 1 deletion mani_skill/agents/controllers/pd_ee_pose.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def reset(self):
] = self.ee_pose_at_base.raw_pose[self.scene._reset_mask]

def compute_ik(
self, target_pose: Pose, action: Array, pos_only=False, max_iterations=100
self, target_pose: Pose, action: Array, pos_only=True, max_iterations=100
):
# Assume the target pose is defined in the base frame
if physx.is_gpu_enabled():
Expand Down Expand Up @@ -242,6 +242,11 @@ def _clip_and_scale_action(self, action):
rot_action = rot_action * self.config.rot_bound
return torch.hstack([pos_action, rot_action])

def compute_ik(self, target_pose: Pose, action: Array, max_iterations=100):
return super().compute_ik(
target_pose, action, pos_only=False, max_iterations=max_iterations
)

def compute_target_pose(self, prev_ee_pose_at_base: Pose, action):
if self.config.use_delta:
delta_pos, delta_rot = action[:, 0:3], action[:, 3:6]
Expand Down
14 changes: 13 additions & 1 deletion mani_skill/envs/sapien_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,14 @@ def reset(self, seed=None, options=None):
options["reconfigure"] is True, will call self._reconfigure() which deletes the entire physx scene and reconstructs everything.
Users building custom tasks generally do not need to override this function.
Returns the first observation and a info dictionary. The info dictionary is of type
```
{
"reconfigure": bool (True if the environment reconfigured. False otherwise)
}
Note that ManiSkill always holds two RNG states, a main RNG, and an episode RNG. The main RNG is used purely to sample an episode seed which
helps with reproducibility of episodes and is for internal use only. The episode RNG is used by the environment/task itself to
e.g. randomize object positions, randomize assets etc. Episode RNG is accessible by using torch.rand (recommended) which is seeded with a
Expand Down Expand Up @@ -655,6 +663,10 @@ def reset(self, seed=None, options=None):
torch.manual_seed(seed=self._episode_seed)
self._reconfigure(options)
self._after_reconfigure(options)

# TODO (stao): Reconfiguration when there is partial reset might not make sense and certainly broken here now.
# Solution to resolve that would be to ensure tasks that do reconfigure more than once are single-env only / cpu sim only
# or disable partial reset features explicitly for tasks that have a reconfiguration frequency
if "env_idx" in options:
env_idx = options["env_idx"]
self._scene._reset_mask = torch.zeros(
Expand Down Expand Up @@ -694,7 +706,7 @@ def reset(self, seed=None, options=None):
if not physx.is_gpu_enabled():
obs = sapien_utils.to_numpy(sapien_utils.unbatch(obs))
self._elapsed_steps = 0
return obs, {}
return obs, dict(reconfigure=reconfigure)

def _set_main_rng(self, seed):
"""Set the main random generator which is only used to set the seed of the episode RNG to improve reproducibility.
Expand Down
7 changes: 6 additions & 1 deletion mani_skill/envs/scene.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ def __init__(
self.actors: Dict[str, Actor] = OrderedDict()
self.articulations: Dict[str, Articulation] = OrderedDict()

self.actor_views: Dict[str, Actor] = OrderedDict()
"""views of actors in any sub-scenes created by using Actor.merge and queryable as if it were a single Actor"""
self.articulation_views: Dict[str, Articulation] = OrderedDict()
"""views of articulations in any sub-scenes created by using Articulation.merge and queryable as if it were a single Articulation"""

self.sensors: Dict[str, BaseSensor] = OrderedDict()
self.human_render_cameras: Dict[str, Camera] = OrderedDict()

Expand Down Expand Up @@ -527,7 +532,7 @@ def _setup_gpu(self):
# As physx_system.gpu_init() was called a single physx step was also taken. So we need to reset
# all the actors and articulations to their original poses as they likely have collided
for actor in self.non_static_actors:
actor.set_pose(actor._builder_initial_pose)
actor.set_pose(actor.inital_pose)
self.px.cuda_rigid_body_data.torch()[:, 7:] = (
self.px.cuda_rigid_body_data.torch()[:, 7:] * 0
) # zero out all velocities
Expand Down
3 changes: 2 additions & 1 deletion mani_skill/envs/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
from .lift_peg_upright import LiftPegUprightEnv
from .open_cabinet_drawer import OpenCabinetDoorEnv, OpenCabinetDrawerEnv
from .peg_insertion_side import PegInsertionSideEnv
from .pick_clutter_ycb import PickClutterYCBEnv
from .pick_cube import PickCubeEnv
from .pick_single_ycb import PickSingleYCBEnv
from .pull_cube import PullCubeEnv
from .push_cube import PushCubeEnv
from .quadruped_run import QuadrupedRunEnv
from .quadruped_stand import QuadrupedStandEnv
from .rotate_cube import RotateCubeEnv
from .stack_cube import StackCubeEnv
from .two_robot_pick_cube import TwoRobotPickCube
from .two_robot_stack_cube import TwoRobotStackCube
from .rotate_cube import RotateCubeEnv
200 changes: 200 additions & 0 deletions mani_skill/envs/tasks/pick_clutter_ycb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
import os
from collections import OrderedDict
from typing import Any, Dict, List, Union

import numpy as np
import sapien
import torch

from mani_skill import ASSET_DIR
from mani_skill.agents.robots import Fetch, Panda
from mani_skill.envs.sapien_env import BaseEnv
from mani_skill.sensors.camera import CameraConfig
from mani_skill.utils import sapien_utils
from mani_skill.utils.building import actors
from mani_skill.utils.building.actor_builder import ActorBuilder
from mani_skill.utils.io_utils import load_json
from mani_skill.utils.registration import register_env
from mani_skill.utils.scene_builder.table.table_scene_builder import TableSceneBuilder
from mani_skill.utils.structs import Actor, Pose
from mani_skill.utils.structs.types import GPUMemoryConfig, SimConfig


#
class PickClutterEnv(BaseEnv):
"""Base environment picking items out of clutter type of tasks. Flexibly supports using different configurations and object datasets"""

SUPPORTED_REWARD_MODES = ["sparse", "none"]
SUPPORTED_ROBOTS = ["panda", "fetch"]
agent: Union[Panda, Fetch]

DEFAULT_EPISODE_JSON: str
DEFAULT_ASSET_ROOT: str
DEFAULT_MODEL_JSON: str

def __init__(
self,
*args,
robot_uids="panda",
robot_init_qpos_noise=0.02,
episode_json: str = None,
**kwargs,
):
self.robot_init_qpos_noise = robot_init_qpos_noise

if episode_json is None:
episode_json = self.DEFAULT_EPISODE_JSON
if not os.path.exists(episode_json):
raise FileNotFoundError(
f"Episode json ({episode_json}) is not found."
"To download default json:"
"`python -m mani_skill2.utils.download_asset pick_clutter_ycb`."
)
self._episodes: List[Dict] = load_json(episode_json)

super().__init__(*args, robot_uids=robot_uids, **kwargs)

if self.num_envs == 1:
# with just one environment there isn't going to be a lot of geometrical variation
# so setting the freq below to 1 ensures each reset (while a little slower) changes the loaded geometries
self.reconfiguration_freq = 1

@property
def _default_sim_cfg(self):
return SimConfig(
gpu_memory_cfg=GPUMemoryConfig(
max_rigid_contact_count=2**21, max_rigid_patch_count=2**19
)
)

@property
def _sensor_configs(self):
pose = sapien_utils.look_at(eye=[0.3, 0, 0.6], target=[-0.1, 0, 0.1])
return [
CameraConfig(
"base_camera",
pose=pose,
width=128,
height=128,
fov=np.pi / 2,
near=0.01,
far=100,
)
]

@property
def _human_render_camera_configs(self):
pose = sapien_utils.look_at([0.6, 0.7, 0.6], [0.0, 0.0, 0.35])
return CameraConfig(
"render_camera", pose=pose, width=512, height=512, fov=1, near=0.01, far=100
)

def _load_model(self, model_id: str) -> ActorBuilder:
raise NotImplementedError()

def _load_scene(self, options: dict):
self.scene_builder = TableSceneBuilder(
self, robot_init_qpos_noise=self.robot_init_qpos_noise
)
self.scene_builder.build()

# sample some clutter configurations
eps_idxs = np.arange(0, len(self._episodes))
rand_idx = torch.randperm(len(eps_idxs), device=torch.device("cpu"))
eps_idxs = eps_idxs[rand_idx]
eps_idxs = np.concatenate(
[eps_idxs] * np.ceil(self.num_envs / len(eps_idxs)).astype(int)
)[: self.num_envs]

self.selectable_target_objects: List[List[Actor]] = []
"""for each sub-scene, a list of objects that can be selected as targets"""
all_objects = []

for i, eps_idx in enumerate(eps_idxs):
self.selectable_target_objects.append([])
episode = self._episodes[eps_idx]
for actor_cfg in episode["actors"]:
builder = self._load_model(actor_cfg["model_id"])
init_pose = actor_cfg["pose"]
builder.initial_pose = sapien.Pose(p=init_pose[:3], q=init_pose[3:])
builder.set_scene_idxs([i])
obj = builder.build(name=f"set_{i}_{actor_cfg['model_id']}")
all_objects.append(obj)
if actor_cfg["rep_pts"] is not None:
# TODO (stao): what is rep_pts?, this is taken from ms2 code
self.selectable_target_objects[-1].append(obj)

self.all_objects = Actor.merge(all_objects, name="all_objects")

self.goal_site = actors.build_sphere(
self._scene,
radius=0.01,
color=[0, 1, 0, 1],
name="goal_site",
body_type="kinematic",
add_collision=False,
)
self._hidden_objects.append(self.goal_site)

self._sample_target_objects()

def _sample_target_objects(self):
# note this samples new target objects for every sub-scene
target_objects = []
for i in range(self.num_envs):
selected_obj_idxs = torch.randint(low=0, high=99999, size=(self.num_envs,))
selected_obj_idxs[i] = selected_obj_idxs[i] % len(
self.selectable_target_objects[-1]
)
target_objects.append(
self.selectable_target_objects[-1][selected_obj_idxs[i]]
)
self.target_object = Actor.merge(target_objects, name="target_object")

def _initialize_episode(self, env_idx: torch.Tensor, options: dict):
with torch.device(self.device):
b = len(env_idx)
self.scene_builder.initialize(env_idx)
goal_pos = torch.rand(size=(b, 3)) * torch.tensor(
[0.3, 0.5, 0.1]
) + torch.tensor([-0.15, -0.25, 0.35])
self.goal_pos = goal_pos
self.goal_site.set_pose(Pose.create_from_pq(self.goal_pos))

# reset objects to original poses
if b == self.num_envs:
# if all envs reset
self.all_objects.pose = self.all_objects.inital_pose
else:
# if only some envs reset, we unfortunately still have to do some mask wrangling
mask = torch.isin(self.all_objects._scene_idxs, env_idx)
self.all_objects.pose = self.all_objects.inital_pose[mask]

def evaluate(self):
return {
"success": torch.zeros(self.num_envs, device=self.device, dtype=bool),
"fail": torch.zeros(self.num_envs, device=self.device, dtype=bool),
}

def _get_obs_extra(self, info: Dict):
return OrderedDict()

def compute_dense_reward(self, obs: Any, action: torch.Tensor, info: Dict):
return torch.zeros(self.num_envs, device=self.device)

def compute_normalized_dense_reward(
self, obs: Any, action: torch.Tensor, info: Dict
):
max_reward = 1.0
return self.compute_dense_reward(obs=obs, action=action, info=info) / max_reward


@register_env("PickClutterYCB-v1", max_episode_steps=100)
class PickClutterYCBEnv(PickClutterEnv):
DEFAULT_EPISODE_JSON = f"{ASSET_DIR}/tasks/pick_clutter/ycb_train_5k.json.gz"

def _load_model(self, model_id):
builder, _ = actors.build_actor_ycb(
model_id, self._scene, name=model_id, return_builder=True
)
return builder
1 change: 1 addition & 0 deletions mani_skill/envs/tasks/pick_single_ycb.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def _load_scene(self, options: dict):
actors: List[Actor] = []
self.obj_heights = []
for i, model_id in enumerate(model_ids):
# TODO: before official release we will finalize a metadata dataclass that these build functions should return.
builder, obj_height = build_actor_ycb(
model_id, self._scene, name=model_id, return_builder=True
)
Expand Down
6 changes: 2 additions & 4 deletions mani_skill/utils/building/actor_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,10 +213,8 @@ def build(self, name):
and initial_pose_b == 1
and physx.is_gpu_enabled()
):
actor._builder_initial_pose = Pose.create(
initial_pose.raw_pose.repeat(num_actors, 1)
)
actor.inital_pose = Pose.create(initial_pose.raw_pose.repeat(num_actors, 1))
else:
actor._builder_initial_pose = initial_pose
actor.inital_pose = initial_pose
self.scene.actors[self.name] = actor
return actor
Loading

0 comments on commit b309aef

Please sign in to comment.