diff --git a/rl_games/algos_torch/a2c_continuous.py b/rl_games/algos_torch/a2c_continuous.py index f5f52fc4..bd1c0e78 100644 --- a/rl_games/algos_torch/a2c_continuous.py +++ b/rl_games/algos_torch/a2c_continuous.py @@ -51,6 +51,7 @@ def __init__(self, base_name, config): } self.central_value_net = central_value.CentralValueTrain(**cv_config).to(self.ppo_device) self.use_experimental_cv = self.config.get('use_experimental_cv', True) + self.dataset = datasets.PPODataset(self.batch_size, self.minibatch_size, self.is_discrete, self.is_rnn, self.ppo_device, self.seq_len) self.algo_observer.after_init(self) @@ -71,6 +72,7 @@ def get_masked_action_values(self, obs, action_masks): def calc_gradients(self, input_dict, opt_step): self.set_train() + value_preds_batch = input_dict['old_values'] old_action_log_probs_batch = input_dict['old_logp_actions'] advantage = input_dict['advantages'] @@ -89,16 +91,16 @@ def calc_gradients(self, input_dict, opt_step): batch_dict = { 'is_train': True, 'prev_actions': actions_batch, - 'obs' : obs_batch, + 'obs': obs_batch, } - rnn_masks = None - if self.is_rnn: - rnn_masks = input_dict['rnn_masks'] - batch_dict['rnn_states'] = input_dict['rnn_states'] - batch_dict['seq_length'] = self.seq_len - with torch.cuda.amp.autocast(enabled=self.mixed_precision): + rnn_masks = None + if self.is_rnn: + rnn_masks = input_dict['rnn_masks'] + batch_dict['rnn_states'] = input_dict['rnn_states'] + batch_dict['seq_length'] = self.seq_len + res_dict = self.model(batch_dict) action_log_probs = res_dict['prev_neglogp'] values = res_dict['value'] @@ -133,11 +135,12 @@ def calc_gradients(self, input_dict, opt_step): self.scaler.update() with torch.no_grad(): - reduce_kl = not self.is_rnn - kl_dist = torch_ext.policy_kl(mu.detach(), sigma.detach(), old_mu_batch, old_sigma_batch, reduce_kl) - if self.is_rnn: - kl_dist = (kl_dist * rnn_masks).sum() / sum_mask - kl_dist = kl_dist.item() + with torch.cuda.amp.autocast(enabled=self.mixed_precision): + reduce_kl = not self.is_rnn + kl_dist = torch_ext.policy_kl(mu.detach(), sigma.detach(), old_mu_batch, old_sigma_batch, reduce_kl) + if self.is_rnn: + kl_dist = (kl_dist * rnn_masks).sum() / sum_mask + kl_dist = kl_dist.item() self.train_result = (a_loss.item(), c_loss.item(), entropy.item(), \ kl_dist, self.last_lr, lr_mul, \ diff --git a/rl_games/algos_torch/central_value.py b/rl_games/algos_torch/central_value.py index f61283dc..5ae30974 100644 --- a/rl_games/algos_torch/central_value.py +++ b/rl_games/algos_torch/central_value.py @@ -31,30 +31,38 @@ def __init__(self, state_shape, value_size, ppo_device, num_agents, num_steps, n self.clip_value = config['clip_value'] self.normalize_input = config['normalize_input'] self.normalize_value = config.get('normalize_value', False) + self.running_mean_std = None + if self.normalize_input: + self.running_mean_std = RunningMeanStd(state_shape) + self.writter = writter self.use_joint_obs_actions = config.get('use_joint_obs_actions', False) self.optimizer = torch.optim.Adam(self.model.parameters(), float(self.lr), eps=1e-07) self.frame = 0 - self.running_mean_std = None + self.grad_norm = config.get('grad_norm', 1) self.truncate_grads = config.get('truncate_grads', False) self.e_clip = config.get('e_clip', 0.2) - if self.normalize_input: - self.running_mean_std = RunningMeanStd(state_shape) + + # todo - from the ьфшт config! + self.mixed_precision = self.config.get('mixed_precision', True) + self.scaler = torch.cuda.amp.GradScaler(enabled=self.mixed_precision) self.is_rnn = self.model.is_rnn() self.rnn_states = None self.batch_size = self.num_steps * self.num_actors + if self.is_rnn: - self.rnn_states = self.model.get_default_rnn_state() - self.rnn_states = [s.to(self.ppo_device) for s in self.rnn_states] - num_seqs = self.num_steps * self.num_actors // self.seq_len - assert((self.num_steps * self.num_actors // self.num_minibatches) % self.seq_len == 0) - self.mb_rnn_states = [torch.zeros((s.size()[0], num_seqs, s.size()[2]), dtype = torch.float32, device=self.ppo_device) for s in self.rnn_states] + with torch.cuda.amp.autocast(enabled=self.mixed_precision): + self.rnn_states = self.model.get_default_rnn_state() + self.rnn_states = [s.to(self.ppo_device) for s in self.rnn_states] + num_seqs = self.num_steps * self.num_actors // self.seq_len + assert((self.num_steps * self.num_actors // self.num_minibatches) % self.seq_len == 0) + self.mb_rnn_states = [torch.zeros((s.size()[0], num_seqs, s.size()[2]), dtype = torch.float32, device=self.ppo_device) for s in self.rnn_states] self.dataset = datasets.PPODataset(self.batch_size, self.mini_batch, True, self.is_rnn, self.ppo_device, self.seq_len) - def get_stats_weights(self): + def get_stats_weights(self): if self.normalize_input: return self.running_mean_std.state_dict() else: @@ -62,18 +70,19 @@ def get_stats_weights(self): def set_stats_weights(self, weights): self.running_mean_std.load_state_dict(weights) - + def update_dataset(self, batch_dict): value_preds = batch_dict['old_values'] returns = batch_dict['returns'] actions = batch_dict['actions'] rnn_masks = batch_dict['rnn_masks'] + if self.num_agents > 1: res = self.update_multiagent_tensors(value_preds, returns, actions, rnn_masks) batch_dict['old_values'] = res[0] batch_dict['returns'] = res[1] batch_dict['actions'] = res[2] - + if self.is_rnn: batch_dict['rnn_states'] = self.mb_rnn_states if self.num_agents > 1: @@ -90,9 +99,11 @@ def _preproc_obs(self, obs_batch): obs_batch = obs_batch.permute((0, 3, 1, 2)) if self.normalize_input: obs_batch = self.running_mean_std(obs_batch) + return obs_batch def pre_step_rnn(self, rnn_indices, state_indices): + #with torch.cuda.amp.autocast(enabled=self.mixed_precision): if self.num_agents > 1: rnn_indices = rnn_indices[::self.num_agents] shifts = rnn_indices % (self.num_steps // self.seq_len) @@ -105,10 +116,12 @@ def pre_step_rnn(self, rnn_indices, state_indices): def post_step_rnn(self, all_done_indices): all_done_indices = all_done_indices[::self.num_agents] // self.num_agents for s in self.rnn_states: - s[:,all_done_indices,:] = s[:,all_done_indices,:] * 0.0 + s[:, all_done_indices, :] = s[:, all_done_indices, :] * 0.0 def forward(self, input_dict): + #with torch.cuda.amp.autocast(enabled=self.mixed_precision): value, rnn_states = self.model(input_dict) + return value, rnn_states def get_value(self, input_dict): @@ -118,7 +131,7 @@ def get_value(self, input_dict): actions = input_dict.get('actions', None) obs_batch = self._preproc_obs(obs_batch) - value, self.rnn_states = self.forward({'obs' : obs_batch, 'actions': actions, + value, self.rnn_states = self.forward({'obs': obs_batch, 'actions': actions, 'rnn_states': self.rnn_states}) if self.num_agents > 1: value = value.repeat(1, self.num_agents) @@ -135,19 +148,20 @@ def train_critic(self, input_dict, opt_step = True): def update_multiagent_tensors(self, value_preds, returns, actions, rnn_masks): batch_size = self.batch_size ma_batch_size = self.num_actors * self.num_agents * self.num_steps - value_preds = value_preds.view(self.num_actors, self.num_agents, self.num_steps, self.value_size).transpose(0,1) - returns = returns.view(self.num_actors, self.num_agents, self.num_steps, self.value_size).transpose(0,1) + value_preds = value_preds.view(self.num_actors, self.num_agents, self.num_steps, self.value_size).transpose(0, 1) + returns = returns.view(self.num_actors, self.num_agents, self.num_steps, self.value_size).transpose(0, 1) value_preds = value_preds.contiguous().view(ma_batch_size, self.value_size)[:batch_size] returns = returns.contiguous().view(ma_batch_size, self.value_size)[:batch_size] if self.use_joint_obs_actions: - assert(len(actions.size()) == 2, 'use_joint_obs_actions not yet supported in continuous environment for central value') - actions = actions.view(self.num_actors, self.num_agents, self.num_steps).transpose(0,1) + assert(len(actions.size() == 2), 'use_joint_obs_actions not yet supported in continuous environment for central value') + actions = actions.view(self.num_actors, self.num_agents, self.num_steps).transpose(0, 1) actions = actions.contiguous().view(batch_size, self.num_agents) - + if self.is_rnn: - rnn_masks = rnn_masks.view(self.num_actors, self.num_agents, self.num_steps).transpose(0,1) - rnn_masks = rnn_masks.flatten(0)[:batch_size] + rnn_masks = rnn_masks.view(self.num_actors, self.num_agents, self.num_steps).transpose(0, 1) + rnn_masks = rnn_masks.flatten(0)[:batch_size] + return value_preds, returns, actions, rnn_masks def train_net(self): @@ -157,33 +171,42 @@ def train_net(self): for idx in range(len(self.dataset)): loss += self.train_critic(self.dataset[idx]) avg_loss = loss / (self.mini_epoch * self.num_minibatches) + self.writter.add_scalar('losses/cval_loss', avg_loss, self.frame) self.frame += self.batch_size + return avg_loss def calc_gradients(self, batch, opt_step): - obs_batch = self._preproc_obs(batch['obs']) - value_preds_batch = batch['old_values'] - returns_batch = batch['returns'] - actions_batch = batch['actions'] - rnn_masks_batch = batch.get('rnn_masks') - - batch_dict = {'obs' : obs_batch, - 'actions' : actions_batch, - 'seq_length' : self.seq_len } - if self.is_rnn: - batch_dict['rnn_states'] = batch['rnn_states'] - - values, _ = self.forward(batch_dict) - loss = common_losses.critic_loss(value_preds_batch, values, self.e_clip, returns_batch, self.clip_value) - losses, _ = torch_ext.apply_masks([loss], rnn_masks_batch) - loss = losses[0] + with torch.cuda.amp.autocast(enabled=self.mixed_precision): + obs_batch = self._preproc_obs(batch['obs']) + value_preds_batch = batch['old_values'] + returns_batch = batch['returns'] + actions_batch = batch['actions'] + rnn_masks_batch = batch.get('rnn_masks') + + batch_dict = {'obs': obs_batch, + 'actions': actions_batch, + 'seq_length': self.seq_len } + + if self.is_rnn: + batch_dict['rnn_states'] = batch['rnn_states'] + + values, _ = self.forward(batch_dict) + loss = common_losses.critic_loss(value_preds_batch, values, self.e_clip, returns_batch, self.clip_value) + losses, _ = torch_ext.apply_masks([loss], rnn_masks_batch) + loss = losses[0] for param in self.model.parameters(): param.grad = None - loss.backward() - if self.truncate_grads: + + self.scaler.scale(loss).backward() + if self.config['truncate_grads']: + self.scaler.unscale_(self.optimizer) nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_norm) + if opt_step: - self.optimizer.step() + self.scaler.step(self.optimizer) + self.scaler.update() + return loss diff --git a/rl_games/common/a2c_common.py b/rl_games/common/a2c_common.py index 249e515e..5b2202a9 100644 --- a/rl_games/common/a2c_common.py +++ b/rl_games/common/a2c_common.py @@ -50,7 +50,7 @@ def __init__(self, base_name, config): self.ppo_device = config.get('device', 'cuda:0') print('Env info:') print(self.env_info) - self.value_size = self.env_info.get('value_size',1) + self.value_size = self.env_info.get('value_size', 1) self.observation_space = self.env_info['observation_space'] self.weight_decay = config.get('weight_decay', 0.0) self.use_action_masks = config.get('use_action_masks', False) @@ -86,7 +86,7 @@ def __init__(self, base_name, config): self.scheduler = schedulers.AdaptiveScheduler(self.lr_threshold) elif self.linear_lr: self.scheduler = schedulers.LinearScheduler(float(config['learning_rate']), - max_steps=self.max_epochs, + max_steps=self.max_epochs, apply_to_entropy=config.get('schedule_entropy', False), start_entropy_coef=config.get('entropy_coef')) else: @@ -175,8 +175,8 @@ def get_action_values(self, obs): input_dict = { 'is_train': False, 'prev_actions': None, - 'obs' : processed_obs, - 'rnn_states' : self.rnn_states + 'obs': processed_obs, + 'rnn_states': self.rnn_states } with torch.no_grad(): @@ -185,7 +185,7 @@ def get_action_values(self, obs): states = obs['states'] input_dict = { 'is_train': False, - 'states' : states, + 'states': states, #'actions' : res_dict['action'], #'rnn_states' : self.rnn_states } @@ -213,8 +213,8 @@ def get_values(self, obs): input_dict = { 'is_train': False, 'prev_actions': None, - 'obs' : processed_obs, - 'rnn_states' : self.rnn_states + 'obs': processed_obs, + 'rnn_states': self.rnn_states } result = self.model(input_dict) value = result['value'] @@ -285,6 +285,7 @@ def process_rnn_indices(self, mb_rnn_masks, indices, steps_mask, steps_state, mb self.last_rnn_indices = rnn_indices self.last_state_indices = state_indices + return seq_indices, False def process_rnn_dones(self, all_done_indices, indices, seq_indices): @@ -295,7 +296,6 @@ def process_rnn_dones(self, all_done_indices, indices, seq_indices): s[:,all_done_indices,:] = s[:,all_done_indices,:] * 0.0 indices += 1 - def cast_obs(self, obs): if isinstance(obs, torch.Tensor): self.is_tensor_obses = True @@ -445,7 +445,7 @@ def set_stats_weights(self, weights): self.value_mean_std.load_state_dict(weights['reward_mean_std']) if self.has_central_value: self.central_value_net.set_stats_weights(state['assymetric_vf_mean_std']) - + def set_weights(self, weights): self.model.load_state_dict(weights['model']) if self.normalize_input: @@ -490,18 +490,18 @@ def play_steps(self): else: res_dict = self.get_action_values(self.obs) - mb_obs[n,:] = self.obs['obs'] - mb_dones[n,:] = self.dones + mb_obs[n, :] = self.obs['obs'] + mb_dones[n, :] = self.dones for k in update_list: - tensors_dict[k][n,:] = res_dict[k] + tensors_dict[k][n, :] = res_dict[k] if self.has_central_value: - mb_vobs[n,:] = self.obs['states'] + mb_vobs[n, :] = self.obs['states'] self.obs, rewards, self.dones, infos = self.env_step(res_dict['action']) shaped_rewards = self.rewards_shaper(rewards) - mb_rewards[n,:] = shaped_rewards + mb_rewards[n, :] = shaped_rewards self.current_rewards += rewards self.current_lengths += 1 @@ -517,7 +517,7 @@ def play_steps(self): self.current_rewards = self.current_rewards * not_dones.unsqueeze(1) self.current_lengths = self.current_lengths * not_dones - + if self.has_central_value and self.central_value_net.use_joint_obs_actions: if self.use_action_masks: masks = self.vec_env.get_action_masks() @@ -536,9 +536,9 @@ def play_steps(self): mb_advs = self.discount_values(fdones, last_extrinsic_values, mb_fdones, mb_extrinsic_values, mb_rewards) mb_returns = mb_advs + mb_extrinsic_values batch_dict = { - 'obs' : mb_obs, - 'returns' : mb_returns, - 'dones' : mb_dones, + 'obs': mb_obs, + 'returns': mb_returns, + 'dones': mb_dones, } for k in update_list: batch_dict[update_dict[k]] = tensors_dict[k] @@ -575,6 +575,7 @@ def play_steps_rnn(self): seq_indices, full_tensor = self.process_rnn_indices(mb_rnn_masks, indices, steps_mask, steps_state, mb_rnn_states) if full_tensor: break + if self.has_central_value: self.central_value_net.pre_step_rnn(self.last_rnn_indices, self.last_state_indices) @@ -587,13 +588,13 @@ def play_steps_rnn(self): self.rnn_states = res_dict['rnn_state'] mb_dones[indices, play_mask] = self.dones.byte() - mb_obs[indices,play_mask] = self.obs['obs'] + mb_obs[indices, play_mask] = self.obs['obs'] for k in update_list: - tensors_dict[k][indices,play_mask] = res_dict[k] + tensors_dict[k][indices, play_mask] = res_dict[k] if self.has_central_value: - mb_vobs[indices[::self.num_agents] ,play_mask[::self.num_agents]//self.num_agents] = self.obs['states'] + mb_vobs[indices[::self.num_agents], play_mask[::self.num_agents]//self.num_agents] = self.obs['states'] self.obs, rewards, self.dones, infos = self.env_step(res_dict['action']) @@ -665,6 +666,7 @@ def play_steps_rnn(self): return batch_dict + class DiscreteA2CBase(A2CBase): def __init__(self, base_name, config): A2CBase.__init__(self, base_name, config) @@ -687,14 +689,14 @@ def init_tensors(self): self.update_list = ['action', 'neglogp', 'value'] self.update_dict = { - 'action' : 'actions', - 'neglogp' : 'neglogpacs', - 'value' : 'values', + 'action': 'actions', + 'neglogp': 'neglogpacs', + 'value': 'values', } self.tensors_dict = { - 'action' : self.mb_actions, - 'neglogp' : self.mb_neglogpacs, - 'value' : self.mb_values, + 'action': self.mb_actions, + 'neglogp': self.mb_neglogpacs, + 'value': self.mb_values, } if self.use_action_masks: self.mb_action_masks = torch.zeros((self.steps_num, batch_size, np.sum(self.actions_num)), dtype = torch.bool, device=self.ppo_device) @@ -725,7 +727,7 @@ def train_epoch(self): self.train_central_value() if self.is_rnn: - print('non masked rnn obs ratio: ',rnn_masks.sum().item() / (rnn_masks.nelement())) + print('non masked rnn obs ratio: ', rnn_masks.sum().item() / (rnn_masks.nelement())) for _ in range(0, self.mini_epochs_num): ep_kls = [] @@ -898,27 +900,29 @@ def init_tensors(self): self.update_list = ['action', 'neglogp', 'value', 'mu', 'sigma'] self.update_dict = { - 'action' : 'actions', - 'neglogp' : 'neglogpacs', - 'value' : 'values', - 'mu' : 'mus', - 'sigma' : 'sigmas' + 'action': 'actions', + 'neglogp': 'neglogpacs', + 'value': 'values', + 'mu': 'mus', + 'sigma': 'sigmas' } self.tensors_dict = { - 'action' : self.mb_actions, - 'neglogp' : self.mb_neglogpacs, - 'value' : self.mb_values, - 'mu' : self.mb_mus, - 'sigma' : self.mb_sigmas, + 'action': self.mb_actions, + 'neglogp': self.mb_neglogpacs, + 'value': self.mb_values, + 'mu': self.mb_mus, + 'sigma': self.mb_sigmas, } def train_epoch(self): play_time_start = time.time() + #with torch.cuda.amp.autocast(self.mixed_precision): with torch.no_grad(): if self.is_rnn: batch_dict = self.play_steps_rnn() else: - batch_dict = self.play_steps() + batch_dict = self.play_steps() + play_time_end = time.time() update_time_start = time.time() @@ -935,7 +939,7 @@ def train_epoch(self): b_losses = [] entropies = [] kls = [] - + if self.is_rnn: frames_mask_ratio = rnn_masks.sum().item() / (rnn_masks.nelement()) print(frames_mask_ratio) @@ -949,17 +953,17 @@ def train_epoch(self): c_losses.append(c_loss) ep_kls.append(kl) entropies.append(entropy) - + if self.bounds_loss_coef is not None: b_losses.append(b_loss) self.dataset.update_mu_sigma(cmu, csigma) if self.schedule_type == 'legacy': - self.last_lr, self.entropy_coef = self.scheduler.update(self.last_lr, self.entropy_coef, self.epoch_num, 0,kl) + self.last_lr, self.entropy_coef = self.scheduler.update(self.last_lr, self.entropy_coef, self.epoch_num, 0, kl) self.update_lr(self.last_lr) if self.schedule_type == 'standard': - self.last_lr, self.entropy_coef = self.scheduler.update(self.last_lr, self.entropy_coef, self.epoch_num, 0,kl) + self.last_lr, self.entropy_coef = self.scheduler.update(self.last_lr, self.entropy_coef, self.epoch_num, 0, kl) self.update_lr(self.last_lr) kls.append(np.mean(ep_kls)) @@ -1062,7 +1066,7 @@ def train(self): self.writer.add_scalar('info/epochs', epoch_num, frame) self.algo_observer.after_print_stats(frame, epoch_num, total_time) - + if self.game_rewards.current_size > 0: mean_rewards = self.game_rewards.get_mean() mean_lengths = self.game_lengths.get_mean()