diff --git a/run_llama_train.sh b/run_llama_train.sh index 723ddec8..c814f9f2 100755 --- a/run_llama_train.sh +++ b/run_llama_train.sh @@ -23,10 +23,8 @@ CHECKPOINT_FOLDER=${CHECKPOINT_FOLDER:-""} # Please adjust this to a longer interval period. The unit of measurement is in steps. CHECKPOINT_INTERVAL=${CHECKPOINT_INTERVAL:-5} +CONFIG_FILE=${CONFIG_FILE:-"./torchtrain/train_configs/train_config.toml"} + torchrun --nproc_per_node=${NGPU} --rdzv_endpoint="localhost:5972" \ --local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ -train.py --steps 10 \ ---model ${MODEL} --model_conf ${MODEL_CONF} \ ---pp_degree ${PP} --sp_degree ${SP} --dp_degree ${DP} \ ---compile \ ---checkpoint-folder=${CHECKPOINT_FOLDER} --checkpoint-interval=${CHECKPOINT_INTERVAL} +train.py --job.config_file ${CONFIG_FILE} diff --git a/test/__init__.py b/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/test_job_config.py b/test/test_job_config.py new file mode 100644 index 00000000..ea47e54a --- /dev/null +++ b/test/test_job_config.py @@ -0,0 +1,21 @@ +import pytest +from torchtrain.config_manager import JobConfig + + +class TestJobConfig: + def test_command_line_args(self): + config = JobConfig() + config.parse_args([]) + assert config.model.name == "llama" + + def test_job_config_file(self): + config = JobConfig() + config.parse_args( + ["--job.config_file", "./torchtrain/train_configs/train_config.toml"] + ) + assert config.model.name == "llama" + + def test_job_file_does_not_exist(self): + with pytest.raises(FileNotFoundError): + config = JobConfig() + config.parse_args(["--job.config_file", "ohno.toml"]) diff --git a/test/test_test.py b/test/test_test.py index 773ccc2c..ed48a6b9 100644 --- a/test/test_test.py +++ b/test/test_test.py @@ -1,6 +1,7 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + # delete me after adding real tests.. class Test: def test_test(self): diff --git a/torchtrain/config_manager.py b/torchtrain/config_manager.py new file mode 100644 index 00000000..ff8afe8f --- /dev/null +++ b/torchtrain/config_manager.py @@ -0,0 +1,215 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +import argparse +import sys +from collections import defaultdict +from typing import Union + +try: + import tomllib +except ModuleNotFoundError: + import tomli as tomllib + + +class JobConfig: + """ + A helper class to manage the train configuration. + Semantics: + - Default config is loaded from a toml file. If no toml file is provided, + then the default config is loaded from argparse defaults. + """ + + def parse_args(self, args_list: list = sys.argv[1:]): + args = JobConfig.init_args_from_command_line(args_list) + config_file = getattr(args, "job.config_file", None) + if config_file is None: + args_dict = self._args_to_two_level_dict(args) + else: + with open(config_file, "rb") as f: + args_dict = tomllib.load(f) + for k, v in args_dict.items(): + class_type = type(k.title(), (), v) + setattr(self, k, class_type()) + self._validate_config() + + def _args_to_two_level_dict(self, args: argparse.Namespace) -> defaultdict: + args_dict = defaultdict(defaultdict) + for k, v in vars(args).items(): + first_level_key, second_level_key = k.split(".", 1) + args_dict[first_level_key][second_level_key] = v + return args_dict + + def _validate_config(self): + # TODO: Add more mandatory validations + assert self.model.name and self.model.flavor and self.model.tokenizer_path + return True + + @staticmethod + def init_args_from_command_line( + args_list: list = sys.argv[1:], + ) -> argparse.Namespace: + """ + Each argument starts with _ which is the section name in the toml file + followed by name of the option in the toml file. For ex, + model.name translates to: + [model] + name + in the toml file + """ + parser = argparse.ArgumentParser(description="TorchTrain arg parser.") + parser.add_argument( + "--job.config_file", + type=str, + default=None, + help="job config file", + ) + + # misc configs + parser.add_argument( + "--job.dump_folder", + type=str, + default="./torchtrain/outputs", + help="folder to dump job outputs", + ) + + # profiling configs + parser.add_argument( + "--profiling.run_profiler", + action="store_true", + help="enable pytorch profiler", + ) + parser.add_argument( + "--profiling.save_traces_folder", + type=str, + default="profiling/traces", + help="trace file location", + ) + parser.add_argument( + "--profiling.profile_every_x_iter", + type=int, + default=10, + help="collect profiler traces every x iterations", + ) + # metrics configs + parser.add_argument( + "--metrics.log_freq", + type=int, + default=10, + help="how often to log metrics to TensorBoard", + ) + parser.add_argument( + "--metrics.enable_tensorboard", + action="store_true", + help="how often to log metrics to TensorBoard", + ) + parser.add_argument( + "--metrics.save_tb_folder", + type=str, + default="tb", + help="folder to dump tensorboard state", + ) + + # model configs + parser.add_argument( + "--model.name", + type=str, + default="llama", + help="which model to train", + ) + parser.add_argument( + "--model.flavor", + type=str, + default="debugmodel", + help="which model config to train", + ) + parser.add_argument( + "--model.tokenizer_path", + type=str, + default="./torchtrain/datasets/tokenizer/tokenizer.model", + help="tokenizer path", + ) + + # optimizer configs + parser.add_argument( + "--optimizer.name", type=str, default="AdamW", help="optimizer to use" + ) + parser.add_argument( + "--optimizer.lr", type=float, default=8e-4, help="learning rate to use" + ) + + # training configs + parser.add_argument( + "--training.dataset", type=str, default="alpaca", help="dataset to use" + ) + parser.add_argument( + "--training.batch_size", type=int, default=8, help="batch size" + ) + parser.add_argument( + "--training.seq_len", type=int, default=2048, help="sequence length" + ) + parser.add_argument( + "--training.warmup_pct", + type=float, + default=0.20, + help="percentage of total training steps to use for warmup", + ) + parser.add_argument( + "--training.max_norm", + type=Union[float, int], + default=1.0, + help="max norm for gradient clipping", + ) + parser.add_argument( + "--training.steps", type=int, default=-1, help="how many train steps to run" + ) + parser.add_argument( + "--training.data_parallel_degree", + type=int, + default=-1, + help="Data Parallelism degree. -1 means leftover ranks will be used (After SP/PP). 1 means disabled.", + ) + parser.add_argument( + "--training.sequence_parallel_degree", + type=int, + default=1, + help="Sequence Parallelism degree. 1 means disabled.", + ) + parser.add_argument( + "--training.pipeline_parallel_degree", + type=int, + default=1, + help="Pipeline Parallelism degree (default of 1 means disabled)", + ) + parser.add_argument( + "--training.compile", + action="store_true", + help="Whether to compile the model.", + ) + parser.add_argument( + "--training.checkpoint_interval", + type=int, + default=3600, + help=( + "Checkpointing interval. The unit of measurement is in seconds or " + "steps depending on --training.checkpoint-internval-type." + ), + ) + parser.add_argument( + "--training.checkpoint_interval_type", + type=str, + default="steps", + help=( + "The checkpointing interval unit of measurement." + "The default value is step." + ), + ) + parser.add_argument( + "--training.checkpoint_folder", + type=str, + default="", + help=( + "The folder to store the checkpoints. If this is not specified or " + "is an empty string, checkpointing is disabled." + ), + ) + return parser.parse_args(args_list) diff --git a/torchtrain/datasets/tokenizer.py b/torchtrain/datasets/tokenizer.py index aa645ae6..07593b55 100644 --- a/torchtrain/datasets/tokenizer.py +++ b/torchtrain/datasets/tokenizer.py @@ -8,8 +8,8 @@ import os from abc import ABC, abstractmethod -from typing import List from logging import getLogger +from typing import List from sentencepiece import SentencePieceProcessor @@ -48,6 +48,7 @@ def create_tokenizer(tokenizer_type: str, tokenizer_path: str) -> TokenizerIf: class SentencePieceTokenizer(TokenizerIf): """tokenizing and encoding/decoding text using SentencePiece.""" + def __init__(self, tokenizer_path: str): """ Initializes the Tokenizer with a SentencePiece model. diff --git a/torchtrain/logging_utils.py b/torchtrain/logging_utils.py index 34d5a1ff..f2ed30b8 100644 --- a/torchtrain/logging_utils.py +++ b/torchtrain/logging_utils.py @@ -1,6 +1,7 @@ -import torch import logging +import torch + logger = logging.getLogger() diff --git a/torchtrain/lr_scheduling.py b/torchtrain/lr_scheduling.py index c477e342..b76bd6d3 100644 --- a/torchtrain/lr_scheduling.py +++ b/torchtrain/lr_scheduling.py @@ -2,6 +2,7 @@ # All rights reserved. from torch.optim.lr_scheduler import LambdaLR +from torchtrain.config_manager import JobConfig # global states for scheduling # these are needed as LambdaLR does not support argument passing @@ -29,11 +30,13 @@ def linear_warmup_linear_decay(current_step: int) -> float: return curr_adjustment -def get_lr_scheduler(optimizer, args): +def get_lr_scheduler(optimizer, job_config: JobConfig): """Build a linear warmup and linear decay scheduler""" global _warmup_steps, _decay_steps - _warmup_steps = max(int(args.steps * args.warmup_pct), 2) - _decay_steps = float(max(1, args.steps - _warmup_steps)) + _warmup_steps = max( + int(job_config.training.steps * job_config.training.warmup_pct), 2 + ) + _decay_steps = float(max(1, job_config.training.steps - _warmup_steps)) warmup_scheduler = LambdaLR(optimizer, lr_lambda=linear_warmup_linear_decay) return warmup_scheduler diff --git a/torchtrain/metrics.py b/torchtrain/metrics.py index cad98e85..b2ad3cc9 100644 --- a/torchtrain/metrics.py +++ b/torchtrain/metrics.py @@ -12,9 +12,9 @@ import torch import torch.nn as nn from torch.utils.tensorboard import SummaryWriter +from torchtrain.config_manager import JobConfig from torchtrain.logging_utils import rank0_log -from torchtrain.profiling import get_config_from_toml _gb_in_bytes = 1024 * 1024 * 1024 _mb_in_bytes = 1024 * 1024 @@ -214,16 +214,14 @@ def close(self): self.writer.close() -def build_metric_logger(tag: Optional[str] = None): - config = get_config_from_toml() - - dump_dir = config["global"]["dump_folder"] - save_tb_folder = config["metrics"]["save_tb_folder"] +def build_metric_logger(config: JobConfig, tag: Optional[str] = None): + dump_dir = config.job.dump_folder + save_tb_folder = config.metrics.save_tb_folder # since we don't have run id yet, use current minute as identifier datetime_str = datetime.now().strftime("%Y%m%d-%H%M") log_dir = os.path.join(dump_dir, save_tb_folder, datetime_str) - enable_tb = config["metrics"].get("enable_tensorboard", False) + enable_tb = config.metrics.enable_tensorboard if enable_tb: rank0_log( f"Metrics logging active. Tensorboard logs will be saved at {log_dir}." diff --git a/torchtrain/parallelisms/parallelize_llama.py b/torchtrain/parallelisms/parallelize_llama.py index 6538d418..805bfa87 100644 --- a/torchtrain/parallelisms/parallelize_llama.py +++ b/torchtrain/parallelisms/parallelize_llama.py @@ -32,6 +32,7 @@ PrepareModuleInput, RowwiseParallel, ) +from torchtrain.config_manager import JobConfig from torchtrain.logging_utils import rank0_log @@ -67,13 +68,14 @@ def partition_fn(name, module, device_mesh): # Uses PTD FSDP AC wrapper -def checkpoint_wrapper(module, config): +# TODO: why is config needed here? +def checkpoint_wrapper(module, job_config: JobConfig): return ptd_checkpoint_wrapper( module, checkpoint_impl=CheckpointImpl.NO_REENTRANT, preserve_rng_state=False ) -def parallelize_llama(model, world_mesh, parallel_dims, args): +def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): """ Apply parallelisms to the model, including PTD parallelisms, and AC. @@ -87,7 +89,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, args): if parallel_dims.sp_enabled: # First we apply Sequence Parallelism if it's enabled tp_mesh = world_mesh["sp"] if world_mesh.ndim > 1 else world_mesh - sp_degree = args.sp_degree + sp_degree = job_config.training.sequence_parallelism_degree # First: # 1. parallelize the first embedding and the last linear proj layer # 2. shard the first layer of transformer block @@ -169,7 +171,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, args): # apply AC to each layer # before wrapping with FSDP, we need to make sure the layer is on GPU transformer_block = transformer_block.cuda() - transformer_block = checkpoint_wrapper(transformer_block, args) + transformer_block = checkpoint_wrapper(transformer_block, job_config) # Wraps each layer with FSDP model.layers[layer_id] = wrap(transformer_block) diff --git a/torchtrain/profiling.py b/torchtrain/profiling.py index 167067e0..ed1a9b5a 100644 --- a/torchtrain/profiling.py +++ b/torchtrain/profiling.py @@ -5,38 +5,20 @@ import os import torch - -try: - import tomllib -except ModuleNotFoundError: - import tomli as tomllib - +from torchtrain.config_manager import JobConfig from torchtrain.logging_utils import rank0_log -_config_file = "./torchtrain/train_configs/train_config.toml" - - -def get_config_from_toml(config_path: str = _config_file) -> dict: - """ - Reads a config file in TOML format and returns a dictionary. - """ - with open(config_path, "rb") as f: - config = tomllib.load(f) - return config - @contextlib.contextmanager -def maybe_run_profiler(*pos_args, **kwargs): - config = get_config_from_toml() - +def maybe_run_profiler(config: JobConfig, *pos_args, **kwargs): # get user defined profiler settings - run_profiler = config["profiling"].get("run_profiler", False) + run_profiler = config.profiling.run_profiler if run_profiler: - dump_dir = config["global"]["dump_folder"] - save_trace_dir = config["profiling"]["save_traces_folder"] + dump_dir = config.job.dump_folder + save_trace_dir = config.profiling.save_traces_folder trace_dir = os.path.join(dump_dir, save_trace_dir) - iter_frequency = config["profiling"]["profile_every_x_iter"] + iter_frequency = config.profiling.profile_every_x_iter _global_iter_count = 0 diff --git a/torchtrain/train_configs/train_config.toml b/torchtrain/train_configs/train_config.toml index da0161e0..bb57a7cf 100644 --- a/torchtrain/train_configs/train_config.toml +++ b/torchtrain/train_configs/train_config.toml @@ -1,5 +1,5 @@ # TorchTrain Config.toml -[global] +[job] dump_folder = "./torchtrain/outputs" [profiling] @@ -11,3 +11,29 @@ profile_every_x_iter = 10 [metrics] enable_tensorboard = true save_tb_folder = "tb" +log_freq = 10 + +[model] +name = "llama" +flavor = "debugmodel" +tokenizer_path = "./torchtrain/datasets/tokenizer/tokenizer.model" + +[optimizer] +name = "AdamW" +lr = 8e-4 + + +[training] +batch_size = 8 +seq_len = 2048 +warmup_pct = 0.20 +max_norm = 1.0 +steps = 10 +data_parallel_degree = -1 +sequence_parallel_degree = 1 +pipeline_parallel_degree = 1 +compile = false +checkpoint_interval = 3600 +checkpoint_interval_type = "steps" +checkpoint_folder = "" +dataset = "alpaca" diff --git a/train.py b/train.py index 845e55f4..f3f3389e 100644 --- a/train.py +++ b/train.py @@ -1,11 +1,10 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. -import argparse import os from dataclasses import dataclass, field from timeit import default_timer as timer -from typing import Any, Dict, List, Union +from typing import Any, Dict, List import numpy as np @@ -16,6 +15,7 @@ from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler from torchtrain.checkpoint import CheckpointManager, IntervalType +from torchtrain.config_manager import JobConfig # torchtrain related from torchtrain.datasets import create_tokenizer, dataloader_fn @@ -49,14 +49,16 @@ def load_state_dict(self, state_dict) -> None: self.losses = state_dict["losses"].tolist() -def build_optimizer(model, args): +def build_optimizer(model, job_config: JobConfig): # build optimizer - if args.optimizer == "Adam": - optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) - elif args.optimizer == "AdamW": - optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr) + name = job_config.optimizer.name + lr = job_config.optimizer.lr + if name == "Adam": + optimizer = torch.optim.Adam(model.parameters(), lr=lr) + elif name == "AdamW": + optimizer = torch.optim.AdamW(model.parameters(), lr=lr) else: - raise NotImplementedError(f"optimizer {args.optimizer} not added") + raise NotImplementedError(f"optimizer {name} not added") return optimizer @@ -73,31 +75,34 @@ def build_grad_scaler(model): return ShardedGradScaler(enabled=enable_grad_scaling) -def main(args): +def main(job_config: JobConfig): init_logger() # init world mesh world_size = int(os.environ["WORLD_SIZE"]) parallel_dims = ParallelDims( - dp=args.dp_degree, sp=args.sp_degree, pp=args.pp_degree, world_size=world_size + dp=job_config.training.data_parallel_degree, + sp=job_config.training.sequence_parallel_degree, + pp=job_config.training.pipeline_parallel_degree, + world_size=world_size, ) world_mesh = parallel_dims.build_mesh(device_type="cuda") - model_name = args.model + model_name = job_config.model.name rank0_log(f"Building {model_name}") # build tokenizer tokenizer_type = model_name_to_tokenizer[model_name] - tokenizer = create_tokenizer(tokenizer_type, args.tokenizer_path) + tokenizer = create_tokenizer(tokenizer_type, job_config.model.tokenizer_path) # build dataloader # need dp world size and rank # TODO: dp might not always be 0 so we need to handle that more carefully dp_degree = world_mesh.size(0) dp_rank = world_mesh.get_local_rank(0) - build_dataloader_fn = dataloader_fn[args.dataset] + build_dataloader_fn = dataloader_fn[job_config.training.dataset] data_loader = build_dataloader_fn( tokenizer, - args.batch_size, - args.seq_len, + job_config.training.batch_size, + job_config.training.seq_len, dp_degree, dp_rank, ) @@ -105,7 +110,7 @@ def main(args): # build model # TODO: add meta initialization model_cls = model_name_to_cls[model_name] - model_config = models_config[model_name][args.model_conf] + model_config = models_config[model_name][job_config.model.flavor] model_config.vocab_size = tokenizer.n_words model = model_cls.from_model_args(model_config) @@ -113,27 +118,29 @@ def main(args): # log model size model_param_count = get_num_params(model) rank0_log( - f"Model {model_name} {args.model_conf} size: {model_param_count:,} total parameters" + f"Model {model_name} {job_config.model.flavor} size: {model_param_count:,} total parameters" ) gpu_metrics = GPUMemoryMonitor("cuda") rank0_log(f"GPU memory usage: {gpu_metrics}") # apply PTD parallelisms + AC - model = models_parallelize_fns[model_name](model, world_mesh, parallel_dims, args) + model = models_parallelize_fns[model_name]( + model, world_mesh, parallel_dims, job_config + ) # to use FSDP-customized gradient scaler and gradient clipping solutions assert isinstance(model, FSDP) # build optimizer after apply parallelisms to the model - optimizer = build_optimizer(model, args) - scheduler = get_lr_scheduler(optimizer, args) + optimizer = build_optimizer(model, job_config) + scheduler = get_lr_scheduler(optimizer, job_config) scaler = build_grad_scaler(model) - metric_logger = build_metric_logger() + metric_logger = build_metric_logger(job_config) # torch.compile model for improved performance - if args.compile: + if job_config.training.compile: rank0_log(f"Compiling model {model_name} with torch.compile...") model = torch.compile( model, @@ -148,23 +155,26 @@ def main(args): model=model, optimizer=optimizer, states={"train_state": train_state}, - folder=args.checkpoint_folder, + folder=job_config.training.checkpoint_folder, interval_type=( IntervalType.SECONDS - if args.checkpoint_interval_type == "seconds" + if job_config.training.checkpoint_interval_type == "seconds" else IntervalType.STEPS ), - interval=args.checkpoint_interval, + interval=job_config.training.checkpoint_interval, ) checkpoint.load() - with maybe_run_profiler() as torch_profiler: + with maybe_run_profiler(job_config) as torch_profiler: checkpoint.reset() # variables used to keep info for metrics logging losses_since_last_log: List[float] = [] nwords_since_last_log = 0 time_last_log = timer() - while train_state.step < args.steps or args.steps == -1: + while ( + train_state.step < job_config.training.steps + or job_config.training.steps == -1 + ): train_state.step += 1 # get batch batch = next(iter(data_loader)) @@ -187,7 +197,7 @@ def main(args): # clip gradients (after unscaling gradients of the optimizer's params) scaler.unscale_(optimizer) - model.clip_grad_norm_(args.max_norm) + model.clip_grad_norm_(job_config.training.max_norm) # optimizer step # If gradients don't contain infs/NaNs, optimizer.step() is then called; @@ -206,7 +216,7 @@ def main(args): losses_since_last_log.append(train_state.current_loss) # log metrics - if (train_state.step - 1) % args.log_freq == 0: + if (train_state.step - 1) % job_config.metrics.log_freq == 0: avg_loss, max_loss = ( np.mean(losses_since_last_log), np.max(losses_since_last_log), @@ -246,107 +256,15 @@ def main(args): ) scheduler.step() - checkpoint.save(train_state.step, force=(train_state.step == args.steps)) + checkpoint.save( + train_state.step, force=(train_state.step == job_config.training.steps) + ) metric_logger.close() rank0_log(f"{gpu_metrics.get_current_stats()}") if __name__ == "__main__": - parser = argparse.ArgumentParser(description="TorchTrain arg parser.") - LOCAL_WORLD_SIZE = int(os.environ["LOCAL_WORLD_SIZE"]) - - parser.add_argument( - "--model", type=str, default="llama", help="which model to train" - ) - parser.add_argument( - "--model_conf", - type=str, - default="debugmodel", - help="which model config to train", - ) - parser.add_argument("--dataset", type=str, default="alpaca", help="dataset to use") - parser.add_argument( - "--tokenizer_path", - type=str, - default="./torchtrain/datasets/tokenizer/tokenizer.model", - help="tokenizer path", - ) - parser.add_argument("--batch_size", type=int, default=8, help="batch size") - parser.add_argument("--seq_len", type=int, default=2048, help="sequence length") - parser.add_argument( - "--optimizer", type=str, default="AdamW", help="optimizer to use" - ) - parser.add_argument("--lr", type=float, default=8e-4, help="learning rate to use") - parser.add_argument( - "--warmup_pct", - type=float, - default=0.20, - help="percentage of total training steps to use for warmup", - ) - parser.add_argument( - "--max_norm", - type=Union[float, int], - default=1.0, - help="max norm for gradient clipping", - ) - parser.add_argument( - "--steps", type=int, default=-1, help="how many train steps to run" - ) - parser.add_argument( - "--dp_degree", - type=int, - default=-1, - help="Data Parallelism degree. -1 means leftover ranks will be used (After SP/PP). 1 means disabled.", - ) - parser.add_argument( - "--sp_degree", - type=int, - default=1, - help="Sequence Parallelism degree. 1 means disabled.", - ) - parser.add_argument( - "--pp_degree", - type=int, - default=1, - help="Pipeline Parallelism degree (default of 1 means disabled)", - ) - parser.add_argument( - "--compile", action="store_true", help="Whether to compile the model." - ) - parser.add_argument( - "--checkpoint-interval", - type=int, - default=3600, - help=( - "Checkpointing interval. The unit of measurement is in seconds or " - "steps depending on --checkpoint-internval-type." - ), - ) - parser.add_argument( - "--checkpoint-interval-type", - type=str, - default="steps", - help=( - "The checkpointing interval unit of measurement." - "The default value is step." - ), - ) - parser.add_argument( - "--checkpoint-folder", - type=str, - default="", - help=( - "The folder to store the checkpoints. If this is not specified or " - "is an empty string, checkpointing is disabled." - ), - ) - parser.add_argument( - "--log_freq", - type=int, - default=10, - help="how often to log metrics to TensorBoard", - ) - - args = parser.parse_args() - main(args) + config = JobConfig() + config.parse_args() + main(config)