diff --git a/rl_games/algos_torch/a2c_continuous.py b/rl_games/algos_torch/a2c_continuous.py index 073e529b..285b8abf 100644 --- a/rl_games/algos_torch/a2c_continuous.py +++ b/rl_games/algos_torch/a2c_continuous.py @@ -7,11 +7,10 @@ from torch import optim import torch -from torch import nn -import numpy as np -import gym + class A2CAgent(a2c_common.ContinuousA2CBase): + def __init__(self, base_name, params): a2c_common.ContinuousA2CBase.__init__(self, base_name, params) obs_shape = self.obs_shape @@ -68,9 +67,9 @@ def save(self, fn): state = self.get_full_state_weights() torch_ext.save_checkpoint(fn, state) - def restore(self, fn): + def restore(self, fn, set_epoch=True): checkpoint = torch_ext.load_checkpoint(fn) - self.set_full_state_weights(checkpoint) + self.set_full_state_weights(checkpoint, set_epoch=set_epoch) def get_masked_action_values(self, obs, action_masks): assert False diff --git a/rl_games/algos_torch/a2c_discrete.py b/rl_games/algos_torch/a2c_discrete.py index a4f20d7b..d386cd8e 100644 --- a/rl_games/algos_torch/a2c_discrete.py +++ b/rl_games/algos_torch/a2c_discrete.py @@ -70,9 +70,9 @@ def save(self, fn): state = self.get_full_state_weights() torch_ext.save_checkpoint(fn, state) - def restore(self, fn): + def restore(self, fn, set_epoch=True): checkpoint = torch_ext.load_checkpoint(fn) - self.set_full_state_weights(checkpoint) + self.set_full_state_weights(checkpoint, set_epoch=set_epoch) def get_masked_action_values(self, obs, action_masks): processed_obs = self._preproc_obs(obs['obs']) diff --git a/rl_games/algos_torch/sac_agent.py b/rl_games/algos_torch/sac_agent.py index f17f52f2..3e1ea21f 100644 --- a/rl_games/algos_torch/sac_agent.py +++ b/rl_games/algos_torch/sac_agent.py @@ -38,6 +38,7 @@ def __init__(self, base_name, params): self.num_steps_per_episode = config.get("num_steps_per_episode", 1) self.normalize_input = config.get("normalize_input", False) + # TODO: double-check! To use bootstrap instead? self.max_env_steps = config.get("max_env_steps", 1000) # temporary, in future we will use other approach print(self.batch_size, self.num_actors, self.num_agents) @@ -88,12 +89,8 @@ def __init__(self, base_name, params): self.target_entropy = self.target_entropy_coef * -self.env_info['action_space'].shape[0] print("Target entropy", self.target_entropy) - self.step = 0 self.algo_observer = config['features']['observer'] - # TODO: Is there a better way to get the maximum number of episodes? - self.max_episodes = torch.ones(self.num_actors, device=self._device)*self.num_steps_per_episode - def load_networks(self, params): builder = model_builder.ModelBuilder() self.config['network'] = builder.load(params) @@ -146,10 +143,10 @@ def base_init(self, base_name, config): self.min_alpha = torch.tensor(np.log(1)).float().to(self._device) self.frame = 0 + self.epoch_num = 0 self.update_time = 0 - self.last_mean_rewards = -100500 + self.last_mean_rewards = -1000000000 self.play_time = 0 - self.epoch_num = 0 # TODO: put it into the separate class pbt_str = '' @@ -205,16 +202,6 @@ def alpha(self): def device(self): return self._device - def get_full_state_weights(self): - state = self.get_weights() - - state['steps'] = self.step - state['actor_optimizer'] = self.actor_optimizer.state_dict() - state['critic_optimizer'] = self.critic_optimizer.state_dict() - state['log_alpha_optimizer'] = self.log_alpha_optimizer.state_dict() - - return state - def get_weights(self): state = {'actor': self.model.sac_network.actor.state_dict(), 'critic': self.model.sac_network.critic.state_dict(), @@ -233,17 +220,37 @@ def set_weights(self, weights): if self.normalize_input and 'running_mean_std' in weights: self.model.running_mean_std.load_state_dict(weights['running_mean_std']) - def set_full_state_weights(self, weights): + def get_full_state_weights(self): + state = self.get_weights() + + state['epoch'] = self.epoch_num + state['frame'] = self.frame + state['actor_optimizer'] = self.actor_optimizer.state_dict() + state['critic_optimizer'] = self.critic_optimizer.state_dict() + state['log_alpha_optimizer'] = self.log_alpha_optimizer.state_dict() + + return state + + def set_full_state_weights(self, weights, set_epoch=True): self.set_weights(weights) - self.step = weights['step'] + if set_epoch: + self.epoch_num = weights['epoch'] + self.frame = weights['frame'] + self.actor_optimizer.load_state_dict(weights['actor_optimizer']) self.critic_optimizer.load_state_dict(weights['critic_optimizer']) self.log_alpha_optimizer.load_state_dict(weights['log_alpha_optimizer']) - def restore(self, fn): + def restore(self, fn, set_epoch=True): checkpoint = torch_ext.load_checkpoint(fn) - self.set_full_state_weights(checkpoint) + self.set_full_state_weights(checkpoint, set_epoch=set_epoch) + + def get_params(self, param_name): + pass + + def set_params(self, param_name, param_value): + pass def get_masked_action_values(self, obs, action_masks): assert False @@ -334,6 +341,7 @@ def preproc_obs(self, obs): if isinstance(obs, dict): obs = obs['obs'] obs = self.model.norm_obs(obs) + return obs def cast_obs(self, obs): @@ -348,7 +356,7 @@ def cast_obs(self, obs): return obs - # todo: move to common utils + # TODO: move to common utils def obs_to_tensors(self, obs): obs_is_dict = isinstance(obs, dict) if obs_is_dict: @@ -359,6 +367,7 @@ def obs_to_tensors(self, obs): upd_obs = self.cast_obs(obs) if not obs_is_dict or 'obs' not in obs: upd_obs = {'obs' : upd_obs} + return upd_obs def _obs_to_tensors_internal(self, obs): @@ -368,18 +377,19 @@ def _obs_to_tensors_internal(self, obs): upd_obs[key] = self._obs_to_tensors_internal(value) else: upd_obs = self.cast_obs(obs) + return upd_obs def preprocess_actions(self, actions): if not self.is_tensor_obses: actions = actions.cpu().numpy() + return actions def env_step(self, actions): actions = self.preprocess_actions(actions) obs, rewards, dones, infos = self.vec_env.step(actions) # (obs_space) -> (n, obs_space) - self.step += self.num_actors if self.is_tensor_obses: return self.obs_to_tensors(obs), rewards.to(self._device), dones.to(self._device), infos else: @@ -415,7 +425,7 @@ def extract_actor_stats(self, actor_losses, entropies, alphas, alpha_losses, act def clear_stats(self): self.game_rewards.clear() self.game_lengths.clear() - self.mean_rewards = self.last_mean_rewards = -100500 + self.mean_rewards = self.last_mean_rewards = -1000000000 self.algo_observer.after_clear_stats() def play_steps(self, random_exploration = False): @@ -505,10 +515,9 @@ def train_epoch(self): def train(self): self.init_tensors() self.algo_observer.after_init(self) - self.last_mean_rewards = -100500 total_time = 0 # rep_count = 0 - self.frame = 0 + self.obs = self.env_reset() while True: diff --git a/rl_games/common/a2c_common.py b/rl_games/common/a2c_common.py index b6bdb687..6aeeebe5 100644 --- a/rl_games/common/a2c_common.py +++ b/rl_games/common/a2c_common.py @@ -161,6 +161,7 @@ def __init__(self, base_name, params): self.rnn_states = None self.name = base_name + # TODO: do we still need it? self.ppo = config.get('ppo', True) self.max_epochs = self.config.get('max_epochs', -1) self.max_frames = self.config.get('max_frames', -1) @@ -229,22 +230,26 @@ def __init__(self, base_name, params): self.game_lengths = torch_ext.AverageMeter(1, self.games_to_track).to(self.ppo_device) self.obs = None self.games_num = self.config['minibatch_size'] // self.seq_len # it is used only for current rnn implementation + self.batch_size = self.horizon_length * self.num_actors * self.num_agents self.batch_size_envs = self.horizon_length * self.num_actors + assert(('minibatch_size_per_env' in self.config) or ('minibatch_size' in self.config)) self.minibatch_size_per_env = self.config.get('minibatch_size_per_env', 0) self.minibatch_size = self.config.get('minibatch_size', self.num_actors * self.minibatch_size_per_env) - self.mini_epochs_num = self.config['mini_epochs'] + self.num_minibatches = self.batch_size // self.minibatch_size assert(self.batch_size % self.minibatch_size == 0) + self.mini_epochs_num = self.config['mini_epochs'] + self.mixed_precision = self.config.get('mixed_precision', False) self.scaler = torch.cuda.amp.GradScaler(enabled=self.mixed_precision) self.last_lr = self.config['learning_rate'] self.frame = 0 self.update_time = 0 - self.mean_rewards = self.last_mean_rewards = -100500 + self.mean_rewards = self.last_mean_rewards = -1000000000 self.play_time = 0 self.epoch_num = 0 self.curr_frames = 0 @@ -588,10 +593,11 @@ def train_central_value(self): def get_full_state_weights(self): state = self.get_weights() state['epoch'] = self.epoch_num + state['frame'] = self.frame state['optimizer'] = self.optimizer.state_dict() + if self.has_central_value: state['assymetric_vf_nets'] = self.central_value_net.state_dict() - state['frame'] = self.frame # This is actually the best reward ever achieved. last_mean_rewards is perhaps not the best variable name # We save it to the checkpoint to prevent overriding the "best ever" checkpoint upon experiment restart @@ -603,19 +609,22 @@ def get_full_state_weights(self): return state - def set_full_state_weights(self, weights): + def set_full_state_weights(self, weights, set_epoch=True): + self.set_weights(weights) - self.epoch_num = weights['epoch'] # frames as well? + if set_epoch: + self.epoch_num = weights['epoch'] + self.frame = weights['frame'] + if self.has_central_value: self.central_value_net.load_state_dict(weights['assymetric_vf_nets']) self.optimizer.load_state_dict(weights['optimizer']) - self.frame = weights.get('frame', 0) - self.last_mean_rewards = weights.get('last_mean_rewards', -100500) - env_state = weights.get('env_state', None) + self.last_mean_rewards = weights.get('last_mean_rewards', -1000000000) if self.vec_env is not None: + env_state = weights.get('env_state', None) self.vec_env.set_env_state(env_state) def get_weights(self): @@ -651,6 +660,55 @@ def set_weights(self, weights): self.model.load_state_dict(weights['model']) self.set_stats_weights(weights) + def get_params(self, param_name): + if param_name in [ + "grad_norm", + "critic_coef", + "bounds_loss_coef", + "entropy_coef", + "kl_threshold", + "gamma", + "tau", + "mini_epochs_num", + "e_clip", + ]: + return getattr(self, param_name) + elif param_name == "learning_rate": + return self.last_lr + else: + raise NotImplementedError(f"Can't get param {param_name}") + + def set_params(self, param_name, param_value): + if param_name in [ + "grad_norm", + "critic_coef", + "bounds_loss_coef", + "entropy_coef", + "gamma", + "tau", + "mini_epochs_num", + "e_clip", + ]: + setattr(self, param_name, param_value) + elif param_name == "learning_rate": + if self.global_rank == 0: + if self.is_adaptive_lr: + raise NotImplementedError("Can't directly mutate LR on this schedule") + else: + self.learning_rate = param_value + + for param_group in self.optimizer.param_groups: + param_group["lr"] = self.learning_rate + elif param_name == "kl_threshold": + if self.global_rank == 0: + if self.is_adaptive_lr: + self.kl_threshold = param_value + self.scheduler.kl_threshold = param_value + else: + raise NotImplementedError("Can't directly mutate kl threshold") + else: + raise NotImplementedError(f"No param found for {param_value}") + def _preproc_obs(self, obs_batch): if type(obs_batch) is dict: obs_batch = copy.copy(obs_batch) @@ -912,7 +970,6 @@ def prepare_dataset(self, batch_dict): values = self.value_mean_std(values) returns = self.value_mean_std(returns) self.value_mean_std.eval() - advantages = torch.sum(advantages, axis=1) diff --git a/rl_games/interfaces/base_algorithm.py b/rl_games/interfaces/base_algorithm.py index 054483f7..8534f408 100644 --- a/rl_games/interfaces/base_algorithm.py +++ b/rl_games/interfaces/base_algorithm.py @@ -1,6 +1,7 @@ from abc import ABC from abc import abstractmethod, abstractproperty + class BaseAlgorithm(ABC): def __init__(self, base_name, config): pass @@ -26,7 +27,7 @@ def get_full_state_weights(self): pass @abstractmethod - def set_full_state_weights(self, weights): + def set_full_state_weights(self, weights, set_epoch): pass @abstractmethod @@ -37,5 +38,14 @@ def get_weights(self): def set_weights(self, weights): pass + # Get algo training parameters + @abstractmethod + def get_params(self, param_name): + pass + + # Set algo training parameters + @abstractmethod + def set_params(self, param_name, param_value): + pass