From a24b6918f105170806262cbd1986215ad39abcd3 Mon Sep 17 00:00:00 2001 From: jon-tow Date: Thu, 26 Jan 2023 02:12:19 +0000 Subject: [PATCH 1/3] Make experiment tracking optional --- trlx/trainer/accelerate_base_trainer.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index 5ce918e96..35c9a9f1f 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -76,17 +76,13 @@ def __init__(self, config, **kwargs): run_name = "/".join([script_name, model_name, num_gpus]) + f":{branch}" - if self.accelerator.is_main_process and not ray.is_initialized(): + is_tracking = config.train.tracker is not None + if is_tracking and self.accelerator.is_main_process and not ray.is_initialized(): config_dict = self.config.to_dict() dist_config = get_distributed_config(self.accelerator) config_dict["distributed"] = dist_config init_trackers_kwargs = {} - if config.train.tracker not in ("wandb", "tensorboard"): - raise ValueError( - f"Only supported trackers are wandb and tensorboard, got {config.train.tracker}" - ) - if config.train.tracker == "wandb": init_trackers_kwargs["wandb"] = { "name": run_name, @@ -101,7 +97,7 @@ def __init__(self, config, **kwargs): config=config_dict, init_kwargs=init_trackers_kwargs, ) - else: # only other supported tracker is tensorboard + elif config.train.tracker == "tensorboard": config_dict_flat = flatten_dict( config_dict ) # flatten config for tensorboard, split list in hparams into flatten config @@ -116,6 +112,12 @@ def __init__(self, config, **kwargs): project_name=self.config.train.project_name, config=config_dict_flat, ) + else: + raise ValueError( + f"Only supported trackers are `wandb` and `tensorboard`. Got: `{config.train.tracker}`" + "Set `tracker` to `None` to disable tracking." + ) + def setup_model(self): """ From 90abaefbaa46d83ed1ce744e0c2ff6d1270b5176 Mon Sep 17 00:00:00 2001 From: jon-tow Date: Thu, 26 Jan 2023 02:40:55 +0000 Subject: [PATCH 2/3] Run pre-commit --- trlx/trainer/accelerate_base_trainer.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index 35c9a9f1f..8e9ad740b 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -77,7 +77,11 @@ def __init__(self, config, **kwargs): run_name = "/".join([script_name, model_name, num_gpus]) + f":{branch}" is_tracking = config.train.tracker is not None - if is_tracking and self.accelerator.is_main_process and not ray.is_initialized(): + if ( + is_tracking + and self.accelerator.is_main_process + and not ray.is_initialized() + ): config_dict = self.config.to_dict() dist_config = get_distributed_config(self.accelerator) config_dict["distributed"] = dist_config @@ -117,7 +121,6 @@ def __init__(self, config, **kwargs): f"Only supported trackers are `wandb` and `tensorboard`. Got: `{config.train.tracker}`" "Set `tracker` to `None` to disable tracking." ) - def setup_model(self): """ From 13da1c6ff0899b0dc4fe614486d823a5e53db827 Mon Sep 17 00:00:00 2001 From: jon-tow Date: Fri, 27 Jan 2023 01:20:43 +0000 Subject: [PATCH 3/3] Fix formatting --- trlx/trainer/accelerate_base_trainer.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index 8e9ad740b..7979c1e95 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -76,9 +76,8 @@ def __init__(self, config, **kwargs): run_name = "/".join([script_name, model_name, num_gpus]) + f":{branch}" - is_tracking = config.train.tracker is not None if ( - is_tracking + config.train.tracker is not None and self.accelerator.is_main_process and not ray.is_initialized() ): @@ -118,7 +117,7 @@ def __init__(self, config, **kwargs): ) else: raise ValueError( - f"Only supported trackers are `wandb` and `tensorboard`. Got: `{config.train.tracker}`" + f"Only supported trackers are `wandb` and `tensorboard`. Got: `{config.train.tracker}`. " "Set `tracker` to `None` to disable tracking." )