diff --git a/sample_factory/cfg/cfg.py b/sample_factory/cfg/cfg.py index ae8e7a30..dbbd3f1f 100644 --- a/sample_factory/cfg/cfg.py +++ b/sample_factory/cfg/cfg.py @@ -744,6 +744,12 @@ def add_wandb_args(p: ArgumentParser): nargs="*", help="Tags can help with finding experiments in WandB web console", ) + p.add_argument( + "--wandb_dir", + default=join(os.getcwd(), "wandb"), + type=str, + help="Logging Directory for WandB", + ) def add_pbt_args(p: ArgumentParser): diff --git a/sample_factory/utils/utils.py b/sample_factory/utils/utils.py index a6119c97..d9de7591 100644 --- a/sample_factory/utils/utils.py +++ b/sample_factory/utils/utils.py @@ -402,6 +402,10 @@ def experiments_dir(cfg, mkdir=True) -> str: return maybe_ensure_dir_exists(cfg.train_dir, mkdir) +def wandb_dir(cfg, mkdir=True) -> str: + return maybe_ensure_dir_exists(cfg.wandb_dir, mkdir) + + def experiment_dir(cfg, mkdir=True) -> str: experiment = cfg.experiment experiments_root = experiments_dir(cfg, mkdir) diff --git a/sample_factory/utils/wandb_utils.py b/sample_factory/utils/wandb_utils.py index 73d7a17e..f0208c8b 100644 --- a/sample_factory/utils/wandb_utils.py +++ b/sample_factory/utils/wandb_utils.py @@ -1,6 +1,6 @@ from datetime import datetime -from sample_factory.utils.utils import log, retry +from sample_factory.utils.utils import log, retry, wandb_dir def init_wandb(cfg): @@ -44,6 +44,7 @@ def init_wandb_func(): tags=cfg.wandb_tags, resume="allow", settings=wandb.Settings(start_method="fork"), + dir=wandb_dir(cfg, True), ) log.debug("Initializing WandB...")