Skip to content

Commit

Permalink
added test aux_loss
Browse files Browse the repository at this point in the history
  • Loading branch information
DenSumy committed Sep 8, 2024
1 parent 2606eff commit 2fe6c5f
Show file tree
Hide file tree
Showing 9 changed files with 180 additions and 14 deletions.
10 changes: 9 additions & 1 deletion rl_games/algos_torch/a2c_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,15 @@ def calc_gradients(self, input_dict):
a_loss, c_loss, entropy, b_loss = losses[0], losses[1], losses[2], losses[3]

loss = a_loss + 0.5 * c_loss * self.critic_coef - entropy * self.entropy_coef + b_loss * self.bounds_loss_coef

aux_loss = self.model.get_aux_loss()
self.aux_loss_dict = {}
if aux_loss is not None:
for k, v in aux_loss.items():
loss += v
if k in self.aux_loss_dict:
self.aux_loss_dict[k] = v.detach()
else:
self.aux_loss_dict[k] = [v.detach()]
if self.multi_gpu:
self.optimizer.zero_grad()
else:
Expand Down
11 changes: 10 additions & 1 deletion rl_games/algos_torch/a2c_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,16 @@ def calc_gradients(self, input_dict):
losses, sum_mask = torch_ext.apply_masks([a_loss.unsqueeze(1), c_loss, entropy.unsqueeze(1)], rnn_masks)
a_loss, c_loss, entropy = losses[0], losses[1], losses[2]
loss = a_loss + 0.5 *c_loss * self.critic_coef - entropy * self.entropy_coef

aux_loss = self.model.get_aux_loss()
self.aux_loss_dict = {}
if aux_loss is not None:
for k, v in aux_loss.items():
loss += v
if k in self.aux_loss_dict:
self.aux_loss_dict[k] = v.detach()
else:
self.aux_loss_dict[k] = [v.detach()]

if self.multi_gpu:
self.optimizer.zero_grad()
else:
Expand Down
24 changes: 23 additions & 1 deletion rl_games/algos_torch/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ def norm_obs(self, observation):
def denorm_value(self, value):
with torch.no_grad():
return self.value_mean_std(value, denorm=True) if self.normalize_value else value


def get_aux_loss(self):
return None

class ModelA2C(BaseModel):
def __init__(self, network):
Expand All @@ -64,7 +68,10 @@ class Network(BaseModelNetwork):
def __init__(self, a2c_network, **kwargs):
BaseModelNetwork.__init__(self,**kwargs)
self.a2c_network = a2c_network


def get_aux_loss(self):
return self.a2c_network.get_aux_loss()

def is_rnn(self):
return self.a2c_network.is_rnn()

Expand Down Expand Up @@ -121,6 +128,9 @@ def __init__(self, a2c_network, **kwargs):
BaseModelNetwork.__init__(self, **kwargs)
self.a2c_network = a2c_network

def get_aux_loss(self):
return self.a2c_network.get_aux_loss()

def is_rnn(self):
return self.a2c_network.is_rnn()

Expand Down Expand Up @@ -190,6 +200,9 @@ def __init__(self, a2c_network, **kwargs):
BaseModelNetwork.__init__(self, **kwargs)
self.a2c_network = a2c_network

def get_aux_loss(self):
return self.a2c_network.get_aux_loss()

def is_rnn(self):
return self.a2c_network.is_rnn()

Expand Down Expand Up @@ -248,6 +261,9 @@ def __init__(self, a2c_network, **kwargs):
BaseModelNetwork.__init__(self, **kwargs)
self.a2c_network = a2c_network

def get_aux_loss(self):
return self.a2c_network.get_aux_loss()

def is_rnn(self):
return self.a2c_network.is_rnn()

Expand Down Expand Up @@ -305,6 +321,9 @@ def __init__(self, a2c_network, **kwargs):
BaseModelNetwork.__init__(self, **kwargs)
self.a2c_network = a2c_network

def get_aux_loss(self):
return self.a2c_network.get_aux_loss()

def is_rnn(self):
return self.a2c_network.is_rnn()

Expand Down Expand Up @@ -344,6 +363,9 @@ def __init__(self, sac_network,**kwargs):
BaseModelNetwork.__init__(self,**kwargs)
self.sac_network = sac_network

def get_aux_loss(self):
return self.sac_network.get_aux_loss()

def critic(self, obs, action):
return self.sac_network.critic(obs, action)

