From 165652cddd88685cf3445986e0c4f4d092d2e511 Mon Sep 17 00:00:00 2001 From: Anish Diwan <56624586+anishhdiwan@users.noreply.github.com> Date: Mon, 29 Jan 2024 19:36:26 +0100 Subject: [PATCH 01/13] Ad/documentation (#272) * Add docs * Update docs; Rename file * Add author --------- Co-authored-by: anishdiwan --- docs/HOW_TO_RL_GAMES.md | 263 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 263 insertions(+) create mode 100644 docs/HOW_TO_RL_GAMES.md diff --git a/docs/HOW_TO_RL_GAMES.md b/docs/HOW_TO_RL_GAMES.md new file mode 100644 index 00000000..8050df1c --- /dev/null +++ b/docs/HOW_TO_RL_GAMES.md @@ -0,0 +1,263 @@ +# Introduction to [rl_games](https://github.com/Denys88/rl_games/) - new envs, and new algorithms built on rl_games +**Author** - [Anish Diwan](https://www.anishdiwan.com/) + +This write-up describes some elements of the general functioning of the [rl_games](https://github.com/Denys88/rl_games/) reinforcement learning library. It also provides a guide on extending rl_games with new environments and algorithms using a structure similar to the [IsaacGymEnvs](https://github.com/NVIDIA-Omniverse/IsaacGymEnvs) package. Topics covered in this write-up are +1. The various components of rl_games (runner, algorthms, environments ...) +2. Using rl_games for your own work + - Adding new gym-like environments to rl_games + - Using non-gym environments and simulators with the algorithms in rl_games (refer to [IsaacGymEnvs](https://github.com/NVIDIA-Omniverse/IsaacGymEnvs) for examples) + - Adding new algorithms to rl_games + +## General setup in rl_games +rl_games uses the main python script called `runner.py` along with flags for either training (`--train`) or executing policies (`--play`) and a mandatory argument for passing training/playing configurations (`--file`). A basic example of training and then playing for PPO in Pong can be executed with the following. You can also checkout the PPO config file at `rl_games/configs/atari/ppo_pong.yaml`. + +``` +python runner.py --train --file rl_games/configs/atari/ppo_pong.yaml +python runner.py --play --file rl_games/configs/atari/ppo_pong.yaml --checkpoint nn/PongNoFrameskip.pth +``` + +rl_games uses the following base classes to define algorithms, instantiate environments, and log metrics. + +1. **Main Script** - `rl_games.torch_runner.Runner` + - This is the main class that instantiates the algorithm as per the given configuration and executes either training or playing + - When instantiated, algorithm instances for all algos in rl_games are automatically added using `rl_games.common.Objectfactory()`'s `register_builder()` method. The same is also done for the player instances for all algos. + - Depending on the args given, either `self.run_train()` or `self.run_play()` is executed + - The Runner also sets up the algorithm observer that logs training metrics. If one is not provided, it automatically uses the `DefaultAlgoObserver()` which logs metrics available to the algo using the tensorboard summarywriter. + - Logs and checkpoints are automatically created in a directory called nn (by default). + - Custom algorithms and observers can also be provided based on your requirements (more on this below). + + +2. **Instantiating Algos** - `rl_games.common.Objectfactory()` + - Creates algorithms or players. Has the `register_builder(self, name, builder)` method that adds a function that returns whatever is being built (name is a str). For example the following line adds the name a2c_continuous to a lambda function that returns the A2CAgent + ```python + register_builder('a2c_continuous', lambda **kwargs : a2c_continuous.A2CAgent(**kwargs)) + ``` + - Also has a `create(self, name, **kwargs)` method that simply returns one of the registered builders by name + +3. **RL Algorithms** + - rl_games has several reinforcement learning algorithms. Most of these inherit from some sort of base algorithm class, for example, `rl_games.algos_torch.A2CBase`. + - In rl_games environments are instantiated by the algorithm. Depending on the config setup, you can also run multiple envs in parallel. + +4. **Environments** - `rl_games.common.vecenv` & `rl_games.common.env_configurations` + - The `vecenv` script holds classes to instantiate different environments based on their type. Since rl_games is quite a broad library, it supports multiple environment types (such as openAI gym envs, brax envs, cule envs etc). These environment types and their base classes are stored in the `rl_games.common.vecenv.vecenv_config` dictionary. The environment class enables stuff like running multiple parallel environments, or running multi-agent environments. By default, all available environments are already added. Adding new environments is explained below. + + - `rl_games.common.env_configurations.configurations` is another dictionary that stores `env_name: {'vecenv_type', 'env_creator}` information. For example, the following stores the environment name "CartPole-v1" with a value for its type and a lambda function that instantiates the respective gym env. + ```python + 'CartPole-v1' : { + 'vecenv_type' : 'RAY', + 'env_creator' : lambda **kwargs : gym.make('CartPole-v1'),} + ``` + - The general idea here is that the algorithm base class (for example `A2CAgent`) instantiates a new environment by looking at the env_name (for example 'CartPole-v1') in the config file. Internally, the name 'CartPole-v1' is used to get the env type from `rl_games.common.env_configurations.configurations`. The type then goes into the `vecenv.vecenv_config` dict which returns the actual environment class (such as RayVecEnv).Note, the env class (such as RayVecEnv) then internally uses the 'env_creator' key to instantiate the environment using whatever function was given to it (for example, `lambda **kwargs : gym.make('CartPole-v1')`) + - While being a bit convoluted, this allows us to directly pass an env name in the config to run experiments + +## Extending rl_games for your own work +While rl_games provides a great baseline implementation of several environments and algorithms, it is also a great starting point for your own work. The rest of this write-up explains how new environments or algorithms can be added. It is based on the setup from [IsaacGymEnvs](https://github.com/NVIDIA-Omniverse/IsaacGymEnvs), the NVIDIA repository for RL simulations and training. We use [hydra](https://hydra.cc/docs/intro/) for easier configuration management. Further, instead of directly using `runner.py` we use another similar script called `train.py` which allows us to dynamically add new environments and insert out own algorithms. + +With this considered, our final file structure is something like this. + +``` +project dir +│ train.py (replacement to the runner.py script) +│ +└───tasks dir (sometimes also called envs dir) +│ │ customenv.py +│ │ customenv_utils.py +| +└───cfg dir (main hydra configs) +│ │ config.yaml (main config for the setting up simulators etc. if needed) +│ │ +│ └─── task dir (configs for the env) +│ │ customenv.yaml +│ │ otherenv.yaml +│ │ ... +| +│ └─── train dir (configs for training the algorithm) +│ │ customenvPPO.yaml +│ │ otherenvAlgo.yaml +│ │ ... +| +└───algos dir (custom wrappers for training algorithms in rl_games) +| │ custom_network_builder.py +| │ custom_algo.py +| | ... +| +└───runs dir (generated automatically on executing train.py) +│ └─── env_name_alg_name_datetime dir (train logs) +│ └─── nn +| | checkpoints.pth +│ └─── summaries + | events.out... +``` + +### Adding new gym-like environments +New environments can be used with the rl_games setup by first defining the TYPE of the new env. A new environment TYPE can be added by calling the `vecenv.register(config_name, func)` function that simply adds the `config_name:func` pair to the dictionary. For example the following line adds a 'RAY' type env with a lambda function that then instantiates the RayVecEnv class. The "RayVecEnv" holds "RayWorkers" that internally store the environment. This automatically allows for multi-env training. + +```python +register('RAY', lambda config_name, num_actors, **kwargs: RayVecEnv(config_name, num_actors, **kwargs)) +``` + +For gym-like envs (that inherit from the gym base class), the TYPE can simply be `RayVecEnv` from rl_games. Adding a gym-like environment essentially translates to creating a class that inherits from gym.Env and adding this under the type 'RAY' to `rl_games.common.env_configurations`. Ideally, this needs to be done by adding the key value pair `env_name: {'vecenv_type', 'env_creator}` to `env_configurations.configurations`. However, this requires modifying the rl_games library. If you do not wish to do that then you can instead use the register method to add your new env to the dictionary, then make a copy of the RayVecEnv and RayWorked classes and change the `__init__` method to instead take in the modified env configurations dict. For example + +**Within train.py** +```python +@hydra.main(version_base="1.1", config_name="custom_config", config_path="./cfg") +def launch_rlg_hydra(cfg: DictConfig): + + from custom_envs.custom_env import SomeEnv + from custom_envs.customenv_utils import CustomRayVecEnv + from rl_games.common import env_configurations, vecenv + + def create_pusht_env(**kwargs): + # Instantiate new env + env = SomeEnv() + + #Alternate example, env = gym.make('LunarLanderContinuous-v2') + return env + + # Register the TYPE + env_configurations.register('pushT', { + 'vecenv_type': 'CUSTOMRAY', + 'env_creator': lambda **kwargs: create_pusht_env(**kwargs), + }) + + # Provide the TYPE:func pair + vecenv.register('CUSTOMRAY', lambda config_name, num_actors, **kwargs: CustomRayVecEnv(env_configurations.configurations, config_name, num_actors, **kwargs)) +``` + +-------------------------------- + +**Custom Env TYPEs (enables adding new envs dynamically)** +```python +# Make a copy of RayVecEnv + +class CustomRayVecEnv(IVecEnv): + import ray + + def __init__(self, config_dict, config_name, num_actors, **kwargs): + ### ADDED CHANGE ### + # Explicityly passing in the dictionary containing env_name: {vecenv_type, env_creator} + self.config_dict = config_dict + + self.config_name = config_name + self.num_actors = num_actors + self.use_torch = False + self.seed = kwargs.pop('seed', None) + + + self.remote_worker = self.ray.remote(CustomRayWorker) + self.workers = [self.remote_worker.remote(self.config_dict, self.config_name, kwargs) for i in range(self.num_actors)] + + ... + ... + +# Make a copy of RayWorker + +class CustomRayWorker: + ## ADDED CHANGE ### + # Add config_dict to init + def __init__(self, config_dict, config_name, config): + self.env = config_dict[config_name]['env_creator'](**config) + + ... + ... +``` + +### Adding non-gym environments & simulators +Non-gym environments can be added in the same way. However, now you also need to create your own TYPE class. [IsaacGymEnvs](https://github.com/NVIDIA-Omniverse/IsaacGymEnvs) does this by defining a new RLGPU type that uses the IsaacGym simulation environment. An example of this can be found in the IsaacGymEnvs library (checkout `RLGPUEnv` [here](https://github.com/NVIDIA-Omniverse/IsaacGymEnvs/blob/main/isaacgymenvs/utils/rlgames_utils.py)). + + +### New algorithms and observers within rl_games + +Adding a custom algorithm essentially translates to registering your own builder and player within the `rl_games.torch_runner.Runner`. IsaacGymEnvs does this by adding the following within the dydra-decorated main function (their algo is called AMP). + +**Within train.py** +```python +# register new AMP network builder and agent +def build_runner(algo_observer): + runner = Runner(algo_observer) + runner.algo_factory.register_builder('amp_continuous', lambda **kwargs : amp_continuous.AMPAgent(**kwargs)) + runner.player_factory.register_builder('amp_continuous', lambda **kwargs : amp_players.AMPPlayerContinuous(**kwargs)) + model_builder.register_model('continuous_amp', lambda network, **kwargs : amp_models.ModelAMPContinuous(network)) + model_builder.register_network('amp', lambda **kwargs : amp_network_builder.AMPBuilder()) + + return runner +``` + +As you might have noticed from above, you can also add a custom observer to log whatever data you need. You can make your own by inheriting from `rl_games.common.algo_observer.AlgoObserver`. If you wish to log scores, your custom environment must have a "scores" key in the info dictionary (the info dict is returned when the environment is stepped). + + +### A complete example +Here's a complete example of a custom `train.py` script that makes a new gym-like env called pushT and uses a custom observer to log metrics. + +```python +import hydra + +from omegaconf import DictConfig, OmegaConf +from omegaconf import DictConfig, OmegaConf + + +# Hydra decorator to pass in the config. Looks for a config file in the specified path. This file in turn has links to other configs +@hydra.main(version_base="1.1", config_name="custom_config", config_path="./cfg") +def launch_rlg_hydra(cfg: DictConfig): + + import logging + import os + + from hydra.utils import to_absolute_path + import gym + from isaacgymenvs.utils.reformat import omegaconf_to_dict, print_dict + from rl_games.common import env_configurations, vecenv + from rl_games.torch_runner import Runner + + + # Naming the run + time_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + run_name = f"{cfg.run_name}_{time_str}" + + # ensure checkpoints can be specified as relative paths + if cfg.checkpoint: + cfg.checkpoint = to_absolute_path(cfg.checkpoint) + + + # Creating a new function to return a pushT environment. This will then be added to rl_games env_configurations so that an env can be created from its name in the config + from custom_envs.pusht_single_env import PushTEnv + from custom_envs.customenv_utils import CustomRayVecEnv, PushTAlgoObserver + + def create_pusht_env(**kwargs): + env = PushTEnv() + return env + + # env_configurations.register adds the env to the list of rl_games envs. + env_configurations.register('pushT', { + 'vecenv_type': 'CUSTOMRAY', + 'env_creator': lambda **kwargs: create_pusht_env(**kwargs), + }) + + # vecenv register calls the following lambda function which then returns an instance of CUSTOMRAY. + vecenv.register('CUSTOMRAY', lambda config_name, num_actors, **kwargs: CustomRayVecEnv(env_configurations.configurations, config_name, num_actors, **kwargs)) + + # Convert to a big dictionary + rlg_config_dict = omegaconf_to_dict(cfg.train) + + # Build an rl_games runner. You can add other algos and builders here + def build_runner(): + runner = Runner(algo_observer=PushTAlgoObserver()) + return runner + + # create runner and set the settings + runner = build_runner() + runner.load(rlg_config_dict) + runner.reset() + + # Run either training or playing via the rl_games runner + runner.run({ + 'train': not cfg.test, + 'play': cfg.test, + # 'checkpoint': cfg.checkpoint, + # 'sigma': cfg.sigma if cfg.sigma != '' else None + }) + + +if __name__ == "__main__": + launch_rlg_hydra() +``` From cba782ceb772795628e52a3da3d5dc8c20ecb779 Mon Sep 17 00:00:00 2001 From: Anish Diwan <56624586+anishhdiwan@users.noreply.github.com> Date: Wed, 31 Jan 2024 03:41:58 +0100 Subject: [PATCH 02/13] Add docstrings to core library components (#273) Documentation style https://google.github.io/styleguide/pyguide.html Co-authored-by: anishdiwan --- rl_games/algos_torch/a2c_continuous.py | 20 +++++++++++ rl_games/algos_torch/a2c_discrete.py | 19 +++++++++++ rl_games/common/env_configurations.py | 8 +++++ rl_games/common/object_factory.py | 25 ++++++++++++++ rl_games/common/vecenv.py | 47 ++++++++++++++++++++++++++ rl_games/torch_runner.py | 41 ++++++++++++++++++++++ 6 files changed, 160 insertions(+) diff --git a/rl_games/algos_torch/a2c_continuous.py b/rl_games/algos_torch/a2c_continuous.py index b731a4ed..78795a2b 100644 --- a/rl_games/algos_torch/a2c_continuous.py +++ b/rl_games/algos_torch/a2c_continuous.py @@ -10,8 +10,20 @@ class A2CAgent(a2c_common.ContinuousA2CBase): + """Continuous PPO Agent + The A2CAgent class inerits from the continuous asymmetric actor-critic class and makes modifications for PPO. + + """ def __init__(self, base_name, params): + """Initialise the algorithm with passed params + + Args: + base_name (:obj:`str`): Name passed on to the observer and used for checkpoints etc. + params (:obj `dict`): Algorithm parameters + + """ + a2c_common.ContinuousA2CBase.__init__(self, base_name, params) obs_shape = self.obs_shape build_config = { @@ -75,6 +87,14 @@ def get_masked_action_values(self, obs, action_masks): assert False def calc_gradients(self, input_dict): + """Compute gradients needed to step the networks of the algorithm. + + Core algo logic is defined here + + Args: + input_dict (:obj:`dict`): Algo inputs as a dict. + + """ value_preds_batch = input_dict['old_values'] old_action_log_probs_batch = input_dict['old_logp_actions'] advantage = input_dict['advantages'] diff --git a/rl_games/algos_torch/a2c_discrete.py b/rl_games/algos_torch/a2c_discrete.py index 467d8d86..312bd17c 100644 --- a/rl_games/algos_torch/a2c_discrete.py +++ b/rl_games/algos_torch/a2c_discrete.py @@ -13,8 +13,19 @@ class DiscreteA2CAgent(a2c_common.DiscreteA2CBase): + """Discrete PPO Agent + The DiscreteA2CAgent class inerits from the discrete asymmetric actor-critic class and makes modifications for PPO. + + """ def __init__(self, base_name, params): + """Initialise the algorithm with passed params + + Args: + base_name (:obj:`str`): Name passed on to the observer and used for checkpoints etc. + params (:obj `dict`): Algorithm parameters + + """ a2c_common.DiscreteA2CBase.__init__(self, base_name, params) obs_shape = self.obs_shape @@ -108,6 +119,14 @@ def train_actor_critic(self, input_dict): return self.train_result def calc_gradients(self, input_dict): + """Compute gradients needed to step the networks of the algorithm. + + Core algo logic is defined here + + Args: + input_dict (:obj:`dict`): Algo inputs as a dict. + + """ value_preds_batch = input_dict['old_values'] old_action_log_probs_batch = input_dict['old_logp_actions'] advantage = input_dict['advantages'] diff --git a/rl_games/common/env_configurations.py b/rl_games/common/env_configurations.py index d8b335e3..43c8ebe1 100644 --- a/rl_games/common/env_configurations.py +++ b/rl_games/common/env_configurations.py @@ -253,6 +253,7 @@ def create_env(name, **kwargs): env = wrappers.TimeLimit(env, steps_limit) return env +# Dictionary of env_name as key and a sub-dict containing env_type and a env-creator function configurations = { 'CartPole-v1' : { 'vecenv_type' : 'RAY', @@ -458,4 +459,11 @@ def get_obs_and_action_spaces_from_config(config): def register(name, config): + """Add a new key-value pair to the known environments (configurations dict). + + Args: + name (:obj:`str`): Name of the env to be added. + config (:obj:`dict`): Dictionary with env type and a creator function. + + """ configurations[name] = config \ No newline at end of file diff --git a/rl_games/common/object_factory.py b/rl_games/common/object_factory.py index 4cd97f98..45728098 100644 --- a/rl_games/common/object_factory.py +++ b/rl_games/common/object_factory.py @@ -1,14 +1,39 @@ class ObjectFactory: + """General-purpose class to instantiate some other base class from rl_games. Usual use case it to instantiate algos, players etc. + + The ObjectFactory class is used to dynamically create any other object using a builder function (typically a lambda function). + + """ + def __init__(self): + """Initialise a dictionary of builders with keys as `str` and values as functions. + + """ self._builders = {} def register_builder(self, name, builder): + """Register a passed builder by adding to the builders dict. + + Initialises runners and players for all algorithms available in the library using `rl_games.common.object_factory.ObjectFactory` + + Args: + name (:obj:`str`): Key of the added builder. + builder (:obj `func`): Function to return the requested object + + """ self._builders[name] = builder def set_builders(self, builders): self._builders = builders def create(self, name, **kwargs): + """Create the requested object by calling a registered builder function. + + Args: + name (:obj:`str`): Key of the requested builder. + **kwargs: Arbitrary kwargs needed for the builder function + + """ builder = self._builders.get(name) if not builder: raise ValueError(name) diff --git a/rl_games/common/vecenv.py b/rl_games/common/vecenv.py index 01016723..c29fd4be 100644 --- a/rl_games/common/vecenv.py +++ b/rl_games/common/vecenv.py @@ -8,7 +8,20 @@ import torch class RayWorker: + """Wrapper around a third-party (gym for example) environment class that enables parallel training. + + The RayWorker class wraps around another environment class to enable the use of this + environment within an asynchronous parallel training setup + + """ def __init__(self, config_name, config): + """Initialise the class. Sets up the environment creator using the `rl_games.common.env_configurations.configuraitons` dict + + Args: + config_name (:obj:`str`): Key of the environment to create. + config: Misc. kwargs passed on to the environment creator function + + """ self.env = configurations[config_name]['env_creator'](**config) def _obs_to_fp32(self, obs): @@ -27,6 +40,12 @@ def _obs_to_fp32(self, obs): return obs def step(self, action): + """Step the environment and reset if done + + Args: + action (type depends on env): Action to take. + + """ next_state, reward, is_done, info = self.env.step(action) if np.isscalar(is_done): @@ -95,9 +114,23 @@ def get_env_info(self): class RayVecEnv(IVecEnv): + """Main env class that manages several `rl_games.common.vecenv.Rayworker` objects for parallel training + + The RayVecEnv class manages a set of individual environments and wraps around the methods from RayWorker. + Each worker is executed asynchronously. + + """ import ray def __init__(self, config_name, num_actors, **kwargs): + """Initialise the class. Sets up the config for the environment and creates individual workers to manage. + + Args: + config_name (:obj:`str`): Key of the environment to create. + num_actors (:obj:`int`): Number of environments (actors) to create + **kwargs: Misc. kwargs passed on to the environment creator function within the RayWorker __init__ + + """ self.config_name = config_name self.num_actors = num_actors self.use_torch = False @@ -131,6 +164,14 @@ def __init__(self, config_name, num_actors, **kwargs): self.concat_func = np.concatenate def step(self, actions): + """Step all individual environments (using the created workers). + Returns a concatenated array of observations, rewards, done states, and infos if the env allows concatenation. + Else returns a nested dict. + + Args: + action (type depends on env): Action to take. + + """ newobs, newstates, newrewards, newdones, newinfos = [], [], [], [], [] res_obs = [] if self.num_agents == 1: @@ -218,6 +259,12 @@ def reset(self): vecenv_config = {} def register(config_name, func): + """Add an environment type (for example RayVecEnv) to the list of available types `rl_games.common.vecenv.vecenv_config` + Args: + config_name (:obj:`str`): Key of the environment to create. + func (:obj:`func`): Function that creates the environment + + """ vecenv_config[config_name] = func def create_vec_env(config_name, num_actors, **kwargs): diff --git a/rl_games/torch_runner.py b/rl_games/torch_runner.py index 83bb8c07..4377d29b 100644 --- a/rl_games/torch_runner.py +++ b/rl_games/torch_runner.py @@ -31,8 +31,24 @@ def _override_sigma(agent, args): class Runner: + """Runs training/inference (playing) procedures as per a given configuration. + + The Runner class provides a high-level API for instantiating agents for either training or playing + with an RL algorithm. It further logs training metrics. + + """ def __init__(self, algo_observer=None): + """Initialise the runner instance with algorithms and observers. + + Initialises runners and players for all algorithms available in the library using `rl_games.common.object_factory.ObjectFactory` + + Args: + algo_observer (:obj:`rl_games.common.algo_observer.AlgoObserver`, optional): Algorithm observer that logs training metrics. + Defaults to `rl_games.common.algo_observer.DefaultAlgoObserver` + + """ + self.algo_factory = object_factory.ObjectFactory() self.algo_factory.register_builder('a2c_continuous', lambda **kwargs : a2c_continuous.A2CAgent(**kwargs)) self.algo_factory.register_builder('a2c_discrete', lambda **kwargs : a2c_discrete.DiscreteA2CAgent(**kwargs)) @@ -55,6 +71,13 @@ def reset(self): pass def load_config(self, params): + """Loads passed config params. + + Args: + params (:obj:`dict`): Parameters passed in as a dict obtained from a yaml file or some other config format. + + """ + self.seed = params.get('seed', None) if self.seed is None: self.seed = int(time.time()) @@ -109,6 +132,12 @@ def load(self, yaml_config): self.load_config(params=self.default_config) def run_train(self, args): + """Run the training procedure from the algorithm passed in. + + Args: + args (:obj:`dict`): Train specific args passed in as a dict obtained from a yaml file or some other config format. + + """ print('Started to train') agent = self.algo_factory.create(self.algo_name, base_name='run', params=self.params) _restore(agent, args) @@ -116,6 +145,12 @@ def run_train(self, args): agent.train() def run_play(self, args): + """Run the inference procedure from the algorithm passed in. + + Args: + args (:obj:`dict`): Playing specific args passed in as a dict obtained from a yaml file or some other config format. + + """ print('Started to play') player = self.create_player() _restore(player, args) @@ -129,6 +164,12 @@ def reset(self): pass def run(self, args): + """Run either train/play depending on the args. + + Args: + args (:obj:`dict`): Args passed in as a dict obtained from a yaml file or some other config format. + + """ if args['train']: self.run_train(args) elif args['play']: From 684df64e9e20156e0e5a4a7cf45ff242064cfe9f Mon Sep 17 00:00:00 2001 From: Nikita Kachaev <79598074+tttonyalpha@users.noreply.github.com> Date: Tue, 28 May 2024 01:27:42 -0500 Subject: [PATCH 03/13] fixed bug with multi-gpu training a2c_common.py (#284) Fixed bug with missed torch.cuda.set_device(self.local_rank), which causes a problem when two different parallel processes try to use the same GPU --- rl_games/common/a2c_common.py | 1 + 1 file changed, 1 insertion(+) diff --git a/rl_games/common/a2c_common.py b/rl_games/common/a2c_common.py index 63b90c07..30383c29 100644 --- a/rl_games/common/a2c_common.py +++ b/rl_games/common/a2c_common.py @@ -1308,6 +1308,7 @@ def train(self): self.curr_frames = self.batch_size_envs if self.multi_gpu: + torch.cuda.set_device(self.local_rank) print("====================broadcasting parameters") model_params = [self.model.state_dict()] dist.broadcast_object_list(model_params, 0) From 66970f8d0e3da642daf957e0587e02e19323ad89 Mon Sep 17 00:00:00 2001 From: Viktor Makoviychuk Date: Wed, 12 Jun 2024 17:01:55 -0700 Subject: [PATCH 04/13] Fixed applying minibatch_size_per_env (#287) --- rl_games/common/a2c_common.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/rl_games/common/a2c_common.py b/rl_games/common/a2c_common.py index 30383c29..f9bd5a14 100644 --- a/rl_games/common/a2c_common.py +++ b/rl_games/common/a2c_common.py @@ -236,7 +236,6 @@ def __init__(self, base_name, params): self.game_shaped_rewards = torch_ext.AverageMeter(self.value_size, self.games_to_track).to(self.ppo_device) self.game_lengths = torch_ext.AverageMeter(1, self.games_to_track).to(self.ppo_device) self.obs = None - self.games_num = self.config['minibatch_size'] // self.seq_length # it is used only for current rnn implementation self.batch_size = self.horizon_length * self.num_actors * self.num_agents self.batch_size_envs = self.horizon_length * self.num_actors @@ -245,6 +244,16 @@ def __init__(self, base_name, params): self.minibatch_size_per_env = self.config.get('minibatch_size_per_env', 0) self.minibatch_size = self.config.get('minibatch_size', self.num_actors * self.minibatch_size_per_env) + # either minibatch_size_per_env or minibatch_size should be present in a config + # if both are present, minibatch_size is used + # otherwise minibatch_size_per_env is used minibatch_size_per_env is used to calculate minibatch_size + self.minibatch_size_per_env = self.config.get('minibatch_size_per_env', 0) + self.minibatch_size = self.config.get('minibatch_size', self.num_actors * self.minibatch_size_per_env) + + assert(self.minibatch_size > 0) + + self.games_num = self.minibatch_size // self.seq_length # it is used only for current rnn implementation + self.num_minibatches = self.batch_size // self.minibatch_size assert(self.batch_size % self.minibatch_size == 0) From dec7275dd444d1de682df1bc0073a4167d7c8c5a Mon Sep 17 00:00:00 2001 From: Lukas Linauer <85884720+llinauer@users.noreply.github.com> Date: Mon, 24 Jun 2024 19:25:29 +0200 Subject: [PATCH 05/13] Fix conditional for choosing deterministic action in SACPlayer get_action method (#290) --- rl_games/algos_torch/players.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rl_games/algos_torch/players.py b/rl_games/algos_torch/players.py index 3dfcdadd..2f82519b 100644 --- a/rl_games/algos_torch/players.py +++ b/rl_games/algos_torch/players.py @@ -230,7 +230,7 @@ def get_action(self, obs, is_deterministic=False): if self.has_batch_dimension == False: obs = unsqueeze_obs(obs) dist = self.model.actor(obs) - actions = dist.sample() if is_deterministic else dist.mean + actions = dist.sample() if not is_deterministic else dist.mean actions = actions.clamp(*self.action_range).to(self.device) if self.has_batch_dimension == False: actions = torch.squeeze(actions.detach()) From 07043a3c9880b18f49f1bd19d0dea18d260b38ff Mon Sep 17 00:00:00 2001 From: Lukas Linauer <85884720+llinauer@users.noreply.github.com> Date: Tue, 25 Jun 2024 18:51:46 +0200 Subject: [PATCH 06/13] Fix SAC with input normalization (#291) * If self.normalize_input is True in SACAgent class, add the weights of the running_mean_std layer in get_weights method * Allow getting normalize_input from config and use self.model.norm_obs in get_action method --------- Co-authored-by: Lukas Linauer --- rl_games/algos_torch/players.py | 3 ++- rl_games/algos_torch/sac_agent.py | 2 ++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/rl_games/algos_torch/players.py b/rl_games/algos_torch/players.py index 2f82519b..69df7913 100644 --- a/rl_games/algos_torch/players.py +++ b/rl_games/algos_torch/players.py @@ -199,7 +199,7 @@ def __init__(self, params): ] obs_shape = self.obs_shape - self.normalize_input = False + self.normalize_input = self.config.get('normalize_input', False) config = { 'obs_dim': self.env_info["observation_space"].shape[0], 'action_dim': self.env_info["action_space"].shape[0], @@ -229,6 +229,7 @@ def restore(self, fn): def get_action(self, obs, is_deterministic=False): if self.has_batch_dimension == False: obs = unsqueeze_obs(obs) + obs = self.model.norm_obs(obs) dist = self.model.actor(obs) actions = dist.sample() if not is_deterministic else dist.mean actions = actions.clamp(*self.action_range).to(self.device) diff --git a/rl_games/algos_torch/sac_agent.py b/rl_games/algos_torch/sac_agent.py index dad8de0c..fd79fb7a 100644 --- a/rl_games/algos_torch/sac_agent.py +++ b/rl_games/algos_torch/sac_agent.py @@ -208,6 +208,8 @@ def get_weights(self): state = {'actor': self.model.sac_network.actor.state_dict(), 'critic': self.model.sac_network.critic.state_dict(), 'critic_target': self.model.sac_network.critic_target.state_dict()} + if self.normalize_input: + state['running_mean_std'] = self.model.running_mean_std.state_dict() return state def save(self, fn): From 7a2b25fc12ebf2a9ba9a05df56e210b35872ebb0 Mon Sep 17 00:00:00 2001 From: Viktor Makoviychuk Date: Wed, 3 Jul 2024 23:22:40 -0700 Subject: [PATCH 07/13] Increased time resolution for more precision performance tracking. (#295) * Increased time resolution for more precision performance tracking. * Updated recommended pytorch version. --- README.md | 6 +++--- rl_games/algos_torch/sac_agent.py | 15 +++++++-------- rl_games/common/a2c_common.py | 28 ++++++++++++++-------------- rl_games/torch_runner.py | 2 +- 4 files changed, 25 insertions(+), 26 deletions(-) diff --git a/README.md b/README.md index 9cf65d70..591166f5 100644 --- a/README.md +++ b/README.md @@ -67,10 +67,10 @@ Explore RL Games quick and easily in colab notebooks: ## Installation -For maximum training performance a preliminary installation of Pytorch 1.9+ with CUDA 11.1+ is highly recommended: +For maximum training performance a preliminary installation of Pytorch 2.2 or newer with CUDA 12.1 or newer is highly recommended: -```conda install pytorch torchvision cudatoolkit=11.3 -c pytorch -c nvidia``` or: -```pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html``` +```conda install pytorch torchvision pytorch-cuda=12.1 -c pytorch -c nvidia``` or: +```pip install pip3 install torch torchvision``` Then: diff --git a/rl_games/algos_torch/sac_agent.py b/rl_games/algos_torch/sac_agent.py index fd79fb7a..d4010fc4 100644 --- a/rl_games/algos_torch/sac_agent.py +++ b/rl_games/algos_torch/sac_agent.py @@ -441,7 +441,7 @@ def clear_stats(self): self.algo_observer.after_clear_stats() def play_steps(self, random_exploration = False): - total_time_start = time.time() + total_time_start = time.perf_counter() total_update_time = 0 total_time = 0 step_time = 0.0 @@ -466,11 +466,10 @@ def play_steps(self, random_exploration = False): with torch.no_grad(): action = self.act(obs.float(), self.env_info["action_space"].shape, sample=True) - step_start = time.time() - + step_start = time.perf_counter() with torch.no_grad(): next_obs, rewards, dones, infos = self.env_step(action) - step_end = time.time() + step_end = time.perf_counter() self.current_rewards += rewards self.current_lengths += 1 @@ -500,7 +499,6 @@ def play_steps(self, random_exploration = False): self.obs = next_obs.clone() rewards = self.rewards_shaper(rewards) - self.replay_buffer.add(obs, action, torch.unsqueeze(rewards, 1), next_obs_processed, torch.unsqueeze(dones, 1)) if isinstance(obs, dict): @@ -508,9 +506,10 @@ def play_steps(self, random_exploration = False): if not random_exploration: self.set_train() - update_time_start = time.time() + + update_time_start = time.perf_counter() actor_loss_info, critic1_loss, critic2_loss = self.update(self.epoch_num) - update_time_end = time.time() + update_time_end = time.perf_counter() update_time = update_time_end - update_time_start self.extract_actor_stats(actor_losses, entropies, alphas, alpha_losses, actor_loss_info) @@ -521,7 +520,7 @@ def play_steps(self, random_exploration = False): total_update_time += update_time - total_time_end = time.time() + total_time_end = time.perf_counter() total_time = total_time_end - total_time_start play_time = total_time - total_update_time diff --git a/rl_games/common/a2c_common.py b/rl_games/common/a2c_common.py index f9bd5a14..19b95985 100644 --- a/rl_games/common/a2c_common.py +++ b/rl_games/common/a2c_common.py @@ -757,9 +757,9 @@ def play_steps(self): if self.has_central_value: self.experience_buffer.update_data('states', n, self.obs['states']) - step_time_start = time.time() + step_time_start = time.perf_counter() self.obs, rewards, self.dones, infos = self.env_step(res_dict['actions']) - step_time_end = time.time() + step_time_end = time.perf_counter() step_time += (step_time_end - step_time_start) @@ -830,9 +830,9 @@ def play_steps_rnn(self): if self.has_central_value: self.experience_buffer.update_data('states', n, self.obs['states']) - step_time_start = time.time() + step_time_start = time.perf_counter() self.obs, rewards, self.dones, infos = self.env_step(res_dict['actions']) - step_time_end = time.time() + step_time_end = time.perf_counter() step_time += (step_time_end - step_time_start) @@ -920,7 +920,7 @@ def train_epoch(self): super().train_epoch() self.set_eval() - play_time_start = time.time() + play_time_start = time.perf_counter() with torch.no_grad(): if self.is_rnn: @@ -930,8 +930,8 @@ def train_epoch(self): self.set_train() - play_time_end = time.time() - update_time_start = time.time() + play_time_end = time.perf_counter() + update_time_start = time.perf_counter() rnn_masks = batch_dict.get('rnn_masks', None) self.curr_frames = batch_dict.pop('played_frames') @@ -966,7 +966,7 @@ def train_epoch(self): if self.normalize_input: self.model.running_mean_std.eval() # don't need to update statstics more than one miniepoch - update_time_end = time.time() + update_time_end = time.perf_counter() play_time = play_time_end - play_time_start update_time = update_time_end - update_time_start total_time = update_time_end - play_time_start @@ -1034,7 +1034,7 @@ def prepare_dataset(self, batch_dict): def train(self): self.init_tensors() self.mean_rewards = self.last_mean_rewards = -100500 - start_time = time.time() + start_time = time.perf_counter() total_time = 0 rep_count = 0 # self.frame = 0 # loading from checkpoint @@ -1183,15 +1183,15 @@ def train_epoch(self): super().train_epoch() self.set_eval() - play_time_start = time.time() + play_time_start = time.perf_counter() with torch.no_grad(): if self.is_rnn: batch_dict = self.play_steps_rnn() else: batch_dict = self.play_steps() - play_time_end = time.time() - update_time_start = time.time() + play_time_end = time.perf_counter() + update_time_start = time.perf_counter() rnn_masks = batch_dict.get('rnn_masks', None) self.set_train() @@ -1240,7 +1240,7 @@ def train_epoch(self): if self.normalize_input: self.model.running_mean_std.eval() # don't need to update statstics more than one miniepoch - update_time_end = time.time() + update_time_end = time.perf_counter() play_time = play_time_end - play_time_start update_time = update_time_end - update_time_start total_time = update_time_end - play_time_start @@ -1310,7 +1310,7 @@ def prepare_dataset(self, batch_dict): def train(self): self.init_tensors() self.last_mean_rewards = -100500 - start_time = time.time() + start_time = time.perf_counter() total_time = 0 rep_count = 0 self.obs = self.env_reset() diff --git a/rl_games/torch_runner.py b/rl_games/torch_runner.py index 4377d29b..86be48ac 100644 --- a/rl_games/torch_runner.py +++ b/rl_games/torch_runner.py @@ -63,7 +63,7 @@ def __init__(self, algo_observer=None): self.algo_observer = algo_observer if algo_observer else DefaultAlgoObserver() torch.backends.cudnn.benchmark = True - ### it didnot help for lots for openai gym envs anyway :( + ### it did not help for lots for openai gym envs anyway :( #torch.backends.cudnn.deterministic = True #torch.use_deterministic_algorithms(True) From 2606effbc2ecbee93ff2cc313b25dd5b4a7f0e54 Mon Sep 17 00:00:00 2001 From: iakinola23 <147214266+iakinola23@users.noreply.github.com> Date: Fri, 12 Jul 2024 04:39:32 -0400 Subject: [PATCH 08/13] Update for tacsl release: CNN tower processing, critic weights loading and freezing. (#298) * fix missing import copy * adding ability to post-process the output of a conv tower with the spatial soft argmax or flatten layer * enable loading the weights of the critic network from a PPO checkpoint, without the actor weights * add flag to freeze critic while training actor --- rl_games/algos_torch/a2c_continuous.py | 4 ++ rl_games/algos_torch/central_value.py | 3 + rl_games/algos_torch/network_builder.py | 14 ++++- rl_games/algos_torch/spatial_softmax.py | 83 +++++++++++++++++++++++++ rl_games/common/a2c_common.py | 8 ++- rl_games/torch_runner.py | 4 ++ 6 files changed, 113 insertions(+), 3 deletions(-) create mode 100644 rl_games/algos_torch/spatial_softmax.py diff --git a/rl_games/algos_torch/a2c_continuous.py b/rl_games/algos_torch/a2c_continuous.py index 78795a2b..a64de95f 100644 --- a/rl_games/algos_torch/a2c_continuous.py +++ b/rl_games/algos_torch/a2c_continuous.py @@ -83,6 +83,10 @@ def restore(self, fn, set_epoch=True): checkpoint = torch_ext.load_checkpoint(fn) self.set_full_state_weights(checkpoint, set_epoch=set_epoch) + def restore_central_value_function(self, fn): + checkpoint = torch_ext.load_checkpoint(fn) + self.set_central_value_function_weights(checkpoint) + def get_masked_action_values(self, obs, action_masks): assert False diff --git a/rl_games/algos_torch/central_value.py b/rl_games/algos_torch/central_value.py index d75c687c..c06d9a18 100644 --- a/rl_games/algos_torch/central_value.py +++ b/rl_games/algos_torch/central_value.py @@ -1,4 +1,5 @@ import os +import copy import torch from torch import nn import torch.distributed as dist @@ -219,6 +220,8 @@ def train_net(self): self.train() loss = 0 for _ in range(self.mini_epoch): + if self.config.get('freeze_critic', False): + break for idx in range(len(self.dataset)): loss += self.train_critic(self.dataset[idx]) if self.normalize_input: diff --git a/rl_games/algos_torch/network_builder.py b/rl_games/algos_torch/network_builder.py index ab047920..e5d625c0 100644 --- a/rl_games/algos_torch/network_builder.py +++ b/rl_games/algos_torch/network_builder.py @@ -8,6 +8,7 @@ from rl_games.algos_torch.sac_helper import SquashedNormal from rl_games.common.layers.recurrent import GRUWithDones, LSTMWithDones from rl_games.common.layers.value import TwoHotEncodedValue, DefaultValue +from rl_games.algos_torch.spatial_softmax import SpatialSoftArgmax def _create_initializer(func, **kwargs): @@ -130,12 +131,17 @@ def _build_conv(self, ctype, **kwargs): if ctype == 'conv2d': return self._build_cnn2d(**kwargs) + if ctype == 'conv2d_spatial_softargmax': + return self._build_cnn2d(add_spatial_softmax=True, **kwargs) + if ctype == 'conv2d_flatten': + return self._build_cnn2d(add_flatten=True, **kwargs) if ctype == 'coord_conv2d': return self._build_cnn2d(conv_func=torch_ext.CoordConv2d, **kwargs) if ctype == 'conv1d': return self._build_cnn1d(**kwargs) - def _build_cnn2d(self, input_shape, convs, activation, conv_func=torch.nn.Conv2d, norm_func_name=None): + def _build_cnn2d(self, input_shape, convs, activation, conv_func=torch.nn.Conv2d, norm_func_name=None, + add_spatial_softmax=False, add_flatten=False): in_channels = input_shape[0] layers = [] for conv in convs: @@ -150,7 +156,11 @@ def _build_cnn2d(self, input_shape, convs, activation, conv_func=torch.nn.Conv2d if norm_func_name == 'layer_norm': layers.append(torch_ext.LayerNorm2d(in_channels)) elif norm_func_name == 'batch_norm': - layers.append(torch.nn.BatchNorm2d(in_channels)) + layers.append(torch.nn.BatchNorm2d(in_channels)) + if add_spatial_softmax: + layers.append(SpatialSoftArgmax(normalize=True)) + if add_flatten: + layers.append(torch.nn.Flatten()) return nn.Sequential(*layers) def _build_cnn1d(self, input_shape, convs, activation, norm_func_name=None): diff --git a/rl_games/algos_torch/spatial_softmax.py b/rl_games/algos_torch/spatial_softmax.py new file mode 100644 index 00000000..862efed9 --- /dev/null +++ b/rl_games/algos_torch/spatial_softmax.py @@ -0,0 +1,83 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +# Adopted from https://gist.github.com/kevinzakka/dd9fa5177cda13593524f4d71eb38ad5 +class SpatialSoftArgmax(nn.Module): + """Spatial softmax as defined in [1]. + + Concretely, the spatial softmax of each feature + map is used to compute a weighted mean of the pixel + locations, effectively performing a soft arg-max + over the feature dimension. + + References: + [1]: End-to-End Training of Deep Visuomotor Policies, + https://arxiv.org/abs/1504.00702 + """ + + def __init__(self, normalize=False): + """Constructor. + + Args: + normalize (bool): Whether to use normalized + image coordinates, i.e. coordinates in + the range `[-1, 1]`. + """ + super().__init__() + + self.normalize = normalize + + def _coord_grid(self, h, w, device): + if self.normalize: + return torch.stack( + torch.meshgrid( + torch.linspace(-1, 1, w, device=device), + torch.linspace(-1, 1, h, device=device), + ) + ) + return torch.stack( + torch.meshgrid( + torch.arange(0, w, device=device), + torch.arange(0, h, device=device), + ) + ) + + def forward(self, x): + assert x.ndim == 4, "Expecting a tensor of shape (B, C, H, W)." + + # compute a spatial softmax over the input: + # given an input of shape (B, C, H, W), + # reshape it to (B*C, H*W) then apply + # the softmax operator over the last dimension + b, c, h, w = x.shape + softmax = F.softmax(x.reshape(-1, h * w), dim=-1) + + # create a meshgrid of pixel coordinates + # both in the x and y axes + xc, yc = self._coord_grid(h, w, x.device) + + # element-wise multiply the x and y coordinates + # with the softmax, then sum over the h*w dimension + # this effectively computes the weighted mean of x + # and y locations + x_mean = (softmax * xc.flatten()).sum(dim=1, keepdims=True) + y_mean = (softmax * yc.flatten()).sum(dim=1, keepdims=True) + + # concatenate and reshape the result + # to (B, C*2) where for every feature + # we have the expected x and y pixel + # locations + return torch.cat([x_mean, y_mean], dim=1).view(-1, c * 2) + + +if __name__ == "__main__": + b, c, h, w = 32, 64, 12, 12 + x = torch.zeros(b, c, h, w) + true_max = torch.randint(0, 10, size=(b, c, 2)) + for i in range(b): + for j in range(c): + x[i, j, true_max[i, j, 0], true_max[i, j, 1]] = 1000 + soft_max = SpatialSoftArgmax()(x).reshape(b, c, 2) + assert torch.allclose(true_max.float(), soft_max) \ No newline at end of file diff --git a/rl_games/common/a2c_common.py b/rl_games/common/a2c_common.py index 19b95985..224bca6b 100644 --- a/rl_games/common/a2c_common.py +++ b/rl_games/common/a2c_common.py @@ -643,6 +643,9 @@ def set_full_state_weights(self, weights, set_epoch=True): env_state = weights.get('env_state', None) self.vec_env.set_env_state(env_state) + def set_central_value_function_weights(self, weights): + self.central_value_net.load_state_dict(weights['assymetric_vf_nets']) + def get_weights(self): state = self.get_stats_weights() state['model'] = self.model.state_dict() @@ -1262,7 +1265,10 @@ def prepare_dataset(self, batch_dict): advantages = returns - values if self.normalize_value: - self.value_mean_std.train() + if self.config.get('freeze_critic', False): + self.value_mean_std.eval() + else: + self.value_mean_std.train() values = self.value_mean_std(values) returns = self.value_mean_std(returns) self.value_mean_std.eval() diff --git a/rl_games/torch_runner.py b/rl_games/torch_runner.py index 86be48ac..0f7a9ac8 100644 --- a/rl_games/torch_runner.py +++ b/rl_games/torch_runner.py @@ -17,6 +17,10 @@ def _restore(agent, args): if 'checkpoint' in args and args['checkpoint'] is not None and args['checkpoint'] !='': + if args['train'] and args.get('load_critic_only', False): + assert agent.has_central_value, 'This should only work for asymmetric actor critic' + agent.restore_central_value_function(args['checkpoint']) + return agent.restore(args['checkpoint']) def _override_sigma(agent, args): From 7f9cd1e3293e74bbc88063b4a11d46daf9097f94 Mon Sep 17 00:00:00 2001 From: Denys Makoviichuk Date: Sun, 8 Sep 2024 17:33:23 -0700 Subject: [PATCH 09/13] added test aux_loss (#303) Co-authored-by: Denys Makoviichuk --- rl_games/algos_torch/a2c_continuous.py | 10 ++- rl_games/algos_torch/a2c_discrete.py | 11 ++- rl_games/algos_torch/models.py | 24 ++++++- rl_games/algos_torch/network_builder.py | 3 + rl_games/common/a2c_common.py | 9 +-- .../test/test_discrite_testnet_aux_loss.yaml | 52 ++++++++++++++ rl_games/envs/__init__.py | 5 +- rl_games/envs/test/rnn_env.py | 10 +++ rl_games/envs/test_network.py | 70 ++++++++++++++++++- 9 files changed, 180 insertions(+), 14 deletions(-) create mode 100644 rl_games/configs/test/test_discrite_testnet_aux_loss.yaml diff --git a/rl_games/algos_torch/a2c_continuous.py b/rl_games/algos_torch/a2c_continuous.py index a64de95f..e93ea362 100644 --- a/rl_games/algos_torch/a2c_continuous.py +++ b/rl_games/algos_torch/a2c_continuous.py @@ -151,7 +151,15 @@ def calc_gradients(self, input_dict): a_loss, c_loss, entropy, b_loss = losses[0], losses[1], losses[2], losses[3] loss = a_loss + 0.5 * c_loss * self.critic_coef - entropy * self.entropy_coef + b_loss * self.bounds_loss_coef - + aux_loss = self.model.get_aux_loss() + self.aux_loss_dict = {} + if aux_loss is not None: + for k, v in aux_loss.items(): + loss += v + if k in self.aux_loss_dict: + self.aux_loss_dict[k] = v.detach() + else: + self.aux_loss_dict[k] = [v.detach()] if self.multi_gpu: self.optimizer.zero_grad() else: diff --git a/rl_games/algos_torch/a2c_discrete.py b/rl_games/algos_torch/a2c_discrete.py index 312bd17c..a69ba9bb 100644 --- a/rl_games/algos_torch/a2c_discrete.py +++ b/rl_games/algos_torch/a2c_discrete.py @@ -170,7 +170,16 @@ def calc_gradients(self, input_dict): losses, sum_mask = torch_ext.apply_masks([a_loss.unsqueeze(1), c_loss, entropy.unsqueeze(1)], rnn_masks) a_loss, c_loss, entropy = losses[0], losses[1], losses[2] loss = a_loss + 0.5 *c_loss * self.critic_coef - entropy * self.entropy_coef - + aux_loss = self.model.get_aux_loss() + self.aux_loss_dict = {} + if aux_loss is not None: + for k, v in aux_loss.items(): + loss += v + if k in self.aux_loss_dict: + self.aux_loss_dict[k] = v.detach() + else: + self.aux_loss_dict[k] = [v.detach()] + if self.multi_gpu: self.optimizer.zero_grad() else: diff --git a/rl_games/algos_torch/models.py b/rl_games/algos_torch/models.py index 9c9dde4d..93d4001d 100644 --- a/rl_games/algos_torch/models.py +++ b/rl_games/algos_torch/models.py @@ -54,6 +54,10 @@ def norm_obs(self, observation): def denorm_value(self, value): with torch.no_grad(): return self.value_mean_std(value, denorm=True) if self.normalize_value else value + + + def get_aux_loss(self): + return None class ModelA2C(BaseModel): def __init__(self, network): @@ -64,7 +68,10 @@ class Network(BaseModelNetwork): def __init__(self, a2c_network, **kwargs): BaseModelNetwork.__init__(self,**kwargs) self.a2c_network = a2c_network - + + def get_aux_loss(self): + return self.a2c_network.get_aux_loss() + def is_rnn(self): return self.a2c_network.is_rnn() @@ -121,6 +128,9 @@ def __init__(self, a2c_network, **kwargs): BaseModelNetwork.__init__(self, **kwargs) self.a2c_network = a2c_network + def get_aux_loss(self): + return self.a2c_network.get_aux_loss() + def is_rnn(self): return self.a2c_network.is_rnn() @@ -190,6 +200,9 @@ def __init__(self, a2c_network, **kwargs): BaseModelNetwork.__init__(self, **kwargs) self.a2c_network = a2c_network + def get_aux_loss(self): + return self.a2c_network.get_aux_loss() + def is_rnn(self): return self.a2c_network.is_rnn() @@ -248,6 +261,9 @@ def __init__(self, a2c_network, **kwargs): BaseModelNetwork.__init__(self, **kwargs) self.a2c_network = a2c_network + def get_aux_loss(self): + return self.a2c_network.get_aux_loss() + def is_rnn(self): return self.a2c_network.is_rnn() @@ -305,6 +321,9 @@ def __init__(self, a2c_network, **kwargs): BaseModelNetwork.__init__(self, **kwargs) self.a2c_network = a2c_network + def get_aux_loss(self): + return self.a2c_network.get_aux_loss() + def is_rnn(self): return self.a2c_network.is_rnn() @@ -344,6 +363,9 @@ def __init__(self, sac_network,**kwargs): BaseModelNetwork.__init__(self,**kwargs) self.sac_network = sac_network + def get_aux_loss(self): + return self.sac_network.get_aux_loss() + def critic(self, obs, action): return self.sac_network.critic(obs, action) diff --git a/rl_games/algos_torch/network_builder.py b/rl_games/algos_torch/network_builder.py index e5d625c0..289812dd 100644 --- a/rl_games/algos_torch/network_builder.py +++ b/rl_games/algos_torch/network_builder.py @@ -67,6 +67,9 @@ def is_rnn(self): def get_default_rnn_state(self): return None + def get_aux_loss(self): + return None + def _calc_input_size(self, input_shape,cnn_layers=None): if cnn_layers is None: assert(len(input_shape) == 1) diff --git a/rl_games/common/a2c_common.py b/rl_games/common/a2c_common.py index 224bca6b..54b3e1ef 100644 --- a/rl_games/common/a2c_common.py +++ b/rl_games/common/a2c_common.py @@ -321,9 +321,7 @@ def __init__(self, base_name, params): self.algo_observer = config['features']['observer'] self.soft_aug = config['features'].get('soft_augmentation', None) - self.has_soft_aug = self.soft_aug is not None - # soft augmentation not yet supported - assert not self.has_soft_aug + self.aux_loss_dict = {} def trancate_gradients_and_step(self): if self.multi_gpu: @@ -374,6 +372,8 @@ def write_stats(self, total_time, epoch_num, step_time, play_time, update_time, self.writer.add_scalar('losses/c_loss', torch_ext.mean_list(c_losses).item(), frame) self.writer.add_scalar('losses/entropy', torch_ext.mean_list(entropies).item(), frame) + for k,v in self.aux_loss_dict.items(): + self.writer.add_scalar('losses/' + k, torch_ext.mean_list(v).item(), frame) self.writer.add_scalar('info/last_lr', last_lr * lr_mul, frame) self.writer.add_scalar('info/lr_mul', lr_mul, frame) self.writer.add_scalar('info/e_clip', self.e_clip * lr_mul, frame) @@ -1357,9 +1357,6 @@ def train(self): if len(b_losses) > 0: self.writer.add_scalar('losses/bounds_loss', torch_ext.mean_list(b_losses).item(), frame) - if self.has_soft_aug: - self.writer.add_scalar('losses/aug_loss', np.mean(aug_losses), frame) - if self.game_rewards.current_size > 0: mean_rewards = self.game_rewards.get_mean() mean_shaped_rewards = self.game_shaped_rewards.get_mean() diff --git a/rl_games/configs/test/test_discrite_testnet_aux_loss.yaml b/rl_games/configs/test/test_discrite_testnet_aux_loss.yaml new file mode 100644 index 00000000..0f666f0d --- /dev/null +++ b/rl_games/configs/test/test_discrite_testnet_aux_loss.yaml @@ -0,0 +1,52 @@ +params: + algo: + name: a2c_discrete + + model: + name: discrete_a2c + + network: + name: testnet_aux_loss + config: + reward_shaper: + scale_value: 1 + normalize_advantage: True + gamma: 0.99 + tau: 0.9 + learning_rate: 2e-4 + name: test_md_multi_obs + score_to_win: 0.95 + grad_norm: 10.5 + entropy_coef: 0.005 + truncate_grads: True + env_name: test_env + e_clip: 0.2 + clip_value: False + num_actors: 16 + horizon_length: 256 + minibatch_size: 2048 + mini_epochs: 4 + critic_coef: 1 + lr_schedule: None + kl_threshold: 0.008 + normalize_input: False + normalize_value: False + weight_decay: 0.0000 + max_epochs: 10000 + seq_length: 16 + save_best_after: 10 + save_frequency: 20 + + env_config: + name: TestRnnEnv-v0 + hide_object: False + apply_dist_reward: False + min_dist: 2 + max_dist: 8 + use_central_value: True + multi_obs_space: True + multi_head_value: False + aux_loss: True + player: + games_num: 100 + deterministic: True \ No newline at end of file diff --git a/rl_games/envs/__init__.py b/rl_games/envs/__init__.py index 6883b34a..b906c43d 100644 --- a/rl_games/envs/__init__.py +++ b/rl_games/envs/__init__.py @@ -1,6 +1,7 @@ -from rl_games.envs.test_network import TestNetBuilder +from rl_games.envs.test_network import TestNetBuilder, TestNetAuxLossBuilder from rl_games.algos_torch import model_builder -model_builder.register_network('testnet', TestNetBuilder) \ No newline at end of file +model_builder.register_network('testnet', TestNetBuilder) +model_builder.register_network('testnet_aux_loss', TestNetAuxLossBuilder) \ No newline at end of file diff --git a/rl_games/envs/test/rnn_env.py b/rl_games/envs/test/rnn_env.py index faa4e17e..5fcf5318 100644 --- a/rl_games/envs/test/rnn_env.py +++ b/rl_games/envs/test/rnn_env.py @@ -16,6 +16,7 @@ def __init__(self, **kwargs): self.apply_dist_reward = kwargs.pop('apply_dist_reward', False) self.apply_exploration_reward = kwargs.pop('apply_exploration_reward', False) self.multi_head_value = kwargs.pop('multi_head_value', False) + self.aux_loss = kwargs.pop('aux_loss', False) if self.multi_head_value: self.value_size = 2 else: @@ -33,6 +34,8 @@ def __init__(self, **kwargs): 'pos': gym.spaces.Box(low=0, high=1, shape=(2, ), dtype=np.float32), 'info': gym.spaces.Box(low=0, high=1, shape=(4, ), dtype=np.float32), } + if self.aux_loss: + spaces['aux_target'] = gym.spaces.Box(low=0, high=1, shape=(1, ), dtype=np.float32) self.observation_space = gym.spaces.Dict(spaces) else: self.observation_space = gym.spaces.Box(low=0, high=1, shape=(6, ), dtype=np.float32) @@ -58,6 +61,9 @@ def reset(self): 'pos': obs[:2], 'info': obs[2:] } + if self.aux_loss: + aux_target = np.sum((self._goal_pos - self._current_pos)**2) / bound**2 + obs['aux_target'] = np.expand_dims(aux_target.astype(np.float32), axis=0) if self.use_central_value: obses = {} obses["obs"] = obs @@ -93,6 +99,7 @@ def step_multi_categorical(self, action): def step(self, action): info = {} self._curr_steps += 1 + bound = self.max_dist - self.min_dist if self.multi_discrete_space: self.step_multi_categorical(action) else: @@ -125,6 +132,9 @@ def step(self, action): 'pos': obs[:2], 'info': obs[2:] } + if self.aux_loss: + aux_target = np.sum((self._goal_pos - self._current_pos)**2) / bound**2 + obs['aux_target'] = np.expand_dims(aux_target.astype(np.float32), axis=0) if self.use_central_value: state = np.concatenate([self._current_pos, self._goal_pos, [show_object, self._curr_steps]], axis=None) obses = {} diff --git a/rl_games/envs/test_network.py b/rl_games/envs/test_network.py index 6170ebb7..7adfae90 100644 --- a/rl_games/envs/test_network.py +++ b/rl_games/envs/test_network.py @@ -2,8 +2,9 @@ from torch import nn import torch.nn.functional as F - -class TestNet(nn.Module): +from rl_games.algos_torch.network_builder import NetworkBuilder + +class TestNet(NetworkBuilder.BaseNetwork): def __init__(self, params, **kwargs): nn.Module.__init__(self) actions_num = kwargs.pop('actions_num') @@ -38,7 +39,7 @@ def forward(self, obs): return action, value, None -from rl_games.algos_torch.network_builder import NetworkBuilder + class TestNetBuilder(NetworkBuilder): def __init__(self, **kwargs): @@ -52,3 +53,66 @@ def build(self, name, **kwargs): def __call__(self, name, **kwargs): return self.build(name, **kwargs) + + + +class TestNetWithAuxLoss(NetworkBuilder.BaseNetwork): + def __init__(self, params, **kwargs): + nn.Module.__init__(self) + actions_num = kwargs.pop('actions_num') + input_shape = kwargs.pop('input_shape') + num_inputs = 0 + + self.target_key = 'aux_target' + assert(type(input_shape) is dict) + for k,v in input_shape.items(): + if self.target_key == k: + self.target_shape = v[0] + else: + num_inputs +=v[0] + + self.central_value = params.get('central_value', False) + self.value_size = kwargs.pop('value_size', 1) + self.linear1 = nn.Linear(num_inputs, 256) + self.linear2 = nn.Linear(256, 128) + self.linear3 = nn.Linear(128, 64) + self.mean_linear = nn.Linear(64, actions_num) + self.value_linear = nn.Linear(64, 1) + self.aux_loss_linear = nn.Linear(64, self.target_shape) + + self.aux_loss_map = { + 'aux_dist_loss' : None + } + def is_rnn(self): + return False + + def get_aux_loss(self): + return self.aux_loss_map + + def forward(self, obs): + obs = obs['obs'] + target_obs = obs[self.target_key] + obs = torch.cat([obs['pos'], obs['info']], axis=-1) + x = F.relu(self.linear1(obs)) + x = F.relu(self.linear2(x)) + x = F.relu(self.linear3(x)) + action = self.mean_linear(x) + value = self.value_linear(x) + y = self.aux_loss_linear(x) + self.aux_loss_map['aux_dist_loss'] = torch.nn.functional.mse_loss(y, target_obs) + if self.central_value: + return value, None + return action, value, None + +class TestNetAuxLossBuilder(NetworkBuilder): + def __init__(self, **kwargs): + NetworkBuilder.__init__(self) + + def load(self, params): + self.params = params + + def build(self, name, **kwargs): + return TestNetWithAuxLoss(self.params, **kwargs) + + def __call__(self, name, **kwargs): + return self.build(name, **kwargs) \ No newline at end of file From 59d4c409302735766251ac8846c1791c828241c1 Mon Sep 17 00:00:00 2001 From: annan_tang <54849345+annan-tang@users.noreply.github.com> Date: Thu, 12 Sep 2024 04:17:22 +0900 Subject: [PATCH 10/13] Add a broadcast for the initial parameters of central_value_net in multi-GPU/node training. (#297) --- rl_games/common/a2c_common.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/rl_games/common/a2c_common.py b/rl_games/common/a2c_common.py index 54b3e1ef..e083b0b5 100644 --- a/rl_games/common/a2c_common.py +++ b/rl_games/common/a2c_common.py @@ -1047,8 +1047,12 @@ def train(self): torch.cuda.set_device(self.local_rank) print("====================broadcasting parameters") model_params = [self.model.state_dict()] + if self.has_central_value: + model_params.append(self.central_value_net.state_dict()) dist.broadcast_object_list(model_params, 0) self.model.load_state_dict(model_params[0]) + if self.has_central_value: + self.central_value_net.load_state_dict(model_params[1]) while True: epoch_num = self.update_epoch() @@ -1326,8 +1330,12 @@ def train(self): torch.cuda.set_device(self.local_rank) print("====================broadcasting parameters") model_params = [self.model.state_dict()] + if self.has_central_value: + model_params.append(self.central_value_net.state_dict()) dist.broadcast_object_list(model_params, 0) self.model.load_state_dict(model_params[0]) + if self.has_central_value: + self.central_value_net.load_state_dict(model_params[1]) while True: epoch_num = self.update_epoch() From eca781c2a05c0c986d0595014764c2817499efa5 Mon Sep 17 00:00:00 2001 From: "Yueqian (Ryan) Liu" Date: Sat, 28 Sep 2024 22:01:55 +0200 Subject: [PATCH 11/13] Documation about class implemetations, customization of training/playing loops and models and networks (#305) * add doc about class implementations and customization of training, playing, model and network * Update docs/ISAAC_GYM.md Fix typo. Co-authored-by: Artem Yerofieiev <169092593+ayerofieiev-tt@users.noreply.github.com> --------- Co-authored-by: Artem Yerofieiev <169092593+ayerofieiev-tt@users.noreply.github.com> --- docs/ISAAC_GYM.md | 172 +++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 171 insertions(+), 1 deletion(-) diff --git a/docs/ISAAC_GYM.md b/docs/ISAAC_GYM.md index d5c39f57..144a1794 100644 --- a/docs/ISAAC_GYM.md +++ b/docs/ISAAC_GYM.md @@ -1,4 +1,174 @@ ## Isaac Gym Results https://developer.nvidia.com/isaac-gym -Coming. +## What's Written Below + +Content below is written to complement `HOW_TO_RL_GAMES.md` in the same directory, while focusing more on **explaining the implementations** in the classes and how to **customize the training (testing) loops, models and networks**. Since the AMP implementation in `IsaacGymEnvs` is used as the example, so you are reading me here under this file. + +## Program Entry Point + +The primary entry point for both training and testing within `IsaacGymEnvs` is the `train.py` script. This file initializes an instance of the `rl_games.torch_runner.Runner` class, and depending on the mode selected, either the `run_train` or `run_play` function is executed. Additionally, `train.py` allows for custom implementations of training and testing loops, as well as the integration of custom neural networks and models into the library through the `build_runner` function, a process referred to as "registering." By registering custom code, the library can be configured to execute the user-defined code by specifying the appropriate names within the configuration file. + +In RL Games, the training algorithms are referred to as "agents," while their counterparts for testing are known as "players." In the `run_train` function, an agent is instantiated, and training is initiated through the `agent.train` call. Similarly, in the `run_play` function, a player is created, and testing begins by invoking `player.run`. Thus, the core entry points for training and testing in RL Games are the `train` function for agents and the `run` function for players. + +```python +def run_train(self, args): + """Run the training procedure from the algorithm passed in.""" + + print('Started to train') + agent = self.algo_factory.create(self.algo_name, base_name='run', params=self.params) + _restore(agent, args) + _override_sigma(agent, args) + agent.train() + +def run_play(self, args): + """Run the inference procedure from the algorithm passed in.""" + + print('Started to play') + player = self.create_player() + _restore(player, args) + _override_sigma(player, args) + player.run() +``` + +## Training Algorithms + +The creation of an agent is handled by the `algo_factory`, as demonstrated in the code above. By default, the `algo_factory` is registered with continuous-action A2C, discrete-action A2C, and SAC. This default registration is found within the constructor of the `Runner` class, and its implementation is shown below. Our primary focus will be on understanding `A2CAgent`, as it is the primary algorithm used for most continuous-control tasks in `IsaacGymEnvs`. + +```python +self.algo_factory.register_builder( + 'a2c_continuous', + lambda **kwargs: a2c_continuous.A2CAgent(**kwargs) +) +self.algo_factory.register_builder( + 'a2c_discrete', + lambda **kwargs: a2c_discrete.DiscreteA2CAgent(**kwargs) +) +self.algo_factory.register_builder( + 'sac', + lambda **kwargs: sac_agent.SACAgent(**kwargs) +) +``` + +At the base of all RL Games algorithms is the `BaseAlgorithm` class, an abstract class that defines several essential methods, including `train` and `train_epoch`, which are critical for training. The `A2CBase` class inherits from `BaseAlgorithm` and provides many shared functionalities for both continuous and discrete A2C agents. These include methods such as `play_steps` and `play_steps_rnn` for gathering rollout data, and `env_step` and `env_reset` for interacting with the environment. However, functions directly related to training—like `train`, `train_epoch`, `update_epoch`, `prepare_dataset`, `train_actor_critic`, and `calc_gradients`—are left unimplemented at this level. These functions are implemented in `ContinuousA2CBase`, a subclass of `A2CBase`, and further in `A2CAgent`, a subclass of `ContinuousA2CBase`. + +The `ContinuousA2CBase` class is responsible for the core logic of agent training, specifically in the methods `train`, `train_epoch`, and `prepare_dataset`. In the `train` function, the environment is reset once before entering the main training loop. This loop consists of three primary stages: + +1. Calling `update_epoch`. +2. Running `train_epoch`. +3. Logging key information, such as episode length, rewards, and losses. + +The `update_epoch` function, which increments the epoch count, is implemented in `A2CAgent`. The heart of the training process is the `train_epoch` function, which operates as follows: + +1. `play_steps` or `play_steps_rnn` is called to generate rollout data in the form of a dictionary of tensors, `batch_dict`. The number of environment steps collected equals the configured `horizon_length`. +2. `prepare_dataset` modifies the tensors in `batch_dict`, which may include normalizing values and advantages, depending on the configuration. +3. Multiple mini-epochs are executed. In each mini-epoch, the dataset is divided into mini-batches, which are sequentially fed into `train_actor_critic`. Function `train_actor_critic`, implemented in `A2CAgent`, internally calls `calc_grad`, also found in `A2CAgent`. + +The `A2CAgent` class, which inherits from `ContinuousA2CBase`, handles the crucial task of gradient calculation and model parameter optimization in its `calc_grad` function. Specifically, `calc_grad` first performs a forward pass of the policy model with PyTorch’s gradients and computational graph enabled. It then calculates the individual loss terms as well as the total scalar loss, runs the backward pass to compute gradients, truncates gradients if necessary, updates model parameters via the optimizer, and finally logs the relevant training metrics such as loss terms and learning rates. + +With an understanding of the default functions, it becomes straightforward to customize agents by inheriting from `A2CAgent` and overriding specific methods to suit particular needs. A good example of this is the implementation of the AMP algorithm in `IsaacGymEnvs`, where the `AMPAgent` class is created and registered in `train.py`, as shown below. + +```python +_runner.algo_factory.register_builder( + 'amp_continuous', + lambda **kwargs: amp_continuous.AMPAgent(**kwargs) +) +``` + +## Players + +Similar to training algorithms, default players are registered with `player_factory` in the `Runner` class. These include `PPOPlayerContinuous`, `PPOPlayerDiscrete`, and `SACPlayer`. Each of these player classes inherits from the `BasePlayer` class, which provides a common `run` function. The derived player classes implement specific methods for restoring from model checkpoints (`restore`), initializing the RNN (`reset`), and generating actions based on observations through `get_action` and `get_masked_action`. + +The testing loop is simpler compared to the training loop. It starts by resetting the environment to obtain the initial observation. Then, for `max_steps` iterations, the loop feeds the observation into the model to generate an action, which is applied to the environment to retrieve the next observation, reward, and other necessary data. This process is repeated for `n_games` episodes, after which the average reward and episode lengths are calculated and displayed. + +Customizing the testing loop is as straightforward as customizing the training loop. By inheriting from a default player class, one can override specific functions as needed. As with custom training algorithms, customized players must also be registered with `player_factory` in `train.py`, as demonstrated below. + +```python +self.player_factory.register_builder( + 'a2c_continuous', + lambda **kwargs: players.PpoPlayerContinuous(**kwargs) +) +self.player_factory.register_builder( + 'a2c_discrete', + lambda **kwargs: players.PpoPlayerDiscrete(**kwargs) +) +self.player_factory.register_builder( + 'sac', + lambda **kwargs: players.SACPlayer(**kwargs) +) + +_runner.player_factory.register_builder( + 'amp_continuous', + lambda **kwargs: amp_players.AMPPlayerContinuous(**kwargs) +) +``` + +## Models and Networks + +The terminology and implementation of models and networks in RL Games version `1.6.1` can be confusing for new users. Below is a high-level overview of their functionality and relationships: + +- **Network Builder:** Network builder classes, such as `A2CBuilder` and `SACBuilder`, are subclasses of `NetworkBuilder` and can be found in `algos_torch.network_builder`. The core component of a network builder is the nested `Network` class (the "inner network" class), which is typically derived from `torch.nn.Module`. This class receives a dictionary of tensors, such as observations and other necessary inputs, and outputs a tuple of tensors from which actions can be generated. The `forward` function of the `Network` class handles this transformation. + +- **Model:** Model classes, like `ModelA2C` and `ModelSACContinuous`, inherit from `BaseModel` in `algos_torch.models`. They are similar to network builders, as each contains a nested `Network` class (the "model network" class) and a `build` function to construct an instance of this network. + +- **Model & Network in Algorithm:** In a default agent or player algorithm, `self.model` refers to an instance of the model network class, while `self.network` refers to an instance of the model class. + +- **Model Builder:** The `ModelBuilder` class, located in `algos_torch.model_builder`, is responsible for loading and managing models. It provides a `load` function, which creates a model instance based on the specified name. + +Customizing models requires implementing a custom network builder and model class. These custom classes should be registered in the `Runner` class within `train.py`. A good reference example is the AMP implementation, as shown below. + +```python +# algos_torch.model_builder.NetworkBuilder.__init__ +self.network_factory.register_builder( + 'actor_critic', + lambda **kwargs: network_builder.A2CBuilder() +) +self.network_factory.register_builder( + 'resnet_actor_critic', + lambda **kwargs: network_builder.A2CResnetBuilder() +) +self.network_factory.register_builder( + 'rnd_curiosity', + lambda **kwargs: network_builder.RNDCuriosityBuilder() +) +self.network_factory.register_builder( + 'soft_actor_critic', + lambda **kwargs: network_builder.SACBuilder() +) + +# algos_torch.model_builder.ModelBuilder.__init__ +self.model_factory.register_builder( + 'discrete_a2c', + lambda network, **kwargs: models.ModelA2C(network) +) +self.model_factory.register_builder( + 'multi_discrete_a2c', + lambda network, **kwargs: models.ModelA2CMultiDiscrete(network) +) +self.model_factory.register_builder( + 'continuous_a2c', + lambda network, **kwargs: models.ModelA2CContinuous(network) +) +self.model_factory.register_builder( + 'continuous_a2c_logstd', + lambda network, **kwargs: models.ModelA2CContinuousLogStd(network) +) +self.model_factory.register_builder( + 'soft_actor_critic', + lambda network, **kwargs: models.ModelSACContinuous(network) +) +self.model_factory.register_builder( + 'central_value', + lambda network, **kwargs: models.ModelCentralValue(network) +) + +# isaacgymenvs.train.launch_rlg_hydra.build_runner +model_builder.register_model( + 'continuous_amp', + lambda network, **kwargs: amp_models.ModelAMPContinuous(network), +) +model_builder.register_network( + 'amp', + lambda **kwargs: amp_network_builder.AMPBuilder() +) +``` From 0ed7c3a344f735d0e244beea7d2a733c8b92d35b Mon Sep 17 00:00:00 2001 From: Denys Makoviichuk Date: Sat, 28 Sep 2024 15:53:35 -0700 Subject: [PATCH 12/13] Added myo suite support (#306) * Added myosuite support with Ray * Added training config example --------- Co-authored-by: Denys Makoviichuk --- rl_games/common/env_configurations.py | 10 ++++ rl_games/common/wrappers.py | 85 +++++++++++++++++++++++++++ rl_games/configs/ppo_myo.yaml | 68 +++++++++++++++++++++ 3 files changed, 163 insertions(+) create mode 100644 rl_games/configs/ppo_myo.yaml diff --git a/rl_games/common/env_configurations.py b/rl_games/common/env_configurations.py index 43c8ebe1..08170847 100644 --- a/rl_games/common/env_configurations.py +++ b/rl_games/common/env_configurations.py @@ -86,6 +86,12 @@ def create_slime_gym_env(**kwargs): env = gym.make(name, **kwargs) return env +def create_myo(**kwargs): + from myosuite.utils import gym + name = kwargs.pop('name') + env = gym.make(name, **kwargs) + env = wrappers.OldGymWrapper(env) + return env def create_atari_gym_env(**kwargs): #frames = kwargs.pop('frames', 1) @@ -427,6 +433,10 @@ def create_env(name, **kwargs): 'env_creator': lambda **kwargs: create_cule(**kwargs), 'vecenv_type': 'CULE' }, + 'myo_gym' : { + 'env_creator' : lambda **kwargs : create_myo(**kwargs), + 'vecenv_type' : 'RAY' + }, } def get_env_info(env): diff --git a/rl_games/common/wrappers.py b/rl_games/common/wrappers.py index a62e0855..dab4a648 100644 --- a/rl_games/common/wrappers.py +++ b/rl_games/common/wrappers.py @@ -1,3 +1,4 @@ +import gymnasium import numpy as np from numpy.random import randint @@ -626,6 +627,90 @@ def __init__(self, env, name): def observation(self, observation): return observation * self.mask +class OldGymWrapper(gym.Env): + def __init__(self, env): + self.env = env + + # Convert Gymnasium spaces to Gym spaces + self.observation_space = self.convert_space(env.observation_space) + self.action_space = self.convert_space(env.action_space) + + def convert_space(self, space): + """Recursively convert Gymnasium spaces to Gym spaces.""" + if isinstance(space, gymnasium.spaces.Box): + return gym.spaces.Box( + low=space.low, + high=space.high, + shape=space.shape, + dtype=space.dtype + ) + elif isinstance(space, gymnasium.spaces.Discrete): + return gym.spaces.Discrete(n=space.n) + elif isinstance(space, gymnasium.spaces.MultiDiscrete): + return gym.spaces.MultiDiscrete(nvec=space.nvec) + elif isinstance(space, gymnasium.spaces.MultiBinary): + return gym.spaces.MultiBinary(n=space.n) + elif isinstance(space, gymnasium.spaces.Tuple): + return gym.spaces.Tuple([self.convert_space(s) for s in space.spaces]) + elif isinstance(space, gymnasium.spaces.Dict): + return gym.spaces.Dict({k: self.convert_space(s) for k, s in space.spaces.items()}) + else: + raise NotImplementedError(f"Space type {type(space)} is not supported.") + + def reset(self): + result = self.env.reset() + if isinstance(result, tuple): + # Gymnasium returns (observation, info) + observation, _ = result + else: + observation = result + # Flatten the observation + observation = gym.spaces.flatten(self.observation_space, observation) + return observation # Old Gym API returns only the observation + + def step(self, action): + # Unflatten the action + action = gym.spaces.unflatten(self.action_space, action) + result = self.env.step(action) + + if len(result) == 5: + # Gymnasium returns (obs, reward, terminated, truncated, info) + observation, reward, terminated, truncated, info = result + done = terminated or truncated # Combine for old Gym API + else: + # Old Gym returns (obs, reward, done, info) + observation, reward, done, info = result + + # Flatten the observation + observation = gym.spaces.flatten(self.observation_space, observation) + return observation, reward, done, info + + def render(self, mode='human'): + return self.env.render(mode=mode) + + def close(self): + return self.env.close() + +# Example usage: +if __name__ == "__main__": + # Create a MyoSuite environment + env = myosuite.make('myoChallengeDieReorientP2-v0') + + # Wrap it with the old Gym-style wrapper + env = OldGymWrapper(env) + + # Use the environment as usual + observation = env.reset() + done = False + while not done: + # Sample a random action + action = env.action_space.sample() + # Step the environment + observation, reward, done, info = env.step(action) + # Optionally render the environment + env.render() + env.close() + def make_atari(env_id, timelimit=True, noop_max=0, skip=4, sticky=False, directory=None, **kwargs): env = gym.make(env_id, **kwargs) diff --git a/rl_games/configs/ppo_myo.yaml b/rl_games/configs/ppo_myo.yaml new file mode 100644 index 00000000..297a014b --- /dev/null +++ b/rl_games/configs/ppo_myo.yaml @@ -0,0 +1,68 @@ +params: + seed: 8 + algo: + name: a2c_continuous + + model: + name: continuous_a2c_logstd + + network: + name: actor_critic + separate: False + space: + continuous: + mu_activation: None + sigma_activation: None + mu_init: + name: default + sigma_init: + name: const_initializer + val: 0 + fixed_sigma: True + mlp: + units: [256,128,64] + d2rl: False + activation: elu + initializer: + name: default + scale: 2 + config: + env_name: myo_gym + name: myo + reward_shaper: + min_val: -1 + scale_value: 0.1 + + normalize_advantage: True + gamma: 0.995 + tau: 0.95 + learning_rate: 3e-4 + lr_schedule: adaptive + kl_threshold: 0.008 + save_best_after: 10 + score_to_win: 10000 + grad_norm: 1.5 + entropy_coef: 0 + truncate_grads: True + e_clip: 0.2 + clip_value: False + num_actors: 16 + horizon_length: 128 + minibatch_size: 1024 + mini_epochs: 4 + critic_coef: 2 + normalize_input: True + bounds_loss_coef: 0.00 + max_epochs: 10000 + normalize_value: True + use_diagnostics: True + value_bootstrap: True + #weight_decay: 0.0001 + use_smooth_clamp: True + env_config: + name: 'myoElbowPose1D6MRandom-v0' + player: + + render: True + deterministic: True + games_num: 200 From 90af59b858943672d64341dc9c36910b7a6cbb96 Mon Sep 17 00:00:00 2001 From: paLeziart <47817000+paLeziart@users.noreply.github.com> Date: Mon, 30 Sep 2024 07:05:01 +0900 Subject: [PATCH 13/13] Remove duplicate lines from a2c_common.py (#288) --- rl_games/common/a2c_common.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/rl_games/common/a2c_common.py b/rl_games/common/a2c_common.py index e083b0b5..54a5cda1 100644 --- a/rl_games/common/a2c_common.py +++ b/rl_games/common/a2c_common.py @@ -240,13 +240,10 @@ def __init__(self, base_name, params): self.batch_size = self.horizon_length * self.num_actors * self.num_agents self.batch_size_envs = self.horizon_length * self.num_actors - assert(('minibatch_size_per_env' in self.config) or ('minibatch_size' in self.config)) - self.minibatch_size_per_env = self.config.get('minibatch_size_per_env', 0) - self.minibatch_size = self.config.get('minibatch_size', self.num_actors * self.minibatch_size_per_env) - # either minibatch_size_per_env or minibatch_size should be present in a config # if both are present, minibatch_size is used # otherwise minibatch_size_per_env is used minibatch_size_per_env is used to calculate minibatch_size + assert(('minibatch_size_per_env' in self.config) or ('minibatch_size' in self.config)) self.minibatch_size_per_env = self.config.get('minibatch_size_per_env', 0) self.minibatch_size = self.config.get('minibatch_size', self.num_actors * self.minibatch_size_per_env)