Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hyperparameter Optimization with Ray Tune and Weights and Biases #76

Merged
merged 30 commits into from
Nov 14, 2022
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
7447bed
req + sweep config + sweep file + config done
ayulockin Oct 26, 2022
dd642d7
sample config + tune_config parsing
ayulockin Oct 27, 2022
053698f
wip trainable fn
ayulockin Oct 27, 2022
cf19b71
ray tune integrated
ayulockin Oct 27, 2022
0341c94
ray tune working
ayulockin Oct 28, 2022
4f880cf
add scheduler + search algo
ayulockin Oct 28, 2022
a8cbcd4
fix config
ayulockin Oct 28, 2022
e48a40d
update ilql model
ayulockin Oct 28, 2022
968deb8
condition on ray init
ayulockin Oct 28, 2022
4d19f5b
condition on ray init
ayulockin Oct 28, 2022
b0de179
remove torch to float convertion
ayulockin Oct 28, 2022
cf23408
refactor
ayulockin Oct 28, 2022
d6cf517
refactor utils + documentation
ayulockin Oct 28, 2022
bd2d2c6
refactor + generalized
ayulockin Oct 28, 2022
311141b
bayesian opt with hyperband
ayulockin Oct 31, 2022
363fb0a
remove comment
ayulockin Oct 31, 2022
55aab3d
Merge branch 'master' into hyperopt
ayulockin Nov 1, 2022
5157c3c
decouple wandb to fix multiprocessing issues with bayes opt
ayulockin Nov 1, 2022
19b052a
log to wandb
ayulockin Nov 1, 2022
cb39537
integrate wandb
ayulockin Nov 1, 2022
7ea8671
trained fully
ayulockin Nov 2, 2022
89d2e60
remove empty file
ayulockin Nov 2, 2022
d863986
remove run
ayulockin Nov 9, 2022
a48f349
revert print
ayulockin Nov 9, 2022
c40edbe
default=1
ayulockin Nov 9, 2022
0ff95fb
add programatic report
ayulockin Nov 10, 2022
62a8063
revert unnecessary changes
ayulockin Nov 14, 2022
72dbeec
support random search + fifo scheduler
ayulockin Nov 14, 2022
486c14e
Merge remote-tracking branch 'origin/master' into hyperopt
ayulockin Nov 14, 2022
dcaaeaf
adapt new config
ayulockin Nov 14, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,5 @@ wandb/
OUT/

ckpts/

ray_results/
2 changes: 1 addition & 1 deletion configs/ppo_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ model:

train:
seq_length: 48 # Size of LM context
epochs: 1000 # Train for max(epochs, total_steps)
epochs: 1000 # Train for max(epochs, total_steps)
total_steps: 10000 # Train for max(epochs, total_steps)
batch_size: 128 # batch size

Expand Down
69 changes: 69 additions & 0 deletions configs/ray_tune_configs/ppo_config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
tune_config:
mode: "max"
metric: "mean_reward"
search_alg: "bohb" # null
ayulockin marked this conversation as resolved.
Show resolved Hide resolved
scheduler: "hyperbandforbohb" # null
ayulockin marked this conversation as resolved.
Show resolved Hide resolved
num_samples: 15
max_concurrent_trials: null
time_budget_s: null
reuse_actors: null

model:
model_path: "lvwerra/gpt2-imdb" # Name of hf model to load
tokenizer_path: "gpt2" # Name of hf tokenizer to load
model_type: "AcceleratePPOModel" # Name of accelerate model type to load
num_layers_unfrozen: # Number of bottom layers to freeze during training
strategy: "choice"
values: [2, 3]

train:
seq_length: # Size of LM context
strategy: "choice"
values: [36, 48, 52]
epochs: # Train for max(epochs, total_steps)
strategy: "choice"
values: [80, 100, 120]
total_steps: 10000 # Train for max(epochs, total_steps)
batch_size: 128 # batch size

lr_ramp_steps: 100 # learning rate warm up
lr_decay_steps: 79000 # learning rate decay
weight_decay: 1.0e-6 # weight decay param
learning_rate_init: 1.412e-4 # init learning rate
learning_rate_target: 1.412e-4 # target final learning rate
opt_betas: [0.9, 0.95] # adam betas

checkpoint_interval: 10000 # checkpoint interval
eval_interval: 4 # eval interval

pipeline: "PPOPipeline" # prompt pipeline to load
orchestrator: "PPOOrchestrator" # orchestrator to load
project_name: "trlx-hyperopt-bohb"