Expand Down
3 changes: 3 additions & 0 deletions rl_games/algos_torch/network_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ def is_rnn(self):
def get_default_rnn_state(self):
return None

def get_aux_loss(self):
return None

def _calc_input_size(self, input_shape,cnn_layers=None):
if cnn_layers is None:
assert(len(input_shape) == 1)
Expand Down
9 changes: 3 additions & 6 deletions rl_games/common/a2c_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,9 +321,7 @@ def __init__(self, base_name, params):
self.algo_observer = config['features']['observer']

self.soft_aug = config['features'].get('soft_augmentation', None)
self.has_soft_aug = self.soft_aug is not None
# soft augmentation not yet supported
assert not self.has_soft_aug
self.aux_loss_dict = {}

def trancate_gradients_and_step(self):
if self.multi_gpu:
Expand Down Expand Up @@ -374,6 +372,8 @@ def write_stats(self, total_time, epoch_num, step_time, play_time, update_time,
self.writer.add_scalar('losses/c_loss', torch_ext.mean_list(c_losses).item(), frame)

self.writer.add_scalar('losses/entropy', torch_ext.mean_list(entropies).item(), frame)
for k,v in self.aux_loss_dict.items():
self.writer.add_scalar('losses/' + k, torch_ext.mean_list(v).item(), frame)
self.writer.add_scalar('info/last_lr', last_lr * lr_mul, frame)
self.writer.add_scalar('info/lr_mul', lr_mul, frame)
self.writer.add_scalar('info/e_clip', self.e_clip * lr_mul, frame)
Expand Down Expand Up @@ -1357,9 +1357,6 @@ def train(self):
if len(b_losses) > 0:
self.writer.add_scalar('losses/bounds_loss', torch_ext.mean_list(b_losses).item(), frame)

if self.has_soft_aug:
self.writer.add_scalar('losses/aug_loss', np.mean(aug_losses), frame)

if self.game_rewards.current_size > 0:
mean_rewards = self.game_rewards.get_mean()
mean_shaped_rewards = self.game_shaped_rewards.get_mean()
Expand Down
52 changes: 52 additions & 0 deletions rl_games/configs/test/test_discrite_testnet_aux_loss.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
params:
algo:
name: a2c_discrete

model:
name: discrete_a2c

network:
name: testnet_aux_loss
config:
reward_shaper:
scale_value: 1
normalize_advantage: True
gamma: 0.99
tau: 0.9
learning_rate: 2e-4
name: test_md_multi_obs
score_to_win: 0.95
grad_norm: 10.5
entropy_coef: 0.005
truncate_grads: True
env_name: test_env
e_clip: 0.2
clip_value: False
num_actors: 16
horizon_length: 256
minibatch_size: 2048
mini_epochs: 4
critic_coef: 1
lr_schedule: None
kl_threshold: 0.008
normalize_input: False
normalize_value: False
weight_decay: 0.0000
max_epochs: 10000
seq_length: 16
save_best_after: 10
save_frequency: 20

env_config:
name: TestRnnEnv-v0
hide_object: False
apply_dist_reward: False
min_dist: 2
max_dist: 8
use_central_value: True
multi_obs_space: True
multi_head_value: False
aux_loss: True
player:
games_num: 100
deterministic: True
5 changes: 3 additions & 2 deletions rl_games/envs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@


from rl_games.envs.test_network import TestNetBuilder
from rl_games.envs.test_network import TestNetBuilder, TestNetAuxLossBuilder
from rl_games.algos_torch import model_builder

model_builder.register_network('testnet', TestNetBuilder)
model_builder.register_network('testnet', TestNetBuilder)
model_builder.register_network('testnet_aux_loss', TestNetAuxLossBuilder)
10 changes: 10 additions & 0 deletions rl_games/envs/test/rnn_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def __init__(self, **kwargs):
self.apply_dist_reward = kwargs.pop('apply_dist_reward', False)
self.apply_exploration_reward = kwargs.pop('apply_exploration_reward', False)
self.multi_head_value = kwargs.pop('multi_head_value', False)
self.aux_loss = kwargs.pop('aux_loss', False)
if self.multi_head_value:
self.value_size = 2
else:
Expand All @@ -33,6 +34,8 @@ def __init__(self, **kwargs):
'pos': gym.spaces.Box(low=0, high=1, shape=(2, ), dtype=np.float32),
'info': gym.spaces.Box(low=0, high=1, shape=(4, ), dtype=np.float32),
}
if self.aux_loss:
spaces['aux_target'] = gym.spaces.Box(low=0, high=1, shape=(1, ), dtype=np.float32)
self.observation_space = gym.spaces.Dict(spaces)
else:
self.observation_space = gym.spaces.Box(low=0, high=1, shape=(6, ), dtype=np.float32)
Expand All @@ -58,6 +61,9 @@ def reset(self):
'pos': obs[:2],
'info': obs[2:]
}
if self.aux_loss:
aux_target = np.sum((self._goal_pos - self._current_pos)**2) / bound**2
obs['aux_target'] = np.expand_dims(aux_target.astype(np.float32), axis=0)
if self.use_central_value:
obses = {}
obses["obs"] = obs
Expand Down Expand Up @@ -93,6 +99,7 @@ def step_multi_categorical(self, action):
def step(self, action):
info = {}
self._curr_steps += 1
bound = self.max_dist - self.min_dist
if self.multi_discrete_space:
self.step_multi_categorical(action)
else:
Expand Down Expand Up @@ -125,6 +132,9 @@ def step(self, action):
'pos': obs[:2],
'info': obs[2:]
}
if self.aux_loss:
aux_target = np.sum((self._goal_pos - self._current_pos)**2) / bound**2
obs['aux_target'] = np.expand_dims(aux_target.astype(np.float32), axis=0)
if self.use_central_value:
state = np.concatenate([self._current_pos, self._goal_pos, [show_object, self._curr_steps]], axis=None)
obses = {}
Expand Down
70 changes: 67 additions & 3 deletions rl_games/envs/test_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
from torch import nn
import torch.nn.functional as F


