diff --git a/README.md b/README.md index 41ce54d2..12eb7460 100644 --- a/README.md +++ b/README.md @@ -5,11 +5,13 @@ ## Papers and related links -* Isaac Gym: High Performance GPU-Based Physics Simulation For Robot Learning: https://arxiv.org/abs/2108.10470 -* Transferring Dexterous Manipulation from GPU Simulation to a Remote Real-World TriFinger: https://s2r2-ig.github.io/ https://arxiv.org/abs/2108.09779 -* Is Independent Learning All You Need in the StarCraft Multi-Agent Challenge? -* Superfast Adversarial Motion Priors (AMP) implementation: https://twitter.com/xbpeng4/status/1506317490766303235 https://github.com/NVIDIA-Omniverse/IsaacGymEnvs -* OSCAR: Data-Driven Operational Space Control for Adaptive and Robust Robot Manipulation: https://cremebrule.github.io/oscar-web/ https://arxiv.org/abs/2110.00704 +* Isaac Gym: High Performance GPU-Based Physics Simulation For Robot Learning. Paper: https://arxiv.org/abs/2108.10470 +* Transferring Dexterous Manipulation from GPU Simulation to a Remote Real-World TriFinger. Site: https://s2r2-ig.github.io/ Paper: https://arxiv.org/abs/2108.09779 +* Is Independent Learning All You Need in the StarCraft Multi-Agent Challenge? Paper: https://arxiv.org/abs/2011.09533 +* Superfast Adversarial Motion Priors (AMP) implementation. Twitter: https://twitter.com/xbpeng4/status/1506317490766303235 Repo: https://github.com/NVIDIA-Omniverse/IsaacGymEnvs +* OSCAR: Data-Driven Operational Space Control for Adaptive and Robust Robot Manipulation. Site: https://cremebrule.github.io/oscar-web/ Paper: https://arxiv.org/abs/2110.00704 +* EnvPool: A Highly Parallel Reinforcement Learning Environment Execution Engine. Paper: https://arxiv.org/abs/2206.10558 Repo: https://github.com/sail-sg/envpool +* TimeChamber: A Massively Parallel Large Scale Self-Play Framework. Repo: https://github.com/inspirai/TimeChamber ## Some results on the different environments @@ -76,7 +78,7 @@ If you use rl-games in your research please use the following citation: title = {rl-games: A High-performance Framework for Reinforcement Learning}, author = {Makoviichuk, Denys and Makoviychuk, Viktor}, month = {May}, -year = {2022}, +year = {2021}, publisher = {GitHub}, journal = {GitHub repository}, howpublished = {\url{https://github.com/Denys88/rl_games}}, @@ -274,6 +276,11 @@ Additional environment supported properties and functions ## Release Notes +1.6.0 +* Implemented SHAC algorithm: [Accelerated Policy Learning with Parallel Differentiable Simulation](https://short-horizon-actor-critic.github.io/) (ICLR 2022) +* Fixed various bugs related to num_frames/num_epochs interaction. +* Fixed a few SAC training configs, and improved SAC implementation. + 1.5.2 * Added observation normalization to the SAC. diff --git a/rl_games/algos_torch/a2c_continuous.py b/rl_games/algos_torch/a2c_continuous.py index 1b5fda68..09d122cb 100644 --- a/rl_games/algos_torch/a2c_continuous.py +++ b/rl_games/algos_torch/a2c_continuous.py @@ -6,14 +6,14 @@ from rl_games.common import datasets from torch import optim -import torch -from torch import nn -import numpy as np -import gym +import torch + class A2CAgent(a2c_common.ContinuousA2CBase): + def __init__(self, base_name, params): a2c_common.ContinuousA2CBase.__init__(self, base_name, params) + obs_shape = self.obs_shape build_config = { 'actions_num' : self.actions_num, @@ -23,7 +23,7 @@ def __init__(self, base_name, params): 'normalize_value' : self.normalize_value, 'normalize_input': self.normalize_input, } - + self.model = self.network.build(build_config) self.model.to(self.ppo_device) self.states = None @@ -47,6 +47,7 @@ def __init__(self, base_name, params): 'config' : self.central_value_config, 'writter' : self.writer, 'max_epochs' : self.max_epochs, + 'max_frames' : self.max_frames, 'multi_gpu' : self.multi_gpu, } self.central_value_net = central_value.CentralValueTrain(**cv_config).to(self.ppo_device) diff --git a/rl_games/algos_torch/central_value.py b/rl_games/algos_torch/central_value.py index 190d0612..865a4be0 100644 --- a/rl_games/algos_torch/central_value.py +++ b/rl_games/algos_torch/central_value.py @@ -2,7 +2,6 @@ import torch from torch import nn import torch.distributed as dist -import gym import numpy as np from rl_games.algos_torch import torch_ext from rl_games.algos_torch.running_mean_std import RunningMeanStd, RunningMeanStdObs @@ -14,8 +13,8 @@ class CentralValueTrain(nn.Module): def __init__(self, state_shape, value_size, ppo_device, num_agents, \ - horizon_length, num_actors, num_actions, seq_len, \ - normalize_value,network, config, writter, max_epochs, multi_gpu): + horizon_length, num_actors, num_actions, seq_len, normalize_value, \ + network, config, writter, max_epochs, max_frames, multi_gpu): nn.Module.__init__(self) self.ppo_device = ppo_device @@ -25,6 +24,7 @@ def __init__(self, state_shape, value_size, ppo_device, num_agents, \ self.state_shape = state_shape self.value_size = value_size self.max_epochs = max_epochs + self.max_frames = max_frames self.multi_gpu = multi_gpu self.truncate_grads = config.get('truncate_grads', False) self.config = config @@ -43,14 +43,28 @@ def __init__(self, state_shape, value_size, ppo_device, num_agents, \ self.lr = float(config['learning_rate']) self.linear_lr = config.get('lr_schedule') == 'linear' - # todo: support max frames as well if self.linear_lr: - self.scheduler = schedulers.LinearScheduler(self.lr, - max_steps = self.max_epochs, - apply_to_entropy = False, - start_entropy_coef = 0) + if self.max_epochs == -1 and self.max_frames == -1: + print("Max epochs and max frames are not set. Linear learning rate schedule can't be used, switching to the contstant (identity) one.") + self.scheduler = schedulers.IdentityScheduler() + else: + print("Linear lr schedule. Min lr = ", self.min_lr) + use_epochs = True + max_steps = self.max_epochs + + if self.max_epochs == -1: + use_epochs = False + max_steps = self.max_frames + + self.scheduler = schedulers.LinearScheduler(self.lr, + min_lr = self.min_lr, + max_steps = max_steps, + use_epochs = use_epochs, + apply_to_entropy = False, + start_entropy_coef = 0.0) else: self.scheduler = schedulers.IdentityScheduler() + self.mini_epoch = config['mini_epochs'] assert(('minibatch_size_per_env' in self.config) or ('minibatch_size' in self.config)) @@ -172,7 +186,6 @@ def post_step_rnn(self, all_done_indices): def forward(self, input_dict): return self.model(input_dict) - def get_value(self, input_dict): self.eval() obs_batch = input_dict['states'] @@ -197,8 +210,8 @@ def train_critic(self, input_dict): def update_multiagent_tensors(self, value_preds, returns, actions, dones): batch_size = self.batch_size ma_batch_size = self.num_actors * self.num_agents * self.horizon_length - value_preds = value_preds.view(self.num_actors, self.num_agents, self.horizon_length, self.value_size).transpose(0,1) - returns = returns.view(self.num_actors, self.num_agents, self.horizon_length, self.value_size).transpose(0,1) + value_preds = value_preds.view(self.num_actors, self.num_agents, self.horizon_length, self.value_size).transpose(0, 1) + returns = returns.view(self.num_actors, self.num_agents, self.horizon_length, self.value_size).transpose(0, 1) value_preds = value_preds.contiguous().view(ma_batch_size, self.value_size)[:batch_size] returns = returns.contiguous().view(ma_batch_size, self.value_size)[:batch_size] dones = dones.contiguous().view(ma_batch_size, self.value_size)[:batch_size] @@ -216,12 +229,13 @@ def train_net(self): avg_loss = loss / (self.mini_epoch * self.num_minibatches) self.epoch_num += 1 - self.lr, _ = self.scheduler.update(self.lr, 0, self.epoch_num, 0, 0) + self.lr, _ = self.scheduler.update(self.lr, 0, self.epoch_num, self.frame, 0) self.update_lr(self.lr) self.frame += self.batch_size if self.writter != None: self.writter.add_scalar('losses/cval_loss', avg_loss, self.frame) - self.writter.add_scalar('info/cval_lr', self.lr, self.frame) + self.writter.add_scalar('info/cval_lr', self.lr, self.frame) + return avg_loss def calc_gradients(self, batch): diff --git a/rl_games/algos_torch/model_builder.py b/rl_games/algos_torch/model_builder.py index c2045c5e..12057ac9 100644 --- a/rl_games/algos_torch/model_builder.py +++ b/rl_games/algos_torch/model_builder.py @@ -1,7 +1,6 @@ from rl_games.common import object_factory -import rl_games.algos_torch -from rl_games.algos_torch import network_builder -from rl_games.algos_torch import models +from rl_games.algos_torch import network_builder, models + NETWORK_REGISTRY = {} MODEL_REGISTRY = {} @@ -14,13 +13,13 @@ def register_model(name, target_class): class NetworkBuilder: + def __init__(self): self.network_factory = object_factory.ObjectFactory() self.network_factory.set_builders(NETWORK_REGISTRY) self.network_factory.register_builder('actor_critic', lambda **kwargs: network_builder.A2CBuilder()) self.network_factory.register_builder('resnet_actor_critic', lambda **kwargs: network_builder.A2CResnetBuilder()) - self.network_factory.register_builder('rnd_curiosity', lambda **kwargs: network_builder.RNDCuriosityBuilder()) self.network_factory.register_builder('soft_actor_critic', lambda **kwargs: network_builder.SACBuilder()) def load(self, params): @@ -32,6 +31,7 @@ def load(self, params): class ModelBuilder: + def __init__(self): self.model_factory = object_factory.ObjectFactory() self.model_factory.set_builders(MODEL_REGISTRY) @@ -46,6 +46,8 @@ def __init__(self): lambda network, **kwargs: models.ModelSACContinuous(network)) self.model_factory.register_builder('central_value', lambda network, **kwargs: models.ModelCentralValue(network)) + self.model_factory.register_builder('shac', + lambda network, **kwargs: models.ModelA2CContinuousSHAC(network)) self.network_builder = NetworkBuilder() def get_network_builder(self): diff --git a/rl_games/algos_torch/models.py b/rl_games/algos_torch/models.py index 79f6bf75..13eb8e69 100644 --- a/rl_games/algos_torch/models.py +++ b/rl_games/algos_torch/models.py @@ -1,4 +1,3 @@ -import rl_games.algos_torch.layers import numpy as np import torch.nn as nn import torch @@ -28,7 +27,9 @@ def build(self, config): return self.Network(self.network_builder.build(self.model_class, **config), obs_shape=obs_shape, normalize_value=normalize_value, normalize_input=normalize_input, value_size=value_size) + class BaseModelNetwork(nn.Module): + def __init__(self, obs_shape, normalize_value, normalize_input, value_size): nn.Module.__init__(self) self.obs_shape = obs_shape @@ -45,12 +46,11 @@ def __init__(self, obs_shape, normalize_value, normalize_input, value_size): self.running_mean_std = RunningMeanStd(obs_shape) def norm_obs(self, observation): - with torch.no_grad(): - return self.running_mean_std(observation) if self.normalize_input else observation + return self.running_mean_std(observation) if self.normalize_input else observation def unnorm_value(self, value): - with torch.no_grad(): - return self.value_mean_std(value, unnorm=True) if self.normalize_value else value + return self.value_mean_std(value, unnorm=True) if self.normalize_value else value + class ModelA2C(BaseModel): def __init__(self, network): @@ -64,7 +64,7 @@ def __init__(self, a2c_network, **kwargs): def is_rnn(self): return self.a2c_network.is_rnn() - + def get_default_rnn_state(self): return self.a2c_network.get_default_rnn_state() @@ -105,7 +105,9 @@ def forward(self, input_dict): } return result + class ModelA2CMultiDiscrete(BaseModel): + def __init__(self, network): BaseModel.__init__(self, 'a2c') self.network_builder = network @@ -169,7 +171,9 @@ def forward(self, input_dict): } return result + class ModelA2CContinuous(BaseModel): + def __init__(self, network): BaseModel.__init__(self, 'a2c') self.network_builder = network @@ -225,6 +229,7 @@ def forward(self, input_dict): class ModelA2CContinuousLogStd(BaseModel): + def __init__(self, network): BaseModel.__init__(self, 'a2c') self.network_builder = network @@ -278,7 +283,41 @@ def neglogp(self, x, mean, std, logstd): + logstd.sum(dim=-1) +class ModelA2CContinuousSHAC(BaseModel): + def __init__(self, network): + BaseModel.__init__(self, 'a2c') + self.network_builder = network + + class Network(BaseModelNetwork): + def __init__(self, a2c_network, **kwargs): + BaseModelNetwork.__init__(self, **kwargs) + self.a2c_network = a2c_network + + def is_rnn(self): + return self.a2c_network.is_rnn() + + def get_default_rnn_state(self): + return self.a2c_network.get_default_rnn_state() + + def forward(self, input_dict): + input_dict['obs'] = self.norm_obs(input_dict['obs']) + mu, logstd, _, states = self.a2c_network(input_dict) + sigma = torch.exp(logstd) + distr = torch.distributions.Normal(mu, sigma) + entropy = distr.entropy().sum(dim=-1) + selected_action = distr.rsample() + result = { + 'actions': selected_action, + 'entropy': entropy, + 'rnn_states': states, + 'mus': mu, + 'sigmas': sigma + } + return result + + class ModelCentralValue(BaseModel): + def __init__(self, network): BaseModel.__init__(self, 'a2c') self.network_builder = network @@ -304,7 +343,6 @@ def forward(self, input_dict): value, states = self.a2c_network(input_dict) if not is_train: value = self.unnorm_value(value) - result = { 'values': value, 'rnn_states': states @@ -312,16 +350,15 @@ def forward(self, input_dict): return result - class ModelSACContinuous(BaseModel): def __init__(self, network): BaseModel.__init__(self, 'sac') self.network_builder = network - + class Network(BaseModelNetwork): - def __init__(self, sac_network,**kwargs): - BaseModelNetwork.__init__(self,**kwargs) + def __init__(self, sac_network, **kwargs): + BaseModelNetwork.__init__(self, **kwargs) self.sac_network = sac_network def critic(self, obs, action): @@ -332,7 +369,7 @@ def critic_target(self, obs, action): def actor(self, obs): return self.sac_network.actor(obs) - + def is_rnn(self): return False @@ -344,3 +381,4 @@ def forward(self, input_dict): + diff --git a/rl_games/algos_torch/network_builder.py b/rl_games/algos_torch/network_builder.py index fec2b2bf..9fb4b7fc 100644 --- a/rl_games/algos_torch/network_builder.py +++ b/rl_games/algos_torch/network_builder.py @@ -180,14 +180,14 @@ def __init__(self, params, **kwargs): actions_num = kwargs.pop('actions_num') input_shape = kwargs.pop('input_shape') self.value_size = kwargs.pop('value_size', 1) - self.num_seqs = num_seqs = kwargs.pop('num_seqs', 1) + self.num_seqs = kwargs.pop('num_seqs', 1) NetworkBuilder.BaseNetwork.__init__(self) self.load(params) self.actor_cnn = nn.Sequential() self.critic_cnn = nn.Sequential() self.actor_mlp = nn.Sequential() self.critic_mlp = nn.Sequential() - + if self.has_cnn: if self.permute_input: input_shape = torch_ext.shape_whc_to_cwh(input_shape) @@ -344,6 +344,7 @@ def forward(self, obs_dict): c_out = c_out.transpose(0,1) a_out = a_out.contiguous().reshape(a_out.size()[0] * a_out.size()[1], -1) c_out = c_out.contiguous().reshape(c_out.size()[0] * c_out.size()[1], -1) + if self.rnn_ln: a_out = self.a_layer_norm(a_out) c_out = self.c_layer_norm(c_out) @@ -358,7 +359,7 @@ def forward(self, obs_dict): else: a_out = self.actor_mlp(a_out) c_out = self.critic_mlp(c_out) - + value = self.value_act(self.value(c_out)) if self.is_discrete: @@ -431,7 +432,7 @@ def forward(self, obs_dict): else: sigma = self.sigma_act(self.sigma(out)) return mu, mu*0 + sigma, value, states - + def is_separate_critic(self): return self.separate @@ -511,6 +512,7 @@ def build(self, name, **kwargs): net = A2CBuilder.Network(self.params, **kwargs) return net + class Conv2dAuto(nn.Conv2d): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -604,7 +606,6 @@ def __init__(self, params, **kwargs): mlp_input_shape = self._calc_input_size(input_shape, self.cnn) in_mlp_shape = mlp_input_shape - if len(self.units) == 0: out_size = mlp_input_shape else: @@ -623,7 +624,6 @@ def __init__(self, params, **kwargs): rnn_in_size += actions_num self.rnn = self._build_rnn(self.rnn_name, rnn_in_size, self.rnn_units, self.rnn_layers) #self.layer_norm = torch.nn.LayerNorm(self.rnn_units) - mlp_args = { 'input_size' : in_mlp_shape, 'units' :self.units, @@ -641,9 +641,9 @@ def __init__(self, params, **kwargs): self.logits = torch.nn.Linear(out_size, actions_num) if self.is_continuous: self.mu = torch.nn.Linear(out_size, actions_num) - self.mu_act = self.activations_factory.create(self.space_config['mu_activation']) + self.mu_act = self.activations_factory.create(self.space_config['mu_activation']) mu_init = self.init_factory.create(**self.space_config['mu_init']) - self.sigma_act = self.activations_factory.create(self.space_config['sigma_activation']) + self.sigma_act = self.activations_factory.create(self.space_config['sigma_activation']) sigma_init = self.init_factory.create(**self.space_config['sigma_init']) if self.fixed_sigma: @@ -670,7 +670,7 @@ def __init__(self, params, **kwargs): else: sigma_init(self.sigma.weight) - mlp_init(self.value.weight) + mlp_init(self.value.weight) def forward(self, obs_dict): if self.require_rewards or self.require_last_actions: @@ -869,7 +869,7 @@ def __init__(self, params, **kwargs): input_shape = kwargs.pop('input_shape') obs_dim = kwargs.pop('obs_dim') action_dim = kwargs.pop('action_dim') - self.num_seqs = num_seqs = kwargs.pop('num_seqs', 1) + self.num_seqs = kwargs.pop('num_seqs', 1) NetworkBuilder.BaseNetwork.__init__(self) self.load(params) diff --git a/rl_games/algos_torch/players.py b/rl_games/algos_torch/players.py index 9120bb94..d713e36a 100644 --- a/rl_games/algos_torch/players.py +++ b/rl_games/algos_torch/players.py @@ -16,6 +16,7 @@ def rescale_actions(low, high, action): class PpoPlayerContinuous(BasePlayer): + def __init__(self, params): BasePlayer.__init__(self, params) self.network = self.config['network'] @@ -81,7 +82,9 @@ def restore(self, fn): def reset(self): self.init_rnn() + class PpoPlayerDiscrete(BasePlayer): + def __init__(self, params): BasePlayer.__init__(self, params) @@ -185,6 +188,7 @@ def reset(self): class SACPlayer(BasePlayer): + def __init__(self, params): BasePlayer.__init__(self, params) self.network = self.config['network'] @@ -225,11 +229,64 @@ def restore(self, fn): def get_action(self, obs, is_deterministic=False): if self.has_batch_dimension == False: obs = unsqueeze_obs(obs) + dist = self.model.actor(obs) actions = dist.sample() if is_deterministic else dist.mean actions = actions.clamp(*self.action_range).to(self.device) if self.has_batch_dimension == False: actions = torch.squeeze(actions.detach()) + + return actions + + def reset(self): + pass + + +class SHACPlayer(BasePlayer): + + def __init__(self, params): + BasePlayer.__init__(self, params) + self.network = self.config['network'] + self.actions_num = self.action_space.shape[0] + self.action_range = [ + float(self.env_info['action_space'].low.min()), + float(self.env_info['action_space'].high.max()) + ] + + obs_shape = self.obs_shape + self.normalize_input = False + config = { + 'obs_dim': self.env_info["observation_space"].shape[0], + 'action_dim': self.env_info["action_space"].shape[0], + 'actions_num' : self.actions_num, + 'input_shape' : obs_shape, + 'value_size': self.env_info.get('value_size', 1), + 'normalize_value': False, + 'normalize_input': self.normalize_input, + } + self.model = self.network.build(config) + self.model.to(self.device) + self.model.eval() + self.is_rnn = self.model.is_rnn() + + def restore(self, fn): + checkpoint = torch_ext.load_checkpoint(fn) + self.model.sac_network.actor.load_state_dict(checkpoint['actor']) + self.model.sac_network.critic.load_state_dict(checkpoint['critic']) + self.model.sac_network.critic_target.load_state_dict(checkpoint['critic_target']) + if self.normalize_input and 'running_mean_std' in checkpoint: + self.model.running_mean_std.load_state_dict(checkpoint['running_mean_std']) + + def get_action(self, obs, is_determenistic=False): + if self.has_batch_dimension == False: + obs = unsqueeze_obs(obs) + + dist = self.model.actor(obs) + actions = dist.sample() if is_determenistic else dist.mean + actions = actions.clamp(*self.action_range).to(self.device) + if self.has_batch_dimension == False: + actions = torch.squeeze(actions.detach()) + return actions def reset(self): diff --git a/rl_games/algos_torch/running_mean_std.py b/rl_games/algos_torch/running_mean_std.py index 152295c1..d6dc285c 100644 --- a/rl_games/algos_torch/running_mean_std.py +++ b/rl_games/algos_torch/running_mean_std.py @@ -43,40 +43,41 @@ def _update_mean_var_count_from_moments(self, mean, var, count, batch_mean, batc return new_mean, new_var, new_count def forward(self, input, unnorm=False, mask=None): - if self.training: - if mask is not None: - mean, var = torch_ext.get_mean_std_with_masks(input, mask) - else: - mean = input.mean(self.axis) # along channel axis - var = input.var(self.axis) - self.running_mean, self.running_var, self.count = self._update_mean_var_count_from_moments(self.running_mean, self.running_var, self.count, - mean, var, input.size()[0] ) + with torch.no_grad(): + if self.training: + if mask is not None: + mean, var = torch_ext.get_mean_std_with_masks(input, mask) + else: + mean = input.mean(self.axis) # along channel axis + var = input.var(self.axis) + self.running_mean, self.running_var, self.count = self._update_mean_var_count_from_moments(self.running_mean, self.running_var, self.count, + mean, var, input.size()[0] ) - # change shape - if self.per_channel: - if len(self.insize) == 3: - current_mean = self.running_mean.view([1, self.insize[0], 1, 1]).expand_as(input) - current_var = self.running_var.view([1, self.insize[0], 1, 1]).expand_as(input) - if len(self.insize) == 2: - current_mean = self.running_mean.view([1, self.insize[0], 1]).expand_as(input) - current_var = self.running_var.view([1, self.insize[0], 1]).expand_as(input) - if len(self.insize) == 1: - current_mean = self.running_mean.view([1, self.insize[0]]).expand_as(input) - current_var = self.running_var.view([1, self.insize[0]]).expand_as(input) - else: - current_mean = self.running_mean - current_var = self.running_var + # change shape + if self.per_channel: + if len(self.insize) == 3: + current_mean = self.running_mean.view([1, self.insize[0], 1, 1]).expand_as(input) + current_var = self.running_var.view([1, self.insize[0], 1, 1]).expand_as(input) + if len(self.insize) == 2: + current_mean = self.running_mean.view([1, self.insize[0], 1]).expand_as(input) + current_var = self.running_var.view([1, self.insize[0], 1]).expand_as(input) + if len(self.insize) == 1: + current_mean = self.running_mean.view([1, self.insize[0]]).expand_as(input) + current_var = self.running_var.view([1, self.insize[0]]).expand_as(input) + else: + current_mean = self.running_mean + current_var = self.running_var # get output if unnorm: y = torch.clamp(input, min=-5.0, max=5.0) - y = torch.sqrt(current_var.float() + self.epsilon)*y + current_mean.float() + y = torch.sqrt(current_var.float().clone() + self.epsilon)*y + current_mean.float().clone() else: if self.norm_only: y = input/ torch.sqrt(current_var.float() + self.epsilon) else: - y = (input - current_mean.float()) / torch.sqrt(current_var.float() + self.epsilon) + y = (input - current_mean.clone().float()) / torch.sqrt(current_var.clone().float() + self.epsilon) y = torch.clamp(y, min=-5.0, max=5.0) return y diff --git a/rl_games/algos_torch/sac_agent.py b/rl_games/algos_torch/sac_agent.py index e18582c5..527225e7 100644 --- a/rl_games/algos_torch/sac_agent.py +++ b/rl_games/algos_torch/sac_agent.py @@ -5,12 +5,12 @@ from rl_games.common import experience from rl_games.common.a2c_common import print_statistics -from rl_games.interfaces.base_algorithm import BaseAlgorithm +from rl_games.interfaces.base_algorithm import BaseAlgorithm from torch.utils.tensorboard import SummaryWriter from datetime import datetime -from rl_games.algos_torch import model_builder +from rl_games.algos_torch import model_builder from torch import optim -import torch +import torch from torch import nn import torch.nn.functional as F import numpy as np @@ -28,6 +28,7 @@ def __init__(self, base_name, params): # TODO: Get obs shape and self.network self.load_networks(params) self.base_init(base_name, config) + self.num_warmup_steps = config["num_warmup_steps"] self.gamma = config["gamma"] self.critic_tau = config["critic_tau"] @@ -38,7 +39,7 @@ def __init__(self, base_name, params): self.num_steps_per_episode = config.get("num_steps_per_episode", 1) self.normalize_input = config.get("normalize_input", False) - self.max_env_steps = config.get("max_env_steps", 1000) # temporary, in future we will use other approach + self.max_env_steps = config.get("max_env_steps", 1000) # temporary, in future we will use other approach print(self.batch_size, self.num_actors, self.num_agents) @@ -58,9 +59,8 @@ def __init__(self, base_name, params): net_config = { 'obs_dim': self.env_info["observation_space"].shape[0], 'action_dim': self.env_info["action_space"].shape[0], - 'actions_num' : self.actions_num, - 'input_shape' : obs_shape, - 'normalize_input' : self.normalize_input, + 'actions_num': self.actions_num, + 'input_shape': obs_shape, 'normalize_input': self.normalize_input, } self.model = self.network.build(net_config) @@ -120,7 +120,7 @@ def base_init(self, base_name, config): self.rewards_shaper = config['reward_shaper'] self.observation_space = self.env_info['observation_space'] self.weight_decay = config.get('weight_decay', 0.0) - #self.use_action_masks = config.get('use_action_masks', False) + # self.use_action_masks = config.get('use_action_masks', False) self.is_train = config.get('is_train', True) self.c_loss = nn.MSELoss() @@ -212,14 +212,14 @@ def get_full_state_weights(self): state['steps'] = self.step state['actor_optimizer'] = self.actor_optimizer.state_dict() state['critic_optimizer'] = self.critic_optimizer.state_dict() - state['log_alpha_optimizer'] = self.log_alpha_optimizer.state_dict() + state['log_alpha_optimizer'] = self.log_alpha_optimizer.state_dict() return state def get_weights(self): state = {'actor': self.model.sac_network.actor.state_dict(), - 'critic': self.model.sac_network.critic.state_dict(), - 'critic_target': self.model.sac_network.critic_target.state_dict()} + 'critic': self.model.sac_network.critic.state_dict(), + 'critic_target': self.model.sac_network.critic_target.state_dict()} return state def save(self, fn): @@ -271,7 +271,7 @@ def update_critic(self, obs, action, reward, next_obs, not_done, step): critic1_loss = self.c_loss(current_Q1, target_Q) critic2_loss = self.c_loss(current_Q2, target_Q) - critic_loss = critic1_loss + critic2_loss + critic_loss = critic1_loss + critic2_loss self.critic_optimizer.zero_grad(set_to_none=True) critic_loss.backward() self.critic_optimizer.step() @@ -308,7 +308,7 @@ def update_actor_and_alpha(self, obs, step): else: alpha_loss = None - return actor_loss.detach(), entropy.detach(), self.alpha.detach(), alpha_loss # TODO: maybe not self.alpha + return actor_loss.detach(), entropy.detach(), self.alpha.detach(), alpha_loss # TODO: maybe not self.alpha def soft_update_params(self, net, target_net, tau): for param, target_param in zip(net.parameters(), target_net.parameters()): @@ -327,7 +327,7 @@ def update(self, step): actor_loss_info = actor_loss, entropy, alpha, alpha_loss self.soft_update_params(self.model.sac_network.critic, self.model.sac_network.critic_target, - self.critic_tau) + self.critic_tau) return actor_loss_info, critic1_loss, critic2_loss def preproc_obs(self, obs): @@ -340,7 +340,7 @@ def cast_obs(self, obs): if isinstance(obs, torch.Tensor): self.is_tensor_obses = True elif isinstance(obs, np.ndarray): - assert(self.observation_space.dtype != np.int8) + assert (self.observation_space.dtype != np.int8) if self.observation_space.dtype == np.uint8: obs = torch.ByteTensor(obs).to(self._device) else: @@ -357,8 +357,8 @@ def obs_to_tensors(self, obs): upd_obs[key] = self._obs_to_tensors_internal(value) else: upd_obs = self.cast_obs(obs) - if not obs_is_dict or 'obs' not in obs: - upd_obs = {'obs' : upd_obs} + if not obs_is_dict or 'obs' not in obs: + upd_obs = {'obs': upd_obs} return upd_obs def _obs_to_tensors_internal(self, obs): @@ -377,7 +377,7 @@ def preprocess_actions(self, actions): def env_step(self, actions): actions = self.preprocess_actions(actions) - obs, rewards, dones, infos = self.vec_env.step(actions) # (obs_space) -> (n, obs_space) + obs, rewards, dones, infos = self.vec_env.step(actions) # (obs_space) -> (n, obs_space) self.step += self.num_actors if self.is_tensor_obses: @@ -468,7 +468,7 @@ def play_steps(self, random_exploration = False): if isinstance(obs, dict): obs = obs['obs'] - if isinstance(next_obs, dict): + if isinstance(next_obs, dict): next_obs = next_obs['obs'] rewards = self.rewards_shaper(rewards) diff --git a/rl_games/algos_torch/shac_agent.py b/rl_games/algos_torch/shac_agent.py new file mode 100644 index 00000000..ad2f858b --- /dev/null +++ b/rl_games/algos_torch/shac_agent.py @@ -0,0 +1,510 @@ +from rl_games.common import schedulers +from rl_games.common.a2c_common import print_statistics + +from rl_games.common.a2c_common import ContinuousA2CBase +from rl_games.algos_torch import torch_ext + +from rl_games.algos_torch import central_value +from rl_games.common import common_losses, datasets +from rl_games.algos_torch import model_builder + +from torch import optim +import torch +import time +import os +import copy +import numpy as np + + +def swap_and_flatten01(arr): + """ + swap and then flatten axes 0 and 1 + """ + if arr is None: + return arr + s = arr.size() + return arr.transpose(0, 1).reshape(s[0] * s[1], *s[2:]) + + +class SHACAgent(ContinuousA2CBase): + + def __init__(self, base_name, params): + ContinuousA2CBase.__init__(self, base_name, params) + + obs_shape = self.obs_shape + build_config = { + 'actions_num': self.actions_num, + 'input_shape': obs_shape, + 'num_seqs': self.num_actors * self.num_agents, + 'value_size': self.env_info.get('value_size', 1), + 'normalize_value': self.normalize_value, + 'normalize_input': self.normalize_input, + } + + # apply tanh to actions + self.apply_tanh = self.config.get('apply_tanh', False) + + if self.apply_tanh: + print("tanh activation is applied to actions") + + self.critic_lr = float(self.config.get('critic_learning_rate', 1e-4)) + self.use_target_critic = self.config.get('use_target_critic', True) + self.target_critic_alpha = self.config.get('target_critic_alpha', 0.4) + + # to work with shac/diffrl repo + # update when warp/brax envs with truncated infor are supported + self.max_episode_length = params['diff_env'].get('episode_length', 1000) + self.actor_model = self.network.build(build_config) + self.critic_model = self.critic_network.build(build_config) + + self.actor_model.to(self.ppo_device) + self.critic_model.to(self.ppo_device) + self.target_critic = copy.deepcopy(self.critic_model) + + if self.linear_lr: + if self.max_epochs == -1 and self.max_frames == -1: + print("Max epochs and max frames are not set. Linear learning rate schedule can't be used, switching to the contstant (identity) one.") + self.critic_scheduler = schedulers.IdentityScheduler() + else: + print("Linear lr schedule. Min lr = ", self.min_lr) + use_epochs = True + max_steps = self.max_epochs + + if self.max_epochs == -1: + use_epochs = False + max_steps = self.max_frames + + self.critic_scheduler = schedulers.LinearScheduler(self.critic_lr, + min_lr = self.min_lr, + max_steps = max_steps, + use_epochs = use_epochs, + apply_to_entropy = False, + start_entropy_coef = 0.0) + else: + self.critic_scheduler = schedulers.IdentityScheduler() + + if self.normalize_input: + self.critic_model.running_mean_std = self.actor_model.running_mean_std + self.target_critic.running_mean_std = self.critic_model.running_mean_std + if self.normalize_value: + self.target_critic.value_mean_std = self.critic_model.value_mean_std + self.actor_model.value_mean_std = None + + self.states = None + self.model = self.actor_model + self.init_rnn_from_model(self.actor_model) + + self.actor_lr = self.last_lr + self.betas = self.config.get('betas', [0.9, 0.999]) + self.optimizer = self.actor_optimizer = optim.Adam(self.actor_model.parameters(), self.actor_lr, betas=self.betas, eps=1e-08, + weight_decay=self.weight_decay) + self.critic_optimizer = optim.Adam(self.critic_model.parameters(), self.critic_lr, betas=self.betas, eps=1e-08, + weight_decay=self.weight_decay) + + self.dataset = datasets.PPODataset(self.batch_size, self.minibatch_size, self.is_discrete, self.is_rnn, + self.ppo_device, self.seq_len) + + if self.normalize_value: + self.value_mean_std = self.critic_model.value_mean_std + + self.algo_observer.after_init(self) + + def play_steps(self): + update_list = self.update_list + accumulated_rewards = torch.zeros((self.horizon_length + 1, self.num_actors), dtype=torch.float32, device=self.device) + actor_loss = torch.tensor(0., dtype=torch.float32, device=self.device) + mb_values = self.experience_buffer.tensor_dict['values'] + gamma = torch.ones(self.num_actors, dtype=torch.float32, device=self.device) + step_time = 0.0 + + self.critic_model.eval() + self.target_critic.eval() + self.actor_model.train() + #torch.autograd.set_detect_anomaly(True) + + if self.normalize_input: + self.actor_model.running_mean_std.train() + if self.normalize_value: + self.value_mean_std.eval() + + obs = self.initialize_trajectory() + last_values = None + + for n in range(self.horizon_length): + res_dict = self.get_actions(obs) + if last_values is None: + last_values = res_dict['values'] = self.get_values(obs) + else: + res_dict['values'] = last_values + + with torch.no_grad(): + self.experience_buffer.update_data('obses', n, obs['obs'].detach()) + self.experience_buffer.update_data('dones', n, self.dones.detach()) + + self.experience_buffer.update_data('values', n, res_dict['values'].detach()) + + step_time_start = time.time() + if self.apply_tanh: + actions = torch.tanh(res_dict['actions']) + + obs, rewards, self.dones, infos = self.env_step(actions) + step_time_end = time.time() + episode_ended = self.current_lengths == self.max_episode_length - 1 + episode_ended_indices = episode_ended.nonzero(as_tuple=False) + step_time += (step_time_end - step_time_start) + + shaped_rewards = self.rewards_shaper(rewards) + + if self.normalize_input: + self.actor_model.running_mean_std.eval() + + real_obs = infos['obs_before_reset'] + last_obs_vals = last_values.clone() #.detach() + for ind in episode_ended_indices: + if torch.isnan(real_obs[ind]).sum() > 0 \ + or torch.isinf(real_obs[ind]).sum() > 0 \ + or (torch.abs(real_obs[ind]) > 1e6).sum() > 0: # ugly fix for nan values + print('Nan gradients: ', ind) + last_obs_vals[ind] = 0 + else: + curr_real_obs = self.obs_to_tensors(real_obs[ind]) + val = self.get_values(curr_real_obs) + last_obs_vals[ind] = val + + if self.normalize_input: + self.actor_model.running_mean_std.train() + shaped_rewards += last_obs_vals * episode_ended.unsqueeze(1).float() + self.experience_buffer.update_data('rewards', n, shaped_rewards.detach()) + + self.current_rewards += rewards.detach() + self.current_lengths += 1 + + env_done_indices = self.dones.view(self.num_actors, self.num_agents).all(dim=1).nonzero(as_tuple=False) + + self.game_rewards.update(self.current_rewards[env_done_indices]) + self.game_lengths.update(self.current_lengths[env_done_indices]) + self.algo_observer.process_infos(infos, env_done_indices) + fdones = self.dones.float() + not_dones = 1.0 - fdones + + accumulated_rewards[n + 1] = accumulated_rewards[n] + gamma * shaped_rewards.squeeze(1) + + last_values = self.get_values(obs) + + self.current_rewards = self.current_rewards * not_dones.unsqueeze(1) + self.current_lengths = self.current_lengths * not_dones + if n < self.horizon_length - 1: + actor_loss = actor_loss - ( + accumulated_rewards[n + 1, env_done_indices]).sum() + else: + actor_loss = actor_loss - ( + accumulated_rewards[n + 1, :] + self.gamma * gamma * last_values.squeeze() * (1.0-episode_ended.float()) * not_dones).sum() + + gamma = gamma * self.gamma + gamma[env_done_indices] = 1.0 + accumulated_rewards[n + 1, env_done_indices] = 0.0 + fdones = self.dones.float().detach() + mb_fdones = self.experience_buffer.tensor_dict['dones'].float().detach() + mb_rewards = self.experience_buffer.tensor_dict['rewards'].detach() + mb_advs = self.discount_values(fdones, last_values.detach(), mb_fdones, mb_values.detach(), mb_rewards) + mb_returns = mb_advs + mb_values + + batch_dict = self.experience_buffer.get_transformed_list(swap_and_flatten01, self.tensor_list) + batch_dict['returns'] = swap_and_flatten01(mb_returns) + batch_dict['played_frames'] = self.batch_size + batch_dict['step_time'] = step_time + + actor_loss = actor_loss / (self.horizon_length * self.num_actors) + + return batch_dict, actor_loss + + def env_step(self, actions): + # todo: add preprocessing + #actions = self.preprocess_actions(actions) + obs, rewards, dones, infos = self.vec_env.step(actions) + + if self.value_size == 1: + rewards = rewards.unsqueeze(1) + return self.obs_to_tensors(obs), rewards.to(self.ppo_device), dones.to(self.ppo_device), infos + + def load_networks(self, params): + ContinuousA2CBase.load_networks(self, params) + if 'critic_config' in self.config: + builder = model_builder.ModelBuilder() + print('Adding Critic Network') + network = builder.load(params['config']['critic_config']) + self.critic_network = network + + def get_actions(self, obs): + processed_obs = self._preproc_obs(obs['obs']) + input_dict = { + 'obs' : processed_obs, + 'rnn_states' : self.rnn_states + } + res_dict = self.actor_model(input_dict) + return res_dict + + def get_values(self, obs): + processed_obs = self._preproc_obs(obs['obs']) + input_dict = { + 'is_train': False, + 'prev_actions': None, + 'obs' : processed_obs, + 'rnn_states' : self.rnn_states + } + + processed_obs = self._preproc_obs(obs['obs']) + if self.use_target_critic: + result = self.target_critic(input_dict) + else: + result = self.critic_model(input_dict) + value = result['values'] + return value + + def initialize_trajectory(self): + #obs = self.vec_env.reset() + obs = self.vec_env.initialize_trajectory() + obs = self.obs_to_tensors(obs) + return obs + + def update_epoch(self): + self.epoch_num += 1 + return self.epoch_num + + def save(self, fn): + state = self.get_full_state_weights() + torch_ext.save_checkpoint(fn, state) + + def restore(self, fn): + checkpoint = torch_ext.load_checkpoint(fn) + self.set_full_state_weights(checkpoint) + + def get_masked_action_values(self, obs, action_masks): + assert False + + def prepare_critic_dataset(self, batch_dict): + obses = batch_dict['obses'].detach() + returns = batch_dict['returns'].detach() + dones = batch_dict['dones'].detach() + values = batch_dict['values'].detach() + rnn_states = batch_dict.get('rnn_states', None) + rnn_masks = batch_dict.get('rnn_masks', None) + + if self.normalize_value: + self.value_mean_std.train() + values = self.value_mean_std(values) + returns = self.value_mean_std(returns) + self.value_mean_std.eval() + + dataset_dict = {} + dataset_dict['old_values'] = values + dataset_dict['returns'] = returns + dataset_dict['obs'] = obses + dataset_dict['dones'] = dones + dataset_dict['rnn_states'] = rnn_states + dataset_dict['rnn_masks'] = rnn_masks + + self.dataset.update_values_dict(dataset_dict) + + def train_actor(self, actor_loss): + self.actor_model.train() + + self.actor_optimizer.zero_grad(set_to_none=True) + + actor_loss.backward() + + if self.truncate_grads: + torch.nn.utils.clip_grad_norm_(self.actor_model.parameters(), self.grad_norm) + + self.actor_optimizer.step() + + return actor_loss.detach() + + def train_critic(self, batch): + self.critic_model.train() + if self.normalize_input: + self.critic_model.running_mean_std.eval() + if self.normalize_value: + self.critic_model.value_mean_std.eval() + + obs_batch = self._preproc_obs(batch['obs']) + value_preds_batch = batch['old_values'] + returns_batch = batch['returns'] + dones_batch = batch['dones'] + rnn_masks_batch = batch.get('rnn_masks') + batch_dict = {'obs' : obs_batch, + 'seq_length' : self.seq_len, + 'dones' : dones_batch} + + res_dict = self.critic_model(batch_dict) + values = res_dict['values'] + loss = common_losses.critic_loss(value_preds_batch, values, self.e_clip, returns_batch, self.clip_value) + losses, _ = torch_ext.apply_masks([loss], rnn_masks_batch) + critic_loss = losses[0] + self.critic_optimizer.zero_grad(set_to_none=True) + critic_loss.backward() + + if self.truncate_grads: + torch.nn.utils.clip_grad_norm_(self.critic_model.parameters(), self.grad_norm) + + self.critic_optimizer.step() + + return critic_loss.detach() + + def update_lr(self, actor_lr, critic_lr): + for param_group in self.actor_optimizer.param_groups: + param_group['lr'] = actor_lr + + for param_group in self.critic_optimizer.param_groups: + param_group['lr'] = critic_lr + + def train_epoch(self): + play_time_start = time.time() + batch_dict, actor_loss = self.play_steps() + play_time_end = time.time() + update_time_start = time.time() + + self.curr_frames = batch_dict.pop('played_frames') + self.prepare_critic_dataset(batch_dict) + self.algo_observer.after_steps() + a_loss = self.train_actor(actor_loss) + a_losses = [a_loss] + c_losses = [] + + for mini_ep in range(0, self.mini_epochs_num): + ep_kls = [] + for i in range(len(self.dataset)): + c_loss = self.train_critic(self.dataset[i]) + c_losses.append(c_loss) + + self.diagnostics.mini_epoch(self, mini_ep) + + # update target critic + with torch.no_grad(): + alpha = self.target_critic_alpha + for param, param_targ in zip(self.critic_model.parameters(), self.target_critic.parameters()): + param_targ.data.mul_(alpha) + param_targ.data.add_((1. - alpha) * param.data) + + self.last_lr, _ = self.scheduler.update(self.last_lr, 0, self.epoch_num, self.frame, 0) + self.critic_lr, _ = self.critic_scheduler.update(self.critic_lr, 0, self.epoch_num, self.frame, 0) + + self.update_lr(self.last_lr, self.critic_lr) + update_time_end = time.time() + play_time = play_time_end - play_time_start + update_time = update_time_end - update_time_start + total_time = update_time_end - play_time_start + + return batch_dict['step_time'], play_time, update_time, total_time, a_losses, c_losses + + def train(self): + self.init_tensors() + self.mean_rewards = self.last_mean_rewards = -100500 + total_time = 0 + + while True: + epoch_num = self.update_epoch() + step_time, play_time, update_time, sum_time, a_losses, c_losses = self.train_epoch() + + # cleaning memory to optimize space + self.dataset.update_values_dict(None) + total_time += sum_time + curr_frames = self.curr_frames + self.frame += curr_frames + should_exit = False + + if self.rank == 0: + self.diagnostics.epoch(self, current_epoch=epoch_num) + scaled_time = self.num_agents * sum_time + scaled_play_time = self.num_agents * play_time + + frame = self.frame // self.num_agents + + print_statistics(self.print_stats, curr_frames, + step_time, scaled_play_time, scaled_time, + epoch_num, self.max_epochs, self.frame, self.max_frames) + + if self.print_stats: + print(f'actor loss: {a_losses[0].item():.2f}') + + self.write_stats(total_time, epoch_num, step_time, play_time, update_time, a_losses, c_losses, self.last_lr, curr_frames) + + self.algo_observer.after_print_stats(frame, epoch_num, total_time) + + if self.game_rewards.current_size > 0: + mean_rewards = self.game_rewards.get_mean() + mean_lengths = self.game_lengths.get_mean() + self.mean_rewards = mean_rewards[0] + + for i in range(self.value_size): + rewards_name = 'rewards' if i == 0 else 'rewards{0}'.format(i) + self.writer.add_scalar(rewards_name + '/step'.format(i), mean_rewards[i], frame) + self.writer.add_scalar(rewards_name + '/iter'.format(i), mean_rewards[i], epoch_num) + self.writer.add_scalar(rewards_name + '/time'.format(i), mean_rewards[i], total_time) + + self.writer.add_scalar('episode_lengths/step', mean_lengths, frame) + self.writer.add_scalar('episode_lengths/iter', mean_lengths, epoch_num) + self.writer.add_scalar('episode_lengths/time', mean_lengths, total_time) + + # removed equal signs (i.e. "rew=") from the checkpoint name since it messes with hydra CLI parsing + checkpoint_name = self.config['name'] + '_ep_' + str(epoch_num) + '_rew_' + str(mean_rewards[0]) + + if self.save_freq > 0: + if (epoch_num % self.save_freq == 0) and (mean_rewards <= self.last_mean_rewards): + self.save(os.path.join(self.nn_dir, 'last_' + checkpoint_name)) + + if mean_rewards[0] > self.last_mean_rewards and epoch_num >= self.save_best_after: + print('saving next best rewards: ', mean_rewards) + self.last_mean_rewards = mean_rewards[0] + self.save(os.path.join(self.nn_dir, self.config['name'])) + + if 'score_to_win' in self.config: + if self.last_mean_rewards > self.config['score_to_win']: + print('Maximum reward achieved. Network won!') + self.save(os.path.join(self.nn_dir, checkpoint_name)) + should_exit = True + + if epoch_num >= self.max_epochs and self.max_epochs != -1: + if self.game_rewards.current_size == 0: + print('WARNING: Max epochs reached before any env terminated at least once') + mean_rewards = -np.inf + + self.save(os.path.join(self.nn_dir, 'last_' + self.config['name'] + '_ep_' + str(epoch_num) \ + + '_rew_' + str(mean_rewards).replace('[', '_').replace(']', '_'))) + print('MAX EPOCHS NUM!') + should_exit = True + + if self.frame >= self.max_frames and self.max_frames != -1: + if self.game_rewards.current_size == 0: + print('WARNING: Max frames reached before any env terminated at least once') + mean_rewards = -np.inf + + self.save(os.path.join(self.nn_dir, 'last_' + self.config['name'] + '_frame_' + str(self.frame) \ + + '_rew_' + str(mean_rewards).replace('[', '_').replace(']', '_'))) + print('MAX FRAMES NUM!') + should_exit = True + + update_time = 0 + + if should_exit: + return self.last_mean_rewards, epoch_num + + def write_stats(self, total_time, epoch_num, step_time, play_time, update_time, a_losses, c_losses, last_lr, curr_frames): + # do we need scaled time? + frame = self.frame + + self.diagnostics.send_info(self.writer) + self.writer.add_scalar('performance/step_inference_rl_update_fps', curr_frames / update_time, frame) + self.writer.add_scalar('performance/step_inference_fps', curr_frames / play_time, frame) + self.writer.add_scalar('performance/step_fps', curr_frames / step_time, frame) + self.writer.add_scalar('performance/rl_update_time', update_time, frame) + self.writer.add_scalar('performance/step_inference_time', play_time, frame) + self.writer.add_scalar('performance/step_time', step_time, frame) + self.writer.add_scalar('losses/a_loss', torch_ext.mean_list(a_losses).item(), frame) + self.writer.add_scalar('losses/c_loss', torch_ext.mean_list(c_losses).item(), frame) + self.writer.add_scalar('info/epochs', epoch_num, frame) + self.writer.add_scalar('info/actor_lr/frame', last_lr, frame) + self.writer.add_scalar('info/actor_lr/epoch_num', last_lr, epoch_num) + self.writer.add_scalar('info/critic_lr/frame', self.critic_lr, frame) + self.writer.add_scalar('info/critic_lr/epoch_num', self.critic_lr, epoch_num) + self.algo_observer.after_print_stats(frame, epoch_num, total_time) \ No newline at end of file diff --git a/rl_games/common/a2c_common.py b/rl_games/common/a2c_common.py index ad244578..4d211358 100644 --- a/rl_games/common/a2c_common.py +++ b/rl_games/common/a2c_common.py @@ -11,7 +11,7 @@ from rl_games.common.interval_summary_writer import IntervalSummaryWriter from rl_games.common.diagnostics import DefaultDiagnostics, PpoDiagnostics from rl_games.algos_torch import model_builder -from rl_games.interfaces.base_algorithm import BaseAlgorithm +from rl_games.interfaces.base_algorithm import BaseAlgorithm import numpy as np import time import gym @@ -157,16 +157,22 @@ def __init__(self, base_name, params): self.is_adaptive_lr = config['lr_schedule'] == 'adaptive' self.linear_lr = config['lr_schedule'] == 'linear' + + # min learning rate used with linear and adaptive schedules + self.min_lr = float(config.get('min_learning_rate', 1e-6)) + # max learning rate used with adaptive schedule + self.max_lr = float(config.get('max_learning_rate', 1e-2)) + self.schedule_type = config.get('schedule_type', 'legacy') # Setting learning rate scheduler if self.is_adaptive_lr: self.kl_threshold = config['kl_threshold'] - self.scheduler = schedulers.AdaptiveScheduler(self.kl_threshold) + self.scheduler = schedulers.AdaptiveScheduler(kl_threshold = self.kl_threshold, + min_lr = self.min_lr, max_lr = self.max_lr) elif self.linear_lr: - - if self.max_epochs == -1 and self.max_frames != -1: + if self.max_epochs == -1 and self.max_frames == -1: print("Max epochs and max frames are not set. Linear learning rate schedule can't be used, switching to the contstant (identity) one.") self.scheduler = schedulers.IdentityScheduler() else: @@ -177,7 +183,8 @@ def __init__(self, base_name, params): use_epochs = False max_steps = self.max_frames - self.scheduler = schedulers.LinearScheduler(float(config['learning_rate']), + self.scheduler = schedulers.LinearScheduler(float(config['learning_rate']), + min_lr = self.min_lr, max_steps = max_steps, use_epochs = use_epochs, apply_to_entropy = config.get('schedule_entropy', False), @@ -230,7 +237,7 @@ def __init__(self, base_name, params): self.mixed_precision = self.config.get('mixed_precision', False) self.scaler = torch.cuda.amp.GradScaler(enabled=self.mixed_precision) - self.last_lr = self.config['learning_rate'] + self.last_lr = float(self.config['learning_rate']) self.frame = 0 self.update_time = 0 self.mean_rewards = self.last_mean_rewards = -100500 @@ -339,7 +346,7 @@ def write_stats(self, total_time, epoch_num, step_time, play_time, update_time, self.writer.add_scalar('performance/step_time', step_time, frame) self.writer.add_scalar('losses/a_loss', torch_ext.mean_list(a_losses).item(), frame) self.writer.add_scalar('losses/c_loss', torch_ext.mean_list(c_losses).item(), frame) - + self.writer.add_scalar('losses/entropy', torch_ext.mean_list(entropies).item(), frame) self.writer.add_scalar('info/last_lr', last_lr * lr_mul, frame) self.writer.add_scalar('info/lr_mul', lr_mul, frame) @@ -437,7 +444,7 @@ def init_tensors(self): val_shape = (self.horizon_length, batch_size, self.value_size) current_rewards_shape = (batch_size, self.value_size) self.current_rewards = torch.zeros(current_rewards_shape, dtype=torch.float32, device=self.ppo_device) - self.current_lengths = torch.zeros(batch_size, dtype=torch.float32, device=self.ppo_device) + self.current_lengths = torch.zeros(batch_size, dtype=torch.long, device=self.ppo_device) self.dones = torch.ones((batch_size,), dtype=torch.uint8, device=self.ppo_device) if self.is_rnn: @@ -522,6 +529,7 @@ def discount_values(self, fdones, last_extrinsic_values, mb_fdones, mb_extrinsic delta = mb_rewards[t] + self.gamma * nextvalues * nextnonterminal - mb_extrinsic_values[t] mb_advs[t] = lastgaelam = delta + self.gamma * self.tau * nextnonterminal * lastgaelam + return mb_advs def discount_values_masks(self, fdones, last_extrinsic_values, mb_fdones, mb_extrinsic_values, mb_rewards, mb_masks): @@ -538,6 +546,7 @@ def discount_values_masks(self, fdones, last_extrinsic_values, mb_fdones, mb_ext masks_t = mb_masks[t].unsqueeze(1) delta = (mb_rewards[t] + self.gamma * nextvalues * nextnonterminal - mb_extrinsic_values[t]) mb_advs[t] = lastgaelam = (delta + self.gamma * self.tau * nextnonterminal * lastgaelam) * masks_t + return mb_advs def clear_stats(self): @@ -755,6 +764,7 @@ def play_steps_rnn(self): self.current_lengths += 1 all_done_indices = self.dones.nonzero(as_tuple=False) env_done_indices = self.dones.view(self.num_actors, self.num_agents).all(dim=1).nonzero(as_tuple=False) + if len(all_done_indices) > 0: for s in self.rnn_states: s[:, all_done_indices, :] = s[:, all_done_indices, :] * 0.0 @@ -789,6 +799,7 @@ def play_steps_rnn(self): states.append(mb_s.permute(1,2,0,3).reshape(-1,t_size, h_size)) batch_dict['rnn_states'] = states batch_dict['step_time'] = step_time + return batch_dict @@ -859,7 +870,7 @@ def train_epoch(self): dist.all_reduce(av_kls, op=dist.ReduceOp.SUM) av_kls /= self.rank_size - self.last_lr, self.entropy_coef = self.scheduler.update(self.last_lr, self.entropy_coef, self.epoch_num, 0, av_kls.item()) + self.last_lr, self.entropy_coef = self.scheduler.update(self.last_lr, self.entropy_coef, self.epoch_num, self.frame, av_kls.item()) self.update_lr(self.last_lr) kls.append(av_kls) self.diagnostics.mini_epoch(self, mini_ep) @@ -1120,7 +1131,7 @@ def train_epoch(self): if self.multi_gpu: dist.all_reduce(kl, op=dist.ReduceOp.SUM) av_kls /= self.rank_size - self.last_lr, self.entropy_coef = self.scheduler.update(self.last_lr, self.entropy_coef, self.epoch_num, 0, av_kls.item()) + self.last_lr, self.entropy_coef = self.scheduler.update(self.last_lr, self.entropy_coef, self.epoch_num, self.frame, av_kls.item()) self.update_lr(self.last_lr) av_kls = torch_ext.mean_list(ep_kls) @@ -1128,7 +1139,7 @@ def train_epoch(self): dist.all_reduce(av_kls, op=dist.ReduceOp.SUM) av_kls /= self.rank_size if self.schedule_type == 'standard': - self.last_lr, self.entropy_coef = self.scheduler.update(self.last_lr, self.entropy_coef, self.epoch_num, 0, av_kls.item()) + self.last_lr, self.entropy_coef = self.scheduler.update(self.last_lr, self.entropy_coef, self.epoch_num, self.frame, av_kls.item()) self.update_lr(self.last_lr) kls.append(av_kls) @@ -1310,8 +1321,6 @@ def train(self): should_exit_t = torch.tensor(should_exit, device=self.device).float() dist.broadcast(should_exit_t, 0) should_exit = should_exit_t.float().item() - if should_exit: - return self.last_mean_rewards, epoch_num if should_exit: return self.last_mean_rewards, epoch_num diff --git a/rl_games/common/player.py b/rl_games/common/player.py index 527602bd..71803382 100644 --- a/rl_games/common/player.py +++ b/rl_games/common/player.py @@ -4,10 +4,13 @@ import torch import copy from rl_games.common import env_configurations -from rl_games.algos_torch import model_builder +from rl_games.algos_torch import model_builder + class BasePlayer(object): + def __init__(self, params): + self.config = config = params['config'] self.load_networks(params) self.env_name = self.config['env_name'] @@ -15,11 +18,13 @@ def __init__(self, params): self.env_info = self.config.get('env_info') self.clip_actions = config.get('clip_actions', True) self.seed = self.env_config.pop('seed', None) + if self.env_info is None: self.env = self.create_env() self.env_info = env_configurations.get_env_info(self.env) else: self.env = config.get('vec_env') + self.value_size = self.env_info.get('value_size', 1) self.action_space = self.env_info['action_space'] self.num_agents = self.env_info['agents'] @@ -42,10 +47,9 @@ def __init__(self, params): self.device_name = self.config.get('device_name', 'cuda') self.render_env = self.player_config.get('render', False) self.games_num = self.player_config.get('games_num', 2000) - if 'deterministic' in self.player_config: - self.is_deterministic = self.player_config['deterministic'] - else: - self.is_deterministic = self.player_config.get('determenistic', True) + + self.is_deterministic = self.player_config.get('determenistic', True) + self.n_game_life = self.player_config.get('n_game_life', 1) self.print_stats = self.player_config.get('print_stats', True) self.render_sleep = self.player_config.get('render_sleep', 0.002) @@ -215,7 +219,7 @@ def run(self): steps += 1 if render: - self.env.render(mode='human') + #self.env.render() time.sleep(self.render_sleep) all_done_indices = done.nonzero(as_tuple=False) diff --git a/rl_games/common/schedulers.py b/rl_games/common/schedulers.py index 78a89ebc..0196d873 100644 --- a/rl_games/common/schedulers.py +++ b/rl_games/common/schedulers.py @@ -1,5 +1,3 @@ - - class RLScheduler: def __init__(self): pass @@ -7,29 +5,32 @@ def __init__(self): def update(self,current_lr, entropy_coef, epoch, frames, **kwargs): pass + class IdentityScheduler(RLScheduler): def __init__(self): super().__init__() - def update(self, current_lr, entropy_coef, epoch, frames, kl_dist, **kwargs): return current_lr, entropy_coef class AdaptiveScheduler(RLScheduler): - def __init__(self, kl_threshold = 0.008): + def __init__(self, kl_threshold=0.008, min_lr=1e-6, max_lr=1e-2): super().__init__() - self.min_lr = 1e-6 - self.max_lr = 1e-2 + + self.min_lr = min_lr + self.max_lr = max_lr self.kl_threshold = kl_threshold def update(self, current_lr, entropy_coef, epoch, frames, kl_dist, **kwargs): lr = current_lr if kl_dist > (2.0 * self.kl_threshold): lr = max(current_lr / 1.5, self.min_lr) + if kl_dist < (0.5 * self.kl_threshold): lr = min(current_lr * 1.5, self.max_lr) - return lr, entropy_coef + + return lr, entropy_coef class LinearScheduler(RLScheduler): @@ -50,9 +51,11 @@ def update(self, current_lr, entropy_coef, epoch, frames, kl_dist, **kwargs): steps = epoch else: steps = frames - mul = max(0, self.max_steps - steps)/self.max_steps + + mul = max(0, self.max_steps - steps)/self.max_steps lr = self.min_lr + (self.start_lr - self.min_lr) * mul + if self.apply_to_entropy: entropy_coef = self.min_entropy_coef + (self.start_entropy_coef - self.min_entropy_coef) * mul - return lr, entropy_coef \ No newline at end of file + return lr, entropy_coef diff --git a/rl_games/configs/atari/ppo_breakout.yaml b/rl_games/configs/atari/ppo_breakout.yaml index bc34be48..485ee75e 100644 --- a/rl_games/configs/atari/ppo_breakout.yaml +++ b/rl_games/configs/atari/ppo_breakout.yaml @@ -75,9 +75,9 @@ params: name: BreakoutNoFrameskip-v4 episode_life: True seed: 5 - player: render: False - games_num: 200 + games_num: 50 n_game_life: 5 determenistic: False + diff --git a/rl_games/envs/connect4_selfplay.py b/rl_games/envs/connect4_selfplay.py index ae761973..b91157b5 100644 --- a/rl_games/envs/connect4_selfplay.py +++ b/rl_games/envs/connect4_selfplay.py @@ -6,11 +6,14 @@ import os from collections import deque + class ConnectFourSelfPlay(gym.Env): + def __init__(self, name="connect_four_v0", **kwargs): gym.Env.__init__(self) + self.name = name - self.is_determenistic = kwargs.pop('is_deterministic', False) + self.is_determenistic = kwargs.pop('deterministic', False) self.is_human = kwargs.pop('is_human', False) self.random_agent = kwargs.pop('random_agent', False) self.config_path = kwargs.pop('config_path') @@ -63,7 +66,6 @@ def reset(self): opponent_action = np.random.choice(ids, 1)[0] else: opponent_action = self.agent.get_masked_action(op_obs, mask, self.is_deterministic).item() - obs, _, _, _ = self.env_step(opponent_action) @@ -83,9 +85,7 @@ def create_agent(self, config): self.agent = runner.create_player() self.agent.model.eval() - def step(self, action): - obs, reward, done, info = self.env_step(action) self.obs_deque.append(obs) diff --git a/rl_games/envs/slimevolley_selfplay.py b/rl_games/envs/slimevolley_selfplay.py index 11e9cb13..7510c82d 100644 --- a/rl_games/envs/slimevolley_selfplay.py +++ b/rl_games/envs/slimevolley_selfplay.py @@ -5,11 +5,14 @@ from rl_games.torch_runner import Runner import os + class SlimeVolleySelfplay(gym.Env): + def __init__(self, name="SlimeVolleyDiscrete-v0", **kwargs): gym.Env.__init__(self) + self.name = name - self.is_deterministic = kwargs.pop('is_deterministic', False) + self.is_deterministic = kwargs.pop('deterministic', False) self.config_path = kwargs.pop('config_path') self.agent = None self.pos_scale = 1 @@ -41,7 +44,6 @@ def create_agent(self, config='rl_games/configs/ma/ppo_slime_self_play.yaml'): self.agent = runner.create_player() - def step(self, action): op_obs = self.agent.obs_to_torch(self.opponent_obs) @@ -50,7 +52,7 @@ def step(self, action): self.sum_rewards += reward if reward < 0: reward = reward * self.neg_scale - + self.opponent_obs = info['otherObs'] if done: info['battle_won'] = np.sign(self.sum_rewards) diff --git a/rl_games/interfaces/base_algorithm.py b/rl_games/interfaces/base_algorithm.py index 054483f7..2783249a 100644 --- a/rl_games/interfaces/base_algorithm.py +++ b/rl_games/interfaces/base_algorithm.py @@ -1,6 +1,7 @@ from abc import ABC from abc import abstractmethod, abstractproperty + class BaseAlgorithm(ABC): def __init__(self, base_name, config): pass diff --git a/rl_games/torch_runner.py b/rl_games/torch_runner.py index 8a58d38a..8c577ecf 100644 --- a/rl_games/torch_runner.py +++ b/rl_games/torch_runner.py @@ -4,17 +4,13 @@ import random from copy import deepcopy import torch -#import yaml -#from rl_games import envs -from rl_games.common import object_factory -from rl_games.common import tr_helpers +from rl_games.common import object_factory, tr_helpers -from rl_games.algos_torch import a2c_continuous -from rl_games.algos_torch import a2c_discrete +from rl_games.algos_torch import a2c_continuous, a2c_discrete from rl_games.algos_torch import players +from rl_games.algos_torch import sac_agent, shac_agent from rl_games.common.algo_observer import DefaultAlgoObserver -from rl_games.algos_torch import sac_agent def _restore(agent, args): @@ -31,19 +27,20 @@ def _override_sigma(agent, args): else: print('Print cannot set new sigma because fixed_sigma is False') - class Runner: def __init__(self, algo_observer=None): self.algo_factory = object_factory.ObjectFactory() self.algo_factory.register_builder('a2c_continuous', lambda **kwargs : a2c_continuous.A2CAgent(**kwargs)) - self.algo_factory.register_builder('a2c_discrete', lambda **kwargs : a2c_discrete.DiscreteA2CAgent(**kwargs)) + self.algo_factory.register_builder('a2c_discrete', lambda **kwargs : a2c_discrete.DiscreteA2CAgent(**kwargs)) self.algo_factory.register_builder('sac', lambda **kwargs: sac_agent.SACAgent(**kwargs)) + self.algo_factory.register_builder('shac', lambda **kwargs: shac_agent.SHACAgent(**kwargs)) #self.algo_factory.register_builder('dqn', lambda **kwargs : dqnagent.DQNAgent(**kwargs)) self.player_factory = object_factory.ObjectFactory() self.player_factory.register_builder('a2c_continuous', lambda **kwargs : players.PpoPlayerContinuous(**kwargs)) self.player_factory.register_builder('a2c_discrete', lambda **kwargs : players.PpoPlayerDiscrete(**kwargs)) self.player_factory.register_builder('sac', lambda **kwargs : players.SACPlayer(**kwargs)) + self.player_factory.register_builder('shac', lambda **kwargs : players.SHACPlayer(**kwargs)) #self.player_factory.register_builder('dqn', lambda **kwargs : players.DQNPlayer(**kwargs)) self.algo_observer = algo_observer if algo_observer else DefaultAlgoObserver() @@ -67,7 +64,6 @@ def load_config(self, params): self.algo_params = params['algo'] self.algo_name = self.algo_params['name'] self.exp_config = None - if self.seed: torch.manual_seed(self.seed) torch.cuda.manual_seed_all(self.seed) @@ -115,8 +111,6 @@ def reset(self): pass def run(self, args): - load_path = None - if args['train']: self.run_train(args) diff --git a/setup.py b/setup.py index 10ae34a3..fd1ca15d 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,7 @@ #packages=[package for package in find_packages() if package.startswith('rl_games')], packages = ['.','rl_games','docs'], package_data={'rl_games':['*','*/*','*/*/*'],'docs':['*','*/*','*/*/*'],}, - version='1.5.2', + version='1.6.0', author='Denys Makoviichuk, Viktor Makoviichuk', author_email='trrrrr97@gmail.com, victor.makoviychuk@gmail.com', license="MIT",