-
Notifications
You must be signed in to change notification settings - Fork 473
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Hyperparameter Optimization with Ray Tune and Weights and Biases (#76)
- Loading branch information
Showing
12 changed files
with
650 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -149,3 +149,5 @@ OUT/ | |
|
||
examples/experiments/grounded_program_synthesis/dataset | ||
ckpts/ | ||
|
||
ray_results/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
tune_config: | ||
mode: "max" | ||
metric: "mean_reward" | ||
search_alg: "bohb" # random | ||
scheduler: "hyperbandforbohb" # fifo | ||
num_samples: 15 | ||
max_concurrent_trials: null | ||
time_budget_s: null | ||
reuse_actors: null | ||
|
||
model: | ||
model_path: "lvwerra/gpt2-imdb" # Name of hf model to load | ||
tokenizer_path: "gpt2" # Name of hf tokenizer to load | ||
model_type: "AcceleratePPOModel" # Name of accelerate model type to load | ||
num_layers_unfrozen: # Number of bottom layers to freeze during training | ||
strategy: "choice" | ||
values: [2, 3] | ||
|
||
train: | ||
seq_length: # Size of LM context | ||
strategy: "choice" | ||
values: [36, 48, 52] | ||
epochs: # Train for max(epochs, total_steps) | ||
strategy: "choice" | ||
values: [80, 100, 120] | ||
total_steps: 10000 # Train for max(epochs, total_steps) | ||
batch_size: 128 # batch size | ||
|
||
lr_init: 1.412e-4 # init learning rate | ||
lr_target: 1.412e-4 # target final learning rate | ||
opt_betas: [0.9, 0.95] # adam betas | ||
opt_eps: 1.0e-8 # adam eps | ||
weight_decay: 1.0e-6 # weight decay param | ||
|
||
checkpoint_interval: 10000 # checkpoint interval | ||
eval_interval: 4 # eval interval | ||
|
||
pipeline: "PPOPipeline" # prompt pipeline to load | ||
orchestrator: "PPOOrchestrator" # orchestrator to load | ||
project_name: "trlx-hyperopt-bohb" | ||
|
||
method: | ||
name: 'ppoconfig' # Name of RL method config | ||
num_rollouts: # Number of rollouts to collect per epoch | ||
strategy: "choice" | ||
values: [96, 128] | ||
chunk_size: 128 # Number of rollouts to collect in one loop of orchestrator | ||
ppo_epochs: # Number of ppo epochs | ||
strategy: "randint" | ||
values: [3, 6] # 3 is inclusive, 6 is exclusive | ||
init_kl_coef: # init kl coefficient | ||
strategy: "quniform" | ||
values: [0.1, 0.3, 0.1] | ||
target: 6 # target kl coefficient, set None for fixed kl coef | ||
horizon: 10000 # PPO horizon | ||
gamma: 1 # PPO discount | ||
lam: # PPO lambda | ||
strategy: "uniform" | ||
values: [0.93, 0.98] | ||
cliprange: 0.2 # clip range | ||
cliprange_value: 0.2 # clip range | ||
vf_coef: 2.3 # value term weight | ||
gen_kwargs: | ||
max_length: 48 # LM max sample gen length | ||
min_length: 48 # LM min sample gen length | ||
top_k: 0.0 # top k | ||
top_p: 1.0 # top p | ||
do_sample: True # sample |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
# Usage: python train_sweep.py --config configs/ray_tune_configs/ppo_config.yml --example-name ppo_sentiments | ||
import wandb | ||
import argparse | ||
from pathlib import Path | ||
|
||
import ray | ||
from ray.air import session | ||
from ray import tune | ||
|
||
import trlx | ||
from trlx.ray_tune import load_ray_yaml | ||
from trlx.ray_tune import get_param_space | ||
from trlx.ray_tune import get_tune_config | ||
from trlx.ray_tune import get_train_function | ||
from trlx.ray_tune.wandb import log_trials, create_report | ||
|
||
from ray.tune.logger import JsonLoggerCallback | ||
from ray.tune.logger import CSVLoggerCallback | ||
|
||
|
||
def tune_function(train_function, param_space: dict, tune_config: dict, resources: dict): | ||
tuner = tune.Tuner( | ||
tune.with_resources(train_function, resources=resources), | ||
param_space=param_space, | ||
tune_config=tune.TuneConfig(**tune_config), | ||
run_config = ray.air.RunConfig( | ||
local_dir="ray_results", | ||
callbacks=[CSVLoggerCallback()] | ||
), | ||
) | ||
|
||
results = tuner.fit() | ||
|
||
log_trials( | ||
tuner._local_tuner.get_experiment_checkpoint_dir(), | ||
param_space["train"]["project_name"] | ||
) | ||
|
||
create_report( | ||
param_space, | ||
tune_config, | ||
Path(tuner._local_tuner.get_experiment_checkpoint_dir()).stem, | ||
results.get_best_result().config | ||
) | ||
|
||
print("Best hyperparameters found were: ", results.get_best_result().config) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--example-name", type=str, default="ppo_sentiments", help="Name of the example" | ||
) | ||
parser.add_argument( | ||
"--config", type=str, default=None, required=True, help="The config file defining the param_space." | ||
) | ||
parser.add_argument( | ||
"--num-cpus", type=int, default=4, help="Number of CPUs to use per exp." | ||
) | ||
parser.add_argument( | ||
"--num-gpus", type=int, default=1, help="Number of GPUs to use per exp." | ||
) | ||
parser.add_argument( | ||
"--server-address", | ||
type=str, | ||
default=None, | ||
required=False, | ||
help="The address of server to connect to if using Ray Client.", | ||
) | ||
|
||
args, _ = parser.parse_known_args() | ||
|
||
# Read config and parse it | ||
config = load_ray_yaml(args.config) | ||
tune_config = get_tune_config(config) | ||
param_space = get_param_space(config) | ||
|
||
# Initialize Ray. | ||
if args.server_address: | ||
ray.init(address=f"ray://{args.server_address}") | ||
else: | ||
ray.init() | ||
|
||
resources = { | ||
"cpu": args.num_cpus, | ||
"gpu": args.num_gpus, | ||
} | ||
|
||
# Register the training function that will be used for training the model. | ||
train_function = get_train_function(args.example_name) | ||
tune.register_trainable("train_function", train_function) | ||
|
||
tune_function(train_function, param_space, tune_config, resources) | ||
|
||
# Shut down Ray. | ||
ray.shutdown() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.