Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tuning code for PC and SQIL #841

Open
wants to merge 36 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
2459817
Fix and expand hyperparameter search space for PC.
ernestum Dec 4, 2023
c9ccf5b
Upgrade environment versions in the train_preference_comparisons config.
ernestum Jan 6, 2024
d7a7da8
Upgrade to ray 2.9.0.
ernestum Jan 10, 2024
55aa6eb
Ensure that PC does at least one comparison per iteration.
ernestum Jan 8, 2024
78553c9
Add initial epoch multiplier as a parameter to the PC script.
ernestum Jan 11, 2024
1145c07
Add tuning folder and move section about hp-tuning from the benchmark…
ernestum Dec 18, 2023
6ecfc34
Add pure optuna tuning script.
ernestum Jan 10, 2024
8828a35
Use functor to execute sacred experiments as optuna trials.
ernestum Jan 12, 2024
8e67dd4
Fix usage string.
ernestum Jan 12, 2024
865c1c5
First draft of tune_on_slurm.sh
ernestum Jan 12, 2024
8e31713
Move hyperparameter search space definitoins to a separate file.
ernestum Jan 16, 2024
2f8cb4d
Add initial_epoch_multiplier to hyperparameter search space of PC
ernestum Jan 18, 2024
d9c4b57
Speed up trials by vectorizing environments
ernestum Jan 18, 2024
be085d3
Improve tune_on_slum.sh
ernestum Jan 18, 2024
efc2dbf
Add script to tune PC on all environments.
ernestum Jan 22, 2024
2520508
Improve documentation for hyperparameter tuning.
ernestum Jan 22, 2024
147c996
Add script to re-run the best trials from a hyperparameter tuning run.
ernestum Jan 22, 2024
b2a98d7
Add gc_after_trial flag to avoid memory issues.
ernestum Feb 5, 2024
6f2e30c
Add feature to re-run a tuning run.
ernestum Feb 5, 2024
454c393
Write output of sbatch script and output of training script to separa…
ernestum Feb 5, 2024
5012b15
Add scripts to rerun the top trials of a hyperparameter sweep.
ernestum Feb 5, 2024
2c978af
Fix documentation and cout.txt placement in tune_on_slurm.sh
ernestum Feb 13, 2024
7178fee
Increase default memory for cheetah and humanoid.
ernestum Feb 19, 2024
263db3e
Remove top-k feature when re-running trials.
ernestum Feb 19, 2024
7311d1c
Fix paths in rerun batch script.
ernestum Feb 19, 2024
5769fc6
Add hyper parameters for SQIL.
ernestum Feb 26, 2024
d3860a3
Turn DQN into a named config.
ernestum Feb 26, 2024
80edc62
Add namedconfigs for all seals envs and set the number of timesteps i…
ernestum Feb 27, 2024
9bb89c9
Store command name as user attr in trial.
ernestum Feb 27, 2024
57018d5
Add hp search space for PC/classic control.
ernestum Feb 27, 2024
b7ca230
Fix fragment length for classic control environments.
ernestum Feb 28, 2024
288d38d
Add sbatch jobs for SQIL to tune_all_on_slurm.
ernestum Feb 28, 2024
2e356b7
Add benchmark analysis notebook.
ernestum Feb 29, 2024
f4bcbdc
Some formatting fixes.
ernestum Feb 29, 2024
60fc75a
Shellcheck fixes.
ernestum Feb 29, 2024
e69dbd6
More formatting fixes.
ernestum Feb 29, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 0 additions & 22 deletions benchmarking/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -186,25 +186,3 @@ where:

If `your_runs_dir` contains runs for more than one algorithm, you will have to
disambiguate using the `--algo` option.

## Tuning Hyperparameters

The hyperparameters of any algorithm in imitation can be tuned using `src/imitation/scripts/tuning.py`.
The benchmarking hyperparameter configs were generated by tuning the hyperparameters using
the search space defined in the `scripts/config/tuning.py`.

The tuning script proceeds in two phases:
1. Tune the hyperparameters using the search space provided.
2. Re-evaluate the best hyperparameter config found in the first phase based on the maximum mean return on a separate set of seeds. Report the mean and standard deviation of these trials.

To use it with the default search space:
```bash
python -m imitation.scripts.tuning with <algo> 'parallel_run_config.base_named_configs=["<env>"]'
```

