Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Add back pytorch kinematics which is likely more stable #485

Merged
merged 3 commits into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 20 additions & 96 deletions mani_skill/agents/controllers/pd_ee_pose.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,19 @@
from dataclasses import dataclass
from typing import List, Literal, Sequence, Union
from typing import Literal, Sequence, Union

try:
import fast_kinematics
except:
# not all systems support the fast_kinematics package at the moment
fast_kinematics = None
import numpy as np
import sapien.physx as physx
import torch
from gymnasium import spaces

from mani_skill import logger
from mani_skill.utils import common, gym_utils, sapien_utils
from mani_skill.agents.controllers.utils.kinematics import Kinematics
from mani_skill.utils import gym_utils
from mani_skill.utils.geometry.rotation_conversions import (
euler_angles_to_matrix,
matrix_to_quaternion,
quaternion_apply,
quaternion_multiply,
)
from mani_skill.utils.structs import ArticulationJoint, Pose
from mani_skill.utils.structs import Pose
from mani_skill.utils.structs.types import Array, DriveMode

from .base_controller import ControllerConfig
Expand All @@ -46,57 +40,16 @@ def _check_gpu_sim_works(self):
def _initialize_joints(self):
self.initial_qpos = None
super()._initialize_joints()
if self.config.ee_link:
self.ee_link = sapien_utils.get_obj_by_name(
self.articulation.get_links(), self.config.ee_link
)
else:
# The child link of last joint is assumed to be the end-effector.
self.ee_link = self.joints[-1].get_child_link()
logger.warn(
"Configuration did not define a ee_link name, using the child link of the last joint"
)
self.ee_link_idx = self.articulation.get_links().index(self.ee_link)

if physx.is_gpu_enabled():
assert (
fast_kinematics is not None
), "fast_kinematics is not installed. This is likely because your system does not support the fast_kinematics library which provides GPU accelerated inverse kinematics solvers"
if self.device.type == "cuda":
self._check_gpu_sim_works()
self.fast_kinematics_model = fast_kinematics.FastKinematics(
self.config.urdf_path, self.scene.num_envs, self.config.ee_link
)
# note that everything past the end-effector is ignored. Any joint whose ancestor is self.ee_link is ignored
# get_joints returns the joints in level order
# for joint in joints
cur_link = self.ee_link.joint.parent_link
active_ancestor_joints: List[ArticulationJoint] = []
while cur_link is not None:
if cur_link.joint.active_index is not None:
active_ancestor_joints.append(cur_link.joint)
cur_link = cur_link.joint.parent_link
active_ancestor_joints = active_ancestor_joints[::-1]
self.active_ancestor_joints = active_ancestor_joints

# initially self.active_joint_indices references active joints that are controlled.
# we also make the assumption that the active index is the same across all parallel managed joints
self.active_ancestor_joint_idxs = [
(x.active_index[0]).cpu().item() for x in self.active_ancestor_joints
]
controlled_joints_idx_in_qmask = [
self.active_ancestor_joint_idxs.index(idx)
for idx in self.active_joint_indices
]
self.qmask = torch.zeros(
len(self.active_ancestor_joints), dtype=bool, device=self.device
)
self.qmask[controlled_joints_idx_in_qmask] = 1
else:
self.qmask = torch.zeros(
self.articulation.max_dof, dtype=bool, device=self.device
)
self.pmodel = self.articulation._objs[0].create_pinocchio_model()
self.qmask[self.active_joint_indices] = 1
self.kinematics = Kinematics(
self.config.urdf_path,
self.config.ee_link,
self.articulation,
self.active_joint_indices,
)

self.ee_link = self.kinematics.end_link

def _initialize_action_space(self):
low = np.float32(np.broadcast_to(self.config.pos_lower, 3))
Expand Down Expand Up @@ -126,41 +79,6 @@ def reset(self):
self.scene._reset_mask
] = self.ee_pose_at_base.raw_pose[self.scene._reset_mask]

