diff --git a/rl_games/algos_torch/a2c_continuous.py b/rl_games/algos_torch/a2c_continuous.py index 7dafb2d5..a37137ef 100644 --- a/rl_games/algos_torch/a2c_continuous.py +++ b/rl_games/algos_torch/a2c_continuous.py @@ -162,9 +162,55 @@ def calc_gradients(self, input_dict): mu = res_dict['mus'] sigma = res_dict['sigmas'] - loss, a_loss, c_loss, entropy, b_loss, sum_mask = self.calc_losses(self.actor_loss_func, - old_action_log_probs_batch, action_log_probs, advantage, curr_e_clip, - value_preds_batch, values, return_batch, mu, entropy, rnn_masks) + loss, a_loss, c_loss, entropy, b_loss, sum_mask = self.calc_losses( + self.actor_loss_func, + old_action_log_probs_batch, + action_log_probs, + advantage, + curr_e_clip, + value_preds_batch, + values, + return_batch, + mu, + entropy, + rnn_masks + ) + + if self.has_value_loss: + c_loss = common_losses.critic_loss( + self.model, value_preds_batch, values, curr_e_clip, return_batch, + self.clip_value + ) + else: + c_loss = torch.zeros(1, device=self.ppo_device) + if self.bound_loss_type == 'regularisation': + b_loss = self.reg_loss(mu) + elif self.bound_loss_type == 'bound': + b_loss = self.bound_loss(mu) + else: + b_loss = torch.zeros(1, device=self.ppo_device) + + losses, sum_mask = torch_ext.apply_masks( + [ + a_loss.unsqueeze(1), + c_loss, + entropy.unsqueeze(1), + b_loss.unsqueeze(1) + ], + rnn_masks + ) + a_loss, c_loss, entropy, b_loss = losses[0], losses[1], losses[2], losses[3] + + loss = a_loss + 0.5 * c_loss * self.critic_coef - entropy * self.entropy_coef + b_loss * self.bounds_loss_coef + aux_loss = self.model.get_aux_loss() + self.aux_loss_dict = {} + if aux_loss is not None: + for k, v in aux_loss.items(): + loss += v + if k in self.aux_loss_dict: + self.aux_loss_dict[k] = v.detach() + else: + self.aux_loss_dict[k] = [v.detach()] if self.multi_gpu: self.optimizer.zero_grad() @@ -173,22 +219,25 @@ def calc_gradients(self, input_dict): param.grad = None self.scaler.scale(loss).backward() - #TODO: Refactor this ugliest code of they year + # TODO: Refactor this ugliest code of they year self.trancate_gradients_and_step() with torch.no_grad(): reduce_kl = rnn_masks is None - kl_dist = torch_ext.policy_kl(mu.detach(), sigma.detach(), old_mu_batch, old_sigma_batch, reduce_kl) + kl_dist = torch_ext.policy_kl( + mu.detach(), sigma.detach(), old_mu_batch, old_sigma_batch, + reduce_kl + ) if rnn_masks is not None: kl_dist = (kl_dist * rnn_masks).sum() / rnn_masks.numel() #/ sum_mask self.diagnostics.mini_batch(self, { - 'values' : value_preds_batch, - 'returns' : return_batch, - 'new_neglogp' : action_log_probs, - 'old_neglogp' : old_action_log_probs_batch, - 'masks' : rnn_masks + 'values': value_preds_batch, + 'returns': return_batch, + 'new_neglogp': action_log_probs, + 'old_neglogp': old_action_log_probs_batch, + 'masks': rnn_masks }, curr_e_clip, 0) self.train_result = (a_loss, c_loss, entropy, \ @@ -214,4 +263,4 @@ def bound_loss(self, mu): b_loss = (mu_loss_low + mu_loss_high).sum(axis=-1) else: b_loss = 0 - return b_loss + return b_loss \ No newline at end of file diff --git a/rl_games/algos_torch/a2c_discrete.py b/rl_games/algos_torch/a2c_discrete.py index b72d0506..781a85d6 100644 --- a/rl_games/algos_torch/a2c_discrete.py +++ b/rl_games/algos_torch/a2c_discrete.py @@ -175,7 +175,16 @@ def calc_gradients(self, input_dict): losses, sum_mask = torch_ext.apply_masks([a_loss.unsqueeze(1), c_loss, entropy.unsqueeze(1)], rnn_masks) a_loss, c_loss, entropy = losses[0], losses[1], losses[2] loss = a_loss + 0.5 *c_loss * self.critic_coef - entropy * self.entropy_coef - + aux_loss = self.model.get_aux_loss() + self.aux_loss_dict = {} + if aux_loss is not None: + for k, v in aux_loss.items(): + loss += v + if k in self.aux_loss_dict: + self.aux_loss_dict[k] = v.detach() + else: + self.aux_loss_dict[k] = [v.detach()] + if self.multi_gpu: self.optimizer.zero_grad() else: diff --git a/rl_games/algos_torch/models.py b/rl_games/algos_torch/models.py index 4a15183c..db2fcd1e 100644 --- a/rl_games/algos_torch/models.py +++ b/rl_games/algos_torch/models.py @@ -57,6 +57,10 @@ def norm_obs(self, observation): def denorm_value(self, value): with torch.no_grad(): return self.value_mean_std(value, denorm=True) if self.normalize_value else value + + + def get_aux_loss(self): + return None class ModelA2C(BaseModel): @@ -68,7 +72,10 @@ class Network(BaseModelNetwork): def __init__(self, a2c_network, **kwargs): BaseModelNetwork.__init__(self,**kwargs) self.a2c_network = a2c_network - + + def get_aux_loss(self): + return self.a2c_network.get_aux_loss() + def is_rnn(self): return self.a2c_network.is_rnn() @@ -126,6 +133,9 @@ def __init__(self, a2c_network, **kwargs): BaseModelNetwork.__init__(self, **kwargs) self.a2c_network = a2c_network + def get_aux_loss(self): + return self.a2c_network.get_aux_loss() + def is_rnn(self): return self.a2c_network.is_rnn() @@ -196,6 +206,9 @@ def __init__(self, a2c_network, **kwargs): BaseModelNetwork.__init__(self, **kwargs) self.a2c_network = a2c_network + def get_aux_loss(self): + return self.a2c_network.get_aux_loss() + def is_rnn(self): return self.a2c_network.is_rnn() @@ -254,6 +267,9 @@ def __init__(self, a2c_network, **kwargs): BaseModelNetwork.__init__(self, **kwargs) self.a2c_network = a2c_network + def get_aux_loss(self): + return self.a2c_network.get_aux_loss() + def is_rnn(self): return self.a2c_network.is_rnn() @@ -312,6 +328,9 @@ def __init__(self, a2c_network, **kwargs): BaseModelNetwork.__init__(self, **kwargs) self.a2c_network = a2c_network + def get_aux_loss(self): + return self.a2c_network.get_aux_loss() + def is_rnn(self): return self.a2c_network.is_rnn() @@ -350,6 +369,9 @@ def __init__(self, sac_network,**kwargs): BaseModelNetwork.__init__(self,**kwargs) self.sac_network = sac_network + def get_aux_loss(self): + return self.sac_network.get_aux_loss() + def critic(self, obs, action): return self.sac_network.critic(obs, action) diff --git a/rl_games/algos_torch/network_builder.py b/rl_games/algos_torch/network_builder.py index 7d39c118..d0571167 100644 --- a/rl_games/algos_torch/network_builder.py +++ b/rl_games/algos_torch/network_builder.py @@ -67,6 +67,9 @@ def is_rnn(self): def get_default_rnn_state(self): return None + def get_aux_loss(self): + return None + def _calc_input_size(self, input_shape,cnn_layers=None): if cnn_layers is None: assert(len(input_shape) == 1) diff --git a/rl_games/common/a2c_common.py b/rl_games/common/a2c_common.py index d754e595..1c47a37e 100644 --- a/rl_games/common/a2c_common.py +++ b/rl_games/common/a2c_common.py @@ -325,9 +325,7 @@ def __init__(self, base_name, params): self.algo_observer = config['features']['observer'] self.soft_aug = config['features'].get('soft_augmentation', None) - self.has_soft_aug = self.soft_aug is not None - # soft augmentation not yet supported - assert not self.has_soft_aug + self.aux_loss_dict = {} def trancate_gradients_and_step(self): if self.multi_gpu: @@ -378,6 +376,8 @@ def write_stats(self, total_time, epoch_num, step_time, play_time, update_time, self.writer.add_scalar('losses/c_loss', torch_ext.mean_list(c_losses).item(), frame) self.writer.add_scalar('losses/entropy', torch_ext.mean_list(entropies).item(), frame) + for k,v in self.aux_loss_dict.items(): + self.writer.add_scalar('losses/' + k, torch_ext.mean_list(v).item(), frame) self.writer.add_scalar('info/last_lr', last_lr * lr_mul, frame) self.writer.add_scalar('info/lr_mul', lr_mul, frame) self.writer.add_scalar('info/e_clip', self.e_clip * lr_mul, frame) @@ -1362,9 +1362,6 @@ def train(self): if len(b_losses) > 0: self.writer.add_scalar('losses/bounds_loss', torch_ext.mean_list(b_losses).item(), frame) - if self.has_soft_aug: - self.writer.add_scalar('losses/aug_loss', np.mean(aug_losses), frame) - if self.game_rewards.current_size > 0: mean_rewards = self.game_rewards.get_mean() mean_shaped_rewards = self.game_shaped_rewards.get_mean() diff --git a/rl_games/configs/test/test_discrite_testnet_aux_loss.yaml b/rl_games/configs/test/test_discrite_testnet_aux_loss.yaml new file mode 100644 index 00000000..0f666f0d --- /dev/null +++ b/rl_games/configs/test/test_discrite_testnet_aux_loss.yaml @@ -0,0 +1,52 @@ +params: + algo: + name: a2c_discrete + + model: + name: discrete_a2c + + network: + name: testnet_aux_loss + config: + reward_shaper: + scale_value: 1 + normalize_advantage: True + gamma: 0.99 + tau: 0.9 + learning_rate: 2e-4 + name: test_md_multi_obs + score_to_win: 0.95 + grad_norm: 10.5 + entropy_coef: 0.005 + truncate_grads: True + env_name: test_env + e_clip: 0.2 + clip_value: False + num_actors: 16 + horizon_length: 256 + minibatch_size: 2048 + mini_epochs: 4 + critic_coef: 1 + lr_schedule: None + kl_threshold: 0.008 + normalize_input: False + normalize_value: False + weight_decay: 0.0000 + max_epochs: 10000 + seq_length: 16 + save_best_after: 10 + save_frequency: 20 + + env_config: + name: TestRnnEnv-v0 + hide_object: False + apply_dist_reward: False + min_dist: 2 + max_dist: 8 + use_central_value: True + multi_obs_space: True + multi_head_value: False + aux_loss: True + player: + games_num: 100 + deterministic: True \ No newline at end of file diff --git a/rl_games/envs/__init__.py b/rl_games/envs/__init__.py index 6883b34a..b906c43d 100644 --- a/rl_games/envs/__init__.py +++ b/rl_games/envs/__init__.py @@ -1,6 +1,7 @@ -from rl_games.envs.test_network import TestNetBuilder +from rl_games.envs.test_network import TestNetBuilder, TestNetAuxLossBuilder from rl_games.algos_torch import model_builder -model_builder.register_network('testnet', TestNetBuilder) \ No newline at end of file +model_builder.register_network('testnet', TestNetBuilder) +model_builder.register_network('testnet_aux_loss', TestNetAuxLossBuilder) \ No newline at end of file diff --git a/rl_games/envs/test/rnn_env.py b/rl_games/envs/test/rnn_env.py index faa4e17e..5fcf5318 100644 --- a/rl_games/envs/test/rnn_env.py +++ b/rl_games/envs/test/rnn_env.py @@ -16,6 +16,7 @@ def __init__(self, **kwargs): self.apply_dist_reward = kwargs.pop('apply_dist_reward', False) self.apply_exploration_reward = kwargs.pop('apply_exploration_reward', False) self.multi_head_value = kwargs.pop('multi_head_value', False) + self.aux_loss = kwargs.pop('aux_loss', False) if self.multi_head_value: self.value_size = 2 else: @@ -33,6 +34,8 @@ def __init__(self, **kwargs): 'pos': gym.spaces.Box(low=0, high=1, shape=(2, ), dtype=np.float32), 'info': gym.spaces.Box(low=0, high=1, shape=(4, ), dtype=np.float32), } + if self.aux_loss: + spaces['aux_target'] = gym.spaces.Box(low=0, high=1, shape=(1, ), dtype=np.float32) self.observation_space = gym.spaces.Dict(spaces) else: self.observation_space = gym.spaces.Box(low=0, high=1, shape=(6, ), dtype=np.float32) @@ -58,6 +61,9 @@ def reset(self): 'pos': obs[:2], 'info': obs[2:] } + if self.aux_loss: + aux_target = np.sum((self._goal_pos - self._current_pos)**2) / bound**2 + obs['aux_target'] = np.expand_dims(aux_target.astype(np.float32), axis=0) if self.use_central_value: obses = {} obses["obs"] = obs @@ -93,6 +99,7 @@ def step_multi_categorical(self, action): def step(self, action): info = {} self._curr_steps += 1 + bound = self.max_dist - self.min_dist if self.multi_discrete_space: self.step_multi_categorical(action) else: @@ -125,6 +132,9 @@ def step(self, action): 'pos': obs[:2], 'info': obs[2:] } + if self.aux_loss: + aux_target = np.sum((self._goal_pos - self._current_pos)**2) / bound**2 + obs['aux_target'] = np.expand_dims(aux_target.astype(np.float32), axis=0) if self.use_central_value: state = np.concatenate([self._current_pos, self._goal_pos, [show_object, self._curr_steps]], axis=None) obses = {} diff --git a/rl_games/envs/test_network.py b/rl_games/envs/test_network.py index 6170ebb7..7adfae90 100644 --- a/rl_games/envs/test_network.py +++ b/rl_games/envs/test_network.py @@ -2,8 +2,9 @@ from torch import nn import torch.nn.functional as F - -class TestNet(nn.Module): +from rl_games.algos_torch.network_builder import NetworkBuilder + +class TestNet(NetworkBuilder.BaseNetwork): def __init__(self, params, **kwargs): nn.Module.__init__(self) actions_num = kwargs.pop('actions_num') @@ -38,7 +39,7 @@ def forward(self, obs): return action, value, None -from rl_games.algos_torch.network_builder import NetworkBuilder + class TestNetBuilder(NetworkBuilder): def __init__(self, **kwargs): @@ -52,3 +53,66 @@ def build(self, name, **kwargs): def __call__(self, name, **kwargs): return self.build(name, **kwargs) + + + +class TestNetWithAuxLoss(NetworkBuilder.BaseNetwork): + def __init__(self, params, **kwargs): + nn.Module.__init__(self) + actions_num = kwargs.pop('actions_num') + input_shape = kwargs.pop('input_shape') + num_inputs = 0 + + self.target_key = 'aux_target' + assert(type(input_shape) is dict) + for k,v in input_shape.items(): + if self.target_key == k: + self.target_shape = v[0] + else: + num_inputs +=v[0] + + self.central_value = params.get('central_value', False) + self.value_size = kwargs.pop('value_size', 1) + self.linear1 = nn.Linear(num_inputs, 256) + self.linear2 = nn.Linear(256, 128) + self.linear3 = nn.Linear(128, 64) + self.mean_linear = nn.Linear(64, actions_num) + self.value_linear = nn.Linear(64, 1) + self.aux_loss_linear = nn.Linear(64, self.target_shape) + + self.aux_loss_map = { + 'aux_dist_loss' : None + } + def is_rnn(self): + return False + + def get_aux_loss(self): + return self.aux_loss_map + + def forward(self, obs): + obs = obs['obs'] + target_obs = obs[self.target_key] + obs = torch.cat([obs['pos'], obs['info']], axis=-1) + x = F.relu(self.linear1(obs)) + x = F.relu(self.linear2(x)) + x = F.relu(self.linear3(x)) + action = self.mean_linear(x) + value = self.value_linear(x) + y = self.aux_loss_linear(x) + self.aux_loss_map['aux_dist_loss'] = torch.nn.functional.mse_loss(y, target_obs) + if self.central_value: + return value, None + return action, value, None + +class TestNetAuxLossBuilder(NetworkBuilder): + def __init__(self, **kwargs): + NetworkBuilder.__init__(self) + + def load(self, params): + self.params = params + + def build(self, name, **kwargs): + return TestNetWithAuxLoss(self.params, **kwargs) + + def __call__(self, name, **kwargs): + return self.build(name, **kwargs) \ No newline at end of file