method:
name: 'ppoconfig' # Name of RL method config
num_rollouts: # Number of rollouts to collect per epoch
strategy: "choice"
values: [96, 128]
chunk_size: 128 # Number of rollouts to collect in one loop of orchestrator
ppo_epochs: # Number of ppo epochs
strategy: "randint"
values: [3, 6] # 3 is inclusive, 6 is exclusive
init_kl_coef: # init kl coefficient
strategy: "quniform"
values: [0.1, 0.3, 0.1]
target: 6 # target kl coefficient, set None for fixed kl coef
horizon: 10000 # PPO horizon
gamma: 1 # PPO discount
lam: # PPO lambda
strategy: "uniform"
values: [0.93, 0.98]
cliprange: 0.2 # clip range
cliprange_value: 0.2 # clip range
vf_coef: 2.3 # value term weight
gen_kwargs:
max_length: 48 # LM max sample gen length
min_length: 48 # LM min sample gen length
top_k: 0.0 # top k
top_p: 1.0 # top p
do_sample: True # sample
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ install_requires =
transformers>=4.21.2
tqdm
wandb
ray>=2.0.1
tabulate>=0.9.0

[options.extras_require]
dev =
Expand Down
88 changes: 88 additions & 0 deletions train_sweep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Usage: python train_sweep.py --config configs/ray_tune_configs/ppo_config.yml --example-name ppo_sentiments
import wandb
import argparse

import ray
from ray.air import session
from ray import tune

import trlx
from trlx.ray_tune import load_ray_yaml
from trlx.ray_tune import get_param_space
from trlx.ray_tune import get_tune_config
from trlx.ray_tune import get_train_function
from trlx.ray_tune.wandb import log_trials

from ray.tune.logger import JsonLoggerCallback
from ray.tune.logger import CSVLoggerCallback


def tune_function(train_function, param_space: dict, tune_config: dict, resources: dict):
tuner = tune.Tuner(
tune.with_resources(train_function, resources=resources),
param_space=param_space,
tune_config=tune.TuneConfig(**tune_config),
run_config = ray.air.RunConfig(
local_dir="ray_results",
callbacks=[CSVLoggerCallback()]
),
)

results = tuner.fit()

log_trials(
tuner._local_tuner.get_experiment_checkpoint_dir(),
param_space["train"]["project_name"]
)

print("Best hyperparameters found were: ", results.get_best_result().config)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--example-name", type=str, default="ppo_sentiments", help="Name of the example"
)
parser.add_argument(
"--config", type=str, default=None, required=True, help="The config file defining the param_space."
)
parser.add_argument(
"--num-cpus", type=int, default=4, help="Number of CPUs to use per exp."
)
parser.add_argument(
"--num-gpus", type=int, default=0, help="Number of GPUs to use per exp."
ayulockin marked this conversation as resolved.
Show resolved Hide resolved
)
parser.add_argument(
"--server-address",
type=str,
default=None,
required=False,
help="The address of server to connect to if using Ray Client.",
)

args, _ = parser.parse_known_args()
ayulockin marked this conversation as resolved.
Show resolved Hide resolved

# Read config and parse it
config = load_ray_yaml(args.config)
tune_config = get_tune_config(config)
param_space = get_param_space(config)

# Initialize Ray.
if args.server_address:
ray.init(address=f"ray://{args.server_address}")
else:
ray.init()

resources = {
"cpu": args.num_cpus,
"gpu": args.num_gpus,
}

# Register the training function that will be used for training the model.
train_function = get_train_function(args.example_name)
tune.register_trainable("train_function", train_function)

tune_function(train_function, param_space, tune_config, resources)

# Shut down Ray.
ray.shutdown()
20 changes: 17 additions & 3 deletions trlx/data/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,21 @@ def to_dict(self):
"""
Convert TRLConfig to dictionary.
"""
data = self.model.__dict__.copy()
data.update(self.train.__dict__)
data.update(self.method.__dict__)
data = {
"model": self.model.__dict__,
"train": self.train.__dict__,
"method": self.method.__dict__,
}

return data

@classmethod
def from_dict(cls, config_dict: dict):
"""
Convert dictionary to TRLConfig.
"""
return cls(
ModelConfig.from_dict(config_dict["model"]),
TrainConfig.from_dict(config_dict["train"]),
get_method(config_dict["method"]["name"]).from_dict(config_dict["method"]),
)
47 changes: 24 additions & 23 deletions trlx/model/accelerate_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import torch
import torch.nn.functional as F
from accelerate import Accelerator
from transformers import AutoTokenizer

