From 7f9cd1e3293e74bbc88063b4a11d46daf9097f94 Mon Sep 17 00:00:00 2001 From: Denys Makoviichuk Date: Sun, 8 Sep 2024 17:33:23 -0700 Subject: [PATCH 1/3] added test aux_loss (#303) Co-authored-by: Denys Makoviichuk --- rl_games/algos_torch/a2c_continuous.py | 10 ++- rl_games/algos_torch/a2c_discrete.py | 11 ++- rl_games/algos_torch/models.py | 24 ++++++- rl_games/algos_torch/network_builder.py | 3 + rl_games/common/a2c_common.py | 9 +-- .../test/test_discrite_testnet_aux_loss.yaml | 52 ++++++++++++++ rl_games/envs/__init__.py | 5 +- rl_games/envs/test/rnn_env.py | 10 +++ rl_games/envs/test_network.py | 70 ++++++++++++++++++- 9 files changed, 180 insertions(+), 14 deletions(-) create mode 100644 rl_games/configs/test/test_discrite_testnet_aux_loss.yaml diff --git a/rl_games/algos_torch/a2c_continuous.py b/rl_games/algos_torch/a2c_continuous.py index a64de95f..e93ea362 100644 --- a/rl_games/algos_torch/a2c_continuous.py +++ b/rl_games/algos_torch/a2c_continuous.py @@ -151,7 +151,15 @@ def calc_gradients(self, input_dict): a_loss, c_loss, entropy, b_loss = losses[0], losses[1], losses[2], losses[3] loss = a_loss + 0.5 * c_loss * self.critic_coef - entropy * self.entropy_coef + b_loss * self.bounds_loss_coef - + aux_loss = self.model.get_aux_loss() + self.aux_loss_dict = {} + if aux_loss is not None: + for k, v in aux_loss.items(): + loss += v + if k in self.aux_loss_dict: + self.aux_loss_dict[k] = v.detach() + else: + self.aux_loss_dict[k] = [v.detach()] if self.multi_gpu: self.optimizer.zero_grad() else: diff --git a/rl_games/algos_torch/a2c_discrete.py b/rl_games/algos_torch/a2c_discrete.py index 312bd17c..a69ba9bb 100644 --- a/rl_games/algos_torch/a2c_discrete.py +++ b/rl_games/algos_torch/a2c_discrete.py @@ -170,7 +170,16 @@ def calc_gradients(self, input_dict): losses, sum_mask = torch_ext.apply_masks([a_loss.unsqueeze(1), c_loss, entropy.unsqueeze(1)], rnn_masks) a_loss, c_loss, entropy = losses[0], losses[1], losses[2] loss = a_loss + 0.5 *c_loss * self.critic_coef - entropy * self.entropy_coef - + aux_loss = self.model.get_aux_loss() + self.aux_loss_dict = {} + if aux_loss is not None: + for k, v in aux_loss.items(): + loss += v + if k in self.aux_loss_dict: + self.aux_loss_dict[k] = v.detach() + else: + self.aux_loss_dict[k] = [v.detach()] + if self.multi_gpu: self.optimizer.zero_grad() else: diff --git a/rl_games/algos_torch/models.py b/rl_games/algos_torch/models.py index 9c9dde4d..93d4001d 100644 --- a/rl_games/algos_torch/models.py +++ b/rl_games/algos_torch/models.py @@ -54,6 +54,10 @@ def norm_obs(self, observation): def denorm_value(self, value): with torch.no_grad(): return self.value_mean_std(value, denorm=True) if self.normalize_value else value + + + def get_aux_loss(self): + return None class ModelA2C(BaseModel): def __init__(self, network): @@ -64,7 +68,10 @@ class Network(BaseModelNetwork): def __init__(self, a2c_network, **kwargs): BaseModelNetwork.__init__(self,**kwargs) self.a2c_network = a2c_network - + + def get_aux_loss(self): + return self.a2c_network.get_aux_loss() + def is_rnn(self): return self.a2c_network.is_rnn() @@ -121,6 +128,9 @@ def __init__(self, a2c_network, **kwargs): BaseModelNetwork.__init__(self, **kwargs) self.a2c_network = a2c_network + def get_aux_loss(self): + return self.a2c_network.get_aux_loss() + def is_rnn(self): return self.a2c_network.is_rnn() @@ -190,6 +200,9 @@ def __init__(self, a2c_network, **kwargs): BaseModelNetwork.__init__(self, **kwargs) self.a2c_network = a2c_network + def get_aux_loss(self): + return self.a2c_network.get_aux_loss() + def is_rnn(self): return self.a2c_network.is_rnn() @@ -248,6 +261,9 @@ def __init__(self, a2c_network, **kwargs): BaseModelNetwork.__init__(self, **kwargs) self.a2c_network = a2c_network + def get_aux_loss(self): + return self.a2c_network.get_aux_loss() + def is_rnn(self): return self.a2c_network.is_rnn() @@ -305,6 +321,9 @@ def __init__(self, a2c_network, **kwargs): BaseModelNetwork.__init__(self, **kwargs) self.a2c_network = a2c_network + def get_aux_loss(self): + return self.a2c_network.get_aux_loss() + def is_rnn(self): return self.a2c_network.is_rnn() @@ -344,6 +363,9 @@ def __init__(self, sac_network,**kwargs): BaseModelNetwork.__init__(self,**kwargs) self.sac_network = sac_network + def get_aux_loss(self): + return self.sac_network.get_aux_loss() + def critic(self, obs, action): return self.sac_network.critic(obs, action) diff --git a/rl_games/algos_torch/network_builder.py b/rl_games/algos_torch/network_builder.py index e5d625c0..289812dd 100644 --- a/rl_games/algos_torch/network_builder.py +++ b/rl_games/algos_torch/network_builder.py @@ -67,6 +67,9 @@ def is_rnn(self): def get_default_rnn_state(self): return None + def get_aux_loss(self): + return None + def _calc_input_size(self, input_shape,cnn_layers=None): if cnn_layers is None: assert(len(input_shape) == 1) diff --git a/rl_games/common/a2c_common.py b/rl_games/common/a2c_common.py index 224bca6b..54b3e1ef 100644 --- a/rl_games/common/a2c_common.py +++ b/rl_games/common/a2c_common.py @@ -321,9 +321,7 @@ def __init__(self, base_name, params): self.algo_observer = config['features']['observer'] self.soft_aug = config['features'].get('soft_augmentation', None) - self.has_soft_aug = self.soft_aug is not None - # soft augmentation not yet supported - assert not self.has_soft_aug + self.aux_loss_dict = {} def trancate_gradients_and_step(self): if self.multi_gpu: @@ -374,6 +372,8 @@ def write_stats(self, total_time, epoch_num, step_time, play_time, update_time, self.writer.add_scalar('losses/c_loss', torch_ext.mean_list(c_losses).item(), frame) self.writer.add_scalar('losses/entropy', torch_ext.mean_list(entropies).item(), frame) + for k,v in self.aux_loss_dict.items(): + self.writer.add_scalar('losses/' + k, torch_ext.mean_list(v).item(), frame) self.writer.add_scalar('info/last_lr', last_lr * lr_mul, frame) self.writer.add_scalar('info/lr_mul', lr_mul, frame) self.writer.add_scalar('info/e_clip', self.e_clip * lr_mul, frame) @@ -1357,9 +1357,6 @@ def train(self): if len(b_losses) > 0: self.writer.add_scalar('losses/bounds_loss', torch_ext.mean_list(b_losses).item(), frame) - if self.has_soft_aug: - self.writer.add_scalar('losses/aug_loss', np.mean(aug_losses), frame) - if self.game_rewards.current_size > 0: mean_rewards = self.game_rewards.get_mean() mean_shaped_rewards = self.game_shaped_rewards.get_mean() diff --git a/rl_games/configs/test/test_discrite_testnet_aux_loss.yaml b/rl_games/configs/test/test_discrite_testnet_aux_loss.yaml new file mode 100644 index 00000000..0f666f0d --- /dev/null +++ b/rl_games/configs/test/test_discrite_testnet_aux_loss.yaml @@ -0,0 +1,52 @@ +params: + algo: + name: a2c_discrete + + model: + name: discrete_a2c + + network: + name: testnet_aux_loss + config: + reward_shaper: + scale_value: 1 + normalize_advantage: True + gamma: 0.99 + tau: 0.9 + learning_rate: 2e-4 + name: test_md_multi_obs + score_to_win: 0.95 + grad_norm: 10.5 + entropy_coef: 0.005 + truncate_grads: True + env_name: test_env + e_clip: 0.2 + clip_value: False + num_actors: 16 + horizon_length: 256 + minibatch_size: 2048 + mini_epochs: 4 + critic_coef: 1 + lr_schedule: None + kl_threshold: 0.008 + normalize_input: False + normalize_value: False + weight_decay: 0.0000 + max_epochs: 10000 + seq_length: 16 + save_best_after: 10 + save_frequency: 20 + + env_config: + name: TestRnnEnv-v0 + hide_object: False + apply_dist_reward: False + min_dist: 2 + max_dist: 8 + use_central_value: True + multi_obs_space: True + multi_head_value: False + aux_loss: True + player: + games_num: 100 + deterministic: True \ No newline at end of file diff --git a/rl_games/envs/__init__.py b/rl_games/envs/__init__.py index 6883b34a..b906c43d 100644 --- a/rl_games/envs/__init__.py +++ b/rl_games/envs/__init__.py @@ -1,6 +1,7 @@ -from rl_games.envs.test_network import TestNetBuilder +from rl_games.envs.test_network import TestNetBuilder, TestNetAuxLossBuilder from rl_games.algos_torch import model_builder -model_builder.register_network('testnet', TestNetBuilder) \ No newline at end of file +model_builder.register_network('testnet', TestNetBuilder) +model_builder.register_network('testnet_aux_loss', TestNetAuxLossBuilder) \ No newline at end of file diff --git a/rl_games/envs/test/rnn_env.py b/rl_games/envs/test/rnn_env.py index faa4e17e..5fcf5318 100644 --- a/rl_games/envs/test/rnn_env.py +++ b/rl_games/envs/test/rnn_env.py @@ -16,6 +16,7 @@ def __init__(self, **kwargs): self.apply_dist_reward = kwargs.pop('apply_dist_reward', False) self.apply_exploration_reward = kwargs.pop('apply_exploration_reward', False) self.multi_head_value = kwargs.pop('multi_head_value', False) + self.aux_loss = kwargs.pop('aux_loss', False) if self.multi_head_value: self.value_size = 2 else: @@ -33,6 +34,8 @@ def __init__(self, **kwargs): 'pos': gym.spaces.Box(low=0, high=1, shape=(2, ), dtype=np.float32), 'info': gym.spaces.Box(low=0, high=1, shape=(4, ), dtype=np.float32), } + if self.aux_loss: + spaces['aux_target'] = gym.spaces.Box(low=0, high=1, shape=(1, ), dtype=np.float32) self.observation_space = gym.spaces.Dict(spaces) else: self.observation_space = gym.spaces.Box(low=0, high=1, shape=(6, ), dtype=np.float32) @@ -58,6 +61,9 @@ def reset(self): 'pos': obs[:2], 'info': obs[2:] } + if self.aux_loss: + aux_target = np.sum((self._goal_pos - self._current_pos)**2) / bound**2 + obs['aux_target'] = np.expand_dims(aux_target.astype(np.float32), axis=0) if self.use_central_value: obses = {} obses["obs"] = obs @@ -93,6 +99,7 @@ def step_multi_categorical(self, action): def step(self, action): info = {} self._curr_steps += 1 + bound = self.max_dist - self.min_dist if self.multi_discrete_space: self.step_multi_categorical(action) else: @@ -125,6 +132,9 @@ def step(self, action): 'pos': obs[:2], 'info': obs[2:] } + if self.aux_loss: + aux_target = np.sum((self._goal_pos - self._current_pos)**2) / bound**2 + obs['aux_target'] = np.expand_dims(aux_target.astype(np.float32), axis=0) if self.use_central_value: state = np.concatenate([self._current_pos, self._goal_pos, [show_object, self._curr_steps]], axis=None) obses = {} diff --git a/rl_games/envs/test_network.py b/rl_games/envs/test_network.py index 6170ebb7..7adfae90 100644 --- a/rl_games/envs/test_network.py +++ b/rl_games/envs/test_network.py @@ -2,8 +2,9 @@ from torch import nn import torch.nn.functional as F - -class TestNet(nn.Module): +from rl_games.algos_torch.network_builder import NetworkBuilder + +class TestNet(NetworkBuilder.BaseNetwork): def __init__(self, params, **kwargs): nn.Module.__init__(self) actions_num = kwargs.pop('actions_num') @@ -38,7 +39,7 @@ def forward(self, obs): return action, value, None -from rl_games.algos_torch.network_builder import NetworkBuilder + class TestNetBuilder(NetworkBuilder): def __init__(self, **kwargs): @@ -52,3 +53,66 @@ def build(self, name, **kwargs): def __call__(self, name, **kwargs): return self.build(name, **kwargs) + + + +class TestNetWithAuxLoss(NetworkBuilder.BaseNetwork): + def __init__(self, params, **kwargs): + nn.Module.__init__(self) + actions_num = kwargs.pop('actions_num') + input_shape = kwargs.pop('input_shape') + num_inputs = 0 + + self.target_key = 'aux_target' + assert(type(input_shape) is dict) + for k,v in input_shape.items(): + if self.target_key == k: + self.target_shape = v[0] + else: + num_inputs +=v[0] + + self.central_value = params.get('central_value', False) + self.value_size = kwargs.pop('value_size', 1) + self.linear1 = nn.Linear(num_inputs, 256) + self.linear2 = nn.Linear(256, 128) + self.linear3 = nn.Linear(128, 64) + self.mean_linear = nn.Linear(64, actions_num) + self.value_linear = nn.Linear(64, 1) + self.aux_loss_linear = nn.Linear(64, self.target_shape) + + self.aux_loss_map = { + 'aux_dist_loss' : None + } + def is_rnn(self): + return False + + def get_aux_loss(self): + return self.aux_loss_map + + def forward(self, obs): + obs = obs['obs'] + target_obs = obs[self.target_key] + obs = torch.cat([obs['pos'], obs['info']], axis=-1) + x = F.relu(self.linear1(obs)) + x = F.relu(self.linear2(x)) + x = F.relu(self.linear3(x)) + action = self.mean_linear(x) + value = self.value_linear(x) + y = self.aux_loss_linear(x) + self.aux_loss_map['aux_dist_loss'] = torch.nn.functional.mse_loss(y, target_obs) + if self.central_value: + return value, None + return action, value, None + +class TestNetAuxLossBuilder(NetworkBuilder): + def __init__(self, **kwargs): + NetworkBuilder.__init__(self) + + def load(self, params): + self.params = params + + def build(self, name, **kwargs): + return TestNetWithAuxLoss(self.params, **kwargs) + + def __call__(self, name, **kwargs): + return self.build(name, **kwargs) \ No newline at end of file From 59d4c409302735766251ac8846c1791c828241c1 Mon Sep 17 00:00:00 2001 From: annan_tang <54849345+annan-tang@users.noreply.github.com> Date: Thu, 12 Sep 2024 04:17:22 +0900 Subject: [PATCH 2/3] Add a broadcast for the initial parameters of central_value_net in multi-GPU/node training. (#297) --- rl_games/common/a2c_common.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/rl_games/common/a2c_common.py b/rl_games/common/a2c_common.py index 54b3e1ef..e083b0b5 100644 --- a/rl_games/common/a2c_common.py +++ b/rl_games/common/a2c_common.py @@ -1047,8 +1047,12 @@ def train(self): torch.cuda.set_device(self.local_rank) print("====================broadcasting parameters") model_params = [self.model.state_dict()] + if self.has_central_value: + model_params.append(self.central_value_net.state_dict()) dist.broadcast_object_list(model_params, 0) self.model.load_state_dict(model_params[0]) + if self.has_central_value: + self.central_value_net.load_state_dict(model_params[1]) while True: epoch_num = self.update_epoch() @@ -1326,8 +1330,12 @@ def train(self): torch.cuda.set_device(self.local_rank) print("====================broadcasting parameters") model_params = [self.model.state_dict()] + if self.has_central_value: + model_params.append(self.central_value_net.state_dict()) dist.broadcast_object_list(model_params, 0) self.model.load_state_dict(model_params[0]) + if self.has_central_value: + self.central_value_net.load_state_dict(model_params[1]) while True: epoch_num = self.update_epoch() From eca781c2a05c0c986d0595014764c2817499efa5 Mon Sep 17 00:00:00 2001 From: "Yueqian (Ryan) Liu" Date: Sat, 28 Sep 2024 22:01:55 +0200 Subject: [PATCH 3/3] Documation about class implemetations, customization of training/playing loops and models and networks (#305) * add doc about class implementations and customization of training, playing, model and network * Update docs/ISAAC_GYM.md Fix typo. Co-authored-by: Artem Yerofieiev <169092593+ayerofieiev-tt@users.noreply.github.com> --------- Co-authored-by: Artem Yerofieiev <169092593+ayerofieiev-tt@users.noreply.github.com> --- docs/ISAAC_GYM.md | 172 +++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 171 insertions(+), 1 deletion(-) diff --git a/docs/ISAAC_GYM.md b/docs/ISAAC_GYM.md index d5c39f57..144a1794 100644 --- a/docs/ISAAC_GYM.md +++ b/docs/ISAAC_GYM.md @@ -1,4 +1,174 @@ ## Isaac Gym Results https://developer.nvidia.com/isaac-gym -Coming. +## What's Written Below + +Content below is written to complement `HOW_TO_RL_GAMES.md` in the same directory, while focusing more on **explaining the implementations** in the classes and how to **customize the training (testing) loops, models and networks**. Since the AMP implementation in `IsaacGymEnvs` is used as the example, so you are reading me here under this file. + +## Program Entry Point + +The primary entry point for both training and testing within `IsaacGymEnvs` is the `train.py` script. This file initializes an instance of the `rl_games.torch_runner.Runner` class, and depending on the mode selected, either the `run_train` or `run_play` function is executed. Additionally, `train.py` allows for custom implementations of training and testing loops, as well as the integration of custom neural networks and models into the library through the `build_runner` function, a process referred to as "registering." By registering custom code, the library can be configured to execute the user-defined code by specifying the appropriate names within the configuration file. + +In RL Games, the training algorithms are referred to as "agents," while their counterparts for testing are known as "players." In the `run_train` function, an agent is instantiated, and training is initiated through the `agent.train` call. Similarly, in the `run_play` function, a player is created, and testing begins by invoking `player.run`. Thus, the core entry points for training and testing in RL Games are the `train` function for agents and the `run` function for players. + +```python +def run_train(self, args): + """Run the training procedure from the algorithm passed in.""" + + print('Started to train') + agent = self.algo_factory.create(self.algo_name, base_name='run', params=self.params) + _restore(agent, args) + _override_sigma(agent, args) + agent.train() + +def run_play(self, args): + """Run the inference procedure from the algorithm passed in.""" + + print('Started to play') + player = self.create_player() + _restore(player, args) + _override_sigma(player, args) + player.run() +``` + +## Training Algorithms + +The creation of an agent is handled by the `algo_factory`, as demonstrated in the code above. By default, the `algo_factory` is registered with continuous-action A2C, discrete-action A2C, and SAC. This default registration is found within the constructor of the `Runner` class, and its implementation is shown below. Our primary focus will be on understanding `A2CAgent`, as it is the primary algorithm used for most continuous-control tasks in `IsaacGymEnvs`. + +```python +self.algo_factory.register_builder( + 'a2c_continuous', + lambda **kwargs: a2c_continuous.A2CAgent(**kwargs) +) +self.algo_factory.register_builder( + 'a2c_discrete', + lambda **kwargs: a2c_discrete.DiscreteA2CAgent(**kwargs) +) +self.algo_factory.register_builder( + 'sac', + lambda **kwargs: sac_agent.SACAgent(**kwargs) +) +``` + +At the base of all RL Games algorithms is the `BaseAlgorithm` class, an abstract class that defines several essential methods, including `train` and `train_epoch`, which are critical for training. The `A2CBase` class inherits from `BaseAlgorithm` and provides many shared functionalities for both continuous and discrete A2C agents. These include methods such as `play_steps` and `play_steps_rnn` for gathering rollout data, and `env_step` and `env_reset` for interacting with the environment. However, functions directly related to training—like `train`, `train_epoch`, `update_epoch`, `prepare_dataset`, `train_actor_critic`, and `calc_gradients`—are left unimplemented at this level. These functions are implemented in `ContinuousA2CBase`, a subclass of `A2CBase`, and further in `A2CAgent`, a subclass of `ContinuousA2CBase`. + +The `ContinuousA2CBase` class is responsible for the core logic of agent training, specifically in the methods `train`, `train_epoch`, and `prepare_dataset`. In the `train` function, the environment is reset once before entering the main training loop. This loop consists of three primary stages: + +1. Calling `update_epoch`. +2. Running `train_epoch`. +3. Logging key information, such as episode length, rewards, and losses. + +The `update_epoch` function, which increments the epoch count, is implemented in `A2CAgent`. The heart of the training process is the `train_epoch` function, which operates as follows: + +1. `play_steps` or `play_steps_rnn` is called to generate rollout data in the form of a dictionary of tensors, `batch_dict`. The number of environment steps collected equals the configured `horizon_length`. +2. `prepare_dataset` modifies the tensors in `batch_dict`, which may include normalizing values and advantages, depending on the configuration. +3. Multiple mini-epochs are executed. In each mini-epoch, the dataset is divided into mini-batches, which are sequentially fed into `train_actor_critic`. Function `train_actor_critic`, implemented in `A2CAgent`, internally calls `calc_grad`, also found in `A2CAgent`. + +The `A2CAgent` class, which inherits from `ContinuousA2CBase`, handles the crucial task of gradient calculation and model parameter optimization in its `calc_grad` function. Specifically, `calc_grad` first performs a forward pass of the policy model with PyTorch’s gradients and computational graph enabled. It then calculates the individual loss terms as well as the total scalar loss, runs the backward pass to compute gradients, truncates gradients if necessary, updates model parameters via the optimizer, and finally logs the relevant training metrics such as loss terms and learning rates. + +With an understanding of the default functions, it becomes straightforward to customize agents by inheriting from `A2CAgent` and overriding specific methods to suit particular needs. A good example of this is the implementation of the AMP algorithm in `IsaacGymEnvs`, where the `AMPAgent` class is created and registered in `train.py`, as shown below. + +```python +_runner.algo_factory.register_builder( + 'amp_continuous', + lambda **kwargs: amp_continuous.AMPAgent(**kwargs) +) +``` + +## Players + +Similar to training algorithms, default players are registered with `player_factory` in the `Runner` class. These include `PPOPlayerContinuous`, `PPOPlayerDiscrete`, and `SACPlayer`. Each of these player classes inherits from the `BasePlayer` class, which provides a common `run` function. The derived player classes implement specific methods for restoring from model checkpoints (`restore`), initializing the RNN (`reset`), and generating actions based on observations through `get_action` and `get_masked_action`. + +The testing loop is simpler compared to the training loop. It starts by resetting the environment to obtain the initial observation. Then, for `max_steps` iterations, the loop feeds the observation into the model to generate an action, which is applied to the environment to retrieve the next observation, reward, and other necessary data. This process is repeated for `n_games` episodes, after which the average reward and episode lengths are calculated and displayed. + +Customizing the testing loop is as straightforward as customizing the training loop. By inheriting from a default player class, one can override specific functions as needed. As with custom training algorithms, customized players must also be registered with `player_factory` in `train.py`, as demonstrated below. + +```python +self.player_factory.register_builder( + 'a2c_continuous', + lambda **kwargs: players.PpoPlayerContinuous(**kwargs) +) +self.player_factory.register_builder( + 'a2c_discrete', + lambda **kwargs: players.PpoPlayerDiscrete(**kwargs) +) +self.player_factory.register_builder( + 'sac', + lambda **kwargs: players.SACPlayer(**kwargs) +) + +_runner.player_factory.register_builder( + 'amp_continuous', + lambda **kwargs: amp_players.AMPPlayerContinuous(**kwargs) +) +``` + +## Models and Networks + +The terminology and implementation of models and networks in RL Games version `1.6.1` can be confusing for new users. Below is a high-level overview of their functionality and relationships: + +- **Network Builder:** Network builder classes, such as `A2CBuilder` and `SACBuilder`, are subclasses of `NetworkBuilder` and can be found in `algos_torch.network_builder`. The core component of a network builder is the nested `Network` class (the "inner network" class), which is typically derived from `torch.nn.Module`. This class receives a dictionary of tensors, such as observations and other necessary inputs, and outputs a tuple of tensors from which actions can be generated. The `forward` function of the `Network` class handles this transformation. + +- **Model:** Model classes, like `ModelA2C` and `ModelSACContinuous`, inherit from `BaseModel` in `algos_torch.models`. They are similar to network builders, as each contains a nested `Network` class (the "model network" class) and a `build` function to construct an instance of this network. + +- **Model & Network in Algorithm:** In a default agent or player algorithm, `self.model` refers to an instance of the model network class, while `self.network` refers to an instance of the model class. + +- **Model Builder:** The `ModelBuilder` class, located in `algos_torch.model_builder`, is responsible for loading and managing models. It provides a `load` function, which creates a model instance based on the specified name. + +Customizing models requires implementing a custom network builder and model class. These custom classes should be registered in the `Runner` class within `train.py`. A good reference example is the AMP implementation, as shown below. + +```python +# algos_torch.model_builder.NetworkBuilder.__init__ +self.network_factory.register_builder( + 'actor_critic', + lambda **kwargs: network_builder.A2CBuilder() +) +self.network_factory.register_builder( + 'resnet_actor_critic', + lambda **kwargs: network_builder.A2CResnetBuilder() +) +self.network_factory.register_builder( + 'rnd_curiosity', + lambda **kwargs: network_builder.RNDCuriosityBuilder() +) +self.network_factory.register_builder( + 'soft_actor_critic', + lambda **kwargs: network_builder.SACBuilder() +) + +# algos_torch.model_builder.ModelBuilder.__init__ +self.model_factory.register_builder( + 'discrete_a2c', + lambda network, **kwargs: models.ModelA2C(network) +) +self.model_factory.register_builder( + 'multi_discrete_a2c', + lambda network, **kwargs: models.ModelA2CMultiDiscrete(network) +) +self.model_factory.register_builder( + 'continuous_a2c', + lambda network, **kwargs: models.ModelA2CContinuous(network) +) +self.model_factory.register_builder( + 'continuous_a2c_logstd', + lambda network, **kwargs: models.ModelA2CContinuousLogStd(network) +) +self.model_factory.register_builder( + 'soft_actor_critic', + lambda network, **kwargs: models.ModelSACContinuous(network) +) +self.model_factory.register_builder( + 'central_value', + lambda network, **kwargs: models.ModelCentralValue(network) +) + +# isaacgymenvs.train.launch_rlg_hydra.build_runner +model_builder.register_model( + 'continuous_amp', + lambda network, **kwargs: amp_models.ModelAMPContinuous(network), +) +model_builder.register_network( + 'amp', + lambda **kwargs: amp_network_builder.AMPBuilder() +) +```