Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Soft Actor-Critic #398

Merged
merged 37 commits into from
Aug 4, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
a347afa
add SAC class outline and spec
kengz Aug 1, 2019
c093d15
add calc_val_loss
kengz Aug 1, 2019
e698b32
finish first SAC implementation
kengz Aug 1, 2019
d78827c
allow use of num_envs by offsetting with training iter
kengz Aug 1, 2019
795d13e
expose training_iter to spec
kengz Aug 1, 2019
461d32e
add async_sac spec
kengz Aug 1, 2019
1ea3582
add continuous action bound fix
kengz Aug 1, 2019
b6636cf
fix random_baseline env filter
kengz Aug 1, 2019
c5519a5
update random baselines for roboschool
kengz Aug 1, 2019
ff82344
lower log freq
kengz Aug 1, 2019
672672f
update enjoy mode doc
kengz Aug 1, 2019
db7bd77
add sac bipedalwalker spec
kengz Aug 1, 2019
e9216bf
add tanh to action rsample
kengz Aug 1, 2019
4234eba
fix SAC
kengz Aug 2, 2019
63b4c6f
update PER error term
kengz Aug 2, 2019
b8dc8a7
add SAC+PER spec
kengz Aug 2, 2019
b4202d1
use body batch_size as default training_start_step
kengz Aug 2, 2019
782f1a8
remove sac target net copy to follow the paper
kengz Aug 2, 2019
5e64e88
use purely random action before training sac
kengz Aug 2, 2019
e18bae0
fix sac random action shape
kengz Aug 2, 2019
6e1dce9
remove cont_hard specs
kengz Aug 2, 2019
97f8db4
move ac roboschool spec down
kengz Aug 2, 2019
a447ad8
add sac specs
kengz Aug 2, 2019
065f04b
split ac roboschool specs to new files
kengz Aug 2, 2019
7a4e19a
add async_sac spec
kengz Aug 2, 2019
59b6a0b
add SAC to README
kengz Aug 2, 2019
4a37762
set session to 4
kengz Aug 2, 2019
75d9f3e
wrap agent.act in no_grad for speedup
kengz Aug 3, 2019
5a126fc
add discrete action for SAC using RelaxedOneHotCategorical with rsample
kengz Aug 3, 2019
34c2e6f
add default temperature for relaxed discrete policy dist
kengz Aug 4, 2019
10442f4
fix analyze session min condition
kengz Aug 4, 2019
1a891bd
commit the best lunar spec found
kengz Aug 4, 2019
7a66dc0
remove arbitrary action clamp
kengz Aug 4, 2019
c7fb22d
fix policy loss to use q_preds instead of q1_preds. lunar working
kengz Aug 4, 2019
016618c
commit lunar search spec
kengz Aug 4, 2019
9934616
add Roboschool benchmark table
kengz Aug 4, 2019
a6c06e7
fix sac readme typo
kengz Aug 4, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions BENCHMARK.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@ The data can be downloaded into SLM Lab's `data/` folder and [reran in enjoy mod
- A2C (n-step): Advantage Actor-Critic with n-step return as advantage estimation
- A3C: Asynchronous Advantage Actor-Critic
- CER: Combined Experience Replay
- DDQN: Double Deep Q-Learning
- DDQN: Double Deep Q-Network
- DIST: Distributed
- DQN: Deep Q-learning
- DQN: Deep Q-Network
- GAE: Generalized Advantage Estimation
- PER: Prioritized Experience Replay
- PPO: Proximal Policy Optimization
- SAC: Soft Actor-Critic
- SIL: Self Imitation Learning

