Skip to content

Commit d599350

Browse files
committed
🚀 [RofuncRL] Still struggling in finding the ASEHRL bug
Now 1 step LLC works, but more steps doesn't
1 parent 69949ed commit d599350

36 files changed

+5142
-78
lines changed

examples/learning_rl/example_HumanoidASE_RofuncRL.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def inference(custom_args):
9191

9292

9393
if __name__ == '__main__':
94-
gpu_id = 1
94+
gpu_id = 0
9595

9696
parser = argparse.ArgumentParser()
9797
# Available tasks and motion files:
@@ -101,15 +101,15 @@ def inference(custom_args):
101101
# HumanoidASEReachSwordShield -> reallusion_sword_shield/RL_Avatar_Idle_Ready_Motion.npy
102102
# HumanoidASELocationSwordShield -> reallusion_sword_shield/RL_Avatar_Idle_Ready_Motion.npy
103103
# HumanoidASEStrikeSwordShield -> reallusion_sword_shield/RL_Avatar_Idle_Ready_Motion.npy
104-
parser.add_argument("--task", type=str, default="HumanoidASEGetupSwordShield")
104+
parser.add_argument("--task", type=str, default="HumanoidASEHeadingSwordShield")
105105
parser.add_argument("--motion_file", type=str,
106-
default="reallusion_sword_shield/dataset_reallusion_sword_shield.yaml")
106+
default="reallusion_sword_shield/RL_Avatar_Idle_Ready_Motion.npy")
107107
parser.add_argument("--agent", type=str, default="ase") # Available agent: ase
108108
parser.add_argument("--num_envs", type=int, default=4096)
109109
parser.add_argument("--sim_device", type=str, default="cuda:{}".format(gpu_id))
110110
parser.add_argument("--rl_device", type=str, default="cuda:{}".format(gpu_id))
111111
parser.add_argument("--graphics_device_id", type=int, default=gpu_id)
112-
parser.add_argument("--headless", type=str, default="True")
112+
parser.add_argument("--headless", type=str, default="False")
113113
parser.add_argument("--inference", action="store_true", help="turn to inference mode while adding this argument")
114114
parser.add_argument("--ckpt_path", type=str, default=None)
115115
custom_args = parser.parse_args()

rofunc/config/learning/rl/train/HumanoidASEHeadingSwordShieldASERofuncRL.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ Agent:
5454
kl_threshold: 0 # Initial coefficient for KL divergence.
5555

5656
llc_ckpt_path:
57-
llc_steps_per_high_action: 5
57+
llc_steps_per_high_action: 1
5858

5959
# state_preprocessor: # State preprocessor type.
6060
# state_preprocessor_kwargs: # State preprocessor kwargs.

rofunc/config/learning/rl/train/HumanoidASEStrikeSwordShieldASERofuncRL.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ Agent:
5353

5454
kl_threshold: 0 # Initial coefficient for KL divergence.
5555

56-
llc_ckpt_path:
56+
llc_ckpt_path: /home/ubuntu/Github/Knowledge-Universe/Robotics/Roadmap-for-robot-science/examples/learning_rl/runs/RofuncRL_ASETrainer_HumanoidASEGetupSwordShield_23-06-26_12-49-35-111331/checkpoints/ckpt_87000.pth
5757
llc_steps_per_high_action: 5
5858

5959
# state_preprocessor: # State preprocessor type.

rofunc/learning/RofuncRL/agents/base_agent.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def track_data(self, tag: str, value: float) -> None:
118118
def store_transition(self, states: torch.Tensor, actions: torch.Tensor, next_states: torch.Tensor,
119119
rewards: torch.Tensor, terminated: torch.Tensor, truncated: torch.Tensor, infos: torch.Tensor):
120120
"""
121-
Record the transition.
121+
Record the transition. (Only rewards, truncated and terminated are used in this base class)
122122
"""
123123
if self.cumulative_rewards is None:
124124
self.cumulative_rewards = torch.zeros_like(rewards, dtype=torch.float32)

rofunc/learning/RofuncRL/agents/mixline/ase_agent.py

+35-14
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import rofunc as rf
2727
from rofunc.learning.RofuncRL.agents.base_agent import BaseAgent
2828
from rofunc.learning.RofuncRL.agents.mixline.amp_agent import AMPAgent
29+
from rofunc.learning.RofuncRL.models.misc_models import ASEDiscEnc
2930
from rofunc.learning.RofuncRL.models.base_models import BaseMLP
3031
from rofunc.learning.RofuncRL.utils.memory import Memory
3132

