diff --git a/mani_skill/envs/tasks/tabletop/turn_faucet.py b/mani_skill/envs/tasks/tabletop/turn_faucet.py index aa45911f4..8d564e4f0 100644 --- a/mani_skill/envs/tasks/tabletop/turn_faucet.py +++ b/mani_skill/envs/tasks/tabletop/turn_faucet.py @@ -135,10 +135,7 @@ def _load_scene(self, options: dict): self.target_angle[torch.isinf(qmax)] = torch.pi / 2 # the angle to go self.target_angle_diff = self.target_angle - self.init_angle - joint_pose = ( - self.target_switch_link.joint.get_global_pose().to_transformation_matrix() - ) - self.target_joint_axis = joint_pose[:, :3, 0] + self.target_joint_axis = torch.zeros((self.num_envs, 3), device=self.device) def _initialize_episode(self, env_idx: torch.Tensor, options: dict): with torch.device(self.device): @@ -165,7 +162,10 @@ def _initialize_episode(self, env_idx: torch.Tensor, options: dict): self.target_switch_link.pose * self.target_switch_link.cmass_local_pose ) self.target_link_pos = cmass_pose.p - + joint_pose = ( + self.target_switch_link.joint.get_global_pose().to_transformation_matrix() + ) + self.target_joint_axis[env_idx] = joint_pose[env_idx, :3, 0] # self.handle_link_goal.set_pose(cmass_pose) @property diff --git a/mani_skill/examples/demo_manual_control.py b/mani_skill/examples/demo_manual_control.py index b034c1ddf..2eb5912dc 100644 --- a/mani_skill/examples/demo_manual_control.py +++ b/mani_skill/examples/demo_manual_control.py @@ -1,10 +1,13 @@ import argparse +import signal import gymnasium as gym +from matplotlib import pyplot as plt import numpy as np +signal.signal(signal.SIGINT, signal.SIG_DFL) # allow ctrl+c from mani_skill.envs.sapien_env import BaseEnv -from mani_skill.utils.visualization.cv2_utils import OpenCVViewer +from mani_skill.utils import visualization from mani_skill.utils.wrappers import RecordEpisode @@ -58,7 +61,19 @@ def main(): # Viewer if args.enable_sapien_viewer: env.render_human() - opencv_viewer = OpenCVViewer(exit_on_esc=False) + renderer = visualization.ImageRenderer() + # disable all default plt shortcuts that are lowercase letters + plt.rcParams["keymap.fullscreen"].remove("f") + plt.rcParams["keymap.home"].remove("h") + plt.rcParams["keymap.home"].remove("r") + plt.rcParams["keymap.back"].remove("c") + plt.rcParams["keymap.forward"].remove("v") + plt.rcParams["keymap.pan"].remove("p") + plt.rcParams["keymap.zoom"].remove("o") + plt.rcParams["keymap.save"].remove("s") + plt.rcParams["keymap.grid"].remove("g") + plt.rcParams["keymap.yscale"].remove("l") + plt.rcParams["keymap.xscale"].remove("k") def render_wait(): if not args.enable_sapien_viewer: @@ -83,21 +98,22 @@ def render_wait(): if args.enable_sapien_viewer: env.render_human() - render_frame = env.render() + render_frame = env.render().cpu().numpy()[0] if after_reset: after_reset = False # Re-focus on opencv viewer if args.enable_sapien_viewer: - opencv_viewer.close() - opencv_viewer = OpenCVViewer(exit_on_esc=False) - + renderer.close() + renderer = visualization.ImageRenderer() + pass # -------------------------------------------------------------------------- # # Interaction # -------------------------------------------------------------------------- # # Input - key = opencv_viewer.imshow(render_frame) - + renderer(render_frame) + # key = opencv_viewer.imshow(render_frame.cpu().numpy()[0]) + key = renderer.last_event.key if renderer.last_event is not None else None body_action = np.zeros([3]) base_action = np.zeros([3]) # hardcoded for fetch robot @@ -192,18 +208,7 @@ def render_wait(): # Visualize observation if key == "v": - if "rgbd" in env.obs_mode: - from itertools import chain - - from mani_skill.utils.visualization.misc import ( - observations_to_images, tile_images) - - images = list( - chain(*[observations_to_images(x) for x in obs["image"].values()]) - ) - render_frame = tile_images(images) - opencv_viewer.imshow(render_frame) - elif "pointcloud" in env.obs_mode: + if "pointcloud" in env.obs_mode: import trimesh xyzw = obs["pointcloud"]["xyzw"] diff --git a/mani_skill/examples/demo_random_action.py b/mani_skill/examples/demo_random_action.py index fe65acf68..bdf770b81 100644 --- a/mani_skill/examples/demo_random_action.py +++ b/mani_skill/examples/demo_random_action.py @@ -14,7 +14,7 @@ def parse_args(args=None): parser.add_argument("-b", "--sim-backend", type=str, default="auto", help="Which simulation backend to use. Can be 'auto', 'cpu', 'gpu'") parser.add_argument("--reward-mode", type=str) parser.add_argument("-c", "--control-mode", type=str) - parser.add_argument("--render-mode", type=str) + parser.add_argument("--render-mode", type=str, default="rgb_array") parser.add_argument("--shader", default="default", type=str, help="Change shader used for rendering. Default is 'default' which is very fast. Can also be 'rt' for ray tracing and generating photo-realistic renders. Can also be 'rt-fast' for a faster but lower quality ray-traced renderer") parser.add_argument("--record-dir", type=str) parser.add_argument("-p", "--pause", action="store_true", help="If using human render mode, auto pauses the simulation upon loading") @@ -58,7 +58,7 @@ def main(args): record_dir = args.record_dir if record_dir: record_dir = record_dir.format(env_id=args.env_id) - env = RecordEpisode(env, record_dir) + env = RecordEpisode(env, record_dir, info_on_video=True) if verbose: print("Observation space", env.observation_space) diff --git a/mani_skill/examples/demo_vis_rgbd.py b/mani_skill/examples/demo_vis_rgbd.py index 84101a8e9..a91b34928 100644 --- a/mani_skill/examples/demo_vis_rgbd.py +++ b/mani_skill/examples/demo_vis_rgbd.py @@ -1,13 +1,15 @@ import signal +import sys + +from matplotlib import pyplot as plt from mani_skill.utils import common -from mani_skill.utils.visualization.misc import tile_images +from mani_skill.utils import visualization signal.signal(signal.SIGINT, signal.SIG_DFL) # allow ctrl+c import argparse import gymnasium as gym -import cv2 import numpy as np from mani_skill.envs.sapien_env import BaseEnv @@ -27,6 +29,12 @@ def parse_args(args=None): args = parser.parse_args() return args +import matplotlib.pyplot as plt +import numpy as np + + + + def main(args): if args.seed is not None: @@ -50,6 +58,8 @@ def main(args): n_cams += 1 print(f"Visualizing {n_cams} RGBD cameras") + renderer = visualization.ImageRenderer() + while True: action = env.action_space.sample() obs, reward, terminated, truncated, info = env.step(action) @@ -66,11 +76,8 @@ def main(args): depth_rgb[..., :] = depth*255 imgs.append(depth_rgb) cam_num += 1 - img = tile_images(imgs, nrows=n_cams) - - cv2.imshow('image',cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) - cv2.waitKey(0) - + img = visualization.tile_images(imgs, nrows=n_cams) + renderer(img) if __name__ == "__main__": main(parse_args()) diff --git a/mani_skill/examples/demo_vis_segmentation.py b/mani_skill/examples/demo_vis_segmentation.py index f496a4c06..4cb5264dc 100644 --- a/mani_skill/examples/demo_vis_segmentation.py +++ b/mani_skill/examples/demo_vis_segmentation.py @@ -1,13 +1,12 @@ import signal from mani_skill.utils import common -from mani_skill.utils.visualization.misc import tile_images +from mani_skill.utils import visualization signal.signal(signal.SIGINT, signal.SIG_DFL) # allow ctrl+c import argparse import gymnasium as gym -import cv2 import numpy as np # color pallete generated via https://medialab.github.io/iwanthue/ color_pallete = np.array([[164,74,82], @@ -118,7 +117,7 @@ def main(args): if selected_id is not None and not isinstance(selected_id, int): selected_id = reverse_seg_id_map[selected_id] - + renderer = visualization.ImageRenderer() while True: action = env.action_space.sample() obs, reward, terminated, truncated, info = env.step(action) @@ -138,11 +137,8 @@ def main(args): seg_rgb[(seg == id)[..., 0]] = color imgs.append(seg_rgb) cam_num += 1 - img = tile_images(imgs, nrows=n_cams) - - cv2.imshow('image',cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) - cv2.waitKey(0) - + img = visualization.tile_images(imgs, nrows=n_cams) + renderer(img) if __name__ == "__main__": main(parse_args()) diff --git a/mani_skill/utils/visualization/UbuntuSansMono-Regular.ttf b/mani_skill/utils/visualization/UbuntuSansMono-Regular.ttf new file mode 100644 index 000000000..9a8ccca89 Binary files /dev/null and b/mani_skill/utils/visualization/UbuntuSansMono-Regular.ttf differ diff --git a/mani_skill/utils/visualization/__init__.py b/mani_skill/utils/visualization/__init__.py index f3bbaaf75..7b2d36824 100644 --- a/mani_skill/utils/visualization/__init__.py +++ b/mani_skill/utils/visualization/__init__.py @@ -1,7 +1,5 @@ -from .cv2_utils import OpenCVViewer, images_to_video_cv2 from .jupyter_utils import display_images from .misc import ( - append_text_to_image, images_to_video, normalize_depth, observations_to_images, @@ -9,3 +7,4 @@ put_text_on_image, tile_images, ) +from .renderer import ImageRenderer diff --git a/mani_skill/utils/visualization/cv2_utils.py b/mani_skill/utils/visualization/cv2_utils.py deleted file mode 100644 index 01623a329..000000000 --- a/mani_skill/utils/visualization/cv2_utils.py +++ /dev/null @@ -1,74 +0,0 @@ -import os -from typing import List - -import cv2 -import numpy as np -import tqdm - - -def images_to_video_cv2( - images: List[np.ndarray], - output_dir: str, - video_name: str, - fps: int = 10, - verbose: bool = True, - is_rgb=True, -): - if not os.path.exists(output_dir): - os.makedirs(output_dir) - video_name = video_name.replace(" ", "_").replace("\n", "_") + ".mp4" - output_path = os.path.join(output_dir, video_name) - image_shape = images[0].shape - frame_size = (image_shape[1], image_shape[0]) - fourcc = cv2.VideoWriter_fourcc("m", "p", "4", "v") - writer = cv2.VideoWriter(output_path, fourcc, fps, frame_size) - if verbose: - print(f"Video created: {output_path}") - images_iter = tqdm.tqdm(images) - else: - images_iter = images - for im in images_iter: - im = im[..., 0:3] - if is_rgb: - im = im[..., ::-1] - writer.write(im) - writer.release() - - -class OpenCVViewer: - def __init__(self, name="OpenCVViewer", is_rgb=True, exit_on_esc=True): - self.name = name - cv2.namedWindow(name, cv2.WINDOW_AUTOSIZE) - self.is_rgb = is_rgb - self.exit_on_esc = exit_on_esc - - def imshow(self, image: np.ndarray, is_rgb=None, non_blocking=False, delay=0): - if image.ndim == 2: - image = np.tile(image[..., np.newaxis], (1, 1, 3)) - elif image.ndim == 3 and image.shape[-1] == 1: - image = np.tile(image, (1, 1, 3)) - assert image.ndim == 3, image.shape - - if self.is_rgb or is_rgb: - image = image[..., ::-1] - cv2.imshow(self.name, image) - - if non_blocking: - return - else: - key = cv2.waitKey(delay) - if key == 27: # escape - if self.exit_on_esc: - exit(0) - else: - return None - elif key == -1: # timeout - pass - else: - return chr(key) - - def close(self): - cv2.destroyWindow(self.name) - - def __del__(self): - self.close() diff --git a/mani_skill/utils/visualization/misc.py b/mani_skill/utils/visualization/misc.py index 5a89710a0..6539b6e22 100644 --- a/mani_skill/utils/visualization/misc.py +++ b/mani_skill/utils/visualization/misc.py @@ -1,11 +1,11 @@ import os from typing import Dict, List, Optional -import cv2 import imageio import numpy as np import torch import tqdm +from PIL import Image, ImageDraw, ImageFont from mani_skill.utils.structs.types import Array @@ -170,75 +170,32 @@ def tile_images(images: List[Array], nrows=1) -> Array: return output_image -def put_text_on_image(image: Array, lines: List[str]): - assert image.dtype == np.uint8, image.dtype - image = image.copy() +TEXT_FONT = None - font_size = 0.5 - font_thickness = 1 - font = cv2.FONT_HERSHEY_SIMPLEX - y = 0 - for line in lines: - textsize = cv2.getTextSize(line, font, font_size, font_thickness)[0] - y += textsize[1] + 10 - x = 10 - cv2.putText( - image, - line, - (x, y), - font, - font_size, - (0, 255, 0), - font_thickness, - lineType=cv2.LINE_AA, +def put_text_on_image(image: np.ndarray, lines: List[str]): + global TEXT_FONT + assert image.dtype == np.uint8, image.dtype + image = image.copy() + image = Image.fromarray(image) + draw = ImageDraw.Draw(image) + if TEXT_FONT is None: + TEXT_FONT = ImageFont.truetype( + os.path.join(os.path.dirname(__file__), "UbuntuSansMono-Regular.ttf"), + size=16, ) - return image - - -def append_text_to_image(image: Array, lines: List[str]): - r"""Appends text left to an image of size (height, width, channels). - The returned image has white text on a black background. - Args: - image: the image to put text - text: a string to display - Returns: - A new image with text inserted left to the input image - See also: - habitat.utils.visualization.utils - """ - # h, w, c = image.shape - font_size = 0.5 - font_thickness = 1 - font = cv2.FONT_HERSHEY_SIMPLEX - blank_image = np.zeros(image.shape, dtype=np.uint8) - - y = 0 + y = -10 for line in lines: - textsize = cv2.getTextSize(line, font, font_size, font_thickness)[0] - y += textsize[1] + 10 + bbox = draw.textbbox((0, 0), text=line) + textheight = bbox[3] - bbox[1] + y += textheight + 10 x = 10 - cv2.putText( - blank_image, - line, - (x, y), - font, - font_size, - (255, 255, 255), - font_thickness, - lineType=cv2.LINE_AA, - ) - # text_image = blank_image[0 : y + 10, 0:w] - # final = np.concatenate((image, text_image), axis=0) - final = np.concatenate((blank_image, image), axis=1) - return final + draw.text((x, y), text=line, fill=(0, 255, 0), font=TEXT_FONT) + return np.array(image) def put_info_on_image(image, info: Dict[str, float], extras=None, overlay=True): lines = [f"{k}: {v:.3f}" for k, v in info.items()] if extras is not None: lines.extend(extras) - if overlay: - return put_text_on_image(image, lines) - else: - return append_text_to_image(image, lines) + return put_text_on_image(image, lines) diff --git a/mani_skill/utils/visualization/renderer.py b/mani_skill/utils/visualization/renderer.py new file mode 100644 index 000000000..01295ccc7 --- /dev/null +++ b/mani_skill/utils/visualization/renderer.py @@ -0,0 +1,39 @@ +import sys + +from matplotlib import pyplot as plt + + +class ImageRenderer: + def __init__(self, wait_for_button_press=True): + """ + Create a very light-weight image renderer. + + Args: + wait_for_button_press (bool): If True, each call to this renderer will pause the process until the user presses any key. + event_handler: Code to run given an event / button press. If None the default is mapping 'escape' and 'q' to sys.exit(0) + """ + self._image = None + self.last_event = None + + def event_handler(self, event): + self.last_event = event + if event.key in ["q", "escape"]: + sys.exit(0) + + def __call__(self, buffer): + if not self._image: + plt.ion() + self.fig = plt.figure() + self._image = plt.imshow(buffer, animated=True) + self.fig.canvas.mpl_connect("key_press_event", self.event_handler) + else: + self._image.set_data(buffer) + plt.waitforbuttonpress() + plt.draw() + + def __del__(self): + self.close() + + def close(self): + plt.ioff() + plt.close() diff --git a/mani_skill/utils/wrappers/record.py b/mani_skill/utils/wrappers/record.py index cb7d64dbe..b5efed349 100644 --- a/mani_skill/utils/wrappers/record.py +++ b/mani_skill/utils/wrappers/record.py @@ -473,6 +473,7 @@ def step(self, action): image = self.capture_image() if self.info_on_video: + info = common.to_numpy(info) scalar_info = gym_utils.extract_scalars_from_info(info) if isinstance(rew, torch.Tensor) and len(rew.shape) > 1: rew = rew[0] diff --git a/tests/test_gpu_envs.py b/tests/test_gpu_envs.py index 08a6c0dd5..efe231c64 100644 --- a/tests/test_gpu_envs.py +++ b/tests/test_gpu_envs.py @@ -394,18 +394,20 @@ def test_fn(): hide_obj._body_data_index, :7 ].clone()[..., :3], p, + atol=1e-6, ).all() assert torch.isclose( hide_obj.px.cuda_rigid_body_data.torch()[ hide_obj._body_data_index, :7 ].clone()[..., 3:], q, + atol=1e-6, ).all() # 4. check that direct calls to raw_pose, pos, and rot same as before - assert torch.isclose(hide_obj.pose.raw_pose, raw_pose).all() - assert torch.isclose(hide_obj.pose.p, p).all() - assert torch.isclose(hide_obj.pose.q, q).all() + assert torch.isclose(hide_obj.pose.raw_pose, raw_pose, atol=1e-6).all() + assert torch.isclose(hide_obj.pose.p, p, atol=1e-6).all() + assert torch.isclose(hide_obj.pose.q, q, atol=1e-6).all() # Test after reset hide_obj.hide_visual()