Skip to content

Commit

Permalink
Fix wandb.errors.RequireError as reported in #162 (#167)
Browse files Browse the repository at this point in the history
* check wandb version

* Update wandb.py

* add wandb version + mean/reward

* fix api changes

* Update ILQL sweep hparams

* Run `isort`

* Remove `return`s from examples to avoid RayTune erros

* Skip `isort` on wandb reports import

* Remove erroring `tqdm` call in `BaseRLTrainer.evaluate`

Co-authored-by: jon-tow <[email protected]>
  • Loading branch information
ayulockin and jon-tow authored Jan 10, 2023
1 parent 0c5246f commit f71401e
Show file tree
Hide file tree
Showing 7 changed files with 13 additions and 13 deletions.
2 changes: 1 addition & 1 deletion configs/sweeps/ilql_sweep.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ tune_config:
scheduler: "fifo"
num_samples: 32

lr_init:
lr:
strategy: "loguniform"
values: [0.00001, 0.01]
tau:
Expand Down
4 changes: 2 additions & 2 deletions configs/sweeps/ppo_sweep.yml
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
tune_config:
mode: "max"
metric: "mean_reward"
metric: "reward/mean"
search_alg: "random"
scheduler: "fifo"
num_samples: 32

# https://docs.ray.io/en/latest/tune/api_docs/search_space.html#tune-sample-docs
lr_init:
lr:
strategy: "loguniform"
values: [0.00001, 0.01]
init_kl_coef:
Expand Down
2 changes: 1 addition & 1 deletion examples/ppo_sentiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def reward_fn(samples: List[str]) -> List[float]:
imdb = load_dataset("imdb", split="test")
val_prompts = [" ".join(review.split()[:4]) for review in imdb["text"]]

return trlx.train(
trlx.train(
reward_fn=reward_fn,
prompts=prompts,
eval_prompts=val_prompts[0:1000],
Expand Down
2 changes: 1 addition & 1 deletion examples/summarize_daily_cnn/t5_summarize_daily_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def reward_fn(samples: List[str]):
) # get prompt like trlx's prompt
prompt_label[key.strip()] = val_summaries[i]

model = trlx.train(
trlx.train(
config.model.model_path,
reward_fn=reward_fn,
prompts=prompts,
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ install_requires =
torchtyping
transformers>=4.21.2
tqdm
wandb
wandb>=0.13.5
ray>=2.0.1
tabulate>=0.9.0
networkx
Expand Down
12 changes: 6 additions & 6 deletions trlx/ray_tune/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

import wandb

wandb.require("report-editing")
import wandb.apis.reports as wb # noqa: E402
import wandb.apis.reports as wb # isort: skip


ray_info = [
"done",
Expand Down Expand Up @@ -84,10 +84,10 @@ def log_trials(trial_path: str, project_name: str):
def create_report(project_name, param_space, tune_config, trial_path, best_config=None):
def get_parallel_coordinate(param_space, metric):
column_names = list(param_space.keys())
columns = [wb.reports.PCColumn(column) for column in column_names]
columns = [wb.PCColumn(column) for column in column_names]

return wb.ParallelCoordinatesPlot(
columns=columns + [wb.reports.PCColumn(metric)],
columns=columns + [wb.PCColumn(metric)],
layout={"x": 0, "y": 0, "w": 12 * 2, "h": 5 * 2},
)

Expand Down Expand Up @@ -155,7 +155,7 @@ def get_metrics_with_history(project_name, group_name, entity=None):
get_scatter_plot(tune_config["metric"]),
],
runsets=[
wb.RunSet(project=project_name).set_filters_with_python_expr(
wb.Runset(project=project_name).set_filters_with_python_expr(
f'group == "{trial_path}"'
)
],
Expand Down Expand Up @@ -192,7 +192,7 @@ def get_metrics_with_history(project_name, group_name, entity=None):
wb.PanelGrid(
panels=line_plot_panels,
runsets=[
wb.RunSet(project=project_name).set_filters_with_python_expr(
wb.Runset(project=project_name).set_filters_with_python_expr(
f'group == "{trial_path}"'
)
],
Expand Down
2 changes: 1 addition & 1 deletion trlx/trainer/accelerate_base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def evaluate(self): # noqa: C901
prompts_sizes = []
lst_prompts = []
generate_time = time()
for prompts in tqdm(self.eval_dataloader, desc="Generating samples"):
for prompts in self.eval_dataloader:
if isinstance(prompts, torch.Tensor):
samples = self.generate_eval(prompts)
else:
Expand Down

0 comments on commit f71401e

Please sign in to comment.