Skip to content

Commit

Permalink
🚀 [RofuncRL] Fix bugs in AMP and ASE
Browse files Browse the repository at this point in the history
  • Loading branch information
Skylark0924 committed Sep 10, 2023
1 parent 766f4e3 commit b10e930
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 7 deletions.
6 changes: 3 additions & 3 deletions examples/learning_rl/example_HumanoidASE_RofuncRL.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def inference(custom_args):


if __name__ == '__main__':
gpu_id = 0
gpu_id = 2

parser = argparse.ArgumentParser()
# Available tasks and motion files:
Expand All @@ -105,9 +105,9 @@ def inference(custom_args):
# HumanoidASEReachSwordShield -> reallusion_sword_shield/RL_Avatar_Idle_Ready_Motion.npy
# HumanoidASELocationSwordShield -> reallusion_sword_shield/RL_Avatar_Idle_Ready_Motion.npy
# HumanoidASEStrikeSwordShield -> reallusion_sword_shield/RL_Avatar_Idle_Ready_Motion.npy
parser.add_argument("--task", type=str, default="HumanoidASEStrikeSwordShield")
parser.add_argument("--task", type=str, default="HumanoidASEGetupSwordShield")
parser.add_argument("--motion_file", type=str,
default="reallusion_sword_shield/RL_Avatar_Idle_Ready_Motion.npy")
default="reallusion_sword_shield/dataset_reallusion_sword_shield.yaml")
parser.add_argument("--agent", type=str, default="ase") # Available agent: ase
parser.add_argument("--num_envs", type=int, default=4096)
parser.add_argument("--sim_device", type=str, default="cuda:{}".format(gpu_id))
Expand Down
19 changes: 18 additions & 1 deletion rofunc/learning/RofuncRL/agents/mixline/amp_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,14 +159,31 @@ def _set_up(self):
"""
Set up optimizer, learning rate scheduler and state/value preprocessors
"""
assert hasattr(self, "policy"), "Policy is not defined."
assert hasattr(self, "value"), "Value is not defined."

# Set up optimizer and learning rate scheduler
super()._set_up()
if self.policy is self.value:
self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=self._lr_a)
if self._lr_scheduler is not None:
self.scheduler = self._lr_scheduler(self.optimizer, **self._lr_scheduler_kwargs)
self.checkpoint_modules["optimizer"] = self.optimizer
else:
self.optimizer_policy = torch.optim.Adam(self.policy.parameters(), lr=self._lr_a, eps=self._adam_eps)
self.optimizer_value = torch.optim.Adam(self.value.parameters(), lr=self._lr_c, eps=self._adam_eps)
if self._lr_scheduler is not None:
self.scheduler_policy = self._lr_scheduler(self.optimizer_policy, **self._lr_scheduler_kwargs)
self.scheduler_value = self._lr_scheduler(self.optimizer_value, **self._lr_scheduler_kwargs)
self.checkpoint_modules["optimizer_policy"] = self.optimizer_policy
self.checkpoint_modules["optimizer_value"] = self.optimizer_value

self.optimizer_disc = torch.optim.Adam(self.discriminator.parameters(), lr=self._lr_d, eps=self._adam_eps)
if self._lr_scheduler is not None:
self.scheduler_disc = self._lr_scheduler(self.optimizer_disc, **self._lr_scheduler_kwargs)
self.checkpoint_modules["optimizer_disc"] = self.optimizer_disc

# set up preprocessors
super()._set_up()
if self._amp_state_preprocessor:
self._amp_state_preprocessor = self._amp_state_preprocessor(**self._amp_state_preprocessor_kwargs)
self.checkpoint_modules["amp_state_preprocessor"] = self._amp_state_preprocessor
Expand Down
4 changes: 2 additions & 2 deletions rofunc/learning/RofuncRL/agents/mixline/ase_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def __init__(self,
state_tensor_size = (img_channel, img_size, img_size)
kd = True
else:
state_tensor_size = self.observation_space
state_tensor_size = observation_space
kd = False
self.memory.create_tensor(name="states", size=state_tensor_size, dtype=torch.float32, keep_dimensions=kd)
self.memory.create_tensor(name="next_states", size=state_tensor_size, dtype=torch.float32, keep_dimensions=kd)
Expand Down Expand Up @@ -333,7 +333,7 @@ def update_net(self):
if self.encoder is self.discriminator:
enc_output = self.encoder.get_enc(self._amp_state_preprocessor(sampled_amp_states))
else:
enc_output = self.encoder(self._amp_state_preprocessor(sampled_amp_states_batch))
enc_output = self.encoder(self._amp_state_preprocessor(sampled_amp_states))
enc_output = torch.nn.functional.normalize(enc_output, dim=-1)
enc_err = -torch.sum(enc_output * sampled_ase_latents, dim=-1, keepdim=True)
enc_loss = torch.mean(enc_err)
Expand Down
21 changes: 21 additions & 0 deletions rofunc/learning/RofuncRL/agents/mixline/ase_hrl_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,27 @@ def __init__(self,

self._set_up()

def _set_up(self):
assert hasattr(self, "policy"), "Policy is not defined."
assert hasattr(self, "value"), "Value is not defined."

# Set up optimizer and learning rate scheduler
if self.policy is self.value:
self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=self._lr_a)
if self._lr_scheduler is not None:
self.scheduler = self._lr_scheduler(self.optimizer, **self._lr_scheduler_kwargs)
self.checkpoint_modules["optimizer"] = self.optimizer
else:
self.optimizer_policy = torch.optim.Adam(self.policy.parameters(), lr=self._lr_a, eps=self._adam_eps)
self.optimizer_value = torch.optim.Adam(self.value.parameters(), lr=self._lr_c, eps=self._adam_eps)
if self._lr_scheduler is not None:
self.scheduler_policy = self._lr_scheduler(self.optimizer_policy, **self._lr_scheduler_kwargs)
self.scheduler_value = self._lr_scheduler(self.optimizer_value, **self._lr_scheduler_kwargs)
self.checkpoint_modules["optimizer_policy"] = self.optimizer_policy
self.checkpoint_modules["optimizer_value"] = self.optimizer_value

super()._set_up()

def _build_llc(self):
from .utils import ase_network_builder
from .utils import ase_agent
Expand Down
3 changes: 2 additions & 1 deletion rofunc/learning/RofuncRL/tasks/amp/motion_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,12 @@
import yaml

from rofunc.utils.datalab.poselib.poselib.skeleton.skeleton3d import SkeletonMotion
from rofunc.utils.datalab.poselib.poselib.core import quat_mul_norm, quat_inverse, quat_angle_axis
from rofunc.learning.RofuncRL.tasks.amp.humanoid_amp_base import DOF_BODY_IDS, DOF_OFFSETS
from rofunc.learning.RofuncRL.tasks.utils.torch_jit_utils import *


class MotionLib():
class MotionLib:
def __init__(self, motion_file, num_dofs, key_body_ids, device):
self._num_dof = num_dofs
self._key_body_ids = key_body_ids
Expand Down

0 comments on commit b10e930

Please sign in to comment.