Skip to content

Commit

Permalink
🚀 [RofuncRL] Fix some bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
Skylark0924 committed Aug 26, 2023
1 parent 54dee91 commit b8125c2
Show file tree
Hide file tree
Showing 9 changed files with 66 additions and 148 deletions.
4 changes: 2 additions & 2 deletions examples/learning_rl/example_Ant_RofuncRL.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,11 @@ def inference(custom_args):


if __name__ == '__main__':
gpu_id = 2
gpu_id = 3

parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="Ant")
parser.add_argument("--agent", type=str, default="td3") # Available agents: ppo, sac, td3
parser.add_argument("--agent", type=str, default="a2c") # Available agents: ppo, sac, td3, a2c
parser.add_argument("--num_envs", type=int, default=4096)
parser.add_argument("--sim_device", type=str, default="cuda:{}".format(gpu_id))
parser.add_argument("--rl_device", type=str, default="cuda:{}".format(gpu_id))
Expand Down
90 changes: 0 additions & 90 deletions rofunc/config/learning/rl/train/AntPPORofuncRL.yaml

This file was deleted.

28 changes: 7 additions & 21 deletions rofunc/learning/RofuncRL/agents/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,40 +73,26 @@ def __init__(self,
'''Set up'''
self._lr_scheduler = None
self._lr_scheduler_kwargs = {}
self._state_preprocessor = None
self._state_preprocessor_kwargs = {}
self._value_preprocessor = None
self._value_preprocessor_kwargs = {}

'''Define state encoder'''
self.se = encoder_map[cfg.Model.state_encoder.encoder_type](cfg.Model).to(self.device) \
if hasattr(cfg.Model, "state_encoder") else EmptyEncoder()

def _set_up(self):
"""
Set up optimizer, learning rate scheduler and state/value preprocessors
Set up state/value preprocessors
"""
assert hasattr(self, "policy"), "Policy is not defined."
assert hasattr(self, "value"), "Value is not defined."

# Set up optimizer and learning rate scheduler
if self.policy is self.value:
self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=self._lr_a)
if self._lr_scheduler is not None:
self.scheduler = self._lr_scheduler(self.optimizer, **self._lr_scheduler_kwargs)
self.checkpoint_modules["optimizer"] = self.optimizer
else:
self.optimizer_policy = torch.optim.Adam(self.policy.parameters(), lr=self._lr_a, eps=self._adam_eps)
self.optimizer_value = torch.optim.Adam(self.value.parameters(), lr=self._lr_c, eps=self._adam_eps)
if self._lr_scheduler is not None:
self.scheduler_policy = self._lr_scheduler(self.optimizer_policy, **self._lr_scheduler_kwargs)
self.scheduler_value = self._lr_scheduler(self.optimizer_value, **self._lr_scheduler_kwargs)
self.checkpoint_modules["optimizer_policy"] = self.optimizer_policy
self.checkpoint_modules["optimizer_value"] = self.optimizer_value

# set up preprocessors
if self._state_preprocessor:
if self._state_preprocessor is not None:
self._state_preprocessor = self._state_preprocessor(**self._state_preprocessor_kwargs)
self.checkpoint_modules["state_preprocessor"] = self._state_preprocessor
else:
self._state_preprocessor = empty_preprocessor
if self._value_preprocessor:
if self._value_preprocessor is not None:
self._value_preprocessor = self._value_preprocessor(**self._value_preprocessor_kwargs)
self.checkpoint_modules["value_preprocessor"] = self._value_preprocessor
else:
Expand Down
11 changes: 1 addition & 10 deletions rofunc/learning/RofuncRL/agents/online/a2c_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,16 +138,7 @@ def _set_up(self):
self.checkpoint_modules["optimizer_value"] = self.optimizer_value

# set up preprocessors
if self._state_preprocessor:
self._state_preprocessor = self._state_preprocessor(**self._state_preprocessor_kwargs)
self.checkpoint_modules["state_preprocessor"] = self._state_preprocessor
else:
self._state_preprocessor = empty_preprocessor
if self._value_preprocessor:
self._value_preprocessor = self._value_preprocessor(**self._value_preprocessor_kwargs)
self.checkpoint_modules["value_preprocessor"] = self._value_preprocessor
else:
self._value_preprocessor = empty_preprocessor
super()._set_up()

def act(self, states: torch.Tensor, deterministic: bool = False):
if not deterministic:
Expand Down
32 changes: 31 additions & 1 deletion rofunc/learning/RofuncRL/agents/online/ppo_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from rofunc.learning.RofuncRL.processors.schedulers import KLAdaptiveRL
from rofunc.learning.RofuncRL.processors.standard_scaler import RunningStandardScaler
from rofunc.learning.RofuncRL.utils.memory import Memory
from rofunc.learning.RofuncRL.processors.normalizers import Normalization