def compute_ik(
self, target_pose: Pose, action: Array, pos_only=True, max_iterations=100
):
# NOTE (stao): it is a bit strange code wise that target_pose and action are both given since
# GPU sim can only use the delta action directly and cannot generate joint targets via a target pose
if physx.is_gpu_enabled():
## GPU IK mixed frame is basically all relative to base frame...
## CPU depends...
jacobian = (
self.fast_kinematics_model.jacobian_mixed_frame_pytorch(
self.articulation.get_qpos()[:, self.active_ancestor_joint_idxs]
)
.view(-1, len(self.active_ancestor_joints), 6)
.permute(0, 2, 1)
)
jacobian = jacobian[:, :, self.qmask]
if pos_only:
jacobian = jacobian[:, 0:3]

# NOTE (stao): this method of IK is from https://mathweb.ucsd.edu/~sbuss/ResearchWeb/ikmethods/iksurvey.pdf by Samuel R. Buss
delta_joint_pos = torch.linalg.pinv(jacobian) @ action.unsqueeze(-1)
return self.qpos + delta_joint_pos.squeeze(-1)
else:
result, success, error = self.pmodel.compute_inverse_kinematics(
self.ee_link_idx,
target_pose.sp,
initial_qpos=common.to_numpy(self.articulation.get_qpos()).squeeze(0),
active_qmask=self.qmask,
max_iterations=max_iterations,
)
if success:
return common.to_tensor([result[self.active_joint_indices]])
else:
return None

def compute_target_pose(self, prev_ee_pose_at_base, action):
# Keep the current rotation and change the position
if self.config.use_delta:
Expand All @@ -187,7 +105,13 @@ def set_action(self, action: Array):
prev_ee_pose_at_base = self.ee_pose_at_base

self._target_pose = self.compute_target_pose(prev_ee_pose_at_base, action)
self._target_qpos = self.compute_ik(self._target_pose, action)
pos_only = type(self.config) == PDEEPosControllerConfig
self._target_qpos = self.kinematics.compute_ik(
self._target_pose,
self.articulation.get_qpos(),
pos_only=pos_only,
action=action,
)
if self._target_qpos is None:
self._target_qpos = self._start_qpos
if self.config.interpolate:
Expand Down
146 changes: 146 additions & 0 deletions mani_skill/agents/controllers/utils/kinematics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
"""
Code for kinematics utilities on CPU/GPU
"""
from contextlib import contextmanager, redirect_stderr, redirect_stdout
from os import devnull
from typing import List

import pytorch_kinematics as pk
import torch
from sapien.wrapper.pinocchio_model import PinocchioModel

from mani_skill.utils import common
from mani_skill.utils.structs.articulation import Articulation
from mani_skill.utils.structs.articulation_joint import ArticulationJoint
from mani_skill.utils.structs.pose import Pose

# currently fast_kinematics has some bugs on some systems so we use the slower pytorch kinematics package instead.
# try:
# import fast_kinematics
# except:
# # not all systems support the fast_kinematics package at the moment
# fast_kinematics = None


class Kinematics:
def __init__(
self,
urdf_path: str,
end_link_name: str,
articulation: Articulation,
active_joint_indices: torch.Tensor,
):
"""
Initialize the kinematics solver. It will be run on whichever device the articulation is on.

Args:
urdf_path (str): path to the URDF file
end_link_name (str): name of the end-effector link
articulation (Articulation): the articulation object
active_joint_indices (torch.Tensor): indices of the active joints that can be controlled
"""
self.urdf_path = urdf_path
self.end_link = articulation.links_map[end_link_name]
self.end_link_idx = articulation.links.index(self.end_link)
self.active_joint_indices = active_joint_indices
self.articulation = articulation
self.device = articulation.device
if self.device.type == "cuda":
self.use_gpu_ik = True
self._setup_gpu()
else:
self.use_gpu_ik = False
self._setup_cpu()

