diff --git a/README.md b/README.md index 329684aa..6410f09d 100644 --- a/README.md +++ b/README.md @@ -303,6 +303,8 @@ Additional environment supported properties and functions * Fixed bug with SAC not saving weights with save_frequency. * Added multi-node training support for GPU-accelerated training environments like Isaac Gym. No changes in training scripts are required. Thanks to @ankurhanda and @ArthurAllshire for assistance in implementation. * Added evaluation feature for inferencing during training. Checkpoints from training process can be automatically picked up and updated in the inferencing process when enabled. +* Added get/set API for runtime update of rl training parameters. Thanks to @ArthurAllshire for the initial version of fast PBT code. +* Fixed SAC not loading weights properly. 1.6.0 diff --git a/rl_games/algos_torch/a2c_continuous.py b/rl_games/algos_torch/a2c_continuous.py index 073e529b..285b8abf 100644 --- a/rl_games/algos_torch/a2c_continuous.py +++ b/rl_games/algos_torch/a2c_continuous.py @@ -7,11 +7,10 @@ from torch import optim import torch -from torch import nn -import numpy as np -import gym + class A2CAgent(a2c_common.ContinuousA2CBase): + def __init__(self, base_name, params): a2c_common.ContinuousA2CBase.__init__(self, base_name, params) obs_shape = self.obs_shape @@ -68,9 +67,9 @@ def save(self, fn): state = self.get_full_state_weights() torch_ext.save_checkpoint(fn, state) - def restore(self, fn): + def restore(self, fn, set_epoch=True): checkpoint = torch_ext.load_checkpoint(fn) - self.set_full_state_weights(checkpoint) + self.set_full_state_weights(checkpoint, set_epoch=set_epoch) def get_masked_action_values(self, obs, action_masks): assert False diff --git a/rl_games/algos_torch/a2c_discrete.py b/rl_games/algos_torch/a2c_discrete.py index a4f20d7b..d386cd8e 100644 --- a/rl_games/algos_torch/a2c_discrete.py +++ b/rl_games/algos_torch/a2c_discrete.py @@ -70,9 +70,9 @@ def save(self, fn): state = self.get_full_state_weights() torch_ext.save_checkpoint(fn, state) - def restore(self, fn): + def restore(self, fn, set_epoch=True): checkpoint = torch_ext.load_checkpoint(fn) - self.set_full_state_weights(checkpoint) + self.set_full_state_weights(checkpoint, set_epoch=set_epoch) def get_masked_action_values(self, obs, action_masks): processed_obs = self._preproc_obs(obs['obs']) diff --git a/rl_games/algos_torch/players.py b/rl_games/algos_torch/players.py index 9120bb94..3dfcdadd 100644 --- a/rl_games/algos_torch/players.py +++ b/rl_games/algos_torch/players.py @@ -16,6 +16,7 @@ def rescale_actions(low, high, action): class PpoPlayerContinuous(BasePlayer): + def __init__(self, params): BasePlayer.__init__(self, params) self.network = self.config['network'] @@ -81,7 +82,9 @@ def restore(self, fn): def reset(self): self.init_rnn() + class PpoPlayerDiscrete(BasePlayer): + def __init__(self, params): BasePlayer.__init__(self, params) @@ -185,6 +188,7 @@ def reset(self): class SACPlayer(BasePlayer): + def __init__(self, params): BasePlayer.__init__(self, params) self.network = self.config['network'] diff --git a/rl_games/algos_torch/sac_agent.py b/rl_games/algos_torch/sac_agent.py index f17f52f2..3697a7f3 100644 --- a/rl_games/algos_torch/sac_agent.py +++ b/rl_games/algos_torch/sac_agent.py @@ -38,6 +38,7 @@ def __init__(self, base_name, params): self.num_steps_per_episode = config.get("num_steps_per_episode", 1) self.normalize_input = config.get("normalize_input", False) + # TODO: double-check! To use bootstrap instead? self.max_env_steps = config.get("max_env_steps", 1000) # temporary, in future we will use other approach print(self.batch_size, self.num_actors, self.num_agents) @@ -60,7 +61,6 @@ def __init__(self, base_name, params): 'action_dim': self.env_info["action_space"].shape[0], 'actions_num' : self.actions_num, 'input_shape' : obs_shape, - 'normalize_input' : self.normalize_input, 'normalize_input': self.normalize_input, } self.model = self.network.build(net_config) @@ -88,12 +88,8 @@ def __init__(self, base_name, params): self.target_entropy = self.target_entropy_coef * -self.env_info['action_space'].shape[0] print("Target entropy", self.target_entropy) - self.step = 0 self.algo_observer = config['features']['observer'] - # TODO: Is there a better way to get the maximum number of episodes? - self.max_episodes = torch.ones(self.num_actors, device=self._device)*self.num_steps_per_episode - def load_networks(self, params): builder = model_builder.ModelBuilder() self.config['network'] = builder.load(params) @@ -133,6 +129,8 @@ def base_init(self, base_name, config): self.max_epochs = self.config.get('max_epochs', -1) self.max_frames = self.config.get('max_frames', -1) + self.save_freq = config.get('save_frequency', 0) + self.network = config['network'] self.rewards_shaper = config['reward_shaper'] self.num_agents = self.env_info.get('agents', 1) @@ -146,10 +144,10 @@ def base_init(self, base_name, config): self.min_alpha = torch.tensor(np.log(1)).float().to(self._device) self.frame = 0 + self.epoch_num = 0 self.update_time = 0 - self.last_mean_rewards = -100500 + self.last_mean_rewards = -1000000000 self.play_time = 0 - self.epoch_num = 0 # TODO: put it into the separate class pbt_str = '' @@ -205,17 +203,8 @@ def alpha(self): def device(self): return self._device - def get_full_state_weights(self): - state = self.get_weights() - - state['steps'] = self.step - state['actor_optimizer'] = self.actor_optimizer.state_dict() - state['critic_optimizer'] = self.critic_optimizer.state_dict() - state['log_alpha_optimizer'] = self.log_alpha_optimizer.state_dict() - - return state - def get_weights(self): + print("Loading weights") state = {'actor': self.model.sac_network.actor.state_dict(), 'critic': self.model.sac_network.critic.state_dict(), 'critic_target': self.model.sac_network.critic_target.state_dict()} @@ -233,17 +222,45 @@ def set_weights(self, weights): if self.normalize_input and 'running_mean_std' in weights: self.model.running_mean_std.load_state_dict(weights['running_mean_std']) - def set_full_state_weights(self, weights): + def get_full_state_weights(self): + print("Loading full weights") + state = self.get_weights() + + state['epoch'] = self.epoch_num + state['frame'] = self.frame + state['actor_optimizer'] = self.actor_optimizer.state_dict() + state['critic_optimizer'] = self.critic_optimizer.state_dict() + state['log_alpha_optimizer'] = self.log_alpha_optimizer.state_dict() + + return state + + def set_full_state_weights(self, weights, set_epoch=True): self.set_weights(weights) - self.step = weights['step'] + if set_epoch: + self.epoch_num = weights['epoch'] + self.frame = weights['frame'] + self.actor_optimizer.load_state_dict(weights['actor_optimizer']) self.critic_optimizer.load_state_dict(weights['critic_optimizer']) self.log_alpha_optimizer.load_state_dict(weights['log_alpha_optimizer']) - def restore(self, fn): + self.last_mean_rewards = weights.get('last_mean_rewards', -1000000000) + + if self.vec_env is not None: + env_state = weights.get('env_state', None) + self.vec_env.set_env_state(env_state) + + def restore(self, fn, set_epoch=True): + print("SAC restore") checkpoint = torch_ext.load_checkpoint(fn) - self.set_full_state_weights(checkpoint) + self.set_full_state_weights(checkpoint, set_epoch=set_epoch) + + def get_param(self, param_name): + pass + + def set_param(self, param_name, param_value): + pass def get_masked_action_values(self, obs, action_masks): assert False @@ -334,6 +351,7 @@ def preproc_obs(self, obs): if isinstance(obs, dict): obs = obs['obs'] obs = self.model.norm_obs(obs) + return obs def cast_obs(self, obs): @@ -348,7 +366,7 @@ def cast_obs(self, obs): return obs - # todo: move to common utils + # TODO: move to common utils def obs_to_tensors(self, obs): obs_is_dict = isinstance(obs, dict) if obs_is_dict: @@ -359,6 +377,7 @@ def obs_to_tensors(self, obs): upd_obs = self.cast_obs(obs) if not obs_is_dict or 'obs' not in obs: upd_obs = {'obs' : upd_obs} + return upd_obs def _obs_to_tensors_internal(self, obs): @@ -368,18 +387,19 @@ def _obs_to_tensors_internal(self, obs): upd_obs[key] = self._obs_to_tensors_internal(value) else: upd_obs = self.cast_obs(obs) + return upd_obs def preprocess_actions(self, actions): if not self.is_tensor_obses: actions = actions.cpu().numpy() + return actions def env_step(self, actions): actions = self.preprocess_actions(actions) obs, rewards, dones, infos = self.vec_env.step(actions) # (obs_space) -> (n, obs_space) - self.step += self.num_actors if self.is_tensor_obses: return self.obs_to_tensors(obs), rewards.to(self._device), dones.to(self._device), infos else: @@ -415,7 +435,7 @@ def extract_actor_stats(self, actor_losses, entropies, alphas, alpha_losses, act def clear_stats(self): self.game_rewards.clear() self.game_lengths.clear() - self.mean_rewards = self.last_mean_rewards = -100500 + self.mean_rewards = self.last_mean_rewards = -1000000000 self.algo_observer.after_clear_stats() def play_steps(self, random_exploration = False): @@ -431,6 +451,11 @@ def play_steps(self, random_exploration = False): critic2_losses = [] obs = self.obs + if isinstance(obs, dict): + obs = self.obs['obs'] + + next_obs_processed = obs.clone() + for s in range(self.num_steps_per_episode): self.set_eval() if random_exploration: @@ -466,16 +491,17 @@ def play_steps(self, random_exploration = False): self.current_rewards = self.current_rewards * not_dones self.current_lengths = self.current_lengths * not_dones - if isinstance(obs, dict): - obs = obs['obs'] if isinstance(next_obs, dict): - next_obs = next_obs['obs'] + next_obs_processed = next_obs['obs'] + + self.obs = next_obs.clone() rewards = self.rewards_shaper(rewards) - self.replay_buffer.add(obs, action, torch.unsqueeze(rewards, 1), next_obs, torch.unsqueeze(dones, 1)) + self.replay_buffer.add(obs, action, torch.unsqueeze(rewards, 1), next_obs_processed, torch.unsqueeze(dones, 1)) - self.obs = obs = next_obs.clone() + if isinstance(obs, dict): + obs = self.obs['obs'] if not random_exploration: self.set_train() @@ -505,10 +531,9 @@ def train_epoch(self): def train(self): self.init_tensors() self.algo_observer.after_init(self) - self.last_mean_rewards = -100500 total_time = 0 # rep_count = 0 - self.frame = 0 + self.obs = self.env_reset() while True: @@ -560,7 +585,7 @@ def train(self): should_exit = False if self.save_freq > 0: - if (self.epoch_num % self.save_freq == 0) and (mean_rewards[0] <= self.last_mean_rewards): + if self.epoch_num % self.save_freq == 0: self.save(os.path.join(self.nn_dir, 'last_' + checkpoint_name)) if mean_rewards > self.last_mean_rewards and self.epoch_num >= self.save_best_after: diff --git a/rl_games/common/a2c_common.py b/rl_games/common/a2c_common.py index b6bdb687..646dd809 100644 --- a/rl_games/common/a2c_common.py +++ b/rl_games/common/a2c_common.py @@ -161,6 +161,7 @@ def __init__(self, base_name, params): self.rnn_states = None self.name = base_name + # TODO: do we still need it? self.ppo = config.get('ppo', True) self.max_epochs = self.config.get('max_epochs', -1) self.max_frames = self.config.get('max_frames', -1) @@ -229,22 +230,26 @@ def __init__(self, base_name, params): self.game_lengths = torch_ext.AverageMeter(1, self.games_to_track).to(self.ppo_device) self.obs = None self.games_num = self.config['minibatch_size'] // self.seq_len # it is used only for current rnn implementation + self.batch_size = self.horizon_length * self.num_actors * self.num_agents self.batch_size_envs = self.horizon_length * self.num_actors + assert(('minibatch_size_per_env' in self.config) or ('minibatch_size' in self.config)) self.minibatch_size_per_env = self.config.get('minibatch_size_per_env', 0) self.minibatch_size = self.config.get('minibatch_size', self.num_actors * self.minibatch_size_per_env) - self.mini_epochs_num = self.config['mini_epochs'] + self.num_minibatches = self.batch_size // self.minibatch_size assert(self.batch_size % self.minibatch_size == 0) + self.mini_epochs_num = self.config['mini_epochs'] + self.mixed_precision = self.config.get('mixed_precision', False) self.scaler = torch.cuda.amp.GradScaler(enabled=self.mixed_precision) self.last_lr = self.config['learning_rate'] self.frame = 0 self.update_time = 0 - self.mean_rewards = self.last_mean_rewards = -100500 + self.mean_rewards = self.last_mean_rewards = -1000000000 self.play_time = 0 self.epoch_num = 0 self.curr_frames = 0 @@ -588,10 +593,11 @@ def train_central_value(self): def get_full_state_weights(self): state = self.get_weights() state['epoch'] = self.epoch_num + state['frame'] = self.frame state['optimizer'] = self.optimizer.state_dict() + if self.has_central_value: state['assymetric_vf_nets'] = self.central_value_net.state_dict() - state['frame'] = self.frame # This is actually the best reward ever achieved. last_mean_rewards is perhaps not the best variable name # We save it to the checkpoint to prevent overriding the "best ever" checkpoint upon experiment restart @@ -603,19 +609,22 @@ def get_full_state_weights(self): return state - def set_full_state_weights(self, weights): + def set_full_state_weights(self, weights, set_epoch=True): + self.set_weights(weights) - self.epoch_num = weights['epoch'] # frames as well? + if set_epoch: + self.epoch_num = weights['epoch'] + self.frame = weights['frame'] + if self.has_central_value: self.central_value_net.load_state_dict(weights['assymetric_vf_nets']) self.optimizer.load_state_dict(weights['optimizer']) - self.frame = weights.get('frame', 0) - self.last_mean_rewards = weights.get('last_mean_rewards', -100500) - env_state = weights.get('env_state', None) + self.last_mean_rewards = weights.get('last_mean_rewards', -1000000000) if self.vec_env is not None: + env_state = weights.get('env_state', None) self.vec_env.set_env_state(env_state) def get_weights(self): @@ -651,6 +660,55 @@ def set_weights(self, weights): self.model.load_state_dict(weights['model']) self.set_stats_weights(weights) + def get_param(self, param_name): + if param_name in [ + "grad_norm", + "critic_coef", + "bounds_loss_coef", + "entropy_coef", + "kl_threshold", + "gamma", + "tau", + "mini_epochs_num", + "e_clip", + ]: + return getattr(self, param_name) + elif param_name == "learning_rate": + return self.last_lr + else: + raise NotImplementedError(f"Can't get param {param_name}") + + def set_param(self, param_name, param_value): + if param_name in [ + "grad_norm", + "critic_coef", + "bounds_loss_coef", + "entropy_coef", + "gamma", + "tau", + "mini_epochs_num", + "e_clip", + ]: + setattr(self, param_name, param_value) + elif param_name == "learning_rate": + if self.global_rank == 0: + if self.is_adaptive_lr: + raise NotImplementedError("Can't directly mutate LR on this schedule") + else: + self.learning_rate = param_value + + for param_group in self.optimizer.param_groups: + param_group["lr"] = self.learning_rate + elif param_name == "kl_threshold": + if self.global_rank == 0: + if self.is_adaptive_lr: + self.kl_threshold = param_value + self.scheduler.kl_threshold = param_value + else: + raise NotImplementedError("Can't directly mutate kl threshold") + else: + raise NotImplementedError(f"No param found for {param_value}") + def _preproc_obs(self, obs_batch): if type(obs_batch) is dict: obs_batch = copy.copy(obs_batch) @@ -912,7 +970,6 @@ def prepare_dataset(self, batch_dict): values = self.value_mean_std(values) returns = self.value_mean_std(returns) self.value_mean_std.eval() - advantages = torch.sum(advantages, axis=1) @@ -1024,7 +1081,7 @@ def train(self): checkpoint_name = self.config['name'] + '_ep_' + str(epoch_num) + '_rew_' + str(mean_rewards[0]) if self.save_freq > 0: - if (epoch_num % self.save_freq == 0) and (mean_rewards <= self.last_mean_rewards): + if epoch_num % self.save_freq == 0: self.save(os.path.join(self.nn_dir, 'last_' + checkpoint_name)) if mean_rewards[0] > self.last_mean_rewards and epoch_num >= self.save_best_after: @@ -1301,7 +1358,7 @@ def train(self): checkpoint_name = self.config['name'] + '_ep_' + str(epoch_num) + '_rew_' + str(mean_rewards[0]) if self.save_freq > 0: - if (epoch_num % self.save_freq == 0) and (mean_rewards[0] <= self.last_mean_rewards): + if epoch_num % self.save_freq == 0: self.save(os.path.join(self.nn_dir, 'last_' + checkpoint_name)) if mean_rewards[0] > self.last_mean_rewards and epoch_num >= self.save_best_after: diff --git a/rl_games/common/player.py b/rl_games/common/player.py index f1a5c35e..98be6501 100644 --- a/rl_games/common/player.py +++ b/rl_games/common/player.py @@ -364,9 +364,9 @@ def run(self): cur_rewards_done = cur_rewards/done_count cur_steps_done = cur_steps/done_count if print_game_res: - print(f'reward: {cur_rewards_done:.4} steps: {cur_steps_done:.4} w: {game_res}') + print(f'reward: {cur_rewards_done:.2f} steps: {cur_steps_done:.1f} w: {game_res}') else: - print(f'reward: {cur_rewards_done:.4} steps: {cur_steps_done:.4f}') + print(f'reward: {cur_rewards_done:.2f} steps: {cur_steps_done:.1f}') sum_game_res += game_res if batch_size//self.num_agents == 1 or games_played >= n_games: diff --git a/rl_games/common/vecenv.py b/rl_games/common/vecenv.py index d6261656..50ab4999 100644 --- a/rl_games/common/vecenv.py +++ b/rl_games/common/vecenv.py @@ -8,6 +8,7 @@ from time import sleep import torch + class RayWorker: def __init__(self, config_name, config): self.env = configurations[config_name]['env_creator'](**config) diff --git a/rl_games/configs/mujoco/sac_ant_envpool.yaml b/rl_games/configs/mujoco/sac_ant_envpool.yaml index e8390ed5..719e8b8f 100644 --- a/rl_games/configs/mujoco/sac_ant_envpool.yaml +++ b/rl_games/configs/mujoco/sac_ant_envpool.yaml @@ -29,7 +29,7 @@ params: max_epochs: 10000 num_steps_per_episode: 8 save_best_after: 500 - save_frequency: 10000 + save_frequency: 1000 gamma: 0.99 init_alpha: 1 alpha_lr: 5e-3 diff --git a/rl_games/interfaces/base_algorithm.py b/rl_games/interfaces/base_algorithm.py index 054483f7..0edf4504 100644 --- a/rl_games/interfaces/base_algorithm.py +++ b/rl_games/interfaces/base_algorithm.py @@ -1,6 +1,7 @@ from abc import ABC from abc import abstractmethod, abstractproperty + class BaseAlgorithm(ABC): def __init__(self, base_name, config): pass @@ -26,7 +27,7 @@ def get_full_state_weights(self): pass @abstractmethod - def set_full_state_weights(self, weights): + def set_full_state_weights(self, weights, set_epoch): pass @abstractmethod @@ -37,5 +38,14 @@ def get_weights(self): def set_weights(self, weights): pass + # Get algo training parameters + @abstractmethod + def get_param(self, param_name): + pass + + # Set algo training parameters + @abstractmethod + def set_param(self, param_name, param_value): + pass diff --git a/rl_games/torch_runner.py b/rl_games/torch_runner.py index e8ec70ee..39b13edb 100644 --- a/rl_games/torch_runner.py +++ b/rl_games/torch_runner.py @@ -131,11 +131,8 @@ def reset(self): pass def run(self, args): - load_path = None - if args['train']: self.run_train(args) - elif args['play']: self.run_play(args) else: diff --git a/runner.py b/runner.py index 25f79af4..c1188751 100644 --- a/runner.py +++ b/runner.py @@ -50,11 +50,9 @@ except yaml.YAMLError as exc: print(exc) - #rank = int(os.getenv("LOCAL_RANK", "0")) global_rank = int(os.getenv("RANK", "0")) if args["track"] and global_rank == 0: import wandb - wandb.init( project=args["wandb_project_name"], entity=args["wandb_entity"],