diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index 5ce918e96..7979c1e95 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -76,17 +76,16 @@ def __init__(self, config, **kwargs): run_name = "/".join([script_name, model_name, num_gpus]) + f":{branch}" - if self.accelerator.is_main_process and not ray.is_initialized(): + if ( + config.train.tracker is not None + and self.accelerator.is_main_process + and not ray.is_initialized() + ): config_dict = self.config.to_dict() dist_config = get_distributed_config(self.accelerator) config_dict["distributed"] = dist_config init_trackers_kwargs = {} - if config.train.tracker not in ("wandb", "tensorboard"): - raise ValueError( - f"Only supported trackers are wandb and tensorboard, got {config.train.tracker}" - ) - if config.train.tracker == "wandb": init_trackers_kwargs["wandb"] = { "name": run_name, @@ -101,7 +100,7 @@ def __init__(self, config, **kwargs): config=config_dict, init_kwargs=init_trackers_kwargs, ) - else: # only other supported tracker is tensorboard + elif config.train.tracker == "tensorboard": config_dict_flat = flatten_dict( config_dict ) # flatten config for tensorboard, split list in hparams into flatten config @@ -116,6 +115,11 @@ def __init__(self, config, **kwargs): project_name=self.config.train.project_name, config=config_dict_flat, ) + else: + raise ValueError( + f"Only supported trackers are `wandb` and `tensorboard`. Got: `{config.train.tracker}`. " + "Set `tracker` to `None` to disable tracking." + ) def setup_model(self): """