def _setup_cpu(self):
"""setup the kinematics solvers on the CPU"""
self.use_gpu_ik = False
# NOTE (stao): currently using the pinnochio that comes packaged with SAPIEN
self.qmask = torch.zeros(
self.articulation.max_dof, dtype=bool, device=self.device
)
self.pmodel: PinocchioModel = self.articulation._objs[
0
].create_pinocchio_model()
self.qmask[self.active_joint_indices] = 1

def _setup_gpu(self):
"""setup the kinematics solvers on the GPU"""
self.use_gpu_ik = True
with open(self.urdf_path, "r") as f:
urdf_str = f.read()

# NOTE (stao): it seems that the pk library currently always outputs some complaints if there are unknown attributes in a URDF. Hide it with this contextmanager here
@contextmanager
def suppress_stdout_stderr():
"""A context manager that redirects stdout and stderr to devnull"""
with open(devnull, "w") as fnull:
with redirect_stderr(fnull) as err, redirect_stdout(fnull) as out:
yield (err, out)

with suppress_stdout_stderr():
self.pk_chain = pk.build_serial_chain_from_urdf(
urdf_str,
end_link_name=self.end_link.name,
).to(device=self.device)

# note that everything past the end-link is ignored. Any joint whose ancestor is self.end_link is ignored
cur_link = self.end_link.joint.parent_link
active_ancestor_joints: List[ArticulationJoint] = []
while cur_link is not None:
if cur_link.joint.active_index is not None:
active_ancestor_joints.append(cur_link.joint)
cur_link = cur_link.joint.parent_link
active_ancestor_joints = active_ancestor_joints[::-1]
self.active_ancestor_joints = active_ancestor_joints

# initially self.active_joint_indices references active joints that are controlled.
# we also make the assumption that the active index is the same across all parallel managed joints
self.active_ancestor_joint_idxs = [
(x.active_index[0]).cpu().item() for x in self.active_ancestor_joints
]
controlled_joints_idx_in_qmask = [
self.active_ancestor_joint_idxs.index(idx)
for idx in self.active_joint_indices
]
self.qmask = torch.zeros(
len(self.active_ancestor_joints), dtype=bool, device=self.device
)
self.qmask[controlled_joints_idx_in_qmask] = 1

def compute_ik(
self, target_pose: Pose, q0: torch.Tensor, pos_only: bool = False, action=None
):
"""Given a target pose, via inverse kinematics compute the target joint positions that will achieve the target pose"""
if self.use_gpu_ik:
q0 = q0[:, self.active_ancestor_joint_idxs]
jacobian = self.pk_chain.jacobian(q0)
# code commented out below is the fast kinematics method
# jacobian = (
# self.fast_kinematics_model.jacobian_mixed_frame_pytorch(
# self.articulation.get_qpos()[:, self.active_ancestor_joint_idxs]
# )
# .view(-1, len(self.active_ancestor_joints), 6)
# .permute(0, 2, 1)
# )
# jacobian = jacobian[:, :, self.qmask]
if pos_only:
jacobian = jacobian[:, 0:3]

# NOTE (stao): this method of IK is from https://mathweb.ucsd.edu/~sbuss/ResearchWeb/ikmethods/iksurvey.pdf by Samuel R. Buss
delta_joint_pos = torch.linalg.pinv(jacobian) @ action.unsqueeze(-1)
return q0 + delta_joint_pos.squeeze(-1)
else:
result, success, error = self.pmodel.compute_inverse_kinematics(
self.end_link_idx,
target_pose.sp,
initial_qpos=q0,
active_qmask=self.qmask,
max_iterations=100,
)
if success:
return common.to_tensor(
[result[self.active_ancestor_joint_idxs]], device=self.device
)
else:
return None
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
"mplib==0.1.1;platform_system=='Linux'",
"fast_kinematics==0.2.2;platform_system=='Linux'",
"IPython",
"pytorch_kinematics_ms==0.7.2", # pytorch kinematics package for ManiSkill forked from https://github.com/UM-ARM-Lab/pytorch_kinematics
"tyro==0.8.5", # nice, typed, command line arg parser
"huggingface_hub", # we use HF to version control some assets/datasets more easily
],
# Glob patterns do not automatically match dotfiles
Expand Down