Skip to content

Commit

Permalink
Release notes updates. Save_freq SAC fix.
Browse files Browse the repository at this point in the history
  • Loading branch information
ViktorM committed Aug 22, 2023
1 parent 289e32f commit e7e4251
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 4 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,8 @@ Additional environment supported properties and functions
* Fixed bug with SAC not saving weights with save_frequency.
* Added multi-node training support for GPU-accelerated training environments like Isaac Gym. No changes in training scripts are required. Thanks to @ankurhanda and @ArthurAllshire for assistance in implementation.
* Added evaluation feature for inferencing during training. Checkpoints from training process can be automatically picked up and updated in the inferencing process when enabled.
* Added get/set API for runtime update of rl training parameters. Thanks to @ArthurAllshire for the initial version of fast PBT code.
* Fixed SAC not loading weights properly.

1.6.0

Expand Down
6 changes: 5 additions & 1 deletion rl_games/algos_torch/sac_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ def base_init(self, base_name, config):
self.max_epochs = self.config.get('max_epochs', -1)
self.max_frames = self.config.get('max_frames', -1)

self.save_freq = config.get('save_frequency', 0)

self.network = config['network']
self.rewards_shaper = config['reward_shaper']
self.num_agents = self.env_info.get('agents', 1)
Expand Down Expand Up @@ -203,6 +205,7 @@ def device(self):
return self._device

def get_weights(self):
print("Loading weights")
state = {'actor': self.model.sac_network.actor.state_dict(),
'critic': self.model.sac_network.critic.state_dict(),
'critic_target': self.model.sac_network.critic_target.state_dict()}
Expand All @@ -221,6 +224,7 @@ def set_weights(self, weights):
self.model.running_mean_std.load_state_dict(weights['running_mean_std'])

def get_full_state_weights(self):
print("Loading full weights")
state = self.get_weights()

state['epoch'] = self.epoch_num
Expand Down Expand Up @@ -569,7 +573,7 @@ def train(self):
should_exit = False

if self.save_freq > 0:
if (self.epoch_num % self.save_freq == 0) and (mean_rewards[0] <= self.last_mean_rewards):
if self.epoch_num % self.save_freq == 0:
self.save(os.path.join(self.nn_dir, 'last_' + checkpoint_name))

if mean_rewards > self.last_mean_rewards and self.epoch_num >= self.save_best_after:
Expand Down
4 changes: 2 additions & 2 deletions rl_games/common/a2c_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1081,7 +1081,7 @@ def train(self):
checkpoint_name = self.config['name'] + '_ep_' + str(epoch_num) + '_rew_' + str(mean_rewards[0])

if self.save_freq > 0:
if (epoch_num % self.save_freq == 0) and (mean_rewards <= self.last_mean_rewards):
if epoch_num % self.save_freq == 0:
self.save(os.path.join(self.nn_dir, 'last_' + checkpoint_name))

if mean_rewards[0] > self.last_mean_rewards and epoch_num >= self.save_best_after:
Expand Down Expand Up @@ -1358,7 +1358,7 @@ def train(self):
checkpoint_name = self.config['name'] + '_ep_' + str(epoch_num) + '_rew_' + str(mean_rewards[0])

if self.save_freq > 0:
if (epoch_num % self.save_freq == 0) and (mean_rewards[0] <= self.last_mean_rewards):
if epoch_num % self.save_freq == 0:
self.save(os.path.join(self.nn_dir, 'last_' + checkpoint_name))

if mean_rewards[0] > self.last_mean_rewards and epoch_num >= self.save_best_after:
Expand Down
1 change: 1 addition & 0 deletions rl_games/common/vecenv.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from time import sleep
import torch


class RayWorker:
def __init__(self, config_name, config):
self.env = configurations[config_name]['env_creator'](**config)
Expand Down
2 changes: 1 addition & 1 deletion rl_games/configs/mujoco/sac_ant_envpool.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ params:
max_epochs: 10000
num_steps_per_episode: 8
save_best_after: 500
save_frequency: 10000
save_frequency: 1000
gamma: 0.99
init_alpha: 1
alpha_lr: 5e-3
Expand Down

0 comments on commit e7e4251

Please sign in to comment.