From f71401e9fd3480921f10fe686c7484217a3862e3 Mon Sep 17 00:00:00 2001 From: Ayush Thakur Date: Tue, 10 Jan 2023 06:17:55 +0530 Subject: [PATCH] Fix wandb.errors.RequireError as reported in #162 (#167) * check wandb version * Update wandb.py * add wandb version + mean/reward * fix api changes * Update ILQL sweep hparams * Run `isort` * Remove `return`s from examples to avoid RayTune erros * Skip `isort` on wandb reports import * Remove erroring `tqdm` call in `BaseRLTrainer.evaluate` Co-authored-by: jon-tow --- configs/sweeps/ilql_sweep.yml | 2 +- configs/sweeps/ppo_sweep.yml | 4 ++-- examples/ppo_sentiments.py | 2 +- .../summarize_daily_cnn/t5_summarize_daily_cnn.py | 2 +- setup.cfg | 2 +- trlx/ray_tune/wandb.py | 12 ++++++------ trlx/trainer/accelerate_base_trainer.py | 2 +- 7 files changed, 13 insertions(+), 13 deletions(-) diff --git a/configs/sweeps/ilql_sweep.yml b/configs/sweeps/ilql_sweep.yml index db4464430..ead70b0fd 100644 --- a/configs/sweeps/ilql_sweep.yml +++ b/configs/sweeps/ilql_sweep.yml @@ -5,7 +5,7 @@ tune_config: scheduler: "fifo" num_samples: 32 -lr_init: +lr: strategy: "loguniform" values: [0.00001, 0.01] tau: diff --git a/configs/sweeps/ppo_sweep.yml b/configs/sweeps/ppo_sweep.yml index 95469ef1b..cb939e376 100644 --- a/configs/sweeps/ppo_sweep.yml +++ b/configs/sweeps/ppo_sweep.yml @@ -1,12 +1,12 @@ tune_config: mode: "max" - metric: "mean_reward" + metric: "reward/mean" search_alg: "random" scheduler: "fifo" num_samples: 32 # https://docs.ray.io/en/latest/tune/api_docs/search_space.html#tune-sample-docs -lr_init: +lr: strategy: "loguniform" values: [0.00001, 0.01] init_kl_coef: diff --git a/examples/ppo_sentiments.py b/examples/ppo_sentiments.py index 17db29f84..8b9abea25 100644 --- a/examples/ppo_sentiments.py +++ b/examples/ppo_sentiments.py @@ -48,7 +48,7 @@ def reward_fn(samples: List[str]) -> List[float]: imdb = load_dataset("imdb", split="test") val_prompts = [" ".join(review.split()[:4]) for review in imdb["text"]] - return trlx.train( + trlx.train( reward_fn=reward_fn, prompts=prompts, eval_prompts=val_prompts[0:1000], diff --git a/examples/summarize_daily_cnn/t5_summarize_daily_cnn.py b/examples/summarize_daily_cnn/t5_summarize_daily_cnn.py index a48fd9123..7ae25e60e 100755 --- a/examples/summarize_daily_cnn/t5_summarize_daily_cnn.py +++ b/examples/summarize_daily_cnn/t5_summarize_daily_cnn.py @@ -74,7 +74,7 @@ def reward_fn(samples: List[str]): ) # get prompt like trlx's prompt prompt_label[key.strip()] = val_summaries[i] - model = trlx.train( + trlx.train( config.model.model_path, reward_fn=reward_fn, prompts=prompts, diff --git a/setup.cfg b/setup.cfg index c96a6b55f..3bc55caac 100644 --- a/setup.cfg +++ b/setup.cfg @@ -19,7 +19,7 @@ install_requires = torchtyping transformers>=4.21.2 tqdm - wandb + wandb>=0.13.5 ray>=2.0.1 tabulate>=0.9.0 networkx diff --git a/trlx/ray_tune/wandb.py b/trlx/ray_tune/wandb.py index b1b202963..fc4a69203 100644 --- a/trlx/ray_tune/wandb.py +++ b/trlx/ray_tune/wandb.py @@ -7,8 +7,8 @@ import wandb -wandb.require("report-editing") -import wandb.apis.reports as wb # noqa: E402 +import wandb.apis.reports as wb # isort: skip + ray_info = [ "done", @@ -84,10 +84,10 @@ def log_trials(trial_path: str, project_name: str): def create_report(project_name, param_space, tune_config, trial_path, best_config=None): def get_parallel_coordinate(param_space, metric): column_names = list(param_space.keys()) - columns = [wb.reports.PCColumn(column) for column in column_names] + columns = [wb.PCColumn(column) for column in column_names] return wb.ParallelCoordinatesPlot( - columns=columns + [wb.reports.PCColumn(metric)], + columns=columns + [wb.PCColumn(metric)], layout={"x": 0, "y": 0, "w": 12 * 2, "h": 5 * 2}, ) @@ -155,7 +155,7 @@ def get_metrics_with_history(project_name, group_name, entity=None): get_scatter_plot(tune_config["metric"]), ], runsets=[ - wb.RunSet(project=project_name).set_filters_with_python_expr( + wb.Runset(project=project_name).set_filters_with_python_expr( f'group == "{trial_path}"' ) ], @@ -192,7 +192,7 @@ def get_metrics_with_history(project_name, group_name, entity=None): wb.PanelGrid( panels=line_plot_panels, runsets=[ - wb.RunSet(project=project_name).set_filters_with_python_expr( + wb.Runset(project=project_name).set_filters_with_python_expr( f'group == "{trial_path}"' ) ], diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index c44f7c0b3..d6d858f7d 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -221,7 +221,7 @@ def evaluate(self): # noqa: C901 prompts_sizes = [] lst_prompts = [] generate_time = time() - for prompts in tqdm(self.eval_dataloader, desc="Generating samples"): + for prompts in self.eval_dataloader: if isinstance(prompts, torch.Tensor): samples = self.generate_eval(prompts) else: