Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Added SHAC Support #168

Open
wants to merge 33 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
44bc254
WIP
ViktorM May 22, 2022
130bfc5
SHAC agent, network, model. WIP.
ViktorM May 23, 2022
afe1aea
first commit
DenSumy May 23, 2022
b49d21f
removed shac model
DenSumy May 23, 2022
15e83e4
merged master
DenSumy May 25, 2022
8bc6ccd
it works
DenSumy May 26, 2022
0725dde
added shac agent
DenSumy May 26, 2022
5dbdc0f
fixed actor loss
DenSumy May 27, 2022
27c0e02
removed shac
DenSumy May 27, 2022
6ce26a2
more cleanup
DenSumy May 27, 2022
50f5c68
Added linear lr for actor and critic. Currently assumes having the sa…
ViktorM May 27, 2022
191401c
Merge branch 'VM/shac' of https://github.com/Denys88/rl_games into VM…
ViktorM May 27, 2022
9c16caf
updated tb
DenSumy May 27, 2022
5d6ce14
best shac
DenSumy May 27, 2022
f67b680
fixed rms
DenSumy May 28, 2022
278c3e4
last
DenSumy May 28, 2022
cafe153
fixed copypaste
DenSumy May 28, 2022
8f1f4a2
Independent lr schedulers for actor and critic.
ViktorM May 29, 2022
6233e98
best version
DenSumy May 29, 2022
9113f82
last update
DenSumy May 30, 2022
187f21d
shac which works
DenSumy May 31, 2022
a965e64
small update to make it equal
DenSumy May 31, 2022
f0a35f8
Merging master.
ViktorM Jul 7, 2022
28fbe8c
Merged with master. Tanh to actions is optional now.
ViktorM Oct 4, 2022
162f6c6
SHAC cleanup. Updated release version.
ViktorM Oct 28, 2022
dfd66d2
Merge branch 'VM/shac' of https://github.com/Denys88/rl_games into VM…
ViktorM Oct 28, 2022
e2f1f4b
Updated SHAC agent to be aligned with PPO and SAC implementations. Ad…
ViktorM Jan 9, 2023
93a2b4b
More improvements.
ViktorM Jan 9, 2023
f1491b6
Added release notes.
ViktorM Jan 9, 2023
fd5ce6f
Fixed const learning rate for SHAC. Fixed linear scheduling with max_…
ViktorM Jan 9, 2023
28a35cf
Fixed max_frames for central value.
ViktorM Jan 9, 2023
b71596a
Readme update.
ViktorM Jan 12, 2023
aa60067
Readme update.
ViktorM Jan 12, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@

## Papers and related links

* Isaac Gym: High Performance GPU-Based Physics Simulation For Robot Learning: https://arxiv.org/abs/2108.10470
* Transferring Dexterous Manipulation from GPU Simulation to a Remote Real-World TriFinger: https://s2r2-ig.github.io/ https://arxiv.org/abs/2108.09779
* Is Independent Learning All You Need in the StarCraft Multi-Agent Challenge? <https://arxiv.org/abs/2011.09533>
* Superfast Adversarial Motion Priors (AMP) implementation: https://twitter.com/xbpeng4/status/1506317490766303235 https://github.com/NVIDIA-Omniverse/IsaacGymEnvs
* OSCAR: Data-Driven Operational Space Control for Adaptive and Robust Robot Manipulation: https://cremebrule.github.io/oscar-web/ https://arxiv.org/abs/2110.00704
* Isaac Gym: High Performance GPU-Based Physics Simulation For Robot Learning. Paper: https://arxiv.org/abs/2108.10470
* Transferring Dexterous Manipulation from GPU Simulation to a Remote Real-World TriFinger. Site: https://s2r2-ig.github.io/ Paper: https://arxiv.org/abs/2108.09779
* Is Independent Learning All You Need in the StarCraft Multi-Agent Challenge? Paper: https://arxiv.org/abs/2011.09533
* Superfast Adversarial Motion Priors (AMP) implementation. Twitter: https://twitter.com/xbpeng4/status/1506317490766303235 Repo: https://github.com/NVIDIA-Omniverse/IsaacGymEnvs
* OSCAR: Data-Driven Operational Space Control for Adaptive and Robust Robot Manipulation. Site: https://cremebrule.github.io/oscar-web/ Paper: https://arxiv.org/abs/2110.00704
* EnvPool: A Highly Parallel Reinforcement Learning Environment Execution Engine. Paper: https://arxiv.org/abs/2206.10558 Repo: https://github.com/sail-sg/envpool
* TimeChamber: A Massively Parallel Large Scale Self-Play Framework. Repo: https://github.com/inspirai/TimeChamber

