From e7e42512bfcdd4d6e852d8fb6f884d6942e7c8e4 Mon Sep 17 00:00:00 2001 From: ViktorM Date: Tue, 22 Aug 2023 09:16:16 -0700 Subject: [PATCH] Release notes updates. Save_freq SAC fix. --- README.md | 2 ++ rl_games/algos_torch/sac_agent.py | 6 +++++- rl_games/common/a2c_common.py | 4 ++-- rl_games/common/vecenv.py | 1 + rl_games/configs/mujoco/sac_ant_envpool.yaml | 2 +- 5 files changed, 11 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 329684aa..6410f09d 100644 --- a/README.md +++ b/README.md @@ -303,6 +303,8 @@ Additional environment supported properties and functions * Fixed bug with SAC not saving weights with save_frequency. * Added multi-node training support for GPU-accelerated training environments like Isaac Gym. No changes in training scripts are required. Thanks to @ankurhanda and @ArthurAllshire for assistance in implementation. * Added evaluation feature for inferencing during training. Checkpoints from training process can be automatically picked up and updated in the inferencing process when enabled. +* Added get/set API for runtime update of rl training parameters. Thanks to @ArthurAllshire for the initial version of fast PBT code. +* Fixed SAC not loading weights properly. 1.6.0 diff --git a/rl_games/algos_torch/sac_agent.py b/rl_games/algos_torch/sac_agent.py index 3e1ea21f..f796a79f 100644 --- a/rl_games/algos_torch/sac_agent.py +++ b/rl_games/algos_torch/sac_agent.py @@ -130,6 +130,8 @@ def base_init(self, base_name, config): self.max_epochs = self.config.get('max_epochs', -1) self.max_frames = self.config.get('max_frames', -1) + self.save_freq = config.get('save_frequency', 0) + self.network = config['network'] self.rewards_shaper = config['reward_shaper'] self.num_agents = self.env_info.get('agents', 1) @@ -203,6 +205,7 @@ def device(self): return self._device def get_weights(self): + print("Loading weights") state = {'actor': self.model.sac_network.actor.state_dict(), 'critic': self.model.sac_network.critic.state_dict(), 'critic_target': self.model.sac_network.critic_target.state_dict()} @@ -221,6 +224,7 @@ def set_weights(self, weights): self.model.running_mean_std.load_state_dict(weights['running_mean_std']) def get_full_state_weights(self): + print("Loading full weights") state = self.get_weights() state['epoch'] = self.epoch_num @@ -569,7 +573,7 @@ def train(self): should_exit = False if self.save_freq > 0: - if (self.epoch_num % self.save_freq == 0) and (mean_rewards[0] <= self.last_mean_rewards): + if self.epoch_num % self.save_freq == 0: self.save(os.path.join(self.nn_dir, 'last_' + checkpoint_name)) if mean_rewards > self.last_mean_rewards and self.epoch_num >= self.save_best_after: diff --git a/rl_games/common/a2c_common.py b/rl_games/common/a2c_common.py index 6aeeebe5..498906b7 100644 --- a/rl_games/common/a2c_common.py +++ b/rl_games/common/a2c_common.py @@ -1081,7 +1081,7 @@ def train(self): checkpoint_name = self.config['name'] + '_ep_' + str(epoch_num) + '_rew_' + str(mean_rewards[0]) if self.save_freq > 0: - if (epoch_num % self.save_freq == 0) and (mean_rewards <= self.last_mean_rewards): + if epoch_num % self.save_freq == 0: self.save(os.path.join(self.nn_dir, 'last_' + checkpoint_name)) if mean_rewards[0] > self.last_mean_rewards and epoch_num >= self.save_best_after: @@ -1358,7 +1358,7 @@ def train(self): checkpoint_name = self.config['name'] + '_ep_' + str(epoch_num) + '_rew_' + str(mean_rewards[0]) if self.save_freq > 0: - if (epoch_num % self.save_freq == 0) and (mean_rewards[0] <= self.last_mean_rewards): + if epoch_num % self.save_freq == 0: self.save(os.path.join(self.nn_dir, 'last_' + checkpoint_name)) if mean_rewards[0] > self.last_mean_rewards and epoch_num >= self.save_best_after: diff --git a/rl_games/common/vecenv.py b/rl_games/common/vecenv.py index d6261656..50ab4999 100644 --- a/rl_games/common/vecenv.py +++ b/rl_games/common/vecenv.py @@ -8,6 +8,7 @@ from time import sleep import torch + class RayWorker: def __init__(self, config_name, config): self.env = configurations[config_name]['env_creator'](**config) diff --git a/rl_games/configs/mujoco/sac_ant_envpool.yaml b/rl_games/configs/mujoco/sac_ant_envpool.yaml index e8390ed5..719e8b8f 100644 --- a/rl_games/configs/mujoco/sac_ant_envpool.yaml +++ b/rl_games/configs/mujoco/sac_ant_envpool.yaml @@ -29,7 +29,7 @@ params: max_epochs: 10000 num_steps_per_episode: 8 save_best_after: 500 - save_frequency: 10000 + save_frequency: 1000 gamma: 0.99 init_alpha: 1 alpha_lr: 5e-3