class TestNet(nn.Module):
from rl_games.algos_torch.network_builder import NetworkBuilder

class TestNet(NetworkBuilder.BaseNetwork):
def __init__(self, params, **kwargs):
nn.Module.__init__(self)
actions_num = kwargs.pop('actions_num')
Expand Down Expand Up @@ -38,7 +39,7 @@ def forward(self, obs):
return action, value, None


from rl_games.algos_torch.network_builder import NetworkBuilder


class TestNetBuilder(NetworkBuilder):
def __init__(self, **kwargs):
Expand All @@ -52,3 +53,66 @@ def build(self, name, **kwargs):

def __call__(self, name, **kwargs):
return self.build(name, **kwargs)



class TestNetWithAuxLoss(NetworkBuilder.BaseNetwork):
def __init__(self, params, **kwargs):
nn.Module.__init__(self)
actions_num = kwargs.pop('actions_num')
input_shape = kwargs.pop('input_shape')
num_inputs = 0

self.target_key = 'aux_target'
assert(type(input_shape) is dict)
for k,v in input_shape.items():
if self.target_key == k:
self.target_shape = v[0]
else:
num_inputs +=v[0]

self.central_value = params.get('central_value', False)
self.value_size = kwargs.pop('value_size', 1)
self.linear1 = nn.Linear(num_inputs, 256)
self.linear2 = nn.Linear(256, 128)
self.linear3 = nn.Linear(128, 64)
self.mean_linear = nn.Linear(64, actions_num)
self.value_linear = nn.Linear(64, 1)
self.aux_loss_linear = nn.Linear(64, self.target_shape)

self.aux_loss_map = {
'aux_dist_loss' : None
}
def is_rnn(self):
return False

def get_aux_loss(self):
return self.aux_loss_map

def forward(self, obs):
obs = obs['obs']
target_obs = obs[self.target_key]
obs = torch.cat([obs['pos'], obs['info']], axis=-1)
x = F.relu(self.linear1(obs))
x = F.relu(self.linear2(x))
x = F.relu(self.linear3(x))
action = self.mean_linear(x)
value = self.value_linear(x)
y = self.aux_loss_linear(x)
self.aux_loss_map['aux_dist_loss'] = torch.nn.functional.mse_loss(y, target_obs)
if self.central_value:
return value, None
return action, value, None

class TestNetAuxLossBuilder(NetworkBuilder):
def __init__(self, **kwargs):
NetworkBuilder.__init__(self)

def load(self, params):
self.params = params

def build(self, name, **kwargs):
return TestNetWithAuxLoss(self.params, **kwargs)

def __call__(self, name, **kwargs):
return self.build(name, **kwargs)

0 comments on commit 2fe6c5f

Please sign in to comment.