diff --git a/mani_skill/envs/tasks/__init__.py b/mani_skill/envs/tasks/__init__.py index d91008968..ac7a80c5d 100644 --- a/mani_skill/envs/tasks/__init__.py +++ b/mani_skill/envs/tasks/__init__.py @@ -16,3 +16,4 @@ from .stack_cube import StackCubeEnv from .two_robot_pick_cube import TwoRobotPickCube from .two_robot_stack_cube import TwoRobotStackCube +from .rotate_cube import RotateCubeEnv \ No newline at end of file diff --git a/mani_skill/envs/tasks/rotate_cube.py b/mani_skill/envs/tasks/rotate_cube.py index 03d551b8e..e938a4c1a 100644 --- a/mani_skill/envs/tasks/rotate_cube.py +++ b/mani_skill/envs/tasks/rotate_cube.py @@ -11,18 +11,13 @@ from mani_skill.sensors.camera import CameraConfig from mani_skill.utils.building import ActorBuilder, actors from mani_skill.utils.building.ground import build_ground -from mani_skill.utils.geometry.rotation_conversions import ( - euler_angles_to_matrix, - matrix_to_quaternion, - random_quaternions, -) from mani_skill.utils.registration import register_env from mani_skill.utils.sapien_utils import look_at from mani_skill.utils.structs.actor import Actor from mani_skill.utils.structs.articulation import Articulation from mani_skill.utils.structs.pose import Pose from mani_skill.utils.structs.types import Array, GPUMemoryConfig, SimConfig - +from mani_skill.envs.utils.randomization.pose import random_quaternions class RotateCubeEnv(BaseEnv): """ @@ -189,7 +184,7 @@ def random_z(min_height: float, max_height: float) -> torch.Tensor: # For initialization pos_x, pos_y = random_xy() pos_z = self.size / 2 - orientation = random_roll_orientation(b, self.device) + orientation = random_quaternions(b, lock_x=True, lock_y=True, device=self.device) elif difficulty == 2: # Fixed goal position in the air with x,y = 0. No orientation. pos_x, pos_y = 0.0, 0.0 @@ -209,7 +204,7 @@ def random_z(min_height: float, max_height: float) -> torch.Tensor: # in the cirriculum pos_x, pos_y = random_xy() pos_z = random_z(min_height=self.radius_3d, max_height=self.max_height) - orientation = random_quaternions(b, dtype=torch.float, device=self.device) + orientation = random_quaternions(b, device=self.device) else: msg = f"Invalid difficulty index for task: {difficulty}." raise ValueError(msg) @@ -350,19 +345,6 @@ def compute_normalized_dense_reward(self, obs: Any, action: Array, info: Dict): return self.compute_dense_reward(obs=obs, action=action, info=info) / max_reward -@torch.jit.script -def random_roll_orientation(num: int, device: torch.device) -> torch.Tensor: - """Returns sampled rotation around z-axis.""" - roll = 2 * np.pi * torch.rand(num, dtype=torch.float, device=device) - pitch = torch.zeros(num, dtype=torch.float, device=device) - yaw = torch.zeros(num, dtype=torch.float, device=device) - euler_angles = torch.stack([roll, pitch, yaw], dim=-1) - rotation_matrix = euler_angles_to_matrix(euler_angles, "XYZ") - quat = matrix_to_quaternion(rotation_matrix) - return quat - - -@torch.jit.script def quat_diff_rad(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: """ Get the difference in radians between two quaternions.