From b10e93023fa6d50881d928f88ed6a2e8e2d123d1 Mon Sep 17 00:00:00 2001 From: Junjia Liu Date: Sun, 10 Sep 2023 17:54:49 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=9A=80=20[RofuncRL]=20Fix=20bugs=20in=20A?= =?UTF-8?q?MP=20and=20ASE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../example_HumanoidASE_RofuncRL.py | 6 +++--- .../RofuncRL/agents/mixline/amp_agent.py | 19 ++++++++++++++++- .../RofuncRL/agents/mixline/ase_agent.py | 4 ++-- .../RofuncRL/agents/mixline/ase_hrl_agent.py | 21 +++++++++++++++++++ .../learning/RofuncRL/tasks/amp/motion_lib.py | 3 ++- 5 files changed, 46 insertions(+), 7 deletions(-) diff --git a/examples/learning_rl/example_HumanoidASE_RofuncRL.py b/examples/learning_rl/example_HumanoidASE_RofuncRL.py index 26f749daa..c97888734 100644 --- a/examples/learning_rl/example_HumanoidASE_RofuncRL.py +++ b/examples/learning_rl/example_HumanoidASE_RofuncRL.py @@ -95,7 +95,7 @@ def inference(custom_args): if __name__ == '__main__': - gpu_id = 0 + gpu_id = 2 parser = argparse.ArgumentParser() # Available tasks and motion files: @@ -105,9 +105,9 @@ def inference(custom_args): # HumanoidASEReachSwordShield -> reallusion_sword_shield/RL_Avatar_Idle_Ready_Motion.npy # HumanoidASELocationSwordShield -> reallusion_sword_shield/RL_Avatar_Idle_Ready_Motion.npy # HumanoidASEStrikeSwordShield -> reallusion_sword_shield/RL_Avatar_Idle_Ready_Motion.npy - parser.add_argument("--task", type=str, default="HumanoidASEStrikeSwordShield") + parser.add_argument("--task", type=str, default="HumanoidASEGetupSwordShield") parser.add_argument("--motion_file", type=str, - default="reallusion_sword_shield/RL_Avatar_Idle_Ready_Motion.npy") + default="reallusion_sword_shield/dataset_reallusion_sword_shield.yaml") parser.add_argument("--agent", type=str, default="ase") # Available agent: ase parser.add_argument("--num_envs", type=int, default=4096) parser.add_argument("--sim_device", type=str, default="cuda:{}".format(gpu_id)) diff --git a/rofunc/learning/RofuncRL/agents/mixline/amp_agent.py b/rofunc/learning/RofuncRL/agents/mixline/amp_agent.py index ed584b1db..3b7c12839 100644 --- a/rofunc/learning/RofuncRL/agents/mixline/amp_agent.py +++ b/rofunc/learning/RofuncRL/agents/mixline/amp_agent.py @@ -159,14 +159,31 @@ def _set_up(self): """ Set up optimizer, learning rate scheduler and state/value preprocessors """ + assert hasattr(self, "policy"), "Policy is not defined." + assert hasattr(self, "value"), "Value is not defined." + # Set up optimizer and learning rate scheduler - super()._set_up() + if self.policy is self.value: + self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=self._lr_a) + if self._lr_scheduler is not None: + self.scheduler = self._lr_scheduler(self.optimizer, **self._lr_scheduler_kwargs) + self.checkpoint_modules["optimizer"] = self.optimizer + else: + self.optimizer_policy = torch.optim.Adam(self.policy.parameters(), lr=self._lr_a, eps=self._adam_eps) + self.optimizer_value = torch.optim.Adam(self.value.parameters(), lr=self._lr_c, eps=self._adam_eps) + if self._lr_scheduler is not None: + self.scheduler_policy = self._lr_scheduler(self.optimizer_policy, **self._lr_scheduler_kwargs) + self.scheduler_value = self._lr_scheduler(self.optimizer_value, **self._lr_scheduler_kwargs) + self.checkpoint_modules["optimizer_policy"] = self.optimizer_policy + self.checkpoint_modules["optimizer_value"] = self.optimizer_value + self.optimizer_disc = torch.optim.Adam(self.discriminator.parameters(), lr=self._lr_d, eps=self._adam_eps) if self._lr_scheduler is not None: self.scheduler_disc = self._lr_scheduler(self.optimizer_disc, **self._lr_scheduler_kwargs) self.checkpoint_modules["optimizer_disc"] = self.optimizer_disc # set up preprocessors + super()._set_up() if self._amp_state_preprocessor: self._amp_state_preprocessor = self._amp_state_preprocessor(**self._amp_state_preprocessor_kwargs) self.checkpoint_modules["amp_state_preprocessor"] = self._amp_state_preprocessor diff --git a/rofunc/learning/RofuncRL/agents/mixline/ase_agent.py b/rofunc/learning/RofuncRL/agents/mixline/ase_agent.py index 2be855be7..f7a633fcc 100644 --- a/rofunc/learning/RofuncRL/agents/mixline/ase_agent.py +++ b/rofunc/learning/RofuncRL/agents/mixline/ase_agent.py @@ -97,7 +97,7 @@ def __init__(self, state_tensor_size = (img_channel, img_size, img_size) kd = True else: - state_tensor_size = self.observation_space + state_tensor_size = observation_space kd = False self.memory.create_tensor(name="states", size=state_tensor_size, dtype=torch.float32, keep_dimensions=kd) self.memory.create_tensor(name="next_states", size=state_tensor_size, dtype=torch.float32, keep_dimensions=kd) @@ -333,7 +333,7 @@ def update_net(self): if self.encoder is self.discriminator: enc_output = self.encoder.get_enc(self._amp_state_preprocessor(sampled_amp_states)) else: - enc_output = self.encoder(self._amp_state_preprocessor(sampled_amp_states_batch)) + enc_output = self.encoder(self._amp_state_preprocessor(sampled_amp_states)) enc_output = torch.nn.functional.normalize(enc_output, dim=-1) enc_err = -torch.sum(enc_output * sampled_ase_latents, dim=-1, keepdim=True) enc_loss = torch.mean(enc_err) diff --git a/rofunc/learning/RofuncRL/agents/mixline/ase_hrl_agent.py b/rofunc/learning/RofuncRL/agents/mixline/ase_hrl_agent.py index b87be7843..15cff3096 100644 --- a/rofunc/learning/RofuncRL/agents/mixline/ase_hrl_agent.py +++ b/rofunc/learning/RofuncRL/agents/mixline/ase_hrl_agent.py @@ -174,6 +174,27 @@ def __init__(self, self._set_up() + def _set_up(self): + assert hasattr(self, "policy"), "Policy is not defined." + assert hasattr(self, "value"), "Value is not defined." + + # Set up optimizer and learning rate scheduler + if self.policy is self.value: + self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=self._lr_a) + if self._lr_scheduler is not None: + self.scheduler = self._lr_scheduler(self.optimizer, **self._lr_scheduler_kwargs) + self.checkpoint_modules["optimizer"] = self.optimizer + else: + self.optimizer_policy = torch.optim.Adam(self.policy.parameters(), lr=self._lr_a, eps=self._adam_eps) + self.optimizer_value = torch.optim.Adam(self.value.parameters(), lr=self._lr_c, eps=self._adam_eps) + if self._lr_scheduler is not None: + self.scheduler_policy = self._lr_scheduler(self.optimizer_policy, **self._lr_scheduler_kwargs) + self.scheduler_value = self._lr_scheduler(self.optimizer_value, **self._lr_scheduler_kwargs) + self.checkpoint_modules["optimizer_policy"] = self.optimizer_policy + self.checkpoint_modules["optimizer_value"] = self.optimizer_value + + super()._set_up() + def _build_llc(self): from .utils import ase_network_builder from .utils import ase_agent diff --git a/rofunc/learning/RofuncRL/tasks/amp/motion_lib.py b/rofunc/learning/RofuncRL/tasks/amp/motion_lib.py index ac9bcf408..d5ec49d6a 100644 --- a/rofunc/learning/RofuncRL/tasks/amp/motion_lib.py +++ b/rofunc/learning/RofuncRL/tasks/amp/motion_lib.py @@ -31,11 +31,12 @@ import yaml from rofunc.utils.datalab.poselib.poselib.skeleton.skeleton3d import SkeletonMotion +from rofunc.utils.datalab.poselib.poselib.core import quat_mul_norm, quat_inverse, quat_angle_axis from rofunc.learning.RofuncRL.tasks.amp.humanoid_amp_base import DOF_BODY_IDS, DOF_OFFSETS from rofunc.learning.RofuncRL.tasks.utils.torch_jit_utils import * -class MotionLib(): +class MotionLib: def __init__(self, motion_file, num_dofs, key_body_ids, device): self._num_dof = num_dofs self._key_body_ids = key_body_ids