Skip to content

Commit

Permalink
Merge with aux_loss branch.
Browse files Browse the repository at this point in the history
  • Loading branch information
ViktorM committed Sep 8, 2024
2 parents 8b274d1 + 2fe6c5f commit 00cbd3d
Show file tree
Hide file tree
Showing 9 changed files with 231 additions and 24 deletions.
71 changes: 60 additions & 11 deletions rl_games/algos_torch/a2c_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,55 @@ def calc_gradients(self, input_dict):
mu = res_dict['mus']
sigma = res_dict['sigmas']

loss, a_loss, c_loss, entropy, b_loss, sum_mask = self.calc_losses(self.actor_loss_func,
old_action_log_probs_batch, action_log_probs, advantage, curr_e_clip,
value_preds_batch, values, return_batch, mu, entropy, rnn_masks)
loss, a_loss, c_loss, entropy, b_loss, sum_mask = self.calc_losses(
self.actor_loss_func,
old_action_log_probs_batch,
action_log_probs,
advantage,
curr_e_clip,
value_preds_batch,
values,
return_batch,
mu,
entropy,
rnn_masks
)

if self.has_value_loss:
c_loss = common_losses.critic_loss(
self.model, value_preds_batch, values, curr_e_clip, return_batch,
self.clip_value
)
else:
c_loss = torch.zeros(1, device=self.ppo_device)
if self.bound_loss_type == 'regularisation':
b_loss = self.reg_loss(mu)
elif self.bound_loss_type == 'bound':
b_loss = self.bound_loss(mu)
else:
b_loss = torch.zeros(1, device=self.ppo_device)

losses, sum_mask = torch_ext.apply_masks(
[
a_loss.unsqueeze(1),
c_loss,
entropy.unsqueeze(1),
b_loss.unsqueeze(1)
],
rnn_masks
)
a_loss, c_loss, entropy, b_loss = losses[0], losses[1], losses[2], losses[3]

loss = a_loss + 0.5 * c_loss * self.critic_coef - entropy * self.entropy_coef + b_loss * self.bounds_loss_coef
aux_loss = self.model.get_aux_loss()
self.aux_loss_dict = {}
if aux_loss is not None:
for k, v in aux_loss.items():
loss += v
if k in self.aux_loss_dict:
self.aux_loss_dict[k] = v.detach()
else:
self.aux_loss_dict[k] = [v.detach()]

if self.multi_gpu:
self.optimizer.zero_grad()
Expand All @@ -173,22 +219,25 @@ def calc_gradients(self, input_dict):
param.grad = None

self.scaler.scale(loss).backward()
#TODO: Refactor this ugliest code of they year
# TODO: Refactor this ugliest code of they year
self.trancate_gradients_and_step()

with torch.no_grad():
reduce_kl = rnn_masks is None
kl_dist = torch_ext.policy_kl(mu.detach(), sigma.detach(), old_mu_batch, old_sigma_batch, reduce_kl)
kl_dist = torch_ext.policy_kl(
mu.detach(), sigma.detach(), old_mu_batch, old_sigma_batch,
reduce_kl
)
if rnn_masks is not None:
kl_dist = (kl_dist * rnn_masks).sum() / rnn_masks.numel() #/ sum_mask

self.diagnostics.mini_batch(self,
{
'values' : value_preds_batch,
'returns' : return_batch,
'new_neglogp' : action_log_probs,
'old_neglogp' : old_action_log_probs_batch,
'masks' : rnn_masks
'values': value_preds_batch,
'returns': return_batch,
'new_neglogp': action_log_probs,
'old_neglogp': old_action_log_probs_batch,
'masks': rnn_masks
}, curr_e_clip, 0)

