Skip to content

Commit

Permalink
Unified config manager for toml and command line (#76)
Browse files Browse the repository at this point in the history
Summary:

PR implements an unfied config manager.

- Command line args and toml file args are now unified.
- Defaults can be loaded from either.

options like `training.batchsize` will be available as
`config.training.batchsize` where `config` is a config manager object.

Test Plan:

Test Plan:
============================= test session starts
============================== platform linux -- Python 3.10.13,
pytest-8.0.1, pluggy-1.4.0 --
/home/gnadathur/local/a/pytorch-env/bin/python cachedir: .pytest_cache
rootdir: /data/users/gnadathur/a/torchtrain
configfile: pyproject.toml
plugins: cov-4.1.0
collecting ... collected 5 items

test/test_job_config.py::TestJobConfig::test_command_line_args PASSED [
20%]
test/test_job_config.py::TestJobConfig::test_command_line_args_with_override
PASSED [ 40%]
test/test_job_config.py::TestJobConfig::test_job_config_file PASSED [
60%]
test/test_job_config.py::TestJobConfig::test_job_config_file_with_override
PASSED [ 80%]
test/test_job_config.py::TestJobConfig::test_job_file_does_not_exist
PASSED [100%]

---------- coverage: platform linux, python 3.10.13-final-0 ----------
Coverage XML written to file coverage.xml

============================= slowest 20 durations
============================= 0.01s call
test/test_job_config.py::TestJobConfig::test_job_config_file_with_override
0.00s call test/test_job_config.py::TestJobConfig::test_job_config_file
0.00s call
test/test_job_config.py::TestJobConfig::test_command_line_args 0.00s
call
test/test_job_config.py::TestJobConfig::test_command_line_args_with_override
0.00s call
test/test_job_config.py::TestJobConfig::test_job_file_does_not_exist
0.00s setup
test/test_job_config.py::TestJobConfig::test_command_line_args 0.00s
teardown test/test_job_config.py::TestJobConfig::test_command_line_args
0.00s setup
test/test_job_config.py::TestJobConfig::test_job_file_does_not_exist
0.00s setup
test/test_job_config.py::TestJobConfig::test_command_line_args_with_override
0.00s teardown
test/test_job_config.py::TestJobConfig::test_command_line_args_with_override
0.00s setup
test/test_job_config.py::TestJobConfig::test_job_config_file_with_override
0.00s setup test/test_job_config.py::TestJobConfig::test_job_config_file
0.00s teardown
test/test_job_config.py::TestJobConfig::test_job_file_does_not_exist
0.00s teardown
test/test_job_config.py::TestJobConfig::test_job_config_file 0.00s
teardown
test/test_job_config.py::TestJobConfig::test_job_config_file_with_override
============================== 5 passed in 0.10s
===============================

Reviewers:

Subscribers:

Tasks:

Tags:

Co-authored-by: gnadathur <[email protected]>
  • Loading branch information
gnadathur and gnadathur authored Feb 24, 2024
1 parent acb50a9 commit a85b9d4
Show file tree
Hide file tree
Showing 13 changed files with 339 additions and 173 deletions.
8 changes: 3 additions & 5 deletions run_llama_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,8 @@ CHECKPOINT_FOLDER=${CHECKPOINT_FOLDER:-""}
# Please adjust this to a longer interval period. The unit of measurement is in steps.
CHECKPOINT_INTERVAL=${CHECKPOINT_INTERVAL:-5}

CONFIG_FILE=${CONFIG_FILE:-"./torchtrain/train_configs/train_config.toml"}

torchrun --nproc_per_node=${NGPU} --rdzv_endpoint="localhost:5972" \
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
train.py --steps 10 \
--model ${MODEL} --model_conf ${MODEL_CONF} \
--pp_degree ${PP} --sp_degree ${SP} --dp_degree ${DP} \
--compile \
--checkpoint-folder=${CHECKPOINT_FOLDER} --checkpoint-interval=${CHECKPOINT_INTERVAL}
train.py --job.config_file ${CONFIG_FILE}
Empty file added test/__init__.py
Empty file.
21 changes: 21 additions & 0 deletions test/test_job_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import pytest
from torchtrain.config_manager import JobConfig


class TestJobConfig:
def test_command_line_args(self):
config = JobConfig()
config.parse_args([])
assert config.model.name == "llama"

def test_job_config_file(self):
config = JobConfig()
config.parse_args(
["--job.config_file", "./torchtrain/train_configs/train_config.toml"]
)
assert config.model.name == "llama"

def test_job_file_does_not_exist(self):
with pytest.raises(FileNotFoundError):
config = JobConfig()
config.parse_args(["--job.config_file", "ohno.toml"])
1 change: 1 addition & 0 deletions test/test_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.


# delete me after adding real tests..
class Test:
def test_test(self):
Expand Down
215 changes: 215 additions & 0 deletions torchtrain/config_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
import argparse
import sys
from collections import defaultdict
from typing import Union

try:
import tomllib
except ModuleNotFoundError:
import tomli as tomllib


class JobConfig:
"""
A helper class to manage the train configuration.
Semantics:
- Default config is loaded from a toml file. If no toml file is provided,
then the default config is loaded from argparse defaults.
"""

def parse_args(self, args_list: list = sys.argv[1:]):
args = JobConfig.init_args_from_command_line(args_list)
config_file = getattr(args, "job.config_file", None)
if config_file is None:
args_dict = self._args_to_two_level_dict(args)
else:
with open(config_file, "rb") as f:
args_dict = tomllib.load(f)
for k, v in args_dict.items():
class_type = type(k.title(), (), v)
setattr(self, k, class_type())
self._validate_config()

def _args_to_two_level_dict(self, args: argparse.Namespace) -> defaultdict:
args_dict = defaultdict(defaultdict)
for k, v in vars(args).items():
first_level_key, second_level_key = k.split(".", 1)
args_dict[first_level_key][second_level_key] = v
return args_dict

def _validate_config(self):
# TODO: Add more mandatory validations
assert self.model.name and self.model.flavor and self.model.tokenizer_path
return True

@staticmethod
def init_args_from_command_line(
args_list: list = sys.argv[1:],
) -> argparse.Namespace:
"""
Each argument starts with <prefix>_ which is the section name in the toml file
followed by name of the option in the toml file. For ex,
model.name translates to:
[model]
name
in the toml file
"""
parser = argparse.ArgumentParser(description="TorchTrain arg parser.")
parser.add_argument(
"--job.config_file",
type=str,
default=None,
help="job config file",
)

# misc configs
parser.add_argument(
"--job.dump_folder",
type=str,
default="./torchtrain/outputs",
help="folder to dump job outputs",
)

# profiling configs
parser.add_argument(
"--profiling.run_profiler",
action="store_true",
help="enable pytorch profiler",
)
parser.add_argument(
"--profiling.save_traces_folder",
type=str,
default="profiling/traces",
help="trace file location",
)
parser.add_argument(
"--profiling.profile_every_x_iter",
type=int,
default=10,
help="collect profiler traces every x iterations",
)
# metrics configs
parser.add_argument(
"--metrics.log_freq",
type=int,
default=10,
help="how often to log metrics to TensorBoard",
)
parser.add_argument(
"--metrics.enable_tensorboard",
action="store_true",
help="how often to log metrics to TensorBoard",
)
parser.add_argument(
"--metrics.save_tb_folder",
type=str,
default="tb",
help="folder to dump tensorboard state",
)

# model configs
parser.add_argument(
"--model.name",
type=str,
default="llama",
help="which model to train",
)
parser.add_argument(
"--model.flavor",
type=str,
default="debugmodel",
help="which model config to train",
)
parser.add_argument(
"--model.tokenizer_path",
type=str,
default="./torchtrain/datasets/tokenizer/tokenizer.model",
help="tokenizer path",
)

# optimizer configs
parser.add_argument(
"--optimizer.name", type=str, default="AdamW", help="optimizer to use"
)
parser.add_argument(
"--optimizer.lr", type=float, default=8e-4, help="learning rate to use"
)

# training configs
parser.add_argument(
"--training.dataset", type=str, default="alpaca", help="dataset to use"
)
parser.add_argument(
"--training.batch_size", type=int, default=8, help="batch size"
)
parser.add_argument(
"--training.seq_len", type=int, default=2048, help="sequence length"
)
parser.add_argument(
"--training.warmup_pct",
type=float,
default=0.20,
help="percentage of total training steps to use for warmup",
)
parser.add_argument(
"--training.max_norm",
type=Union[float, int],
default=1.0,
help="max norm for gradient clipping",
)
parser.add_argument(
"--training.steps", type=int, default=-1, help="how many train steps to run"
)
parser.add_argument(
"--training.data_parallel_degree",
type=int,
default=-1,
help="Data Parallelism degree. -1 means leftover ranks will be used (After SP/PP). 1 means disabled.",
)
parser.add_argument(
"--training.sequence_parallel_degree",
type=int,
default=1,
help="Sequence Parallelism degree. 1 means disabled.",
)
parser.add_argument(
"--training.pipeline_parallel_degree",
type=int,
default=1,
help="Pipeline Parallelism degree (default of 1 means disabled)",
)
parser.add_argument(
"--training.compile",
action="store_true",
help="Whether to compile the model.",
)
parser.add_argument(
"--training.checkpoint_interval",
type=int,
default=3600,
help=(
"Checkpointing interval. The unit of measurement is in seconds or "
"steps depending on --training.checkpoint-internval-type."
),
)
parser.add_argument(
"--training.checkpoint_interval_type",
type=str,
default="steps",
help=(
"The checkpointing interval unit of measurement."
"The default value is step."
),
)
parser.add_argument(
"--training.checkpoint_folder",
type=str,
default="",
help=(
"The folder to store the checkpoints. If this is not specified or "
"is an empty string, checkpointing is disabled."
),
)
return parser.parse_args(args_list)
3 changes: 2 additions & 1 deletion torchtrain/datasets/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

import os
from abc import ABC, abstractmethod
from typing import List
from logging import getLogger
from typing import List

from sentencepiece import SentencePieceProcessor

Expand Down Expand Up @@ -48,6 +48,7 @@ def create_tokenizer(tokenizer_type: str, tokenizer_path: str) -> TokenizerIf:

class SentencePieceTokenizer(TokenizerIf):
"""tokenizing and encoding/decoding text using SentencePiece."""

def __init__(self, tokenizer_path: str):
"""
Initializes the Tokenizer with a SentencePiece model.
Expand Down
3 changes: 2 additions & 1 deletion torchtrain/logging_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
import logging

import torch

logger = logging.getLogger()


Expand Down
9 changes: 6 additions & 3 deletions torchtrain/lr_scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# All rights reserved.

from torch.optim.lr_scheduler import LambdaLR
from torchtrain.config_manager import JobConfig

# global states for scheduling
# these are needed as LambdaLR does not support argument passing
Expand Down Expand Up @@ -29,11 +30,13 @@ def linear_warmup_linear_decay(current_step: int) -> float:
return curr_adjustment


def get_lr_scheduler(optimizer, args):
def get_lr_scheduler(optimizer, job_config: JobConfig):
"""Build a linear warmup and linear decay scheduler"""
global _warmup_steps, _decay_steps
_warmup_steps = max(int(args.steps * args.warmup_pct), 2)
_decay_steps = float(max(1, args.steps - _warmup_steps))
_warmup_steps = max(
int(job_config.training.steps * job_config.training.warmup_pct), 2
)
_decay_steps = float(max(1, job_config.training.steps - _warmup_steps))

warmup_scheduler = LambdaLR(optimizer, lr_lambda=linear_warmup_linear_decay)
return warmup_scheduler
12 changes: 5 additions & 7 deletions torchtrain/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
import torch
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
from torchtrain.config_manager import JobConfig

from torchtrain.logging_utils import rank0_log
from torchtrain.profiling import get_config_from_toml

_gb_in_bytes = 1024 * 1024 * 1024
_mb_in_bytes = 1024 * 1024
Expand Down Expand Up @@ -214,16 +214,14 @@ def close(self):
self.writer.close()


def build_metric_logger(tag: Optional[str] = None):
config = get_config_from_toml()

dump_dir = config["global"]["dump_folder"]
save_tb_folder = config["metrics"]["save_tb_folder"]
def build_metric_logger(config: JobConfig, tag: Optional[str] = None):
dump_dir = config.job.dump_folder
save_tb_folder = config.metrics.save_tb_folder
# since we don't have run id yet, use current minute as identifier
datetime_str = datetime.now().strftime("%Y%m%d-%H%M")
log_dir = os.path.join(dump_dir, save_tb_folder, datetime_str)

enable_tb = config["metrics"].get("enable_tensorboard", False)
enable_tb = config.metrics.enable_tensorboard
if enable_tb:
rank0_log(
f"Metrics logging active. Tensorboard logs will be saved at {log_dir}."
Expand Down
10 changes: 6 additions & 4 deletions torchtrain/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
PrepareModuleInput,
RowwiseParallel,
)
from torchtrain.config_manager import JobConfig

from torchtrain.logging_utils import rank0_log

Expand Down Expand Up @@ -67,13 +68,14 @@ def partition_fn(name, module, device_mesh):


# Uses PTD FSDP AC wrapper
def checkpoint_wrapper(module, config):
# TODO: why is config needed here?
def checkpoint_wrapper(module, job_config: JobConfig):
return ptd_checkpoint_wrapper(
module, checkpoint_impl=CheckpointImpl.NO_REENTRANT, preserve_rng_state=False
)


def parallelize_llama(model, world_mesh, parallel_dims, args):
def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
"""
Apply parallelisms to the model, including PTD parallelisms, and AC.
Expand All @@ -87,7 +89,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, args):
if parallel_dims.sp_enabled:
# First we apply Sequence Parallelism if it's enabled
tp_mesh = world_mesh["sp"] if world_mesh.ndim > 1 else world_mesh
sp_degree = args.sp_degree
sp_degree = job_config.training.sequence_parallelism_degree
# First:
# 1. parallelize the first embedding and the last linear proj layer
# 2. shard the first layer of transformer block
Expand Down Expand Up @@ -169,7 +171,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, args):
# apply AC to each layer
# before wrapping with FSDP, we need to make sure the layer is on GPU
transformer_block = transformer_block.cuda()
transformer_block = checkpoint_wrapper(transformer_block, args)
transformer_block = checkpoint_wrapper(transformer_block, job_config)

# Wraps each layer with FSDP
model.layers[layer_id] = wrap(transformer_block)
Expand Down
Loading

0 comments on commit a85b9d4

Please sign in to comment.