class PPOAgent(BaseAgent):
Expand Down Expand Up @@ -109,10 +110,14 @@ def __init__(self,
self._clip_predicted_values = self.cfg.Agent.clip_predicted_values
self._kl_threshold = self.cfg.Agent.kl_threshold
self._rewards_shaper = self.cfg.get("Agent", {}).get("rewards_shaper", lambda rewards: rewards * 0.01)
self._state_preprocessor = None # TODO: Check
# self._state_preprocessor = None # TODO: Check
# self._state_preprocessor = RunningStandardScaler
# self._state_preprocessor_kwargs = self.cfg.get("Agent", {}).get("state_preprocessor_kwargs",
# {"size": observation_space, "device": device})
# self._state_preprocessor = Normalization
# self._state_preprocessor_kwargs = self.cfg.get("Agent", {}).get("state_preprocessor_kwargs",
# {"shape": observation_space, "device": device})

self._value_preprocessor = RunningStandardScaler
self._value_preprocessor_kwargs = self.cfg.get("Agent", {}).get("value_preprocessor_kwargs",
{"size": 1, "device": device})
Expand All @@ -123,6 +128,31 @@ def __init__(self,

self._set_up()

def _set_up(self):
"""
Set up optimizer, learning rate scheduler and state/value preprocessors
"""
assert hasattr(self, "policy"), "Policy is not defined."
assert hasattr(self, "value"), "Value is not defined."

# Set up optimizer and learning rate scheduler
if self.policy is self.value:
self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=self._lr_a)
if self._lr_scheduler is not None:
self.scheduler = self._lr_scheduler(self.optimizer, **self._lr_scheduler_kwargs)
self.checkpoint_modules["optimizer"] = self.optimizer
else:
self.optimizer_policy = torch.optim.Adam(self.policy.parameters(), lr=self._lr_a, eps=self._adam_eps)
self.optimizer_value = torch.optim.Adam(self.value.parameters(), lr=self._lr_c, eps=self._adam_eps)
if self._lr_scheduler is not None:
self.scheduler_policy = self._lr_scheduler(self.optimizer_policy, **self._lr_scheduler_kwargs)
self.scheduler_value = self._lr_scheduler(self.optimizer_value, **self._lr_scheduler_kwargs)
self.checkpoint_modules["optimizer_policy"] = self.optimizer_policy
self.checkpoint_modules["optimizer_value"] = self.optimizer_value

# set up preprocessors
super()._set_up()

def act(self, states: torch.Tensor, deterministic: bool = False):
if not deterministic:
# sample stochastic actions
Expand Down
19 changes: 7 additions & 12 deletions rofunc/learning/RofuncRL/agents/online/sac_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from rofunc.learning.RofuncRL.models.critic_models import Critic
from rofunc.learning.RofuncRL.processors.schedulers import KLAdaptiveRL
from rofunc.learning.RofuncRL.processors.normalizers import Normalization
from rofunc.learning.RofuncRL.processors.standard_scaler import empty_preprocessor
from rofunc.learning.RofuncRL.utils.memory import Memory


Expand Down Expand Up @@ -109,10 +108,10 @@ def __init__(self,
self._entropy_learning_rate = self.cfg.Agent.entropy_learning_rate
self._entropy_coefficient = self.cfg.Agent.initial_entropy_value
self._target_entropy = self.cfg.Agent.target_entropy
# self._state_preprocessor = None # TODO: Check
self._state_preprocessor = Normalization
self._state_preprocessor_kwargs = self.cfg.get("Agent", {}).get("state_preprocessor_kwargs",
{"shape": observation_space, "device": device})
self._state_preprocessor = None # TODO: Check
# self._state_preprocessor = Normalization
# self._state_preprocessor_kwargs = self.cfg.get("Agent", {}).get("state_preprocessor_kwargs",
# {"shape": observation_space, "device": device})

'''Misc variables'''
self._current_log_prob = None
Expand Down Expand Up @@ -154,20 +153,16 @@ def _set_up(self):

self.checkpoint_modules["entropy_optimizer"] = self.entropy_optimizer

# set up preprocessors
if self._state_preprocessor:
self._state_preprocessor = self._state_preprocessor(**self._state_preprocessor_kwargs)
self.checkpoint_modules["state_preprocessor"] = self._state_preprocessor
else:
self._state_preprocessor = empty_preprocessor

# freeze target networks with respect to optimizers (update via .update_parameters())
self.target_critic_1.freeze_parameters(True)
self.target_critic_2.freeze_parameters(True)
# update target networks (hard update)
self.target_critic_1.update_parameters(self.critic_1, polyak=1)
self.target_critic_2.update_parameters(self.critic_2, polyak=1)

# set up preprocessors
super()._set_up()

def act(self, states: torch.Tensor, deterministic: bool = False):
if not deterministic:
# sample stochastic actions
Expand Down
10 changes: 3 additions & 7 deletions rofunc/learning/RofuncRL/agents/online/td3_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,13 +141,6 @@ def _set_up(self):
self.checkpoint_modules["actor_optimizer"] = self.actor_optimizer
self.checkpoint_modules["critic_optimizer"] = self.critic_optimizer

# set up preprocessors
if self._state_preprocessor:
self._state_preprocessor = self._state_preprocessor(**self._state_preprocessor_kwargs)
self.checkpoint_modules["state_preprocessor"] = self._state_preprocessor
else:
self._state_preprocessor = empty_preprocessor

# freeze target networks with respect to optimizers (update via .update_parameters())
self.target_actor.freeze_parameters(True)
self.target_critic_1.freeze_parameters(True)
Expand All @@ -156,6 +149,9 @@ def _set_up(self):
self.target_actor.update_parameters(self.actor, polyak=1)
self.target_critic_1.update_parameters(self.critic_1, polyak=1)
self.target_critic_2.update_parameters(self.critic_2, polyak=1)

# set up preprocessors
super()._set_up()

def act(self, states: torch.Tensor, deterministic: bool = False):
if not deterministic:
Expand Down
11 changes: 8 additions & 3 deletions rofunc/learning/RofuncRL/processors/normalizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,14 @@ def __init__(self, shape, device): # shape:the dimension of input data
self.std = torch.sqrt(self.S).to(device)

def train(self, x):
if len(x.shape) == len(self.shape) + 1: # Batch data
batch_size = x.shape[0]
x = torch.sum(x, dim=0) / batch_size
if isinstance(self.shape, int):
if len(x.shape) == 2:
batch_size = x.shape[0]
x = torch.sum(x, dim=0) / batch_size
else:
if len(x.shape) == len(self.shape) + 1:
batch_size = x.shape[0]
x = torch.sum(x, dim=0) / batch_size

self.n += 1
if self.n == 1:
Expand Down
9 changes: 7 additions & 2 deletions rofunc/learning/RofuncRL/trainers/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import copy
import datetime
import json
import multiprocessing
import os
import random
Expand Down Expand Up @@ -61,7 +62,7 @@ def __init__(self,

'''Rofunc logger'''
self.rofunc_logger = BeautyLogger(self.exp_dir, verbose=self.cfg.Trainer.rofunc_logger_kwargs.verbose)
self.rofunc_logger.info(f"Configurations:\n{OmegaConf.to_yaml(self.cfg)}")
self.rofunc_logger.info(f"Trainer configurations:\n{OmegaConf.to_yaml(self.cfg)}")

'''TensorBoard'''
# main entry to log data for consumption and visualization by TensorBoard
Expand Down Expand Up @@ -90,9 +91,12 @@ def __init__(self,
self._update_times = 0
self.start_time = None

'''Evaluation and inference configurations'''
self.eval_flag = self.cfg.Trainer.eval_flag if hasattr(self.cfg.Trainer, "eval_flag") else False
self.eval_freq = self.cfg.Trainer.eval_freq if hasattr(self.cfg.Trainer, "eval_freq") else 0
self.eval_steps = self.cfg.Trainer.eval_steps if hasattr(self.cfg.Trainer, "eval_steps") else 0
self.eval_env_seed = self.cfg.Trainer.eval_env_seed if hasattr(self.cfg.Trainer,
"eval_env_seed") else random.randint(0, 10000)
self.use_eval_thread = self.cfg.Trainer.use_eval_thread if hasattr(self.cfg.Trainer,
"use_eval_thread") else False
assert self.eval_steps % self.max_episode_steps == 0, \
Expand All @@ -104,11 +108,12 @@ def __init__(self,
'''Environment'''
env.device = self.device
self.env = wrap_env(env, logger=self.rofunc_logger, seed=self.cfg.Trainer.seed)
self.eval_env = wrap_env(env, logger=self.rofunc_logger, seed=random.randint(0, 10000))
self.eval_env = wrap_env(env, logger=self.rofunc_logger, seed=self.eval_env_seed) if self.eval_flag else None
self.rofunc_logger.info(f"Environment:\n "
f" action_space: {self.env.action_space.shape}\n "
f" observation_space: {self.env.observation_space.shape}\n "
f" num_envs: {self.env.num_envs}")
self.rofunc_logger.info(f"Task configurations:\n{self.env._env.cfg}")

'''Normalization'''
self.state_norm = Normalization(shape=self.env.observation_space, device=device)
Expand Down

0 comments on commit b8125c2

Please sign in to comment.