## Some results on the different environments

Expand Down Expand Up @@ -76,7 +78,7 @@ If you use rl-games in your research please use the following citation:
title = {rl-games: A High-performance Framework for Reinforcement Learning},
author = {Makoviichuk, Denys and Makoviychuk, Viktor},
month = {May},
year = {2022},
year = {2021},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/Denys88/rl_games}},
Expand Down Expand Up @@ -274,6 +276,11 @@ Additional environment supported properties and functions

## Release Notes

1.6.0
* Implemented SHAC algorithm: [Accelerated Policy Learning with Parallel Differentiable Simulation](https://short-horizon-actor-critic.github.io/) (ICLR 2022)
* Fixed various bugs related to num_frames/num_epochs interaction.
* Fixed a few SAC training configs, and improved SAC implementation.

1.5.2

* Added observation normalization to the SAC.
Expand Down
11 changes: 6 additions & 5 deletions rl_games/algos_torch/a2c_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
from rl_games.common import datasets

from torch import optim
import torch
from torch import nn
import numpy as np
import gym
import torch


class A2CAgent(a2c_common.ContinuousA2CBase):

def __init__(self, base_name, params):
a2c_common.ContinuousA2CBase.__init__(self, base_name, params)

obs_shape = self.obs_shape
build_config = {
'actions_num' : self.actions_num,
Expand All @@ -23,7 +23,7 @@ def __init__(self, base_name, params):
'normalize_value' : self.normalize_value,
'normalize_input': self.normalize_input,
}

self.model = self.network.build(build_config)
self.model.to(self.ppo_device)
self.states = None
Expand All @@ -47,6 +47,7 @@ def __init__(self, base_name, params):
'config' : self.central_value_config,
'writter' : self.writer,
'max_epochs' : self.max_epochs,
'max_frames' : self.max_frames,
'multi_gpu' : self.multi_gpu,
}
self.central_value_net = central_value.CentralValueTrain(**cv_config).to(self.ppo_device)
Expand Down
40 changes: 27 additions & 13 deletions rl_games/algos_torch/central_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import torch
from torch import nn
import torch.distributed as dist
import gym
import numpy as np
from rl_games.algos_torch import torch_ext
from rl_games.algos_torch.running_mean_std import RunningMeanStd, RunningMeanStdObs
Expand All @@ -14,8 +13,8 @@
class CentralValueTrain(nn.Module):

def __init__(self, state_shape, value_size, ppo_device, num_agents, \
horizon_length, num_actors, num_actions, seq_len, \
normalize_value,network, config, writter, max_epochs, multi_gpu):
horizon_length, num_actors, num_actions, seq_len, normalize_value, \
network, config, writter, max_epochs, max_frames, multi_gpu):
nn.Module.__init__(self)

