Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Remove cv2 #335

Merged
merged 5 commits into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions mani_skill/envs/tasks/tabletop/turn_faucet.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,7 @@ def _load_scene(self, options: dict):
self.target_angle[torch.isinf(qmax)] = torch.pi / 2
# the angle to go
self.target_angle_diff = self.target_angle - self.init_angle
joint_pose = (
self.target_switch_link.joint.get_global_pose().to_transformation_matrix()
)
self.target_joint_axis = joint_pose[:, :3, 0]
self.target_joint_axis = torch.zeros((self.num_envs, 3), device=self.device)

def _initialize_episode(self, env_idx: torch.Tensor, options: dict):
with torch.device(self.device):
Expand All @@ -165,7 +162,10 @@ def _initialize_episode(self, env_idx: torch.Tensor, options: dict):
self.target_switch_link.pose * self.target_switch_link.cmass_local_pose
)
self.target_link_pos = cmass_pose.p

joint_pose = (
self.target_switch_link.joint.get_global_pose().to_transformation_matrix()
)
self.target_joint_axis[env_idx] = joint_pose[env_idx, :3, 0]
# self.handle_link_goal.set_pose(cmass_pose)

@property
Expand Down
45 changes: 25 additions & 20 deletions mani_skill/examples/demo_manual_control.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import argparse
import signal

import gymnasium as gym
from matplotlib import pyplot as plt
import numpy as np

signal.signal(signal.SIGINT, signal.SIG_DFL) # allow ctrl+c
from mani_skill.envs.sapien_env import BaseEnv
from mani_skill.utils.visualization.cv2_utils import OpenCVViewer
from mani_skill.utils import visualization
from mani_skill.utils.wrappers import RecordEpisode


Expand Down Expand Up @@ -58,7 +61,19 @@ def main():
# Viewer
if args.enable_sapien_viewer:
env.render_human()
opencv_viewer = OpenCVViewer(exit_on_esc=False)
renderer = visualization.ImageRenderer()
# disable all default plt shortcuts that are lowercase letters
plt.rcParams["keymap.fullscreen"].remove("f")
plt.rcParams["keymap.home"].remove("h")
plt.rcParams["keymap.home"].remove("r")
plt.rcParams["keymap.back"].remove("c")
plt.rcParams["keymap.forward"].remove("v")
plt.rcParams["keymap.pan"].remove("p")
plt.rcParams["keymap.zoom"].remove("o")
plt.rcParams["keymap.save"].remove("s")
plt.rcParams["keymap.grid"].remove("g")
plt.rcParams["keymap.yscale"].remove("l")
plt.rcParams["keymap.xscale"].remove("k")

def render_wait():
if not args.enable_sapien_viewer:
Expand All @@ -83,21 +98,22 @@ def render_wait():
if args.enable_sapien_viewer:
env.render_human()

render_frame = env.render()
render_frame = env.render().cpu().numpy()[0]

if after_reset:
after_reset = False
# Re-focus on opencv viewer
if args.enable_sapien_viewer:
opencv_viewer.close()
opencv_viewer = OpenCVViewer(exit_on_esc=False)

renderer.close()
renderer = visualization.ImageRenderer()
pass
# -------------------------------------------------------------------------- #
# Interaction
# -------------------------------------------------------------------------- #
# Input
key = opencv_viewer.imshow(render_frame)

renderer(render_frame)
# key = opencv_viewer.imshow(render_frame.cpu().numpy()[0])
key = renderer.last_event.key if renderer.last_event is not None else None
body_action = np.zeros([3])
base_action = np.zeros([3]) # hardcoded for fetch robot

Expand Down Expand Up @@ -192,18 +208,7 @@ def render_wait():

# Visualize observation
if key == "v":
if "rgbd" in env.obs_mode:
from itertools import chain

from mani_skill.utils.visualization.misc import (
observations_to_images, tile_images)

images = list(
chain(*[observations_to_images(x) for x in obs["image"].values()])
)
render_frame = tile_images(images)
opencv_viewer.imshow(render_frame)
elif "pointcloud" in env.obs_mode:
if "pointcloud" in env.obs_mode:
import trimesh

