Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added masks for envs #195

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions rl_games/algos_torch/a2c_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,9 @@ def calc_gradients(self, input_dict):
'obs' : obs_batch,
}

rnn_masks = None
env_masks = None
if self.is_rnn:
rnn_masks = input_dict['rnn_masks']
env_masks = input_dict['env_masks']
batch_dict['rnn_states'] = input_dict['rnn_states']
batch_dict['seq_length'] = self.seq_len

Expand All @@ -121,7 +121,7 @@ def calc_gradients(self, input_dict):
b_loss = self.bound_loss(mu)
else:
b_loss = torch.zeros(1, device=self.ppo_device)
losses, sum_mask = torch_ext.apply_masks([a_loss.unsqueeze(1), c_loss , entropy.unsqueeze(1), b_loss.unsqueeze(1)], rnn_masks)
losses, sum_mask = torch_ext.apply_masks([a_loss.unsqueeze(1), c_loss , entropy.unsqueeze(1), b_loss.unsqueeze(1)], env_masks)
a_loss, c_loss, entropy, b_loss = losses[0], losses[1], losses[2], losses[3]

loss = a_loss + 0.5 * c_loss * self.critic_coef - entropy * self.entropy_coef + b_loss * self.bounds_loss_coef
Expand All @@ -137,18 +137,18 @@ def calc_gradients(self, input_dict):
self.trancate_gradients_and_step()

with torch.no_grad():
reduce_kl = rnn_masks is None
reduce_kl = env_masks is None
kl_dist = torch_ext.policy_kl(mu.detach(), sigma.detach(), old_mu_batch, old_sigma_batch, reduce_kl)
if rnn_masks is not None:
kl_dist = (kl_dist * rnn_masks).sum() / rnn_masks.numel() #/ sum_mask
if env_masks is not None:
kl_dist = (kl_dist * env_masks).sum() / env_masks.numel() #/ sum_mask

self.diagnostics.mini_batch(self,
{
'values' : value_preds_batch,
'returns' : return_batch,
'new_neglogp' : action_log_probs,
'old_neglogp' : old_action_log_probs_batch,
'masks' : rnn_masks
'masks' : env_masks
}, curr_e_clip, 0)

self.train_result = (a_loss, c_loss, entropy, \
Expand Down
12 changes: 6 additions & 6 deletions rl_games/algos_torch/a2c_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,9 @@ def calc_gradients(self, input_dict):
}
if self.use_action_masks:
batch_dict['action_masks'] = input_dict['action_masks']
rnn_masks = None
env_masks = None
if self.is_rnn:
rnn_masks = input_dict['rnn_masks']
env_masks = input_dict['env_masks']
batch_dict['rnn_states'] = input_dict['rnn_states']
batch_dict['seq_length'] = self.seq_len
batch_dict['bptt_len'] = self.bptt_len
Expand All @@ -147,7 +147,7 @@ def calc_gradients(self, input_dict):
c_loss = torch.zeros(1, device=self.ppo_device)


losses, sum_mask = torch_ext.apply_masks([a_loss.unsqueeze(1), c_loss, entropy.unsqueeze(1)], rnn_masks)
losses, sum_mask = torch_ext.apply_masks([a_loss.unsqueeze(1), c_loss, entropy.unsqueeze(1)], env_masks)
a_loss, c_loss, entropy = losses[0], losses[1], losses[2]
loss = a_loss + 0.5 *c_loss * self.critic_coef - entropy * self.entropy_coef

Expand All @@ -162,8 +162,8 @@ def calc_gradients(self, input_dict):

with torch.no_grad():
kl_dist = 0.5 * ((old_action_log_probs_batch - action_log_probs)**2)
if rnn_masks is not None:
kl_dist = (kl_dist * rnn_masks).sum() / rnn_masks.numel() # / sum_mask
if env_masks is not None:
kl_dist = (kl_dist * env_masks).sum() / env_masks.numel() # / sum_mask
else:
kl_dist = kl_dist.mean()