self.ppo_device = ppo_device
Expand All @@ -25,6 +24,7 @@ def __init__(self, state_shape, value_size, ppo_device, num_agents, \
self.state_shape = state_shape
self.value_size = value_size
self.max_epochs = max_epochs
self.max_frames = max_frames
self.multi_gpu = multi_gpu
self.truncate_grads = config.get('truncate_grads', False)
self.config = config
Expand All @@ -43,14 +43,28 @@ def __init__(self, state_shape, value_size, ppo_device, num_agents, \
self.lr = float(config['learning_rate'])
self.linear_lr = config.get('lr_schedule') == 'linear'

# todo: support max frames as well
if self.linear_lr:
self.scheduler = schedulers.LinearScheduler(self.lr,
max_steps = self.max_epochs,
apply_to_entropy = False,
start_entropy_coef = 0)
if self.max_epochs == -1 and self.max_frames == -1:
print("Max epochs and max frames are not set. Linear learning rate schedule can't be used, switching to the contstant (identity) one.")
self.scheduler = schedulers.IdentityScheduler()
else:
print("Linear lr schedule. Min lr = ", self.min_lr)
use_epochs = True
max_steps = self.max_epochs

if self.max_epochs == -1:
use_epochs = False
max_steps = self.max_frames

self.scheduler = schedulers.LinearScheduler(self.lr,
min_lr = self.min_lr,
max_steps = max_steps,
use_epochs = use_epochs,
apply_to_entropy = False,
start_entropy_coef = 0.0)
else:
self.scheduler = schedulers.IdentityScheduler()


self.mini_epoch = config['mini_epochs']
assert(('minibatch_size_per_env' in self.config) or ('minibatch_size' in self.config))
Expand Down Expand Up @@ -172,7 +186,6 @@ def post_step_rnn(self, all_done_indices):
def forward(self, input_dict):
return self.model(input_dict)


def get_value(self, input_dict):
self.eval()
obs_batch = input_dict['states']
Expand All @@ -197,8 +210,8 @@ def train_critic(self, input_dict):
def update_multiagent_tensors(self, value_preds, returns, actions, dones):
batch_size = self.batch_size
ma_batch_size = self.num_actors * self.num_agents * self.horizon_length
value_preds = value_preds.view(self.num_actors, self.num_agents, self.horizon_length, self.value_size).transpose(0,1)
returns = returns.view(self.num_actors, self.num_agents, self.horizon_length, self.value_size).transpose(0,1)
value_preds = value_preds.view(self.num_actors, self.num_agents, self.horizon_length, self.value_size).transpose(0, 1)
returns = returns.view(self.num_actors, self.num_agents, self.horizon_length, self.value_size).transpose(0, 1)
value_preds = value_preds.contiguous().view(ma_batch_size, self.value_size)[:batch_size]
returns = returns.contiguous().view(ma_batch_size, self.value_size)[:batch_size]
dones = dones.contiguous().view(ma_batch_size, self.value_size)[:batch_size]
Expand All @@ -216,12 +229,13 @@ def train_net(self):
avg_loss = loss / (self.mini_epoch * self.num_minibatches)

self.epoch_num += 1
self.lr, _ = self.scheduler.update(self.lr, 0, self.epoch_num, 0, 0)
self.lr, _ = self.scheduler.update(self.lr, 0, self.epoch_num, self.frame, 0)
self.update_lr(self.lr)
self.frame += self.batch_size
if self.writter != None:
self.writter.add_scalar('losses/cval_loss', avg_loss, self.frame)
self.writter.add_scalar('info/cval_lr', self.lr, self.frame)
self.writter.add_scalar('info/cval_lr', self.lr, self.frame)

return avg_loss

def calc_gradients(self, batch):
Expand Down
10 changes: 6 additions & 4 deletions rl_games/algos_torch/model_builder.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from rl_games.common import object_factory
import rl_games.algos_torch
from rl_games.algos_torch import network_builder
from rl_games.algos_torch import models
from rl_games.algos_torch import network_builder, models


NETWORK_REGISTRY = {}
MODEL_REGISTRY = {}
Expand All @@ -14,13 +13,13 @@ def register_model(name, target_class):


class NetworkBuilder:

def __init__(self):
self.network_factory = object_factory.ObjectFactory()
self.network_factory.set_builders(NETWORK_REGISTRY)
self.network_factory.register_builder('actor_critic', lambda **kwargs: network_builder.A2CBuilder())
self.network_factory.register_builder('resnet_actor_critic',
lambda **kwargs: network_builder.A2CResnetBuilder())
self.network_factory.register_builder('rnd_curiosity', lambda **kwargs: network_builder.RNDCuriosityBuilder())
self.network_factory.register_builder('soft_actor_critic', lambda **kwargs: network_builder.SACBuilder())

def load(self, params):
Expand All @@ -32,6 +31,7 @@ def load(self, params):


class ModelBuilder:

def __init__(self):
self.model_factory = object_factory.ObjectFactory()
self.model_factory.set_builders(MODEL_REGISTRY)
Expand All @@ -46,6 +46,8 @@ def __init__(self):
lambda network, **kwargs: models.ModelSACContinuous(network))
self.model_factory.register_builder('central_value',
lambda network, **kwargs: models.ModelCentralValue(network))
self.model_factory.register_builder('shac',
lambda network, **kwargs: models.ModelA2CContinuousSHAC(network))
self.network_builder = NetworkBuilder()

def get_network_builder(self):
Expand Down
62 changes: 50 additions & 12 deletions rl_games/algos_torch/models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import rl_games.algos_torch.layers
import numpy as np
import torch.nn as nn
import torch
Expand Down Expand Up @@ -28,7 +27,9 @@ def build(self, config):
return self.Network(self.network_builder.build(self.model_class, **config), obs_shape=obs_shape,
normalize_value=normalize_value, normalize_input=normalize_input, value_size=value_size)


class BaseModelNetwork(nn.Module):

def __init__(self, obs_shape, normalize_value, normalize_input, value_size):
nn.Module.__init__(self)
self.obs_shape = obs_shape
Expand All @@ -45,12 +46,11 @@ def __init__(self, obs_shape, normalize_value, normalize_input, value_size):
self.running_mean_std = RunningMeanStd(obs_shape)

def norm_obs(self, observation):
with torch.no_grad():
return self.running_mean_std(observation) if self.normalize_input else observation
return self.running_mean_std(observation) if self.normalize_input else observation

def unnorm_value(self, value):
with torch.no_grad():
return self.value_mean_std(value, unnorm=True) if self.normalize_value else value
return self.value_mean_std(value, unnorm=True) if self.normalize_value else value


class ModelA2C(BaseModel):
def __init__(self, network):
Expand All @@ -64,7 +64,7 @@ def __init__(self, a2c_network, **kwargs):

def is_rnn(self):
return self.a2c_network.is_rnn()

def get_default_rnn_state(self):
return self.a2c_network.get_default_rnn_state()

Expand Down Expand Up @@ -105,7 +105,9 @@ def forward(self, input_dict):
}
return result


