Skip to content

Commit

Permalink
add TensorBoard logging with loss and wps
Browse files Browse the repository at this point in the history
ghstack-source-id: cdfe4c2c496feae23399019ec2a63b443fb3b6a9
Pull Request resolved: #57
  • Loading branch information
tianyu-l committed Feb 13, 2024
1 parent da50d34 commit 5d37ce0
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 2 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ torch >= 2.2.0.dev
sentencepiece
datasets
tomli >= 1.1.0 ; python_version < "3.11"
tensorboard
2 changes: 1 addition & 1 deletion run_llama_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,5 @@ CHECKPOINT_INTERVAL=${CHECKPOINT_INTERVAL:-5}
torchrun --nproc_per_node=${NGPU} \
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
train.py --steps 10 --compile \
--pp_degree ${PP} --sp_degree ${SP} --dp_degree ${DP}
--pp_degree ${PP} --sp_degree ${SP} --dp_degree ${DP} \
--checkpoint-folder=${CHECKPOINT_FOLDER} --checkpoint-interval=${CHECKPOINT_INTERVAL}
49 changes: 49 additions & 0 deletions torchtrain/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

import os
from datetime import datetime
from typing import Any, Dict, Optional

import torch
from torch.utils.tensorboard import SummaryWriter

from torchtrain.logging_utils import rank0_log
from torchtrain.profiling import get_config_from_toml


class MetricLogger:
def __init__(self, log_dir, tag, enable_tb):
self.tag = tag
self.writer: Optional[SummaryWriter] = None
if enable_tb:
self.writer = SummaryWriter(log_dir, max_queue=1000)

def log(self, metrics: Dict[str, Any], step: int):
if self.writer is not None:
for k, v in metrics.items():
tag = k if self.tag is None else f"{self.tag}/{k}"
self.writer.add_scalar(tag, v, step)

def close(self):
if self.writer is not None:
self.writer.close()


def build_metric_logger(tag: Optional[str] = None):
config = get_config_from_toml()

dump_dir = config["global"]["dump_folder"]
save_tb_folder = config["metrics"]["save_tb_folder"]
# since we don't have run id yet, use current minute as identifier
datetime_str = datetime.now().strftime("%Y%m%d-%H%M")
log_dir = os.path.join(dump_dir, save_tb_folder, datetime_str)

enable_tb = config["metrics"].get("enable_tensorboard", False)
if enable_tb:
rank0_log(
f"Metrics logging active. Tensorboard logs will be saved at {log_dir}."
)

rank_str = f"rank_{torch.distributed.get_rank()}"
return MetricLogger(os.path.join(log_dir, rank_str), tag, enable_tb)
4 changes: 4 additions & 0 deletions torchtrain/train_configs/train_config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,7 @@ run_profiler = true
save_traces_folder = "profiling/traces"
# profiling frequency - example: 10 means every 10th iter will be profiled
profile_every_x_iter = 10

[metrics]
enable_tensorboard = true
save_tb_folder = "tb"
19 changes: 19 additions & 0 deletions torchtrain/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

from typing import Union

import torch
import torch.distributed._functional_collectives as funcol
import torch.distributed.distributed_c10d as c10d
from torch.distributed.device_mesh import DeviceMesh


def dist_max(x: Union[int, float], mesh: DeviceMesh) -> float:
tensor = torch.tensor(x).cuda()
return funcol.all_reduce(tensor, reduceOp=c10d.ReduceOp.MAX.name, group=mesh)


def dist_mean(x: Union[int, float], mesh: DeviceMesh) -> float:
tensor = torch.tensor(x).cuda()
return funcol.all_reduce(tensor, reduceOp=c10d.ReduceOp.AVG.name, group=mesh)
46 changes: 45 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@
import argparse
import os
from dataclasses import dataclass, field
from timeit import default_timer as timer
from typing import Any, Dict, List, Union

import numpy as np

# torch imports
import torch
import torch.nn.functional as F
Expand All @@ -18,11 +21,13 @@
from torchtrain.datasets import create_tokenizer, dataloader_fn
from torchtrain.logging_utils import init_logger, rank0_log
from torchtrain.lr_scheduling import get_lr_scheduler
from torchtrain.metrics import build_metric_logger

from torchtrain.models import model_name_to_cls, model_name_to_tokenizer, models_config
from torchtrain.parallelisms import models_parallelize_fns, ParallelDims

from torchtrain.profiling import maybe_run_profiler
from torchtrain.utils import dist_max, dist_mean


@dataclass
Expand Down Expand Up @@ -116,7 +121,7 @@ def main(args):

scaler = build_grad_scaler(model)

# TODO: add metrics
metric_logger = build_metric_logger()

# torch.compile model for improved performance
if args.compile:
Expand Down Expand Up @@ -146,13 +151,18 @@ def main(args):

with maybe_run_profiler() as torch_profiler:
checkpoint.reset()
# variables used to keep info for metrics logging
losses_since_last_log: List[float] = []
nwords_since_last_log = 0
time_last_log = timer()
while train_state.step < args.steps or args.steps == -1:
train_state.step += 1
# get batch
batch = next(iter(data_loader))
input_ids, labels = batch
input_ids = input_ids.cuda()
labels = labels.cuda()
nwords_since_last_log += labels.numel()

optimizer.zero_grad()

Expand Down Expand Up @@ -184,6 +194,32 @@ def main(args):

train_state.current_loss = loss.item()
train_state.losses.append(train_state.current_loss)
losses_since_last_log.append(train_state.current_loss)

# log metrics
if (train_state.step - 1) % args.log_freq == 0:
avg_loss, max_loss = np.mean(losses_since_last_log), np.max(
losses_since_last_log
)
global_avg_loss, global_max_loss = dist_mean(
avg_loss, world_mesh
), dist_max(max_loss, world_mesh)

time_delta = timer() - time_last_log
wps = nwords_since_last_log / (
time_delta * parallel_dims.sp * parallel_dims.pp
)

metrics = {
"global_avg_loss": global_avg_loss,
"global_max_loss": global_max_loss,
"wps": wps,
}
metric_logger.log(metrics, step=train_state.step)

losses_since_last_log.clear()
nwords_since_last_log = 0
time_last_log = timer()

rank0_log(
f"step: {train_state.step}, current loss: {train_state.current_loss}, lr: {scheduler.get_last_lr()}"
Expand All @@ -192,6 +228,8 @@ def main(args):

checkpoint.save(train_state.step, force=(train_state.step == args.steps))

metric_logger.close()


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="TorchTrain arg parser.")
Expand Down Expand Up @@ -282,6 +320,12 @@ def main(args):
"is an empty string, checkpointing is disabled."
),
)
parser.add_argument(
"--log_freq",
type=int,
default=10,
help="how often to log metrics to TensorBoard",
)

args = parser.parse_args()
main(args)

0 comments on commit 5d37ce0

Please sign in to comment.