Expand All @@ -173,7 +173,7 @@ def calc_gradients(self, input_dict):
'returns' : return_batch,
'new_neglogp' : action_log_probs,
'old_neglogp' : old_action_log_probs_batch,
'masks' : rnn_masks
'masks' : env_masks
}, curr_e_clip, 0)

self.train_result = (a_loss, c_loss, entropy, kl_dist,self.last_lr, lr_mul)
Expand Down
10 changes: 5 additions & 5 deletions rl_games/algos_torch/central_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def update_dataset(self, batch_dict):
returns = batch_dict['returns']
actions = batch_dict['actions']
dones = batch_dict['dones']
rnn_masks = batch_dict['rnn_masks']
env_masks = batch_dict['env_masks']
if self.num_agents > 1:
res = self.update_multiagent_tensors(value_preds, returns, actions, dones)
batch_dict['old_values'] = res[0]
Expand All @@ -130,8 +130,8 @@ def update_dataset(self, batch_dict):

batch_dict['rnn_states'] = states
if self.num_agents > 1:
rnn_masks = res[3]
batch_dict['rnn_masks'] = rnn_masks
env_masks = res[3]
batch_dict['env_masks'] = env_masks
self.dataset.update_values_dict(batch_dict)

def _preproc_obs(self, obs_batch):
Expand Down Expand Up @@ -222,7 +222,7 @@ def calc_gradients(self, batch):
returns_batch = batch['returns']
actions_batch = batch['actions']
dones_batch = batch['dones']
rnn_masks_batch = batch.get('rnn_masks')
env_masks_batch = batch.get('env_masks')

