Skip to content

Commit

Permalink
Make experiment tracking optional (#226)
Browse files Browse the repository at this point in the history
* Make experiment tracking optional
  • Loading branch information
jon-tow authored Jan 27, 2023
1 parent aeab7cc commit 3bb6c15
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions trlx/trainer/accelerate_base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,17 +76,16 @@ def __init__(self, config, **kwargs):

run_name = "/".join([script_name, model_name, num_gpus]) + f":{branch}"

if self.accelerator.is_main_process and not ray.is_initialized():
if (
config.train.tracker is not None
and self.accelerator.is_main_process
and not ray.is_initialized()
):
config_dict = self.config.to_dict()
dist_config = get_distributed_config(self.accelerator)
config_dict["distributed"] = dist_config
init_trackers_kwargs = {}

if config.train.tracker not in ("wandb", "tensorboard"):
raise ValueError(
f"Only supported trackers are wandb and tensorboard, got {config.train.tracker}"
)

if config.train.tracker == "wandb":
init_trackers_kwargs["wandb"] = {
"name": run_name,
Expand All @@ -101,7 +100,7 @@ def __init__(self, config, **kwargs):
config=config_dict,
init_kwargs=init_trackers_kwargs,
)
else: # only other supported tracker is tensorboard
elif config.train.tracker == "tensorboard":
config_dict_flat = flatten_dict(
config_dict
) # flatten config for tensorboard, split list in hparams into flatten config
Expand All @@ -116,6 +115,11 @@ def __init__(self, config, **kwargs):
project_name=self.config.train.project_name,
config=config_dict_flat,
)
else:
raise ValueError(
f"Only supported trackers are `wandb` and `tensorboard`. Got: `{config.train.tracker}`. "
"Set `tracker` to `None` to disable tracking."
)

def setup_model(self):
"""
Expand Down

0 comments on commit 3bb6c15

Please sign in to comment.