Skip to content

Commit

Permalink
Hyperparameter Optimization with Ray Tune and Weights and Biases (#76)
Browse files Browse the repository at this point in the history
  • Loading branch information
ayulockin authored Nov 14, 2022
1 parent 3900999 commit 59b3c65
Show file tree
Hide file tree
Showing 12 changed files with 650 additions and 11 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -149,3 +149,5 @@ OUT/

examples/experiments/grounded_program_synthesis/dataset
ckpts/

ray_results/
2 changes: 1 addition & 1 deletion configs/ppo_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ model:

train:
seq_length: 48 # Size of LM context
epochs: 1000 # Train for max(epochs, total_steps)
epochs: 1000 # Train for max(epochs, total_steps)
total_steps: 10000 # Train for max(epochs, total_steps)
batch_size: 128 # batch size

Expand Down
68 changes: 68 additions & 0 deletions configs/ray_tune_configs/ppo_config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
tune_config:
mode: "max"
metric: "mean_reward"
search_alg: "bohb" # random
scheduler: "hyperbandforbohb" # fifo
num_samples: 15
max_concurrent_trials: null
time_budget_s: null
reuse_actors: null

model:
model_path: "lvwerra/gpt2-imdb" # Name of hf model to load
tokenizer_path: "gpt2" # Name of hf tokenizer to load
model_type: "AcceleratePPOModel" # Name of accelerate model type to load
num_layers_unfrozen: # Number of bottom layers to freeze during training
strategy: "choice"
values: [2, 3]

train:
seq_length: # Size of LM context
strategy: "choice"
values: [36, 48, 52]
epochs: # Train for max(epochs, total_steps)
strategy: "choice"
values: [80, 100, 120]
total_steps: 10000 # Train for max(epochs, total_steps)
batch_size: 128 # batch size

lr_init: 1.412e-4 # init learning rate
lr_target: 1.412e-4 # target final learning rate
opt_betas: [0.9, 0.95] # adam betas
opt_eps: 1.0e-8 # adam eps
weight_decay: 1.0e-6 # weight decay param

checkpoint_interval: 10000 # checkpoint interval
eval_interval: 4 # eval interval

pipeline: "PPOPipeline" # prompt pipeline to load
orchestrator: "PPOOrchestrator" # orchestrator to load
project_name: "trlx-hyperopt-bohb"

method:
name: 'ppoconfig' # Name of RL method config
num_rollouts: # Number of rollouts to collect per epoch
strategy: "choice"
values: [96, 128]
chunk_size: 128 # Number of rollouts to collect in one loop of orchestrator
ppo_epochs: # Number of ppo epochs
strategy: "randint"
values: [3, 6] # 3 is inclusive, 6 is exclusive
init_kl_coef: # init kl coefficient
strategy: "quniform"
values: [0.1, 0.3, 0.1]
target: 6 # target kl coefficient, set None for fixed kl coef
horizon: 10000 # PPO horizon
gamma: 1 # PPO discount
lam: # PPO lambda
strategy: "uniform"
values: [0.93, 0.98]
cliprange: 0.2 # clip range
cliprange_value: 0.2 # clip range
vf_coef: 2.3 # value term weight
gen_kwargs:
max_length: 48 # LM max sample gen length
min_length: 48 # LM min sample gen length
top_k: 0.0 # top k
top_p: 1.0 # top p
do_sample: True # sample
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ install_requires =
transformers>=4.21.2
tqdm
wandb
ray>=2.0.1
tabulate>=0.9.0

[options.extras_require]
dev =
Expand Down
96 changes: 96 additions & 0 deletions train_sweep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Usage: python train_sweep.py --config configs/ray_tune_configs/ppo_config.yml --example-name ppo_sentiments
import wandb
import argparse
from pathlib import Path

import ray
from ray.air import session
from ray import tune

import trlx
from trlx.ray_tune import load_ray_yaml
from trlx.ray_tune import get_param_space
from trlx.ray_tune import get_tune_config
from trlx.ray_tune import get_train_function
from trlx.ray_tune.wandb import log_trials, create_report

from ray.tune.logger import JsonLoggerCallback
from ray.tune.logger import CSVLoggerCallback


def tune_function(train_function, param_space: dict, tune_config: dict, resources: dict):
tuner = tune.Tuner(
tune.with_resources(train_function, resources=resources),
param_space=param_space,
tune_config=tune.TuneConfig(**tune_config),
run_config = ray.air.RunConfig(
local_dir="ray_results",
callbacks=[CSVLoggerCallback()]
),
)

results = tuner.fit()

log_trials(
tuner._local_tuner.get_experiment_checkpoint_dir(),
param_space["train"]["project_name"]
)

create_report(
param_space,
tune_config,
Path(tuner._local_tuner.get_experiment_checkpoint_dir()).stem,
results.get_best_result().config
)

print("Best hyperparameters found were: ", results.get_best_result().config)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--example-name", type=str, default="ppo_sentiments", help="Name of the example"
)
parser.add_argument(
"--config", type=str, default=None, required=True, help="The config file defining the param_space."
)
parser.add_argument(
"--num-cpus", type=int, default=4, help="Number of CPUs to use per exp."
)
parser.add_argument(
"--num-gpus", type=int, default=1, help="Number of GPUs to use per exp."
)
parser.add_argument(
"--server-address",
type=str,
default=None,
required=False,
help="The address of server to connect to if using Ray Client.",
)