batch_dict = {'obs' : obs_batch,
'actions' : actions_batch,
Expand All @@ -234,7 +234,7 @@ def calc_gradients(self, batch):
res_dict = self.model(batch_dict)
values = res_dict['values']
loss = common_losses.critic_loss(value_preds_batch, values, self.e_clip, returns_batch, self.clip_value)
losses, _ = torch_ext.apply_masks([loss], rnn_masks_batch)
losses, _ = torch_ext.apply_masks([loss], env_masks_batch)
loss = losses[0]
if self.multi_gpu:
self.optimizer.zero_grad()
Expand Down
2 changes: 1 addition & 1 deletion rl_games/algos_torch/torch_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def policy_kl(p0_mu, p0_sigma, p1_mu, p1_sigma, reduce=True):
return kl

def mean_mask(input, mask, sum_mask):
return (input * rnn_masks).sum() / sum_mask
return (input * mask).sum() / sum_mask

def shape_whc_to_cwh(shape):
if len(shape) == 3:
Expand Down
64 changes: 38 additions & 26 deletions rl_games/common/a2c_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ def __init__(self, base_name, params):
self.observation_space = self.env_info['observation_space']
self.weight_decay = config.get('weight_decay', 0.0)
self.use_action_masks = config.get('use_action_masks', False)

self.has_env_masks = self.env_info.get('env_masks', False)
self.is_train = config.get('is_train', True)

self.central_value_config = self.config.get('central_value_config', None)
Expand Down Expand Up @@ -453,6 +455,19 @@ def preprocess_actions(self, actions):
actions = actions.cpu().numpy()
return actions

def get_env_masks(self):
'''
returns env masks from vectorised env.
For example return torch.ones((self.num_actors,), device=self.device, dtype=torch.uint8)
'''
if self.is_tensor_obses:
return self.vec_env.get_env_masks()
else:
return torch.from_numpy(self.vec_env.get_env_masks(), device=self.device)




def env_step(self, actions):
actions = self.preprocess_actions(actions)
obs, rewards, dones, infos = self.vec_env.step(actions)
Expand Down Expand Up @@ -624,6 +639,9 @@ def play_steps(self):
res_dict = self.get_masked_action_values(self.obs, masks)
else:
res_dict = self.get_action_values(self.obs)
if self.has_env_masks:
env_masks = self.get_env_masks()
self.experience_buffer.update_data('env_masks', n, env_masks)
self.experience_buffer.update_data('obses', n, self.obs['obs'])
self.experience_buffer.update_data('dones', n, self.dones)

Expand Down Expand Up @@ -692,6 +710,9 @@ def play_steps_rnn(self):
res_dict = self.get_masked_action_values(self.obs, masks)
else:
res_dict = self.get_action_values(self.obs)
if self.has_env_masks:
env_masks = self.get_env_masks()
self.experience_buffer.update_data('env_masks', n, env_masks)
self.rnn_states = res_dict['rnn_states']
self.experience_buffer.update_data('obses', n, self.obs['obs'])
self.experience_buffer.update_data('dones', n, self.dones.byte())
Expand Down Expand Up @@ -793,7 +814,7 @@ def train_epoch(self):

play_time_end = time.time()
update_time_start = time.time()
rnn_masks = batch_dict.get('rnn_masks', None)
env_masks = batch_dict.get('env_masks', None)

self.curr_frames = batch_dict.pop('played_frames')
self.prepare_dataset(batch_dict)
Expand Down Expand Up @@ -835,7 +856,7 @@ def train_epoch(self):
return batch_dict['step_time'], play_time, update_time, total_time, a_losses, c_losses, entropies, kls, last_lr, lr_mul

def prepare_dataset(self, batch_dict):
rnn_masks = batch_dict.get('rnn_masks', None)
env_masks = batch_dict.get('env_masks', None)

returns = batch_dict['returns']
values = batch_dict['values']
Expand All @@ -855,16 +876,11 @@ def prepare_dataset(self, batch_dict):
advantages = torch.sum(advantages, axis=1)

if self.normalize_advantage:
if self.is_rnn:
if self.normalize_rms_advantage:
advantages = self.advantage_mean_std(advantages, mask=rnn_masks)
else:
advantages = torch_ext.normalization_with_masks(advantages, rnn_masks)
if self.normalize_rms_advantage:
advantages = self.advantage_mean_std(advantages, mask=env_masks)
else:
if self.normalize_rms_advantage:
advantages = self.advantage_mean_std(advantages)
else:
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
advantages = torch_ext.normalization_with_masks(advantages, env_masks)


dataset_dict = {}
dataset_dict['old_values'] = values
Expand All @@ -875,7 +891,7 @@ def prepare_dataset(self, batch_dict):
dataset_dict['obs'] = obses
dataset_dict['dones'] = dones
dataset_dict['rnn_states'] = rnn_states
dataset_dict['rnn_masks'] = rnn_masks
dataset_dict['env_masks'] = env_masks

if self.use_action_masks:
dataset_dict['action_masks'] = batch_dict['action_masks']
Expand All @@ -890,7 +906,7 @@ def prepare_dataset(self, batch_dict):
dataset_dict['actions'] = actions
dataset_dict['dones'] = dones
dataset_dict['obs'] = batch_dict['states']
dataset_dict['rnn_masks'] = rnn_masks
dataset_dict['env_masks'] = env_masks
self.central_value_net.update_dataset(dataset_dict)

def train(self):
Expand Down Expand Up @@ -1024,6 +1040,8 @@ def init_tensors(self):
A2CBase.init_tensors(self)
self.update_list = ['actions', 'neglogpacs', 'values', 'mus', 'sigmas']
self.tensor_list = self.update_list + ['obses', 'states', 'dones']
if self.has_env_masks:
self.tensor_list += ['env_masks']

def train_epoch(self):
super().train_epoch()
Expand All @@ -1038,7 +1056,7 @@ def train_epoch(self):

play_time_end = time.time()
update_time_start = time.time()
rnn_masks = batch_dict.get('rnn_masks', None)
env_masks = batch_dict.get('env_masks', None)

self.set_train()
self.curr_frames = batch_dict.pop('played_frames')
Expand Down Expand Up @@ -1103,7 +1121,7 @@ def prepare_dataset(self, batch_dict):
mus = batch_dict['mus']
sigmas = batch_dict['sigmas']
rnn_states = batch_dict.get('rnn_states', None)
rnn_masks = batch_dict.get('rnn_masks', None)
env_masks = batch_dict.get('env_masks', None)

advantages = returns - values

Expand All @@ -1116,16 +1134,10 @@ def prepare_dataset(self, batch_dict):
advantages = torch.sum(advantages, axis=1)

if self.normalize_advantage:
if self.is_rnn:
if self.normalize_rms_advantage:
advantages = self.advantage_mean_std(advantages, mask=rnn_masks)
else:
advantages = torch_ext.normalization_with_masks(advantages, rnn_masks)
if self.normalize_rms_advantage:
advantages = self.advantage_mean_std(advantages, mask=env_masks)
else:
if self.normalize_rms_advantage:
advantages = self.advantage_mean_std(advantages)
else:
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
advantages = torch_ext.normalization_with_masks(advantages, env_masks)

dataset_dict = {}
dataset_dict['old_values'] = values
Expand All @@ -1136,7 +1148,7 @@ def prepare_dataset(self, batch_dict):
dataset_dict['obs'] = obses
dataset_dict['dones'] = dones
dataset_dict['rnn_states'] = rnn_states
dataset_dict['rnn_masks'] = rnn_masks
dataset_dict['env_masks'] = env_masks
dataset_dict['mu'] = mus
dataset_dict['sigma'] = sigmas

Expand All @@ -1150,7 +1162,7 @@ def prepare_dataset(self, batch_dict):
dataset_dict['actions'] = actions
dataset_dict['obs'] = batch_dict['states']
dataset_dict['dones'] = dones
dataset_dict['rnn_masks'] = rnn_masks
dataset_dict['env_masks'] = env_masks
self.central_value_net.update_dataset(dataset_dict)

def train(self):
Expand Down
6 changes: 5 additions & 1 deletion rl_games/common/experience.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,11 +294,13 @@ def __init__(self, env_info, algo_info, device, aux_tensor_dict=None):

self.num_agents = env_info.get('agents', 1)
self.action_space = env_info['action_space']

self.has_env_masks = env_info.get('env_masks', False)
print('self.has_env_masks', env_info )
self.num_actors = algo_info['num_actors']
self.horizon_length = algo_info['horizon_length']
self.has_central_value = algo_info['has_central_value']
self.use_action_masks = algo_info.get('use_action_masks', False)

batch_size = self.num_actors * self.num_agents
self.is_discrete = False
self.is_multi_discrete = False
Expand Down Expand Up @@ -337,6 +339,8 @@ def _init_from_env_info(self, env_info):
self.tensor_dict['values'] = self._create_tensor_from_space(val_space, obs_base_shape)
self.tensor_dict['neglogpacs'] = self._create_tensor_from_space(gym.spaces.Box(low=0, high=1,shape=(), dtype=np.float32), obs_base_shape)
self.tensor_dict['dones'] = self._create_tensor_from_space(gym.spaces.Box(low=0, high=1,shape=(), dtype=np.uint8), obs_base_shape)
if self.has_env_masks:
self.tensor_dict['env_masks'] = self._create_tensor_from_space(gym.spaces.Box(low=0, high=1,shape=(), dtype=np.uint8), obs_base_shape)
if self.is_discrete or self.is_multi_discrete:
self.tensor_dict['actions'] = self._create_tensor_from_space(gym.spaces.Box(low=0, high=1,shape=self.actions_shape, dtype=np.long), obs_base_shape)
if self.use_action_masks:
Expand Down
6 changes: 6 additions & 0 deletions rl_games/common/ivecenv.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ def get_env_state(self):
Can be used for stateful training sessions, i.e. with adaptive curriculums.
"""
return None

def get_env_masks(self):
"""
Return env masks to prevent training in particular states or envs.
"""
return None

def set_env_state(self, env_state):
pass