Skip to content

Commit

Permalink
Merge pull request #429 from kengz/asac
Browse files Browse the repository at this point in the history
latest SAC discrete results
  • Loading branch information
kengz authored Nov 13, 2019
2 parents 8112907 + 8cd0c06 commit 7dc93a8
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 35 deletions.
12 changes: 6 additions & 6 deletions BENCHMARK.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,15 @@ SLM Lab's benchmark includes environments from the following offerings:
||||||||
|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
| Env. \ Alg. | DQN | DDQN+PER | A2C (GAE) | A2C (n-step) | PPO | SAC |
| Breakout <details><summary><i>graph</i></summary><img src="https://user-images.githubusercontent.com/8209263/67737546-dabb6380-f9c8-11e9-901e-b96cc28f1fdf.png"></details> | 80.88 | 182 | 377 | 398 | **443** | - |
| Breakout <details><summary><i>graph</i></summary><img src="https://user-images.githubusercontent.com/8209263/67737546-dabb6380-f9c8-11e9-901e-b96cc28f1fdf.png"></details> | 80.88 | 182 | 377 | 398 | **443** | 3.51* |
| Pong <details><summary><i>graph</i></summary><img src="https://user-images.githubusercontent.com/8209263/67737554-e018ae00-f9c8-11e9-92b5-3bd8d213b1e0.png"></details> | 18.48 | 20.5 | 19.31 | 19.56 | **20.58** | 19.87* |
| Seaquest <details><summary><i>graph</i></summary><img src="https://user-images.githubusercontent.com/8209263/67737557-e3139e80-f9c8-11e9-9446-119593ca956b.png"></details> | 1185 | **4405** | 1070 | 1684 | 1715 | - |
| Qbert <details><summary><i>graph</i></summary><img src="https://user-images.githubusercontent.com/8209263/67737559-e575f880-f9c8-11e9-8c98-f14c82041a45.png"></details> | 5494 | 11426 | 12405 | **13590** | 13460 | 214* |
| Seaquest <details><summary><i>graph</i></summary><img src="https://user-images.githubusercontent.com/8209263/67737557-e3139e80-f9c8-11e9-9446-119593ca956b.png"></details> | 1185 | **4405** | 1070 | 1684 | 1715 | 171* |
| Qbert <details><summary><i>graph</i></summary><img src="https://user-images.githubusercontent.com/8209263/67737559-e575f880-f9c8-11e9-8c98-f14c82041a45.png"></details> | 5494 | 11426 | 12405 | **13590** | 13460 | 923* |
| LunarLander <details><summary><i>graph</i></summary><img src="https://user-images.githubusercontent.com/8209263/67737566-e7d85280-f9c8-11e9-8df8-39c1205c5308.png"></details> | 192 | 233 | 25.21 | 68.23 | 214 | **276** |
| UnityHallway <details><summary><i>graph</i></summary><img src="https://user-images.githubusercontent.com/8209263/67737569-ead34300-f9c8-11e9-9e26-61fe1d779989.png"></details> | -0.32 | 0.27 | 0.08 | -0.96 | **0.73** | - |
| UnityPushBlock <details><summary><i>graph</i></summary><img src="https://user-images.githubusercontent.com/8209263/67737577-eeff6080-f9c8-11e9-931c-843ba697779c.png"></details> | 4.88 | 4.93 | 4.68 | 4.93 | **4.97** | - |
| UnityHallway <details><summary><i>graph</i></summary><img src="https://user-images.githubusercontent.com/8209263/67737569-ead34300-f9c8-11e9-9e26-61fe1d779989.png"></details> | -0.32 | 0.27 | 0.08 | -0.96 | **0.73** | 0.01 |
| UnityPushBlock <details><summary><i>graph</i></summary><img src="https://user-images.githubusercontent.com/8209263/67737577-eeff6080-f9c8-11e9-931c-843ba697779c.png"></details> | 4.88 | 4.93 | 4.68 | 4.93 | **4.97** | -0.70 |

