Skip to content

Commit 9d27d70

Browse files
authored
[MISC] reformat examples/drone (#379)
reformat examples/drone
1 parent ae8f556 commit 9d27d70

File tree

3 files changed

+30
-29
lines changed

3 files changed

+30
-29
lines changed

examples/drone/hover_env.py

+27-26
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
import genesis as gs
44
from genesis.utils.geom import quat_to_xyz, transform_by_quat, inv_quat, transform_quat_by_quat
55

6+
67
def gs_rand_float(lower, upper, shape, device):
78
return (upper - lower) * torch.rand(size=shape, device=device) + lower
89

10+
911
class HoverEnv:
1012
def __init__(self, num_envs, env_cfg, obs_cfg, reward_cfg, command_cfg, show_viewer=False, device="cuda"):
1113
self.device = torch.device(device)
@@ -52,18 +54,19 @@ def __init__(self, num_envs, env_cfg, obs_cfg, reward_cfg, command_cfg, show_vie
5254

5355
# add target
5456
if self.env_cfg["visualize_target"]:
55-
self.target = self.scene.add_entity(morph=gs.morphs.Mesh(
56-
file="meshes/sphere.obj",
57-
scale=0.05,
58-
fixed=True,
59-
collision=False,
60-
),
61-
surface=gs.surfaces.Rough(
62-
diffuse_texture=gs.textures.ColorTexture(
63-
color=(1.0, 0.5, 0.5),
64-
),
65-
),
66-
)
57+
self.target = self.scene.add_entity(
58+
morph=gs.morphs.Mesh(
59+
file="meshes/sphere.obj",
60+
scale=0.05,
61+
fixed=True,
62+
collision=False,
63+
),
64+
surface=gs.surfaces.Rough(
65+
diffuse_texture=gs.textures.ColorTexture(
66+
color=(1.0, 0.5, 0.5),
67+
),
68+
),
69+
)
6770
else:
6871
self.target = None
6972

@@ -120,9 +123,7 @@ def _resample_commands(self, envs_idx):
120123

121124
def _at_target(self):
122125
at_target = (
123-
(torch.norm(self.rel_pos, dim=1) < self.env_cfg["at_target_threshold"])
124-
.nonzero(as_tuple=False)
125-
.flatten()
126+
(torch.norm(self.rel_pos, dim=1) < self.env_cfg["at_target_threshold"]).nonzero(as_tuple=False).flatten()
126127
)
127128
return at_target
128129

@@ -134,7 +135,7 @@ def step(self, actions):
134135
# self.drone.control_dofs_position(target_dof_pos)
135136

136137
# 14468 is hover rpm
137-
self.drone.set_propellels_rpm((1 + exec_actions*0.8) * 14468.429183500699)
138+
self.drone.set_propellels_rpm((1 + exec_actions * 0.8) * 14468.429183500699)
138139
self.scene.step()
139140

140141
# update buffers
@@ -157,12 +158,12 @@ def step(self, actions):
157158

158159
# check termination and reset
159160
self.crash_condition = (
160-
(torch.abs(self.base_euler[:, 1]) > self.env_cfg["termination_if_pitch_greater_than"]) |
161-
(torch.abs(self.base_euler[:, 0]) > self.env_cfg["termination_if_roll_greater_than"]) |
162-
(torch.abs(self.rel_pos[:, 0]) > self.env_cfg["termination_if_x_greater_than"]) |
163-
(torch.abs(self.rel_pos[:, 1]) > self.env_cfg["termination_if_y_greater_than"]) |
164-
(torch.abs(self.rel_pos[:, 2]) > self.env_cfg["termination_if_z_greater_than"]) |
165-
(self.base_pos[:, 2] < self.env_cfg["termination_if_close_to_ground"])
161+
(torch.abs(self.base_euler[:, 1]) > self.env_cfg["termination_if_pitch_greater_than"])
162+
| (torch.abs(self.base_euler[:, 0]) > self.env_cfg["termination_if_roll_greater_than"])
163+
| (torch.abs(self.rel_pos[:, 0]) > self.env_cfg["termination_if_x_greater_than"])
164+
| (torch.abs(self.rel_pos[:, 1]) > self.env_cfg["termination_if_y_greater_than"])
165+
| (torch.abs(self.rel_pos[:, 2]) > self.env_cfg["termination_if_z_greater_than"])
166+
| (self.base_pos[:, 2] < self.env_cfg["termination_if_close_to_ground"])
166167
)
167168
self.reset_buf = (self.episode_length_buf > self.max_episode_length) | self.crash_condition
168169

@@ -248,15 +249,15 @@ def _reward_smooth(self):
248249

249250
def _reward_yaw(self):
250251
yaw = self.base_euler[:, 2]
251-
yaw = torch.where(yaw > 180, yaw - 360, yaw)/180*3.14159 # use rad for yaw_reward
252+
yaw = torch.where(yaw > 180, yaw - 360, yaw) / 180 * 3.14159 # use rad for yaw_reward
252253
yaw_rew = torch.exp(self.reward_cfg["yaw_lambda"] * torch.abs(yaw))
253254
return yaw_rew
254-
255+
255256
def _reward_angular(self):
256-
angular_rew = torch.norm(self.base_ang_vel/3.14159, dim=1)
257+
angular_rew = torch.norm(self.base_ang_vel / 3.14159, dim=1)
257258
return angular_rew
258259

259260
def _reward_crash(self):
260261
crash_rew = torch.zeros((self.num_envs,), device=self.device, dtype=gs.tc_float)
261262
crash_rew[self.crash_condition] = 1
262-
return crash_rew
263+
return crash_rew

examples/drone/hover_eval.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def main():
4545

4646
obs, _ = env.reset()
4747

48-
max_sim_step = int(env_cfg["episode_length_s"]*env_cfg["max_visualize_FPS"])
48+
max_sim_step = int(env_cfg["episode_length_s"] * env_cfg["max_visualize_FPS"])
4949
with torch.no_grad():
5050
if args.record:
5151
env.cam.start_recording()
@@ -59,6 +59,7 @@ def main():
5959
actions = policy(obs)
6060
obs, _, rews, dones, infos = env.step(actions)
6161

62+
6263
if __name__ == "__main__":
6364
main()
6465

examples/drone/hover_train.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111

1212
def get_train_cfg(exp_name, max_iterations):
13-
1413
train_cfg_dict = {
1514
"algorithm": {
1615
"clip_param": 0.2,
@@ -95,7 +94,7 @@ def get_cfgs():
9594
"yaw": 0.01,
9695
"angular": -2e-4,
9796
"crash": -10.0,
98-
}
97+
},
9998
}
10099
command_cfg = {
101100
"num_commands": 3,

0 commit comments

Comments
 (0)