Skip to content

Commit

Permalink
Small changes for DexPBT paper. 1) load env state from checkpoint in …
Browse files Browse the repository at this point in the history
…player (i.e. if we train with curriculum we want to load the same state). 2) Minor speedup disabling validate_args in distributions 3) added another way to stop training max_env_steps (#204)
  • Loading branch information
alex-petrenko authored Sep 29, 2022
1 parent f3e9c7f commit d8645b2
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 13 deletions.
4 changes: 2 additions & 2 deletions rl_games/algos_torch/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def forward(self, input_dict):
prev_actions = input_dict.get('prev_actions', None)
input_dict['obs'] = self.norm_obs(input_dict['obs'])
mu, sigma, value, states = self.a2c_network(input_dict)
distr = torch.distributions.Normal(mu, sigma)
distr = torch.distributions.Normal(mu, sigma, validate_args=False)

if is_train:
entropy = distr.entropy().sum(dim=-1)
Expand Down Expand Up @@ -246,7 +246,7 @@ def forward(self, input_dict):
input_dict['obs'] = self.norm_obs(input_dict['obs'])
mu, logstd, value, states = self.a2c_network(input_dict)
sigma = torch.exp(logstd)
distr = torch.distributions.Normal(mu, sigma)
distr = torch.distributions.Normal(mu, sigma, validate_args=False)
if is_train:
entropy = distr.entropy().sum(dim=-1)
prev_neglogp = self.neglogp(prev_actions, mu, sigma, logstd)
Expand Down
14 changes: 13 additions & 1 deletion rl_games/algos_torch/players.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ def restore(self, fn):
if self.normalize_input and 'running_mean_std' in checkpoint:
self.model.running_mean_std.load_state_dict(checkpoint['running_mean_std'])

env_state = checkpoint.get('env_state', None)
if self.env is not None and env_state is not None:
self.env.set_env_state(env_state)

def reset(self):
self.init_rnn()

Expand Down Expand Up @@ -172,6 +176,10 @@ def restore(self, fn):
if self.normalize_input and 'running_mean_std' in checkpoint:
self.model.running_mean_std.load_state_dict(checkpoint['running_mean_std'])

env_state = checkpoint.get('env_state', None)
if self.env is not None and env_state is not None:
self.env.set_env_state(env_state)

def reset(self):
self.init_rnn()

Expand Down Expand Up @@ -210,6 +218,10 @@ def restore(self, fn):
if self.normalize_input and 'running_mean_std' in checkpoint:
self.model.running_mean_std.load_state_dict(checkpoint['running_mean_std'])

env_state = checkpoint.get('env_state', None)
if self.env is not None and env_state is not None:
self.env.set_env_state(env_state)

def get_action(self, obs, is_deterministic=False):
if self.has_batch_dimension == False:
obs = unsqueeze_obs(obs)
Expand All @@ -221,4 +233,4 @@ def get_action(self, obs, is_deterministic=False):
return actions

def reset(self):
pass
pass
2 changes: 1 addition & 1 deletion rl_games/algos_torch/torch_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def load_checkpoint(filename):
return state

def parameterized_truncated_normal(uniform, mu, sigma, a, b):
normal = torch.distributions.normal.Normal(0, 1)
normal = torch.distributions.normal.Normal(0, 1, validate_args=False)

alpha = (a - mu) / sigma
beta = (b - mu) / sigma
Expand Down
17 changes: 8 additions & 9 deletions rl_games/common/a2c_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def __init__(self, base_name, params):

self.ppo = config.get('ppo', True)
self.max_epochs = self.config.get('max_epochs', 1e6)
self.max_frames = self.config.get('max_frames', 1e10)

self.is_adaptive_lr = config['lr_schedule'] == 'adaptive'
self.linear_lr = config['lr_schedule'] == 'linear'
Expand Down Expand Up @@ -932,7 +933,7 @@ def train(self):
fps_step = curr_frames / step_time
fps_step_inference = curr_frames / scaled_play_time
fps_total = curr_frames / scaled_time
print(f'fps step: {fps_step:.0f} fps step and policy inference: {fps_step_inference:.0f} fps total: {fps_total:.0f} epoch: {epoch_num}/{self.max_epochs}')
print(f'fps step: {fps_step:.0f} fps step and policy inference: {fps_step_inference:.0f} fps total: {fps_total:.0f} epoch: {epoch_num}/{self.max_epochs} frame: {self.frame}/{self.max_frames}')

self.write_stats(total_time, epoch_num, step_time, play_time, update_time, a_losses, c_losses, entropies, kls, last_lr, lr_mul, frame, scaled_time, scaled_play_time, curr_frames)

Expand Down Expand Up @@ -974,14 +975,12 @@ def train(self):
self.save(os.path.join(self.nn_dir, checkpoint_name))
should_exit = True

if epoch_num >= self.max_epochs:
if epoch_num >= self.max_epochs or self.frame >= self.max_frames:
if self.game_rewards.current_size == 0:
print('WARNING: Max epochs reached before any env terminated at least once')
mean_rewards = -np.inf

self.save(os.path.join(self.nn_dir,
'last_' + self.config['name'] + 'ep' + str(epoch_num) + 'rew' + str(
mean_rewards)))
self.save(os.path.join(self.nn_dir, 'last_' + self.config['name'] + 'ep' + str(epoch_num) + 'rew' + str(mean_rewards)))
print('MAX EPOCHS NUM!')
should_exit = True
update_time = 0
Expand Down Expand Up @@ -1191,7 +1190,7 @@ def train(self):
fps_step = curr_frames / step_time
fps_step_inference = curr_frames / scaled_play_time
fps_total = curr_frames / scaled_time
print(f'fps step: {fps_step:.0f} fps step and policy inference: {fps_step_inference:.0f} fps total: {fps_total:.0f} epoch: {epoch_num}/{self.max_epochs}')
print(f'fps step: {fps_step:.0f} fps step and policy inference: {fps_step_inference:.0f} fps total: {fps_total:.0f} epoch: {epoch_num}/{self.max_epochs} frame: {self.frame}/{self.max_frames}')

self.write_stats(total_time, epoch_num, step_time, play_time, update_time, a_losses, c_losses, entropies, kls, last_lr, lr_mul, frame, scaled_time, scaled_play_time, curr_frames)
if len(b_losses) > 0:
Expand Down Expand Up @@ -1235,11 +1234,11 @@ def train(self):
self.save(os.path.join(self.nn_dir, checkpoint_name))
should_exit = True

if epoch_num >= self.max_epochs:
if epoch_num >= self.max_epochs or self.frame >= self.max_frames:
if self.game_rewards.current_size == 0:
print('WARNING: Max epochs reached before any env terminated at least once')
mean_rewards = -np.inf
self.save(os.path.join(self.nn_dir, 'last_' + self.config['name'] + 'ep' + str(epoch_num) + 'rew' + str(mean_rewards)))
self.save(os.path.join(self.nn_dir, 'last_' + self.config['name'] + 'ep' + str(epoch_num) + 'rew' + str(mean_rewards).replace('[', '_').replace(']', '_')))
print('MAX EPOCHS NUM!')
should_exit = True

Expand All @@ -1253,4 +1252,4 @@ def train(self):
return self.last_mean_rewards, epoch_num

if should_exit:
return self.last_mean_rewards, epoch_num
return self.last_mean_rewards, epoch_num

0 comments on commit d8645b2

Please sign in to comment.