import wandb
Expand All @@ -18,17 +17,28 @@
else:
from tqdm import tqdm

import ray
from ray.air import session


def parse_results_for_session(results: dict):
for k, v in results.items():
if isinstance(v, torch.Tensor):
results[k] = float(v)
ayulockin marked this conversation as resolved.
Show resolved Hide resolved

return results


@register_model
class AccelerateRLModel(BaseRLModel):
"""
RL Model that uses accelerate for training
"""

def __init__(self, config, train_mode=True):
def __init__(self, config, accelerator, train_mode=True):
super().__init__(config, train_mode)

self.accelerator = Accelerator(log_with="wandb")
self.accelerator = accelerator

if int(os.environ.get("WORLD_SIZE", 1)) > 1:
torch.distributed.barrier(device_ids=[int(os.environ.get("LOCAL_RANK", 0))])
Expand Down Expand Up @@ -63,21 +73,6 @@ def __init__(self, config, train_mode=True):
for m in gpt_blocks_to_freeze:
m.requires_grad_(False)

if self.accelerator.is_main_process:
self.accelerator.init_trackers(
project_name=self.config.train.project_name,
config=self.config.to_dict(),
init_kwargs={
"wandb": {
"name": f"{config.model.model_path}",
"entity": self.config.train.entity_name,
"mode": "disabled"
if os.environ.get("debug", False)
else "online",
}
},
)

self.opt = torch.optim.AdamW(
self.model.parameters(),
lr=float(self.config.train.learning_rate_init),
Expand Down Expand Up @@ -174,7 +169,7 @@ def evaluate(self):
columns.append("reward")
columns_data.append(rewards)
stats["mean_reward"] = mean_reward
print(f"{mean_reward=}")
# print(f"{mean_reward=}")
ayulockin marked this conversation as resolved.
Show resolved Hide resolved

# additionally log any other metrics
if self.metric_fn:
Expand All @@ -194,9 +189,8 @@ def evaluate(self):
columns_data.append(values)

rows = list(zip(*columns_data))
stats["samples"] = wandb.Table(columns=columns, rows=rows)

print(rows[0])
if not ray.is_initialized():
stats["samples"] = wandb.Table(columns=columns, rows=rows)
ayulockin marked this conversation as resolved.
Show resolved Hide resolved

return stats

Expand Down Expand Up @@ -241,7 +235,14 @@ def learn(self):
"backward_time": backward_time,
}
)
self.accelerator.log(results)

if not ray.is_initialized():
self.accelerator.log(results)

# Report the metrics to Ray Tune.
if ray.is_initialized():
tmp_results = parse_results_for_session(results)
session.report(tmp_results)

desc = ", ".join(f"{k}: {v:.2f}" for k, v in stats.items())
tbar.set_description(desc)
Expand Down
3 changes: 2 additions & 1 deletion trlx/model/accelerate_ilql_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@ class AccelerateILQLModel(AccelerateRLModel):
def __init__(
self,
config,
accelerator,
logit_mask=None,
metric_fn=None,
train_mode=True,
):
super().__init__(config, train_mode)
super().__init__(config, train_mode, accelerator)
self.logit_mask = logit_mask
self.metric_fn = metric_fn
self.reward_fn = None
Expand Down
4 changes: 2 additions & 2 deletions trlx/model/accelerate_ppo_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ def update(self, current, n_steps):

@register_model
class AcceleratePPOModel(AccelerateRLModel):
def __init__(self, config):
super().__init__(config)
def __init__(self, config, accelerator):
super().__init__(config, accelerator)

self.store = PPORolloutStorage(self.tokenizer.pad_token_id)

Expand Down
7 changes: 6 additions & 1 deletion trlx/orchestrator/ppo_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from trlx.utils import Clock
from trlx.utils.modeling import logprobs_from_logits

import ray
from ray.air import session


@register_orchestrator
class PPOOrchestrator(Orchestrator):
Expand Down Expand Up @@ -124,7 +127,9 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0):
ppo_rl_elements += new_ppo_rl_elements

stats = {"exp_time": exp_time}
self.rl_model.accelerator.log(stats, step=iter_count)

if not ray.is_initialized():
self.rl_model.accelerator.log(stats, step=iter_count)

# Push samples and rewards to model's rollout storage
self.rl_model.push_to_store(ppo_rl_elements)
Loading