@@ -72,6 +73,13 @@ def __init__(self,
7273
self._enc_reward_weight = cfg.Agent.enc_reward_weight
7374

7475
'''Define ASE specific models except for AMP'''
76+
# self.discriminator = ASEDiscEnc(cfg.Model,
77+
# input_dim=amp_observation_space.shape[0],
78+
# enc_output_dim=self._ase_latent_dim,
79+
# disc_output_dim=1,
80+
# cfg_name='encoder').to(device)
81+
# self.encoder = self.discriminator
82+
7583
self.encoder = BaseMLP(cfg.Model,
7684
input_dim=amp_observation_space.shape[0],
7785
output_dim=self._ase_latent_dim,
@@ -95,10 +103,11 @@ def __init__(self,
95103

96104
def _set_up(self):
97105
super()._set_up()
98-
self.optimizer_enc = torch.optim.Adam(self.encoder.parameters(), lr=self._lr_e, eps=self._adam_eps)
99-
if self._lr_scheduler is not None:
100-
self.scheduler_enc = self._lr_scheduler(self.optimizer_enc, **self._lr_scheduler_kwargs)
101-
self.checkpoint_modules["optimizer_enc"] = self.optimizer_enc
106+
if self.encoder is not self.discriminator:
107+
self.optimizer_enc = torch.optim.Adam(self.encoder.parameters(), lr=self._lr_e, eps=self._adam_eps)
108+
if self._lr_scheduler is not None:
109+
self.scheduler_enc = self._lr_scheduler(self.optimizer_enc, **self._lr_scheduler_kwargs)
110+
self.checkpoint_modules["optimizer_enc"] = self.optimizer_enc
102111

103112
def act(self, states: torch.Tensor, deterministic: bool = False, ase_latents: torch.Tensor = None):
104113
if self._current_states is not None:
@@ -173,7 +182,10 @@ def update_net(self):
173182
style_rewards *= self._discriminator_reward_scale
174183

175184
# Compute encoder reward
176-
enc_output = self.encoder(self._amp_state_preprocessor(amp_states))
185+
if self.encoder is self.discriminator:
186+
enc_output = self.encoder.get_enc(self._amp_state_preprocessor(amp_states))
187+
else:
188+
enc_output = self.encoder(self._amp_state_preprocessor(amp_states))
177189
enc_output = torch.nn.functional.normalize(enc_output, dim=-1)
178190
enc_reward = torch.clamp_min(torch.sum(enc_output * ase_latents, dim=-1, keepdim=True), 0.0)
179191
enc_reward *= self._enc_reward_scale
@@ -311,7 +323,10 @@ def update_net(self):
311323
discriminator_loss *= self._discriminator_loss_scale
312324

313325
# encoder loss
314-
enc_output = self.encoder(self._amp_state_preprocessor(sampled_amp_states))
326+
if self.encoder is self.discriminator:
327+
enc_output = self.encoder.get_enc(self._amp_state_preprocessor(sampled_amp_states))
328+
else:
329+
enc_output = self.encoder(self._amp_state_preprocessor(sampled_amp_states_batch))
315330
enc_output = torch.nn.functional.normalize(enc_output, dim=-1)
316331
enc_err = -torch.sum(enc_output * sampled_ase_latents, dim=-1, keepdim=True)
317332
enc_loss = torch.mean(enc_err)
@@ -357,17 +372,21 @@ def update_net(self):
357372

358373
# Update discriminator network
359374
self.optimizer_disc.zero_grad()
360-
discriminator_loss.backward()
375+
if self.encoder is self.discriminator:
376+
(discriminator_loss + enc_loss).backward()
377+
else:
378+
discriminator_loss.backward()
361379
if self._grad_norm_clip > 0:
362380
nn.utils.clip_grad_norm_(self.discriminator.parameters(), self._grad_norm_clip)
363381
self.optimizer_disc.step()
364382

365383
# Update encoder network
366-
self.optimizer_enc.zero_grad()
367-
enc_loss.backward()
368-
if self._grad_norm_clip > 0:
369-
nn.utils.clip_grad_norm_(self.encoder.parameters(), self._grad_norm_clip)
370-
self.optimizer_enc.step()
384+
if self.encoder is not self.discriminator:
385+
self.optimizer_enc.zero_grad()
386+
enc_loss.backward()
387+
if self._grad_norm_clip > 0:
388+
nn.utils.clip_grad_norm_(self.encoder.parameters(), self._grad_norm_clip)
389+
self.optimizer_enc.step()
371390

372391
# update cumulative losses
373392
cumulative_policy_loss += policy_loss.item()
@@ -382,7 +401,8 @@ def update_net(self):
382401
self.scheduler_policy.step()
383402
self.scheduler_value.step()
384403
self.scheduler_disc.step()
385-
self.scheduler_enc.step()
404+
if self.encoder is not self.discriminator:
405+
self.scheduler_enc.step()
386406

387407
# update AMP replay buffer
388408
self.replay_buffer.add_samples(states=amp_states.view(-1, amp_states.shape[-1]))
@@ -407,4 +427,5 @@ def update_net(self):
407427
self.track_data("Learning / Learning rate (policy)", self.scheduler_policy.get_last_lr()[0])
408428
self.track_data("Learning / Learning rate (value)", self.scheduler_value.get_last_lr()[0])
409429
self.track_data("Learning / Learning rate (discriminator)", self.scheduler_disc.get_last_lr()[0])
410-
self.track_data("Learning / Learning rate (encoder)", self.scheduler_enc.get_last_lr()[0])
430+
if self.encoder is not self.discriminator:
431+
self.track_data("Learning / Learning rate (encoder)", self.scheduler_enc.get_last_lr()[0])

0 commit comments

Comments
 (0)