self.train_result = (a_loss, c_loss, entropy, \
Expand All @@ -214,4 +263,4 @@ def bound_loss(self, mu):
b_loss = (mu_loss_low + mu_loss_high).sum(axis=-1)
else:
b_loss = 0
return b_loss
return b_loss
11 changes: 10 additions & 1 deletion rl_games/algos_torch/a2c_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,16 @@ def calc_gradients(self, input_dict):
losses, sum_mask = torch_ext.apply_masks([a_loss.unsqueeze(1), c_loss, entropy.unsqueeze(1)], rnn_masks)
a_loss, c_loss, entropy = losses[0], losses[1], losses[2]
loss = a_loss + 0.5 *c_loss * self.critic_coef - entropy * self.entropy_coef

aux_loss = self.model.get_aux_loss()
self.aux_loss_dict = {}
if aux_loss is not None:
for k, v in aux_loss.items():
loss += v
if k in self.aux_loss_dict:
self.aux_loss_dict[k] = v.detach()
else:
self.aux_loss_dict[k] = [v.detach()]

if self.multi_gpu:
self.optimizer.zero_grad()
else:
Expand Down
24 changes: 23 additions & 1 deletion rl_games/algos_torch/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ def norm_obs(self, observation):
def denorm_value(self, value):
with torch.no_grad():
return self.value_mean_std(value, denorm=True) if self.normalize_value else value


def get_aux_loss(self):
return None


class ModelA2C(BaseModel):
Expand All @@ -68,7 +72,10 @@ class Network(BaseModelNetwork):
def __init__(self, a2c_network, **kwargs):
BaseModelNetwork.__init__(self,**kwargs)
self.a2c_network = a2c_network


def get_aux_loss(self):
return self.a2c_network.get_aux_loss()

def is_rnn(self):
return self.a2c_network.is_rnn()

Expand Down Expand Up @@ -126,6 +133,9 @@ def __init__(self, a2c_network, **kwargs):
BaseModelNetwork.__init__(self, **kwargs)
self.a2c_network = a2c_network

def get_aux_loss(self):
return self.a2c_network.get_aux_loss()

def is_rnn(self):
return self.a2c_network.is_rnn()

Expand Down Expand Up @@ -196,6 +206,9 @@ def __init__(self, a2c_network, **kwargs):
BaseModelNetwork.__init__(self, **kwargs)
self.a2c_network = a2c_network

def get_aux_loss(self):
return self.a2c_network.get_aux_loss()

def is_rnn(self):
return self.a2c_network.is_rnn()

Expand Down Expand Up @@ -254,6 +267,9 @@ def __init__(self, a2c_network, **kwargs):
BaseModelNetwork.__init__(self, **kwargs)
self.a2c_network = a2c_network

def get_aux_loss(self):
return self.a2c_network.get_aux_loss()

def is_rnn(self):
return self.a2c_network.is_rnn()

Expand Down Expand Up @@ -312,6 +328,9 @@ def __init__(self, a2c_network, **kwargs):
BaseModelNetwork.__init__(self, **kwargs)
self.a2c_network = a2c_network

def get_aux_loss(self):
return self.a2c_network.get_aux_loss()

def is_rnn(self):
return self.a2c_network.is_rnn()

Expand Down Expand Up @@ -350,6 +369,9 @@ def __init__(self, sac_network,**kwargs):
BaseModelNetwork.__init__(self,**kwargs)
self.sac_network = sac_network

def get_aux_loss(self):
return self.sac_network.get_aux_loss()

def critic(self, obs, action):
return self.sac_network.critic(obs, action)

Expand Down
3 changes: 3 additions & 0 deletions rl_games/algos_torch/network_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ def is_rnn(self):
def get_default_rnn_state(self):
return None

def get_aux_loss(self):
return None

def _calc_input_size(self, input_shape,cnn_layers=None):
if cnn_layers is None:
assert(len(input_shape) == 1)
Expand Down
9 changes: 3 additions & 6 deletions rl_games/common/a2c_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,9 +325,7 @@ def __init__(self, base_name, params):
self.algo_observer = config['features']['observer']

self.soft_aug = config['features'].get('soft_augmentation', None)
self.has_soft_aug = self.soft_aug is not None
# soft augmentation not yet supported
assert not self.has_soft_aug
self.aux_loss_dict = {}

def trancate_gradients_and_step(self):
if self.multi_gpu:
Expand Down Expand Up @@ -378,6 +376,8 @@ def write_stats(self, total_time, epoch_num, step_time, play_time, update_time,
self.writer.add_scalar('losses/c_loss', torch_ext.mean_list(c_losses).item(), frame)

self.writer.add_scalar('losses/entropy', torch_ext.mean_list(entropies).item(), frame)
for k,v in self.aux_loss_dict.items():
self.writer.add_scalar('losses/' + k, torch_ext.mean_list(v).item(), frame)
self.writer.add_scalar('info/last_lr', last_lr * lr_mul, frame)
self.writer.add_scalar('info/lr_mul', lr_mul, frame)
self.writer.add_scalar('info/e_clip', self.e_clip * lr_mul, frame)
Expand Down Expand Up @@ -1362,9 +1362,6 @@ def train(self):
if len(b_losses) > 0:
self.writer.add_scalar('losses/bounds_loss', torch_ext.mean_list(b_losses).item(), frame)

if self.has_soft_aug:
self.writer.add_scalar('losses/aug_loss', np.mean(aug_losses), frame)

if self.game_rewards.current_size > 0:
mean_rewards = self.game_rewards.get_mean()
mean_shaped_rewards = self.game_shaped_rewards.get_mean()
Expand Down
52 changes: 52 additions & 0 deletions rl_games/configs/test/test_discrite_testnet_aux_loss.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
params:
algo:
name: a2c_discrete

model:
name: discrete_a2c

network:
name: testnet_aux_loss
config:
reward_shaper:
scale_value: 1
normalize_advantage: True
gamma: 0.99
tau: 0.9
learning_rate: 2e-4
name: test_md_multi_obs
score_to_win: 0.95
grad_norm: 10.5
entropy_coef: 0.005
truncate_grads: True
env_name: test_env
e_clip: 0.2
clip_value: False
num_actors: 16
horizon_length: 256
minibatch_size: 2048
mini_epochs: 4
critic_coef: 1
lr_schedule: None
kl_threshold: 0.008
normalize_input: False
normalize_value: False
weight_decay: 0.0000
max_epochs: 10000
seq_length: 16
save_best_after: 10
save_frequency: 20

env_config:
name: TestRnnEnv-v0
hide_object: False
apply_dist_reward: False
min_dist: 2
max_dist: 8
use_central_value: True
multi_obs_space: True
multi_head_value: False
aux_loss: True
player:
games_num: 100
deterministic: True
5 changes: 3 additions & 2 deletions rl_games/envs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@


from rl_games.envs.test_network import TestNetBuilder
from rl_games.envs.test_network import TestNetBuilder, TestNetAuxLossBuilder
from rl_games.algos_torch import model_builder

model_builder.register_network('testnet', TestNetBuilder)
model_builder.register_network('testnet', TestNetBuilder)
model_builder.register_network('testnet_aux_loss', TestNetAuxLossBuilder)
10 changes: 10 additions & 0 deletions rl_games/envs/test/rnn_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def __init__(self, **kwargs):
self.apply_dist_reward = kwargs.pop('apply_dist_reward', False)
self.apply_exploration_reward = kwargs.pop('apply_exploration_reward', False)
self.multi_head_value = kwargs.pop('multi_head_value', False)
self.aux_loss = kwargs.pop('aux_loss', False)
if self.multi_head_value:
self.value_size = 2
else:
Expand All @@ -33,6 +34,8 @@ def __init__(self, **kwargs):
'pos': gym.spaces.Box(low=0, high=1, shape=(2, ), dtype=np.float32),
'info': gym.spaces.Box(low=0, high=1, shape=(4, ), dtype=np.float32),
}
if self.aux_loss:
spaces['aux_target'] = gym.spaces.Box(low=0, high=1, shape=(1, ), dtype=np.float32)
self.observation_space = gym.spaces.Dict(spaces)
else:
self.observation_space = gym.spaces.Box(low=0, high=1, shape=(6, ), dtype=np.float32)
Expand All @@ -58,6 +61,9 @@ def reset(self):
'pos': obs[:2],
'info': obs[2:]
}
if self.aux_loss:
aux_target = np.sum((self._goal_pos - self._current_pos)**2) / bound**2
obs['aux_target'] = np.expand_dims(aux_target.astype(np.float32), axis=0)
if self.use_central_value:
obses = {}
obses["obs"] = obs
Expand Down Expand Up @@ -93,6 +99,7 @@ def step_multi_categorical(self, action):
def step(self, action):
info = {}
self._curr_steps += 1
bound = self.max_dist - self.min_dist
if self.multi_discrete_space:
self.step_multi_categorical(action)
else:
Expand Down Expand Up @@ -125,6 +132,9 @@ def step(self, action):
'pos': obs[:2],
'info': obs[2:]
}
if self.aux_loss:
aux_target = np.sum((self._goal_pos - self._current_pos)**2) / bound**2
obs['aux_target'] = np.expand_dims(aux_target.astype(np.float32), axis=0)
if self.use_central_value:
state = np.concatenate([self._current_pos, self._goal_pos, [show_object, self._curr_steps]], axis=None)
obses = {}
Expand Down
Loading

0 comments on commit 00cbd3d

Please sign in to comment.