In this command:
- `<algo>` provides the default search space and settings for the specific algorithm, which is defined in the `scripts/config/tuning.py`
- `<env>` sets the environment to tune the algorithm in. They are defined in the algo-specifc `scripts/config/train_[adversarial|imitation|preference_comparisons|rl].py` files. For the already tuned environments, use the `<algo>_<env>` named configs here.

See the documentation of `scripts/tuning.py` and `scripts/parallel.py` for many other arguments that can be
provided through the command line to change the tuning behavior.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

IS_NOT_WINDOWS = os.name != "nt"

PARALLEL_REQUIRE = ["ray[debug,tune]~=2.0.0"]
PARALLEL_REQUIRE = ["ray[debug,tune]~=2.9.0"]
ATARI_REQUIRE = [
"seals[atari]~=0.2.1",
]
Expand Down
4 changes: 4 additions & 0 deletions src/imitation/algorithms/preference_comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -1678,6 +1678,10 @@ def train(
unnormalized_probs = vec_schedule(np.linspace(0, 1, self.num_iterations))
probs = unnormalized_probs / np.sum(unnormalized_probs)
shares = util.oric(probs * total_comparisons)
shares[
shares <= 0
] = 1 # ensure we at least request one comparison per iteration

schedule = [initial_comparisons] + shares.tolist()
print(f"Query schedule: {schedule}")

Expand Down
30 changes: 29 additions & 1 deletion src/imitation/scripts/config/train_imitation.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ def seals_mountain_car():
environment = dict(gym_id="seals/MountainCar-v0")
bc = dict(l2_weight=0.0)
dagger = dict(total_timesteps=20000)
sqil = dict(total_timesteps=1e5)


@train_imitation_ex.named_config
def seals_ant():
environment = dict(gym_id="seals/Ant-v1")
sqil = dict(total_timesteps=2e6)


@train_imitation_ex.named_config
Expand All @@ -57,11 +64,13 @@ def cartpole():
def seals_cartpole():
environment = dict(gym_id="seals/CartPole-v0")
dagger = dict(total_timesteps=20000)
sqil = dict(total_timesteps=1e5)


@train_imitation_ex.named_config
def pendulum():
environment = dict(gym_id="Pendulum-v1")
sqil = dict(total_timesteps=1e5)


@train_imitation_ex.named_config
Expand All @@ -76,14 +85,33 @@ def half_cheetah():
dagger = dict(total_timesteps=60000)


@train_imitation_ex.named_config
def seals_half_cheetah():
environment = dict(gym_id="seals/HalfCheetah-v1")
sqil = dict(total_timesteps=2e6)


@train_imitation_ex.named_config
def seals_hopper():
environment = dict(gym_id="seals/Hopper-v1")
sqil = dict(total_timesteps=2e6)


@train_imitation_ex.named_config
def seals_walker():
environment = dict(gym_id="seals/Walker2d-v1")
sqil = dict(total_timesteps=2e6)


@train_imitation_ex.named_config
def humanoid():
environment = dict(gym_id="Humanoid-v2")


@train_imitation_ex.named_config
def seals_humanoid():
environment = dict(gym_id="seals/Humanoid-v0")
environment = dict(gym_id="seals/Humanoid-v1")
sqil = dict(total_timesteps=2e6)


@train_imitation_ex.named_config
Expand Down
14 changes: 8 additions & 6 deletions src/imitation/scripts/config/train_preference_comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ def train_defaults():
transition_oversampling = 1
# fraction of total_comparisons that will be sampled right at the beginning
initial_comparison_frac = 0.1
# factor by which to oversample the number of epochs in the first iteration
initial_epoch_multiplier = 200.0
# fraction of sampled trajectories that will include some random actions
exploration_frac = 0.0
preference_model_kwargs = {}
Expand Down Expand Up @@ -77,7 +79,7 @@ def cartpole():

@train_preference_comparisons_ex.named_config
def seals_ant():
environment = dict(gym_id="seals/Ant-v0")
environment = dict(gym_id="seals/Ant-v1")
rl = dict(
batch_size=2048,
rl_kwargs=dict(
Expand All @@ -104,7 +106,7 @@ def half_cheetah():

@train_preference_comparisons_ex.named_config
def seals_half_cheetah():
environment = dict(gym_id="seals/HalfCheetah-v0")
environment = dict(gym_id="seals/HalfCheetah-v1")
rl = dict(
batch_size=512,
rl_kwargs=dict(
Expand All @@ -125,7 +127,7 @@ def seals_half_cheetah():

@train_preference_comparisons_ex.named_config
def seals_hopper():
environment = dict(gym_id="seals/Hopper-v0")
environment = dict(gym_id="seals/Hopper-v1")
policy = dict(
policy_cls="MlpPolicy",
policy_kwargs=dict(
Expand All @@ -151,7 +153,7 @@ def seals_hopper():

@train_preference_comparisons_ex.named_config
def seals_swimmer():
environment = dict(gym_id="seals/Swimmer-v0")
environment = dict(gym_id="seals/Swimmer-v1")
policy = dict(
policy_cls="MlpPolicy",
policy_kwargs=dict(
Expand All @@ -178,7 +180,7 @@ def seals_swimmer():

@train_preference_comparisons_ex.named_config
def seals_walker():
environment = dict(gym_id="seals/Walker2d-v0")
environment = dict(gym_id="seals/Walker2d-v1")
policy = dict(
policy_cls="MlpPolicy",
policy_kwargs=dict(
Expand Down Expand Up @@ -206,7 +208,7 @@ def seals_walker():
@train_preference_comparisons_ex.named_config
def seals_humanoid():
locals().update(**MUJOCO_SHARED_LOCALS)
environment = dict(gym_id="seals/Humanoid-v0")
environment = dict(gym_id="seals/Humanoid-v1")
total_timesteps = int(4e6)


Expand Down
53 changes: 31 additions & 22 deletions src/imitation/scripts/config/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import ray.tune as tune
import sacred
from torch import nn

from imitation.algorithms import dagger as dagger_alg
from imitation.scripts.parallel import parallel_ex
Expand Down Expand Up @@ -188,38 +187,48 @@ def pc():
parallel_run_config = dict(
sacred_ex_name="train_preference_comparisons",
run_name="pc_tuning",
base_named_configs=["logging.wandb_logging"],
base_named_configs=[],
base_config_updates={
"environment": {"num_vec": 1},
"demonstrations": {"source": "huggingface"},
"total_timesteps": 2e7,
"total_comparisons": 5000,
"query_schedule": "hyperbolic",
"gatherer_kwargs": {"sample": True},
"total_comparisons": 1000,
"active_selection": True,
},
search_space={
"named_configs": [
["reward.normalize_output_disable"],
],
"named_configs": ["reward.reward_ensemble"],
"config_updates": {
"train": {
"policy_kwargs": {
"activation_fn": tune.choice(
[
nn.ReLU,
],
),
},
"active_selection_oversampling": tune.randint(1, 11),
"comparison_queue_size": tune.randint(
1, 1001,
), # upper bound determined by total_comparisons=1000
"exploration_frac": tune.uniform(0.0, 0.5),
"fragment_length": tune.randint(
1, 1001,
), # trajectories are 1000 steps long
"gatherer_kwargs": {
"temperature": tune.uniform(0.0, 2.0),
"discount_factor": tune.uniform(0.95, 1.0),
"sample": tune.choice([True, False]),
},
"initial_comparison_frac": tune.uniform(0.01, 1.0),
"num_iterations": tune.randint(1, 51),
"preference_model_kwargs": {
"noise_prob": tune.uniform(0.0, 0.1),
"discount_factor": tune.uniform(0.95, 1.0),
},
"num_iterations": tune.choice([25, 50]),
"initial_comparison_frac": tune.choice([0.1, 0.25]),
"query_schedule": tune.choice(
["hyperbolic", "constant", "inverse_quadratic",]
),
"trajectory_generator_kwargs": {
"switch_prob": tune.uniform(0.1, 1),
"random_prob": tune.uniform(0.1, 0.9),
},
"transition_oversampling": tune.uniform(0.9, 2.0),
"reward_trainer_kwargs": {
"epochs": tune.choice([1, 3, 6]),
"epochs": tune.randint(1, 11),
},
"rl": {
"batch_size": tune.choice([512, 2048, 8192]),
"rl_kwargs": {
"learning_rate": tune.loguniform(1e-5, 1e-2),
"ent_coef": tune.loguniform(1e-7, 1e-3),
},
},
Expand Down
5 changes: 5 additions & 0 deletions src/imitation/scripts/ingredients/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,11 @@ def sac():
locals() # quieten flake8


@rl_ingredient.named_config
def dqn():
rl_cls = sb3.DQN


def _maybe_add_relabel_buffer(
rl_kwargs: Dict[str, Any],
relabel_reward_fn: Optional[RewardFn] = None,
Expand Down
3 changes: 1 addition & 2 deletions src/imitation/scripts/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,13 +188,12 @@ def _ray_tune_sacred_wrapper(
`ex.run`) and `reporter`. The function returns the run result.
"""

def inner(config: Mapping[str, Any], reporter) -> Mapping[str, Any]:
def inner(config: Mapping[str, Any]) -> Mapping[str, Any]:
"""Trainable function with the correct signature for `ray.tune`.

Args:
config: Keyword arguments for `ex.run()`, where `ex` is the
`sacred.Experiment` instance associated with `sacred_ex_name`.
reporter: Callback to report progress to Ray.

Returns:
Result from `ray.Run` object.
Expand Down
5 changes: 5 additions & 0 deletions src/imitation/scripts/train_preference_comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def train_preference_comparisons(
fragment_length: int,
transition_oversampling: float,
initial_comparison_frac: float,
initial_epoch_multiplier: float,
exploration_frac: float,
trajectory_path: Optional[str],
trajectory_generator_kwargs: Mapping[str, Any],
Expand Down Expand Up @@ -106,6 +107,9 @@ def train_preference_comparisons(
sampled before the rest of training begins (using the randomly initialized
agent). This can be used to pretrain the reward model before the agent
is trained on the learned reward.
initial_epoch_multiplier: before agent training begins, train the reward
model for this many more epochs than usual (on fragments sampled from a
random agent).
exploration_frac: fraction of trajectory samples that will be created using
partially random actions, rather than the current policy. Might be helpful
if the learned policy explores too little and gets stuck with a wrong
Expand Down Expand Up @@ -258,6 +262,7 @@ def train_preference_comparisons(
fragment_length=fragment_length,
transition_oversampling=transition_oversampling,
initial_comparison_frac=initial_comparison_frac,
initial_epoch_multiplier=initial_epoch_multiplier,
custom_logger=custom_logger,
allow_variable_horizon=allow_variable_horizon,
query_schedule=query_schedule,
Expand Down
38 changes: 38 additions & 0 deletions tuning/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Tuning Hyperparameters
This directory contains scripts for tuning hyperparameters for imitation learning algorithms.
Additional helper scripts allow for running multiple tuning jobs in parallel on a SLURM cluster.

Use `tune.py` to tune hyperparameters for a single algorithm and environment using Optuna.
If you want to specify a custom algorithm and search space, add it to the dict in `hp_search_spaces.py`.

You can tune using multiple workers in parallel by running multiple instances of `tune.py` that all point to the same journal log file (see `tune.py --help` for details).
To easily launch multiple workers on a SLURM cluster and ensure they don't conflict with each other,
use the `tune_on_slurm.py` script.
This script will launch a SLURM job array with the specified number of workers.
If you want to tune all algorithms on all environments on SLURM, use `tune_all_on_slurm.sh`.

# Legacy Tuning Scripts

Note: There are some legacy tuning scripts that can be used like this:

The hyperparameters of any algorithm in imitation can be tuned using `src/imitation/scripts/tuning.py`.
The benchmarking hyperparameter configs were generated by tuning the hyperparameters using
the search space defined in the `scripts/config/tuning.py`.

The tuning script proceeds in two phases:
1. Tune the hyperparameters using the search space provided.
2. Re-evaluate the best hyperparameter config found in the first phase
based on the maximum mean return on a separate set of seeds.
Report the mean and standard deviation of these trials.

To use it with the default search space:
```bash
python -m imitation.scripts.tuning with <algo> 'parallel_run_config.base_named_configs=["<env>"]'
```

In this command:
- `<algo>` provides the default search space and settings for the specific algorithm, which is defined in the `scripts/config/tuning.py`
- `<env>` sets the environment to tune the algorithm in. They are defined in the algo-specifc `scripts/config/train_[adversarial|imitation|preference_comparisons|rl].py` files. For the already tuned environments, use the `<algo>_<env>` named configs here.

See the documentation of `scripts/tuning.py` and `scripts/parallel.py` for many other arguments that can be
provided through the command line to change the tuning behavior.
Loading
Loading