From d8645b2678c0d8a6e98a6e3f2b17f0ecfbff71ad Mon Sep 17 00:00:00 2001 From: Aleksei Petrenko Date: Wed, 28 Sep 2022 19:07:02 -0700 Subject: [PATCH] Small changes for DexPBT paper. 1) load env state from checkpoint in player (i.e. if we train with curriculum we want to load the same state). 2) Minor speedup disabling validate_args in distributions 3) added another way to stop training max_env_steps (#204) --- rl_games/algos_torch/models.py | 4 ++-- rl_games/algos_torch/players.py | 14 +++++++++++++- rl_games/algos_torch/torch_ext.py | 2 +- rl_games/common/a2c_common.py | 17 ++++++++--------- 4 files changed, 24 insertions(+), 13 deletions(-) diff --git a/rl_games/algos_torch/models.py b/rl_games/algos_torch/models.py index 920cbf54..79f6bf75 100644 --- a/rl_games/algos_torch/models.py +++ b/rl_games/algos_torch/models.py @@ -195,7 +195,7 @@ def forward(self, input_dict): prev_actions = input_dict.get('prev_actions', None) input_dict['obs'] = self.norm_obs(input_dict['obs']) mu, sigma, value, states = self.a2c_network(input_dict) - distr = torch.distributions.Normal(mu, sigma) + distr = torch.distributions.Normal(mu, sigma, validate_args=False) if is_train: entropy = distr.entropy().sum(dim=-1) @@ -246,7 +246,7 @@ def forward(self, input_dict): input_dict['obs'] = self.norm_obs(input_dict['obs']) mu, logstd, value, states = self.a2c_network(input_dict) sigma = torch.exp(logstd) - distr = torch.distributions.Normal(mu, sigma) + distr = torch.distributions.Normal(mu, sigma, validate_args=False) if is_train: entropy = distr.entropy().sum(dim=-1) prev_neglogp = self.neglogp(prev_actions, mu, sigma, logstd) diff --git a/rl_games/algos_torch/players.py b/rl_games/algos_torch/players.py index 20f22aae..9120bb94 100644 --- a/rl_games/algos_torch/players.py +++ b/rl_games/algos_torch/players.py @@ -74,6 +74,10 @@ def restore(self, fn): if self.normalize_input and 'running_mean_std' in checkpoint: self.model.running_mean_std.load_state_dict(checkpoint['running_mean_std']) + env_state = checkpoint.get('env_state', None) + if self.env is not None and env_state is not None: + self.env.set_env_state(env_state) + def reset(self): self.init_rnn() @@ -172,6 +176,10 @@ def restore(self, fn): if self.normalize_input and 'running_mean_std' in checkpoint: self.model.running_mean_std.load_state_dict(checkpoint['running_mean_std']) + env_state = checkpoint.get('env_state', None) + if self.env is not None and env_state is not None: + self.env.set_env_state(env_state) + def reset(self): self.init_rnn() @@ -210,6 +218,10 @@ def restore(self, fn): if self.normalize_input and 'running_mean_std' in checkpoint: self.model.running_mean_std.load_state_dict(checkpoint['running_mean_std']) + env_state = checkpoint.get('env_state', None) + if self.env is not None and env_state is not None: + self.env.set_env_state(env_state) + def get_action(self, obs, is_deterministic=False): if self.has_batch_dimension == False: obs = unsqueeze_obs(obs) @@ -221,4 +233,4 @@ def get_action(self, obs, is_deterministic=False): return actions def reset(self): - pass \ No newline at end of file + pass diff --git a/rl_games/algos_torch/torch_ext.py b/rl_games/algos_torch/torch_ext.py index 168d9b8c..b35033ad 100644 --- a/rl_games/algos_torch/torch_ext.py +++ b/rl_games/algos_torch/torch_ext.py @@ -84,7 +84,7 @@ def load_checkpoint(filename): return state def parameterized_truncated_normal(uniform, mu, sigma, a, b): - normal = torch.distributions.normal.Normal(0, 1) + normal = torch.distributions.normal.Normal(0, 1, validate_args=False) alpha = (a - mu) / sigma beta = (b - mu) / sigma diff --git a/rl_games/common/a2c_common.py b/rl_games/common/a2c_common.py index 3c0be1a7..b9cee955 100644 --- a/rl_games/common/a2c_common.py +++ b/rl_games/common/a2c_common.py @@ -134,6 +134,7 @@ def __init__(self, base_name, params): self.ppo = config.get('ppo', True) self.max_epochs = self.config.get('max_epochs', 1e6) + self.max_frames = self.config.get('max_frames', 1e10) self.is_adaptive_lr = config['lr_schedule'] == 'adaptive' self.linear_lr = config['lr_schedule'] == 'linear' @@ -932,7 +933,7 @@ def train(self): fps_step = curr_frames / step_time fps_step_inference = curr_frames / scaled_play_time fps_total = curr_frames / scaled_time - print(f'fps step: {fps_step:.0f} fps step and policy inference: {fps_step_inference:.0f} fps total: {fps_total:.0f} epoch: {epoch_num}/{self.max_epochs}') + print(f'fps step: {fps_step:.0f} fps step and policy inference: {fps_step_inference:.0f} fps total: {fps_total:.0f} epoch: {epoch_num}/{self.max_epochs} frame: {self.frame}/{self.max_frames}') self.write_stats(total_time, epoch_num, step_time, play_time, update_time, a_losses, c_losses, entropies, kls, last_lr, lr_mul, frame, scaled_time, scaled_play_time, curr_frames) @@ -974,14 +975,12 @@ def train(self): self.save(os.path.join(self.nn_dir, checkpoint_name)) should_exit = True - if epoch_num >= self.max_epochs: + if epoch_num >= self.max_epochs or self.frame >= self.max_frames: if self.game_rewards.current_size == 0: print('WARNING: Max epochs reached before any env terminated at least once') mean_rewards = -np.inf - self.save(os.path.join(self.nn_dir, - 'last_' + self.config['name'] + 'ep' + str(epoch_num) + 'rew' + str( - mean_rewards))) + self.save(os.path.join(self.nn_dir, 'last_' + self.config['name'] + 'ep' + str(epoch_num) + 'rew' + str(mean_rewards))) print('MAX EPOCHS NUM!') should_exit = True update_time = 0 @@ -1191,7 +1190,7 @@ def train(self): fps_step = curr_frames / step_time fps_step_inference = curr_frames / scaled_play_time fps_total = curr_frames / scaled_time - print(f'fps step: {fps_step:.0f} fps step and policy inference: {fps_step_inference:.0f} fps total: {fps_total:.0f} epoch: {epoch_num}/{self.max_epochs}') + print(f'fps step: {fps_step:.0f} fps step and policy inference: {fps_step_inference:.0f} fps total: {fps_total:.0f} epoch: {epoch_num}/{self.max_epochs} frame: {self.frame}/{self.max_frames}') self.write_stats(total_time, epoch_num, step_time, play_time, update_time, a_losses, c_losses, entropies, kls, last_lr, lr_mul, frame, scaled_time, scaled_play_time, curr_frames) if len(b_losses) > 0: @@ -1235,11 +1234,11 @@ def train(self): self.save(os.path.join(self.nn_dir, checkpoint_name)) should_exit = True - if epoch_num >= self.max_epochs: + if epoch_num >= self.max_epochs or self.frame >= self.max_frames: if self.game_rewards.current_size == 0: print('WARNING: Max epochs reached before any env terminated at least once') mean_rewards = -np.inf - self.save(os.path.join(self.nn_dir, 'last_' + self.config['name'] + 'ep' + str(epoch_num) + 'rew' + str(mean_rewards))) + self.save(os.path.join(self.nn_dir, 'last_' + self.config['name'] + 'ep' + str(epoch_num) + 'rew' + str(mean_rewards).replace('[', '_').replace(']', '_'))) print('MAX EPOCHS NUM!') should_exit = True @@ -1253,4 +1252,4 @@ def train(self): return self.last_mean_rewards, epoch_num if should_exit: - return self.last_mean_rewards, epoch_num \ No newline at end of file + return self.last_mean_rewards, epoch_num