From 44bc254e723acf82f035e1c699ad4e036bd5f42a Mon Sep 17 00:00:00 2001 From: ViktorM Date: Sat, 21 May 2022 23:46:31 -0700 Subject: [PATCH 01/26] WIP --- rl_games/algos_torch/models.py | 32 +- rl_games/algos_torch/players.py | 57 +++ rl_games/algos_torch/sac_agent.py | 2 +- rl_games/algos_torch/shac_agent.py | 572 +++++++++++++++++++++++++++++ rl_games/common/a2c_common.py | 4 +- rl_games/torch_runner.py | 17 +- 6 files changed, 671 insertions(+), 13 deletions(-) create mode 100644 rl_games/algos_torch/shac_agent.py diff --git a/rl_games/algos_torch/models.py b/rl_games/algos_torch/models.py index 920cbf54..8a1ad5c9 100644 --- a/rl_games/algos_torch/models.py +++ b/rl_games/algos_torch/models.py @@ -169,6 +169,7 @@ def forward(self, input_dict): } return result + class ModelA2CContinuous(BaseModel): def __init__(self, network): BaseModel.__init__(self, 'a2c') @@ -312,7 +313,6 @@ def forward(self, input_dict): return result - class ModelSACContinuous(BaseModel): def __init__(self, network): @@ -343,4 +343,34 @@ def forward(self, input_dict): return dist +class ModelSHAC(BaseModel): + + def __init__(self, network): + BaseModel.__init__(self, 'shac') + self.network_builder = network + + class Network(BaseModelNetwork): + def __init__(self, shac_network,**kwargs): + BaseModelNetwork.__init__(self,**kwargs) + self.shac_network = shac_network + + def critic(self, obs, action): + return self.shac_network.critic(obs, action) + + def critic_target(self, obs, action): + return self.shac_network.critic_target(obs, action) + + def actor(self, obs): + return self.shac_network.actor(obs) + + def is_rnn(self): + return False + + def forward(self, input_dict): + is_train = input_dict.pop('is_train', True) + mu, sigma = self.shac_network(input_dict) + dist = SquashedNormal(mu, sigma) + return dist + + diff --git a/rl_games/algos_torch/players.py b/rl_games/algos_torch/players.py index 283a8072..7cf911a4 100644 --- a/rl_games/algos_torch/players.py +++ b/rl_games/algos_torch/players.py @@ -16,6 +16,7 @@ def rescale_actions(low, high, action): class PpoPlayerContinuous(BasePlayer): + def __init__(self, params): BasePlayer.__init__(self, params) self.network = self.config['network'] @@ -77,7 +78,9 @@ def restore(self, fn): def reset(self): self.init_rnn() + class PpoPlayerDiscrete(BasePlayer): + def __init__(self, params): BasePlayer.__init__(self, params) @@ -176,6 +179,7 @@ def reset(self): class SACPlayer(BasePlayer): + def __init__(self, params): BasePlayer.__init__(self, params) self.network = self.config['network'] @@ -212,11 +216,64 @@ def restore(self, fn): def get_action(self, obs, is_determenistic=False): if self.has_batch_dimension == False: obs = unsqueeze_obs(obs) + dist = self.model.actor(obs) actions = dist.sample() if is_determenistic else dist.mean actions = actions.clamp(*self.action_range).to(self.device) if self.has_batch_dimension == False: actions = torch.squeeze(actions.detach()) + + return actions + + def reset(self): + pass + + +class SHACPlayer(BasePlayer): + + def __init__(self, params): + BasePlayer.__init__(self, params) + self.network = self.config['network'] + self.actions_num = self.action_space.shape[0] + self.action_range = [ + float(self.env_info['action_space'].low.min()), + float(self.env_info['action_space'].high.max()) + ] + + obs_shape = self.obs_shape + self.normalize_input = False + config = { + 'obs_dim': self.env_info["observation_space"].shape[0], + 'action_dim': self.env_info["action_space"].shape[0], + 'actions_num' : self.actions_num, + 'input_shape' : obs_shape, + 'value_size': self.env_info.get('value_size', 1), + 'normalize_value': False, + 'normalize_input': self.normalize_input, + } + self.model = self.network.build(config) + self.model.to(self.device) + self.model.eval() + self.is_rnn = self.model.is_rnn() + + def restore(self, fn): + checkpoint = torch_ext.load_checkpoint(fn) + self.model.sac_network.actor.load_state_dict(checkpoint['actor']) + self.model.sac_network.critic.load_state_dict(checkpoint['critic']) + self.model.sac_network.critic_target.load_state_dict(checkpoint['critic_target']) + if self.normalize_input and 'running_mean_std' in checkpoint: + self.model.running_mean_std.load_state_dict(checkpoint['running_mean_std']) + + def get_action(self, obs, is_determenistic=False): + if self.has_batch_dimension == False: + obs = unsqueeze_obs(obs) + + dist = self.model.actor(obs) + actions = dist.sample() if is_determenistic else dist.mean + actions = actions.clamp(*self.action_range).to(self.device) + if self.has_batch_dimension == False: + actions = torch.squeeze(actions.detach()) + return actions def reset(self): diff --git a/rl_games/algos_torch/sac_agent.py b/rl_games/algos_torch/sac_agent.py index b9fbeb0b..a6361ac7 100644 --- a/rl_games/algos_torch/sac_agent.py +++ b/rl_games/algos_torch/sac_agent.py @@ -17,8 +17,8 @@ import time - class SACAgent(BaseAlgorithm): + def __init__(self, base_name, params): self.config = config = params['config'] print(config) diff --git a/rl_games/algos_torch/shac_agent.py b/rl_games/algos_torch/shac_agent.py new file mode 100644 index 00000000..865c5801 --- /dev/null +++ b/rl_games/algos_torch/shac_agent.py @@ -0,0 +1,572 @@ +from rl_games.algos_torch import torch_ext + +from rl_games.algos_torch.running_mean_std import RunningMeanStd + +from rl_games.common import vecenv +from rl_games.common import schedulers +from rl_games.common import experience +from rl_games.interfaces.base_algorithm import BaseAlgorithm +from torch.utils.tensorboard import SummaryWriter +from datetime import datetime +from rl_games.algos_torch import model_builder + +import torch +from torch import nn, optim +import torch.nn.functional as F +import numpy as np +import time + + +class SHACAgent(BaseAlgorithm): + + def __init__(self, base_name, params): + + self.config = config = params['config'] + print(config) + + # TODO: Get obs shape and self.network + self.load_networks(params) + self.base_init(base_name, config) + self.horizon_length = config["horizon_length"] + self.gamma = config["gamma"] + self.critic_tau = config["critic_tau"] + self.batch_size = config["batch_size"] + + self.num_steps_per_episode = config.get("num_steps_per_episode", 1) + self.normalize_input = config.get("normalize_input", False) + + self.max_env_steps = config.get("max_env_steps", 1000) # temporary, in future we will use other approach + + print(self.batch_size, self.num_actors, self.num_agents) + + self.num_frames_per_epoch = self.num_actors * self.num_steps_per_episode + + action_space = self.env_info['action_space'] + self.actions_num = action_space.shape[0] + + self.action_range = [ + float(self.env_info['action_space'].low.min()), + float(self.env_info['action_space'].high.max()) + ] + + obs_shape = torch_ext.shape_whc_to_cwh(self.obs_shape) + net_config = { + 'obs_dim': self.env_info["observation_space"].shape[0], + 'action_dim': self.env_info["action_space"].shape[0], + 'actions_num' : self.actions_num, + 'input_shape' : obs_shape, + 'normalize_input' : self.normalize_input, + 'normalize_input': self.normalize_input, + } + self.model = self.network.build(net_config) + self.model.to(self.shac_device) + + print("Number of Agents", self.num_actors, "Batch Size", self.batch_size) + + self.actor_optimizer = torch.optim.Adam(self.model.shac_network.actor.parameters(), + lr=self.config['actor_lr'], + betas=self.config.get("actor_betas", [0.9, 0.999])) + + self.critic_optimizer = torch.optim.Adam(self.model.shac_network.critic.parameters(), + lr=self.config["critic_lr"], + betas=self.config.get("critic_betas", [0.9, 0.999])) + + self.step = 0 + self.algo_observer = config['features']['observer'] + + + # TODO: Is there a better way to get the maximum number of episodes? + #self.max_episodes = torch.ones(self.num_actors, device=self.shac_device)*self.num_steps_per_episode + + def load_networks(self, params): + builder = model_builder.ModelBuilder() + self.config['network'] = builder.load(params) + + def base_init(self, base_name, config): + self.env_config = config.get('env_config', {}) + self.num_actors = config.get('num_actors', 1) + self.env_name = config['env_name'] + print("Env name:", self.env_name) + + self.env_info = config.get('env_info') + if self.env_info is None: + self.vec_env = vecenv.create_vec_env(self.env_name, self.num_actors, **self.env_config) + self.env_info = self.vec_env.get_env_info() + + self.shac_device = config.get('device', 'cuda:0') + + print('Env info:') + print(self.env_info) + + # shac params + self.gamma = config['params']['config'].get('gamma', 0.99) + + self.critic_method = config['params']['config'].get('critic_method', 'td-lambda') # ['one-step', 'td-lambda'] + if self.critic_method == 'td-lambda': + self.lam = config['params']['config'].get('lambda', 0.95) + + # self.steps_num = cfg["params"]["config"]["steps_num"] + self.max_epochs = config["params"]["config"]["max_epochs"] # add get + self.actor_lr = float(config["params"]["config"]["actor_learning_rate"]) + self.critic_lr = float(config['params']['config']['critic_learning_rate']) + self.lr_schedule = config['params']['config'].get('lr_schedule', 'linear') + + self.target_critic_alpha = config['params']['config'].get('target_critic_alpha', 0.4) + + self.obs_rms = None + if config.get('obs_rms', False): + self.obs_rms = RunningMeanStd(shape = (self.num_obs), device = self.device) + + self.ret_rms = None + if config.get('ret_rms', False): + self.ret_rms = RunningMeanStd(shape = (), device = self.device) + + #self.rew_scale = cfg['params']['config'].get('rew_scale', 1.0) + + self.critic_iterations = config.get('critic_iterations', 16) + self.num_batch = config.get('num_batch', 4) + self.batch_size = self.num_actors * self.horizon_length // self.num_batch + self.name = config.get('name', "Ant") + + self.truncate_grad = config["params"]["config"]["truncate_grads"] + self.grad_norm = config["params"]["config"]["grad_norm"] + ########### + + self.rewards_shaper = config['reward_shaper'] + self.observation_space = self.env_info['observation_space'] + self.weight_decay = config.get('weight_decay', 0.0) + #self.use_action_masks = config.get('use_action_masks', False) + self.is_train = config.get('is_train', True) + + self.c_loss = nn.MSELoss() + # self.c2_loss = nn.SmoothL1Loss() + + self.save_best_after = config.get('save_best_after', 100) + self.print_stats = config.get('print_stats', True) + self.rnn_states = None + self.name = base_name + + self.max_epochs = self.config.get('max_epochs', 1e6) + + self.network = config['network'] + self.rewards_shaper = config['reward_shaper'] + self.num_agents = self.env_info.get('agents', 1) + self.obs_shape = self.observation_space.shape + + self.games_to_track = self.config.get('games_to_track', 100) + self.game_rewards = torch_ext.AverageMeter(1, self.games_to_track).to(self.shac_device) + self.game_lengths = torch_ext.AverageMeter(1, self.games_to_track).to(self.shac_device) + self.obs = None + + self.frame = 0 + self.update_time = 0 + self.last_mean_rewards = -100500 + self.play_time = 0 + self.epoch_num = 0 + + self.writer = SummaryWriter('runs/' + config['name'] + datetime.now().strftime("_%d-%H-%M-%S")) + print("Run Directory:", config['name'] + datetime.now().strftime("_%d-%H-%M-%S")) + + self.is_tensor_obses = None + self.is_rnn = False + self.last_rnn_indices = None + self.last_state_indices = None + + # shac + + # replay buffer + self.obs_buf = torch.zeros((self.steps_num, self.num_envs, self.num_obs), dtype = torch.float32, device = self.device) + self.rew_buf = torch.zeros((self.steps_num, self.num_envs), dtype = torch.float32, device = self.device) + self.done_mask = torch.zeros((self.steps_num, self.num_envs), dtype = torch.float32, device = self.device) + self.next_values = torch.zeros((self.steps_num, self.num_envs), dtype = torch.float32, device = self.device) + self.target_values = torch.zeros((self.steps_num, self.num_envs), dtype = torch.float32, device = self.device) + self.ret = torch.zeros((self.num_envs), dtype = torch.float32, device = self.device) + + # for kl divergence computing + self.old_mus = torch.zeros((self.steps_num, self.num_envs, self.num_actions), dtype = torch.float32, device = self.device) + self.old_sigmas = torch.zeros((self.steps_num, self.num_envs, self.num_actions), dtype = torch.float32, device = self.device) + self.mus = torch.zeros((self.steps_num, self.num_envs, self.num_actions), dtype = torch.float32, device = self.device) + self.sigmas = torch.zeros((self.steps_num, self.num_envs, self.num_actions), dtype = torch.float32, device = self.device) + + # counting variables + self.iter_count = 0 + self.step_count = 0 + + # loss variables + self.episode_length_his = [] + self.episode_loss_his = [] + self.episode_discounted_loss_his = [] + self.episode_loss = torch.zeros(self.num_envs, dtype = torch.float32, device = self.device) + self.episode_discounted_loss = torch.zeros(self.num_envs, dtype = torch.float32, device = self.device) + self.episode_gamma = torch.ones(self.num_envs, dtype = torch.float32, device = self.device) + self.episode_length = torch.zeros(self.num_envs, dtype = int) + self.best_policy_loss = np.inf + self.actor_loss = np.inf + self.value_loss = np.inf + + # average meter + self.episode_loss_meter = AverageMeter(1, 100).to(self.device) + self.episode_discounted_loss_meter = AverageMeter(1, 100).to(self.device) + self.episode_length_meter = AverageMeter(1, 100).to(self.device) + + # timer + self.time_report = TimeReport() + self.all_params = list(self.actor.parameters()) + list(self.critic.parameters()) + self.target_critic = copy.deepcopy(self.critic) + + # if cfg['params']['general']['train']: + # self.save('init_policy') + + def init_tensors(self): + if self.observation_space.dtype == np.uint8: + torch_dtype = torch.uint8 + else: + torch_dtype = torch.float32 + batch_size = self.num_agents * self.num_actors + + self.current_rewards = torch.zeros(batch_size, dtype=torch.float32, device=self.shac_device) + self.current_lengths = torch.zeros(batch_size, dtype=torch.long, device=self.shac_device) + + self.dones = torch.zeros((batch_size,), dtype=torch.uint8, device=self.shac_device) + + @property + def device(self): + return self.shac_device + + def get_full_state_weights(self): + state = self.get_weights() + + state['steps'] = self.step + state['actor_optimizer'] = self.actor_optimizer.state_dict() + state['critic_optimizer'] = self.critic_optimizer.state_dict() + + return state + + def get_weights(self): + state = {'actor': self.model.shac_network.actor.state_dict(), + 'critic': self.model.shac_network.critic.state_dict(), + 'critic_target': self.model.shac_network.critic_target.state_dict()} + return state + + def save(self, fn): + state = self.get_full_state_weights() + torch_ext.save_checkpoint(fn, state) + + def set_weights(self, weights): + self.model.shac_network.actor.load_state_dict(weights['actor']) + self.model.shac_network.critic.load_state_dict(weights['critic']) + self.model.shac_network.critic_target.load_state_dict(weights['critic_target']) + + if self.normalize_input and 'running_mean_std' in weights: + self.model.running_mean_std.load_state_dict(weights['running_mean_std']) + + def set_full_state_weights(self, weights): + self.set_weights(weights) + + self.step = weights['step'] + self.actor_optimizer.load_state_dict(weights['actor_optimizer']) + self.critic_optimizer.load_state_dict(weights['critic_optimizer']) + + def restore(self, fn): + checkpoint = torch_ext.load_checkpoint(fn) + self.set_full_state_weights(checkpoint) + + def get_masked_action_values(self, obs, action_masks): + assert False + + def set_eval(self): + self.model.eval() + + def set_train(self): + self.model.train() + + def update_critic(self, obs, action, reward, next_obs, not_done,step): + with torch.no_grad(): + dist = self.model.actor(next_obs) + next_action = dist.rsample() + log_prob = dist.log_prob(next_action).sum(-1, keepdim=True) + target_Q1, target_Q2 = self.model.critic_target(next_obs, next_action) + target_V = torch.min(target_Q1, target_Q2) - self.alpha * log_prob + + target_Q = reward + (not_done * self.gamma * target_V) + target_Q = target_Q.detach() + + # get current Q estimates + current_Q1, current_Q2 = self.model.critic(obs, action) + + critic1_loss = self.c_loss(current_Q1, target_Q) + critic2_loss = self.c_loss(current_Q2, target_Q) + critic_loss = critic1_loss + critic2_loss + self.critic_optimizer.zero_grad(set_to_none=True) + critic_loss.backward() + self.critic_optimizer.step() + + return critic_loss.detach(), critic1_loss.detach(), critic2_loss.detach() + + def update_actor(self, obs, step): + for p in self.model.shac_network.critic.parameters(): + p.requires_grad = False + + dist = self.model.actor(obs) + action = dist.rsample() + log_prob = dist.log_prob(action).sum(-1, keepdim=True) + entropy = dist.entropy().sum(-1, keepdim=True).mean() + actor_Q1, actor_Q2 = self.model.critic(obs, action) + actor_Q = torch.min(actor_Q1, actor_Q2) + + actor_loss = (torch.max(self.alpha.detach(), self.min_alpha) * log_prob - actor_Q) + actor_loss = actor_loss.mean() + + self.actor_optimizer.zero_grad(set_to_none=True) + actor_loss.backward() + self.actor_optimizer.step() + + for p in self.model.shac_network.critic.parameters(): + p.requires_grad = True + + if self.learnable_temperature: + alpha_loss = (self.alpha * + (-log_prob - self.target_entropy).detach()).mean() + self.log_alpha_optimizer.zero_grad(set_to_none=True) + alpha_loss.backward() + self.log_alpha_optimizer.step() + else: + alpha_loss = None + + return actor_loss.detach(), entropy.detach(), self.alpha.detach(), alpha_loss # TODO: maybe not self.alpha + + def soft_update_params(self, net, target_net, tau): + for param, target_param in zip(net.parameters(), target_net.parameters()): + target_param.data.copy_(tau * param.data + + (1 - tau) * target_param.data) + + def update(self, step): + obs, action, reward, next_obs, done = self.replay_buffer.sample(self.batch_size) + not_done = ~done + + obs = self.preproc_obs(obs) + next_obs = self.preproc_obs(next_obs) + + critic_loss, critic1_loss, critic2_loss = self.update_critic(obs, action, reward, next_obs, not_done, step) + + actor_loss, entropy, alpha, alpha_loss = self.update_actor_and_alpha(obs, step) + + actor_loss_info = actor_loss, entropy, alpha, alpha_loss + self.soft_update_params(self.model.shac_network.critic, self.model.shac_network.critic_target, + self.critic_tau) + return actor_loss_info, critic1_loss, critic2_loss + + def preproc_obs(self, obs): + if isinstance(obs, dict): + obs = obs['obs'] + return obs + + def env_step(self, actions): + if not self.is_tensor_obses: + actions = actions.cpu().numpy() + obs, rewards, dones, infos = self.vec_env.step(actions) # (obs_space) -> (n, obs_space) + + self.step += self.num_actors + if self.is_tensor_obses: + return obs, rewards, dones, infos + else: + return torch.from_numpy(obs).to(self.sac_device), torch.from_numpy(rewards).to(self.sac_device), torch.from_numpy(dones).to(self.sac_device), infos + + def env_reset(self): + with torch.no_grad(): + obs = self.vec_env.reset() + + if self.is_tensor_obses is None: + self.is_tensor_obses = torch.is_tensor(obs) + print("Observations are tensors:", self.is_tensor_obses) + + if self.is_tensor_obses: + return obs.to(self.sac_device) + else: + return torch.from_numpy(obs).to(self.sac_device) + + def act(self, obs, action_dim, sample=False): + obs = self.preproc_obs(obs) + dist = self.model.actor(obs) + actions = dist.sample() if sample else dist.mean + actions = actions.clamp(*self.action_range) + assert actions.ndim == 2 + return actions + + def extract_actor_stats(self, actor_losses, entropies, alphas, alpha_losses, actor_loss_info): + actor_loss, entropy, alpha, alpha_loss = actor_loss_info + + actor_losses.append(actor_loss) + entropies.append(entropy) + if alpha_losses is not None: + alphas.append(alpha) + alpha_losses.append(alpha_loss) + + def clear_stats(self): + self.game_rewards.clear() + self.game_lengths.clear() + self.mean_rewards = self.last_mean_rewards = -100500 + self.algo_observer.after_clear_stats() + + def play_steps(self, random_exploration=False): + total_time_start = time.time() + total_update_time = 0 + total_time = 0 + step_time = 0.0 + actor_losses = [] + entropies = [] + alphas = [] + alpha_losses = [] + critic1_losses = [] + critic2_losses = [] + + obs = self.obs + for _ in range(self.num_steps_per_episode): + self.set_eval() + if random_exploration: + action = torch.rand((self.num_actors, *self.env_info["action_space"].shape), device=self.sac_device) * 2 - 1 + else: + with torch.no_grad(): + action = self.act(obs.float(), self.env_info["action_space"].shape, sample=True) + + step_start = time.time() + + with torch.no_grad(): + next_obs, rewards, dones, infos = self.env_step(action) + step_end = time.time() + + self.current_rewards += rewards + self.current_lengths += 1 + + total_time += step_end - step_start + + step_time += (step_end - step_start) + + all_done_indices = dones.nonzero(as_tuple=False) + done_indices = all_done_indices[::self.num_agents] + self.game_rewards.update(self.current_rewards[done_indices]) + self.game_lengths.update(self.current_lengths[done_indices]) + + not_dones = 1.0 - dones.float() + + self.algo_observer.process_infos(infos, done_indices) + + no_timeouts = self.current_lengths != self.max_env_steps + dones = dones * no_timeouts + + self.current_rewards = self.current_rewards * not_dones + self.current_lengths = self.current_lengths * not_dones + + if isinstance(obs, dict): + obs = obs['obs'] + if isinstance(next_obs, dict): + next_obs = next_obs['obs'] + + rewards = self.rewards_shaper(rewards) + + self.replay_buffer.add(obs, action, torch.unsqueeze(rewards, 1), next_obs, torch.unsqueeze(dones, 1)) + + self.obs = obs = next_obs.clone() + + if not random_exploration: + self.set_train() + update_time_start = time.time() + actor_loss_info, critic1_loss, critic2_loss = self.update(self.epoch_num) + update_time_end = time.time() + update_time = update_time_end - update_time_start + + self.extract_actor_stats(actor_losses, entropies, alphas, alpha_losses, actor_loss_info) + critic1_losses.append(critic1_loss) + critic2_losses.append(critic2_loss) + else: + update_time = 0 + + total_update_time += update_time + + total_time_end = time.time() + total_time = total_time_end - total_time_start + play_time = total_time - total_update_time + + return step_time, play_time, total_update_time, total_time, actor_losses, entropies, alphas, alpha_losses, critic1_losses, critic2_losses + + def train_epoch(self): + if self.epoch_num < self.num_seed_steps: + step_time, play_time, total_update_time, total_time, actor_losses, entropies, alphas, alpha_losses, critic1_losses, critic2_losses = self.play_steps(random_exploration=True) + else: + step_time, play_time, total_update_time, total_time, actor_losses, entropies, alphas, alpha_losses, critic1_losses, critic2_losses = self.play_steps(random_exploration=False) + + return step_time, play_time, total_update_time, total_time, actor_losses, entropies, alphas, alpha_losses, critic1_losses, critic2_losses + + def train(self): + self.init_tensors() + self.algo_observer.after_init(self) + self.last_mean_rewards = -100500 + total_time = 0 + # rep_count = 0 + self.frame = 0 + self.obs = self.env_reset() + + while True: + self.epoch_num += 1 + step_time, play_time, update_time, epoch_total_time, actor_losses, entropies, alphas, alpha_losses, critic1_losses, critic2_losses = self.train_epoch() + + total_time += epoch_total_time + + scaled_time = epoch_total_time + scaled_play_time = play_time + curr_frames = self.num_frames_per_epoch + self.frame += curr_frames + frame = self.frame #TODO: Fix frame + # print(frame) + + if self.print_stats: + fps_step = curr_frames / scaled_play_time + fps_total = curr_frames / scaled_time + print(f'fps step: {fps_step:.1f} fps total: {fps_total:.1f}') + + self.writer.add_scalar('performance/step_inference_rl_update_fps', curr_frames / scaled_time, frame) + self.writer.add_scalar('performance/step_inference_fps', curr_frames / scaled_play_time, frame) + self.writer.add_scalar('performance/step_fps', curr_frames / step_time, frame) + self.writer.add_scalar('performance/rl_update_time', update_time, frame) + self.writer.add_scalar('performance/step_inference_time', play_time, frame) + self.writer.add_scalar('performance/step_time', step_time, frame) + + if self.epoch_num >= self.num_seed_steps: + self.writer.add_scalar('losses/a_loss', torch_ext.mean_list(actor_losses).item(), frame) + self.writer.add_scalar('losses/c1_loss', torch_ext.mean_list(critic1_losses).item(), frame) + self.writer.add_scalar('losses/c2_loss', torch_ext.mean_list(critic2_losses).item(), frame) + self.writer.add_scalar('losses/entropy', torch_ext.mean_list(entropies).item(), frame) + if alpha_losses[0] is not None: + self.writer.add_scalar('losses/alpha_loss', torch_ext.mean_list(alpha_losses).item(), frame) + self.writer.add_scalar('info/alpha', torch_ext.mean_list(alphas).item(), frame) + + self.writer.add_scalar('info/epochs', self.epoch_num, frame) + self.algo_observer.after_print_stats(frame, self.epoch_num, total_time) + + if self.game_rewards.current_size > 0: + mean_rewards = self.game_rewards.get_mean() + mean_lengths = self.game_lengths.get_mean() + + self.writer.add_scalar('rewards/step', mean_rewards, frame) + # self.writer.add_scalar('rewards/iter', mean_rewards, epoch_num) + self.writer.add_scalar('rewards/time', mean_rewards, total_time) + self.writer.add_scalar('episode_lengths/step', mean_lengths, frame) + # self.writer.add_scalar('episode_lengths/iter', mean_lengths, epoch_num) + self.writer.add_scalar('episode_lengths/time', mean_lengths, total_time) + + if mean_rewards > self.last_mean_rewards and self.epoch_num >= self.save_best_after: + print('saving next best rewards: ', mean_rewards) + self.last_mean_rewards = mean_rewards + self.save("./nn/" + self.config['name']) + if self.last_mean_rewards > self.config.get('score_to_win', float('inf')): + print('Network won!') + self.save("./nn/" + self.config['name'] + 'ep=' + str(self.epoch_num) + 'rew=' + str(mean_rewards)) + return self.last_mean_rewards, self.epoch_num + + if self.epoch_num > self.max_epochs: + self.save("./nn/" + 'last_' + self.config['name'] + 'ep=' + str(self.epoch_num) + 'rew=' + str(mean_rewards)) + print('MAX EPOCHS NUM!') + return self.last_mean_rewards, self.epoch_num + update_time = 0 + + \ No newline at end of file diff --git a/rl_games/common/a2c_common.py b/rl_games/common/a2c_common.py index 9708aadc..9a0dfaa4 100644 --- a/rl_games/common/a2c_common.py +++ b/rl_games/common/a2c_common.py @@ -43,7 +43,10 @@ def rescale_actions(low, high, action): class A2CBase(BaseAlgorithm): def __init__(self, base_name, params): + self.config = config = params['config'] + + # population based training parameters pbt_str = '' self.population_based_training = config.get('population_based_training', False) if self.population_based_training: @@ -81,7 +84,6 @@ def __init__(self, base_name, params): else: self.diagnostics = DefaultDiagnostics() - self.network_path = config.get('network_path', "./nn/") self.log_path = config.get('log_path', "runs/") self.env_config = config.get('env_config', {}) diff --git a/rl_games/torch_runner.py b/rl_games/torch_runner.py index e697b1ed..0d513ccc 100644 --- a/rl_games/torch_runner.py +++ b/rl_games/torch_runner.py @@ -2,21 +2,15 @@ import random import copy import torch -import yaml -from rl_games import envs -from rl_games.common import object_factory -from rl_games.common import env_configurations -from rl_games.common import experiment -from rl_games.common import tr_helpers +from rl_games.common import object_factory, tr_helpers -from rl_games.algos_torch import model_builder from rl_games.algos_torch import a2c_continuous from rl_games.algos_torch import a2c_discrete from rl_games.algos_torch import players from rl_games.common.algo_observer import DefaultAlgoObserver -from rl_games.algos_torch import sac_agent -import rl_games.networks +from rl_games.algos_torch import sac_agent, shac_agent + def _restore(agent, args): if 'checkpoint' in args and args['checkpoint'] is not None and args['checkpoint'] !='': @@ -31,18 +25,21 @@ def _override_sigma(agent, args): net.sigma.fill_(float(args['sigma'])) else: print('Print cannot set new sigma because fixed_sigma is False') + class Runner: def __init__(self, algo_observer=None): self.algo_factory = object_factory.ObjectFactory() self.algo_factory.register_builder('a2c_continuous', lambda **kwargs : a2c_continuous.A2CAgent(**kwargs)) self.algo_factory.register_builder('a2c_discrete', lambda **kwargs : a2c_discrete.DiscreteA2CAgent(**kwargs)) self.algo_factory.register_builder('sac', lambda **kwargs: sac_agent.SACAgent(**kwargs)) + self.algo_factory.register_builder('shac', lambda **kwargs: shac_agent.SHACAgent(**kwargs)) #self.algo_factory.register_builder('dqn', lambda **kwargs : dqnagent.DQNAgent(**kwargs)) self.player_factory = object_factory.ObjectFactory() self.player_factory.register_builder('a2c_continuous', lambda **kwargs : players.PpoPlayerContinuous(**kwargs)) self.player_factory.register_builder('a2c_discrete', lambda **kwargs : players.PpoPlayerDiscrete(**kwargs)) self.player_factory.register_builder('sac', lambda **kwargs : players.SACPlayer(**kwargs)) + self.player_factory.register_builder('shac', lambda **kwargs : players.SHACPlayer(**kwargs)) #self.player_factory.register_builder('dqn', lambda **kwargs : players.DQNPlayer(**kwargs)) self.algo_observer = algo_observer if algo_observer else DefaultAlgoObserver() @@ -50,6 +47,7 @@ def __init__(self, algo_observer=None): ### it didnot help for lots for openai gym envs anyway :( #torch.backends.cudnn.deterministic = True #torch.use_deterministic_algorithms(True) + def reset(self): pass @@ -61,7 +59,6 @@ def load_config(self, params): self.exp_config = None if self.seed: - torch.manual_seed(self.seed) torch.cuda.manual_seed_all(self.seed) np.random.seed(self.seed) From 130bfc5b568de1fe75c64dd28d013dbb76c86dce Mon Sep 17 00:00:00 2001 From: ViktorM Date: Sun, 22 May 2022 17:20:46 -0700 Subject: [PATCH 02/26] SHAC agent, network, model. WIP. --- rl_games/algos_torch/a2c_continuous.py | 8 +- rl_games/algos_torch/model_builder.py | 10 +- rl_games/algos_torch/models.py | 27 ++- rl_games/algos_torch/network_builder.py | 259 +++++++++++++++++++++--- rl_games/algos_torch/sac_agent.py | 44 ++-- rl_games/algos_torch/sac_helper.py | 4 +- rl_games/algos_torch/shac_agent.py | 252 +++++++++++------------ rl_games/common/a2c_common.py | 2 +- rl_games/interfaces/base_algorithm.py | 1 + rl_games/torch_runner.py | 2 - 10 files changed, 417 insertions(+), 192 deletions(-) diff --git a/rl_games/algos_torch/a2c_continuous.py b/rl_games/algos_torch/a2c_continuous.py index 7cfd07d3..82d3263a 100644 --- a/rl_games/algos_torch/a2c_continuous.py +++ b/rl_games/algos_torch/a2c_continuous.py @@ -6,10 +6,8 @@ from rl_games.common import datasets from torch import optim -import torch -from torch import nn -import numpy as np -import gym +import torch + class A2CAgent(a2c_common.ContinuousA2CBase): def __init__(self, base_name, params): @@ -23,7 +21,7 @@ def __init__(self, base_name, params): 'normalize_value' : self.normalize_value, 'normalize_input': self.normalize_input, } - + self.model = self.network.build(build_config) self.model.to(self.ppo_device) self.states = None diff --git a/rl_games/algos_torch/model_builder.py b/rl_games/algos_torch/model_builder.py index c2045c5e..8acb2318 100644 --- a/rl_games/algos_torch/model_builder.py +++ b/rl_games/algos_torch/model_builder.py @@ -1,7 +1,6 @@ from rl_games.common import object_factory -import rl_games.algos_torch -from rl_games.algos_torch import network_builder -from rl_games.algos_torch import models +from rl_games.algos_torch import network_builder, models + NETWORK_REGISTRY = {} MODEL_REGISTRY = {} @@ -14,6 +13,7 @@ def register_model(name, target_class): class NetworkBuilder: + def __init__(self): self.network_factory = object_factory.ObjectFactory() self.network_factory.set_builders(NETWORK_REGISTRY) @@ -22,6 +22,7 @@ def __init__(self): lambda **kwargs: network_builder.A2CResnetBuilder()) self.network_factory.register_builder('rnd_curiosity', lambda **kwargs: network_builder.RNDCuriosityBuilder()) self.network_factory.register_builder('soft_actor_critic', lambda **kwargs: network_builder.SACBuilder()) + self.network_factory.register_builder('shac_actor_critic', lambda **kwargs: network_builder.SHACBuilder()) def load(self, params): network_name = params['name'] @@ -32,6 +33,7 @@ def load(self, params): class ModelBuilder: + def __init__(self): self.model_factory = object_factory.ObjectFactory() self.model_factory.set_builders(MODEL_REGISTRY) @@ -44,6 +46,8 @@ def __init__(self): lambda network, **kwargs: models.ModelA2CContinuousLogStd(network)) self.model_factory.register_builder('soft_actor_critic', lambda network, **kwargs: models.ModelSACContinuous(network)) + self.model_factory.register_builder('shac_actor_critic', + lambda network, **kwargs: models.ModelSHAC(network)) self.model_factory.register_builder('central_value', lambda network, **kwargs: models.ModelCentralValue(network)) self.network_builder = NetworkBuilder() diff --git a/rl_games/algos_torch/models.py b/rl_games/algos_torch/models.py index 8a1ad5c9..7d8f91ab 100644 --- a/rl_games/algos_torch/models.py +++ b/rl_games/algos_torch/models.py @@ -1,4 +1,3 @@ -import rl_games.algos_torch.layers import numpy as np import torch.nn as nn import torch @@ -28,7 +27,9 @@ def build(self, config): return self.Network(self.network_builder.build(self.model_class, **config), obs_shape=obs_shape, normalize_value=normalize_value, normalize_input=normalize_input, value_size=value_size) + class BaseModelNetwork(nn.Module): + def __init__(self, obs_shape, normalize_value, normalize_input, value_size): nn.Module.__init__(self) self.obs_shape = obs_shape @@ -52,6 +53,7 @@ def unnorm_value(self, value): with torch.no_grad(): return self.value_mean_std(value, unnorm=True) if self.normalize_value else value + class ModelA2C(BaseModel): def __init__(self, network): BaseModel.__init__(self, 'a2c') @@ -64,7 +66,7 @@ def __init__(self, a2c_network, **kwargs): def is_rnn(self): return self.a2c_network.is_rnn() - + def get_default_rnn_state(self): return self.a2c_network.get_default_rnn_state() @@ -105,7 +107,9 @@ def forward(self, input_dict): } return result + class ModelA2CMultiDiscrete(BaseModel): + def __init__(self, network): BaseModel.__init__(self, 'a2c') self.network_builder = network @@ -171,6 +175,7 @@ def forward(self, input_dict): class ModelA2CContinuous(BaseModel): + def __init__(self, network): BaseModel.__init__(self, 'a2c') self.network_builder = network @@ -226,6 +231,7 @@ def forward(self, input_dict): class ModelA2CContinuousLogStd(BaseModel): + def __init__(self, network): BaseModel.__init__(self, 'a2c') self.network_builder = network @@ -280,6 +286,7 @@ def neglogp(self, x, mean, std, logstd): class ModelCentralValue(BaseModel): + def __init__(self, network): BaseModel.__init__(self, 'a2c') self.network_builder = network @@ -318,10 +325,10 @@ class ModelSACContinuous(BaseModel): def __init__(self, network): BaseModel.__init__(self, 'sac') self.network_builder = network - + class Network(BaseModelNetwork): - def __init__(self, sac_network,**kwargs): - BaseModelNetwork.__init__(self,**kwargs) + def __init__(self, sac_network, **kwargs): + BaseModelNetwork.__init__(self, **kwargs) self.sac_network = sac_network def critic(self, obs, action): @@ -332,7 +339,7 @@ def critic_target(self, obs, action): def actor(self, obs): return self.sac_network.actor(obs) - + def is_rnn(self): return False @@ -348,10 +355,10 @@ class ModelSHAC(BaseModel): def __init__(self, network): BaseModel.__init__(self, 'shac') self.network_builder = network - + class Network(BaseModelNetwork): - def __init__(self, shac_network,**kwargs): - BaseModelNetwork.__init__(self,**kwargs) + def __init__(self, shac_network, **kwargs): + BaseModelNetwork.__init__(self, **kwargs) self.shac_network = shac_network def critic(self, obs, action): @@ -362,7 +369,7 @@ def critic_target(self, obs, action): def actor(self, obs): return self.shac_network.actor(obs) - + def is_rnn(self): return False diff --git a/rl_games/algos_torch/network_builder.py b/rl_games/algos_torch/network_builder.py index 2424aa25..c10af008 100644 --- a/rl_games/algos_torch/network_builder.py +++ b/rl_games/algos_torch/network_builder.py @@ -180,14 +180,14 @@ def __init__(self, params, **kwargs): actions_num = kwargs.pop('actions_num') input_shape = kwargs.pop('input_shape') self.value_size = kwargs.pop('value_size', 1) - self.num_seqs = num_seqs = kwargs.pop('num_seqs', 1) + self.num_seqs = kwargs.pop('num_seqs', 1) NetworkBuilder.BaseNetwork.__init__(self) self.load(params) self.actor_cnn = nn.Sequential() self.critic_cnn = nn.Sequential() self.actor_mlp = nn.Sequential() self.critic_mlp = nn.Sequential() - + if self.has_cnn: if self.permute_input: input_shape = torch_ext.shape_whc_to_cwh(input_shape) @@ -344,6 +344,7 @@ def forward(self, obs_dict): c_out = c_out.transpose(0,1) a_out = a_out.contiguous().reshape(a_out.size()[0] * a_out.size()[1], -1) c_out = c_out.contiguous().reshape(c_out.size()[0] * c_out.size()[1], -1) + if self.rnn_ln: a_out = self.a_layer_norm(a_out) c_out = self.c_layer_norm(c_out) @@ -358,7 +359,7 @@ def forward(self, obs_dict): else: a_out = self.actor_mlp(a_out) c_out = self.critic_mlp(c_out) - + value = self.value_act(self.value(c_out)) if self.is_discrete: @@ -431,7 +432,7 @@ def forward(self, obs_dict): else: sigma = self.sigma_act(self.sigma(out)) return mu, mu*0 + sigma, value, states - + def is_separate_critic(self): return self.separate @@ -491,7 +492,7 @@ def load(self, params): self.is_discrete = False self.is_continuous = False self.is_multi_discrete = False - + if self.has_rnn: self.rnn_units = params['rnn']['units'] self.rnn_layers = params['rnn']['layers'] @@ -511,6 +512,7 @@ def build(self, name, **kwargs): net = A2CBuilder.Network(self.params, **kwargs) return net + class Conv2dAuto(nn.Conv2d): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -547,7 +549,7 @@ def __init__(self, channels, activation='relu', use_bn=False, use_zero_init=True if use_attention: self.ca = ChannelAttention(channels) self.sa = SpatialAttention() - + def forward(self, x): residual = x x = self.activate1(x) @@ -592,17 +594,17 @@ def __init__(self, params, **kwargs): actions_num = kwargs.pop('actions_num') input_shape = kwargs.pop('input_shape') input_shape = torch_ext.shape_whc_to_cwh(input_shape) - self.num_seqs = num_seqs = kwargs.pop('num_seqs', 1) + self.num_seqs = kwargs.pop('num_seqs', 1) self.value_size = kwargs.pop('value_size', 1) NetworkBuilder.BaseNetwork.__init__(self, **kwargs) self.load(params) - + self.cnn = self._build_impala(input_shape, self.conv_depths) mlp_input_shape = self._calc_input_size(input_shape, self.cnn) in_mlp_shape = mlp_input_shape - + if len(self.units) == 0: out_size = mlp_input_shape else: @@ -617,7 +619,7 @@ def __init__(self, params, **kwargs): in_mlp_shape = self.rnn_units self.rnn = self._build_rnn(self.rnn_name, rnn_in_size, self.rnn_units, self.rnn_layers) #self.layer_norm = torch.nn.LayerNorm(self.rnn_units) - + mlp_args = { 'input_size' : in_mlp_shape, 'units' :self.units, @@ -635,9 +637,9 @@ def __init__(self, params, **kwargs): self.logits = torch.nn.Linear(out_size, actions_num) if self.is_continuous: self.mu = torch.nn.Linear(out_size, actions_num) - self.mu_act = self.activations_factory.create(self.space_config['mu_activation']) + self.mu_act = self.activations_factory.create(self.space_config['mu_activation']) mu_init = self.init_factory.create(**self.space_config['mu_init']) - self.sigma_act = self.activations_factory.create(self.space_config['sigma_activation']) + self.sigma_act = self.activations_factory.create(self.space_config['sigma_activation']) sigma_init = self.init_factory.create(**self.space_config['sigma_init']) if self.fixed_sigma: @@ -646,7 +648,7 @@ def __init__(self, params, **kwargs): self.sigma = torch.nn.Linear(out_size, actions_num) mlp_init = self.init_factory.create(**self.initializer) - + for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out') @@ -654,7 +656,7 @@ def __init__(self, params, **kwargs): for m in self.mlp: if isinstance(m, nn.Linear): mlp_init(m.weight) - + if self.is_discrete: mlp_init(self.logits.weight) if self.is_continuous: @@ -664,7 +666,7 @@ def __init__(self, params, **kwargs): else: sigma_init(self.sigma.weight) - mlp_init(self.value.weight) + mlp_init(self.value.weight) def forward(self, obs_dict): obs = obs_dict['obs'] @@ -675,7 +677,7 @@ def forward(self, obs_dict): out = self.cnn(out) out = out.flatten(1) out = self.flatten_act(out) - + if self.has_rnn: if not self.is_rnn_before_mlp: out = self.mlp(out) @@ -765,6 +767,7 @@ def build(self, name, **kwargs): net = A2CResnetBuilder.Network(self.params, **kwargs) return net + class DiagGaussianActor(NetworkBuilder.BaseNetwork): """torch.distributions implementation of an diagonal Gaussian policy.""" def __init__(self, output_dim, log_std_bounds, **mlp_args): @@ -793,6 +796,7 @@ def forward(self, obs): # Modify to only return mu and std return dist + class DoubleQCritic(NetworkBuilder.BaseNetwork): """Critic network, employes double Q-learning.""" def __init__(self, output_dim, **mlp_args): @@ -806,7 +810,6 @@ def __init__(self, output_dim, **mlp_args): last_layer = list(self.Q2.children())[-2].out_features self.Q2 = nn.Sequential(*list(self.Q2.children()), nn.Linear(last_layer, output_dim)) - def forward(self, obs, action): assert obs.size(0) == action.size(0) @@ -823,24 +826,23 @@ def __init__(self, **kwargs): def load(self, params): self.params = params - + def build(self, name, **kwargs): net = SACBuilder.Network(self.params, **kwargs) return net - + class Network(NetworkBuilder.BaseNetwork): def __init__(self, params, **kwargs): actions_num = kwargs.pop('actions_num') input_shape = kwargs.pop('input_shape') obs_dim = kwargs.pop('obs_dim') action_dim = kwargs.pop('action_dim') - self.num_seqs = num_seqs = kwargs.pop('num_seqs', 1) + self.num_seqs = kwargs.pop('num_seqs', 1) NetworkBuilder.BaseNetwork.__init__(self) self.load(params) mlp_input_shape = input_shape - actor_mlp_args = { 'input_size' : obs_dim, 'units' : self.units, @@ -881,6 +883,110 @@ def __init__(self, params, **kwargs): if getattr(m, "bias", None) is not None: torch.nn.init.zeros_(m.bias) + def _build_critic(self, output_dim, **mlp_args): + return DoubleQCritic(output_dim, **mlp_args) + + def _build_actor(self, output_dim, log_std_bounds, **mlp_args): + return DiagGaussianActor(output_dim, log_std_bounds, **mlp_args) + + def forward(self, obs_dict): + """TODO""" + obs = obs_dict['obs'] + mu, sigma = self.actor(obs) + return mu, sigma + + def is_separate_critic(self): + return self.separate + + def load(self, params): + self.separate = params.get('separate', True) + self.units = params['mlp']['units'] + self.activation = params['mlp']['activation'] + self.initializer = params['mlp']['initializer'] + self.is_d2rl = params['mlp'].get('d2rl', False) + self.norm_only_first_layer = params['mlp'].get('norm_only_first_layer', False) + self.value_activation = params.get('value_activation', 'None') + self.normalization = params.get('normalization', None) + self.has_space = 'space' in params + self.value_shape = params.get('value_shape', 1) + self.central_value = params.get('central_value', False) + self.joint_obs_actions_config = params.get('joint_obs_actions', None) + self.log_std_bounds = params.get('log_std_bounds', None) + + if self.has_space: + self.is_discrete = 'discrete' in params['space'] + self.is_continuous = 'continuous'in params['space'] + if self.is_continuous: + self.space_config = params['space']['continuous'] + elif self.is_discrete: + self.space_config = params['space']['discrete'] + else: + self.is_discrete = False + self.is_continuous = False + + +class SACBuilder(NetworkBuilder): + def __init__(self, **kwargs): + NetworkBuilder.__init__(self) + + def load(self, params): + self.params = params + + def build(self, name, **kwargs): + net = SHACBuilder.Network(self.params, **kwargs) + return net + + class Network(NetworkBuilder.BaseNetwork): + def __init__(self, params, **kwargs): + actions_num = kwargs.pop('actions_num') + input_shape = kwargs.pop('input_shape') + obs_dim = kwargs.pop('obs_dim') + action_dim = kwargs.pop('action_dim') + self.num_seqs = kwargs.pop('num_seqs', 1) + NetworkBuilder.BaseNetwork.__init__(self) + self.load(params) + + mlp_input_shape = input_shape + + actor_mlp_args = { + 'input_size' : obs_dim, + 'units' : self.units, + 'activation' : self.activation, + 'norm_func_name' : self.normalization, + 'dense_func' : torch.nn.Linear, + 'd2rl' : self.is_d2rl, + 'norm_only_first_layer' : self.norm_only_first_layer + } + + critic_mlp_args = { + 'input_size' : obs_dim + action_dim, + 'units' : self.units, + 'activation' : self.activation, + 'norm_func_name' : self.normalization, + 'dense_func' : torch.nn.Linear, + 'd2rl' : self.is_d2rl, + 'norm_only_first_layer' : self.norm_only_first_layer + } + print("Building Actor") + self.actor = self._build_actor(2*action_dim, self.log_std_bounds, **actor_mlp_args) + + if self.separate: + print("Building Critic") + self.critic = self._build_critic(1, **critic_mlp_args) + print("Building Critic Target") + self.critic_target = self._build_critic(1, **critic_mlp_args) + self.critic_target.load_state_dict(self.critic.state_dict()) + + mlp_init = self.init_factory.create(**self.initializer) + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d): + cnn_init(m.weight) + if getattr(m, "bias", None) is not None: + torch.nn.init.zeros_(m.bias) + if isinstance(m, nn.Linear): + mlp_init(m.weight) + if getattr(m, "bias", None) is not None: + torch.nn.init.zeros_(m.bias) def _build_critic(self, output_dim, **mlp_args): return DoubleQCritic(output_dim, **mlp_args) @@ -893,8 +999,7 @@ def forward(self, obs_dict): obs = obs_dict['obs'] mu, sigma = self.actor(obs) return mu, sigma - - + def is_separate_critic(self): return self.separate @@ -913,6 +1018,7 @@ def load(self, params): self.joint_obs_actions_config = params.get('joint_obs_actions', None) self.log_std_bounds = params.get('log_std_bounds', None) + # todo: add assert if discrete, there is no discrete action space for SHAC if self.has_space: self.is_discrete = 'discrete' in params['space'] self.is_continuous = 'continuous'in params['space'] @@ -924,4 +1030,109 @@ def load(self, params): self.is_discrete = False self.is_continuous = False - \ No newline at end of file + +class SHACBuilder(NetworkBuilder): + def __init__(self, **kwargs): + NetworkBuilder.__init__(self) + + def load(self, params): + self.params = params + + def build(self, name, **kwargs): + net = SHACBuilder.Network(self.params, **kwargs) + return net + + class Network(NetworkBuilder.BaseNetwork): + def __init__(self, params, **kwargs): + actions_num = kwargs.pop('actions_num') + input_shape = kwargs.pop('input_shape') + obs_dim = kwargs.pop('obs_dim') + action_dim = kwargs.pop('action_dim') + self.num_seqs = kwargs.pop('num_seqs', 1) + + NetworkBuilder.BaseNetwork.__init__(self) + self.load(params) + + mlp_input_shape = input_shape + + actor_mlp_args = { + 'input_size' : obs_dim, + 'units' : self.units, + 'activation' : self.activation, + 'norm_func_name' : self.normalization, + 'dense_func' : torch.nn.Linear, + 'd2rl' : self.is_d2rl, + 'norm_only_first_layer' : self.norm_only_first_layer + } + + critic_mlp_args = { + 'input_size' : obs_dim + action_dim, + 'units' : self.units, + 'activation' : self.activation, + 'norm_func_name' : self.normalization, + 'dense_func' : torch.nn.Linear, + 'd2rl' : self.is_d2rl, + 'norm_only_first_layer' : self.norm_only_first_layer + } + print("Building Actor") + self.actor = self._build_actor(2*action_dim, self.log_std_bounds, **actor_mlp_args) + + if self.separate: + print("Building Critic") + self.critic = self._build_critic(1, **critic_mlp_args) + print("Building Critic Target") + self.critic_target = self._build_critic(1, **critic_mlp_args) + self.critic_target.load_state_dict(self.critic.state_dict()) + + mlp_init = self.init_factory.create(**self.initializer) + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d): + cnn_init(m.weight) + if getattr(m, "bias", None) is not None: + torch.nn.init.zeros_(m.bias) + if isinstance(m, nn.Linear): + mlp_init(m.weight) + if getattr(m, "bias", None) is not None: + torch.nn.init.zeros_(m.bias) + + def _build_critic(self, output_dim, **mlp_args): + return DoubleQCritic(output_dim, **mlp_args) + + def _build_actor(self, output_dim, log_std_bounds, **mlp_args): + return DiagGaussianActor(output_dim, log_std_bounds, **mlp_args) + + def forward(self, obs_dict): + """TODO""" + obs = obs_dict['obs'] + mu, sigma = self.actor(obs) + return mu, sigma + + def is_separate_critic(self): + return self.separate + + def load(self, params): + self.separate = params.get('separate', True) + self.units = params['mlp']['units'] + self.activation = params['mlp']['activation'] + self.initializer = params['mlp']['initializer'] + self.is_d2rl = params['mlp'].get('d2rl', False) + self.norm_only_first_layer = params['mlp'].get('norm_only_first_layer', False) + self.value_activation = params.get('value_activation', 'None') + self.normalization = params.get('normalization', None) + self.has_space = 'space' in params + self.value_shape = params.get('value_shape', 1) + self.central_value = params.get('central_value', False) + self.joint_obs_actions_config = params.get('joint_obs_actions', None) + self.log_std_bounds = params.get('log_std_bounds', None) + + if self.has_space: + # todo: add assert if discrete, there is no discrete action space for SHAC + self.is_discrete = 'discrete' in params['space'] + self.is_continuous = 'continuous'in params['space'] + if self.is_continuous: + self.space_config = params['space']['continuous'] + elif self.is_discrete: + self.space_config = params['space']['discrete'] + else: + self.is_discrete = False + self.is_continuous = False \ No newline at end of file diff --git a/rl_games/algos_torch/sac_agent.py b/rl_games/algos_torch/sac_agent.py index a6361ac7..f3cff521 100644 --- a/rl_games/algos_torch/sac_agent.py +++ b/rl_games/algos_torch/sac_agent.py @@ -1,27 +1,28 @@ from rl_games.algos_torch import torch_ext -from rl_games.algos_torch.running_mean_std import RunningMeanStd +#from rl_games.algos_torch.running_mean_std import RunningMeanStd -from rl_games.common import vecenv -from rl_games.common import schedulers -from rl_games.common import experience +from rl_games.common import vecenv, schedulers, experience from rl_games.interfaces.base_algorithm import BaseAlgorithm from torch.utils.tensorboard import SummaryWriter from datetime import datetime from rl_games.algos_torch import model_builder from torch import optim -import torch +import torch from torch import nn import torch.nn.functional as F import numpy as np import time +import gym + class SACAgent(BaseAlgorithm): def __init__(self, base_name, params): self.config = config = params['config'] print(config) + # TODO: Get obs shape and self.network self.load_networks(params) self.base_init(base_name, config) @@ -59,7 +60,7 @@ def __init__(self, base_name, params): 'input_shape' : obs_shape, 'normalize_input' : self.normalize_input, 'normalize_input': self.normalize_input, - } + } self.model = self.network.build(net_config) self.model.to(self.sac_device) @@ -87,7 +88,6 @@ def __init__(self, base_name, params): self.step = 0 self.algo_observer = config['features']['observer'] - # TODO: Is there a better way to get the maximum number of episodes? self.max_episodes = torch.ones(self.num_actors, device=self.sac_device)*self.num_steps_per_episode # self.episode_lengths = np.zeros(self.num_actors, dtype=int) @@ -132,7 +132,13 @@ def base_init(self, base_name, config): self.network = config['network'] self.rewards_shaper = config['reward_shaper'] self.num_agents = self.env_info.get('agents', 1) - self.obs_shape = self.observation_space.shape + #self.obs_shape = self.observation_space.shape + if isinstance(self.observation_space, gym.spaces.Dict): + self.obs_shape = {} + for k,v in self.observation_space.spaces.items(): + self.obs_shape[k] = v.shape + else: + self.obs_shape = self.observation_space.shape self.games_to_track = self.config.get('games_to_track', 100) self.game_rewards = torch_ext.AverageMeter(1, self.games_to_track).to(self.sac_device) @@ -140,16 +146,16 @@ def base_init(self, base_name, config): self.obs = None self.min_alpha = torch.tensor(np.log(1)).float().to(self.sac_device) - + self.frame = 0 self.update_time = 0 self.last_mean_rewards = -100500 self.play_time = 0 self.epoch_num = 0 - + self.writer = SummaryWriter('runs/' + config['name'] + datetime.now().strftime("_%d-%H-%M-%S")) print("Run Directory:", config['name'] + datetime.now().strftime("_%d-%H-%M-%S")) - + self.is_tensor_obses = None self.is_rnn = False self.last_rnn_indices = None @@ -166,7 +172,7 @@ def init_tensors(self): self.current_lengths = torch.zeros(batch_size, dtype=torch.long, device=self.sac_device) self.dones = torch.zeros((batch_size,), dtype=torch.uint8, device=self.sac_device) - + @property def alpha(self): return self.log_alpha.exp() @@ -174,7 +180,7 @@ def alpha(self): @property def device(self): return self.sac_device - + def get_full_state_weights(self): state = self.get_weights() @@ -257,7 +263,7 @@ def update_actor_and_alpha(self, obs, step): entropy = dist.entropy().sum(-1, keepdim=True).mean() actor_Q1, actor_Q2 = self.model.critic(obs, action) actor_Q = torch.min(actor_Q1, actor_Q2) - + actor_loss = (torch.max(self.alpha.detach(), self.min_alpha) * log_prob - actor_Q) actor_loss = actor_loss.mean() @@ -315,7 +321,7 @@ def env_step(self, actions): return obs, rewards, dones, infos else: return torch.from_numpy(obs).to(self.sac_device), torch.from_numpy(rewards).to(self.sac_device), torch.from_numpy(dones).to(self.sac_device), infos - + def env_reset(self): with torch.no_grad(): obs = self.vec_env.reset() @@ -323,7 +329,7 @@ def env_reset(self): if self.is_tensor_obses is None: self.is_tensor_obses = torch.is_tensor(obs) print("Observations are tensors:", self.is_tensor_obses) - + if self.is_tensor_obses: return obs.to(self.sac_device) else: @@ -339,7 +345,7 @@ def act(self, obs, action_dim, sample=False): def extract_actor_stats(self, actor_losses, entropies, alphas, alpha_losses, actor_loss_info): actor_loss, entropy, alpha, alpha_loss = actor_loss_info - + actor_losses.append(actor_loss) entropies.append(entropy) if alpha_losses is not None: @@ -511,6 +517,4 @@ def train(self): self.save("./nn/" + 'last_' + self.config['name'] + 'ep=' + str(self.epoch_num) + 'rew=' + str(mean_rewards)) print('MAX EPOCHS NUM!') return self.last_mean_rewards, self.epoch_num - update_time = 0 - - \ No newline at end of file + update_time = 0 \ No newline at end of file diff --git a/rl_games/algos_torch/sac_helper.py b/rl_games/algos_torch/sac_helper.py index 1ce94a38..c600db89 100644 --- a/rl_games/algos_torch/sac_helper.py +++ b/rl_games/algos_torch/sac_helper.py @@ -1,11 +1,10 @@ # from rl_games.algos_torch.network_builder import NetworkBuilder from torch import distributions as pyd -import torch -import torch.nn as nn import math import torch.nn.functional as F import numpy as np + class TanhTransform(pyd.transforms.Transform): domain = pyd.constraints.real codomain = pyd.constraints.interval(-1.0, 1.0) @@ -35,6 +34,7 @@ def log_abs_det_jacobian(self, x, y): # https://github.com/tensorflow/probability/commit/ef6bb176e0ebd1cf6e25c6b5cecdd2428c22963f#diff-e120f70e92e6741bca649f04fcd907b7 return 2. * (math.log(2.) - x - F.softplus(-2. * x)) + class SquashedNormal(pyd.transformed_distribution.TransformedDistribution): def __init__(self, loc, scale): self.loc = loc diff --git a/rl_games/algos_torch/shac_agent.py b/rl_games/algos_torch/shac_agent.py index 865c5801..e491302f 100644 --- a/rl_games/algos_torch/shac_agent.py +++ b/rl_games/algos_torch/shac_agent.py @@ -2,9 +2,8 @@ from rl_games.algos_torch.running_mean_std import RunningMeanStd -from rl_games.common import vecenv -from rl_games.common import schedulers -from rl_games.common import experience +from rl_games.common import vecenv, schedulers, experience + from rl_games.interfaces.base_algorithm import BaseAlgorithm from torch.utils.tensorboard import SummaryWriter from datetime import datetime @@ -16,6 +15,8 @@ import numpy as np import time +import gym + class SHACAgent(BaseAlgorithm): @@ -26,10 +27,57 @@ def __init__(self, base_name, params): # TODO: Get obs shape and self.network self.load_networks(params) - self.base_init(base_name, config) + + self.env_config = config.get('env_config', {}) + self.num_actors = config.get('num_actors', 1) + self.env_name = config['env_name'] + print("Env name:", self.env_name) + + self.env_info = config.get('env_info') + if self.env_info is None: + self.vec_env = vecenv.create_vec_env(self.env_name, self.num_actors, **self.env_config) + self.env_info = self.vec_env.get_env_info() + + self._device = config.get('device', 'cuda:0') + + print('Env info:') + print(self.env_info) + + # if cfg['train']: + # self.save('init_policy') + + self.rewards_shaper = config['reward_shaper'] + self.observation_space = self.env_info['observation_space'] + self.weight_decay = config.get('weight_decay', 0.0) + #self.use_action_masks = config.get('use_action_masks', False) + self.is_train = config.get('is_train', True) + + self.c_loss = nn.MSELoss() + # self.c2_loss = nn.SmoothL1Loss() + + self.save_best_after = config.get('save_best_after', 100) + self.print_stats = config.get('print_stats', True) + self.rnn_states = None + self.name = base_name + + self.max_epochs = self.config.get('max_epochs', 1e6) + + self.network = config['network'] + self.rewards_shaper = config['reward_shaper'] + self.num_agents = self.env_info.get('agents', 1) + + if isinstance(self.observation_space, gym.spaces.Dict): + self.obs_shape = {} + for k,v in self.observation_space.spaces.items(): + self.obs_shape[k] = v.shape + else: + self.obs_shape = self.observation_space.shape + self.horizon_length = config["horizon_length"] self.gamma = config["gamma"] - self.critic_tau = config["critic_tau"] + self.critic_coef = config.get('critic_coef', 1.0) + # self.tau = self.config['tau'] + self.critic_tau = config["critic_tau"] # todo align names with PPO impl self.batch_size = config["batch_size"] self.num_steps_per_episode = config.get("num_steps_per_episode", 1) @@ -59,7 +107,7 @@ def __init__(self, base_name, params): 'normalize_input': self.normalize_input, } self.model = self.network.build(net_config) - self.model.to(self.shac_device) + self.model.to(self.device) print("Number of Agents", self.num_actors, "Batch Size", self.batch_size) @@ -74,52 +122,31 @@ def __init__(self, base_name, params): self.step = 0 self.algo_observer = config['features']['observer'] - # TODO: Is there a better way to get the maximum number of episodes? - #self.max_episodes = torch.ones(self.num_actors, device=self.shac_device)*self.num_steps_per_episode + #self.max_episodes = torch.ones(self.num_actors, device=self.device)*self.num_steps_per_episode - def load_networks(self, params): - builder = model_builder.ModelBuilder() - self.config['network'] = builder.load(params) - def base_init(self, base_name, config): - self.env_config = config.get('env_config', {}) - self.num_actors = config.get('num_actors', 1) - self.env_name = config['env_name'] - print("Env name:", self.env_name) + # shac params + self.gamma = config.get('gamma', 0.99) - self.env_info = config.get('env_info') - if self.env_info is None: - self.vec_env = vecenv.create_vec_env(self.env_name, self.num_actors, **self.env_config) - self.env_info = self.vec_env.get_env_info() + self.critic_method = config.get('critic_method', 'td-lambda') # ['one-step', 'td-lambda'] + if self.critic_method == 'td-lambda': + self.lam = config.get('lambda', 0.95) - self.shac_device = config.get('device', 'cuda:0') + self.max_epochs = config.get("max_epochs", 1000) # add get + self.actor_lr = float(config["actor_learning_rate"]) + self.critic_lr = float(config['critic_learning_rate']) + self.lr_schedule = config.get('lr_schedule', 'linear') - print('Env info:') - print(self.env_info) + self.target_critic_alpha = config.get('target_critic_alpha', 0.4) - # shac params - self.gamma = config['params']['config'].get('gamma', 0.99) - - self.critic_method = config['params']['config'].get('critic_method', 'td-lambda') # ['one-step', 'td-lambda'] - if self.critic_method == 'td-lambda': - self.lam = config['params']['config'].get('lambda', 0.95) - - # self.steps_num = cfg["params"]["config"]["steps_num"] - self.max_epochs = config["params"]["config"]["max_epochs"] # add get - self.actor_lr = float(config["params"]["config"]["actor_learning_rate"]) - self.critic_lr = float(config['params']['config']['critic_learning_rate']) - self.lr_schedule = config['params']['config'].get('lr_schedule', 'linear') - - self.target_critic_alpha = config['params']['config'].get('target_critic_alpha', 0.4) - - self.obs_rms = None - if config.get('obs_rms', False): - self.obs_rms = RunningMeanStd(shape = (self.num_obs), device = self.device) - - self.ret_rms = None - if config.get('ret_rms', False): - self.ret_rms = RunningMeanStd(shape = (), device = self.device) + # self.obs_rms = None + # if config.get('obs_rms', False): + # self.obs_rms = RunningMeanStd(shape = (self.num_obs), device = self.device) + + # self.ret_rms = None + # if config.get('ret_rms', False): + # self.ret_rms = RunningMeanStd(shape = (), device = self.device) #self.rew_scale = cfg['params']['config'].get('rew_scale', 1.0) @@ -128,66 +155,29 @@ def base_init(self, base_name, config): self.batch_size = self.num_actors * self.horizon_length // self.num_batch self.name = config.get('name', "Ant") - self.truncate_grad = config["params"]["config"]["truncate_grads"] - self.grad_norm = config["params"]["config"]["grad_norm"] + self.truncate_grad = config.get("truncate_grads", True) + self.grad_norm = config.get("grad_norm", 1.0) ########### - self.rewards_shaper = config['reward_shaper'] - self.observation_space = self.env_info['observation_space'] - self.weight_decay = config.get('weight_decay', 0.0) - #self.use_action_masks = config.get('use_action_masks', False) - self.is_train = config.get('is_train', True) - - self.c_loss = nn.MSELoss() - # self.c2_loss = nn.SmoothL1Loss() - - self.save_best_after = config.get('save_best_after', 100) - self.print_stats = config.get('print_stats', True) - self.rnn_states = None - self.name = base_name - - self.max_epochs = self.config.get('max_epochs', 1e6) - - self.network = config['network'] - self.rewards_shaper = config['reward_shaper'] - self.num_agents = self.env_info.get('agents', 1) - self.obs_shape = self.observation_space.shape - self.games_to_track = self.config.get('games_to_track', 100) - self.game_rewards = torch_ext.AverageMeter(1, self.games_to_track).to(self.shac_device) - self.game_lengths = torch_ext.AverageMeter(1, self.games_to_track).to(self.shac_device) + self.game_rewards = torch_ext.AverageMeter(1, self.games_to_track).to(self.device) + self.game_lengths = torch_ext.AverageMeter(1, self.games_to_track).to(self.device) self.obs = None - + self.frame = 0 self.update_time = 0 self.last_mean_rewards = -100500 self.play_time = 0 self.epoch_num = 0 - + self.writer = SummaryWriter('runs/' + config['name'] + datetime.now().strftime("_%d-%H-%M-%S")) print("Run Directory:", config['name'] + datetime.now().strftime("_%d-%H-%M-%S")) - - self.is_tensor_obses = None + + self.is_tensor_obses = True self.is_rnn = False self.last_rnn_indices = None self.last_state_indices = None - # shac - - # replay buffer - self.obs_buf = torch.zeros((self.steps_num, self.num_envs, self.num_obs), dtype = torch.float32, device = self.device) - self.rew_buf = torch.zeros((self.steps_num, self.num_envs), dtype = torch.float32, device = self.device) - self.done_mask = torch.zeros((self.steps_num, self.num_envs), dtype = torch.float32, device = self.device) - self.next_values = torch.zeros((self.steps_num, self.num_envs), dtype = torch.float32, device = self.device) - self.target_values = torch.zeros((self.steps_num, self.num_envs), dtype = torch.float32, device = self.device) - self.ret = torch.zeros((self.num_envs), dtype = torch.float32, device = self.device) - - # for kl divergence computing - self.old_mus = torch.zeros((self.steps_num, self.num_envs, self.num_actions), dtype = torch.float32, device = self.device) - self.old_sigmas = torch.zeros((self.steps_num, self.num_envs, self.num_actions), dtype = torch.float32, device = self.device) - self.mus = torch.zeros((self.steps_num, self.num_envs, self.num_actions), dtype = torch.float32, device = self.device) - self.sigmas = torch.zeros((self.steps_num, self.num_envs, self.num_actions), dtype = torch.float32, device = self.device) - # counting variables self.iter_count = 0 self.step_count = 0 @@ -196,43 +186,57 @@ def base_init(self, base_name, config): self.episode_length_his = [] self.episode_loss_his = [] self.episode_discounted_loss_his = [] - self.episode_loss = torch.zeros(self.num_envs, dtype = torch.float32, device = self.device) - self.episode_discounted_loss = torch.zeros(self.num_envs, dtype = torch.float32, device = self.device) - self.episode_gamma = torch.ones(self.num_envs, dtype = torch.float32, device = self.device) - self.episode_length = torch.zeros(self.num_envs, dtype = int) + self.best_policy_loss = np.inf self.actor_loss = np.inf self.value_loss = np.inf - + # average meter - self.episode_loss_meter = AverageMeter(1, 100).to(self.device) - self.episode_discounted_loss_meter = AverageMeter(1, 100).to(self.device) - self.episode_length_meter = AverageMeter(1, 100).to(self.device) + self.episode_loss_meter = torch_ext.AverageMeter(1, 100).to(self.device) + self.episode_discounted_loss_meter = torch_ext.AverageMeter(1, 100).to(self.device) + self.episode_length_meter = torch_ext.AverageMeter(1, 100).to(self.device) # timer - self.time_report = TimeReport() - self.all_params = list(self.actor.parameters()) + list(self.critic.parameters()) - self.target_critic = copy.deepcopy(self.critic) - - # if cfg['params']['general']['train']: - # self.save('init_policy') + # self.time_report = TimeReport() + # self.all_params = list(self.actor.parameters()) + list(self.critic.parameters()) + # self.target_critic = copy.deepcopy(self.critic) + + def load_networks(self, params): + builder = model_builder.ModelBuilder() + self.config['network'] = builder.load(params) def init_tensors(self): - if self.observation_space.dtype == np.uint8: - torch_dtype = torch.uint8 - else: - torch_dtype = torch.float32 + batch_size = self.num_agents * self.num_actors - self.current_rewards = torch.zeros(batch_size, dtype=torch.float32, device=self.shac_device) - self.current_lengths = torch.zeros(batch_size, dtype=torch.long, device=self.shac_device) + self.current_rewards = torch.zeros(batch_size, dtype=torch.float32, device=self.device) + self.current_lengths = torch.zeros(batch_size, dtype=torch.long, device=self.device) + + self.dones = torch.zeros((batch_size,), dtype=torch.uint8, device=self.device) + + # replay buffer + self.obs_buf = torch.zeros((self.horizon_length, self.num_envs, self.num_obs), dtype = torch.float32, device = self.device) + self.rew_buf = torch.zeros((self.horizon_length, self.num_envs), dtype = torch.float32, device = self.device) + self.done_mask = torch.zeros((self.horizon_length, self.num_envs), dtype = torch.float32, device = self.device) + self.next_values = torch.zeros((self.horizon_length, self.num_envs), dtype = torch.float32, device = self.device) + self.target_values = torch.zeros((self.horizon_length, self.num_envs), dtype = torch.float32, device = self.device) + self.ret = torch.zeros((self.num_envs), dtype = torch.float32, device = self.device) + + # for kl divergence computing + self.old_mus = torch.zeros((self.horizon_length, self.num_envs, self.num_actions), dtype = torch.float32, device = self.device) + self.old_sigmas = torch.zeros((self.horizon_length, self.num_envs, self.num_actions), dtype = torch.float32, device = self.device) + self.mus = torch.zeros((self.horizon_length, self.num_envs, self.num_actions), dtype = torch.float32, device = self.device) + self.sigmas = torch.zeros((self.horizon_length, self.num_envs, self.num_actions), dtype = torch.float32, device = self.device) + + self.episode_loss = torch.zeros(self.num_envs, dtype = torch.float32, device = self.device) + self.episode_discounted_loss = torch.zeros(self.num_envs, dtype = torch.float32, device = self.device) + self.episode_gamma = torch.ones(self.num_envs, dtype = torch.float32, device = self.device) + self.episode_length = torch.zeros(self.num_envs, dtype = int) - self.dones = torch.zeros((batch_size,), dtype=torch.uint8, device=self.shac_device) - @property def device(self): - return self.shac_device - + return self._device + def get_full_state_weights(self): state = self.get_weights() @@ -313,7 +317,7 @@ def update_actor(self, obs, step): entropy = dist.entropy().sum(-1, keepdim=True).mean() actor_Q1, actor_Q2 = self.model.critic(obs, action) actor_Q = torch.min(actor_Q1, actor_Q2) - + actor_loss = (torch.max(self.alpha.detach(), self.min_alpha) * log_prob - actor_Q) actor_loss = actor_loss.mean() @@ -370,8 +374,8 @@ def env_step(self, actions): if self.is_tensor_obses: return obs, rewards, dones, infos else: - return torch.from_numpy(obs).to(self.sac_device), torch.from_numpy(rewards).to(self.sac_device), torch.from_numpy(dones).to(self.sac_device), infos - + return torch.from_numpy(obs).to(self.device), torch.from_numpy(rewards).to(self.device), torch.from_numpy(dones).to(self.device), infos + def env_reset(self): with torch.no_grad(): obs = self.vec_env.reset() @@ -379,11 +383,11 @@ def env_reset(self): if self.is_tensor_obses is None: self.is_tensor_obses = torch.is_tensor(obs) print("Observations are tensors:", self.is_tensor_obses) - + if self.is_tensor_obses: - return obs.to(self.sac_device) + return obs.to(self.device) else: - return torch.from_numpy(obs).to(self.sac_device) + return torch.from_numpy(obs).to(self.device) def act(self, obs, action_dim, sample=False): obs = self.preproc_obs(obs) @@ -395,7 +399,7 @@ def act(self, obs, action_dim, sample=False): def extract_actor_stats(self, actor_losses, entropies, alphas, alpha_losses, actor_loss_info): actor_loss, entropy, alpha, alpha_loss = actor_loss_info - + actor_losses.append(actor_loss) entropies.append(entropy) if alpha_losses is not None: @@ -424,7 +428,7 @@ def play_steps(self, random_exploration=False): for _ in range(self.num_steps_per_episode): self.set_eval() if random_exploration: - action = torch.rand((self.num_actors, *self.env_info["action_space"].shape), device=self.sac_device) * 2 - 1 + action = torch.rand((self.num_actors, *self.env_info["action_space"].shape), device=self.device) * 2 - 1 else: with torch.no_grad(): action = self.act(obs.float(), self.env_info["action_space"].shape, sample=True) @@ -567,6 +571,4 @@ def train(self): self.save("./nn/" + 'last_' + self.config['name'] + 'ep=' + str(self.epoch_num) + 'rew=' + str(mean_rewards)) print('MAX EPOCHS NUM!') return self.last_mean_rewards, self.epoch_num - update_time = 0 - - \ No newline at end of file + update_time = 0 \ No newline at end of file diff --git a/rl_games/common/a2c_common.py b/rl_games/common/a2c_common.py index 9a0dfaa4..e15d85b9 100644 --- a/rl_games/common/a2c_common.py +++ b/rl_games/common/a2c_common.py @@ -159,7 +159,7 @@ def __init__(self, base_name, params): self.truncate_grads = self.config.get('truncate_grads', False) self.has_phasic_policy_gradients = False - if isinstance(self.observation_space,gym.spaces.Dict): + if isinstance(self.observation_space, gym.spaces.Dict): self.obs_shape = {} for k,v in self.observation_space.spaces.items(): self.obs_shape[k] = v.shape diff --git a/rl_games/interfaces/base_algorithm.py b/rl_games/interfaces/base_algorithm.py index 054483f7..2783249a 100644 --- a/rl_games/interfaces/base_algorithm.py +++ b/rl_games/interfaces/base_algorithm.py @@ -1,6 +1,7 @@ from abc import ABC from abc import abstractmethod, abstractproperty + class BaseAlgorithm(ABC): def __init__(self, base_name, config): pass diff --git a/rl_games/torch_runner.py b/rl_games/torch_runner.py index 0d513ccc..143efaeb 100644 --- a/rl_games/torch_runner.py +++ b/rl_games/torch_runner.py @@ -99,8 +99,6 @@ def reset(self): pass def run(self, args): - load_path = None - if args['train']: self.run_train(args) From afe1aea4966ce338daf4e1a11427a824618a433a Mon Sep 17 00:00:00 2001 From: Viktor Makoviichuk Date: Sun, 22 May 2022 21:27:17 -0700 Subject: [PATCH 03/26] first commit --- rl_games/algos_torch/models.py | 3 +- rl_games/algos_torch/shac_agent.py | 643 +++++------------------------ 2 files changed, 116 insertions(+), 530 deletions(-) diff --git a/rl_games/algos_torch/models.py b/rl_games/algos_torch/models.py index 7d8f91ab..b8833dac 100644 --- a/rl_games/algos_torch/models.py +++ b/rl_games/algos_torch/models.py @@ -375,8 +375,9 @@ def is_rnn(self): def forward(self, input_dict): is_train = input_dict.pop('is_train', True) + input_dict['obs'] = self.norm_obs(input_dict['obs']) mu, sigma = self.shac_network(input_dict) - dist = SquashedNormal(mu, sigma) + dist = torch.distributions.Normal(mu, sigma) return dist diff --git a/rl_games/algos_torch/shac_agent.py b/rl_games/algos_torch/shac_agent.py index e491302f..b90b6ada 100644 --- a/rl_games/algos_torch/shac_agent.py +++ b/rl_games/algos_torch/shac_agent.py @@ -4,7 +4,7 @@ from rl_games.common import vecenv, schedulers, experience -from rl_games.interfaces.base_algorithm import BaseAlgorithm +from rl_games.common.a2c_common import ContinuousA2CBase from torch.utils.tensorboard import SummaryWriter from datetime import datetime from rl_games.algos_torch import model_builder @@ -18,259 +18,110 @@ import gym -class SHACAgent(BaseAlgorithm): - +class SHACAgent(ContinuousA2CBase): def __init__(self, base_name, params): - - self.config = config = params['config'] - print(config) - - # TODO: Get obs shape and self.network - self.load_networks(params) - - self.env_config = config.get('env_config', {}) - self.num_actors = config.get('num_actors', 1) - self.env_name = config['env_name'] - print("Env name:", self.env_name) - - self.env_info = config.get('env_info') - if self.env_info is None: - self.vec_env = vecenv.create_vec_env(self.env_name, self.num_actors, **self.env_config) - self.env_info = self.vec_env.get_env_info() - - self._device = config.get('device', 'cuda:0') - - print('Env info:') - print(self.env_info) - - # if cfg['train']: - # self.save('init_policy') - - self.rewards_shaper = config['reward_shaper'] - self.observation_space = self.env_info['observation_space'] - self.weight_decay = config.get('weight_decay', 0.0) - #self.use_action_masks = config.get('use_action_masks', False) - self.is_train = config.get('is_train', True) - - self.c_loss = nn.MSELoss() - # self.c2_loss = nn.SmoothL1Loss() - - self.save_best_after = config.get('save_best_after', 100) - self.print_stats = config.get('print_stats', True) - self.rnn_states = None - self.name = base_name - - self.max_epochs = self.config.get('max_epochs', 1e6) - - self.network = config['network'] - self.rewards_shaper = config['reward_shaper'] - self.num_agents = self.env_info.get('agents', 1) - - if isinstance(self.observation_space, gym.spaces.Dict): - self.obs_shape = {} - for k,v in self.observation_space.spaces.items(): - self.obs_shape[k] = v.shape - else: - self.obs_shape = self.observation_space.shape - - self.horizon_length = config["horizon_length"] - self.gamma = config["gamma"] - self.critic_coef = config.get('critic_coef', 1.0) - # self.tau = self.config['tau'] - self.critic_tau = config["critic_tau"] # todo align names with PPO impl - self.batch_size = config["batch_size"] - - self.num_steps_per_episode = config.get("num_steps_per_episode", 1) - self.normalize_input = config.get("normalize_input", False) - - self.max_env_steps = config.get("max_env_steps", 1000) # temporary, in future we will use other approach - - print(self.batch_size, self.num_actors, self.num_agents) - - self.num_frames_per_epoch = self.num_actors * self.num_steps_per_episode - - action_space = self.env_info['action_space'] - self.actions_num = action_space.shape[0] - - self.action_range = [ - float(self.env_info['action_space'].low.min()), - float(self.env_info['action_space'].high.max()) - ] - - obs_shape = torch_ext.shape_whc_to_cwh(self.obs_shape) - net_config = { - 'obs_dim': self.env_info["observation_space"].shape[0], - 'action_dim': self.env_info["action_space"].shape[0], - 'actions_num' : self.actions_num, - 'input_shape' : obs_shape, - 'normalize_input' : self.normalize_input, + ContinuousA2CBase.__init__(self, base_name, params) + obs_shape = self.obs_shape + build_config = { + 'actions_num': self.actions_num, + 'input_shape': obs_shape, + 'num_seqs': self.num_actors * self.num_agents, + 'value_size': self.env_info.get('value_size', 1), + 'normalize_value': self.normalize_value, 'normalize_input': self.normalize_input, - } - self.model = self.network.build(net_config) - self.model.to(self.device) - - print("Number of Agents", self.num_actors, "Batch Size", self.batch_size) - - self.actor_optimizer = torch.optim.Adam(self.model.shac_network.actor.parameters(), - lr=self.config['actor_lr'], - betas=self.config.get("actor_betas", [0.9, 0.999])) - - self.critic_optimizer = torch.optim.Adam(self.model.shac_network.critic.parameters(), - lr=self.config["critic_lr"], - betas=self.config.get("critic_betas", [0.9, 0.999])) - - self.step = 0 - self.algo_observer = config['features']['observer'] - - # TODO: Is there a better way to get the maximum number of episodes? - #self.max_episodes = torch.ones(self.num_actors, device=self.device)*self.num_steps_per_episode - - - # shac params - self.gamma = config.get('gamma', 0.99) - - self.critic_method = config.get('critic_method', 'td-lambda') # ['one-step', 'td-lambda'] - if self.critic_method == 'td-lambda': - self.lam = config.get('lambda', 0.95) - - self.max_epochs = config.get("max_epochs", 1000) # add get - self.actor_lr = float(config["actor_learning_rate"]) - self.critic_lr = float(config['critic_learning_rate']) - self.lr_schedule = config.get('lr_schedule', 'linear') + } + + self.model = self.network.build(build_config) + self.critic = self.critic_network.build(build_config) + if self.normalize_input: + self.critic.input_mean_std = self.model.input_mean_std + + self.model.to(self.ppo_device) + self.states = None + self.init_rnn_from_model(self.model) + self.last_lr = float(self.last_lr) + self.bound_loss_type = self.config.get('bound_loss_type', 'bound') # 'regularisation' or 'bound' + self.optimizer = optim.Adam(self.model.parameters(), float(self.last_lr), eps=1e-08, + weight_decay=self.weight_decay) + + if self.has_central_value: + cv_config = { + 'state_shape': self.state_shape, + 'value_size': self.value_size, + 'ppo_device': self.ppo_device, + 'num_agents': self.num_agents, + 'horizon_length': self.horizon_length, + 'num_actors': self.num_actors, + 'num_actions': self.actions_num, + 'seq_len': self.seq_len, + 'normalize_value': self.normalize_value, + 'network': self.central_value_config['network'], + 'config': self.central_value_config, + 'writter': self.writer, + 'max_epochs': self.max_epochs, + 'multi_gpu': self.multi_gpu, + 'hvd': self.hvd if self.multi_gpu else None + } + self.central_value_net = central_value.CentralValueTrain(**cv_config).to(self.ppo_device) + + self.dataset = datasets.PPODataset(self.batch_size, self.minibatch_size, self.is_discrete, self.is_rnn, + self.ppo_device, self.seq_len) + if self.normalize_value: + self.value_mean_std = self.central_value_net.model.value_mean_std if self.has_central_value else self.model.value_mean_std - self.target_critic_alpha = config.get('target_critic_alpha', 0.4) - - # self.obs_rms = None - # if config.get('obs_rms', False): - # self.obs_rms = RunningMeanStd(shape = (self.num_obs), device = self.device) - - # self.ret_rms = None - # if config.get('ret_rms', False): - # self.ret_rms = RunningMeanStd(shape = (), device = self.device) - - #self.rew_scale = cfg['params']['config'].get('rew_scale', 1.0) - - self.critic_iterations = config.get('critic_iterations', 16) - self.num_batch = config.get('num_batch', 4) - self.batch_size = self.num_actors * self.horizon_length // self.num_batch - self.name = config.get('name', "Ant") - - self.truncate_grad = config.get("truncate_grads", True) - self.grad_norm = config.get("grad_norm", 1.0) - ########### - - self.games_to_track = self.config.get('games_to_track', 100) - self.game_rewards = torch_ext.AverageMeter(1, self.games_to_track).to(self.device) - self.game_lengths = torch_ext.AverageMeter(1, self.games_to_track).to(self.device) - self.obs = None - - self.frame = 0 - self.update_time = 0 - self.last_mean_rewards = -100500 - self.play_time = 0 - self.epoch_num = 0 - - self.writer = SummaryWriter('runs/' + config['name'] + datetime.now().strftime("_%d-%H-%M-%S")) - print("Run Directory:", config['name'] + datetime.now().strftime("_%d-%H-%M-%S")) - - self.is_tensor_obses = True - self.is_rnn = False - self.last_rnn_indices = None - self.last_state_indices = None - - # counting variables - self.iter_count = 0 - self.step_count = 0 - - # loss variables - self.episode_length_his = [] - self.episode_loss_his = [] - self.episode_discounted_loss_his = [] - - self.best_policy_loss = np.inf - self.actor_loss = np.inf - self.value_loss = np.inf - - # average meter - self.episode_loss_meter = torch_ext.AverageMeter(1, 100).to(self.device) - self.episode_discounted_loss_meter = torch_ext.AverageMeter(1, 100).to(self.device) - self.episode_length_meter = torch_ext.AverageMeter(1, 100).to(self.device) - - # timer - # self.time_report = TimeReport() - # self.all_params = list(self.actor.parameters()) + list(self.critic.parameters()) - # self.target_critic = copy.deepcopy(self.critic) + self.algo_observer.after_init(self) def load_networks(self, params): - builder = model_builder.ModelBuilder() - self.config['network'] = builder.load(params) - - def init_tensors(self): - - batch_size = self.num_agents * self.num_actors - - self.current_rewards = torch.zeros(batch_size, dtype=torch.float32, device=self.device) - self.current_lengths = torch.zeros(batch_size, dtype=torch.long, device=self.device) - - self.dones = torch.zeros((batch_size,), dtype=torch.uint8, device=self.device) - - # replay buffer - self.obs_buf = torch.zeros((self.horizon_length, self.num_envs, self.num_obs), dtype = torch.float32, device = self.device) - self.rew_buf = torch.zeros((self.horizon_length, self.num_envs), dtype = torch.float32, device = self.device) - self.done_mask = torch.zeros((self.horizon_length, self.num_envs), dtype = torch.float32, device = self.device) - self.next_values = torch.zeros((self.horizon_length, self.num_envs), dtype = torch.float32, device = self.device) - self.target_values = torch.zeros((self.horizon_length, self.num_envs), dtype = torch.float32, device = self.device) - self.ret = torch.zeros((self.num_envs), dtype = torch.float32, device = self.device) - - # for kl divergence computing - self.old_mus = torch.zeros((self.horizon_length, self.num_envs, self.num_actions), dtype = torch.float32, device = self.device) - self.old_sigmas = torch.zeros((self.horizon_length, self.num_envs, self.num_actions), dtype = torch.float32, device = self.device) - self.mus = torch.zeros((self.horizon_length, self.num_envs, self.num_actions), dtype = torch.float32, device = self.device) - self.sigmas = torch.zeros((self.horizon_length, self.num_envs, self.num_actions), dtype = torch.float32, device = self.device) - - self.episode_loss = torch.zeros(self.num_envs, dtype = torch.float32, device = self.device) - self.episode_discounted_loss = torch.zeros(self.num_envs, dtype = torch.float32, device = self.device) - self.episode_gamma = torch.ones(self.num_envs, dtype = torch.float32, device = self.device) - self.episode_length = torch.zeros(self.num_envs, dtype = int) - - @property - def device(self): - return self._device - - def get_full_state_weights(self): - state = self.get_weights() - - state['steps'] = self.step - state['actor_optimizer'] = self.actor_optimizer.state_dict() - state['critic_optimizer'] = self.critic_optimizer.state_dict() - - return state + ContinuousA2CBase.load_networks(self, params) + if critic_config in self.config: + builder = model_builder.ModelBuilder() + print('Adding Critic Network') + network = builder.load(params['config']['critic_config']) + self.critic_network = network + + def get_actions(self, obs): + processed_obs = self._preproc_obs(obs['obs']) + input_dict = { + 'is_train': False, + 'prev_actions': None, + 'obs' : processed_obs, + 'rnn_states' : self.rnn_states + } + res_dict = self.model(input_dict) + return res_dict + + def get_values(self, obs): + processed_obs = self._preproc_obs(obs['obs']) + input_dict = { + 'is_train': False, + 'prev_actions': None, + 'obs' : processed_obs, + 'rnn_states' : self.rnn_states + } + if self.has_central_value: + states = obs['states'] + self.central_value_net.eval() + input_dict = { + 'is_train': False, + 'states' : states, + 'actions' : None, + 'is_done': self.dones, + } + value = self.get_central_value(input_dict) + else: + processed_obs = self._preproc_obs(obs['obs']) + result = self.critic_model(input_dict) + value = result['values'] + return value - def get_weights(self): - state = {'actor': self.model.shac_network.actor.state_dict(), - 'critic': self.model.shac_network.critic.state_dict(), - 'critic_target': self.model.shac_network.critic_target.state_dict()} - return state + def update_epoch(self): + self.epoch_num += 1 + return self.epoch_num def save(self, fn): state = self.get_full_state_weights() torch_ext.save_checkpoint(fn, state) - def set_weights(self, weights): - self.model.shac_network.actor.load_state_dict(weights['actor']) - self.model.shac_network.critic.load_state_dict(weights['critic']) - self.model.shac_network.critic_target.load_state_dict(weights['critic_target']) - - if self.normalize_input and 'running_mean_std' in weights: - self.model.running_mean_std.load_state_dict(weights['running_mean_std']) - - def set_full_state_weights(self, weights): - self.set_weights(weights) - - self.step = weights['step'] - self.actor_optimizer.load_state_dict(weights['actor_optimizer']) - self.critic_optimizer.load_state_dict(weights['critic_optimizer']) - def restore(self, fn): checkpoint = torch_ext.load_checkpoint(fn) self.set_full_state_weights(checkpoint) @@ -278,297 +129,31 @@ def restore(self, fn): def get_masked_action_values(self, obs, action_masks): assert False - def set_eval(self): - self.model.eval() + def calc_gradients(self, input_dict): + value_preds_batch = input_dict['old_values'] + old_action_log_probs_batch = input_dict['old_logp_actions'] + advantage = input_dict['advantages'] + old_mu_batch = input_dict['mu'] + old_sigma_batch = input_dict['sigma'] + return_batch = input_dict['returns'] + actions_batch = input_dict['actions'] + obs_batch = input_dict['obs'] + obs_batch = self._preproc_obs(obs_batch) - def set_train(self): - self.model.train() - def update_critic(self, obs, action, reward, next_obs, not_done,step): - with torch.no_grad(): - dist = self.model.actor(next_obs) - next_action = dist.rsample() - log_prob = dist.log_prob(next_action).sum(-1, keepdim=True) - target_Q1, target_Q2 = self.model.critic_target(next_obs, next_action) - target_V = torch.min(target_Q1, target_Q2) - self.alpha * log_prob - - target_Q = reward + (not_done * self.gamma * target_V) - target_Q = target_Q.detach() - - # get current Q estimates - current_Q1, current_Q2 = self.model.critic(obs, action) - - critic1_loss = self.c_loss(current_Q1, target_Q) - critic2_loss = self.c_loss(current_Q2, target_Q) - critic_loss = critic1_loss + critic2_loss - self.critic_optimizer.zero_grad(set_to_none=True) - critic_loss.backward() - self.critic_optimizer.step() - - return critic_loss.detach(), critic1_loss.detach(), critic2_loss.detach() - - def update_actor(self, obs, step): - for p in self.model.shac_network.critic.parameters(): - p.requires_grad = False - - dist = self.model.actor(obs) - action = dist.rsample() - log_prob = dist.log_prob(action).sum(-1, keepdim=True) - entropy = dist.entropy().sum(-1, keepdim=True).mean() - actor_Q1, actor_Q2 = self.model.critic(obs, action) - actor_Q = torch.min(actor_Q1, actor_Q2) - - actor_loss = (torch.max(self.alpha.detach(), self.min_alpha) * log_prob - actor_Q) - actor_loss = actor_loss.mean() - - self.actor_optimizer.zero_grad(set_to_none=True) - actor_loss.backward() - self.actor_optimizer.step() - - for p in self.model.shac_network.critic.parameters(): - p.requires_grad = True - - if self.learnable_temperature: - alpha_loss = (self.alpha * - (-log_prob - self.target_entropy).detach()).mean() - self.log_alpha_optimizer.zero_grad(set_to_none=True) - alpha_loss.backward() - self.log_alpha_optimizer.step() - else: - alpha_loss = None - - return actor_loss.detach(), entropy.detach(), self.alpha.detach(), alpha_loss # TODO: maybe not self.alpha - - def soft_update_params(self, net, target_net, tau): - for param, target_param in zip(net.parameters(), target_net.parameters()): - target_param.data.copy_(tau * param.data + - (1 - tau) * target_param.data) - - def update(self, step): - obs, action, reward, next_obs, done = self.replay_buffer.sample(self.batch_size) - not_done = ~done - - obs = self.preproc_obs(obs) - next_obs = self.preproc_obs(next_obs) + self.optimizer.zero_grad(set_to_none=True) - critic_loss, critic1_loss, critic2_loss = self.update_critic(obs, action, reward, next_obs, not_done, step) - actor_loss, entropy, alpha, alpha_loss = self.update_actor_and_alpha(obs, step) + self.scaler.scale(loss).backward() + self.trancate_gradients() - actor_loss_info = actor_loss, entropy, alpha, alpha_loss - self.soft_update_params(self.model.shac_network.critic, self.model.shac_network.critic_target, - self.critic_tau) - return actor_loss_info, critic1_loss, critic2_loss - - def preproc_obs(self, obs): - if isinstance(obs, dict): - obs = obs['obs'] - return obs - - def env_step(self, actions): - if not self.is_tensor_obses: - actions = actions.cpu().numpy() - obs, rewards, dones, infos = self.vec_env.step(actions) # (obs_space) -> (n, obs_space) - - self.step += self.num_actors - if self.is_tensor_obses: - return obs, rewards, dones, infos - else: - return torch.from_numpy(obs).to(self.device), torch.from_numpy(rewards).to(self.device), torch.from_numpy(dones).to(self.device), infos - - def env_reset(self): with torch.no_grad(): - obs = self.vec_env.reset() - - if self.is_tensor_obses is None: - self.is_tensor_obses = torch.is_tensor(obs) - print("Observations are tensors:", self.is_tensor_obses) - - if self.is_tensor_obses: - return obs.to(self.device) - else: - return torch.from_numpy(obs).to(self.device) - - def act(self, obs, action_dim, sample=False): - obs = self.preproc_obs(obs) - dist = self.model.actor(obs) - actions = dist.sample() if sample else dist.mean - actions = actions.clamp(*self.action_range) - assert actions.ndim == 2 - return actions - - def extract_actor_stats(self, actor_losses, entropies, alphas, alpha_losses, actor_loss_info): - actor_loss, entropy, alpha, alpha_loss = actor_loss_info - - actor_losses.append(actor_loss) - entropies.append(entropy) - if alpha_losses is not None: - alphas.append(alpha) - alpha_losses.append(alpha_loss) - - def clear_stats(self): - self.game_rewards.clear() - self.game_lengths.clear() - self.mean_rewards = self.last_mean_rewards = -100500 - self.algo_observer.after_clear_stats() - - def play_steps(self, random_exploration=False): - total_time_start = time.time() - total_update_time = 0 - total_time = 0 - step_time = 0.0 - actor_losses = [] - entropies = [] - alphas = [] - alpha_losses = [] - critic1_losses = [] - critic2_losses = [] - - obs = self.obs - for _ in range(self.num_steps_per_episode): - self.set_eval() - if random_exploration: - action = torch.rand((self.num_actors, *self.env_info["action_space"].shape), device=self.device) * 2 - 1 - else: - with torch.no_grad(): - action = self.act(obs.float(), self.env_info["action_space"].shape, sample=True) - - step_start = time.time() - - with torch.no_grad(): - next_obs, rewards, dones, infos = self.env_step(action) - step_end = time.time() - - self.current_rewards += rewards - self.current_lengths += 1 - - total_time += step_end - step_start - - step_time += (step_end - step_start) - - all_done_indices = dones.nonzero(as_tuple=False) - done_indices = all_done_indices[::self.num_agents] - self.game_rewards.update(self.current_rewards[done_indices]) - self.game_lengths.update(self.current_lengths[done_indices]) - - not_dones = 1.0 - dones.float() - - self.algo_observer.process_infos(infos, done_indices) - - no_timeouts = self.current_lengths != self.max_env_steps - dones = dones * no_timeouts - - self.current_rewards = self.current_rewards * not_dones - self.current_lengths = self.current_lengths * not_dones - - if isinstance(obs, dict): - obs = obs['obs'] - if isinstance(next_obs, dict): - next_obs = next_obs['obs'] - - rewards = self.rewards_shaper(rewards) - - self.replay_buffer.add(obs, action, torch.unsqueeze(rewards, 1), next_obs, torch.unsqueeze(dones, 1)) - - self.obs = obs = next_obs.clone() - - if not random_exploration: - self.set_train() - update_time_start = time.time() - actor_loss_info, critic1_loss, critic2_loss = self.update(self.epoch_num) - update_time_end = time.time() - update_time = update_time_end - update_time_start - - self.extract_actor_stats(actor_losses, entropies, alphas, alpha_losses, actor_loss_info) - critic1_losses.append(critic1_loss) - critic2_losses.append(critic2_loss) - else: - update_time = 0 - - total_update_time += update_time - - total_time_end = time.time() - total_time = total_time_end - total_time_start - play_time = total_time - total_update_time - - return step_time, play_time, total_update_time, total_time, actor_losses, entropies, alphas, alpha_losses, critic1_losses, critic2_losses - - def train_epoch(self): - if self.epoch_num < self.num_seed_steps: - step_time, play_time, total_update_time, total_time, actor_losses, entropies, alphas, alpha_losses, critic1_losses, critic2_losses = self.play_steps(random_exploration=True) - else: - step_time, play_time, total_update_time, total_time, actor_losses, entropies, alphas, alpha_losses, critic1_losses, critic2_losses = self.play_steps(random_exploration=False) - - return step_time, play_time, total_update_time, total_time, actor_losses, entropies, alphas, alpha_losses, critic1_losses, critic2_losses - - def train(self): - self.init_tensors() - self.algo_observer.after_init(self) - self.last_mean_rewards = -100500 - total_time = 0 - # rep_count = 0 - self.frame = 0 - self.obs = self.env_reset() - - while True: - self.epoch_num += 1 - step_time, play_time, update_time, epoch_total_time, actor_losses, entropies, alphas, alpha_losses, critic1_losses, critic2_losses = self.train_epoch() - - total_time += epoch_total_time - - scaled_time = epoch_total_time - scaled_play_time = play_time - curr_frames = self.num_frames_per_epoch - self.frame += curr_frames - frame = self.frame #TODO: Fix frame - # print(frame) - - if self.print_stats: - fps_step = curr_frames / scaled_play_time - fps_total = curr_frames / scaled_time - print(f'fps step: {fps_step:.1f} fps total: {fps_total:.1f}') - - self.writer.add_scalar('performance/step_inference_rl_update_fps', curr_frames / scaled_time, frame) - self.writer.add_scalar('performance/step_inference_fps', curr_frames / scaled_play_time, frame) - self.writer.add_scalar('performance/step_fps', curr_frames / step_time, frame) - self.writer.add_scalar('performance/rl_update_time', update_time, frame) - self.writer.add_scalar('performance/step_inference_time', play_time, frame) - self.writer.add_scalar('performance/step_time', step_time, frame) - - if self.epoch_num >= self.num_seed_steps: - self.writer.add_scalar('losses/a_loss', torch_ext.mean_list(actor_losses).item(), frame) - self.writer.add_scalar('losses/c1_loss', torch_ext.mean_list(critic1_losses).item(), frame) - self.writer.add_scalar('losses/c2_loss', torch_ext.mean_list(critic2_losses).item(), frame) - self.writer.add_scalar('losses/entropy', torch_ext.mean_list(entropies).item(), frame) - if alpha_losses[0] is not None: - self.writer.add_scalar('losses/alpha_loss', torch_ext.mean_list(alpha_losses).item(), frame) - self.writer.add_scalar('info/alpha', torch_ext.mean_list(alphas).item(), frame) - - self.writer.add_scalar('info/epochs', self.epoch_num, frame) - self.algo_observer.after_print_stats(frame, self.epoch_num, total_time) - - if self.game_rewards.current_size > 0: - mean_rewards = self.game_rewards.get_mean() - mean_lengths = self.game_lengths.get_mean() - - self.writer.add_scalar('rewards/step', mean_rewards, frame) - # self.writer.add_scalar('rewards/iter', mean_rewards, epoch_num) - self.writer.add_scalar('rewards/time', mean_rewards, total_time) - self.writer.add_scalar('episode_lengths/step', mean_lengths, frame) - # self.writer.add_scalar('episode_lengths/iter', mean_lengths, epoch_num) - self.writer.add_scalar('episode_lengths/time', mean_lengths, total_time) + reduce_kl = rnn_masks is None + kl_dist = torch_ext.policy_kl(mu.detach(), sigma.detach(), old_mu_batch, old_sigma_batch, reduce_kl) + if rnn_masks is not None: + kl_dist = (kl_dist * rnn_masks).sum() / rnn_masks.numel() # / sum_mask - if mean_rewards > self.last_mean_rewards and self.epoch_num >= self.save_best_after: - print('saving next best rewards: ', mean_rewards) - self.last_mean_rewards = mean_rewards - self.save("./nn/" + self.config['name']) - if self.last_mean_rewards > self.config.get('score_to_win', float('inf')): - print('Network won!') - self.save("./nn/" + self.config['name'] + 'ep=' + str(self.epoch_num) + 'rew=' + str(mean_rewards)) - return self.last_mean_rewards, self.epoch_num - if self.epoch_num > self.max_epochs: - self.save("./nn/" + 'last_' + self.config['name'] + 'ep=' + str(self.epoch_num) + 'rew=' + str(mean_rewards)) - print('MAX EPOCHS NUM!') - return self.last_mean_rewards, self.epoch_num - update_time = 0 \ No newline at end of file + def train_actor_critic(self, input_dict): + self.calc_gradients(input_dict) + return self.train_result From b49d21f584459f1174830d5482bebbcbcc4bc144 Mon Sep 17 00:00:00 2001 From: Viktor Makoviichuk Date: Sun, 22 May 2022 21:27:59 -0700 Subject: [PATCH 04/26] removed shac model --- rl_games/algos_torch/model_builder.py | 2 -- rl_games/algos_torch/models.py | 30 --------------------------- 2 files changed, 32 deletions(-) diff --git a/rl_games/algos_torch/model_builder.py b/rl_games/algos_torch/model_builder.py index 8acb2318..30cf6fd2 100644 --- a/rl_games/algos_torch/model_builder.py +++ b/rl_games/algos_torch/model_builder.py @@ -46,8 +46,6 @@ def __init__(self): lambda network, **kwargs: models.ModelA2CContinuousLogStd(network)) self.model_factory.register_builder('soft_actor_critic', lambda network, **kwargs: models.ModelSACContinuous(network)) - self.model_factory.register_builder('shac_actor_critic', - lambda network, **kwargs: models.ModelSHAC(network)) self.model_factory.register_builder('central_value', lambda network, **kwargs: models.ModelCentralValue(network)) self.network_builder = NetworkBuilder() diff --git a/rl_games/algos_torch/models.py b/rl_games/algos_torch/models.py index b8833dac..fdbc5020 100644 --- a/rl_games/algos_torch/models.py +++ b/rl_games/algos_torch/models.py @@ -350,35 +350,5 @@ def forward(self, input_dict): return dist -class ModelSHAC(BaseModel): - - def __init__(self, network): - BaseModel.__init__(self, 'shac') - self.network_builder = network - - class Network(BaseModelNetwork): - def __init__(self, shac_network, **kwargs): - BaseModelNetwork.__init__(self, **kwargs) - self.shac_network = shac_network - - def critic(self, obs, action): - return self.shac_network.critic(obs, action) - - def critic_target(self, obs, action): - return self.shac_network.critic_target(obs, action) - - def actor(self, obs): - return self.shac_network.actor(obs) - - def is_rnn(self): - return False - - def forward(self, input_dict): - is_train = input_dict.pop('is_train', True) - input_dict['obs'] = self.norm_obs(input_dict['obs']) - mu, sigma = self.shac_network(input_dict) - dist = torch.distributions.Normal(mu, sigma) - return dist - From 8bc6ccd3abee952101f45f68c7b4384ced46f976 Mon Sep 17 00:00:00 2001 From: Viktor Makoviichuk Date: Thu, 26 May 2022 11:01:27 -0700 Subject: [PATCH 05/26] it works --- rl_games/algos_torch/model_builder.py | 3 +- rl_games/algos_torch/models.py | 40 +- rl_games/algos_torch/running_mean_std.py | 45 +- rl_games/algos_torch/shac_agent.py | 416 ++++++++++++++---- rl_games/common/a2c_common.py | 5 +- rl_games/common/env_configurations.py | 2 +- rl_games/common/player.py | 2 +- rl_games/common/wrappers.py | 8 +- .../configs/atari/ppo_breakout_torch.yaml | 4 +- 9 files changed, 407 insertions(+), 118 deletions(-) diff --git a/rl_games/algos_torch/model_builder.py b/rl_games/algos_torch/model_builder.py index 30cf6fd2..d38ff340 100644 --- a/rl_games/algos_torch/model_builder.py +++ b/rl_games/algos_torch/model_builder.py @@ -22,7 +22,6 @@ def __init__(self): lambda **kwargs: network_builder.A2CResnetBuilder()) self.network_factory.register_builder('rnd_curiosity', lambda **kwargs: network_builder.RNDCuriosityBuilder()) self.network_factory.register_builder('soft_actor_critic', lambda **kwargs: network_builder.SACBuilder()) - self.network_factory.register_builder('shac_actor_critic', lambda **kwargs: network_builder.SHACBuilder()) def load(self, params): network_name = params['name'] @@ -48,6 +47,8 @@ def __init__(self): lambda network, **kwargs: models.ModelSACContinuous(network)) self.model_factory.register_builder('central_value', lambda network, **kwargs: models.ModelCentralValue(network)) + self.model_factory.register_builder('shac', + lambda network, **kwargs: models.ModelA2CContinuousSHAC(network)) self.network_builder = NetworkBuilder() def get_network_builder(self): diff --git a/rl_games/algos_torch/models.py b/rl_games/algos_torch/models.py index fdbc5020..00b0059d 100644 --- a/rl_games/algos_torch/models.py +++ b/rl_games/algos_torch/models.py @@ -46,12 +46,10 @@ def __init__(self, obs_shape, normalize_value, normalize_input, value_size): self.running_mean_std = RunningMeanStd(obs_shape) def norm_obs(self, observation): - with torch.no_grad(): - return self.running_mean_std(observation) if self.normalize_input else observation + return self.running_mean_std(observation) if self.normalize_input else observation def unnorm_value(self, value): - with torch.no_grad(): - return self.value_mean_std(value, unnorm=True) if self.normalize_value else value + return self.value_mean_std(value, unnorm=True) if self.normalize_value else value class ModelA2C(BaseModel): @@ -285,6 +283,39 @@ def neglogp(self, x, mean, std, logstd): + logstd.sum(dim=-1) +class ModelA2CContinuousSHAC(BaseModel): + def __init__(self, network): + BaseModel.__init__(self, 'a2c') + self.network_builder = network + + class Network(BaseModelNetwork): + def __init__(self, a2c_network, **kwargs): + BaseModelNetwork.__init__(self, **kwargs) + self.a2c_network = a2c_network + + def is_rnn(self): + return self.a2c_network.is_rnn() + + def get_default_rnn_state(self): + return self.a2c_network.get_default_rnn_state() + + def forward(self, input_dict): + input_dict['obs'] = self.norm_obs(input_dict['obs']) + mu, logstd, _, states = self.a2c_network(input_dict) + sigma = torch.exp(logstd) + distr = torch.distributions.Normal(mu, sigma) + entropy = distr.entropy().sum(dim=-1) + selected_action = distr.rsample() + result = { + 'actions': selected_action, + 'entropy': entropy, + 'rnn_states': states, + 'mus': mu, + 'sigmas': sigma + } + return result + + class ModelCentralValue(BaseModel): def __init__(self, network): @@ -312,7 +343,6 @@ def forward(self, input_dict): value, states = self.a2c_network(input_dict) if not is_train: value = self.unnorm_value(value) - result = { 'values': value, 'rnn_states': states diff --git a/rl_games/algos_torch/running_mean_std.py b/rl_games/algos_torch/running_mean_std.py index 152295c1..947f4250 100644 --- a/rl_games/algos_torch/running_mean_std.py +++ b/rl_games/algos_torch/running_mean_std.py @@ -43,29 +43,30 @@ def _update_mean_var_count_from_moments(self, mean, var, count, batch_mean, batc return new_mean, new_var, new_count def forward(self, input, unnorm=False, mask=None): - if self.training: - if mask is not None: - mean, var = torch_ext.get_mean_std_with_masks(input, mask) - else: - mean = input.mean(self.axis) # along channel axis - var = input.var(self.axis) - self.running_mean, self.running_var, self.count = self._update_mean_var_count_from_moments(self.running_mean, self.running_var, self.count, - mean, var, input.size()[0] ) + with torch.no_grad(): + if self.training: + if mask is not None: + mean, var = torch_ext.get_mean_std_with_masks(input, mask) + else: + mean = input.mean(self.axis) # along channel axis + var = input.var(self.axis) + self.running_mean, self.running_var, self.count = self._update_mean_var_count_from_moments(self.running_mean, self.running_var, self.count, + mean, var, input.size()[0] ) - # change shape - if self.per_channel: - if len(self.insize) == 3: - current_mean = self.running_mean.view([1, self.insize[0], 1, 1]).expand_as(input) - current_var = self.running_var.view([1, self.insize[0], 1, 1]).expand_as(input) - if len(self.insize) == 2: - current_mean = self.running_mean.view([1, self.insize[0], 1]).expand_as(input) - current_var = self.running_var.view([1, self.insize[0], 1]).expand_as(input) - if len(self.insize) == 1: - current_mean = self.running_mean.view([1, self.insize[0]]).expand_as(input) - current_var = self.running_var.view([1, self.insize[0]]).expand_as(input) - else: - current_mean = self.running_mean - current_var = self.running_var + # change shape + if self.per_channel: + if len(self.insize) == 3: + current_mean = self.running_mean.view([1, self.insize[0], 1, 1]).expand_as(input) + current_var = self.running_var.view([1, self.insize[0], 1, 1]).expand_as(input) + if len(self.insize) == 2: + current_mean = self.running_mean.view([1, self.insize[0], 1]).expand_as(input) + current_var = self.running_var.view([1, self.insize[0], 1]).expand_as(input) + if len(self.insize) == 1: + current_mean = self.running_mean.view([1, self.insize[0]]).expand_as(input) + current_var = self.running_var.view([1, self.insize[0]]).expand_as(input) + else: + current_mean = self.running_mean + current_var = self.running_var # get output diff --git a/rl_games/algos_torch/shac_agent.py b/rl_games/algos_torch/shac_agent.py index b90b6ada..f74ca28f 100644 --- a/rl_games/algos_torch/shac_agent.py +++ b/rl_games/algos_torch/shac_agent.py @@ -1,22 +1,29 @@ -from rl_games.algos_torch import torch_ext from rl_games.algos_torch.running_mean_std import RunningMeanStd - from rl_games.common import vecenv, schedulers, experience from rl_games.common.a2c_common import ContinuousA2CBase -from torch.utils.tensorboard import SummaryWriter -from datetime import datetime -from rl_games.algos_torch import model_builder +from rl_games.common import a2c_common +from rl_games.algos_torch import torch_ext -import torch -from torch import nn, optim -import torch.nn.functional as F -import numpy as np +from rl_games.algos_torch import central_value +from rl_games.common import common_losses +from rl_games.common import datasets +from rl_games.algos_torch import model_builder +from torch import optim +import torch import time +import os +import copy -import gym - +def swap_and_flatten01(arr): + """ + swap and then flatten axes 0 and 1 + """ + if arr is None: + return arr + s = arr.size() + return arr.transpose(0, 1).reshape(s[0] * s[1], *s[2:]) class SHACAgent(ContinuousA2CBase): def __init__(self, base_name, params): @@ -30,50 +37,130 @@ def __init__(self, base_name, params): 'normalize_value': self.normalize_value, 'normalize_input': self.normalize_input, } + self.critic_lr = self.config.get('critic_learning_rate', 0.0001) + self.target_critic_alpha = 0.4 + self.actor_model = self.network.build(build_config) + self.critic_model = self.critic_network.build(build_config) - self.model = self.network.build(build_config) - self.critic = self.critic_network.build(build_config) - if self.normalize_input: - self.critic.input_mean_std = self.model.input_mean_std + self.actor_model.to(self.ppo_device) + self.critic_model.to(self.ppo_device) + self.target_critic = copy.deepcopy(self.critic_model) - self.model.to(self.ppo_device) + if self.normalize_input: + self.critic_model.running_mean_std = self.actor_model.running_mean_std + self.target_critic.running_mean_std = self.critic_model.running_mean_std + if self.normalize_value: + self.target_critic.value_mean_std = self.critic_model.value_mean_std self.states = None - self.init_rnn_from_model(self.model) + self.model = self.actor_model + self.init_rnn_from_model(self.actor_model) self.last_lr = float(self.last_lr) - self.bound_loss_type = self.config.get('bound_loss_type', 'bound') # 'regularisation' or 'bound' - self.optimizer = optim.Adam(self.model.parameters(), float(self.last_lr), eps=1e-08, + self.optimizer = self.actor_optimizer = optim.Adam(self.actor_model.parameters(), float(self.last_lr), eps=1e-08, weight_decay=self.weight_decay) - - if self.has_central_value: - cv_config = { - 'state_shape': self.state_shape, - 'value_size': self.value_size, - 'ppo_device': self.ppo_device, - 'num_agents': self.num_agents, - 'horizon_length': self.horizon_length, - 'num_actors': self.num_actors, - 'num_actions': self.actions_num, - 'seq_len': self.seq_len, - 'normalize_value': self.normalize_value, - 'network': self.central_value_config['network'], - 'config': self.central_value_config, - 'writter': self.writer, - 'max_epochs': self.max_epochs, - 'multi_gpu': self.multi_gpu, - 'hvd': self.hvd if self.multi_gpu else None - } - self.central_value_net = central_value.CentralValueTrain(**cv_config).to(self.ppo_device) + self.critic_optimizer = optim.Adam(self.critic_model.parameters(), float(self.critic_lr), eps=1e-08, + weight_decay=self.weight_decay) self.dataset = datasets.PPODataset(self.batch_size, self.minibatch_size, self.is_discrete, self.is_rnn, self.ppo_device, self.seq_len) if self.normalize_value: - self.value_mean_std = self.central_value_net.model.value_mean_std if self.has_central_value else self.model.value_mean_std + self.value_mean_std = self.critic_model.value_mean_std self.algo_observer.after_init(self) + + def play_steps(self): + update_list = self.update_list + accumulated_rewards = torch.zeros((self.horizon_length + 1, self.num_actors), dtype=torch.float32, device=self.device) + actor_loss = torch.tensor(0., dtype=torch.float32, device=self.device) + mb_values = self.experience_buffer.tensor_dict['values'] + gamma = torch.ones(self.num_actors, dtype = torch.float32, device = self.device) + step_time = 0.0 + self.critic_model.eval() + self.actor_model.train() + if self.normalize_input: + self.actor_model.running_mean_std.train() + if self.normalize_value: + self.actor_model.value_mean_std.eval() + obs = self.initialize_trajectory() + for n in range(self.horizon_length): + res_dict = self.get_actions(obs) + res_dict['values'] = self.get_values(obs) + + with torch.no_grad(): + self.experience_buffer.update_data('obses', n, obs['obs'].detach()) + self.experience_buffer.update_data('dones', n, self.dones.detach()) + + self.experience_buffer.update_data('values', n, res_dict['values'].detach()) + + step_time_start = time.time() + obs, rewards, self.dones, infos = self.env_step(torch.tanh(res_dict['actions'])) + step_time_end = time.time() + + step_time += (step_time_end - step_time_start) + + shaped_rewards = self.rewards_shaper(rewards) + + self.experience_buffer.update_data('rewards', n, shaped_rewards.detach()) + + self.current_rewards += rewards.detach() + self.current_lengths += 1 + all_done_indices = self.dones.nonzero(as_tuple=False) + env_done_indices = self.dones.view(self.num_actors, self.num_agents).all(dim=1).nonzero(as_tuple=False) + + self.game_rewards.update(self.current_rewards[env_done_indices]) + self.game_lengths.update(self.current_lengths[env_done_indices]) + self.algo_observer.process_infos(infos, env_done_indices) + fdones = self.dones.float() + not_dones = 1.0 - fdones + + self.current_rewards = self.current_rewards * not_dones.unsqueeze(1) + self.current_lengths = self.current_lengths * not_dones + + accumulated_rewards[n + 1] = accumulated_rewards[n] + gamma * shaped_rewards.squeeze(1) + + last_values = self.get_values(obs) + + end_vals = last_values.squeeze(1) + end_vals = end_vals * not_dones + if n < self.horizon_length - 1: + #- self.gamma * gamma[env_done_indices] * end_vals[env_done_indices] + actor_loss = actor_loss - accumulated_rewards[n + 1] * fdones + gamma = gamma * self.gamma + gamma[env_done_indices] = 1.0 + accumulated_rewards[n + 1, env_done_indices] = 0.0 + else: + # terminate all envs at the end of optimization iteration + actor_loss = actor_loss + (-accumulated_rewards[n + 1] - self.gamma * gamma * end_vals) + + + + + fdones = self.dones.float().detach() + mb_fdones = self.experience_buffer.tensor_dict['dones'].float().detach() + mb_rewards = self.experience_buffer.tensor_dict['rewards'].detach() + mb_advs = self.discount_values(fdones, last_values.detach(), mb_fdones, mb_values.detach(), mb_rewards) + mb_returns = mb_advs + mb_values + + batch_dict = self.experience_buffer.get_transformed_list(swap_and_flatten01, self.tensor_list) + batch_dict['returns'] = swap_and_flatten01(mb_returns) + batch_dict['played_frames'] = self.batch_size + batch_dict['step_time'] = step_time + + actor_loss = actor_loss.mean() / self.horizon_length + return batch_dict, actor_loss + + def env_step(self, actions): + #actions = self.preprocess_actions(actions) + obs, rewards, dones, infos = self.vec_env.step(actions) + + if self.value_size == 1: + rewards = rewards.unsqueeze(1) + return self.obs_to_tensors(obs), rewards.to(self.ppo_device), dones.to(self.ppo_device), infos + + def load_networks(self, params): ContinuousA2CBase.load_networks(self, params) - if critic_config in self.config: + if 'critic_config' in self.config: builder = model_builder.ModelBuilder() print('Adding Critic Network') network = builder.load(params['config']['critic_config']) @@ -82,12 +169,10 @@ def load_networks(self, params): def get_actions(self, obs): processed_obs = self._preproc_obs(obs['obs']) input_dict = { - 'is_train': False, - 'prev_actions': None, 'obs' : processed_obs, 'rnn_states' : self.rnn_states } - res_dict = self.model(input_dict) + res_dict = self.actor_model(input_dict) return res_dict def get_values(self, obs): @@ -98,21 +183,17 @@ def get_values(self, obs): 'obs' : processed_obs, 'rnn_states' : self.rnn_states } - if self.has_central_value: - states = obs['states'] - self.central_value_net.eval() - input_dict = { - 'is_train': False, - 'states' : states, - 'actions' : None, - 'is_done': self.dones, - } - value = self.get_central_value(input_dict) - else: - processed_obs = self._preproc_obs(obs['obs']) - result = self.critic_model(input_dict) - value = result['values'] - return value + + processed_obs = self._preproc_obs(obs['obs']) + result = self.critic_model(input_dict) + value = result['values'] + return value + + def initialize_trajectory(self): + #obs = self.vec_env.reset() + obs = self.vec_env.initialize_trajectory() + obs = self.obs_to_tensors(obs) + return obs def update_epoch(self): self.epoch_num += 1 @@ -129,31 +210,208 @@ def restore(self, fn): def get_masked_action_values(self, obs, action_masks): assert False - def calc_gradients(self, input_dict): - value_preds_batch = input_dict['old_values'] - old_action_log_probs_batch = input_dict['old_logp_actions'] - advantage = input_dict['advantages'] - old_mu_batch = input_dict['mu'] - old_sigma_batch = input_dict['sigma'] - return_batch = input_dict['returns'] - actions_batch = input_dict['actions'] - obs_batch = input_dict['obs'] - obs_batch = self._preproc_obs(obs_batch) + def prepare_critic_dataset(self, batch_dict): + obses = batch_dict['obses'].detach() + returns = batch_dict['returns'].detach() + dones = batch_dict['dones'].detach() + values = batch_dict['values'].detach() + rnn_states = batch_dict.get('rnn_states', None) + rnn_masks = batch_dict.get('rnn_masks', None) + + advantages = returns - values + + if self.normalize_value: + self.value_mean_std.train() + values = self.value_mean_std(values) + returns = self.value_mean_std(returns) + self.value_mean_std.eval() + + advantages = torch.sum(advantages, axis=1) + + if self.normalize_advantage: + advantages = torch_ext.normalization_with_masks(advantages, rnn_masks) + + dataset_dict = {} + dataset_dict['old_values'] = values + dataset_dict['advantages'] = advantages + dataset_dict['returns'] = returns + dataset_dict['obs'] = obses + dataset_dict['dones'] = dones + dataset_dict['rnn_states'] = rnn_states + dataset_dict['rnn_masks'] = rnn_masks + + self.dataset.update_values_dict(dataset_dict) + + + def train_actor(self, actor_loss): + self.actor_model.train() + + self.actor_optimizer.zero_grad(set_to_none=True) + + actor_loss.backward() + + if self.truncate_grads: + torch.nn.utils.clip_grad_norm_(self.actor_model.parameters(), self.grad_norm) + + self.actor_optimizer.step() + return actor_loss.detach() + + def train_critic(self, batch): + self.critic_model.train() + if self.normalize_input: + self.critic_model.running_mean_std.eval() + if self.normalize_value: + self.critic_model.value_mean_std.eval() + obs_batch = self._preproc_obs(batch['obs']) + value_preds_batch = batch['old_values'] + returns_batch = batch['returns'] + dones_batch = batch['dones'] + rnn_masks_batch = batch.get('rnn_masks') + batch_dict = {'obs' : obs_batch, + 'seq_length' : self.seq_len, + 'dones' : dones_batch} + + res_dict = self.critic_model(batch_dict) + values = res_dict['values'] + loss = common_losses.critic_loss(value_preds_batch, values, self.e_clip, returns_batch, self.clip_value) + losses, _ = torch_ext.apply_masks([loss], rnn_masks_batch) + critic_loss = losses[0] + self.critic_optimizer.zero_grad(set_to_none=True) + critic_loss.backward() + + if self.truncate_grads: + torch.nn.utils.clip_grad_norm_(self.critic_model.parameters(), self.grad_norm) + + self.critic_optimizer.step() + return critic_loss.detach() - self.optimizer.zero_grad(set_to_none=True) + def train_epoch(self): + play_time_start = time.time() + batch_dict, actor_loss = self.play_steps() + play_time_end = time.time() + update_time_start = time.time() + self.curr_frames = batch_dict.pop('played_frames') + self.prepare_critic_dataset(batch_dict) + self.algo_observer.after_steps() + a_loss = self.train_actor(actor_loss) + a_losses = [a_loss] + c_losses = [] - self.scaler.scale(loss).backward() - self.trancate_gradients() + for mini_ep in range(0, self.mini_epochs_num): + ep_kls = [] + for i in range(len(self.dataset)): + c_loss = self.train_critic(self.dataset[i]) + c_losses.append(c_loss) + + self.diagnostics.mini_epoch(self, mini_ep) + + # update target critic with torch.no_grad(): - reduce_kl = rnn_masks is None - kl_dist = torch_ext.policy_kl(mu.detach(), sigma.detach(), old_mu_batch, old_sigma_batch, reduce_kl) - if rnn_masks is not None: - kl_dist = (kl_dist * rnn_masks).sum() / rnn_masks.numel() # / sum_mask + alpha = self.target_critic_alpha + for param, param_targ in zip(self.critic_model.parameters(), self.target_critic.parameters()): + param_targ.data.mul_(alpha) + param_targ.data.add_((1. - alpha) * param.data) + update_time_end = time.time() + play_time = play_time_end - play_time_start + update_time = update_time_end - update_time_start + total_time = update_time_end - play_time_start + + return batch_dict['step_time'], play_time, update_time, total_time, a_losses, c_losses + + + def train(self): + self.init_tensors() + self.mean_rewards = self.last_mean_rewards = -100500 + start_time = time.time() + total_time = 0 + rep_count = 0 + # self.frame = 0 # loading from checkpoint + + + while True: + epoch_num = self.update_epoch() + step_time, play_time, update_time, sum_time, a_losses, c_losses = self.train_epoch() + + # cleaning memory to optimize space + self.dataset.update_values_dict(None) + total_time += sum_time + curr_frames = self.curr_frames + self.frame += curr_frames + should_exit = False + if self.rank == 0: + self.diagnostics.epoch(self, current_epoch=epoch_num) + scaled_time = self.num_agents * sum_time + scaled_play_time = self.num_agents * play_time + + frame = self.frame // self.num_agents + + if self.print_stats: + fps_step = curr_frames / step_time + fps_step_inference = curr_frames / scaled_play_time + fps_total = curr_frames / scaled_time + print(f'fps step: {fps_step:.1f} fps step and policy inference: {fps_step_inference:.1f} fps total: {fps_total:.1f} epoch: {epoch_num}/{self.max_epochs}') + print('a_loss:', a_losses[0].item()) + + self.write_stats(total_time, epoch_num, step_time, play_time, update_time, a_losses, c_losses, curr_frames) + + self.algo_observer.after_print_stats(frame, epoch_num, total_time) + + if self.game_rewards.current_size > 0: + mean_rewards = self.game_rewards.get_mean() + mean_lengths = self.game_lengths.get_mean() + self.mean_rewards = mean_rewards[0] + + for i in range(self.value_size): + rewards_name = 'rewards' if i == 0 else 'rewards{0}'.format(i) + self.writer.add_scalar(rewards_name + '/step'.format(i), mean_rewards[i], frame) + self.writer.add_scalar(rewards_name + '/iter'.format(i), mean_rewards[i], epoch_num) + self.writer.add_scalar(rewards_name + '/time'.format(i), mean_rewards[i], total_time) + + self.writer.add_scalar('episode_lengths/step', mean_lengths, frame) + self.writer.add_scalar('episode_lengths/iter', mean_lengths, epoch_num) + self.writer.add_scalar('episode_lengths/time', mean_lengths, total_time) + + if self.has_self_play_config: + self.self_play_manager.update(self) + + # removed equal signs (i.e. "rew=") from the checkpoint name since it messes with hydra CLI parsing + checkpoint_name = self.config['name'] + '_ep_' + str(epoch_num) + '_rew_' + str(mean_rewards[0]) + + if self.save_freq > 0: + if (epoch_num % self.save_freq == 0) and (mean_rewards <= self.last_mean_rewards): + self.save(os.path.join(self.nn_dir, 'last_' + checkpoint_name)) + + if mean_rewards[0] > self.last_mean_rewards and epoch_num >= self.save_best_after: + print('saving next best rewards: ', mean_rewards) + self.last_mean_rewards = mean_rewards[0] + self.save(os.path.join(self.nn_dir, self.config['name'])) + if self.last_mean_rewards > self.config['score_to_win']: + print('Network won!') + self.save(os.path.join(self.nn_dir, checkpoint_name)) + should_exit = True + if epoch_num > self.max_epochs: + self.save(os.path.join(self.nn_dir, 'last_' + checkpoint_name)) + print('MAX EPOCHS NUM!') + should_exit = True + update_time = 0 + if should_exit: + return self.last_mean_rewards, epoch_num - def train_actor_critic(self, input_dict): - self.calc_gradients(input_dict) - return self.train_result + def write_stats(self, total_time, epoch_num, step_time, play_time, update_time, a_losses, c_losses, curr_frames): + # do we need scaled time? + frame = self.frame + self.diagnostics.send_info(self.writer) + self.writer.add_scalar('performance/step_inference_rl_update_fps', curr_frames / update_time, frame) + self.writer.add_scalar('performance/step_inference_fps', curr_frames / play_time, frame) + self.writer.add_scalar('performance/step_fps', curr_frames / step_time, frame) + self.writer.add_scalar('performance/rl_update_time', update_time, frame) + self.writer.add_scalar('performance/step_inference_time', play_time, frame) + self.writer.add_scalar('performance/step_time', step_time, frame) + self.writer.add_scalar('losses/a_loss', torch_ext.mean_list(a_losses).item(), frame) + self.writer.add_scalar('losses/c_loss', torch_ext.mean_list(c_losses).item(), frame) + self.writer.add_scalar('info/epochs', epoch_num, frame) + self.algo_observer.after_print_stats(frame, epoch_num, total_time) \ No newline at end of file diff --git a/rl_games/common/a2c_common.py b/rl_games/common/a2c_common.py index c74389c0..b8411b8c 100644 --- a/rl_games/common/a2c_common.py +++ b/rl_games/common/a2c_common.py @@ -43,10 +43,7 @@ def rescale_actions(low, high, action): class A2CBase(BaseAlgorithm): def __init__(self, base_name, params): - self.config = config = params['config'] - - # population based training parameters pbt_str = '' self.population_based_training = config.get('population_based_training', False) if self.population_based_training: @@ -126,7 +123,7 @@ def __init__(self, base_name, params): self.rnn_states = None self.name = base_name - self.ppo = config['ppo'] + self.ppo = config.get('ppo', True) self.max_epochs = self.config.get('max_epochs', 1e6) self.is_adaptive_lr = config['lr_schedule'] == 'adaptive' diff --git a/rl_games/common/env_configurations.py b/rl_games/common/env_configurations.py index 935fc344..a1a08d46 100644 --- a/rl_games/common/env_configurations.py +++ b/rl_games/common/env_configurations.py @@ -101,7 +101,7 @@ def create_atari_gym_env(**kwargs): name = kwargs.pop('name') skip = kwargs.pop('skip',4) episode_life = kwargs.pop('episode_life',True) - env = wrappers.make_atari_deepmind(name, skip=skip,episode_life=episode_life) + env = wrappers.make_atari_deepmind(name, skip=skip,episode_life=episode_life, **kwargs) return env def create_dm_control_env(**kwargs): diff --git a/rl_games/common/player.py b/rl_games/common/player.py index f05d7721..caf523af 100644 --- a/rl_games/common/player.py +++ b/rl_games/common/player.py @@ -204,7 +204,7 @@ def run(self): steps += 1 if render: - self.env.render(mode='human') + #self.env.render() time.sleep(self.render_sleep) all_done_indices = done.nonzero(as_tuple=False) diff --git a/rl_games/common/wrappers.py b/rl_games/common/wrappers.py index 3baf1a83..846e8e16 100644 --- a/rl_games/common/wrappers.py +++ b/rl_games/common/wrappers.py @@ -596,8 +596,8 @@ def observation(self, observation): return observation * self.mask -def make_atari(env_id, timelimit=True, noop_max=0, skip=4, sticky=False, directory=None): - env = gym.make(env_id) +def make_atari(env_id, timelimit=True, noop_max=0, skip=4, sticky=False, directory=None, **kwargs): + env = gym.make(env_id, **kwargs) if 'Montezuma' in env_id: env = MontezumaInfoWrapper(env, room_address=3 if 'Montezuma' in env_id else 1) env = StickyActionEnv(env) @@ -647,7 +647,7 @@ def make_car_racing(env_id, skip=4): env = make_atari(env_id, noop_max=0, skip=skip) return wrap_carracing(env, clip_rewards=False) -def make_atari_deepmind(env_id, noop_max=30, skip=4, sticky=False, episode_life=True): - env = make_atari(env_id, noop_max=noop_max, skip=skip, sticky=sticky) +def make_atari_deepmind(env_id, noop_max=30, skip=4, sticky=False, episode_life=True, **kwargs): + env = make_atari(env_id, noop_max=noop_max, skip=skip, sticky=sticky, **kwargs) return wrap_deepmind(env, episode_life=episode_life, clip_rewards=False) diff --git a/rl_games/configs/atari/ppo_breakout_torch.yaml b/rl_games/configs/atari/ppo_breakout_torch.yaml index 2886978e..dad54dea 100644 --- a/rl_games/configs/atari/ppo_breakout_torch.yaml +++ b/rl_games/configs/atari/ppo_breakout_torch.yaml @@ -75,8 +75,10 @@ params: name: 'BreakoutNoFrameskip-v4' episode_life: True seed: 5 + #render_mode: 'human' player: render: False - games_num: 200 + games_num: 50 n_game_life: 5 determenistic: False + From 0725dde92c688f96abd9c78af0fa48ba6390edff Mon Sep 17 00:00:00 2001 From: Viktor Makoviichuk Date: Thu, 26 May 2022 16:49:22 -0700 Subject: [PATCH 06/26] added shac agent --- rl_games/algos_torch/shac_agent.py | 26 +++++++++++++++++--------- rl_games/common/a2c_common.py | 2 +- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/rl_games/algos_torch/shac_agent.py b/rl_games/algos_torch/shac_agent.py index f74ca28f..f124d317 100644 --- a/rl_games/algos_torch/shac_agent.py +++ b/rl_games/algos_torch/shac_agent.py @@ -38,7 +38,9 @@ def __init__(self, base_name, params): 'normalize_input': self.normalize_input, } self.critic_lr = self.config.get('critic_learning_rate', 0.0001) - self.target_critic_alpha = 0.4 + self.use_target_critic = self.config.get('use_target_critic', True) + self.target_critic_alpha = self.config.get('target_critic_alpha', 0.4) + self.max_episode_length = 1000 # temporary hardcoded self.actor_model = self.network.build(build_config) self.critic_model = self.critic_network.build(build_config) @@ -55,9 +57,10 @@ def __init__(self, base_name, params): self.model = self.actor_model self.init_rnn_from_model(self.actor_model) self.last_lr = float(self.last_lr) - self.optimizer = self.actor_optimizer = optim.Adam(self.actor_model.parameters(), float(self.last_lr), eps=1e-08, + self.betas = self.config.get('betas',[0.9, 0.999]) + self.optimizer = self.actor_optimizer = optim.Adam(self.actor_model.parameters(), float(self.last_lr), betas=self.betas, eps=1e-08, weight_decay=self.weight_decay) - self.critic_optimizer = optim.Adam(self.critic_model.parameters(), float(self.critic_lr), eps=1e-08, + self.critic_optimizer = optim.Adam(self.critic_model.parameters(), float(self.critic_lr), betas=self.betas, eps=1e-08, weight_decay=self.weight_decay) self.dataset = datasets.PPODataset(self.batch_size, self.minibatch_size, self.is_discrete, self.is_rnn, @@ -104,6 +107,7 @@ def play_steps(self): self.current_rewards += rewards.detach() self.current_lengths += 1 + episode_ended = self.current_lengths == self.max_episode_length all_done_indices = self.dones.nonzero(as_tuple=False) env_done_indices = self.dones.view(self.num_actors, self.num_agents).all(dim=1).nonzero(as_tuple=False) @@ -119,18 +123,19 @@ def play_steps(self): accumulated_rewards[n + 1] = accumulated_rewards[n] + gamma * shaped_rewards.squeeze(1) last_values = self.get_values(obs) - + episode_ended_vals = self.get_values(self.obs_to_tensors(infos['obs_before_reset'])).squeeze() + episode_ended_vals = episode_ended * episode_ended_vals end_vals = last_values.squeeze(1) end_vals = end_vals * not_dones + if n < self.horizon_length - 1: - #- self.gamma * gamma[env_done_indices] * end_vals[env_done_indices] - actor_loss = actor_loss - accumulated_rewards[n + 1] * fdones + actor_loss = actor_loss - (accumulated_rewards[n + 1] * fdones + self.gamma * gamma * episode_ended_vals).sum() # gamma = gamma * self.gamma gamma[env_done_indices] = 1.0 accumulated_rewards[n + 1, env_done_indices] = 0.0 else: # terminate all envs at the end of optimization iteration - actor_loss = actor_loss + (-accumulated_rewards[n + 1] - self.gamma * gamma * end_vals) + actor_loss = actor_loss + (-accumulated_rewards[n + 1] - self.gamma * gamma * (end_vals + episode_ended_vals)).sum() @@ -146,7 +151,7 @@ def play_steps(self): batch_dict['played_frames'] = self.batch_size batch_dict['step_time'] = step_time - actor_loss = actor_loss.mean() / self.horizon_length + actor_loss = actor_loss / (self.horizon_length * self.horizon_length) return batch_dict, actor_loss def env_step(self, actions): @@ -185,7 +190,10 @@ def get_values(self, obs): } processed_obs = self._preproc_obs(obs['obs']) - result = self.critic_model(input_dict) + if self.use_target_critic: + result = self.target_critic(input_dict) + else: + result = self.critic_model(input_dict) value = result['values'] return value diff --git a/rl_games/common/a2c_common.py b/rl_games/common/a2c_common.py index b8411b8c..f3a4669c 100644 --- a/rl_games/common/a2c_common.py +++ b/rl_games/common/a2c_common.py @@ -386,7 +386,7 @@ def init_tensors(self): val_shape = (self.horizon_length, batch_size, self.value_size) current_rewards_shape = (batch_size, self.value_size) self.current_rewards = torch.zeros(current_rewards_shape, dtype=torch.float32, device=self.ppo_device) - self.current_lengths = torch.zeros(batch_size, dtype=torch.float32, device=self.ppo_device) + self.current_lengths = torch.zeros(batch_size, dtype=torch.long, device=self.ppo_device) self.dones = torch.ones((batch_size,), dtype=torch.uint8, device=self.ppo_device) if self.is_rnn: From 5dbdc0f6fd04e6c09c5e3dfcdf5e67a56086a56c Mon Sep 17 00:00:00 2001 From: Viktor Makoviichuk Date: Thu, 26 May 2022 20:37:53 -0700 Subject: [PATCH 07/26] fixed actor loss --- rl_games/algos_torch/shac_agent.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/rl_games/algos_torch/shac_agent.py b/rl_games/algos_torch/shac_agent.py index f124d317..b2990b32 100644 --- a/rl_games/algos_torch/shac_agent.py +++ b/rl_games/algos_torch/shac_agent.py @@ -96,7 +96,8 @@ def play_steps(self): self.experience_buffer.update_data('values', n, res_dict['values'].detach()) step_time_start = time.time() - obs, rewards, self.dones, infos = self.env_step(torch.tanh(res_dict['actions'])) + actions = torch.tanh(res_dict['actions']) + obs, rewards, self.dones, infos = self.env_step(actions) step_time_end = time.time() step_time += (step_time_end - step_time_start) @@ -123,8 +124,8 @@ def play_steps(self): accumulated_rewards[n + 1] = accumulated_rewards[n] + gamma * shaped_rewards.squeeze(1) last_values = self.get_values(obs) - episode_ended_vals = self.get_values(self.obs_to_tensors(infos['obs_before_reset'])).squeeze() - episode_ended_vals = episode_ended * episode_ended_vals + episode_ended_vals = self.get_values(self.obs_to_tensors(infos['obs_before_reset'])) + episode_ended_vals = episode_ended * episode_ended_vals.squeeze() end_vals = last_values.squeeze(1) end_vals = end_vals * not_dones @@ -135,7 +136,7 @@ def play_steps(self): accumulated_rewards[n + 1, env_done_indices] = 0.0 else: # terminate all envs at the end of optimization iteration - actor_loss = actor_loss + (-accumulated_rewards[n + 1] - self.gamma * gamma * (end_vals + episode_ended_vals)).sum() + actor_loss = actor_loss - (accumulated_rewards[n + 1] + self.gamma * gamma * (end_vals + episode_ended_vals)).sum() @@ -151,7 +152,7 @@ def play_steps(self): batch_dict['played_frames'] = self.batch_size batch_dict['step_time'] = step_time - actor_loss = actor_loss / (self.horizon_length * self.horizon_length) + actor_loss = actor_loss / (self.horizon_length * self.num_actors) return batch_dict, actor_loss def env_step(self, actions): From 27c0e02eb564213d107fb0ecc951b32160422d26 Mon Sep 17 00:00:00 2001 From: Viktor Makoviichuk Date: Thu, 26 May 2022 21:37:21 -0700 Subject: [PATCH 08/26] removed shac --- rl_games/algos_torch/network_builder.py | 107 ------------------------ 1 file changed, 107 deletions(-) diff --git a/rl_games/algos_torch/network_builder.py b/rl_games/algos_torch/network_builder.py index c10af008..04930c6a 100644 --- a/rl_games/algos_torch/network_builder.py +++ b/rl_games/algos_torch/network_builder.py @@ -1029,110 +1029,3 @@ def load(self, params): else: self.is_discrete = False self.is_continuous = False - - -class SHACBuilder(NetworkBuilder): - def __init__(self, **kwargs): - NetworkBuilder.__init__(self) - - def load(self, params): - self.params = params - - def build(self, name, **kwargs): - net = SHACBuilder.Network(self.params, **kwargs) - return net - - class Network(NetworkBuilder.BaseNetwork): - def __init__(self, params, **kwargs): - actions_num = kwargs.pop('actions_num') - input_shape = kwargs.pop('input_shape') - obs_dim = kwargs.pop('obs_dim') - action_dim = kwargs.pop('action_dim') - self.num_seqs = kwargs.pop('num_seqs', 1) - - NetworkBuilder.BaseNetwork.__init__(self) - self.load(params) - - mlp_input_shape = input_shape - - actor_mlp_args = { - 'input_size' : obs_dim, - 'units' : self.units, - 'activation' : self.activation, - 'norm_func_name' : self.normalization, - 'dense_func' : torch.nn.Linear, - 'd2rl' : self.is_d2rl, - 'norm_only_first_layer' : self.norm_only_first_layer - } - - critic_mlp_args = { - 'input_size' : obs_dim + action_dim, - 'units' : self.units, - 'activation' : self.activation, - 'norm_func_name' : self.normalization, - 'dense_func' : torch.nn.Linear, - 'd2rl' : self.is_d2rl, - 'norm_only_first_layer' : self.norm_only_first_layer - } - print("Building Actor") - self.actor = self._build_actor(2*action_dim, self.log_std_bounds, **actor_mlp_args) - - if self.separate: - print("Building Critic") - self.critic = self._build_critic(1, **critic_mlp_args) - print("Building Critic Target") - self.critic_target = self._build_critic(1, **critic_mlp_args) - self.critic_target.load_state_dict(self.critic.state_dict()) - - mlp_init = self.init_factory.create(**self.initializer) - for m in self.modules(): - if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d): - cnn_init(m.weight) - if getattr(m, "bias", None) is not None: - torch.nn.init.zeros_(m.bias) - if isinstance(m, nn.Linear): - mlp_init(m.weight) - if getattr(m, "bias", None) is not None: - torch.nn.init.zeros_(m.bias) - - def _build_critic(self, output_dim, **mlp_args): - return DoubleQCritic(output_dim, **mlp_args) - - def _build_actor(self, output_dim, log_std_bounds, **mlp_args): - return DiagGaussianActor(output_dim, log_std_bounds, **mlp_args) - - def forward(self, obs_dict): - """TODO""" - obs = obs_dict['obs'] - mu, sigma = self.actor(obs) - return mu, sigma - - def is_separate_critic(self): - return self.separate - - def load(self, params): - self.separate = params.get('separate', True) - self.units = params['mlp']['units'] - self.activation = params['mlp']['activation'] - self.initializer = params['mlp']['initializer'] - self.is_d2rl = params['mlp'].get('d2rl', False) - self.norm_only_first_layer = params['mlp'].get('norm_only_first_layer', False) - self.value_activation = params.get('value_activation', 'None') - self.normalization = params.get('normalization', None) - self.has_space = 'space' in params - self.value_shape = params.get('value_shape', 1) - self.central_value = params.get('central_value', False) - self.joint_obs_actions_config = params.get('joint_obs_actions', None) - self.log_std_bounds = params.get('log_std_bounds', None) - - if self.has_space: - # todo: add assert if discrete, there is no discrete action space for SHAC - self.is_discrete = 'discrete' in params['space'] - self.is_continuous = 'continuous'in params['space'] - if self.is_continuous: - self.space_config = params['space']['continuous'] - elif self.is_discrete: - self.space_config = params['space']['discrete'] - else: - self.is_discrete = False - self.is_continuous = False \ No newline at end of file From 6ce26a2d6a547a4c2f6429dabff82b95374779b4 Mon Sep 17 00:00:00 2001 From: Viktor Makoviichuk Date: Thu, 26 May 2022 21:38:15 -0700 Subject: [PATCH 09/26] more cleanup --- rl_games/algos_torch/model_builder.py | 1 - 1 file changed, 1 deletion(-) diff --git a/rl_games/algos_torch/model_builder.py b/rl_games/algos_torch/model_builder.py index d38ff340..12057ac9 100644 --- a/rl_games/algos_torch/model_builder.py +++ b/rl_games/algos_torch/model_builder.py @@ -20,7 +20,6 @@ def __init__(self): self.network_factory.register_builder('actor_critic', lambda **kwargs: network_builder.A2CBuilder()) self.network_factory.register_builder('resnet_actor_critic', lambda **kwargs: network_builder.A2CResnetBuilder()) - self.network_factory.register_builder('rnd_curiosity', lambda **kwargs: network_builder.RNDCuriosityBuilder()) self.network_factory.register_builder('soft_actor_critic', lambda **kwargs: network_builder.SACBuilder()) def load(self, params): From 50f5c68329e6e9a68cdbcaa071fb4d0cca7cb562 Mon Sep 17 00:00:00 2001 From: ViktorM Date: Thu, 26 May 2022 21:45:01 -0700 Subject: [PATCH 10/26] Added linear lr for actor and critic. Currently assumes having the same learning rate. --- rl_games/algos_torch/central_value.py | 7 ++-- rl_games/algos_torch/shac_agent.py | 55 +++++++++++++++++++-------- rl_games/common/a2c_common.py | 10 +++-- rl_games/common/schedulers.py | 17 +++++---- 4 files changed, 61 insertions(+), 28 deletions(-) diff --git a/rl_games/algos_torch/central_value.py b/rl_games/algos_torch/central_value.py index e1074ce0..2815262b 100644 --- a/rl_games/algos_torch/central_value.py +++ b/rl_games/algos_torch/central_value.py @@ -176,8 +176,8 @@ def train_critic(self, input_dict): def update_multiagent_tensors(self, value_preds, returns, actions, dones): batch_size = self.batch_size ma_batch_size = self.num_actors * self.num_agents * self.horizon_length - value_preds = value_preds.view(self.num_actors, self.num_agents, self.horizon_length, self.value_size).transpose(0,1) - returns = returns.view(self.num_actors, self.num_agents, self.horizon_length, self.value_size).transpose(0,1) + value_preds = value_preds.view(self.num_actors, self.num_agents, self.horizon_length, self.value_size).transpose(0, 1) + returns = returns.view(self.num_actors, self.num_agents, self.horizon_length, self.value_size).transpose(0, 1) value_preds = value_preds.contiguous().view(ma_batch_size, self.value_size)[:batch_size] returns = returns.contiguous().view(ma_batch_size, self.value_size)[:batch_size] dones = dones.contiguous().view(ma_batch_size, self.value_size)[:batch_size] @@ -200,7 +200,8 @@ def train_net(self): self.frame += self.batch_size if self.writter != None: self.writter.add_scalar('losses/cval_loss', avg_loss, self.frame) - self.writter.add_scalar('info/cval_lr', self.lr, self.frame) + self.writter.add_scalar('info/cval_lr', self.lr, self.frame) + return avg_loss def calc_gradients(self, batch): diff --git a/rl_games/algos_torch/shac_agent.py b/rl_games/algos_torch/shac_agent.py index f124d317..f32b04ce 100644 --- a/rl_games/algos_torch/shac_agent.py +++ b/rl_games/algos_torch/shac_agent.py @@ -1,4 +1,3 @@ - from rl_games.algos_torch.running_mean_std import RunningMeanStd from rl_games.common import vecenv, schedulers, experience @@ -16,6 +15,7 @@ import os import copy + def swap_and_flatten01(arr): """ swap and then flatten axes 0 and 1 @@ -25,7 +25,9 @@ def swap_and_flatten01(arr): s = arr.size() return arr.transpose(0, 1).reshape(s[0] * s[1], *s[2:]) + class SHACAgent(ContinuousA2CBase): + def __init__(self, base_name, params): ContinuousA2CBase.__init__(self, base_name, params) obs_shape = self.obs_shape @@ -40,6 +42,7 @@ def __init__(self, base_name, params): self.critic_lr = self.config.get('critic_learning_rate', 0.0001) self.use_target_critic = self.config.get('use_target_critic', True) self.target_critic_alpha = self.config.get('target_critic_alpha', 0.4) + self.max_episode_length = 1000 # temporary hardcoded self.actor_model = self.network.build(build_config) self.critic_model = self.critic_network.build(build_config) @@ -53,15 +56,19 @@ def __init__(self, base_name, params): self.target_critic.running_mean_std = self.critic_model.running_mean_std if self.normalize_value: self.target_critic.value_mean_std = self.critic_model.value_mean_std + self.states = None self.model = self.actor_model self.init_rnn_from_model(self.actor_model) + self.last_lr = float(self.last_lr) self.betas = self.config.get('betas',[0.9, 0.999]) self.optimizer = self.actor_optimizer = optim.Adam(self.actor_model.parameters(), float(self.last_lr), betas=self.betas, eps=1e-08, weight_decay=self.weight_decay) - self.critic_optimizer = optim.Adam(self.critic_model.parameters(), float(self.critic_lr), betas=self.betas, eps=1e-08, - weight_decay=self.weight_decay) + # self.critic_optimizer = optim.Adam(self.critic_model.parameters(), float(self.critic_lr), betas=self.betas, eps=1e-08, + # weight_decay=self.weight_decay) + self.critic_optimizer = optim.Adam(self.critic_model.parameters(), float(self.last_lr), betas=self.betas, eps=1e-08, + weight_decay=self.weight_decay) self.dataset = datasets.PPODataset(self.batch_size, self.minibatch_size, self.is_discrete, self.is_rnn, self.ppo_device, self.seq_len) @@ -70,7 +77,6 @@ def __init__(self, base_name, params): self.algo_observer.after_init(self) - def play_steps(self): update_list = self.update_list accumulated_rewards = torch.zeros((self.horizon_length + 1, self.num_actors), dtype=torch.float32, device=self.device) @@ -80,11 +86,13 @@ def play_steps(self): step_time = 0.0 self.critic_model.eval() self.actor_model.train() + if self.normalize_input: self.actor_model.running_mean_std.train() if self.normalize_value: self.actor_model.value_mean_std.eval() obs = self.initialize_trajectory() + for n in range(self.horizon_length): res_dict = self.get_actions(obs) res_dict['values'] = self.get_values(obs) @@ -137,9 +145,6 @@ def play_steps(self): # terminate all envs at the end of optimization iteration actor_loss = actor_loss + (-accumulated_rewards[n + 1] - self.gamma * gamma * (end_vals + episode_ended_vals)).sum() - - - fdones = self.dones.float().detach() mb_fdones = self.experience_buffer.tensor_dict['dones'].float().detach() mb_rewards = self.experience_buffer.tensor_dict['rewards'].detach() @@ -162,7 +167,6 @@ def env_step(self, actions): rewards = rewards.unsqueeze(1) return self.obs_to_tensors(obs), rewards.to(self.ppo_device), dones.to(self.ppo_device), infos - def load_networks(self, params): ContinuousA2CBase.load_networks(self, params) if 'critic_config' in self.config: @@ -250,7 +254,6 @@ def prepare_critic_dataset(self, batch_dict): self.dataset.update_values_dict(dataset_dict) - def train_actor(self, actor_loss): self.actor_model.train() @@ -262,6 +265,7 @@ def train_actor(self, actor_loss): torch.nn.utils.clip_grad_norm_(self.actor_model.parameters(), self.grad_norm) self.actor_optimizer.step() + return actor_loss.detach() def train_critic(self, batch): @@ -270,6 +274,7 @@ def train_critic(self, batch): self.critic_model.running_mean_std.eval() if self.normalize_value: self.critic_model.value_mean_std.eval() + obs_batch = self._preproc_obs(batch['obs']) value_preds_batch = batch['old_values'] returns_batch = batch['returns'] @@ -294,6 +299,18 @@ def train_critic(self, batch): return critic_loss.detach() + def update_lr(self, actor_lr, critic_lr): + if self.multi_gpu: + lr_tensor = torch.tensor([lr]) + self.hvd.broadcast_value(lr_tensor, 'learning_rate') + lr = lr_tensor.item() + + for param_group in self.actor_optimizer.param_groups: + param_group['lr'] = actor_lr + + for param_group in self.critic_optimizer.param_groups: + param_group['lr'] = critic_lr + def train_epoch(self): play_time_start = time.time() batch_dict, actor_loss = self.play_steps() @@ -313,15 +330,19 @@ def train_epoch(self): c_loss = self.train_critic(self.dataset[i]) c_losses.append(c_loss) - self.diagnostics.mini_epoch(self, mini_ep) - # update target critic + # update target critic with torch.no_grad(): alpha = self.target_critic_alpha for param, param_targ in zip(self.critic_model.parameters(), self.target_critic.parameters()): param_targ.data.mul_(alpha) param_targ.data.add_((1. - alpha) * param.data) + + self.last_lr, _ = self.scheduler.update(self.last_lr, 0, self.epoch_num, 0, None) + #self.critic_lr, _ = self.scheduler.update(self.critic_lr, 0, self.epoch_num, 0, None) + self.update_lr(self.last_lr, self.last_lr) + update_time_end = time.time() play_time = play_time_end - play_time_start update_time = update_time_end - update_time_start @@ -329,7 +350,6 @@ def train_epoch(self): return batch_dict['step_time'], play_time, update_time, total_time, a_losses, c_losses - def train(self): self.init_tensors() self.mean_rewards = self.last_mean_rewards = -100500 @@ -338,7 +358,6 @@ def train(self): rep_count = 0 # self.frame = 0 # loading from checkpoint - while True: epoch_num = self.update_epoch() step_time, play_time, update_time, sum_time, a_losses, c_losses = self.train_epoch() @@ -349,6 +368,7 @@ def train(self): curr_frames = self.curr_frames self.frame += curr_frames should_exit = False + if self.rank == 0: self.diagnostics.epoch(self, current_epoch=epoch_num) scaled_time = self.num_agents * sum_time @@ -360,8 +380,8 @@ def train(self): fps_step = curr_frames / step_time fps_step_inference = curr_frames / scaled_play_time fps_total = curr_frames / scaled_time - print(f'fps step: {fps_step:.1f} fps step and policy inference: {fps_step_inference:.1f} fps total: {fps_total:.1f} epoch: {epoch_num}/{self.max_epochs}') - print('a_loss:', a_losses[0].item()) + print(f'fps step: {fps_step:.0f} fps step and policy inference: {fps_step_inference:.0f} fps total: {fps_total:.0f} epoch: {epoch_num}/{self.max_epochs}') + print('actor loss:', a_losses[0].item()) self.write_stats(total_time, epoch_num, step_time, play_time, update_time, a_losses, c_losses, curr_frames) @@ -405,7 +425,9 @@ def train(self): self.save(os.path.join(self.nn_dir, 'last_' + checkpoint_name)) print('MAX EPOCHS NUM!') should_exit = True + update_time = 0 + if should_exit: return self.last_mean_rewards, epoch_num @@ -422,4 +444,7 @@ def write_stats(self, total_time, epoch_num, step_time, play_time, update_time, self.writer.add_scalar('losses/a_loss', torch_ext.mean_list(a_losses).item(), frame) self.writer.add_scalar('losses/c_loss', torch_ext.mean_list(c_losses).item(), frame) self.writer.add_scalar('info/epochs', epoch_num, frame) + self.writer.add_scalar('info/last_lr', self.last_lr, frame) + self.writer.add_scalar('info/last_lr', self.last_lr, epoch_num) + self.algo_observer.after_print_stats(frame, epoch_num, total_time) \ No newline at end of file diff --git a/rl_games/common/a2c_common.py b/rl_games/common/a2c_common.py index f3a4669c..599d3552 100644 --- a/rl_games/common/a2c_common.py +++ b/rl_games/common/a2c_common.py @@ -253,7 +253,7 @@ def __init__(self, base_name, params): def trancate_gradients_and_step(self): if self.multi_gpu: self.optimizer.synchronize() - + if self.truncate_grads: self.scaler.unscale_(self.optimizer) nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_norm) @@ -288,7 +288,7 @@ def write_stats(self, total_time, epoch_num, step_time, play_time, update_time, self.writer.add_scalar('performance/step_time', step_time, frame) self.writer.add_scalar('losses/a_loss', torch_ext.mean_list(a_losses).item(), frame) self.writer.add_scalar('losses/c_loss', torch_ext.mean_list(c_losses).item(), frame) - + self.writer.add_scalar('losses/entropy', torch_ext.mean_list(entropies).item(), frame) self.writer.add_scalar('info/last_lr', last_lr * lr_mul, frame) self.writer.add_scalar('info/lr_mul', lr_mul, frame) @@ -471,6 +471,7 @@ def discount_values(self, fdones, last_extrinsic_values, mb_fdones, mb_extrinsic delta = mb_rewards[t] + self.gamma * nextvalues * nextnonterminal - mb_extrinsic_values[t] mb_advs[t] = lastgaelam = delta + self.gamma * self.tau * nextnonterminal * lastgaelam + return mb_advs def discount_values_masks(self, fdones, last_extrinsic_values, mb_fdones, mb_extrinsic_values, mb_rewards, mb_masks): @@ -487,6 +488,7 @@ def discount_values_masks(self, fdones, last_extrinsic_values, mb_fdones, mb_ext masks_t = mb_masks[t].unsqueeze(1) delta = (mb_rewards[t] + self.gamma * nextvalues * nextnonterminal - mb_extrinsic_values[t]) mb_advs[t] = lastgaelam = (delta + self.gamma * self.tau * nextnonterminal * lastgaelam) * masks_t + return mb_advs def clear_stats(self): @@ -704,6 +706,7 @@ def play_steps_rnn(self): self.current_lengths += 1 all_done_indices = self.dones.nonzero(as_tuple=False) env_done_indices = self.dones.view(self.num_actors, self.num_agents).all(dim=1).nonzero(as_tuple=False) + if len(all_done_indices) > 0: for s in self.rnn_states: s[:, all_done_indices, :] = s[:, all_done_indices, :] * 0.0 @@ -738,6 +741,7 @@ def play_steps_rnn(self): states.append(mb_s.permute(1,2,0,3).reshape(-1,t_size, h_size)) batch_dict['rnn_states'] = states batch_dict['step_time'] = step_time + return batch_dict @@ -799,7 +803,7 @@ def train_epoch(self): a_losses.append(a_loss) c_losses.append(c_loss) ep_kls.append(kl) - entropies.append(entropy) + entropies.append(entropy) av_kls = torch_ext.mean_list(ep_kls) if self.multi_gpu: diff --git a/rl_games/common/schedulers.py b/rl_games/common/schedulers.py index 562d3562..df945671 100644 --- a/rl_games/common/schedulers.py +++ b/rl_games/common/schedulers.py @@ -1,5 +1,3 @@ - - class RLScheduler: def __init__(self): pass @@ -7,11 +5,11 @@ def __init__(self): def update(self,current_lr, entropy_coef, epoch, frames, **kwargs): pass + class IdentityScheduler(RLScheduler): def __init__(self): super().__init__() - def update(self, current_lr, entropy_coef, epoch, frames, kl_dist, **kwargs): return current_lr, entropy_coef @@ -27,13 +25,15 @@ def update(self, current_lr, entropy_coef, epoch, frames, kl_dist, **kwargs): lr = current_lr if kl_dist > (2.0 * self.kl_threshold): lr = max(current_lr / 1.5, self.min_lr) + if kl_dist < (0.5 * self.kl_threshold): lr = min(current_lr * 1.5, self.max_lr) - return lr, entropy_coef + + return lr, entropy_coef class LinearScheduler(RLScheduler): - def __init__(self, start_lr, min_lr=1e-6, max_steps = 1000000, use_epochs=True, apply_to_entropy=False, **kwargs): + def __init__(self, start_lr, min_lr=1e-5, max_steps = 1000000, use_epochs=True, apply_to_entropy=False, **kwargs): super().__init__() self.start_lr = start_lr self.min_lr = min_lr @@ -49,8 +49,11 @@ def update(self, current_lr, entropy_coef, epoch, frames, kl_dist, **kwargs): steps = epoch else: steps = frames - mul = max(0, self.max_steps - steps)/self.max_steps + + mul = max(0, self.max_steps - steps)/self.max_steps lr = self.min_lr + (self.start_lr - self.min_lr) * mul + if self.apply_to_entropy: entropy_coef = self.min_entropy_coef + (self.start_entropy_coef - self.min_entropy_coef) * mul - return lr, entropy_coef \ No newline at end of file + + return lr, entropy_coef \ No newline at end of file From 9c16caf7623f8564fe3d8408740ca726c5106f39 Mon Sep 17 00:00:00 2001 From: Viktor Makoviichuk Date: Thu, 26 May 2022 23:49:38 -0700 Subject: [PATCH 11/26] updated tb --- rl_games/algos_torch/shac_agent.py | 23 ++++++++--------------- 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/rl_games/algos_torch/shac_agent.py b/rl_games/algos_torch/shac_agent.py index 10eb0aca..8fbc3db4 100644 --- a/rl_games/algos_torch/shac_agent.py +++ b/rl_games/algos_torch/shac_agent.py @@ -56,6 +56,7 @@ def __init__(self, base_name, params): self.target_critic.running_mean_std = self.critic_model.running_mean_std if self.normalize_value: self.target_critic.value_mean_std = self.critic_model.value_mean_std + self.actor_model.value_mean_std = None self.states = None self.model = self.actor_model @@ -85,12 +86,10 @@ def play_steps(self): gamma = torch.ones(self.num_actors, dtype = torch.float32, device = self.device) step_time = 0.0 self.critic_model.eval() + self.target_critic.eval() self.actor_model.train() - if self.normalize_input: self.actor_model.running_mean_std.train() - if self.normalize_value: - self.actor_model.value_mean_std.eval() obs = self.initialize_trajectory() for n in range(self.horizon_length): @@ -132,8 +131,8 @@ def play_steps(self): accumulated_rewards[n + 1] = accumulated_rewards[n] + gamma * shaped_rewards.squeeze(1) last_values = self.get_values(obs) - episode_ended_vals = self.get_values(self.obs_to_tensors(infos['obs_before_reset'])) - episode_ended_vals = episode_ended * episode_ended_vals.squeeze() + episode_ended_values = self.get_values(self.obs_to_tensors(infos['obs_before_reset'])) + episode_ended_vals = episode_ended * episode_ended_values.squeeze() end_vals = last_values.squeeze(1) end_vals = end_vals * not_dones @@ -143,6 +142,7 @@ def play_steps(self): gamma[env_done_indices] = 1.0 accumulated_rewards[n + 1, env_done_indices] = 0.0 else: + #last_values = last_values.detach() * not_dones.unsqueeze(1) + episode_ended.unsqueeze(1) * episode_ended_values.detach() # terminate all envs at the end of optimization iteration actor_loss = actor_loss - (accumulated_rewards[n + 1] + self.gamma * gamma * (end_vals + episode_ended_vals)).sum() @@ -231,7 +231,6 @@ def prepare_critic_dataset(self, batch_dict): rnn_states = batch_dict.get('rnn_states', None) rnn_masks = batch_dict.get('rnn_masks', None) - advantages = returns - values if self.normalize_value: self.value_mean_std.train() @@ -239,14 +238,10 @@ def prepare_critic_dataset(self, batch_dict): returns = self.value_mean_std(returns) self.value_mean_std.eval() - advantages = torch.sum(advantages, axis=1) - if self.normalize_advantage: - advantages = torch_ext.normalization_with_masks(advantages, rnn_masks) dataset_dict = {} dataset_dict['old_values'] = values - dataset_dict['advantages'] = advantages dataset_dict['returns'] = returns dataset_dict['obs'] = obses dataset_dict['dones'] = dones @@ -339,11 +334,9 @@ def train_epoch(self): for param, param_targ in zip(self.critic_model.parameters(), self.target_critic.parameters()): param_targ.data.mul_(alpha) param_targ.data.add_((1. - alpha) * param.data) - - self.last_lr, _ = self.scheduler.update(self.last_lr, 0, self.epoch_num, 0, None) + self.last_lr, _ = self.scheduler.update(self.last_lr, 0, self.epoch_num, 0, None) #self.critic_lr, _ = self.scheduler.update(self.critic_lr, 0, self.epoch_num, 0, None) self.update_lr(self.last_lr, self.last_lr) - update_time_end = time.time() play_time = play_time_end - play_time_start update_time = update_time_end - update_time_start @@ -445,7 +438,7 @@ def write_stats(self, total_time, epoch_num, step_time, play_time, update_time, self.writer.add_scalar('losses/a_loss', torch_ext.mean_list(a_losses).item(), frame) self.writer.add_scalar('losses/c_loss', torch_ext.mean_list(c_losses).item(), frame) self.writer.add_scalar('info/epochs', epoch_num, frame) - self.writer.add_scalar('info/last_lr', self.last_lr, frame) - self.writer.add_scalar('info/last_lr', self.last_lr, epoch_num) + self.writer.add_scalar('info/last_lr/frame', self.last_lr, frame) + self.writer.add_scalar('info/last_lr/epoch_num', self.last_lr, epoch_num) self.algo_observer.after_print_stats(frame, epoch_num, total_time) \ No newline at end of file From 5d6ce14c7783ed07c0c0455d43967eb4f4d24a47 Mon Sep 17 00:00:00 2001 From: Viktor Makoviichuk Date: Fri, 27 May 2022 00:13:24 -0700 Subject: [PATCH 12/26] best shac --- rl_games/algos_torch/shac_agent.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/rl_games/algos_torch/shac_agent.py b/rl_games/algos_torch/shac_agent.py index 8fbc3db4..bac65eef 100644 --- a/rl_games/algos_torch/shac_agent.py +++ b/rl_games/algos_torch/shac_agent.py @@ -142,7 +142,7 @@ def play_steps(self): gamma[env_done_indices] = 1.0 accumulated_rewards[n + 1, env_done_indices] = 0.0 else: - #last_values = last_values.detach() * not_dones.unsqueeze(1) + episode_ended.unsqueeze(1) * episode_ended_values.detach() + last_values = last_values.detach() * not_dones.unsqueeze(1) + episode_ended.unsqueeze(1) * episode_ended_values.detach() # terminate all envs at the end of optimization iteration actor_loss = actor_loss - (accumulated_rewards[n + 1] + self.gamma * gamma * (end_vals + episode_ended_vals)).sum() @@ -231,7 +231,6 @@ def prepare_critic_dataset(self, batch_dict): rnn_states = batch_dict.get('rnn_states', None) rnn_masks = batch_dict.get('rnn_masks', None) - if self.normalize_value: self.value_mean_std.train() values = self.value_mean_std(values) @@ -396,9 +395,6 @@ def train(self): self.writer.add_scalar('episode_lengths/iter', mean_lengths, epoch_num) self.writer.add_scalar('episode_lengths/time', mean_lengths, total_time) - if self.has_self_play_config: - self.self_play_manager.update(self) - # removed equal signs (i.e. "rew=") from the checkpoint name since it messes with hydra CLI parsing checkpoint_name = self.config['name'] + '_ep_' + str(epoch_num) + '_rew_' + str(mean_rewards[0]) From f67b6805e681b7c49128e67a13b4ade7b706d959 Mon Sep 17 00:00:00 2001 From: Viktor Makoviichuk Date: Fri, 27 May 2022 19:32:06 -0700 Subject: [PATCH 13/26] fixed rms --- rl_games/algos_torch/running_mean_std.py | 2 +- rl_games/algos_torch/shac_agent.py | 47 +++++++++++++++--------- 2 files changed, 31 insertions(+), 18 deletions(-) diff --git a/rl_games/algos_torch/running_mean_std.py b/rl_games/algos_torch/running_mean_std.py index 947f4250..a7a028bf 100644 --- a/rl_games/algos_torch/running_mean_std.py +++ b/rl_games/algos_torch/running_mean_std.py @@ -77,7 +77,7 @@ def forward(self, input, unnorm=False, mask=None): if self.norm_only: y = input/ torch.sqrt(current_var.float() + self.epsilon) else: - y = (input - current_mean.float()) / torch.sqrt(current_var.float() + self.epsilon) + y = (input - current_mean.clone().float()) / torch.sqrt(current_var.clone().float() + self.epsilon) y = torch.clamp(y, min=-5.0, max=5.0) return y diff --git a/rl_games/algos_torch/shac_agent.py b/rl_games/algos_torch/shac_agent.py index bac65eef..bbe3f953 100644 --- a/rl_games/algos_torch/shac_agent.py +++ b/rl_games/algos_torch/shac_agent.py @@ -91,10 +91,13 @@ def play_steps(self): if self.normalize_input: self.actor_model.running_mean_std.train() obs = self.initialize_trajectory() - + last_values = None for n in range(self.horizon_length): res_dict = self.get_actions(obs) - res_dict['values'] = self.get_values(obs) + if last_values is None: + res_dict['values'] = self.get_values(obs) + else: + res_dict['values'] = last_values with torch.no_grad(): self.experience_buffer.update_data('obses', n, obs['obs'].detach()) @@ -125,27 +128,37 @@ def play_steps(self): fdones = self.dones.float() not_dones = 1.0 - fdones - self.current_rewards = self.current_rewards * not_dones.unsqueeze(1) - self.current_lengths = self.current_lengths * not_dones - accumulated_rewards[n + 1] = accumulated_rewards[n] + gamma * shaped_rewards.squeeze(1) last_values = self.get_values(obs) - episode_ended_values = self.get_values(self.obs_to_tensors(infos['obs_before_reset'])) - episode_ended_vals = episode_ended * episode_ended_values.squeeze() - end_vals = last_values.squeeze(1) - end_vals = end_vals * not_dones + for id in env_done_indices: + if self.current_lengths[id] < self.max_episode_length: # early termination + last_values[id] = 0. + else: # otherwise, use terminal value critic to estimate the long-term performance + if self.normalize_input: + self.actor_model.running_mean_std.eval() + real_obs = self.obs_to_tensors(infos['obs_before_reset'][id]) + last_values[id] = self.get_values(real_obs) + if self.normalize_input: + self.actor_model.running_mean_std.train() + + if (last_values > 1e6).sum() > 0 or (last_values < -1e6).sum() > 0: + print('next value error') + raise ValueError + self.current_rewards = self.current_rewards * not_dones.unsqueeze(1) + self.current_lengths = self.current_lengths * not_dones + accumulated_rewards[n + 1, :] = accumulated_rewards[n, :] + gamma * shaped_rewards.squeeze(1) if n < self.horizon_length - 1: - actor_loss = actor_loss - (accumulated_rewards[n + 1] * fdones + self.gamma * gamma * episode_ended_vals).sum() # - gamma = gamma * self.gamma - gamma[env_done_indices] = 1.0 - accumulated_rewards[n + 1, env_done_indices] = 0.0 + actor_loss = actor_loss - ( + accumulated_rewards[n + 1, env_done_indices] + self.gamma * gamma[env_done_indices] * + last_values.squeeze(1)[env_done_indices]).sum() else: - last_values = last_values.detach() * not_dones.unsqueeze(1) + episode_ended.unsqueeze(1) * episode_ended_values.detach() - # terminate all envs at the end of optimization iteration - actor_loss = actor_loss - (accumulated_rewards[n + 1] + self.gamma * gamma * (end_vals + episode_ended_vals)).sum() - + actor_loss = actor_loss - ( + accumulated_rewards[n + 1, :] + self.gamma * gamma * last_values.squeeze()).sum() + gamma = gamma * self.gamma + gamma[env_done_indices] = 1.0 + accumulated_rewards[n + 1, env_done_indices] = 0.0 fdones = self.dones.float().detach() mb_fdones = self.experience_buffer.tensor_dict['dones'].float().detach() mb_rewards = self.experience_buffer.tensor_dict['rewards'].detach() From 278c3e4997330fb5a26ec9f5a9b4ec86f7e93387 Mon Sep 17 00:00:00 2001 From: Viktor Makoviichuk Date: Fri, 27 May 2022 21:22:52 -0700 Subject: [PATCH 14/26] last --- rl_games/algos_torch/shac_agent.py | 26 ++++++-------------------- 1 file changed, 6 insertions(+), 20 deletions(-) diff --git a/rl_games/algos_torch/shac_agent.py b/rl_games/algos_torch/shac_agent.py index bbe3f953..227fa438 100644 --- a/rl_games/algos_torch/shac_agent.py +++ b/rl_games/algos_torch/shac_agent.py @@ -109,16 +109,17 @@ def play_steps(self): actions = torch.tanh(res_dict['actions']) obs, rewards, self.dones, infos = self.env_step(actions) step_time_end = time.time() - + episode_ended = self.current_lengths == self.max_episode_length - 1 step_time += (step_time_end - step_time_start) shaped_rewards = self.rewards_shaper(rewards) - + real_obs = self.obs_to_tensors(infos['obs_before_reset']) + shaped_rewards += self.gamma * self.get_values(real_obs) * episode_ended.unsqueeze(1).float() self.experience_buffer.update_data('rewards', n, shaped_rewards.detach()) self.current_rewards += rewards.detach() self.current_lengths += 1 - episode_ended = self.current_lengths == self.max_episode_length + all_done_indices = self.dones.nonzero(as_tuple=False) env_done_indices = self.dones.view(self.num_actors, self.num_agents).all(dim=1).nonzero(as_tuple=False) @@ -131,31 +132,16 @@ def play_steps(self): accumulated_rewards[n + 1] = accumulated_rewards[n] + gamma * shaped_rewards.squeeze(1) last_values = self.get_values(obs) - for id in env_done_indices: - if self.current_lengths[id] < self.max_episode_length: # early termination - last_values[id] = 0. - else: # otherwise, use terminal value critic to estimate the long-term performance - if self.normalize_input: - self.actor_model.running_mean_std.eval() - real_obs = self.obs_to_tensors(infos['obs_before_reset'][id]) - last_values[id] = self.get_values(real_obs) - if self.normalize_input: - self.actor_model.running_mean_std.train() - - if (last_values > 1e6).sum() > 0 or (last_values < -1e6).sum() > 0: - print('next value error') - raise ValueError self.current_rewards = self.current_rewards * not_dones.unsqueeze(1) self.current_lengths = self.current_lengths * not_dones accumulated_rewards[n + 1, :] = accumulated_rewards[n, :] + gamma * shaped_rewards.squeeze(1) if n < self.horizon_length - 1: actor_loss = actor_loss - ( - accumulated_rewards[n + 1, env_done_indices] + self.gamma * gamma[env_done_indices] * - last_values.squeeze(1)[env_done_indices]).sum() + accumulated_rewards[n + 1, env_done_indices]).sum() else: actor_loss = actor_loss - ( - accumulated_rewards[n + 1, :] + self.gamma * gamma * last_values.squeeze()).sum() + accumulated_rewards[n + 1, :] + self.gamma * gamma * last_values.squeeze() * (1.0-episode_ended.float()) * not_dones).sum() gamma = gamma * self.gamma gamma[env_done_indices] = 1.0 accumulated_rewards[n + 1, env_done_indices] = 0.0 From cafe15323065c1f7cec03a6f63973c217f48fda4 Mon Sep 17 00:00:00 2001 From: Viktor Makoviichuk Date: Sat, 28 May 2022 13:57:42 -0700 Subject: [PATCH 15/26] fixed copypaste --- rl_games/algos_torch/shac_agent.py | 1 - 1 file changed, 1 deletion(-) diff --git a/rl_games/algos_torch/shac_agent.py b/rl_games/algos_torch/shac_agent.py index 227fa438..02048f93 100644 --- a/rl_games/algos_torch/shac_agent.py +++ b/rl_games/algos_torch/shac_agent.py @@ -135,7 +135,6 @@ def play_steps(self): self.current_rewards = self.current_rewards * not_dones.unsqueeze(1) self.current_lengths = self.current_lengths * not_dones - accumulated_rewards[n + 1, :] = accumulated_rewards[n, :] + gamma * shaped_rewards.squeeze(1) if n < self.horizon_length - 1: actor_loss = actor_loss - ( accumulated_rewards[n + 1, env_done_indices]).sum() From 8f1f4a2e98a9f8270abfc301add8bffba938f477 Mon Sep 17 00:00:00 2001 From: ViktorM Date: Sat, 28 May 2022 22:42:45 -0700 Subject: [PATCH 16/26] Independent lr schedulers for actor and critic. --- rl_games/algos_torch/shac_agent.py | 27 ++++++++++++--------------- rl_games/common/a2c_common.py | 2 +- 2 files changed, 13 insertions(+), 16 deletions(-) diff --git a/rl_games/algos_torch/shac_agent.py b/rl_games/algos_torch/shac_agent.py index 02048f93..1fc0636c 100644 --- a/rl_games/algos_torch/shac_agent.py +++ b/rl_games/algos_torch/shac_agent.py @@ -39,10 +39,12 @@ def __init__(self, base_name, params): 'normalize_value': self.normalize_value, 'normalize_input': self.normalize_input, } - self.critic_lr = self.config.get('critic_learning_rate', 0.0001) + self.critic_lr = float(self.config.get('critic_learning_rate', 0.0001)) self.use_target_critic = self.config.get('use_target_critic', True) self.target_critic_alpha = self.config.get('target_critic_alpha', 0.4) + self.critic_scheduler = schedulers.LinearScheduler(self.critic_lr, min_lr=1e-4, max_steps=self.max_epochs) + self.max_episode_length = 1000 # temporary hardcoded self.actor_model = self.network.build(build_config) self.critic_model = self.critic_network.build(build_config) @@ -66,9 +68,7 @@ def __init__(self, base_name, params): self.betas = self.config.get('betas',[0.9, 0.999]) self.optimizer = self.actor_optimizer = optim.Adam(self.actor_model.parameters(), float(self.last_lr), betas=self.betas, eps=1e-08, weight_decay=self.weight_decay) - # self.critic_optimizer = optim.Adam(self.critic_model.parameters(), float(self.critic_lr), betas=self.betas, eps=1e-08, - # weight_decay=self.weight_decay) - self.critic_optimizer = optim.Adam(self.critic_model.parameters(), float(self.last_lr), betas=self.betas, eps=1e-08, + self.critic_optimizer = optim.Adam(self.critic_model.parameters(), float(self.critic_lr), betas=self.betas, eps=1e-08, weight_decay=self.weight_decay) self.dataset = datasets.PPODataset(self.batch_size, self.minibatch_size, self.is_discrete, self.is_rnn, @@ -88,10 +88,12 @@ def play_steps(self): self.critic_model.eval() self.target_critic.eval() self.actor_model.train() + if self.normalize_input: self.actor_model.running_mean_std.train() obs = self.initialize_trajectory() last_values = None + for n in range(self.horizon_length): res_dict = self.get_actions(obs) if last_values is None: @@ -141,6 +143,7 @@ def play_steps(self): else: actor_loss = actor_loss - ( accumulated_rewards[n + 1, :] + self.gamma * gamma * last_values.squeeze() * (1.0-episode_ended.float()) * not_dones).sum() + gamma = gamma * self.gamma gamma[env_done_indices] = 1.0 accumulated_rewards[n + 1, env_done_indices] = 0.0 @@ -235,8 +238,6 @@ def prepare_critic_dataset(self, batch_dict): returns = self.value_mean_std(returns) self.value_mean_std.eval() - - dataset_dict = {} dataset_dict['old_values'] = values dataset_dict['returns'] = returns @@ -293,11 +294,6 @@ def train_critic(self, batch): return critic_loss.detach() def update_lr(self, actor_lr, critic_lr): - if self.multi_gpu: - lr_tensor = torch.tensor([lr]) - self.hvd.broadcast_value(lr_tensor, 'learning_rate') - lr = lr_tensor.item() - for param_group in self.actor_optimizer.param_groups: param_group['lr'] = actor_lr @@ -331,9 +327,10 @@ def train_epoch(self): for param, param_targ in zip(self.critic_model.parameters(), self.target_critic.parameters()): param_targ.data.mul_(alpha) param_targ.data.add_((1. - alpha) * param.data) - self.last_lr, _ = self.scheduler.update(self.last_lr, 0, self.epoch_num, 0, None) - #self.critic_lr, _ = self.scheduler.update(self.critic_lr, 0, self.epoch_num, 0, None) - self.update_lr(self.last_lr, self.last_lr) + self.last_lr, _ = self.scheduler.update(self.last_lr, 0, self.epoch_num, 0, None) + self.critic_lr, _ = self.critic_scheduler.update(self.critic_lr, 0, self.epoch_num, 0, None) + + self.update_lr(self.last_lr, self.critic_lr) update_time_end = time.time() play_time = play_time_end - play_time_start update_time = update_time_end - update_time_start @@ -351,7 +348,7 @@ def train(self): while True: epoch_num = self.update_epoch() - step_time, play_time, update_time, sum_time, a_losses, c_losses = self.train_epoch() + step_time, play_time, update_time, sum_time, a_losses, c_losses = self.train_epoch() # cleaning memory to optimize space self.dataset.update_values_dict(None) diff --git a/rl_games/common/a2c_common.py b/rl_games/common/a2c_common.py index 599d3552..b72947be 100644 --- a/rl_games/common/a2c_common.py +++ b/rl_games/common/a2c_common.py @@ -11,7 +11,7 @@ from rl_games.common.interval_summary_writer import IntervalSummaryWriter from rl_games.common.diagnostics import DefaultDiagnostics, PpoDiagnostics from rl_games.algos_torch import model_builder -from rl_games.interfaces.base_algorithm import BaseAlgorithm +from rl_games.interfaces.base_algorithm import BaseAlgorithm import numpy as np import time import gym From 6233e98a33cc3102484842316421cdffc51a3599 Mon Sep 17 00:00:00 2001 From: Viktor Makoviichuk Date: Sun, 29 May 2022 16:17:14 -0700 Subject: [PATCH 17/26] best version --- rl_games/algos_torch/running_mean_std.py | 2 +- rl_games/algos_torch/shac_agent.py | 44 +++++++++++++++++------- 2 files changed, 33 insertions(+), 13 deletions(-) diff --git a/rl_games/algos_torch/running_mean_std.py b/rl_games/algos_torch/running_mean_std.py index a7a028bf..d6dc285c 100644 --- a/rl_games/algos_torch/running_mean_std.py +++ b/rl_games/algos_torch/running_mean_std.py @@ -72,7 +72,7 @@ def forward(self, input, unnorm=False, mask=None): if unnorm: y = torch.clamp(input, min=-5.0, max=5.0) - y = torch.sqrt(current_var.float() + self.epsilon)*y + current_mean.float() + y = torch.sqrt(current_var.float().clone() + self.epsilon)*y + current_mean.float().clone() else: if self.norm_only: y = input/ torch.sqrt(current_var.float() + self.epsilon) diff --git a/rl_games/algos_torch/shac_agent.py b/rl_games/algos_torch/shac_agent.py index 02048f93..57ec3af3 100644 --- a/rl_games/algos_torch/shac_agent.py +++ b/rl_games/algos_torch/shac_agent.py @@ -39,7 +39,7 @@ def __init__(self, base_name, params): 'normalize_value': self.normalize_value, 'normalize_input': self.normalize_input, } - self.critic_lr = self.config.get('critic_learning_rate', 0.0001) + self.critic_lr = float(self.config.get('critic_learning_rate', 0.0001)) self.use_target_critic = self.config.get('use_target_critic', True) self.target_critic_alpha = self.config.get('target_critic_alpha', 0.4) @@ -62,17 +62,24 @@ def __init__(self, base_name, params): self.model = self.actor_model self.init_rnn_from_model(self.actor_model) - self.last_lr = float(self.last_lr) + self.actor_lr = float(self.last_lr) self.betas = self.config.get('betas',[0.9, 0.999]) - self.optimizer = self.actor_optimizer = optim.Adam(self.actor_model.parameters(), float(self.last_lr), betas=self.betas, eps=1e-08, + self.optimizer = self.actor_optimizer = optim.Adam(self.actor_model.parameters(), float(self.actor_lr), betas=self.betas, eps=1e-08, weight_decay=self.weight_decay) # self.critic_optimizer = optim.Adam(self.critic_model.parameters(), float(self.critic_lr), betas=self.betas, eps=1e-08, # weight_decay=self.weight_decay) - self.critic_optimizer = optim.Adam(self.critic_model.parameters(), float(self.last_lr), betas=self.betas, eps=1e-08, + self.critic_optimizer = optim.Adam(self.critic_model.parameters(), self.critic_lr, betas=self.betas, eps=1e-08, weight_decay=self.weight_decay) self.dataset = datasets.PPODataset(self.batch_size, self.minibatch_size, self.is_discrete, self.is_rnn, self.ppo_device, self.seq_len) + if self.linear_lr: + self.critic_scheduler = schedulers.LinearScheduler(self.critic_lr, + max_steps=self.max_epochs, + apply_to_entropy=0, + start_entropy_coef=0) + else: + self.critic_scheduler = schedulers.IdentityScheduler() if self.normalize_value: self.value_mean_std = self.critic_model.value_mean_std @@ -90,12 +97,14 @@ def play_steps(self): self.actor_model.train() if self.normalize_input: self.actor_model.running_mean_std.train() + if self.normalize_value: + self.value_mean_std.eval() obs = self.initialize_trajectory() last_values = None for n in range(self.horizon_length): res_dict = self.get_actions(obs) if last_values is None: - res_dict['values'] = self.get_values(obs) + last_values = res_dict['values'] = self.get_values(obs) else: res_dict['values'] = last_values @@ -113,8 +122,18 @@ def play_steps(self): step_time += (step_time_end - step_time_start) shaped_rewards = self.rewards_shaper(rewards) - real_obs = self.obs_to_tensors(infos['obs_before_reset']) - shaped_rewards += self.gamma * self.get_values(real_obs) * episode_ended.unsqueeze(1).float() + + real_obs = infos['obs_before_reset'] + if torch.isnan(real_obs).sum() > 0 \ + or torch.isinf(real_obs).sum() > 0 \ + or (torch.abs(real_obs) > 1e6).sum() > 0: # ugly fix for nan values + print('KTOTO NOOB') + last_obs_vals = last_values.detach() + else: + real_obs = self.obs_to_tensors(real_obs) + last_obs_vals = self.get_values(real_obs) + + shaped_rewards += last_obs_vals * episode_ended.unsqueeze(1).float() self.experience_buffer.update_data('rewards', n, shaped_rewards.detach()) self.current_rewards += rewards.detach() @@ -332,8 +351,8 @@ def train_epoch(self): param_targ.data.mul_(alpha) param_targ.data.add_((1. - alpha) * param.data) self.last_lr, _ = self.scheduler.update(self.last_lr, 0, self.epoch_num, 0, None) - #self.critic_lr, _ = self.scheduler.update(self.critic_lr, 0, self.epoch_num, 0, None) - self.update_lr(self.last_lr, self.last_lr) + self.critic_lr, _ = self.critic_scheduler.update(self.critic_lr, 0, self.epoch_num, 0, None) + self.update_lr(self.last_lr, self.critic_lr) update_time_end = time.time() play_time = play_time_end - play_time_start update_time = update_time_end - update_time_start @@ -432,7 +451,8 @@ def write_stats(self, total_time, epoch_num, step_time, play_time, update_time, self.writer.add_scalar('losses/a_loss', torch_ext.mean_list(a_losses).item(), frame) self.writer.add_scalar('losses/c_loss', torch_ext.mean_list(c_losses).item(), frame) self.writer.add_scalar('info/epochs', epoch_num, frame) - self.writer.add_scalar('info/last_lr/frame', self.last_lr, frame) - self.writer.add_scalar('info/last_lr/epoch_num', self.last_lr, epoch_num) - + self.writer.add_scalar('info/actor_lr/frame', self.last_lr, frame) + self.writer.add_scalar('info/actor_lr/epoch_num', self.last_lr, epoch_num) + self.writer.add_scalar('info/critic_lr/frame', self.critic_lr, frame) + self.writer.add_scalar('info/critic_lr/epoch_num', self.critic_lr, epoch_num) self.algo_observer.after_print_stats(frame, epoch_num, total_time) \ No newline at end of file From 187f21d39243f1a3b1f683e2b8c77e787a39b0eb Mon Sep 17 00:00:00 2001 From: Viktor Makoviichuk Date: Mon, 30 May 2022 18:23:40 -0700 Subject: [PATCH 18/26] shac which works --- rl_games/algos_torch/shac_agent.py | 27 ++++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/rl_games/algos_torch/shac_agent.py b/rl_games/algos_torch/shac_agent.py index 1900a539..575ad46b 100644 --- a/rl_games/algos_torch/shac_agent.py +++ b/rl_games/algos_torch/shac_agent.py @@ -95,6 +95,7 @@ def play_steps(self): self.critic_model.eval() self.target_critic.eval() self.actor_model.train() + #torch.autograd.set_detect_anomaly(True) if self.normalize_input: self.actor_model.running_mean_std.train() @@ -121,20 +122,28 @@ def play_steps(self): obs, rewards, self.dones, infos = self.env_step(actions) step_time_end = time.time() episode_ended = self.current_lengths == self.max_episode_length - 1 + episode_ended_indices = episode_ended.nonzero(as_tuple=False) step_time += (step_time_end - step_time_start) shaped_rewards = self.rewards_shaper(rewards) + if self.normalize_input: + self.actor_model.running_mean_std.eval() real_obs = infos['obs_before_reset'] - if torch.isnan(real_obs).sum() > 0 \ - or torch.isinf(real_obs).sum() > 0 \ - or (torch.abs(real_obs) > 1e6).sum() > 0: # ugly fix for nan values - print('KTOTO NOOB') - last_obs_vals = last_values.detach() - else: - real_obs = self.obs_to_tensors(real_obs) - last_obs_vals = self.get_values(real_obs) - + last_obs_vals = last_values.clone() #.detach() + for ind in episode_ended_indices: + if torch.isnan(real_obs[ind]).sum() > 0 \ + or torch.isinf(real_obs[ind]).sum() > 0 \ + or (torch.abs(real_obs[ind]) > 1e6).sum() > 0: # ugly fix for nan values + print('KTOTO NOOB:', ind) + else: + print('VSE OK:', ind) + curr_real_obs = self.obs_to_tensors(real_obs[ind]) + val = self.get_values(curr_real_obs) + last_obs_vals[ind] = val + + if self.normalize_input: + self.actor_model.running_mean_std.train() shaped_rewards += last_obs_vals * episode_ended.unsqueeze(1).float() self.experience_buffer.update_data('rewards', n, shaped_rewards.detach()) From a965e641175957653bfea98886d4923e70e289ca Mon Sep 17 00:00:00 2001 From: Viktor Makoviichuk Date: Mon, 30 May 2022 18:29:23 -0700 Subject: [PATCH 19/26] small update to make it equal --- rl_games/algos_torch/shac_agent.py | 1 + 1 file changed, 1 insertion(+) diff --git a/rl_games/algos_torch/shac_agent.py b/rl_games/algos_torch/shac_agent.py index 575ad46b..f3e260c1 100644 --- a/rl_games/algos_torch/shac_agent.py +++ b/rl_games/algos_torch/shac_agent.py @@ -136,6 +136,7 @@ def play_steps(self): or torch.isinf(real_obs[ind]).sum() > 0 \ or (torch.abs(real_obs[ind]) > 1e6).sum() > 0: # ugly fix for nan values print('KTOTO NOOB:', ind) + last_obs_vals[ind] = 0 else: print('VSE OK:', ind) curr_real_obs = self.obs_to_tensors(real_obs[ind]) From 162f6c6f502c24e7bd4d40c2e2f9101ec481fd6d Mon Sep 17 00:00:00 2001 From: ViktorM Date: Fri, 28 Oct 2022 02:03:06 -0700 Subject: [PATCH 20/26] SHAC cleanup. Updated release version. --- rl_games/algos_torch/shac_agent.py | 3 +-- setup.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/rl_games/algos_torch/shac_agent.py b/rl_games/algos_torch/shac_agent.py index f3e260c1..19a31bc1 100644 --- a/rl_games/algos_torch/shac_agent.py +++ b/rl_games/algos_torch/shac_agent.py @@ -135,10 +135,9 @@ def play_steps(self): if torch.isnan(real_obs[ind]).sum() > 0 \ or torch.isinf(real_obs[ind]).sum() > 0 \ or (torch.abs(real_obs[ind]) > 1e6).sum() > 0: # ugly fix for nan values - print('KTOTO NOOB:', ind) + print('Nan gradients: ', ind) last_obs_vals[ind] = 0 else: - print('VSE OK:', ind) curr_real_obs = self.obs_to_tensors(real_obs[ind]) val = self.get_values(curr_real_obs) last_obs_vals[ind] = val diff --git a/setup.py b/setup.py index 10ae34a3..fd1ca15d 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,7 @@ #packages=[package for package in find_packages() if package.startswith('rl_games')], packages = ['.','rl_games','docs'], package_data={'rl_games':['*','*/*','*/*/*'],'docs':['*','*/*','*/*/*'],}, - version='1.5.2', + version='1.6.0', author='Denys Makoviichuk, Viktor Makoviichuk', author_email='trrrrr97@gmail.com, victor.makoviychuk@gmail.com', license="MIT", From 93a2b4b7d6750af31de9cb8f43bb8d04b3999a9b Mon Sep 17 00:00:00 2001 From: ViktorM Date: Sun, 8 Jan 2023 16:48:14 -0800 Subject: [PATCH 21/26] More improvements. --- rl_games/algos_torch/shac_agent.py | 42 +++++++++++++++++++++--------- rl_games/common/a2c_common.py | 4 +-- 2 files changed, 30 insertions(+), 16 deletions(-) diff --git a/rl_games/algos_torch/shac_agent.py b/rl_games/algos_torch/shac_agent.py index bfa2cdbb..f76bd8fa 100644 --- a/rl_games/algos_torch/shac_agent.py +++ b/rl_games/algos_torch/shac_agent.py @@ -13,6 +13,7 @@ import time import os import copy +import numpy as np def swap_and_flatten01(arr): @@ -61,7 +62,7 @@ def __init__(self, base_name, params): self.target_critic = copy.deepcopy(self.critic_model) if self.linear_lr: - if self.max_epochs == -1 and self.max_frames != -1: + if self.max_epochs == -1 and self.max_frames == -1: print("Max epochs and max frames are not set. Linear learning rate schedule can't be used, switching to the contstant (identity) one.") self.critic_scheduler = schedulers.IdentityScheduler() else: @@ -400,7 +401,6 @@ def train(self): start_time = time.time() total_time = 0 rep_count = 0 - # self.frame = 0 # loading from checkpoint while True: epoch_num = self.update_epoch() @@ -425,10 +425,9 @@ def train(self): epoch_num, self.max_epochs, self.frame, self.max_frames) if self.print_stats: - #print('actor loss:', a_losses[0].item()) print(f'actor loss: {a_losses[0].item():.2f}') - self.write_stats(total_time, epoch_num, step_time, play_time, update_time, a_losses, c_losses, curr_frames) + self.write_stats(total_time, epoch_num, step_time, play_time, update_time, a_losses, c_losses, self.last_lr, curr_frames) self.algo_observer.after_print_stats(frame, epoch_num, total_time) @@ -458,22 +457,39 @@ def train(self): print('saving next best rewards: ', mean_rewards) self.last_mean_rewards = mean_rewards[0] self.save(os.path.join(self.nn_dir, self.config['name'])) - if self.last_mean_rewards > self.config['score_to_win']: - print('Network won!') - self.save(os.path.join(self.nn_dir, checkpoint_name)) - should_exit = True - if epoch_num > self.max_epochs: - self.save(os.path.join(self.nn_dir, 'last_' + checkpoint_name)) + if 'score_to_win' in self.config: + if self.last_mean_rewards > self.config['score_to_win']: + print('Maximum reward achieved. Network won!') + self.save(os.path.join(self.nn_dir, checkpoint_name)) + should_exit = True + + if epoch_num >= self.max_epochs and self.max_epochs != -1: + if self.game_rewards.current_size == 0: + print('WARNING: Max epochs reached before any env terminated at least once') + mean_rewards = -np.inf + + self.save(os.path.join(self.nn_dir, 'last_' + self.config['name'] + '_ep_' + str(epoch_num) \ + + '_rew_' + str(mean_rewards).replace('[', '_').replace(']', '_'))) print('MAX EPOCHS NUM!') should_exit = True + if self.frame >= self.max_frames and self.max_frames != -1: + if self.game_rewards.current_size == 0: + print('WARNING: Max frames reached before any env terminated at least once') + mean_rewards = -np.inf + + self.save(os.path.join(self.nn_dir, 'last_' + self.config['name'] + '_frame_' + str(self.frame) \ + + '_rew_' + str(mean_rewards).replace('[', '_').replace(']', '_'))) + print('MAX FRAMES NUM!') + should_exit = True + update_time = 0 if should_exit: return self.last_mean_rewards, epoch_num - def write_stats(self, total_time, epoch_num, step_time, play_time, update_time, a_losses, c_losses, curr_frames): + def write_stats(self, total_time, epoch_num, step_time, play_time, update_time, a_losses, c_losses, last_lr, curr_frames): # do we need scaled time? frame = self.frame self.diagnostics.send_info(self.writer) @@ -486,8 +502,8 @@ def write_stats(self, total_time, epoch_num, step_time, play_time, update_time, self.writer.add_scalar('losses/a_loss', torch_ext.mean_list(a_losses).item(), frame) self.writer.add_scalar('losses/c_loss', torch_ext.mean_list(c_losses).item(), frame) self.writer.add_scalar('info/epochs', epoch_num, frame) - self.writer.add_scalar('info/actor_lr/frame', self.last_lr, frame) - self.writer.add_scalar('info/actor_lr/epoch_num', self.last_lr, epoch_num) + self.writer.add_scalar('info/actor_lr/frame', last_lr, frame) + self.writer.add_scalar('info/actor_lr/epoch_num', last_lr, epoch_num) self.writer.add_scalar('info/critic_lr/frame', self.critic_lr, frame) self.writer.add_scalar('info/critic_lr/epoch_num', self.critic_lr, epoch_num) self.algo_observer.after_print_stats(frame, epoch_num, total_time) \ No newline at end of file diff --git a/rl_games/common/a2c_common.py b/rl_games/common/a2c_common.py index ff8bff6f..8b29017f 100644 --- a/rl_games/common/a2c_common.py +++ b/rl_games/common/a2c_common.py @@ -172,7 +172,7 @@ def __init__(self, base_name, params): min_lr = self.min_lr, max_lr = self.max_lr) elif self.linear_lr: - if self.max_epochs == -1 and self.max_frames != -1: + if self.max_epochs == -1 and self.max_frames == -1: print("Max epochs and max frames are not set. Linear learning rate schedule can't be used, switching to the contstant (identity) one.") self.scheduler = schedulers.IdentityScheduler() else: @@ -1321,8 +1321,6 @@ def train(self): should_exit_t = torch.tensor(should_exit, device=self.device).float() dist.broadcast(should_exit_t, 0) should_exit = should_exit_t.float().item() - if should_exit: - return self.last_mean_rewards, epoch_num if should_exit: return self.last_mean_rewards, epoch_num From f1491b6148b86915b099cadda52ec2a488585a26 Mon Sep 17 00:00:00 2001 From: ViktorM Date: Sun, 8 Jan 2023 19:34:51 -0800 Subject: [PATCH 22/26] Added release notes. --- README.md | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 41ce54d2..570bbedc 100644 --- a/README.md +++ b/README.md @@ -76,7 +76,7 @@ If you use rl-games in your research please use the following citation: title = {rl-games: A High-performance Framework for Reinforcement Learning}, author = {Makoviichuk, Denys and Makoviychuk, Viktor}, month = {May}, -year = {2022}, +year = {2021}, publisher = {GitHub}, journal = {GitHub repository}, howpublished = {\url{https://github.com/Denys88/rl_games}}, @@ -274,6 +274,11 @@ Additional environment supported properties and functions ## Release Notes +1.6.0 +* Implemented SHAC algorithm: [Accelerated Policy Learning with Parallel Differentiable Simulation](https://short-horizon-actor-critic.github.io/) (ICLR 2022) +* Fixed various bugs related to num_frames/num_epochs interaction. +* Fixed a few SAC training configs, and improved SAC implementation. + 1.5.2 * Added observation normalization to the SAC. From fd5ce6f2da06a22dbed1c4f819073b0b5451d4cb Mon Sep 17 00:00:00 2001 From: ViktorM Date: Mon, 9 Jan 2023 00:20:38 -0800 Subject: [PATCH 23/26] Fixed const learning rate for SHAC. Fixed linear scheduling with max_frames. --- rl_games/algos_torch/central_value.py | 4 +--- rl_games/algos_torch/shac_agent.py | 17 +++++++++-------- rl_games/common/a2c_common.py | 8 ++++---- 3 files changed, 14 insertions(+), 15 deletions(-) diff --git a/rl_games/algos_torch/central_value.py b/rl_games/algos_torch/central_value.py index 9f300185..ce827283 100644 --- a/rl_games/algos_torch/central_value.py +++ b/rl_games/algos_torch/central_value.py @@ -2,7 +2,6 @@ import torch from torch import nn import torch.distributed as dist -import gym import numpy as np from rl_games.algos_torch import torch_ext from rl_games.algos_torch.running_mean_std import RunningMeanStd, RunningMeanStdObs @@ -172,7 +171,6 @@ def post_step_rnn(self, all_done_indices): def forward(self, input_dict): return self.model(input_dict) - def get_value(self, input_dict): self.eval() obs_batch = input_dict['states'] @@ -216,7 +214,7 @@ def train_net(self): avg_loss = loss / (self.mini_epoch * self.num_minibatches) self.epoch_num += 1 - self.lr, _ = self.scheduler.update(self.lr, 0, self.epoch_num, 0, 0) + self.lr, _ = self.scheduler.update(self.lr, 0, self.epoch_num, self.frame, 0) self.update_lr(self.lr) self.frame += self.batch_size if self.writter != None: diff --git a/rl_games/algos_torch/shac_agent.py b/rl_games/algos_torch/shac_agent.py index f76bd8fa..ad2f858b 100644 --- a/rl_games/algos_torch/shac_agent.py +++ b/rl_games/algos_torch/shac_agent.py @@ -94,9 +94,9 @@ def __init__(self, base_name, params): self.model = self.actor_model self.init_rnn_from_model(self.actor_model) - self.actor_lr = float(self.last_lr) - self.betas = self.config.get('betas',[0.9, 0.999]) - self.optimizer = self.actor_optimizer = optim.Adam(self.actor_model.parameters(), float(self.actor_lr), betas=self.betas, eps=1e-08, + self.actor_lr = self.last_lr + self.betas = self.config.get('betas', [0.9, 0.999]) + self.optimizer = self.actor_optimizer = optim.Adam(self.actor_model.parameters(), self.actor_lr, betas=self.betas, eps=1e-08, weight_decay=self.weight_decay) self.critic_optimizer = optim.Adam(self.critic_model.parameters(), self.critic_lr, betas=self.betas, eps=1e-08, weight_decay=self.weight_decay) @@ -157,6 +157,7 @@ def play_steps(self): if self.normalize_input: self.actor_model.running_mean_std.eval() + real_obs = infos['obs_before_reset'] last_obs_vals = last_values.clone() #.detach() for ind in episode_ended_indices: @@ -178,7 +179,6 @@ def play_steps(self): self.current_rewards += rewards.detach() self.current_lengths += 1 - all_done_indices = self.dones.nonzero(as_tuple=False) env_done_indices = self.dones.view(self.num_actors, self.num_agents).all(dim=1).nonzero(as_tuple=False) self.game_rewards.update(self.current_rewards[env_done_indices]) @@ -215,9 +215,11 @@ def play_steps(self): batch_dict['step_time'] = step_time actor_loss = actor_loss / (self.horizon_length * self.num_actors) + return batch_dict, actor_loss def env_step(self, actions): + # todo: add preprocessing #actions = self.preprocess_actions(actions) obs, rewards, dones, infos = self.vec_env.step(actions) @@ -384,8 +386,8 @@ def train_epoch(self): param_targ.data.mul_(alpha) param_targ.data.add_((1. - alpha) * param.data) - self.last_lr, _ = self.scheduler.update(self.last_lr, 0, self.epoch_num, 0, None) - self.critic_lr, _ = self.critic_scheduler.update(self.critic_lr, 0, self.epoch_num, 0, None) + self.last_lr, _ = self.scheduler.update(self.last_lr, 0, self.epoch_num, self.frame, 0) + self.critic_lr, _ = self.critic_scheduler.update(self.critic_lr, 0, self.epoch_num, self.frame, 0) self.update_lr(self.last_lr, self.critic_lr) update_time_end = time.time() @@ -398,9 +400,7 @@ def train_epoch(self): def train(self): self.init_tensors() self.mean_rewards = self.last_mean_rewards = -100500 - start_time = time.time() total_time = 0 - rep_count = 0 while True: epoch_num = self.update_epoch() @@ -492,6 +492,7 @@ def train(self): def write_stats(self, total_time, epoch_num, step_time, play_time, update_time, a_losses, c_losses, last_lr, curr_frames): # do we need scaled time? frame = self.frame + self.diagnostics.send_info(self.writer) self.writer.add_scalar('performance/step_inference_rl_update_fps', curr_frames / update_time, frame) self.writer.add_scalar('performance/step_inference_fps', curr_frames / play_time, frame) diff --git a/rl_games/common/a2c_common.py b/rl_games/common/a2c_common.py index 8b29017f..4d211358 100644 --- a/rl_games/common/a2c_common.py +++ b/rl_games/common/a2c_common.py @@ -237,7 +237,7 @@ def __init__(self, base_name, params): self.mixed_precision = self.config.get('mixed_precision', False) self.scaler = torch.cuda.amp.GradScaler(enabled=self.mixed_precision) - self.last_lr = self.config['learning_rate'] + self.last_lr = float(self.config['learning_rate']) self.frame = 0 self.update_time = 0 self.mean_rewards = self.last_mean_rewards = -100500 @@ -870,7 +870,7 @@ def train_epoch(self): dist.all_reduce(av_kls, op=dist.ReduceOp.SUM) av_kls /= self.rank_size - self.last_lr, self.entropy_coef = self.scheduler.update(self.last_lr, self.entropy_coef, self.epoch_num, 0, av_kls.item()) + self.last_lr, self.entropy_coef = self.scheduler.update(self.last_lr, self.entropy_coef, self.epoch_num, self.frame, av_kls.item()) self.update_lr(self.last_lr) kls.append(av_kls) self.diagnostics.mini_epoch(self, mini_ep) @@ -1131,7 +1131,7 @@ def train_epoch(self): if self.multi_gpu: dist.all_reduce(kl, op=dist.ReduceOp.SUM) av_kls /= self.rank_size - self.last_lr, self.entropy_coef = self.scheduler.update(self.last_lr, self.entropy_coef, self.epoch_num, 0, av_kls.item()) + self.last_lr, self.entropy_coef = self.scheduler.update(self.last_lr, self.entropy_coef, self.epoch_num, self.frame, av_kls.item()) self.update_lr(self.last_lr) av_kls = torch_ext.mean_list(ep_kls) @@ -1139,7 +1139,7 @@ def train_epoch(self): dist.all_reduce(av_kls, op=dist.ReduceOp.SUM) av_kls /= self.rank_size if self.schedule_type == 'standard': - self.last_lr, self.entropy_coef = self.scheduler.update(self.last_lr, self.entropy_coef, self.epoch_num, 0, av_kls.item()) + self.last_lr, self.entropy_coef = self.scheduler.update(self.last_lr, self.entropy_coef, self.epoch_num, self.frame, av_kls.item()) self.update_lr(self.last_lr) kls.append(av_kls) From 28a35cf7463bba9ca07b46a3be5bb1e259ad1b8a Mon Sep 17 00:00:00 2001 From: ViktorM Date: Mon, 9 Jan 2023 00:34:42 -0800 Subject: [PATCH 24/26] Fixed max_frames for central value. --- rl_games/algos_torch/a2c_continuous.py | 1 + rl_games/algos_torch/central_value.py | 29 +++++++++++++++++++------- 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/rl_games/algos_torch/a2c_continuous.py b/rl_games/algos_torch/a2c_continuous.py index c24b713a..09d122cb 100644 --- a/rl_games/algos_torch/a2c_continuous.py +++ b/rl_games/algos_torch/a2c_continuous.py @@ -47,6 +47,7 @@ def __init__(self, base_name, params): 'config' : self.central_value_config, 'writter' : self.writer, 'max_epochs' : self.max_epochs, + 'max_frames' : self.max_frames, 'multi_gpu' : self.multi_gpu, } self.central_value_net = central_value.CentralValueTrain(**cv_config).to(self.ppo_device) diff --git a/rl_games/algos_torch/central_value.py b/rl_games/algos_torch/central_value.py index ce827283..865a4be0 100644 --- a/rl_games/algos_torch/central_value.py +++ b/rl_games/algos_torch/central_value.py @@ -13,8 +13,8 @@ class CentralValueTrain(nn.Module): def __init__(self, state_shape, value_size, ppo_device, num_agents, \ - horizon_length, num_actors, num_actions, seq_len, \ - normalize_value,network, config, writter, max_epochs, multi_gpu): + horizon_length, num_actors, num_actions, seq_len, normalize_value, \ + network, config, writter, max_epochs, max_frames, multi_gpu): nn.Module.__init__(self) self.ppo_device = ppo_device @@ -24,6 +24,7 @@ def __init__(self, state_shape, value_size, ppo_device, num_agents, \ self.state_shape = state_shape self.value_size = value_size self.max_epochs = max_epochs + self.max_frames = max_frames self.multi_gpu = multi_gpu self.truncate_grads = config.get('truncate_grads', False) self.config = config @@ -42,14 +43,28 @@ def __init__(self, state_shape, value_size, ppo_device, num_agents, \ self.lr = float(config['learning_rate']) self.linear_lr = config.get('lr_schedule') == 'linear' - # todo: support max frames if self.linear_lr: - self.scheduler = schedulers.LinearScheduler(self.lr, - max_steps = self.max_epochs, - apply_to_entropy = False, - start_entropy_coef = 0) + if self.max_epochs == -1 and self.max_frames == -1: + print("Max epochs and max frames are not set. Linear learning rate schedule can't be used, switching to the contstant (identity) one.") + self.scheduler = schedulers.IdentityScheduler() + else: + print("Linear lr schedule. Min lr = ", self.min_lr) + use_epochs = True + max_steps = self.max_epochs + + if self.max_epochs == -1: + use_epochs = False + max_steps = self.max_frames + + self.scheduler = schedulers.LinearScheduler(self.lr, + min_lr = self.min_lr, + max_steps = max_steps, + use_epochs = use_epochs, + apply_to_entropy = False, + start_entropy_coef = 0.0) else: self.scheduler = schedulers.IdentityScheduler() + self.mini_epoch = config['mini_epochs'] assert(('minibatch_size_per_env' in self.config) or ('minibatch_size' in self.config)) From b71596a3073865f9daf6f471c1edff2a53d379f6 Mon Sep 17 00:00:00 2001 From: ViktorM Date: Wed, 11 Jan 2023 23:08:23 -0800 Subject: [PATCH 25/26] Readme update. --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index 570bbedc..c0dd1965 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,8 @@ * Is Independent Learning All You Need in the StarCraft Multi-Agent Challenge? * Superfast Adversarial Motion Priors (AMP) implementation: https://twitter.com/xbpeng4/status/1506317490766303235 https://github.com/NVIDIA-Omniverse/IsaacGymEnvs * OSCAR: Data-Driven Operational Space Control for Adaptive and Robust Robot Manipulation: https://cremebrule.github.io/oscar-web/ https://arxiv.org/abs/2110.00704 +* EnvPool: A Highly Parallel Reinforcement Learning Environment Execution Engine: https://github.com/sail-sg/envpool https://arxiv.org/abs/2206.10558 +* TimeChamber: A Massively Parallel Large Scale Self-Play Framework: https://github.com/inspirai/TimeChamber ## Some results on the different environments From aa600678f7e2536301939d48e72deec85c5b1e25 Mon Sep 17 00:00:00 2001 From: ViktorM Date: Wed, 11 Jan 2023 23:15:47 -0800 Subject: [PATCH 26/26] Readme update. --- README.md | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index c0dd1965..12eb7460 100644 --- a/README.md +++ b/README.md @@ -5,13 +5,13 @@ ## Papers and related links -* Isaac Gym: High Performance GPU-Based Physics Simulation For Robot Learning: https://arxiv.org/abs/2108.10470 -* Transferring Dexterous Manipulation from GPU Simulation to a Remote Real-World TriFinger: https://s2r2-ig.github.io/ https://arxiv.org/abs/2108.09779 -* Is Independent Learning All You Need in the StarCraft Multi-Agent Challenge? -* Superfast Adversarial Motion Priors (AMP) implementation: https://twitter.com/xbpeng4/status/1506317490766303235 https://github.com/NVIDIA-Omniverse/IsaacGymEnvs -* OSCAR: Data-Driven Operational Space Control for Adaptive and Robust Robot Manipulation: https://cremebrule.github.io/oscar-web/ https://arxiv.org/abs/2110.00704 -* EnvPool: A Highly Parallel Reinforcement Learning Environment Execution Engine: https://github.com/sail-sg/envpool https://arxiv.org/abs/2206.10558 -* TimeChamber: A Massively Parallel Large Scale Self-Play Framework: https://github.com/inspirai/TimeChamber +* Isaac Gym: High Performance GPU-Based Physics Simulation For Robot Learning. Paper: https://arxiv.org/abs/2108.10470 +* Transferring Dexterous Manipulation from GPU Simulation to a Remote Real-World TriFinger. Site: https://s2r2-ig.github.io/ Paper: https://arxiv.org/abs/2108.09779 +* Is Independent Learning All You Need in the StarCraft Multi-Agent Challenge? Paper: https://arxiv.org/abs/2011.09533 +* Superfast Adversarial Motion Priors (AMP) implementation. Twitter: https://twitter.com/xbpeng4/status/1506317490766303235 Repo: https://github.com/NVIDIA-Omniverse/IsaacGymEnvs +* OSCAR: Data-Driven Operational Space Control for Adaptive and Robust Robot Manipulation. Site: https://cremebrule.github.io/oscar-web/ Paper: https://arxiv.org/abs/2110.00704 +* EnvPool: A Highly Parallel Reinforcement Learning Environment Execution Engine. Paper: https://arxiv.org/abs/2206.10558 Repo: https://github.com/sail-sg/envpool +* TimeChamber: A Massively Parallel Large Scale Self-Play Framework. Repo: https://github.com/inspirai/TimeChamber ## Some results on the different environments