args, _ = parser.parse_known_args()

# Read config and parse it
config = load_ray_yaml(args.config)
tune_config = get_tune_config(config)
param_space = get_param_space(config)

# Initialize Ray.
if args.server_address:
ray.init(address=f"ray://{args.server_address}")
else:
ray.init()

resources = {
"cpu": args.num_cpus,
"gpu": args.num_gpus,
}

# Register the training function that will be used for training the model.
train_function = get_train_function(args.example_name)
tune.register_trainable("train_function", train_function)

tune_function(train_function, param_space, tune_config, resources)

# Shut down Ray.
ray.shutdown()
20 changes: 17 additions & 3 deletions trlx/data/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,21 @@ def to_dict(self):
"""
Convert TRLConfig to dictionary.
"""
data = self.model.__dict__.copy()
data.update(self.train.__dict__)
data.update(self.method.__dict__)
data = {
"model": self.model.__dict__,
"train": self.train.__dict__,
"method": self.method.__dict__,
}

return data

@classmethod
def from_dict(cls, config_dict: dict):
"""
Convert dictionary to TRLConfig.
"""
return cls(
ModelConfig.from_dict(config_dict["model"]),
TrainConfig.from_dict(config_dict["train"]),
get_method(config_dict["method"]["name"]).from_dict(config_dict["method"]),
)
27 changes: 22 additions & 5 deletions trlx/model/accelerate_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,17 @@
else:
from tqdm import tqdm

import ray
from ray.air import session


def parse_results_for_session(results: dict):
for k, v in results.items():
if isinstance(v, torch.Tensor):
results[k] = float(v)

return results


@register_model
class AccelerateRLModel(BaseRLModel):
Expand Down Expand Up @@ -63,7 +74,7 @@ def __init__(self, config, train_mode=True):
for m in gpt_blocks_to_freeze:
m.requires_grad_(False)

if self.accelerator.is_main_process:
if self.accelerator.is_main_process and not ray.is_initialized():
self.accelerator.init_trackers(
project_name=self.config.train.project_name,
config=self.config.to_dict(),
Expand Down Expand Up @@ -199,9 +210,8 @@ def evaluate(self):
columns_data.append(values)

rows = list(zip(*columns_data))
stats["samples"] = wandb.Table(columns=columns, rows=rows)

print(rows[0])
if not ray.is_initialized():
stats["samples"] = wandb.Table(columns=columns, rows=rows)

return stats

Expand Down Expand Up @@ -246,7 +256,14 @@ def learn(self):
"backward_time": backward_time,
}
)
self.accelerator.log(results)

if not ray.is_initialized():
self.accelerator.log(results)

# Report the metrics to Ray Tune.
if ray.is_initialized():
tmp_results = parse_results_for_session(results)
session.report(tmp_results)

desc = ", ".join(f"{k}: {v:.2f}" for k, v in stats.items())
tbar.set_description(desc)
Expand Down
7 changes: 6 additions & 1 deletion trlx/orchestrator/ppo_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from trlx.utils import Clock
from trlx.utils.modeling import logprobs_from_logits

import ray
from ray.air import session


@register_orchestrator
class PPOOrchestrator(Orchestrator):
Expand Down Expand Up @@ -124,7 +127,9 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0):
ppo_rl_elements += new_ppo_rl_elements

stats = {"exp_time": exp_time}
self.rl_model.accelerator.log(stats, step=iter_count)

if not ray.is_initialized():
self.rl_model.accelerator.log(stats, step=iter_count)

# Push samples and rewards to model's rollout storage
self.rl_model.push_to_store(ppo_rl_elements)
Loading

0 comments on commit 59b3c65

Please sign in to comment.