3
3
import genesis as gs
4
4
from genesis .utils .geom import quat_to_xyz , transform_by_quat , inv_quat , transform_quat_by_quat
5
5
6
+
6
7
def gs_rand_float (lower , upper , shape , device ):
7
8
return (upper - lower ) * torch .rand (size = shape , device = device ) + lower
8
9
10
+
9
11
class HoverEnv :
10
12
def __init__ (self , num_envs , env_cfg , obs_cfg , reward_cfg , command_cfg , show_viewer = False , device = "cuda" ):
11
13
self .device = torch .device (device )
@@ -52,18 +54,19 @@ def __init__(self, num_envs, env_cfg, obs_cfg, reward_cfg, command_cfg, show_vie
52
54
53
55
# add target
54
56
if self .env_cfg ["visualize_target" ]:
55
- self .target = self .scene .add_entity (morph = gs .morphs .Mesh (
56
- file = "meshes/sphere.obj" ,
57
- scale = 0.05 ,
58
- fixed = True ,
59
- collision = False ,
60
- ),
61
- surface = gs .surfaces .Rough (
62
- diffuse_texture = gs .textures .ColorTexture (
63
- color = (1.0 , 0.5 , 0.5 ),
64
- ),
65
- ),
66
- )
57
+ self .target = self .scene .add_entity (
58
+ morph = gs .morphs .Mesh (
59
+ file = "meshes/sphere.obj" ,
60
+ scale = 0.05 ,
61
+ fixed = True ,
62
+ collision = False ,
63
+ ),
64
+ surface = gs .surfaces .Rough (
65
+ diffuse_texture = gs .textures .ColorTexture (
66
+ color = (1.0 , 0.5 , 0.5 ),
67
+ ),
68
+ ),
69
+ )
67
70
else :
68
71
self .target = None
69
72
@@ -120,9 +123,7 @@ def _resample_commands(self, envs_idx):
120
123
121
124
def _at_target (self ):
122
125
at_target = (
123
- (torch .norm (self .rel_pos , dim = 1 ) < self .env_cfg ["at_target_threshold" ])
124
- .nonzero (as_tuple = False )
125
- .flatten ()
126
+ (torch .norm (self .rel_pos , dim = 1 ) < self .env_cfg ["at_target_threshold" ]).nonzero (as_tuple = False ).flatten ()
126
127
)
127
128
return at_target
128
129
@@ -134,7 +135,7 @@ def step(self, actions):
134
135
# self.drone.control_dofs_position(target_dof_pos)
135
136
136
137
# 14468 is hover rpm
137
- self .drone .set_propellels_rpm ((1 + exec_actions * 0.8 ) * 14468.429183500699 )
138
+ self .drone .set_propellels_rpm ((1 + exec_actions * 0.8 ) * 14468.429183500699 )
138
139
self .scene .step ()
139
140
140
141
# update buffers
@@ -157,12 +158,12 @@ def step(self, actions):
157
158
158
159
# check termination and reset
159
160
self .crash_condition = (
160
- (torch .abs (self .base_euler [:, 1 ]) > self .env_cfg ["termination_if_pitch_greater_than" ]) |
161
- (torch .abs (self .base_euler [:, 0 ]) > self .env_cfg ["termination_if_roll_greater_than" ]) |
162
- (torch .abs (self .rel_pos [:, 0 ]) > self .env_cfg ["termination_if_x_greater_than" ]) |
163
- (torch .abs (self .rel_pos [:, 1 ]) > self .env_cfg ["termination_if_y_greater_than" ]) |
164
- (torch .abs (self .rel_pos [:, 2 ]) > self .env_cfg ["termination_if_z_greater_than" ]) |
165
- (self .base_pos [:, 2 ] < self .env_cfg ["termination_if_close_to_ground" ])
161
+ (torch .abs (self .base_euler [:, 1 ]) > self .env_cfg ["termination_if_pitch_greater_than" ])
162
+ | (torch .abs (self .base_euler [:, 0 ]) > self .env_cfg ["termination_if_roll_greater_than" ])
163
+ | (torch .abs (self .rel_pos [:, 0 ]) > self .env_cfg ["termination_if_x_greater_than" ])
164
+ | (torch .abs (self .rel_pos [:, 1 ]) > self .env_cfg ["termination_if_y_greater_than" ])
165
+ | (torch .abs (self .rel_pos [:, 2 ]) > self .env_cfg ["termination_if_z_greater_than" ])
166
+ | (self .base_pos [:, 2 ] < self .env_cfg ["termination_if_close_to_ground" ])
166
167
)
167
168
self .reset_buf = (self .episode_length_buf > self .max_episode_length ) | self .crash_condition
168
169
@@ -248,15 +249,15 @@ def _reward_smooth(self):
248
249
249
250
def _reward_yaw (self ):
250
251
yaw = self .base_euler [:, 2 ]
251
- yaw = torch .where (yaw > 180 , yaw - 360 , yaw )/ 180 * 3.14159 # use rad for yaw_reward
252
+ yaw = torch .where (yaw > 180 , yaw - 360 , yaw ) / 180 * 3.14159 # use rad for yaw_reward
252
253
yaw_rew = torch .exp (self .reward_cfg ["yaw_lambda" ] * torch .abs (yaw ))
253
254
return yaw_rew
254
-
255
+
255
256
def _reward_angular (self ):
256
- angular_rew = torch .norm (self .base_ang_vel / 3.14159 , dim = 1 )
257
+ angular_rew = torch .norm (self .base_ang_vel / 3.14159 , dim = 1 )
257
258
return angular_rew
258
259
259
260
def _reward_crash (self ):
260
261
crash_rew = torch .zeros ((self .num_envs ,), device = self .device , dtype = gs .tc_float )
261
262
crash_rew [self .crash_condition ] = 1
262
- return crash_rew
263
+ return crash_rew
0 commit comments