Skip to content

Commit bec909b

Browse files
gnadathurgnadathur
and
gnadathur
authored
Unified config manager for toml and command line (pytorch#76)
Summary: PR implements an unfied config manager. - Command line args and toml file args are now unified. - Defaults can be loaded from either. options like `training.batchsize` will be available as `config.training.batchsize` where `config` is a config manager object. Test Plan: Test Plan: ============================= test session starts ============================== platform linux -- Python 3.10.13, pytest-8.0.1, pluggy-1.4.0 -- /home/gnadathur/local/a/pytorch-env/bin/python cachedir: .pytest_cache rootdir: /data/users/gnadathur/a/torchtrain configfile: pyproject.toml plugins: cov-4.1.0 collecting ... collected 5 items test/test_job_config.py::TestJobConfig::test_command_line_args PASSED [ 20%] test/test_job_config.py::TestJobConfig::test_command_line_args_with_override PASSED [ 40%] test/test_job_config.py::TestJobConfig::test_job_config_file PASSED [ 60%] test/test_job_config.py::TestJobConfig::test_job_config_file_with_override PASSED [ 80%] test/test_job_config.py::TestJobConfig::test_job_file_does_not_exist PASSED [100%] ---------- coverage: platform linux, python 3.10.13-final-0 ---------- Coverage XML written to file coverage.xml ============================= slowest 20 durations ============================= 0.01s call test/test_job_config.py::TestJobConfig::test_job_config_file_with_override 0.00s call test/test_job_config.py::TestJobConfig::test_job_config_file 0.00s call test/test_job_config.py::TestJobConfig::test_command_line_args 0.00s call test/test_job_config.py::TestJobConfig::test_command_line_args_with_override 0.00s call test/test_job_config.py::TestJobConfig::test_job_file_does_not_exist 0.00s setup test/test_job_config.py::TestJobConfig::test_command_line_args 0.00s teardown test/test_job_config.py::TestJobConfig::test_command_line_args 0.00s setup test/test_job_config.py::TestJobConfig::test_job_file_does_not_exist 0.00s setup test/test_job_config.py::TestJobConfig::test_command_line_args_with_override 0.00s teardown test/test_job_config.py::TestJobConfig::test_command_line_args_with_override 0.00s setup test/test_job_config.py::TestJobConfig::test_job_config_file_with_override 0.00s setup test/test_job_config.py::TestJobConfig::test_job_config_file 0.00s teardown test/test_job_config.py::TestJobConfig::test_job_file_does_not_exist 0.00s teardown test/test_job_config.py::TestJobConfig::test_job_config_file 0.00s teardown test/test_job_config.py::TestJobConfig::test_job_config_file_with_override ============================== 5 passed in 0.10s =============================== Reviewers: Subscribers: Tasks: Tags: Co-authored-by: gnadathur <[email protected]>
1 parent 0dda1a8 commit bec909b

13 files changed

+339
-173
lines changed

run_llama_train.sh

+3-5
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,8 @@ CHECKPOINT_FOLDER=${CHECKPOINT_FOLDER:-""}
2323
# Please adjust this to a longer interval period. The unit of measurement is in steps.
2424
CHECKPOINT_INTERVAL=${CHECKPOINT_INTERVAL:-5}
2525

26+
CONFIG_FILE=${CONFIG_FILE:-"./torchtrain/train_configs/train_config.toml"}
27+
2628
torchrun --nproc_per_node=${NGPU} --rdzv_endpoint="localhost:5972" \
2729
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
28-
train.py --steps 10 \
29-
--model ${MODEL} --model_conf ${MODEL_CONF} \
30-
--pp_degree ${PP} --sp_degree ${SP} --dp_degree ${DP} \
31-
--compile \
32-
--checkpoint-folder=${CHECKPOINT_FOLDER} --checkpoint-interval=${CHECKPOINT_INTERVAL}
30+
train.py --job.config_file ${CONFIG_FILE}

test/__init__.py

Whitespace-only changes.

test/test_job_config.py

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import pytest
2+
from torchtrain.config_manager import JobConfig
3+
4+
5+
class TestJobConfig:
6+
def test_command_line_args(self):
7+
config = JobConfig()
8+
config.parse_args([])
9+
assert config.model.name == "llama"
10+
11+
def test_job_config_file(self):
12+
config = JobConfig()
13+
config.parse_args(
14+
["--job.config_file", "./torchtrain/train_configs/train_config.toml"]
15+
)
16+
assert config.model.name == "llama"
17+
18+
def test_job_file_does_not_exist(self):
19+
with pytest.raises(FileNotFoundError):
20+
config = JobConfig()
21+
config.parse_args(["--job.config_file", "ohno.toml"])

test/test_test.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
33

4+
45
# delete me after adding real tests..
56
class Test:
67
def test_test(self):

torchtrain/config_manager.py

+215
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
import argparse
4+
import sys
5+
from collections import defaultdict
6+
from typing import Union
7+
8+
try:
9+
import tomllib
10+
except ModuleNotFoundError:
11+
import tomli as tomllib
12+
13+
14+
class JobConfig:
15+
"""
16+
A helper class to manage the train configuration.
17+
Semantics:
18+
- Default config is loaded from a toml file. If no toml file is provided,
19+
then the default config is loaded from argparse defaults.
20+
"""
21+
22+
def parse_args(self, args_list: list = sys.argv[1:]):
23+
args = JobConfig.init_args_from_command_line(args_list)
24+
config_file = getattr(args, "job.config_file", None)
25+
if config_file is None:
26+
args_dict = self._args_to_two_level_dict(args)
27+
else:
28+
with open(config_file, "rb") as f:
29+
args_dict = tomllib.load(f)
30+
for k, v in args_dict.items():
31+
class_type = type(k.title(), (), v)
32+
setattr(self, k, class_type())
33+
self._validate_config()
34+
35+
def _args_to_two_level_dict(self, args: argparse.Namespace) -> defaultdict:
36+
args_dict = defaultdict(defaultdict)
37+
for k, v in vars(args).items():
38+
first_level_key, second_level_key = k.split(".", 1)
39+
args_dict[first_level_key][second_level_key] = v
40+
return args_dict
41+
42+
def _validate_config(self):
43+
# TODO: Add more mandatory validations
44+
assert self.model.name and self.model.flavor and self.model.tokenizer_path
45+
return True
46+
47+
@staticmethod
48+
def init_args_from_command_line(
49+
args_list: list = sys.argv[1:],
50+
) -> argparse.Namespace:
51+
"""
52+
Each argument starts with <prefix>_ which is the section name in the toml file
53+
followed by name of the option in the toml file. For ex,
54+
model.name translates to:
55+
[model]
56+
name
57+
in the toml file
58+
"""
59+
parser = argparse.ArgumentParser(description="TorchTrain arg parser.")
60+
parser.add_argument(
61+
"--job.config_file",
62+
type=str,
63+
default=None,
64+
help="job config file",
65+
)
66+
67+
# misc configs
68+
parser.add_argument(
69+
"--job.dump_folder",
70+
type=str,
71+
default="./torchtrain/outputs",
72+
help="folder to dump job outputs",
73+
)
74+
75+
# profiling configs
76+
parser.add_argument(
77+
"--profiling.run_profiler",
78+
action="store_true",
79+
help="enable pytorch profiler",
80+
)
81+
parser.add_argument(
82+
"--profiling.save_traces_folder",
83+
type=str,
84+
default="profiling/traces",
85+
help="trace file location",
86+
)
87+
parser.add_argument(
88+
"--profiling.profile_every_x_iter",
89+
type=int,
90+
default=10,
91+
help="collect profiler traces every x iterations",
92+
)
93+
# metrics configs
94+
parser.add_argument(
95+
"--metrics.log_freq",
96+
type=int,
97+
default=10,
98+
help="how often to log metrics to TensorBoard",
99+
)
100+
parser.add_argument(
101+
"--metrics.enable_tensorboard",
102+
action="store_true",
103+
help="how often to log metrics to TensorBoard",
104+
)
105+
parser.add_argument(
106+
"--metrics.save_tb_folder",
107+
type=str,
108+
default="tb",
109+
help="folder to dump tensorboard state",
110+
)
111+
112+
# model configs
113+
parser.add_argument(
114+
"--model.name",
115+
type=str,
116+
default="llama",
117+
help="which model to train",
118+
)
119+
parser.add_argument(
120+
"--model.flavor",
121+
type=str,
122+
default="debugmodel",
123+
help="which model config to train",
124+
)
125+
parser.add_argument(
126+
"--model.tokenizer_path",
127+
type=str,
128+
default="./torchtrain/datasets/tokenizer/tokenizer.model",
129+
help="tokenizer path",
130+
)
131+
132+
# optimizer configs
133+
parser.add_argument(
134+
"--optimizer.name", type=str, default="AdamW", help="optimizer to use"
135+
)
136+
parser.add_argument(
137+
"--optimizer.lr", type=float, default=8e-4, help="learning rate to use"
138+
)
139+
140+
# training configs
141+
parser.add_argument(
142+
"--training.dataset", type=str, default="alpaca", help="dataset to use"
143+
)
144+
parser.add_argument(
145+
"--training.batch_size", type=int, default=8, help="batch size"
146+
)
147+
parser.add_argument(
148+
"--training.seq_len", type=int, default=2048, help="sequence length"
149+
)
150+
parser.add_argument(
151+
"--training.warmup_pct",
152+
type=float,
153+
default=0.20,
154+
help="percentage of total training steps to use for warmup",
155+
)
156+
parser.add_argument(
157+
"--training.max_norm",
158+
type=Union[float, int],
159+
default=1.0,
160+
help="max norm for gradient clipping",
161+
)
162+
parser.add_argument(
163+
"--training.steps", type=int, default=-1, help="how many train steps to run"
164+
)
165+
parser.add_argument(
166+
"--training.data_parallel_degree",
167+
type=int,
168+
default=-1,
169+
help="Data Parallelism degree. -1 means leftover ranks will be used (After SP/PP). 1 means disabled.",
170+
)
171+
parser.add_argument(
172+
"--training.sequence_parallel_degree",
173+
type=int,
174+
default=1,
175+
help="Sequence Parallelism degree. 1 means disabled.",
176+
)
177+
parser.add_argument(
178+
"--training.pipeline_parallel_degree",
179+
type=int,
180+
default=1,
181+
help="Pipeline Parallelism degree (default of 1 means disabled)",
182+
)
183+
parser.add_argument(
184+
"--training.compile",
185+
action="store_true",
186+
help="Whether to compile the model.",
187+
)
188+
parser.add_argument(
189+
"--training.checkpoint_interval",
190+
type=int,
191+
default=3600,
192+
help=(
193+
"Checkpointing interval. The unit of measurement is in seconds or "
194+
"steps depending on --training.checkpoint-internval-type."
195+
),
196+
)
197+
parser.add_argument(
198+
"--training.checkpoint_interval_type",
199+
type=str,
200+
default="steps",
201+
help=(
202+
"The checkpointing interval unit of measurement."
203+
"The default value is step."
204+
),
205+
)
206+
parser.add_argument(
207+
"--training.checkpoint_folder",
208+
type=str,
209+
default="",
210+
help=(
211+
"The folder to store the checkpoints. If this is not specified or "
212+
"is an empty string, checkpointing is disabled."
213+
),
214+
)
215+
return parser.parse_args(args_list)

torchtrain/datasets/tokenizer.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88

99
import os
1010
from abc import ABC, abstractmethod
11-
from typing import List
1211
from logging import getLogger
12+
from typing import List
1313

1414
from sentencepiece import SentencePieceProcessor
1515

@@ -48,6 +48,7 @@ def create_tokenizer(tokenizer_type: str, tokenizer_path: str) -> TokenizerIf:
4848

4949
class SentencePieceTokenizer(TokenizerIf):
5050
"""tokenizing and encoding/decoding text using SentencePiece."""
51+
5152
def __init__(self, tokenizer_path: str):
5253
"""
5354
Initializes the Tokenizer with a SentencePiece model.

torchtrain/logging_utils.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
import torch
21
import logging
32

3+
import torch
4+
45
logger = logging.getLogger()
56

67

torchtrain/lr_scheduling.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# All rights reserved.
33

44
from torch.optim.lr_scheduler import LambdaLR
5+
from torchtrain.config_manager import JobConfig
56

67
# global states for scheduling
78
# these are needed as LambdaLR does not support argument passing
@@ -29,11 +30,13 @@ def linear_warmup_linear_decay(current_step: int) -> float:
2930
return curr_adjustment
3031

3132

32-
def get_lr_scheduler(optimizer, args):
33+
def get_lr_scheduler(optimizer, job_config: JobConfig):
3334
"""Build a linear warmup and linear decay scheduler"""
3435
global _warmup_steps, _decay_steps
35-
_warmup_steps = max(int(args.steps * args.warmup_pct), 2)
36-
_decay_steps = float(max(1, args.steps - _warmup_steps))
36+
_warmup_steps = max(
37+
int(job_config.training.steps * job_config.training.warmup_pct), 2
38+
)
39+
_decay_steps = float(max(1, job_config.training.steps - _warmup_steps))
3740

3841
warmup_scheduler = LambdaLR(optimizer, lr_lambda=linear_warmup_linear_decay)
3942
return warmup_scheduler

torchtrain/metrics.py

+5-7
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
import torch
1313
import torch.nn as nn
1414
from torch.utils.tensorboard import SummaryWriter
15+
from torchtrain.config_manager import JobConfig
1516

1617
from torchtrain.logging_utils import rank0_log
17-
from torchtrain.profiling import get_config_from_toml
1818

1919
_gb_in_bytes = 1024 * 1024 * 1024
2020
_mb_in_bytes = 1024 * 1024
@@ -214,16 +214,14 @@ def close(self):
214214
self.writer.close()
215215

216216

217-
def build_metric_logger(tag: Optional[str] = None):
218-
config = get_config_from_toml()
219-
220-
dump_dir = config["global"]["dump_folder"]
221-
save_tb_folder = config["metrics"]["save_tb_folder"]
217+
def build_metric_logger(config: JobConfig, tag: Optional[str] = None):
218+
dump_dir = config.job.dump_folder
219+
save_tb_folder = config.metrics.save_tb_folder
222220
# since we don't have run id yet, use current minute as identifier
223221
datetime_str = datetime.now().strftime("%Y%m%d-%H%M")
224222
log_dir = os.path.join(dump_dir, save_tb_folder, datetime_str)
225223

226-
enable_tb = config["metrics"].get("enable_tensorboard", False)
224+
enable_tb = config.metrics.enable_tensorboard
227225
if enable_tb:
228226
rank0_log(
229227
f"Metrics logging active. Tensorboard logs will be saved at {log_dir}."

torchtrain/parallelisms/parallelize_llama.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
PrepareModuleInput,
3333
RowwiseParallel,
3434
)
35+
from torchtrain.config_manager import JobConfig
3536

3637
from torchtrain.logging_utils import rank0_log
3738

@@ -67,13 +68,14 @@ def partition_fn(name, module, device_mesh):
6768

6869

6970
# Uses PTD FSDP AC wrapper
70-
def checkpoint_wrapper(module, config):
71+
# TODO: why is config needed here?
72+
def checkpoint_wrapper(module, job_config: JobConfig):
7173
return ptd_checkpoint_wrapper(
7274
module, checkpoint_impl=CheckpointImpl.NO_REENTRANT, preserve_rng_state=False
7375
)
7476

7577

76-
def parallelize_llama(model, world_mesh, parallel_dims, args):
78+
def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
7779
"""
7880
Apply parallelisms to the model, including PTD parallelisms, and AC.
7981
@@ -87,7 +89,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, args):
8789
if parallel_dims.sp_enabled:
8890
# First we apply Sequence Parallelism if it's enabled
8991
tp_mesh = world_mesh["sp"] if world_mesh.ndim > 1 else world_mesh
90-
sp_degree = args.sp_degree
92+
sp_degree = job_config.training.sequence_parallelism_degree
9193
# First:
9294
# 1. parallelize the first embedding and the last linear proj layer
9395
# 2. shard the first layer of transformer block
@@ -169,7 +171,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, args):
169171
# apply AC to each layer
170172
# before wrapping with FSDP, we need to make sure the layer is on GPU
171173
transformer_block = transformer_block.cuda()
172-
transformer_block = checkpoint_wrapper(transformer_block, args)
174+
transformer_block = checkpoint_wrapper(transformer_block, job_config)
173175

174176
# Wraps each layer with FSDP
175177
model.layers[layer_id] = wrap(transformer_block)

0 commit comments

Comments
 (0)