Skip to content

Commit

Permalink
support customized run_name in tracker
Browse files Browse the repository at this point in the history
  • Loading branch information
聂靖入 committed Oct 24, 2023
1 parent bcbcdac commit 4e668ca
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 1 deletion.
1 change: 1 addition & 0 deletions trlx/data/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ class TrainConfig:
trainer_kwargs: Dict[str, Any] = field(default_factory=dict) # Extra keyword arguments for the trainer

project_name: str = "trlx"
run_name: Optional[str] = None
entity_name: Optional[str] = None
group_name: Optional[str] = None

Expand Down
4 changes: 3 additions & 1 deletion trlx/trainer/accelerate_base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,9 @@ def __init__(self, config, **kwargs): # noqa: C901
num_gpus = f"{self.accelerator.num_processes}gpus"
branch = get_git_tag()[0]

run_name = "/".join([script_name, model_name, num_gpus]) + f":{branch}"
run_name = self.config.train.run_name
if not run_name:
run_name = "/".join([script_name, model_name, num_gpus]) + f":{branch}"

if self.accelerator.is_main_process:
config_dict = self.config.to_dict()
Expand Down

0 comments on commit 4e668ca

Please sign in to comment.