diff --git a/.gitignore b/.gitignore index 9fc606114..45e03b9e5 100644 --- a/.gitignore +++ b/.gitignore @@ -149,3 +149,5 @@ OUT/ examples/experiments/grounded_program_synthesis/dataset ckpts/ + +ray_results/ diff --git a/configs/ppo_config.yml b/configs/ppo_config.yml index ac22cd38c..38a36cddc 100644 --- a/configs/ppo_config.yml +++ b/configs/ppo_config.yml @@ -6,7 +6,7 @@ model: train: seq_length: 48 # Size of LM context - epochs: 1000 # Train for max(epochs, total_steps) + epochs: 1000 # Train for max(epochs, total_steps) total_steps: 10000 # Train for max(epochs, total_steps) batch_size: 128 # batch size diff --git a/configs/ray_tune_configs/ppo_config.yml b/configs/ray_tune_configs/ppo_config.yml new file mode 100644 index 000000000..8cc3d0fb9 --- /dev/null +++ b/configs/ray_tune_configs/ppo_config.yml @@ -0,0 +1,68 @@ +tune_config: + mode: "max" + metric: "mean_reward" + search_alg: "bohb" # random + scheduler: "hyperbandforbohb" # fifo + num_samples: 15 + max_concurrent_trials: null + time_budget_s: null + reuse_actors: null + +model: + model_path: "lvwerra/gpt2-imdb" # Name of hf model to load + tokenizer_path: "gpt2" # Name of hf tokenizer to load + model_type: "AcceleratePPOModel" # Name of accelerate model type to load + num_layers_unfrozen: # Number of bottom layers to freeze during training + strategy: "choice" + values: [2, 3] + +train: + seq_length: # Size of LM context + strategy: "choice" + values: [36, 48, 52] + epochs: # Train for max(epochs, total_steps) + strategy: "choice" + values: [80, 100, 120] + total_steps: 10000 # Train for max(epochs, total_steps) + batch_size: 128 # batch size + + lr_init: 1.412e-4 # init learning rate + lr_target: 1.412e-4 # target final learning rate + opt_betas: [0.9, 0.95] # adam betas + opt_eps: 1.0e-8 # adam eps + weight_decay: 1.0e-6 # weight decay param + + checkpoint_interval: 10000 # checkpoint interval + eval_interval: 4 # eval interval + + pipeline: "PPOPipeline" # prompt pipeline to load + orchestrator: "PPOOrchestrator" # orchestrator to load + project_name: "trlx-hyperopt-bohb" + +method: + name: 'ppoconfig' # Name of RL method config + num_rollouts: # Number of rollouts to collect per epoch + strategy: "choice" + values: [96, 128] + chunk_size: 128 # Number of rollouts to collect in one loop of orchestrator + ppo_epochs: # Number of ppo epochs + strategy: "randint" + values: [3, 6] # 3 is inclusive, 6 is exclusive + init_kl_coef: # init kl coefficient + strategy: "quniform" + values: [0.1, 0.3, 0.1] + target: 6 # target kl coefficient, set None for fixed kl coef + horizon: 10000 # PPO horizon + gamma: 1 # PPO discount + lam: # PPO lambda + strategy: "uniform" + values: [0.93, 0.98] + cliprange: 0.2 # clip range + cliprange_value: 0.2 # clip range + vf_coef: 2.3 # value term weight + gen_kwargs: + max_length: 48 # LM max sample gen length + min_length: 48 # LM min sample gen length + top_k: 0.0 # top k + top_p: 1.0 # top p + do_sample: True # sample diff --git a/setup.cfg b/setup.cfg index 1566b4077..d80a565e3 100644 --- a/setup.cfg +++ b/setup.cfg @@ -20,6 +20,8 @@ install_requires = transformers>=4.21.2 tqdm wandb + ray>=2.0.1 + tabulate>=0.9.0 [options.extras_require] dev = diff --git a/train_sweep.py b/train_sweep.py new file mode 100644 index 000000000..e59c988af --- /dev/null +++ b/train_sweep.py @@ -0,0 +1,96 @@ +# Usage: python train_sweep.py --config configs/ray_tune_configs/ppo_config.yml --example-name ppo_sentiments +import wandb +import argparse +from pathlib import Path + +import ray +from ray.air import session +from ray import tune + +import trlx +from trlx.ray_tune import load_ray_yaml +from trlx.ray_tune import get_param_space +from trlx.ray_tune import get_tune_config +from trlx.ray_tune import get_train_function +from trlx.ray_tune.wandb import log_trials, create_report + +from ray.tune.logger import JsonLoggerCallback +from ray.tune.logger import CSVLoggerCallback + + +def tune_function(train_function, param_space: dict, tune_config: dict, resources: dict): + tuner = tune.Tuner( + tune.with_resources(train_function, resources=resources), + param_space=param_space, + tune_config=tune.TuneConfig(**tune_config), + run_config = ray.air.RunConfig( + local_dir="ray_results", + callbacks=[CSVLoggerCallback()] + ), + ) + + results = tuner.fit() + + log_trials( + tuner._local_tuner.get_experiment_checkpoint_dir(), + param_space["train"]["project_name"] + ) + + create_report( + param_space, + tune_config, + Path(tuner._local_tuner.get_experiment_checkpoint_dir()).stem, + results.get_best_result().config + ) + + print("Best hyperparameters found were: ", results.get_best_result().config) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--example-name", type=str, default="ppo_sentiments", help="Name of the example" + ) + parser.add_argument( + "--config", type=str, default=None, required=True, help="The config file defining the param_space." + ) + parser.add_argument( + "--num-cpus", type=int, default=4, help="Number of CPUs to use per exp." + ) + parser.add_argument( + "--num-gpus", type=int, default=1, help="Number of GPUs to use per exp." + ) + parser.add_argument( + "--server-address", + type=str, + default=None, + required=False, + help="The address of server to connect to if using Ray Client.", + ) + + args, _ = parser.parse_known_args() + + # Read config and parse it + config = load_ray_yaml(args.config) + tune_config = get_tune_config(config) + param_space = get_param_space(config) + + # Initialize Ray. + if args.server_address: + ray.init(address=f"ray://{args.server_address}") + else: + ray.init() + + resources = { + "cpu": args.num_cpus, + "gpu": args.num_gpus, + } + + # Register the training function that will be used for training the model. + train_function = get_train_function(args.example_name) + tune.register_trainable("train_function", train_function) + + tune_function(train_function, param_space, tune_config, resources) + + # Shut down Ray. + ray.shutdown() diff --git a/trlx/data/configs.py b/trlx/data/configs.py index aaa6f22f9..516ad2231 100644 --- a/trlx/data/configs.py +++ b/trlx/data/configs.py @@ -142,7 +142,21 @@ def to_dict(self): """ Convert TRLConfig to dictionary. """ - data = self.model.__dict__.copy() - data.update(self.train.__dict__) - data.update(self.method.__dict__) + data = { + "model": self.model.__dict__, + "train": self.train.__dict__, + "method": self.method.__dict__, + } + return data + + @classmethod + def from_dict(cls, config_dict: dict): + """ + Convert dictionary to TRLConfig. + """ + return cls( + ModelConfig.from_dict(config_dict["model"]), + TrainConfig.from_dict(config_dict["train"]), + get_method(config_dict["method"]["name"]).from_dict(config_dict["method"]), + ) diff --git a/trlx/model/accelerate_base_model.py b/trlx/model/accelerate_base_model.py index 272a6a038..2a6bf3078 100644 --- a/trlx/model/accelerate_base_model.py +++ b/trlx/model/accelerate_base_model.py @@ -18,6 +18,17 @@ else: from tqdm import tqdm +import ray +from ray.air import session + + +def parse_results_for_session(results: dict): + for k, v in results.items(): + if isinstance(v, torch.Tensor): + results[k] = float(v) + + return results + @register_model class AccelerateRLModel(BaseRLModel): @@ -63,7 +74,7 @@ def __init__(self, config, train_mode=True): for m in gpt_blocks_to_freeze: m.requires_grad_(False) - if self.accelerator.is_main_process: + if self.accelerator.is_main_process and not ray.is_initialized(): self.accelerator.init_trackers( project_name=self.config.train.project_name, config=self.config.to_dict(), @@ -199,9 +210,8 @@ def evaluate(self): columns_data.append(values) rows = list(zip(*columns_data)) - stats["samples"] = wandb.Table(columns=columns, rows=rows) - - print(rows[0]) + if not ray.is_initialized(): + stats["samples"] = wandb.Table(columns=columns, rows=rows) return stats @@ -246,7 +256,14 @@ def learn(self): "backward_time": backward_time, } ) - self.accelerator.log(results) + + if not ray.is_initialized(): + self.accelerator.log(results) + + # Report the metrics to Ray Tune. + if ray.is_initialized(): + tmp_results = parse_results_for_session(results) + session.report(tmp_results) desc = ", ".join(f"{k}: {v:.2f}" for k, v in stats.items()) tbar.set_description(desc) diff --git a/trlx/orchestrator/ppo_orchestrator.py b/trlx/orchestrator/ppo_orchestrator.py index 7b8957001..7f6096aae 100644 --- a/trlx/orchestrator/ppo_orchestrator.py +++ b/trlx/orchestrator/ppo_orchestrator.py @@ -10,6 +10,9 @@ from trlx.utils import Clock from trlx.utils.modeling import logprobs_from_logits +import ray +from ray.air import session + @register_orchestrator class PPOOrchestrator(Orchestrator): @@ -124,7 +127,9 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): ppo_rl_elements += new_ppo_rl_elements stats = {"exp_time": exp_time} - self.rl_model.accelerator.log(stats, step=iter_count) + + if not ray.is_initialized(): + self.rl_model.accelerator.log(stats, step=iter_count) # Push samples and rewards to model's rollout storage self.rl_model.push_to_store(ppo_rl_elements) diff --git a/trlx/ray_tune/__init__.py b/trlx/ray_tune/__init__.py new file mode 100644 index 000000000..d0a63d071 --- /dev/null +++ b/trlx/ray_tune/__init__.py @@ -0,0 +1,196 @@ +import yaml +from ray import tune + + +def load_ray_yaml(path: str): + with open(path, "r") as f: + config = yaml.load(f, Loader=yaml.FullLoader) + return config + + +def get_param_space(config: dict): + """Get the param space from the config file.""" + + def get_strategy(value): + """Get search space strategy from config. + A search space defines valid values for your hyperparameters and + can specify how these values are sampled. + + Refer to the documentation for more info: + https://docs.ray.io/en/latest/tune/api_docs/search_space.html#tune-sample-docs + + The user will have to define the search space in the config file by providing + the name of the `strategy` and the `values` to sample from. + + The valid strategies are: + - `uniform` (List) - Samples uniformly between the given bounds. + - `quniform` (List) - Samples uniformly between the given bounds, quantized. + - `loguniform` (List) - Samples uniformly between the given bounds on a log scale. + - `qloguniform` (List) - Samples uniformly between the given bounds on a log scale, quantized. + - `randn` (List) - Samples from a normal distribution. + - `qrandn` (List) - Samples from a normal distribution, quantized. + - `randint` (List) - Samples uniformly between the given bounds, quantized to integers. + - `qrandint` (List) - Samples uniformly between the given bounds, quantized to integers. + - `lograndint` (List) - Samples uniformly between the given bounds on a log scale, quantized to integers. + - `qlograndint` (List) - Samples uniformly between the given bounds on a log scale, quantized to integers. + - `choice` (List) - Samples from a discrete set of values. + - `qrandn` (List) - Samples from a normal distribution, quantized. + - `grid_search` (List) - Samples from the given list of values. + + """ + + strategy = value["strategy"] + if strategy == "uniform": + assert isinstance(value["values"], list) + assert len(value["values"]) == 2 + return tune.uniform(*value["values"]) + elif strategy == "quniform": + assert isinstance(value["values"], list) + assert len(value["values"]) == 3 + return tune.quniform(*value["values"]) + elif strategy == "loguniform": + assert isinstance(value["values"], list) + assert len(value["values"]) == 3 + return tune.loguniform(*value["values"]) + elif strategy == "qloguniform": + assert isinstance(value["values"], list) + assert len(value["values"]) == 4 + return tune.qloguniform(*value["values"]) + elif strategy == "randn": + assert isinstance(value["values"], list) + assert len(value["values"]) == 2 + return tune.randn(*value["values"]) + elif strategy == "qrandn": + assert isinstance(value["values"], list) + assert len(value["values"]) == 3 + return tune.qrandn(*value["values"]) + elif strategy == "randint": + assert isinstance(value["values"], list) + assert len(value["values"]) == 2 + return tune.randint(*value["values"]) + elif strategy == "qrandint": + assert isinstance(value["values"], list) + assert len(value["values"]) == 3 + return tune.qrandint(*value["values"]) + elif strategy == "lograndint": + assert isinstance(value["values"], list) + assert len(value["values"]) == 3 + return tune.lograndint(*value["values"]) + elif strategy == "qlograndint": + assert isinstance(value["values"], list) + assert len(value["values"]) == 4 + return tune.qlograndint(*value["values"]) + elif strategy == "choice": + assert isinstance(value["values"], list) + return tune.choice(value["values"]) + elif strategy == "grid": + assert isinstance(value["values"], list) + return tune.grid_search(value["values"]) + + def parse_param_space(param_space: dict): + """Parse the param space from the config file by + replacing the strategies with relevant distribution APIs. + + """ + for k, v in param_space.items(): + if isinstance(v, dict): + if "strategy" in v.keys(): + strategy = get_strategy(v) + param_space[k] = strategy + + return param_space + + model_dict = parse_param_space(config["model"]) + train_dict = parse_param_space(config["train"]) + method_dict = parse_param_space(config["method"]) + + return {"model": model_dict, "train": train_dict, "method": method_dict} + + +def get_search_alg(tune_config: dict): + """Initialize the search algorithm and return it. + + Bayesian Optimization is currently supported. + """ + search_alg = tune_config["search_alg"] + + if search_alg == "bayesopt": + try: + from ray.tune.search.bayesopt import BayesOptSearch + except ImportError: + raise ImportError( + "Please pip install bayesian-optimization to use BayesOptSearch." + ) + + assert "metric" in tune_config.keys() and "mode" in tune_config.keys() + "Please specify metric and mode for BayesOptSearch." + + return BayesOptSearch(metric=tune_config["metric"], mode=tune_config["mode"]) + elif search_alg == "bohb": + try: + from ray.tune.search.bohb import TuneBOHB + except ImportError: + raise ImportError( + "Please pip install hpbandster and ConfigSpace to use TuneBOHB." + ) + + assert "metric" in tune_config.keys() and "mode" in tune_config.keys() + "Please specify metric and mode for TuneBOHB." + + return TuneBOHB() + elif search_alg == "random": + return None + else: + NotImplementedError("Search algorithm not supported.") + + +def get_scheduler(tune_config: dict): + """Initialize the scheduler and return it. + + The schedulers can early terminate bad trials, pause trials, + clone trials, and alter hyperparameters of a running trial. + + Refer to the documentation for more info: + https://docs.ray.io/en/latest/tune/api_docs/schedulers.html#tune-schedulers + + Currently available schedulers are: + - `hyperband` - Implements the HyperBand early stopping algorithm. + + """ + scheduler = tune_config["scheduler"] + + if scheduler == "hyperband": + return tune.schedulers.HyperBandScheduler() + elif scheduler == "hyperbandforbohb": + return tune.schedulers.HyperBandForBOHB() + elif scheduler == "fifo": + return None + else: + NotImplementedError("Scheduler not supported.") + + +def get_tune_config(config: dict): + """Get the tune config to initialized `tune.TuneConfig` + to be passed `tune.Tuner`. + """ + tune_config = config["tune_config"] + + if "search_alg" in tune_config.keys() and tune_config["search_alg"] is not None: + tune_config["search_alg"] = get_search_alg(tune_config) + + if "scheduler" in tune_config.keys() and tune_config["scheduler"] is not None: + tune_config["scheduler"] = get_scheduler(tune_config) + + # Remove config keys with None values. + tune_config = {k: v for k, v in tune_config.items() if v is not None} + + return tune_config + + +def get_train_function(example_name: str): + if example_name == "ppo_sentiments": + from .train_funcs import ppo_sentiments_train + + return ppo_sentiments_train + else: + NotImplementedError("Example not implemented yet.") diff --git a/trlx/ray_tune/train_funcs.py b/trlx/ray_tune/train_funcs.py new file mode 100644 index 000000000..c00971706 --- /dev/null +++ b/trlx/ray_tune/train_funcs.py @@ -0,0 +1,32 @@ +# Find the optimal hyperparameters to generates positive movie +# reviews by tuning a pretrained on IMDB model with a sentiment reward function. + +from datasets import load_dataset + +import trlx +from trlx.data.configs import TRLConfig + + +def ppo_sentiments_train(config: dict): + from transformers import pipeline + + config = TRLConfig.from_dict(config) + + sentiment_fn = pipeline("sentiment-analysis", "lvwerra/distilbert-imdb", device=-1) + + def reward_fn(samples): + outputs = sentiment_fn(samples, return_all_scores=True) + sentiments = [output[1]["score"] for output in outputs] + return sentiments + + # Take few words off of movies reviews as prompts + imdb = load_dataset("imdb", split="train+test") + prompts = [" ".join(review.split()[:4]) for review in imdb["text"]] + + model = trlx.train( + "lvwerra/gpt2-imdb", + reward_fn=reward_fn, + prompts=prompts, + eval_prompts=["I don't know much about Hungarian underground"] * 64, + config=config, + ) diff --git a/trlx/ray_tune/wandb.py b/trlx/ray_tune/wandb.py new file mode 100644 index 000000000..191c5f059 --- /dev/null +++ b/trlx/ray_tune/wandb.py @@ -0,0 +1,201 @@ +"""Utility function to log the results of a Ray Tune experiment to W&B.""" + +import os +import json +import pandas as pd +from pathlib import Path + +import wandb + +wandb.require("report-editing") +import wandb.apis.reports as wb + + +def parse_result(result): + tmp_result = {} + for key, value in result.items(): + if not "time" in key and isinstance(value, (int, float)): + tmp_result[key] = value + + tmp_result.pop("done", None) + tmp_result.pop("timesteps_total", None) + tmp_result.pop("episodes_total", None) + tmp_result.pop("iterations_since_restore", None) + tmp_result.pop("training_iteration", None) + tmp_result.pop("pid", None) + + return tmp_result + + +def log_trials(trial_path: str, project_name: str): + trial_path = Path(trial_path) + files = os.listdir(trial_path) + + trial_paths = [] + for filename in files: + tmp_path = os.path.join(trial_path, filename) + if os.path.isdir(tmp_path): + trial_paths.append(tmp_path) + + for trial in trial_paths: + files = os.listdir(trial) + + # Open params.json and load the configs for that trial. + with open(os.path.join(trial, "params.json"), "r") as f: + params = json.load(f) + + # Initialize wandb + run = wandb.init( + project=project_name, + config=dict(params), + group=trial_path.stem, + job_type="hyperopt", + ) + + # Open result.json and log the metrics to W&B. + with open(os.path.join(trial, "result.json"), "r") as f: + for line in f: + result = dict(json.loads(line)) + result.pop("config", None) + result = parse_result(result) + wandb.log(result) + + # Close the W&B run. + run.finish() + + +def create_report(param_space, tune_config, trial_path, best_config=None): + def get_column_names(param_space): + column_names = [] + for k, v in param_space.items(): + if isinstance(v, dict): + for kk, vv in v.items(): + try: + if vv.__module__ == "ray.tune.search.sample": + column_names.append(f"{k}.{kk}") + except AttributeError: + pass + + return column_names + + def get_parallel_coordinate(param_space, metric): + column_names = get_column_names(param_space) + columns = [wb.reports.PCColumn(column) for column in column_names] + + return wb.ParallelCoordinatesPlot( + columns=columns + [wb.reports.PCColumn(metric)], + layout={"x": 0, "y": 0, "w": 12 * 2, "h": 5 * 2}, + ) + + def get_param_importance(metric): + return wb.ParameterImportancePlot( + # Get it from the metric name. + with_respect_to=metric, + layout={"x": 0, "y": 5, "w": 6 * 2, "h": 4 * 2}, + ) + + def get_scatter_plot(metric): + return wb.ScatterPlot( + # Get it from the metric name. + title=f"{metric} v. Index", + x="Index", + y=metric, + running_ymin=True, + font_size="small", + layout={"x": 6, "y": 5, "w": 6 * 2, "h": 4 * 2}, + ) + + def get_metrics_with_history(project_name, group_name, entity=None): + entity_project = f"{entity}/{project_name}" if entity else project_name + api = wandb.Api() + runs = api.runs(entity_project) + for run in runs: + if run.group == str(group_name): + history = run.history() + metrics = history.columns + break + + metrics = [metric for metric in metrics if not metric.startswith("_")] + return metrics + + report = wb.Report( + project=param_space["train"]["project_name"], + title=f"Hyperparameter Optimization Report: {trial_path}", + description="This is a report that shows the results of a hyperparameter optimization experiment.", + ) + + report.blocks = [ + wb.P( + "The following plots show the results of the hyperparameter optimization experiment. " + "Use this as a starting point for your analysis. Go in the edit mode to customize the report. " + "Share it with your team to collaborate on the analysis." + ), + wb.H1(text="Analysis"), + wb.P( + "Parallel coordinates chart (top) summarize the relationship between large numbers of hyperparameters " + "and model metrics at a glance. \nThe scatter plot (right) compares the different trials and gives you a " + "insight on how the trials progressed. \nThe parameter importance plot(left) lists the hyperparameters " + "that were the best predictors of, and highly correlated to desirable values of your metrics." + ), + wb.PanelGrid( + panels=[ + get_parallel_coordinate(param_space, tune_config["metric"]), + get_param_importance(tune_config["metric"]), + get_scatter_plot(tune_config["metric"]), + ], + runsets=[ + wb.RunSet( + project=param_space["train"]["project_name"] + ).set_filters_with_python_expr(f'group == "{trial_path}"') + ], + ), + ] + + metrics = get_metrics_with_history( + param_space["train"]["project_name"], + trial_path, + ) + + line_plot_panels = [] + for metric in metrics: + line_plot_panels.append( + wb.LinePlot( + title=f"{metric}", + x="Step", + y=[f"{metric}"], + title_x="Step", + smoothing_show_original=True, + max_runs_to_show=10, + plot_type="line", + font_size="auto", + legend_position="north", + ) + ) + + report.blocks = report.blocks + [ + wb.H1(text="Metrics"), + wb.P( + "The following line plots show the metrics for each trial. Use this to investigate the " + "performance of the model for each trial at the metrics level." + ), + wb.PanelGrid( + panels=line_plot_panels, + runsets=[ + wb.RunSet( + project=param_space["train"]["project_name"] + ).set_filters_with_python_expr(f'group == "{trial_path}"') + ], + ), + ] + + if best_config: + report.blocks = report.blocks + [ + wb.H1(text="Best Config"), + wb.P( + "The code block shown below is the best config found by the hyperparameter " + "optimization experiment according to Ray Tune." + ), + wb.CodeBlock(code=[json.dumps(best_config, indent=4)], language="json"), + ] + + report.save() diff --git a/trlx/trlx.py b/trlx/trlx.py index 4e01b5d82..9fa6bcdb9 100644 --- a/trlx/trlx.py +++ b/trlx/trlx.py @@ -1,6 +1,8 @@ import os from typing import Callable, Iterable, List, Optional, Tuple +from accelerate import Accelerator + from trlx.data.configs import TRLConfig from trlx.model.accelerate_ilql_model import AccelerateILQLModel @@ -12,6 +14,8 @@ from trlx.pipeline.offline_pipeline import PromptPipeline from trlx.utils.loading import get_model, get_orchestrator +import ray + def train( model_path: Optional[str] = None, @@ -76,7 +80,9 @@ def train( config.model.model_path = model_path model = AccelerateILQLModel( - config=config, logit_mask=logit_mask, metric_fn=metric_fn + config=config, + logit_mask=logit_mask, + metric_fn=metric_fn, ) batch_size = config.train.batch_size * int(os.environ.get("WORLD_SIZE", 1))