Skip to content

Commit 33c8779

Browse files
Added BODY_VEL_WORLD as observation type in Mujoco
- now we provide also the opportunity to get the body velocity info in world frame
1 parent 495af65 commit 33c8779

File tree

3 files changed

+14
-11
lines changed

3 files changed

+14
-11
lines changed

mushroom_rl/environments/mujoco_envs/air_hockey/base.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def __init__(self, n_agents=1, env_noise=False, obs_noise=False, gamma=0.99, hor
5757
("robot_1/joint_3_vel", "planar_robot_1/joint_3", ObservationType.JOINT_VEL)]
5858

5959
additional_data += [("robot_1/ee_pos", "planar_robot_1/body_ee", ObservationType.BODY_POS),
60-
("robot_1/ee_vel", "planar_robot_1/body_ee", ObservationType.BODY_VEL)]
60+
("robot_1/ee_vel", "planar_robot_1/body_ee", ObservationType.BODY_VEL_WORLD)]
6161

6262
collision_spec += [("robot_1/ee", ["planar_robot_1/ee"])]
6363

@@ -76,7 +76,7 @@ def __init__(self, n_agents=1, env_noise=False, obs_noise=False, gamma=0.99, hor
7676
("robot_2/joint_3_vel", "planar_robot_2/joint_3", ObservationType.JOINT_VEL)]
7777

7878
additional_data += [("robot_2/ee_pos", "planar_robot_2/body_ee", ObservationType.BODY_POS),
79-
("robot_2/ee_vel", "planar_robot_2/body_ee", ObservationType.BODY_VEL)]
79+
("robot_2/ee_vel", "planar_robot_2/body_ee", ObservationType.BODY_VEL_WORLD)]
8080

8181
collision_spec += [("robot_2/ee", ["planar_robot_2/ee"])]
8282
else:

mushroom_rl/environments/mujoco_envs/ball_in_a_cup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def __init__(self):
3636
("palm_yaw_pos", "wam/palm_yaw_joint", ObservationType.JOINT_POS),
3737
("palm_yaw_vel", "wam/palm_yaw_joint", ObservationType.JOINT_VEL),
3838
("ball_pos", "ball", ObservationType.BODY_POS),
39-
("ball_vel", "ball", ObservationType.BODY_VEL)]
39+
("ball_vel", "ball", ObservationType.BODY_VEL_WORLD)]
4040

4141
additional_data_spec = [("ball_pos", "ball", ObservationType.BODY_POS),
4242
("goal_pos", "cup_goal_final", ObservationType.SITE_POS)]

mushroom_rl/utils/mujoco/observation_helper.py

+11-8
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,22 @@ class ObservationType(Enum):
1111
The Observation have the following returns:
1212
BODY_POS: (3,) x, y, z position of the body
1313
BODY_ROT: (4,) quaternion of the body
14-
BODY_VEL: (6,) first angular velocity around x, y, z. Then linear velocity for x, y, z
14+
BODY_VEL: (6,) first angular velocity around x, y, z. Then linear velocity for x, y, z, in local frame
15+
BODY_VEL_WORLD: (6,) first angular velocity around x, y, z. Then linear velocity for x, y, z, in world frame
1516
JOINT_POS: (1,) rotation of the joint OR (7,) position, quaternion of a free joint
1617
JOINT_VEL: (1,) velocity of the joint OR (6,) FIRST linear then angular velocity !different to BODY_VEL!
1718
SITE_POS: (3,) x, y, z position of the body
1819
SITE_ROT: (9,) rotation matrix of the site
1920
"""
20-
__order__ = "BODY_POS BODY_ROT BODY_VEL JOINT_POS JOINT_VEL SITE_POS SITE_ROT"
21+
__order__ = "BODY_POS BODY_ROT BODY_VEL BODY_VEL_WORLD JOINT_POS JOINT_VEL SITE_POS SITE_ROT"
2122
BODY_POS = 0
2223
BODY_ROT = 1
2324
BODY_VEL = 2
24-
JOINT_POS = 3
25-
JOINT_VEL = 4
26-
SITE_POS = 5
27-
SITE_ROT = 6
25+
BODY_VEL_WORLD = 3
26+
JOINT_POS = 4
27+
JOINT_VEL = 5
28+
SITE_POS = 6
29+
SITE_ROT = 7
2830

2931

3032
class ObservationHelper:
@@ -190,9 +192,10 @@ def get_state(self, model, data, name, o_type):
190192
obs = data.body(name).xpos
191193
elif o_type == ObservationType.BODY_ROT:
192194
obs = data.body(name).xquat
193-
elif o_type == ObservationType.BODY_VEL:
195+
elif o_type == ObservationType.BODY_VEL or o_type == ObservationType.BODY_VEL_WORLD:
196+
local = o_type == ObservationType.BODY_VEL
194197
obs = np.empty(6)
195-
mujoco.mj_objectVelocity(model, data, mujoco.mjtObj.mjOBJ_XBODY, data.body(name).id, obs, True)
198+
mujoco.mj_objectVelocity(model, data, mujoco.mjtObj.mjOBJ_XBODY, data.body(name).id, obs, local)
196199
elif o_type == ObservationType.JOINT_POS:
197200
obs = data.joint(name).qpos
198201
elif o_type == ObservationType.JOINT_VEL:

0 commit comments

Comments
 (0)