class ModelA2CMultiDiscrete(BaseModel):

def __init__(self, network):
BaseModel.__init__(self, 'a2c')
self.network_builder = network
Expand Down Expand Up @@ -169,7 +171,9 @@ def forward(self, input_dict):
}
return result


class ModelA2CContinuous(BaseModel):

def __init__(self, network):
BaseModel.__init__(self, 'a2c')
self.network_builder = network
Expand Down Expand Up @@ -225,6 +229,7 @@ def forward(self, input_dict):


class ModelA2CContinuousLogStd(BaseModel):

def __init__(self, network):
BaseModel.__init__(self, 'a2c')
self.network_builder = network
Expand Down Expand Up @@ -278,7 +283,41 @@ def neglogp(self, x, mean, std, logstd):
+ logstd.sum(dim=-1)


class ModelA2CContinuousSHAC(BaseModel):
def __init__(self, network):
BaseModel.__init__(self, 'a2c')
self.network_builder = network

class Network(BaseModelNetwork):
def __init__(self, a2c_network, **kwargs):
BaseModelNetwork.__init__(self, **kwargs)
self.a2c_network = a2c_network

def is_rnn(self):
return self.a2c_network.is_rnn()

def get_default_rnn_state(self):
return self.a2c_network.get_default_rnn_state()

def forward(self, input_dict):
input_dict['obs'] = self.norm_obs(input_dict['obs'])
mu, logstd, _, states = self.a2c_network(input_dict)
sigma = torch.exp(logstd)
distr = torch.distributions.Normal(mu, sigma)
entropy = distr.entropy().sum(dim=-1)
selected_action = distr.rsample()
result = {
'actions': selected_action,
'entropy': entropy,
'rnn_states': states,
'mus': mu,
'sigmas': sigma
}
return result


class ModelCentralValue(BaseModel):

def __init__(self, network):
BaseModel.__init__(self, 'a2c')
self.network_builder = network
Expand All @@ -304,24 +343,22 @@ def forward(self, input_dict):
value, states = self.a2c_network(input_dict)
if not is_train:
value = self.unnorm_value(value)

result = {
'values': value,
'rnn_states': states
}
return result



class ModelSACContinuous(BaseModel):

def __init__(self, network):
BaseModel.__init__(self, 'sac')
self.network_builder = network

class Network(BaseModelNetwork):
def __init__(self, sac_network,**kwargs):
BaseModelNetwork.__init__(self,**kwargs)
def __init__(self, sac_network, **kwargs):
BaseModelNetwork.__init__(self, **kwargs)
self.sac_network = sac_network

def critic(self, obs, action):
Expand All @@ -332,7 +369,7 @@ def critic_target(self, obs, action):

def actor(self, obs):
return self.sac_network.actor(obs)

def is_rnn(self):
return False

Expand All @@ -344,3 +381,4 @@ def forward(self, input_dict):




Loading