diff --git a/rl_games/algos_torch/models.py b/rl_games/algos_torch/models.py index 920cbf54..79f6bf75 100644 --- a/rl_games/algos_torch/models.py +++ b/rl_games/algos_torch/models.py @@ -195,7 +195,7 @@ def forward(self, input_dict): prev_actions = input_dict.get('prev_actions', None) input_dict['obs'] = self.norm_obs(input_dict['obs']) mu, sigma, value, states = self.a2c_network(input_dict) - distr = torch.distributions.Normal(mu, sigma) + distr = torch.distributions.Normal(mu, sigma, validate_args=False) if is_train: entropy = distr.entropy().sum(dim=-1) @@ -246,7 +246,7 @@ def forward(self, input_dict): input_dict['obs'] = self.norm_obs(input_dict['obs']) mu, logstd, value, states = self.a2c_network(input_dict) sigma = torch.exp(logstd) - distr = torch.distributions.Normal(mu, sigma) + distr = torch.distributions.Normal(mu, sigma, validate_args=False) if is_train: entropy = distr.entropy().sum(dim=-1) prev_neglogp = self.neglogp(prev_actions, mu, sigma, logstd) diff --git a/rl_games/algos_torch/players.py b/rl_games/algos_torch/players.py index 20f22aae..9120bb94 100644 --- a/rl_games/algos_torch/players.py +++ b/rl_games/algos_torch/players.py @@ -74,6 +74,10 @@ def restore(self, fn): if self.normalize_input and 'running_mean_std' in checkpoint: self.model.running_mean_std.load_state_dict(checkpoint['running_mean_std']) + env_state = checkpoint.get('env_state', None) + if self.env is not None and env_state is not None: + self.env.set_env_state(env_state) + def reset(self): self.init_rnn() @@ -172,6 +176,10 @@ def restore(self, fn): if self.normalize_input and 'running_mean_std' in checkpoint: self.model.running_mean_std.load_state_dict(checkpoint['running_mean_std']) + env_state = checkpoint.get('env_state', None) + if self.env is not None and env_state is not None: + self.env.set_env_state(env_state) + def reset(self): self.init_rnn() @@ -210,6 +218,10 @@ def restore(self, fn): if self.normalize_input and 'running_mean_std' in checkpoint: self.model.running_mean_std.load_state_dict(checkpoint['running_mean_std']) + env_state = checkpoint.get('env_state', None) + if self.env is not None and env_state is not None: + self.env.set_env_state(env_state) + def get_action(self, obs, is_deterministic=False): if self.has_batch_dimension == False: obs = unsqueeze_obs(obs) @@ -221,4 +233,4 @@ def get_action(self, obs, is_deterministic=False): return actions def reset(self): - pass \ No newline at end of file + pass diff --git a/rl_games/algos_torch/torch_ext.py b/rl_games/algos_torch/torch_ext.py index 168d9b8c..b35033ad 100644 --- a/rl_games/algos_torch/torch_ext.py +++ b/rl_games/algos_torch/torch_ext.py @@ -84,7 +84,7 @@ def load_checkpoint(filename): return state def parameterized_truncated_normal(uniform, mu, sigma, a, b): - normal = torch.distributions.normal.Normal(0, 1) + normal = torch.distributions.normal.Normal(0, 1, validate_args=False) alpha = (a - mu) / sigma beta = (b - mu) / sigma diff --git a/rl_games/common/a2c_common.py b/rl_games/common/a2c_common.py index 3c0be1a7..b9cee955 100644 --- a/rl_games/common/a2c_common.py +++ b/rl_games/common/a2c_common.py @@ -134,6 +134,7 @@ def __init__(self, base_name, params): self.ppo = config.get('ppo', True) self.max_epochs = self.config.get('max_epochs', 1e6) + self.max_frames = self.config.get('max_frames', 1e10) self.is_adaptive_lr = config['lr_schedule'] == 'adaptive' self.linear_lr = config['lr_schedule'] == 'linear' @@ -932,7 +933,7 @@ def train(self): fps_step = curr_frames / step_time fps_step_inference = curr_frames / scaled_play_time fps_total = curr_frames / scaled_time - print(f'fps step: {fps_step:.0f} fps step and policy inference: {fps_step_inference:.0f} fps total: {fps_total:.0f} epoch: {epoch_num}/{self.max_epochs}') + print(f'fps step: {fps_step:.0f} fps step and policy inference: {fps_step_inference:.0f} fps total: {fps_total:.0f} epoch: {epoch_num}/{self.max_epochs} frame: {self.frame}/{self.max_frames}') self.write_stats(total_time, epoch_num, step_time, play_time, update_time, a_losses, c_losses, entropies, kls, last_lr, lr_mul, frame, scaled_time, scaled_play_time, curr_frames) @@ -974,14 +975,12 @@ def train(self): self.save(os.path.join(self.nn_dir, checkpoint_name)) should_exit = True - if epoch_num >= self.max_epochs: + if epoch_num >= self.max_epochs or self.frame >= self.max_frames: if self.game_rewards.current_size == 0: print('WARNING: Max epochs reached before any env terminated at least once') mean_rewards = -np.inf - self.save(os.path.join(self.nn_dir, - 'last_' + self.config['name'] + 'ep' + str(epoch_num) + 'rew' + str( - mean_rewards))) + self.save(os.path.join(self.nn_dir, 'last_' + self.config['name'] + 'ep' + str(epoch_num) + 'rew' + str(mean_rewards))) print('MAX EPOCHS NUM!') should_exit = True update_time = 0 @@ -1191,7 +1190,7 @@ def train(self): fps_step = curr_frames / step_time fps_step_inference = curr_frames / scaled_play_time fps_total = curr_frames / scaled_time - print(f'fps step: {fps_step:.0f} fps step and policy inference: {fps_step_inference:.0f} fps total: {fps_total:.0f} epoch: {epoch_num}/{self.max_epochs}') + print(f'fps step: {fps_step:.0f} fps step and policy inference: {fps_step_inference:.0f} fps total: {fps_total:.0f} epoch: {epoch_num}/{self.max_epochs} frame: {self.frame}/{self.max_frames}') self.write_stats(total_time, epoch_num, step_time, play_time, update_time, a_losses, c_losses, entropies, kls, last_lr, lr_mul, frame, scaled_time, scaled_play_time, curr_frames) if len(b_losses) > 0: @@ -1235,11 +1234,11 @@ def train(self): self.save(os.path.join(self.nn_dir, checkpoint_name)) should_exit = True - if epoch_num >= self.max_epochs: + if epoch_num >= self.max_epochs or self.frame >= self.max_frames: if self.game_rewards.current_size == 0: print('WARNING: Max epochs reached before any env terminated at least once') mean_rewards = -np.inf - self.save(os.path.join(self.nn_dir, 'last_' + self.config['name'] + 'ep' + str(epoch_num) + 'rew' + str(mean_rewards))) + self.save(os.path.join(self.nn_dir, 'last_' + self.config['name'] + 'ep' + str(epoch_num) + 'rew' + str(mean_rewards).replace('[', '_').replace(']', '_'))) print('MAX EPOCHS NUM!') should_exit = True @@ -1253,4 +1252,4 @@ def train(self): return self.last_mean_rewards, epoch_num if should_exit: - return self.last_mean_rewards, epoch_num \ No newline at end of file + return self.last_mean_rewards, epoch_num