Skip to content

Commit

Permalink
Tune HER hyperparams (#58)
Browse files Browse the repository at this point in the history
* Update her hyperparams

* Contrib repo is now required

* Save hyperparams

* Remove reward offset

* Update params

* Update hyperparams

* Add TQC hyperparam opt support

* Update requirements

* Update docker image

* Attempt to fix CI

* Fix bug when using HER + DQN/TQC for hyperparam optimization

* Fix SQLAlchemy version

* Maybe pip will be happy now?

* Use latest contrib version

* Test if hack is still needed

* Remove hack

* Cleanup
  • Loading branch information
araffin authored Dec 13, 2020
1 parent e9437a7 commit 8ea4f4a
Show file tree
Hide file tree
Showing 17 changed files with 177 additions and 293 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ jobs:
run: |
python -m pip install --upgrade pip
# cpu version of pytorch - faster to download
pip install torch==1.5.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install torch==1.7.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install -r requirements.txt
# Use headless version
pip install opencv-python-headless
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/trained_agents.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ jobs:
run: |
python -m pip install --upgrade pip
# cpu version of pytorch - faster to download
pip install torch==1.5.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install torch==1.7.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install -r requirements.txt
# Use headless version
pip install opencv-python-headless
Expand Down
7 changes: 6 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
## Pre-Release 0.11.0a0 (WIP)
## Pre-Release 0.11.0a2 (WIP)

### Breaking Changes
- Removed `LinearNormalActionNoise`
- Evaluation is now deterministic by default, except for Atari games
- `sb3_contrib` is now required
- `TimeFeatureWrapper` was moved to the contrib repo

### New Features
- Added option to choose which `VecEnv` class to use for multiprocessing
- Added hyperparameter optimization support for `TQC`

### Bug fixes
- Improved detection of Atari games
- Fix potential bug in plotting script when there is not enough timesteps
- Fixed a bug when using HER + DQN/TQC for hyperparam optimization

### Documentation

Expand All @@ -21,6 +25,7 @@
- Changed `PPO` atari hyperparameters (removed vf clipping)
- Changed `A2C` atari hyperparameters (eps value of the optimizer)
- Updated benchmark script
- Updated hyperparameter optim search space (commented gSDE for A2C/PPO)

## Pre-Release 0.10.0 (2020-10-28)

Expand Down
8 changes: 4 additions & 4 deletions hyperparams/a2c.yml
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ BipedalWalkerHardcore-v3:

# Tuned
HalfCheetahBulletEnv-v0:
env_wrapper: utils.wrappers.TimeFeatureWrapper
env_wrapper: sb3_contrib.common.wrappers.TimeFeatureWrapper
normalize: true
n_envs: 4
n_timesteps: !!float 2e6
Expand All @@ -145,7 +145,7 @@ HalfCheetahBulletEnv-v0:
policy_kwargs: "dict(log_std_init=-2, ortho_init=False, full_std=True)"

Walker2DBulletEnv-v0:
env_wrapper: utils.wrappers.TimeFeatureWrapper
env_wrapper: sb3_contrib.common.wrappers.TimeFeatureWrapper
normalize: true
n_envs: 4
n_timesteps: !!float 2e6
Expand Down Expand Up @@ -198,7 +198,7 @@ AntBulletEnv-v0:

# Tuned
HopperBulletEnv-v0:
env_wrapper: utils.wrappers.TimeFeatureWrapper
env_wrapper: sb3_contrib.common.wrappers.TimeFeatureWrapper
normalize: true
n_envs: 4
n_timesteps: !!float 2e6
Expand All @@ -218,7 +218,7 @@ HopperBulletEnv-v0:
# Tuned but unstable
# Not working without SDE?
ReacherBulletEnv-v0:
env_wrapper: utils.wrappers.TimeFeatureWrapper
env_wrapper: sb3_contrib.common.wrappers.TimeFeatureWrapper
normalize: true
n_envs: 4
n_timesteps: !!float 2e6
Expand Down
16 changes: 8 additions & 8 deletions hyperparams/ddpg.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ BipedalWalkerHardcore-v3:

# Tuned
HalfCheetahBulletEnv-v0:
env_wrapper: utils.wrappers.TimeFeatureWrapper
env_wrapper: sb3_contrib.common.wrappers.TimeFeatureWrapper
n_timesteps: !!float 1e6
policy: 'MlpPolicy'
gamma: 0.98
Expand All @@ -75,7 +75,7 @@ HalfCheetahBulletEnv-v0:

# Tuned
AntBulletEnv-v0:
env_wrapper: utils.wrappers.TimeFeatureWrapper
env_wrapper: sb3_contrib.common.wrappers.TimeFeatureWrapper
n_timesteps: !!float 1e6
policy: 'MlpPolicy'
gamma: 0.98
Expand All @@ -90,7 +90,7 @@ AntBulletEnv-v0:

# Tuned
HopperBulletEnv-v0:
env_wrapper: utils.wrappers.TimeFeatureWrapper
env_wrapper: sb3_contrib.common.wrappers.TimeFeatureWrapper
n_timesteps: !!float 1e6
policy: 'MlpPolicy'
gamma: 0.98
Expand All @@ -107,7 +107,7 @@ HopperBulletEnv-v0:

# Tuned
Walker2DBulletEnv-v0:
env_wrapper: utils.wrappers.TimeFeatureWrapper
env_wrapper: sb3_contrib.common.wrappers.TimeFeatureWrapper
n_timesteps: !!float 1e6
policy: 'MlpPolicy'
gamma: 0.98
Expand All @@ -124,7 +124,7 @@ Walker2DBulletEnv-v0:

# TO BE tested
HumanoidBulletEnv-v0:
env_wrapper: utils.wrappers.TimeFeatureWrapper
env_wrapper: sb3_contrib.common.wrappers.TimeFeatureWrapper
n_timesteps: !!float 2e6
policy: 'MlpPolicy'
gamma: 0.98
Expand All @@ -139,7 +139,7 @@ HumanoidBulletEnv-v0:

# To be tuned
ReacherBulletEnv-v0:
env_wrapper: utils.wrappers.TimeFeatureWrapper
env_wrapper: sb3_contrib.common.wrappers.TimeFeatureWrapper
n_timesteps: !!float 3e5
policy: 'MlpPolicy'
gamma: 0.98
Expand All @@ -154,7 +154,7 @@ ReacherBulletEnv-v0:

# To be tuned
InvertedDoublePendulumBulletEnv-v0:
env_wrapper: utils.wrappers.TimeFeatureWrapper
env_wrapper: sb3_contrib.common.wrappers.TimeFeatureWrapper
n_timesteps: !!float 1e6
policy: 'MlpPolicy'
gamma: 0.98
Expand All @@ -169,7 +169,7 @@ InvertedDoublePendulumBulletEnv-v0:

# To be tuned
InvertedPendulumSwingupBulletEnv-v0:
env_wrapper: utils.wrappers.TimeFeatureWrapper
env_wrapper: sb3_contrib.common.wrappers.TimeFeatureWrapper
n_timesteps: !!float 3e5
policy: 'MlpPolicy'
gamma: 0.98
Expand Down
163 changes: 60 additions & 103 deletions hyperparams/her.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ NeckGoalEnvRelativeSparse-v2:
# env_wrapper:
# - utils.wrappers.HistoryWrapper:
# horizon: 2
# - utils.wrappers.TimeFeatureWrapper
# - sb3_contrib.common.wrappers.TimeFeatureWrapper
n_timesteps: !!float 1e6
policy: 'MlpPolicy'
learning_rate: !!float 7.3e-4
Expand All @@ -31,7 +31,7 @@ NeckGoalEnvRelativeDense-v2:
env_wrapper:
- utils.wrappers.HistoryWrapperObsDict:
horizon: 2
# - utils.wrappers.TimeFeatureWrapper
# - sb3_contrib.common.wrappers.TimeFeatureWrapper
n_timesteps: !!float 1e6
policy: 'MlpPolicy'
learning_rate: !!float 7.3e-4
Expand All @@ -53,6 +53,22 @@ NeckGoalEnvRelativeDense-v2:
goal_selection_strategy: 'future'
online_sampling: False

FetchPush-v1:
env_wrapper:
- sb3_contrib.common.wrappers.TimeFeatureWrapper
n_timesteps: !!float 1e6
policy: 'MlpPolicy'
model_class: 'tqc'
n_sampled_goal: 4
goal_selection_strategy: 'future'
buffer_size: 1000000
batch_size: 2048
gamma: 0.95
learning_rate: !!float 1e-3
tau: 0.05
policy_kwargs: "dict(n_critics=2, net_arch=[512, 512, 512])"
online_sampling: True

# DDPG hyperparams
#parking-v0:
# n_timesteps: !!float 2e5
Expand All @@ -70,121 +86,76 @@ NeckGoalEnvRelativeDense-v2:
# online_sampling: True
# max_episode_length: 100


# SAC hyperparams, her paper
parking-v0:
n_timesteps: !!float 2e5
policy: 'MlpPolicy'
model_class: 'sac'
model_class: 'tqc'
n_sampled_goal: 4
goal_selection_strategy: 'future'
buffer_size: 1000000
batch_size: 256
batch_size: 1024
gamma: 0.95
learning_rate: !!float 1e-3
# noise_type: 'normal'
# noise_std: 0.2
policy_kwargs: "dict(net_arch=[256, 256, 256])"
online_sampling: False
# normalize: True
tau: 0.05
policy_kwargs: "dict(n_critics=2, net_arch=[512, 512, 512])"
online_sampling: True
max_episode_length: 100

# TD3 hyperparams, her paper
#parking-v0:
# n_timesteps: !!float 2e5
# policy: 'MlpPolicy'
# model_class: 'td3'
# n_sampled_goal: 4
# goal_selection_strategy: 'future'
# buffer_size: 1000000
# batch_size: 256
# gamma: 0.95
# learning_rate: !!float 1e-3
# noise_type: 'normal'
# noise_std: 0.2
# policy_kwargs: "dict(net_arch=[256, 256, 256])"
# online_sampling: True
# max_episode_length: 100

# normalize: True

# Mujoco Robotic Env
# DDPG hyperparams
# FetchReach-v1:
# n_timesteps: !!float 20000
# policy: 'MlpPolicy'
# model_class: 'ddpg'
# n_sampled_goal: 4
# goal_selection_strategy: 'future'
# buffer_size: 1000000
# batch_size: 256
# gamma: 0.95
# random_exploration: 0.3
# actor_lr: !!float 1e-3
# critic_lr: !!float 1e-3
# noise_type: 'normal'
# noise_std: 0.2
# normalize_observations: true
# normalize_returns: false
# policy_kwargs: "dict(layers=[256, 256, 256])"
# online_sampling: True

# NOTE: shoube be run with 8 workers: mpirun -n 8
# FetchPush-v1:
# n_timesteps: !!float 2e6
# policy: 'MlpPolicy'
# model_class: 'ddpg'
# n_sampled_goal: 4
# goal_selection_strategy: 'future'
# buffer_size: 200000
# batch_size: 256
# gamma: 0.95
# random_exploration: 0.3
# actor_lr: !!float 1e-3
# critic_lr: !!float 1e-3
# noise_type: 'normal'
# noise_std: 0.2
# normalize_observations: true
# normalize_returns: false
# policy_kwargs: "dict(layers=[16, 16, 16])"
FetchSlide-v1:
env_wrapper:
- sb3_contrib.common.wrappers.TimeFeatureWrapper
n_timesteps: !!float 1e6
policy: 'MlpPolicy'
model_class: 'tqc'
n_sampled_goal: 4
goal_selection_strategy: 'future'
buffer_size: 1000000
batch_size: 2048
gamma: 0.95
learning_rate: !!float 1e-3
tau: 0.05
# ent_coef: 0.01
policy_kwargs: "dict(n_critics=2, net_arch=[512, 512, 512])"
online_sampling: True

FetchPush-v1:
env_wrapper:
- utils.wrappers.HistoryWrapperObsDict:
horizon: 2
# - utils.wrappers.TimeFeatureObsDictWrapper
n_timesteps: !!float 3e6
- sb3_contrib.common.wrappers.TimeFeatureWrapper
n_timesteps: !!float 1e6
policy: 'MlpPolicy'
model_class: 'sac'
model_class: 'tqc'
n_sampled_goal: 4
goal_selection_strategy: 'future'
buffer_size: 1000000
ent_coef: 'auto'
batch_size: 2048
gamma: 0.95
learning_rate: !!float 7e-4
use_sde: True
gradient_steps: -1
train_freq: -1
n_episodes_rollout: 1
sde_sample_freq: 10
# noise_type: 'normal'
# noise_std: 0.2
learning_starts: 1000
learning_rate: !!float 1e-3
tau: 0.05
# ent_coef: 0.01
policy_kwargs: "dict(n_critics=2, net_arch=[256, 256, 256])"
online_sampling: True
normalize: True

FetchPickAndPlace-v1:
n_timesteps: !!float 4e6
env_wrapper:
- sb3_contrib.common.wrappers.TimeFeatureWrapper
# - utils.wrappers.DoneOnSuccessWrapper:
# reward_offset: 0
# n_successes: 4
# - stable_baselines3.common.monitor.Monitor
n_timesteps: !!float 1e6
policy: 'MlpPolicy'
model_class: 'sac'
model_class: 'tqc'
n_sampled_goal: 4
goal_selection_strategy: 'future'
buffer_size: 1000000
ent_coef: 'auto'
# batch_size: 256
batch_size: 1024
gamma: 0.95
# learning_rate: !!float 1e-3
learning_starts: 1000
train_freq: 1
learning_rate: !!float 1e-3
tau: 0.05
policy_kwargs: "dict(n_critics=2, net_arch=[512, 512, 512])"
online_sampling: True

# SAC hyperparams
Expand All @@ -202,17 +173,3 @@ FetchReach-v1:
learning_starts: 1000
online_sampling: True
normalize: True


# TD3 hyperparams
# FetchReach-v1:
# n_timesteps: !!float 25000
# policy: 'MlpPolicy'
# model_class: 'td3'
# n_sampled_goal: 4
# goal_selection_strategy: 'future'
# buffer_size: 1000000
# batch_size: 256
# gamma: 0.95
# learning_rate: 0.001
# learning_starts: 1000
Loading

0 comments on commit 8ea4f4a

Please sign in to comment.