xyzw = obs["pointcloud"]["xyzw"]
Expand Down
4 changes: 2 additions & 2 deletions mani_skill/examples/demo_random_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def parse_args(args=None):
parser.add_argument("-b", "--sim-backend", type=str, default="auto", help="Which simulation backend to use. Can be 'auto', 'cpu', 'gpu'")
parser.add_argument("--reward-mode", type=str)
parser.add_argument("-c", "--control-mode", type=str)
parser.add_argument("--render-mode", type=str)
parser.add_argument("--render-mode", type=str, default="rgb_array")
parser.add_argument("--shader", default="default", type=str, help="Change shader used for rendering. Default is 'default' which is very fast. Can also be 'rt' for ray tracing and generating photo-realistic renders. Can also be 'rt-fast' for a faster but lower quality ray-traced renderer")
parser.add_argument("--record-dir", type=str)
parser.add_argument("-p", "--pause", action="store_true", help="If using human render mode, auto pauses the simulation upon loading")
Expand Down Expand Up @@ -58,7 +58,7 @@ def main(args):
record_dir = args.record_dir
if record_dir:
record_dir = record_dir.format(env_id=args.env_id)
env = RecordEpisode(env, record_dir)
env = RecordEpisode(env, record_dir, info_on_video=True)

if verbose:
print("Observation space", env.observation_space)
Expand Down
21 changes: 14 additions & 7 deletions mani_skill/examples/demo_vis_rgbd.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import signal
import sys

from matplotlib import pyplot as plt

from mani_skill.utils import common
from mani_skill.utils.visualization.misc import tile_images
from mani_skill.utils import visualization
signal.signal(signal.SIGINT, signal.SIG_DFL) # allow ctrl+c

import argparse

import gymnasium as gym
import cv2
import numpy as np

from mani_skill.envs.sapien_env import BaseEnv
Expand All @@ -27,6 +29,12 @@ def parse_args(args=None):
args = parser.parse_args()
return args

import matplotlib.pyplot as plt
import numpy as np





def main(args):
if args.seed is not None:
Expand All @@ -50,6 +58,8 @@ def main(args):
n_cams += 1
print(f"Visualizing {n_cams} RGBD cameras")

renderer = visualization.ImageRenderer()

while True:
action = env.action_space.sample()
obs, reward, terminated, truncated, info = env.step(action)
Expand All @@ -66,11 +76,8 @@ def main(args):
depth_rgb[..., :] = depth*255
imgs.append(depth_rgb)
cam_num += 1
img = tile_images(imgs, nrows=n_cams)

cv2.imshow('image',cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
cv2.waitKey(0)

img = visualization.tile_images(imgs, nrows=n_cams)
renderer(img)

if __name__ == "__main__":
main(parse_args())
12 changes: 4 additions & 8 deletions mani_skill/examples/demo_vis_segmentation.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import signal

from mani_skill.utils import common
from mani_skill.utils.visualization.misc import tile_images
from mani_skill.utils import visualization
signal.signal(signal.SIGINT, signal.SIG_DFL) # allow ctrl+c

import argparse

import gymnasium as gym
import cv2
import numpy as np
# color pallete generated via https://medialab.github.io/iwanthue/
color_pallete = np.array([[164,74,82],
Expand Down Expand Up @@ -118,7 +117,7 @@ def main(args):
if selected_id is not None and not isinstance(selected_id, int):
selected_id = reverse_seg_id_map[selected_id]


renderer = visualization.ImageRenderer()
while True:
action = env.action_space.sample()
obs, reward, terminated, truncated, info = env.step(action)
Expand All @@ -138,11 +137,8 @@ def main(args):
seg_rgb[(seg == id)[..., 0]] = color
imgs.append(seg_rgb)
cam_num += 1
img = tile_images(imgs, nrows=n_cams)

cv2.imshow('image',cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
cv2.waitKey(0)

img = visualization.tile_images(imgs, nrows=n_cams)
renderer(img)

if __name__ == "__main__":
main(parse_args())
Binary file not shown.
3 changes: 1 addition & 2 deletions mani_skill/utils/visualization/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from .cv2_utils import OpenCVViewer, images_to_video_cv2
from .jupyter_utils import display_images
from .misc import (
append_text_to_image,
images_to_video,
normalize_depth,
observations_to_images,
put_info_on_image,
put_text_on_image,
tile_images,
)
from .renderer import ImageRenderer
74 changes: 0 additions & 74 deletions mani_skill/utils/visualization/cv2_utils.py

This file was deleted.

Loading