From 394ff209fe00a1e83fe7757e9aa5c79f18389a73 Mon Sep 17 00:00:00 2001 From: Denys Makoviichuk Date: Tue, 16 Aug 2022 16:26:32 -0700 Subject: [PATCH 1/2] added masked envs --- rl_games/algos_torch/a2c_continuous.py | 14 +++---- rl_games/algos_torch/a2c_discrete.py | 12 +++--- rl_games/algos_torch/central_value.py | 10 ++--- rl_games/algos_torch/torch_ext.py | 2 +- rl_games/common/a2c_common.py | 55 ++++++++++++++------------ rl_games/common/experience.py | 3 ++ 6 files changed, 52 insertions(+), 44 deletions(-) diff --git a/rl_games/algos_torch/a2c_continuous.py b/rl_games/algos_torch/a2c_continuous.py index 7a7c762d..ec1bfb88 100644 --- a/rl_games/algos_torch/a2c_continuous.py +++ b/rl_games/algos_torch/a2c_continuous.py @@ -95,9 +95,9 @@ def calc_gradients(self, input_dict): 'obs' : obs_batch, } - rnn_masks = None + env_masks = None if self.is_rnn: - rnn_masks = input_dict['rnn_masks'] + env_masks = input_dict['env_masks'] batch_dict['rnn_states'] = input_dict['rnn_states'] batch_dict['seq_length'] = self.seq_len @@ -121,7 +121,7 @@ def calc_gradients(self, input_dict): b_loss = self.bound_loss(mu) else: b_loss = torch.zeros(1, device=self.ppo_device) - losses, sum_mask = torch_ext.apply_masks([a_loss.unsqueeze(1), c_loss , entropy.unsqueeze(1), b_loss.unsqueeze(1)], rnn_masks) + losses, sum_mask = torch_ext.apply_masks([a_loss.unsqueeze(1), c_loss , entropy.unsqueeze(1), b_loss.unsqueeze(1)], env_masks) a_loss, c_loss, entropy, b_loss = losses[0], losses[1], losses[2], losses[3] loss = a_loss + 0.5 * c_loss * self.critic_coef - entropy * self.entropy_coef + b_loss * self.bounds_loss_coef @@ -137,10 +137,10 @@ def calc_gradients(self, input_dict): self.trancate_gradients_and_step() with torch.no_grad(): - reduce_kl = rnn_masks is None + reduce_kl = env_masks is None kl_dist = torch_ext.policy_kl(mu.detach(), sigma.detach(), old_mu_batch, old_sigma_batch, reduce_kl) - if rnn_masks is not None: - kl_dist = (kl_dist * rnn_masks).sum() / rnn_masks.numel() #/ sum_mask + if env_masks is not None: + kl_dist = (kl_dist * env_masks).sum() / env_masks.numel() #/ sum_mask self.diagnostics.mini_batch(self, { @@ -148,7 +148,7 @@ def calc_gradients(self, input_dict): 'returns' : return_batch, 'new_neglogp' : action_log_probs, 'old_neglogp' : old_action_log_probs_batch, - 'masks' : rnn_masks + 'masks' : env_masks }, curr_e_clip, 0) self.train_result = (a_loss, c_loss, entropy, \ diff --git a/rl_games/algos_torch/a2c_discrete.py b/rl_games/algos_torch/a2c_discrete.py index 3f50ae9e..db253864 100644 --- a/rl_games/algos_torch/a2c_discrete.py +++ b/rl_games/algos_torch/a2c_discrete.py @@ -126,9 +126,9 @@ def calc_gradients(self, input_dict): } if self.use_action_masks: batch_dict['action_masks'] = input_dict['action_masks'] - rnn_masks = None + env_masks = None if self.is_rnn: - rnn_masks = input_dict['rnn_masks'] + env_masks = input_dict['env_masks'] batch_dict['rnn_states'] = input_dict['rnn_states'] batch_dict['seq_length'] = self.seq_len batch_dict['bptt_len'] = self.bptt_len @@ -147,7 +147,7 @@ def calc_gradients(self, input_dict): c_loss = torch.zeros(1, device=self.ppo_device) - losses, sum_mask = torch_ext.apply_masks([a_loss.unsqueeze(1), c_loss, entropy.unsqueeze(1)], rnn_masks) + losses, sum_mask = torch_ext.apply_masks([a_loss.unsqueeze(1), c_loss, entropy.unsqueeze(1)], env_masks) a_loss, c_loss, entropy = losses[0], losses[1], losses[2] loss = a_loss + 0.5 *c_loss * self.critic_coef - entropy * self.entropy_coef @@ -162,8 +162,8 @@ def calc_gradients(self, input_dict): with torch.no_grad(): kl_dist = 0.5 * ((old_action_log_probs_batch - action_log_probs)**2) - if rnn_masks is not None: - kl_dist = (kl_dist * rnn_masks).sum() / rnn_masks.numel() # / sum_mask + if env_masks is not None: + kl_dist = (kl_dist * env_masks).sum() / env_masks.numel() # / sum_mask else: kl_dist = kl_dist.mean() @@ -173,7 +173,7 @@ def calc_gradients(self, input_dict): 'returns' : return_batch, 'new_neglogp' : action_log_probs, 'old_neglogp' : old_action_log_probs_batch, - 'masks' : rnn_masks + 'masks' : env_masks }, curr_e_clip, 0) self.train_result = (a_loss, c_loss, entropy, kl_dist,self.last_lr, lr_mul) diff --git a/rl_games/algos_torch/central_value.py b/rl_games/algos_torch/central_value.py index d0e1cd0e..7aaeb797 100644 --- a/rl_games/algos_torch/central_value.py +++ b/rl_games/algos_torch/central_value.py @@ -113,7 +113,7 @@ def update_dataset(self, batch_dict): returns = batch_dict['returns'] actions = batch_dict['actions'] dones = batch_dict['dones'] - rnn_masks = batch_dict['rnn_masks'] + env_masks = batch_dict['env_masks'] if self.num_agents > 1: res = self.update_multiagent_tensors(value_preds, returns, actions, dones) batch_dict['old_values'] = res[0] @@ -130,8 +130,8 @@ def update_dataset(self, batch_dict): batch_dict['rnn_states'] = states if self.num_agents > 1: - rnn_masks = res[3] - batch_dict['rnn_masks'] = rnn_masks + env_masks = res[3] + batch_dict['env_masks'] = env_masks self.dataset.update_values_dict(batch_dict) def _preproc_obs(self, obs_batch): @@ -222,7 +222,7 @@ def calc_gradients(self, batch): returns_batch = batch['returns'] actions_batch = batch['actions'] dones_batch = batch['dones'] - rnn_masks_batch = batch.get('rnn_masks') + env_masks_batch = batch.get('env_masks') batch_dict = {'obs' : obs_batch, 'actions' : actions_batch, @@ -234,7 +234,7 @@ def calc_gradients(self, batch): res_dict = self.model(batch_dict) values = res_dict['values'] loss = common_losses.critic_loss(value_preds_batch, values, self.e_clip, returns_batch, self.clip_value) - losses, _ = torch_ext.apply_masks([loss], rnn_masks_batch) + losses, _ = torch_ext.apply_masks([loss], env_masks_batch) loss = losses[0] if self.multi_gpu: self.optimizer.zero_grad() diff --git a/rl_games/algos_torch/torch_ext.py b/rl_games/algos_torch/torch_ext.py index 168d9b8c..c3f61606 100644 --- a/rl_games/algos_torch/torch_ext.py +++ b/rl_games/algos_torch/torch_ext.py @@ -36,7 +36,7 @@ def policy_kl(p0_mu, p0_sigma, p1_mu, p1_sigma, reduce=True): return kl def mean_mask(input, mask, sum_mask): - return (input * rnn_masks).sum() / sum_mask + return (input * mask).sum() / sum_mask def shape_whc_to_cwh(shape): if len(shape) == 3: diff --git a/rl_games/common/a2c_common.py b/rl_games/common/a2c_common.py index 3c0be1a7..12f02a62 100644 --- a/rl_games/common/a2c_common.py +++ b/rl_games/common/a2c_common.py @@ -107,6 +107,7 @@ def __init__(self, base_name, params): self.observation_space = self.env_info['observation_space'] self.weight_decay = config.get('weight_decay', 0.0) self.use_action_masks = config.get('use_action_masks', False) + self.has_env_masks = self.env_info.get('env_masks', False) self.is_train = config.get('is_train', True) self.central_value_config = self.config.get('central_value_config', None) @@ -453,6 +454,12 @@ def preprocess_actions(self, actions): actions = actions.cpu().numpy() return actions + def get_env_masks(self): + if self.is_tensor_obses: + return self.vec_env.get_env_masks() + else: + return torch.from_numpy(self.vec_env.get_env_masks()) + def env_step(self, actions): actions = self.preprocess_actions(actions) obs, rewards, dones, infos = self.vec_env.step(actions) @@ -624,6 +631,9 @@ def play_steps(self): res_dict = self.get_masked_action_values(self.obs, masks) else: res_dict = self.get_action_values(self.obs) + if self.has_env_masks: + env_masks = self.get_env_masks() + self.experience_buffer.update_data('env_masks', n, env_masks) self.experience_buffer.update_data('obses', n, self.obs['obs']) self.experience_buffer.update_data('dones', n, self.dones) @@ -692,6 +702,9 @@ def play_steps_rnn(self): res_dict = self.get_masked_action_values(self.obs, masks) else: res_dict = self.get_action_values(self.obs) + if self.has_env_masks: + env_masks = self.get_env_masks() + self.experience_buffer.update_data('env_masks', n, env_masks) self.rnn_states = res_dict['rnn_states'] self.experience_buffer.update_data('obses', n, self.obs['obs']) self.experience_buffer.update_data('dones', n, self.dones.byte()) @@ -793,7 +806,7 @@ def train_epoch(self): play_time_end = time.time() update_time_start = time.time() - rnn_masks = batch_dict.get('rnn_masks', None) + env_masks = batch_dict.get('env_masks', None) self.curr_frames = batch_dict.pop('played_frames') self.prepare_dataset(batch_dict) @@ -835,7 +848,7 @@ def train_epoch(self): return batch_dict['step_time'], play_time, update_time, total_time, a_losses, c_losses, entropies, kls, last_lr, lr_mul def prepare_dataset(self, batch_dict): - rnn_masks = batch_dict.get('rnn_masks', None) + env_masks = batch_dict.get('env_masks', None) returns = batch_dict['returns'] values = batch_dict['values'] @@ -855,16 +868,11 @@ def prepare_dataset(self, batch_dict): advantages = torch.sum(advantages, axis=1) if self.normalize_advantage: - if self.is_rnn: - if self.normalize_rms_advantage: - advantages = self.advantage_mean_std(advantages, mask=rnn_masks) - else: - advantages = torch_ext.normalization_with_masks(advantages, rnn_masks) + if self.normalize_rms_advantage: + advantages = self.advantage_mean_std(advantages, mask=env_masks) else: - if self.normalize_rms_advantage: - advantages = self.advantage_mean_std(advantages) - else: - advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) + advantages = torch_ext.normalization_with_masks(advantages, env_masks) + dataset_dict = {} dataset_dict['old_values'] = values @@ -875,7 +883,7 @@ def prepare_dataset(self, batch_dict): dataset_dict['obs'] = obses dataset_dict['dones'] = dones dataset_dict['rnn_states'] = rnn_states - dataset_dict['rnn_masks'] = rnn_masks + dataset_dict['env_masks'] = env_masks if self.use_action_masks: dataset_dict['action_masks'] = batch_dict['action_masks'] @@ -890,7 +898,7 @@ def prepare_dataset(self, batch_dict): dataset_dict['actions'] = actions dataset_dict['dones'] = dones dataset_dict['obs'] = batch_dict['states'] - dataset_dict['rnn_masks'] = rnn_masks + dataset_dict['env_masks'] = env_masks self.central_value_net.update_dataset(dataset_dict) def train(self): @@ -1024,6 +1032,8 @@ def init_tensors(self): A2CBase.init_tensors(self) self.update_list = ['actions', 'neglogpacs', 'values', 'mus', 'sigmas'] self.tensor_list = self.update_list + ['obses', 'states', 'dones'] + if self.has_env_masks: + self.tensor_list += ['env_masks'] def train_epoch(self): super().train_epoch() @@ -1038,7 +1048,7 @@ def train_epoch(self): play_time_end = time.time() update_time_start = time.time() - rnn_masks = batch_dict.get('rnn_masks', None) + env_masks = batch_dict.get('env_masks', None) self.set_train() self.curr_frames = batch_dict.pop('played_frames') @@ -1103,7 +1113,7 @@ def prepare_dataset(self, batch_dict): mus = batch_dict['mus'] sigmas = batch_dict['sigmas'] rnn_states = batch_dict.get('rnn_states', None) - rnn_masks = batch_dict.get('rnn_masks', None) + env_masks = batch_dict.get('env_masks', None) advantages = returns - values @@ -1117,15 +1127,10 @@ def prepare_dataset(self, batch_dict): if self.normalize_advantage: if self.is_rnn: - if self.normalize_rms_advantage: - advantages = self.advantage_mean_std(advantages, mask=rnn_masks) - else: - advantages = torch_ext.normalization_with_masks(advantages, rnn_masks) + if self.normalize_rms_advantage: + advantages = self.advantage_mean_std(advantages, mask=env_masks) else: - if self.normalize_rms_advantage: - advantages = self.advantage_mean_std(advantages) - else: - advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) + advantages = torch_ext.normalization_with_masks(advantages, env_masks) dataset_dict = {} dataset_dict['old_values'] = values @@ -1136,7 +1141,7 @@ def prepare_dataset(self, batch_dict): dataset_dict['obs'] = obses dataset_dict['dones'] = dones dataset_dict['rnn_states'] = rnn_states - dataset_dict['rnn_masks'] = rnn_masks + dataset_dict['env_masks'] = env_masks dataset_dict['mu'] = mus dataset_dict['sigma'] = sigmas @@ -1150,7 +1155,7 @@ def prepare_dataset(self, batch_dict): dataset_dict['actions'] = actions dataset_dict['obs'] = batch_dict['states'] dataset_dict['dones'] = dones - dataset_dict['rnn_masks'] = rnn_masks + dataset_dict['env_masks'] = env_masks self.central_value_net.update_dataset(dataset_dict) def train(self): diff --git a/rl_games/common/experience.py b/rl_games/common/experience.py index c62fd3c9..e1cace39 100644 --- a/rl_games/common/experience.py +++ b/rl_games/common/experience.py @@ -299,6 +299,7 @@ def __init__(self, env_info, algo_info, device, aux_tensor_dict=None): self.horizon_length = algo_info['horizon_length'] self.has_central_value = algo_info['has_central_value'] self.use_action_masks = algo_info.get('use_action_masks', False) + self.has_env_masks = algo_info.get('has_env_masks', False) batch_size = self.num_actors * self.num_agents self.is_discrete = False self.is_multi_discrete = False @@ -337,6 +338,8 @@ def _init_from_env_info(self, env_info): self.tensor_dict['values'] = self._create_tensor_from_space(val_space, obs_base_shape) self.tensor_dict['neglogpacs'] = self._create_tensor_from_space(gym.spaces.Box(low=0, high=1,shape=(), dtype=np.float32), obs_base_shape) self.tensor_dict['dones'] = self._create_tensor_from_space(gym.spaces.Box(low=0, high=1,shape=(), dtype=np.uint8), obs_base_shape) + if self.has_env_masks: + self.tensor_dict['env_masks'] = self._create_tensor_from_space(gym.spaces.Box(low=0, high=1,shape=(), dtype=np.uint8), obs_base_shape) if self.is_discrete or self.is_multi_discrete: self.tensor_dict['actions'] = self._create_tensor_from_space(gym.spaces.Box(low=0, high=1,shape=self.actions_shape, dtype=np.long), obs_base_shape) if self.use_action_masks: From 2f917718c5b557bac6aebf3337a8d7efa5d115a6 Mon Sep 17 00:00:00 2001 From: Denys Makoviichuk Date: Tue, 16 Aug 2022 18:03:49 -0700 Subject: [PATCH 2/2] env masks --- rl_games/common/a2c_common.py | 11 +++++++++-- rl_games/common/experience.py | 5 +++-- rl_games/common/ivecenv.py | 6 ++++++ 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/rl_games/common/a2c_common.py b/rl_games/common/a2c_common.py index 12f02a62..8cf76cf8 100644 --- a/rl_games/common/a2c_common.py +++ b/rl_games/common/a2c_common.py @@ -107,6 +107,7 @@ def __init__(self, base_name, params): self.observation_space = self.env_info['observation_space'] self.weight_decay = config.get('weight_decay', 0.0) self.use_action_masks = config.get('use_action_masks', False) + self.has_env_masks = self.env_info.get('env_masks', False) self.is_train = config.get('is_train', True) @@ -455,10 +456,17 @@ def preprocess_actions(self, actions): return actions def get_env_masks(self): + ''' + returns env masks from vectorised env. + For example return torch.ones((self.num_actors,), device=self.device, dtype=torch.uint8) + ''' if self.is_tensor_obses: return self.vec_env.get_env_masks() else: - return torch.from_numpy(self.vec_env.get_env_masks()) + return torch.from_numpy(self.vec_env.get_env_masks(), device=self.device) + + + def env_step(self, actions): actions = self.preprocess_actions(actions) @@ -1126,7 +1134,6 @@ def prepare_dataset(self, batch_dict): advantages = torch.sum(advantages, axis=1) if self.normalize_advantage: - if self.is_rnn: if self.normalize_rms_advantage: advantages = self.advantage_mean_std(advantages, mask=env_masks) else: diff --git a/rl_games/common/experience.py b/rl_games/common/experience.py index e1cace39..4c8b8be8 100644 --- a/rl_games/common/experience.py +++ b/rl_games/common/experience.py @@ -294,12 +294,13 @@ def __init__(self, env_info, algo_info, device, aux_tensor_dict=None): self.num_agents = env_info.get('agents', 1) self.action_space = env_info['action_space'] - + self.has_env_masks = env_info.get('env_masks', False) + print('self.has_env_masks', env_info ) self.num_actors = algo_info['num_actors'] self.horizon_length = algo_info['horizon_length'] self.has_central_value = algo_info['has_central_value'] self.use_action_masks = algo_info.get('use_action_masks', False) - self.has_env_masks = algo_info.get('has_env_masks', False) + batch_size = self.num_actors * self.num_agents self.is_discrete = False self.is_multi_discrete = False diff --git a/rl_games/common/ivecenv.py b/rl_games/common/ivecenv.py index 97c43cb4..8f9acf88 100644 --- a/rl_games/common/ivecenv.py +++ b/rl_games/common/ivecenv.py @@ -31,6 +31,12 @@ def get_env_state(self): Can be used for stateful training sessions, i.e. with adaptive curriculums. """ return None + + def get_env_masks(self): + """ + Return env masks to prevent training in particular states or envs. + """ + return None def set_env_state(self, env_state): pass