### Atari Benchmark
Expand All @@ -38,6 +39,17 @@ The specs for these are contained in the [`slm_lab/spec/benchmark`](https://gith
| Seaquest <details><summary><i>graph</i></summary><img src="https://user-images.githubusercontent.com/8209263/62232168-6bf34d80-b37a-11e9-9564-fa3609dc5c75.png"><img src="https://user-images.githubusercontent.com/8209263/62232167-6bf34d80-b37a-11e9-8db3-c79a0e78292b.png"></details> | 892.68 <details><summary><i>graph</i></summary><img src="https://user-images.githubusercontent.com/8209263/62020266-29adee80-b177-11e9-83c2-fafbdbb982b9.png"></details> | 1,686.08 <details><summary><i>graph</i></summary><img src="https://user-images.githubusercontent.com/8209263/62020350-772a5b80-b177-11e9-8917-e3c8a745cd08.png"></details> | 1,583.04 <details><summary><i>graph</i></summary><img src="https://user-images.githubusercontent.com/8209263/62067113-cf994180-b1e7-11e9-870b-b9bba71f2a7e.png"></details> | 1,118.50 <details><summary><i>graph</i></summary><img src="https://user-images.githubusercontent.com/8209263/62100462-a9ef5500-b246-11e9-8699-9356ff81ff93.png"></details> | **3,751.34** <details><summary><i>graph</i></summary><img src="https://user-images.githubusercontent.com/8209263/62230991-ebcbe880-b377-11e9-8de4-a01379d1d61c.png"></details> |


### Roboschool Benchmark

[Roboschool](https://github.com/openai/roboschool) by OpenAI offers free open source robotics simulations with improved physics. Although it mirrors the environments from MuJuCo, its environments' rewards are different.

| Env. \ Alg. | A2C (GAE) | A2C (n-step) | PPO | SAC |
|:---|---|---|---|---|---|
| RoboschoolAnt | | | | 1153.87 <details><summary><i>graph</i></summary><img src="https://user-images.githubusercontent.com/8209263/62429426-5f952a80-b6c3-11e9-8cf7-ee2bc908b2b3.png"></details> |
| RoboschoolHalfCheetah | | | | 1204.68 <details><summary><i>graph</i></summary><img src="https://user-images.githubusercontent.com/8209263/62429436-7471be00-b6c3-11e9-8343-cd646aca68e7.png"></details> |
| RoboschoolHopper | | | | 1161.24 <details><summary><i>graph</i></summary><img src="https://user-images.githubusercontent.com/8209263/62429437-79367200-b6c3-11e9-8a05-2c1fd0eb5e1f.png"></details> |
| RoboschoolWalker2d | | | | 695.36 <details><summary><i>graph</i></summary><img src="https://user-images.githubusercontent.com/8209263/62429440-7cc9f900-b6c3-11e9-8d06-1476393d0e9e.png"></details> |


### Classic Benchmark

Expand Down
27 changes: 18 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,16 @@ The benchmark results also include complete [spec files](https://github.com/keng

Below shows the latest benchmark status. See the full [benchmark results here](https://github.com/kengz/SLM-Lab/blob/master/BENCHMARK.md).

| **Algorithm\Benchmark** | Atari |
|-------------------------|-------|
| SARSA | - |
| DQN | :white_check_mark: |
| Double-DQN, Dueling-DQN, PER-DQN | :white_check_mark: |
| REINFORCE | - |
| A2C (GAE & n-step) | :white_check_mark: |
| PPO | :white_check_mark: |
| SIL (A2C, PPO) | |
| **Algorithm\Benchmark** | Atari | Roboschool |
|-------------------------|-------|-------|
| SARSA | - | - |
| DQN (Deep Q-Network) | :white_check_mark: | - |
| Double-DQN, Dueling-DQN, PER-DQN | :white_check_mark: | - |
| REINFORCE | - | - |
| A2C with GAE & n-step (Advantage Actor-Critic) | :white_check_mark: | |
| PPO (Proximal Policy Optimization) | :white_check_mark: | |
| SIL (Self Imitation Learning) | | |
| SAC (Soft Actor-Critic) | | :white_check_mark: |

Due to their standardized design, all the algorithms can be parallelized asynchronously using Hogwild. Hence, SLM Lab also includes A3C, distributed-DQN, distributed-PPO.

Expand Down Expand Up @@ -141,6 +142,14 @@ Below shows a trial graph with multiple sessions:

![](https://kengz.gitbooks.io/slm-lab/content/assets/a2c_gae_pong_t0_trial_graph_mean_returns_ma_vs_frames.png)

### Enjoy mode

Once a Trial completes with a good model saved into the `data/` folder, for example `data/a2c_gae_pong_2019_08_01_010727`, use the `enjoy` mode to show the trained agent playing the environment. Use the `enjoy@{prename}` mode to pick a saved trial-sesison, for example:

```shell
python run_lab.py data/a2c_gae_pong_2019_08_01_010727/a2c_gae_pong_spec.json a2c_gae_pong enjoy@a2c_gae_pong_t0_s0
```

### Benchmark

To run a full benchmark, simply pick a file and run it in train mode. For example, for A2C Atari benchmark, the spec file is `slm_lab/spec/benchmark/a2c/a2c_atari.json`. This file is parametrized to run on a set of environments. Run the benchmark:
Expand Down
14 changes: 0 additions & 14 deletions job/cont_benchmark.json

This file was deleted.

14 changes: 14 additions & 0 deletions job/roboschool_benchmark.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
{
"slm_lab/spec/benchmark/a2c/a2c_nstep_roboschool.json": {
"a2c_nstep_roboschool": "train",
},
"slm_lab/spec/benchmark/a2c/a2c_gae_roboschool.json": {
"a2c_gae_roboschool": "train",
},
"slm_lab/spec/benchmark/ppo/ppo_roboschool.json": {
"ppo_roboschool": "train",
},
"slm_lab/spec/benchmark/sac/sac_roboschool.json": {
"sac_roboschool": "train",
}
}
1 change: 1 addition & 0 deletions slm_lab/agent/algorithm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@
from .ppo import *
from .random import *
from .reinforce import *
from .sac import *
from .sarsa import *
from .sil import *
1 change: 1 addition & 0 deletions slm_lab/agent/algorithm/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def init_algorithm_params(self):
action_pdtype='Argmax',
action_policy='epsilon_greedy',
explore_var_spec=None,
training_start_step=self.body.memory.batch_size,
))
util.set_attr(self, self.algorithm_spec, [
'action_pdtype',
Expand Down
11 changes: 6 additions & 5 deletions slm_lab/agent/algorithm/policy_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import numpy as np
import pydash as ps
import torch
import torch.nn.functional as F

logger = logger.get_logger(__name__)

Expand All @@ -19,7 +18,7 @@
ACTION_PDS = {
'continuous': ['Normal', 'Beta', 'Gumbel', 'LogNormal'],
'multi_continuous': ['MultivariateNormal'],
'discrete': ['Categorical', 'Argmax', 'GumbelCategorical'],
'discrete': ['Categorical', 'Argmax', 'GumbelCategorical', 'RelaxedOneHotCategorical'],
'multi_discrete': ['MultiCategorical'],
'multi_binary': ['Bernoulli'],
}
Expand Down Expand Up @@ -95,14 +94,16 @@ def init_action_pd(ActionPD, pdparam):
- continuous: action_pd = ActionPD(loc, scale)
'''
if 'logits' in ActionPD.arg_constraints: # discrete
action_pd = ActionPD(logits=pdparam)
# for relaxed discrete dist. with reparametrizable discrete actions
pd_kwargs = {'temperature': torch.tensor(1.0)} if hasattr(ActionPD, 'temperature') else {}
action_pd = ActionPD(logits=pdparam, **pd_kwargs)
else: # continuous, args = loc and scale
if isinstance(pdparam, list): # split output
loc, scale = pdparam
else:
loc, scale = pdparam.transpose(0, 1)
# scale (stdev) must be > 0, use softplus with positive
scale = F.softplus(scale) + 1e-8
# scale (stdev) must be > 0, log-clamp-exp
scale = torch.clamp(scale, min=-20, max=2).exp()
if isinstance(pdparam, list): # split output
# construct covars from a batched scale tensor
covars = torch.diag_embed(scale)
Expand Down
228 changes: 228 additions & 0 deletions slm_lab/agent/algorithm/sac.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
from slm_lab.agent import net
from slm_lab.agent.algorithm import policy_util
from slm_lab.agent.algorithm.actor_critic import ActorCritic
from slm_lab.agent.net import net_util
from slm_lab.lib import logger, util
from slm_lab.lib.decorator import lab_api
import numpy as np
import torch

logger = logger.get_logger(__name__)


class SoftActorCritic(ActorCritic):
'''
Implementation of Soft Actor-Critic (SAC)
Original paper: "Soft Actor-Critic: Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor"
https://arxiv.org/abs/1801.01290

e.g. algorithm_spec
"algorithm": {
"name": "SoftActorCritic",
"action_pdtype": "default",
"action_policy": "default",
"gamma": 0.99,
"training_frequency": 1,
}
'''
@lab_api
def init_algorithm_params(self):
'''Initialize other algorithm parameters'''
# set default
util.set_attr(self, dict(
action_pdtype='default',
action_policy='default',
training_iter=self.body.env.num_envs,
training_start_step=self.body.memory.batch_size,
))
util.set_attr(self, self.algorithm_spec, [
'action_pdtype',
'action_policy',
'gamma', # the discount factor
'training_iter',
'training_frequency',
])
self.to_train = 0
self.action_policy = getattr(policy_util, self.action_policy)

@lab_api
def init_nets(self, global_nets=None):
'''
Networks: net(actor/policy), critic (value), target_critic, q1_net, q1_net
All networks are separate, and have the same hidden layer architectures and optim specs, so tuning is minimal
'''
self.shared = False # SAC does not share networks
in_dim = self.body.state_dim
out_dim = net_util.get_out_dim(self.body)
NetClass = getattr(net, self.net_spec['type'])
# main actor network
self.net = NetClass(self.net_spec, in_dim, out_dim)
self.net_names = ['net']
# critic network and its target network
val_out_dim = 1
self.critic_net = NetClass(self.net_spec, in_dim, val_out_dim)
self.target_critic_net = NetClass(self.net_spec, in_dim, val_out_dim)
self.net_names += ['critic_net', 'target_critic_net']
# two Q-networks to mitigate positive bias in q_loss and speed up training
q_in_dim = in_dim + self.body.action_dim # NOTE concat s, a for now
self.q1_net = NetClass(self.net_spec, q_in_dim, val_out_dim)
self.q2_net = NetClass(self.net_spec, q_in_dim, val_out_dim)
self.net_names += ['q1_net', 'q2_net']

# init net optimizer and its lr scheduler
self.optim = net_util.get_optim(self.net, self.net.optim_spec)
self.lr_scheduler = net_util.get_lr_scheduler(self.optim, self.net.lr_scheduler_spec)
self.critic_optim = net_util.get_optim(self.critic_net, self.critic_net.optim_spec)
self.critic_lr_scheduler = net_util.get_lr_scheduler(self.critic_optim, self.critic_net.lr_scheduler_spec)
self.q1_optim = net_util.get_optim(self.q1_net, self.q1_net.optim_spec)
self.q1_lr_scheduler = net_util.get_lr_scheduler(self.q1_optim, self.q1_net.lr_scheduler_spec)
self.q2_optim = net_util.get_optim(self.q2_net, self.q2_net.optim_spec)
self.q2_lr_scheduler = net_util.get_lr_scheduler(self.q2_optim, self.q2_net.lr_scheduler_spec)
net_util.set_global_nets(self, global_nets)
self.post_init_nets()

@lab_api
def act(self, state):
if self.body.env.clock.frame < self.training_start_step:
return policy_util.random(state, self, self.body).cpu().squeeze().numpy()
else:
action = self.action_policy(state, self, self.body)
if self.body.is_discrete:
# discrete output is RelaxedOneHotCategorical, need to sample to int
action = torch.distributions.Categorical(probs=action).sample()
else:
action = torch.tanh(action) # continuous action bound
return action.cpu().squeeze().numpy()

def calc_q(self, state, action, net=None):
'''Forward-pass to calculate the predicted state-action-value from q1_net.'''
x = torch.cat((state, action), dim=-1)
net = self.q1_net if net is None else net
q_pred = net(x).view(-1)
return q_pred

def calc_v_targets(self, batch, action_pd):
'''V_tar = Q(s, a) - log pi(a|s), Q(s, a) = min(Q1(s, a), Q2(s, a))'''
states = batch['states']
with torch.no_grad():
if self.body.is_discrete:
kengz marked this conversation as resolved.
Show resolved Hide resolved
actions = action_pd.sample()
log_probs = action_pd.log_prob(actions)
else:
mus = action_pd.sample()
actions = torch.tanh(mus)
# paper Appendix C. Enforcing Action Bounds for continuous actions
log_probs = action_pd.log_prob(mus) - torch.log(1 - actions.pow(2) + 1e-6).sum(1)

q1_preds = self.calc_q(states, actions, self.q1_net)
q2_preds = self.calc_q(states, actions, self.q2_net)
q_preds = torch.min(q1_preds, q2_preds)

v_targets = q_preds - log_probs
return v_targets

def calc_q_targets(self, batch):
'''Q_tar = r + gamma * V_pred(s'; target_critic)'''
with torch.no_grad():
target_next_v_preds = self.calc_v(batch['next_states'], net=self.target_critic_net)
q_targets = batch['rewards'] + self.gamma * (1 - batch['dones']) * target_next_v_preds
kengz marked this conversation as resolved.
Show resolved Hide resolved
return q_targets

def calc_reg_loss(self, preds, targets):
'''Calculate the regression loss for V and Q values, using the same loss function from net_spec'''
assert preds.shape == targets.shape, f'{preds.shape} != {targets.shape}'
reg_loss = self.net.loss_fn(preds, targets)
return reg_loss

def calc_policy_loss(self, batch, action_pd):
'''policy_loss = log pi(f(a)|s) - Q1(s, f(a)), where f(a) = reparametrized action'''
states = batch['states']
if self.body.is_discrete:
kengz marked this conversation as resolved.
Show resolved Hide resolved
reparam_actions = action_pd.rsample()
log_probs = action_pd.log_prob(reparam_actions)
else:
reparam_mus = action_pd.rsample() # reparametrization for paper eq. 11
reparam_actions = torch.tanh(reparam_mus)
# paper Appendix C. Enforcing Action Bounds for continuous actions
log_probs = action_pd.log_prob(reparam_mus) - torch.log(1 - reparam_actions.pow(2) + 1e-6).sum(1)

q1_preds = self.calc_q(states, reparam_actions, self.q1_net)
q2_preds = self.calc_q(states, reparam_actions, self.q2_net)
q_preds = torch.min(q1_preds, q2_preds)

policy_loss = (log_probs - q_preds).mean()
return policy_loss

def try_update_per(self, q_preds, q_targets):
if 'Prioritized' in util.get_class_name(self.body.memory): # PER
with torch.no_grad():
errors = (q_preds - q_targets).abs().cpu().numpy()
self.body.memory.update_priorities(errors)

def train(self):
'''Train actor critic by computing the loss in batch efficiently'''
if util.in_eval_lab_modes():
return np.nan
clock = self.body.env.clock
if self.to_train == 1:
for _ in range(self.training_iter):
batch = self.sample()
clock.set_batch_size(len(batch))

# forward passes for losses
states = batch['states']
actions = batch['actions']
if self.body.is_discrete:
# to one-hot discrete action for Q input.
# TODO support multi-discrete actions
kengz marked this conversation as resolved.
Show resolved Hide resolved
actions = torch.eye(self.body.action_dim)[actions.long()]
pdparams = self.calc_pdparam(states)
action_pd = policy_util.init_action_pd(self.body.ActionPD, pdparams)

# V-value loss
v_preds = self.calc_v(states, net=self.critic_net)
v_targets = self.calc_v_targets(batch, action_pd)
val_loss = self.calc_reg_loss(v_preds, v_targets)
self.critic_net.train_step(val_loss, self.critic_optim, self.critic_lr_scheduler, clock=clock, global_net=self.global_critic_net)

# Q-value loss for both Q nets
q_targets = self.calc_q_targets(batch)
q1_preds = self.calc_q(states, actions, self.q1_net)
q1_loss = self.calc_reg_loss(q1_preds, q_targets)
self.q1_net.train_step(q1_loss, self.q1_optim, self.q1_lr_scheduler, clock=clock, global_net=self.global_q1_net)
q2_preds = self.calc_q(states, actions, self.q2_net)
q2_loss = self.calc_reg_loss(q2_preds, q_targets)
self.q2_net.train_step(q2_loss, self.q2_optim, self.q2_lr_scheduler, clock=clock, global_net=self.global_q2_net)

# policy loss
policy_loss = self.calc_policy_loss(batch, action_pd)
self.net.train_step(policy_loss, self.optim, self.lr_scheduler, clock=clock, global_net=self.global_net)

loss = policy_loss + val_loss + q1_loss + q2_loss

# update target_critic_net
self.update_nets()
# update PER priorities if availalbe
self.try_update_per(torch.min(q1_preds, q2_preds), q_targets)

# reset
self.to_train = 0
logger.debug(f'Trained {self.name} at epi: {clock.epi}, frame: {clock.frame}, t: {clock.t}, total_reward so far: {self.body.env.total_reward}, loss: {loss:g}')
kengz marked this conversation as resolved.
Show resolved Hide resolved
return loss.item()
else:
return np.nan

def update_nets(self):
'''Update target critic net'''
if util.frame_mod(self.body.env.clock.frame, self.critic_net.update_frequency, self.body.env.num_envs):
kengz marked this conversation as resolved.
Show resolved Hide resolved
if self.critic_net.update_type == 'replace':
net_util.copy(self.critic_net, self.target_critic_net)
elif self.critic_net.update_type == 'polyak':
net_util.polyak_update(self.critic_net, self.target_critic_net, self.critic_net.polyak_coef)
else:
raise ValueError('Unknown critic_net.update_type. Should be "replace" or "polyak". Exiting.')

@lab_api
def update(self):
'''Updates self.target_critic_net and the explore variables'''
return self.body.explore_var
2 changes: 1 addition & 1 deletion slm_lab/experiment/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def analyze_session(session_spec, session_df, df_mode, plot=True):
'''Analyze session and save data, then return metrics. Note there are 2 types of session_df: body.eval_df and body.train_df'''
info_prepath = session_spec['meta']['info_prepath']
session_df = session_df.copy()
assert len(session_df) > 1, f'Need more than 1 datapoint to calculate metrics'
assert len(session_df) > 2, f'Need more than 2 datapoint to calculate metrics' # first datapoint at frame 0 is empty
util.write(session_df, f'{info_prepath}_session_df_{df_mode}.csv')
# calculate metrics
session_metrics = calc_session_metrics(session_df, ps.get(session_spec, 'env.0.name'), info_prepath, df_mode)
Expand Down
Loading