Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add TensorBoard logging with loss and wps #57

Merged
merged 3 commits into from
Feb 15, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ torch >= 2.2.0.dev
sentencepiece
datasets
tomli >= 1.1.0 ; python_version < "3.11"
tensorboard
2 changes: 1 addition & 1 deletion run_llama_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,5 @@ CHECKPOINT_INTERVAL=${CHECKPOINT_INTERVAL:-5}
torchrun --nproc_per_node=${NGPU} \
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
train.py --steps 10 --compile \
--pp_degree ${PP} --sp_degree ${SP} --dp_degree ${DP}
--pp_degree ${PP} --sp_degree ${SP} --dp_degree ${DP} \
--checkpoint-folder=${CHECKPOINT_FOLDER} --checkpoint-interval=${CHECKPOINT_INTERVAL}
49 changes: 49 additions & 0 deletions torchtrain/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

import os
from datetime import datetime
from typing import Any, Dict, Optional

import torch
from torch.utils.tensorboard import SummaryWriter

from torchtrain.logging_utils import rank0_log
from torchtrain.profiling import get_config_from_toml


class MetricLogger:
def __init__(self, log_dir, tag, enable_tb):
self.tag = tag
self.writer: Optional[SummaryWriter] = None
if enable_tb:
self.writer = SummaryWriter(log_dir, max_queue=1000)

def log(self, metrics: Dict[str, Any], step: int):
if self.writer is not None:
for k, v in metrics.items():
tag = k if self.tag is None else f"{self.tag}/{k}"
self.writer.add_scalar(tag, v, step)

def close(self):
if self.writer is not None:
self.writer.close()


def build_metric_logger(tag: Optional[str] = None):
config = get_config_from_toml()

dump_dir = config["global"]["dump_folder"]
save_tb_folder = config["metrics"]["save_tb_folder"]
# since we don't have run id yet, use current minute as identifier
datetime_str = datetime.now().strftime("%Y%m%d-%H%M")
log_dir = os.path.join(dump_dir, save_tb_folder, datetime_str)

enable_tb = config["metrics"].get("enable_tensorboard", False)
if enable_tb:
rank0_log(
f"Metrics logging active. Tensorboard logs will be saved at {log_dir}."
)

rank_str = f"rank_{torch.distributed.get_rank()}"
return MetricLogger(os.path.join(log_dir, rank_str), tag, enable_tb)
4 changes: 4 additions & 0 deletions torchtrain/train_configs/train_config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,7 @@ run_profiler = true
save_traces_folder = "profiling/traces"
# profiling frequency - example: 10 means every 10th iter will be profiled
profile_every_x_iter = 10

[metrics]
enable_tensorboard = true
save_tb_folder = "tb"
19 changes: 19 additions & 0 deletions torchtrain/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

from typing import Union

import torch
import torch.distributed._functional_collectives as funcol
import torch.distributed.distributed_c10d as c10d
from torch.distributed.device_mesh import DeviceMesh


def dist_max(x: Union[int, float], mesh: DeviceMesh) -> float:
tensor = torch.tensor(x).cuda()
return funcol.all_reduce(tensor, reduceOp=c10d.ReduceOp.MAX.name, group=mesh)


def dist_mean(x: Union[int, float], mesh: DeviceMesh) -> float:
tensor = torch.tensor(x).cuda()
return funcol.all_reduce(tensor, reduceOp=c10d.ReduceOp.AVG.name, group=mesh)
46 changes: 45 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@
import argparse
import os
from dataclasses import dataclass, field
from timeit import default_timer as timer
from typing import Any, Dict, List, Union

import numpy as np

# torch imports
import torch
import torch.nn.functional as F
Expand All @@ -18,11 +21,13 @@
from torchtrain.datasets import create_tokenizer, dataloader_fn
from torchtrain.logging_utils import init_logger, rank0_log
from torchtrain.lr_scheduling import get_lr_scheduler
from torchtrain.metrics import build_metric_logger

from torchtrain.models import model_name_to_cls, model_name_to_tokenizer, models_config
from torchtrain.parallelisms import models_parallelize_fns, ParallelDims

from torchtrain.profiling import maybe_run_profiler
from torchtrain.utils import dist_max, dist_mean


@dataclass
Expand Down Expand Up @@ -116,7 +121,7 @@ def main(args):

scaler = build_grad_scaler(model)

# TODO: add metrics
metric_logger = build_metric_logger()

# torch.compile model for improved performance
if args.compile:
Expand Down Expand Up @@ -146,13 +151,18 @@ def main(args):

with maybe_run_profiler() as torch_profiler:
checkpoint.reset()
# variables used to keep info for metrics logging
losses_since_last_log: List[float] = []
nwords_since_last_log = 0
time_last_log = timer()
while train_state.step < args.steps or args.steps == -1:
train_state.step += 1
# get batch
batch = next(iter(data_loader))
input_ids, labels = batch
input_ids = input_ids.cuda()
labels = labels.cuda()
nwords_since_last_log += labels.numel()

optimizer.zero_grad()

Expand Down Expand Up @@ -184,6 +194,32 @@ def main(args):

train_state.current_loss = loss.item()
train_state.losses.append(train_state.current_loss)
losses_since_last_log.append(train_state.current_loss)

# log metrics
if (train_state.step - 1) % args.log_freq == 0:
avg_loss, max_loss = np.mean(losses_since_last_log), np.max(
losses_since_last_log
)
global_avg_loss, global_max_loss = dist_mean(
avg_loss, world_mesh
), dist_max(max_loss, world_mesh)

time_delta = timer() - time_last_log
wps = nwords_since_last_log / (
time_delta * parallel_dims.sp * parallel_dims.pp
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a neater way is to define a model_parallel_size in the parallel_dims class that return this number directly (i.e. a cached property)

)

metrics = {
"global_avg_loss": global_avg_loss,
"global_max_loss": global_max_loss,
"wps": wps,
}
metric_logger.log(metrics, step=train_state.step)

losses_since_last_log.clear()
nwords_since_last_log = 0
time_last_log = timer()

rank0_log(
f"step: {train_state.step}, current loss: {train_state.current_loss}, lr: {scheduler.get_last_lr()}"
Expand All @@ -192,6 +228,8 @@ def main(args):

checkpoint.save(train_state.step, force=(train_state.step == args.steps))

metric_logger.close()


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="TorchTrain arg parser.")
Expand Down Expand Up @@ -282,6 +320,12 @@ def main(args):
"is an empty string, checkpointing is disabled."
),
)
parser.add_argument(
"--log_freq",
type=int,
default=10,
help="how often to log metrics to TensorBoard",
)

args = parser.parse_args()
main(args)
Loading