>Episode score at the end of training attained by SLM Lab implementations on discrete-action control problems. Reported episode scores are the average over the last 100 checkpoints, and then averaged over 4 Sessions. Results marked with `*` were trained using the hybrid synchronous/asynchronous version of SAC to parallelize and speed up training time.
>Episode score at the end of training attained by SLM Lab implementations on discrete-action control problems. Reported episode scores are the average over the last 100 checkpoints, and then averaged over 4 Sessions. A Random baseline with score averaged over 100 episodes is included. Results marked with `*` were trained using the hybrid synchronous/asynchronous version of SAC to parallelize and speed up training time. For SAC, Breakout, Pong and Seaquest were trained for 2M frames instead of 10M frames.
>For the full Atari benchmark, see [Atari Benchmark](https://github.com/kengz/SLM-Lab/blob/benchmark/BENCHMARK.md#atari-benchmark)
Expand Down
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,13 @@ Due to their standardized design, all the algorithms can be parallelized asynchr
||||||||
|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
| Env. \ Alg. | DQN | DDQN+PER | A2C (GAE) | A2C (n-step) | PPO | SAC |
| Breakout <details><summary><i>graph</i></summary><img src="https://user-images.githubusercontent.com/8209263/67737546-dabb6380-f9c8-11e9-901e-b96cc28f1fdf.png"></details> | 80.88 | 182 | 377 | 398 | **443** | - |
| Breakout <details><summary><i>graph</i></summary><img src="https://user-images.githubusercontent.com/8209263/67737546-dabb6380-f9c8-11e9-901e-b96cc28f1fdf.png"></details> | 80.88 | 182 | 377 | 398 | **443** | 3.51* |
| Pong <details><summary><i>graph</i></summary><img src="https://user-images.githubusercontent.com/8209263/67737554-e018ae00-f9c8-11e9-92b5-3bd8d213b1e0.png"></details> | 18.48 | 20.5 | 19.31 | 19.56 | **20.58** | 19.87* |
| Seaquest <details><summary><i>graph</i></summary><img src="https://user-images.githubusercontent.com/8209263/67737557-e3139e80-f9c8-11e9-9446-119593ca956b.png"></details> | 1185 | **4405** | 1070 | 1684 | 1715 | - |
| Qbert <details><summary><i>graph</i></summary><img src="https://user-images.githubusercontent.com/8209263/67737559-e575f880-f9c8-11e9-8c98-f14c82041a45.png"></details> | 5494 | 11426 | 12405 | **13590** | 13460 | 214* |
| Seaquest <details><summary><i>graph</i></summary><img src="https://user-images.githubusercontent.com/8209263/67737557-e3139e80-f9c8-11e9-9446-119593ca956b.png"></details> | 1185 | **4405** | 1070 | 1684 | 1715 | 171* |
| Qbert <details><summary><i>graph</i></summary><img src="https://user-images.githubusercontent.com/8209263/67737559-e575f880-f9c8-11e9-8c98-f14c82041a45.png"></details> | 5494 | 11426 | 12405 | **13590** | 13460 | 923* |
| LunarLander <details><summary><i>graph</i></summary><img src="https://user-images.githubusercontent.com/8209263/67737566-e7d85280-f9c8-11e9-8df8-39c1205c5308.png"></details> | 192 | 233 | 25.21 | 68.23 | 214 | **276** |
| UnityHallway <details><summary><i>graph</i></summary><img src="https://user-images.githubusercontent.com/8209263/67737569-ead34300-f9c8-11e9-9e26-61fe1d779989.png"></details> | -0.32 | 0.27 | 0.08 | -0.96 | **0.73** | - |
| UnityPushBlock <details><summary><i>graph</i></summary><img src="https://user-images.githubusercontent.com/8209263/67737577-eeff6080-f9c8-11e9-931c-843ba697779c.png"></details> | 4.88 | 4.93 | 4.68 | 4.93 | **4.97** | - |
| UnityHallway <details><summary><i>graph</i></summary><img src="https://user-images.githubusercontent.com/8209263/67737569-ead34300-f9c8-11e9-9e26-61fe1d779989.png"></details> | -0.32 | 0.27 | 0.08 | -0.96 | **0.73** | 0.01 |
| UnityPushBlock <details><summary><i>graph</i></summary><img src="https://user-images.githubusercontent.com/8209263/67737577-eeff6080-f9c8-11e9-931c-843ba697779c.png"></details> | 4.88 | 4.93 | 4.68 | 4.93 | **4.97** | -0.70 |

>For the full Atari benchmark, see [Atari Benchmark](https://github.com/kengz/SLM-Lab/blob/benchmark/BENCHMARK.md#atari-benchmark)
Expand Down
18 changes: 13 additions & 5 deletions bin/plot_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
trial_metrics_path = '*t0_trial_metrics.pkl'
env_name_map = {
'lunar': 'LunarLander',
'reakout': 'Breakout',
'ong': 'Pong',
'bert': 'Qbert',
'eaquest': 'Seaquest',
'humanoid': 'RoboschoolHumanoid',
'humanoidflagrun': 'RoboschoolHumanoidFlagrun',
'humanoidflagrunharder': 'RoboschoolHumanoidFlagrunHarder',
Expand Down Expand Up @@ -160,9 +162,9 @@ def plot_envs(algos, envs, data_folder, legend_list, frame_scales=None):
'SAC',
]
envs = [
'Breakout',
'reakout',
'ong',
'Seaquest',
'eaquest',
'bert',
'lunar',
'UnityHallway',
Expand All @@ -177,8 +179,8 @@ def plot_envs(algos, envs, data_folder, legend_list, frame_scales=None):

# plot normal
envs = [
'Breakout',
'Seaquest',
# 'Breakout',
# 'Seaquest',
'lunar',
'UnityHallway',
'UnityPushBlock',
Expand All @@ -187,11 +189,17 @@ def plot_envs(algos, envs, data_folder, legend_list, frame_scales=None):

# Replot Pong and Qbert for Async SAC
envs = [
'reakout',
'ong',
'bert',
'eaquest',
]
plot_envs(algos, envs, data_folder, legend_list, frame_scales=[(-1, 6)])

envs = [
'bert',
]
plot_envs(algos, envs, data_folder, legend_list, frame_scales=[(-1, 8)])


# Continuous
# Roboschool + Unity
Expand Down
20 changes: 12 additions & 8 deletions slm_lab/spec/benchmark/async_sac/async_sac_atari.json
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
"name": "BreakoutNoFrameskip-v4",
"frame_op": "concat",
"frame_op_len": 4,
"image_downsize": [64, 64],
"reward_scale": "sign",
"num_envs": 4,
"max_t": null,
Expand All @@ -60,8 +61,8 @@
},
"meta": {
"distributed": "shared",
"log_frequency": 500,
"eval_frequency": 500,
"log_frequency": 1000,
"eval_frequency": 1000,
"rigorous_eval": 0,
"max_session": 6,
"max_trial": 1,
Expand Down Expand Up @@ -117,6 +118,7 @@
"name": "PongNoFrameskip-v4",
"frame_op": "concat",
"frame_op_len": 4,
"image_downsize": [64, 64],
"reward_scale": "sign",
"num_envs": 4,
"max_t": null,
Expand All @@ -128,8 +130,8 @@
},
"meta": {
"distributed": "shared",
"log_frequency": 500,
"eval_frequency": 500,
"log_frequency": 1000,
"eval_frequency": 1000,
"rigorous_eval": 0,
"max_session": 6,
"max_trial": 1,
Expand Down Expand Up @@ -185,6 +187,7 @@
"name": "QbertNoFrameskip-v4",
"frame_op": "concat",
"frame_op_len": 4,
"image_downsize": [64, 64],
"reward_scale": "sign",
"num_envs": 4,
"max_t": null,
Expand All @@ -196,8 +199,8 @@
},
"meta": {
"distributed": "shared",
"log_frequency": 500,
"eval_frequency": 500,
"log_frequency": 1000,
"eval_frequency": 1000,
"rigorous_eval": 0,
"max_session": 6,
"max_trial": 1,
Expand Down Expand Up @@ -253,6 +256,7 @@
"name": "SeaquestNoFrameskip-v4",
"frame_op": "concat",
"frame_op_len": 4,
"image_downsize": [64, 64],
"reward_scale": "sign",
"num_envs": 4,
"max_t": null,
Expand All @@ -264,8 +268,8 @@
},
"meta": {
"distributed": "shared",
"log_frequency": 500,
"eval_frequency": 500,
"log_frequency": 1000,
"eval_frequency": 1000,
"rigorous_eval": 0,
"max_session": 6,
"max_trial": 1,
Expand Down
10 changes: 5 additions & 5 deletions slm_lab/spec/benchmark/async_sac/async_sac_qbert.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
},
"memory": {
"name": "Replay",
"batch_size": 512,
"batch_size": 256,
"max_size": 200000,
"use_cer": false
},
Expand Down Expand Up @@ -53,7 +53,7 @@
"reward_scale": "sign",
"num_envs": 4,
"max_t": null,
"max_frame": 5e6
"max_frame": 2e6
}],
"body": {
"product": "outer",
Expand All @@ -64,7 +64,7 @@
"log_frequency": 1000,
"eval_frequency": 1000,
"rigorous_eval": 0,
"max_session": 6,
"max_session": 4,
"max_trial": 1,
}
},
Expand Down Expand Up @@ -119,7 +119,7 @@
"frame_op": "concat",
"frame_op_len": 4,
"image_downsize": [64, 64],
"reward_scale": "sign",
"reward_scale":null,
"num_envs": 4,
"max_t": null,
"max_frame": 1e7
Expand All @@ -133,7 +133,7 @@
"log_frequency": 1000,
"eval_frequency": 1000,
"rigorous_eval": 0,
"max_session": 6,
"max_session": 4,
"max_trial": 1,
}
},
Expand Down
11 changes: 5 additions & 6 deletions slm_lab/spec/benchmark/sac/sac_unity.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,21 @@
"memory": {
"name": "Replay",
"batch_size": 256,
"max_size": 200000,
"max_size": 100000,
"use_cer": false
},
"net": {
"type": "MLPNet",
"hid_layers": [256, 256],
"hid_layers_activation": "relu",
"hid_layers": [64, 64, 32],
"hid_layers_activation": "leakyrelu",
"init_fn": "orthogonal_",
"clip_grad_val": 0.5,
"loss_spec": {
"name": "MSELoss"
},
"optim_spec": {
"name": "Lookahead",
"optimizer": "RAdam",
"lr": 3e-3,
"name": "RAdam",
"lr": 3e-4,
},
"lr_scheduler_spec": null,
"update_type": "polyak",
Expand Down

0 comments on commit 7dc93a8

Please sign in to comment.