Skip to content

Commit

Permalink
Merge pull request #399 from kengz/sac2
Browse files Browse the repository at this point in the history
Soft Actor-Critic improvements
  • Loading branch information
kengz authored Aug 11, 2019
2 parents 85a2f39 + f6a2922 commit 4fb2efe
Show file tree
Hide file tree
Showing 18 changed files with 362 additions and 141 deletions.
12 changes: 7 additions & 5 deletions BENCHMARK.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,14 @@ The specs for these are contained in the [`slm_lab/spec/benchmark`](https://gith

[Roboschool](https://github.com/openai/roboschool) by OpenAI offers free open source robotics simulations with improved physics. Although it mirrors the environments from MuJuCo, its environments' rewards are different.

>The results for SAC are uploaded in [PR 399](https://github.com/kengz/SLM-Lab/pull/399).
| Env. \ Alg. | A2C (GAE) | A2C (n-step) | PPO | SAC |
|:---|---|---|---|---|---|
| RoboschoolAnt | | | | 1153.87 <details><summary><i>graph</i></summary><img src="https://user-images.githubusercontent.com/8209263/62429426-5f952a80-b6c3-11e9-8cf7-ee2bc908b2b3.png"></details> |
| RoboschoolHalfCheetah | | | | 1204.68 <details><summary><i>graph</i></summary><img src="https://user-images.githubusercontent.com/8209263/62429436-7471be00-b6c3-11e9-8343-cd646aca68e7.png"></details> |
| RoboschoolHopper | | | | 1161.24 <details><summary><i>graph</i></summary><img src="https://user-images.githubusercontent.com/8209263/62429437-79367200-b6c3-11e9-8a05-2c1fd0eb5e1f.png"></details> |
| RoboschoolWalker2d | | | | 695.36 <details><summary><i>graph</i></summary><img src="https://user-images.githubusercontent.com/8209263/62429440-7cc9f900-b6c3-11e9-8d06-1476393d0e9e.png"></details> |
|:---|---|---|---|---|
| RoboschoolAnt | | | | 2451.55 <details><summary><i>graph</i></summary><img src="https://user-images.githubusercontent.com/8209263/62837481-c1eead80-bc24-11e9-913e-7685d64ecf87.png"></details> |
| RoboschoolHalfCheetah | | | | 2004.27 <details><summary><i>graph</i></summary><img src="https://user-images.githubusercontent.com/8209263/62837485-daf75e80-bc24-11e9-8fba-279802ccdd1d.png"></details> |
| RoboschoolHopper | | | | 2090.52 <details><summary><i>graph</i></summary><img src="https://user-images.githubusercontent.com/8209263/62837491-e8144d80-bc24-11e9-9d06-27a35b4aacca.png"></details> |
| RoboschoolWalker2d | | | | 1711.92 <details><summary><i>graph</i></summary><img src="https://user-images.githubusercontent.com/8209263/62837495-f2364c00-bc24-11e9-8bdc-fa88831c227b.png"></details> |


### Classic Benchmark
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def run_tests(self):

setup(
name='slm_lab',
version='4.0.0',
version='4.0.1',
description='Modular Deep Reinforcement Learning framework in PyTorch.',
long_description='https://github.com/kengz/slm_lab',
keywords='SLM Lab',
Expand Down
9 changes: 5 additions & 4 deletions slm_lab/agent/algorithm/policy_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@

# register custom distributions
setattr(distributions, 'Argmax', distribution.Argmax)
setattr(distributions, 'GumbelCategorical', distribution.GumbelCategorical)
setattr(distributions, 'GumbelSoftmax', distribution.GumbelSoftmax)
setattr(distributions, 'MultiCategorical', distribution.MultiCategorical)
# probability distributions constraints for different action types; the first in the list is the default
ACTION_PDS = {
'continuous': ['Normal', 'Beta', 'Gumbel', 'LogNormal'],
'multi_continuous': ['MultivariateNormal'],
'discrete': ['Categorical', 'Argmax', 'GumbelCategorical', 'RelaxedOneHotCategorical'],
'discrete': ['Categorical', 'Argmax', 'GumbelSoftmax'],
'multi_discrete': ['MultiCategorical'],
'multi_binary': ['Bernoulli'],
}
Expand Down Expand Up @@ -93,7 +93,8 @@ def init_action_pd(ActionPD, pdparam):
- discrete: action_pd = ActionPD(logits)
- continuous: action_pd = ActionPD(loc, scale)
'''
if 'logits' in ActionPD.arg_constraints: # discrete
args = ActionPD.arg_constraints
if 'logits' in args: # discrete
# for relaxed discrete dist. with reparametrizable discrete actions
pd_kwargs = {'temperature': torch.tensor(1.0)} if hasattr(ActionPD, 'temperature') else {}
action_pd = ActionPD(logits=pdparam, **pd_kwargs)
Expand All @@ -104,7 +105,7 @@ def init_action_pd(ActionPD, pdparam):
loc, scale = pdparam.transpose(0, 1)
# scale (stdev) must be > 0, log-clamp-exp
scale = torch.clamp(scale, min=-20, max=2).exp()
if isinstance(pdparam, list): # split output
if 'covariance_matrix' in args: # split output
# construct covars from a batched scale tensor
covars = torch.diag_embed(scale)
action_pd = ActionPD(loc=loc, covariance_matrix=covars)
Expand Down
189 changes: 101 additions & 88 deletions slm_lab/agent/algorithm/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from slm_lab.lib.decorator import lab_api
import numpy as np
import torch
import torch.nn.functional as F

logger = logger.get_logger(__name__)

Expand All @@ -15,6 +16,8 @@ class SoftActorCritic(ActorCritic):
Implementation of Soft Actor-Critic (SAC)
Original paper: "Soft Actor-Critic: Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor"
https://arxiv.org/abs/1801.01290
Improvement of SAC paper: "Soft Actor-Critic Algorithms and Applications"
https://arxiv.org/abs/1812.05905
e.g. algorithm_spec
"algorithm": {
Expand All @@ -41,43 +44,48 @@ def init_algorithm_params(self):
'gamma', # the discount factor
'training_iter',
'training_frequency',
'training_start_step',
])
if self.body.is_discrete:
assert self.action_pdtype == 'GumbelSoftmax'
self.to_train = 0
self.action_policy = getattr(policy_util, self.action_policy)

@lab_api
def init_nets(self, global_nets=None):
'''
Networks: net(actor/policy), critic (value), target_critic, q1_net, q1_net
Networks: net(actor/policy), q1_net, target_q1_net, q2_net, target_q2_net
All networks are separate, and have the same hidden layer architectures and optim specs, so tuning is minimal
'''
self.shared = False # SAC does not share networks
in_dim = self.body.state_dim
out_dim = net_util.get_out_dim(self.body)
NetClass = getattr(net, self.net_spec['type'])
# main actor network
self.net = NetClass(self.net_spec, in_dim, out_dim)
self.net = NetClass(self.net_spec, self.body.state_dim, net_util.get_out_dim(self.body))
self.net_names = ['net']
# critic network and its target network
val_out_dim = 1
self.critic_net = NetClass(self.net_spec, in_dim, val_out_dim)
self.target_critic_net = NetClass(self.net_spec, in_dim, val_out_dim)
self.net_names += ['critic_net', 'target_critic_net']
# two Q-networks to mitigate positive bias in q_loss and speed up training
q_in_dim = in_dim + self.body.action_dim # NOTE concat s, a for now
self.q1_net = NetClass(self.net_spec, q_in_dim, val_out_dim)
self.q2_net = NetClass(self.net_spec, q_in_dim, val_out_dim)
self.net_names += ['q1_net', 'q2_net']
# two critic Q-networks to mitigate positive bias in q_loss and speed up training, uses q_net.py with prefix Q
QNetClass = getattr(net, 'Q' + self.net_spec['type'])
q_in_dim = [self.body.state_dim, self.body.action_dim]
self.q1_net = QNetClass(self.net_spec, q_in_dim, 1)
self.target_q1_net = QNetClass(self.net_spec, q_in_dim, 1)
self.q2_net = QNetClass(self.net_spec, q_in_dim, 1)
self.target_q2_net = QNetClass(self.net_spec, q_in_dim, 1)
self.net_names += ['q1_net', 'target_q1_net', 'q2_net', 'target_q2_net']
net_util.copy(self.q1_net, self.target_q1_net)
net_util.copy(self.q2_net, self.target_q2_net)
# temperature variable to be learned, and its target entropy
self.log_alpha = torch.zeros(1, requires_grad=True, device=self.net.device)
self.alpha = self.log_alpha.detach().exp()
self.target_entropy = - np.product(self.body.action_space.shape)

# init net optimizer and its lr scheduler
self.optim = net_util.get_optim(self.net, self.net.optim_spec)
self.lr_scheduler = net_util.get_lr_scheduler(self.optim, self.net.lr_scheduler_spec)
self.critic_optim = net_util.get_optim(self.critic_net, self.critic_net.optim_spec)
self.critic_lr_scheduler = net_util.get_lr_scheduler(self.critic_optim, self.critic_net.lr_scheduler_spec)
self.q1_optim = net_util.get_optim(self.q1_net, self.q1_net.optim_spec)
self.q1_lr_scheduler = net_util.get_lr_scheduler(self.q1_optim, self.q1_net.lr_scheduler_spec)
self.q2_optim = net_util.get_optim(self.q2_net, self.q2_net.optim_spec)
self.q2_lr_scheduler = net_util.get_lr_scheduler(self.q2_optim, self.q2_net.lr_scheduler_spec)
self.alpha_optim = net_util.get_optim(self.log_alpha, self.net.optim_spec)
self.alpha_lr_scheduler = net_util.get_lr_scheduler(self.alpha_optim, self.net.lr_scheduler_spec)
net_util.set_global_nets(self, global_nets)
self.post_init_nets()

Expand All @@ -87,45 +95,54 @@ def act(self, state):
return policy_util.random(state, self, self.body).cpu().squeeze().numpy()
else:
action = self.action_policy(state, self, self.body)
if self.body.is_discrete:
# discrete output is RelaxedOneHotCategorical, need to sample to int
action = torch.distributions.Categorical(probs=action).sample()
else:
action = torch.tanh(action) # continuous action bound
if not self.body.is_discrete:
action = self.scale_action(torch.tanh(action)) # continuous action bound
return action.cpu().squeeze().numpy()

def calc_q(self, state, action, net=None):
'''Forward-pass to calculate the predicted state-action-value from q1_net.'''
x = torch.cat((state, action), dim=-1)
net = self.q1_net if net is None else net
q_pred = net(x).view(-1)
return q_pred

def calc_v_targets(self, batch, action_pd):
'''V_tar = Q(s, a) - log pi(a|s), Q(s, a) = min(Q1(s, a), Q2(s, a))'''
states = batch['states']
with torch.no_grad():
if self.body.is_discrete:
actions = action_pd.sample()
log_probs = action_pd.log_prob(actions)
else:
mus = action_pd.sample()
actions = torch.tanh(mus)
# paper Appendix C. Enforcing Action Bounds for continuous actions
log_probs = action_pd.log_prob(mus) - torch.log(1 - actions.pow(2) + 1e-6).sum(1)
def scale_action(self, action):
'''Scale continuous actions from tanh range'''
action_space = self.body.action_space
low, high = torch.from_numpy(action_space.low), torch.from_numpy(action_space.high)
return action * (high - low) / 2 + (low + high) / 2

q1_preds = self.calc_q(states, actions, self.q1_net)
q2_preds = self.calc_q(states, actions, self.q2_net)
q_preds = torch.min(q1_preds, q2_preds)
def guard_q_actions(self, actions):
'''Guard to convert actions to one-hot for input to Q-network'''
if self.body.is_discrete:
# TODO support multi-discrete actions
actions = F.one_hot(actions.long(), self.body.action_dim).float()
return actions

def calc_log_prob_action(self, action_pd, reparam=False):
'''Calculate log_probs and actions with option to reparametrize from paper eq. 11'''
samples = action_pd.rsample() if reparam else action_pd.sample()
if self.body.is_discrete: # this is straightforward using GumbelSoftmax
actions = samples
log_probs = action_pd.log_prob(actions)
else:
mus = samples
actions = self.scale_action(torch.tanh(mus))
# paper Appendix C. Enforcing Action Bounds for continuous actions
log_probs = (action_pd.log_prob(mus) - torch.log(1 - actions.pow(2) + 1e-6).sum(1))
return log_probs, actions

v_targets = q_preds - log_probs
return v_targets
def calc_q(self, state, action, net):
'''Forward-pass to calculate the predicted state-action-value from q1_net.'''
q_pred = net(state, action).view(-1)
return q_pred

def calc_q_targets(self, batch):
'''Q_tar = r + gamma * V_pred(s'; target_critic)'''
'''Q_tar = r + gamma * (target_Q(s', a') - alpha * log pi(a'|s'))'''
next_states = batch['next_states']
with torch.no_grad():
target_next_v_preds = self.calc_v(batch['next_states'], net=self.target_critic_net)
q_targets = batch['rewards'] + self.gamma * (1 - batch['dones']) * target_next_v_preds
pdparams = self.calc_pdparam(next_states)
action_pd = policy_util.init_action_pd(self.body.ActionPD, pdparams)
next_log_probs, next_actions = self.calc_log_prob_action(action_pd)
next_actions = self.guard_q_actions(next_actions) # non-reparam discrete actions need to be converted into one-hot

next_target_q1_preds = self.calc_q(next_states, next_actions, self.target_q1_net)
next_target_q2_preds = self.calc_q(next_states, next_actions, self.target_q2_net)
next_target_q_preds = torch.min(next_target_q1_preds, next_target_q2_preds)
q_targets = batch['rewards'] + self.gamma * (1 - batch['dones']) * (next_target_q_preds - self.alpha * next_log_probs)
return q_targets

def calc_reg_loss(self, preds, targets):
Expand All @@ -134,31 +151,33 @@ def calc_reg_loss(self, preds, targets):
reg_loss = self.net.loss_fn(preds, targets)
return reg_loss

def calc_policy_loss(self, batch, action_pd):
'''policy_loss = log pi(f(a)|s) - Q1(s, f(a)), where f(a) = reparametrized action'''
def calc_policy_loss(self, batch, log_probs, reparam_actions):
'''policy_loss = alpha * log pi(f(a)|s) - Q1(s, f(a)), where f(a) = reparametrized action'''
states = batch['states']
if self.body.is_discrete:
reparam_actions = action_pd.rsample()
log_probs = action_pd.log_prob(reparam_actions)
else:
reparam_mus = action_pd.rsample() # reparametrization for paper eq. 11
reparam_actions = torch.tanh(reparam_mus)
# paper Appendix C. Enforcing Action Bounds for continuous actions
log_probs = action_pd.log_prob(reparam_mus) - torch.log(1 - reparam_actions.pow(2) + 1e-6).sum(1)

q1_preds = self.calc_q(states, reparam_actions, self.q1_net)
q2_preds = self.calc_q(states, reparam_actions, self.q2_net)
q_preds = torch.min(q1_preds, q2_preds)

policy_loss = (log_probs - q_preds).mean()
policy_loss = (self.alpha * log_probs - q_preds).mean()
return policy_loss

def calc_alpha_loss(self, log_probs):
alpha_loss = - (self.log_alpha * (log_probs.detach() + self.target_entropy)).mean()
return alpha_loss

def try_update_per(self, q_preds, q_targets):
if 'Prioritized' in util.get_class_name(self.body.memory): # PER
with torch.no_grad():
errors = (q_preds - q_targets).abs().cpu().numpy()
self.body.memory.update_priorities(errors)

def train_alpha(self, alpha_loss):
'''Custom method to train the alpha variable'''
self.alpha_lr_scheduler.step(epoch=self.body.env.clock.frame)
self.alpha_optim.zero_grad()
alpha_loss.backward()
self.alpha_optim.step()
self.alpha = self.log_alpha.detach().exp()

def train(self):
'''Train actor critic by computing the loss in batch efficiently'''
if util.in_eval_lab_modes():
Expand All @@ -169,38 +188,30 @@ def train(self):
batch = self.sample()
clock.set_batch_size(len(batch))

# forward passes for losses
states = batch['states']
actions = batch['actions']
if self.body.is_discrete:
# to one-hot discrete action for Q input.
# TODO support multi-discrete actions
actions = torch.eye(self.body.action_dim)[actions.long()]
pdparams = self.calc_pdparam(states)
action_pd = policy_util.init_action_pd(self.body.ActionPD, pdparams)

# V-value loss
v_preds = self.calc_v(states, net=self.critic_net)
v_targets = self.calc_v_targets(batch, action_pd)
val_loss = self.calc_reg_loss(v_preds, v_targets)
self.critic_net.train_step(val_loss, self.critic_optim, self.critic_lr_scheduler, clock=clock, global_net=self.global_critic_net)

# Q-value loss for both Q nets
actions = self.guard_q_actions(batch['actions'])
q_targets = self.calc_q_targets(batch)
# Q-value loss for both Q nets
q1_preds = self.calc_q(states, actions, self.q1_net)
q1_loss = self.calc_reg_loss(q1_preds, q_targets)
self.q1_net.train_step(q1_loss, self.q1_optim, self.q1_lr_scheduler, clock=clock, global_net=self.global_q1_net)

q2_preds = self.calc_q(states, actions, self.q2_net)
q2_loss = self.calc_reg_loss(q2_preds, q_targets)
self.q2_net.train_step(q2_loss, self.q2_optim, self.q2_lr_scheduler, clock=clock, global_net=self.global_q2_net)

# policy loss
policy_loss = self.calc_policy_loss(batch, action_pd)
action_pd = policy_util.init_action_pd(self.body.ActionPD, self.calc_pdparam(states))
log_probs, reparam_actions = self.calc_log_prob_action(action_pd, reparam=True)
policy_loss = self.calc_policy_loss(batch, log_probs, reparam_actions)
self.net.train_step(policy_loss, self.optim, self.lr_scheduler, clock=clock, global_net=self.global_net)

loss = policy_loss + val_loss + q1_loss + q2_loss
# alpha loss
alpha_loss = self.calc_alpha_loss(log_probs)
self.train_alpha(alpha_loss)

# update target_critic_net
loss = q1_loss + q2_loss + policy_loss + alpha_loss
# update target networks
self.update_nets()
# update PER priorities if availalbe
self.try_update_per(torch.min(q1_preds, q2_preds), q_targets)
Expand All @@ -213,16 +224,18 @@ def train(self):
return np.nan

def update_nets(self):
'''Update target critic net'''
if util.frame_mod(self.body.env.clock.frame, self.critic_net.update_frequency, self.body.env.num_envs):
if self.critic_net.update_type == 'replace':
net_util.copy(self.critic_net, self.target_critic_net)
elif self.critic_net.update_type == 'polyak':
net_util.polyak_update(self.critic_net, self.target_critic_net, self.critic_net.polyak_coef)
'''Update target networks'''
if util.frame_mod(self.body.env.clock.frame, self.q1_net.update_frequency, self.body.env.num_envs):
if self.q1_net.update_type == 'replace':
net_util.copy(self.q1_net, self.target_q1_net)
net_util.copy(self.q2_net, self.target_q2_net)
elif self.q1_net.update_type == 'polyak':
net_util.polyak_update(self.q1_net, self.target_q1_net, self.q1_net.polyak_coef)
net_util.polyak_update(self.q2_net, self.target_q2_net, self.q2_net.polyak_coef)
else:
raise ValueError('Unknown critic_net.update_type. Should be "replace" or "polyak". Exiting.')
raise ValueError('Unknown q1_net.update_type. Should be "replace" or "polyak". Exiting.')

@lab_api
def update(self):
'''Updates self.target_critic_net and the explore variables'''
'''Override parent method to do nothing'''
return self.body.explore_var
1 change: 1 addition & 0 deletions slm_lab/agent/net/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from slm_lab.agent.net.conv import *
from slm_lab.agent.net.mlp import *
from slm_lab.agent.net.recurrent import *
from slm_lab.agent.net.q_net import *
Loading

0 comments on commit 4fb2efe

Please sign in to comment.