Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable checkpointing with DCP #26

Merged
merged 5 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion run_llama_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ TRAINER_DIR=${1:-/home/$USER/local/torchtrain}
MODEL="debugmodel"
NGPU=8
MP=4
# Change this string to a meaningful one to enable checkpoint
CHECKPOINT_FOLDER=""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we change this to something like a /tmp/torchtrain so that it saves somewhere when we locally run it?

Copy link
Contributor Author

@fegin fegin Feb 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should set this be an opt-in feature so that people won't get surprise and may save too many files to /tmp when people are using the same machine. And since if the training finishes, there will be a checkpoint, users may unconsciously ignore all the new training because of an existing checkpoint with last_step. That happens a lot. So it's better to do an opt-in feature.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make sense

# Please adjust this to a longer interval period. The unit of measurement is in steps.
CHECKPOINT_INTERVAL=5

torchrun --nproc_per_node=${NGPU} \
train.py --steps 10
train.py --steps 10 --compile \
--checkpoint-folder=${CHECKPOINT_FOLDER} --checkpoint-interval=${CHECKPOINT_INTERVAL}
146 changes: 146 additions & 0 deletions torchtrain/checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

import enum
import os
import re
import time
from typing import Any, Dict

import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import torch.nn as nn
from torch.distributed.checkpoint.state_dict import (
get_model_state_dict,
get_optimizer_state_dict,
set_model_state_dict,
set_optimizer_state_dict,
)
from torchtrain.logging_utils import rank0_log


class IntervalType(enum.Enum):
SECONDS = enum.auto()
STEPS = enum.auto()


class ModelWrapper:
def __init__(self, model: nn.Module) -> None:
self.model = model

def state_dict(self) -> None:
return get_model_state_dict(self.model)

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
set_model_state_dict(self.model, state_dict)


class OptimizerWrapper:
def __init__(self, model: nn.Module, optim: torch.optim.Optimizer) -> None:
self.model = model
self.optim = optim

def state_dict(self) -> None:
return get_optimizer_state_dict(self.model, self.optim)

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
set_optimizer_state_dict(self.model, self.optim, optim_state_dict=state_dict)


class CheckpointManager:
def __init__(
self,
model: nn.Module,
optimizer: torch.optim.Optimizer,
states: Dict[str, Any],
folder: str,
interval_type: IntervalType,
interval: int,
) -> None:
self.folder = folder
self.states = states
self.states.update(
{
"model": ModelWrapper(model),
"optimizer": OptimizerWrapper(model, optimizer),
}
)
self.interval_type = interval_type
self.interval = interval
self.begin = 0
self.work = None
self.pg = dist.new_group(backend="gloo")
self.doit = None

def reset(self) -> None:
self.begin = time.monotonic()

def create_checkpoint_id(self, step: int) -> str:
return os.path.join(self.folder, f"step-{step}")

def save(self, curr_step: int, force: bool = False) -> None:
if not self.folder:
return

if not force:
if self.interval_type == IntervalType.STEPS and not (
curr_step % self.interval == 0
):
return
if self.interval_type == IntervalType.SECONDS:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer we get rid of the seconds handling as mentioned in another comment to keep our stack simple enough.

We can add it back once we feel this mode is needed for the actual training

Copy link
Contributor Author

@fegin fegin Feb 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is sometimes better to use time because a lot of features can change the per-iteration time like model type, batch size and other stuffs. Using steps may require some tuning to avoid affect the overall performance.

We can change the default to steps so that users don't need to worry about it now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure sounds good! My main motivation is that our library to be as simple as possible, we can evaluate once we start real trainings if we would use time interval type, and decide later whether we want to keep it or not

doit = (time.monotonic() - self.begin) >= self.interval
self.doit = torch.tensor(int(doit))
if self.work is None:
self.work = dist.all_reduce(self.doit, group=self.pg, async_op=True)
return
elif curr_step % 5 == 4:
self.work.wait()
self.work = None
doit = self.doit.item()
self.doit = None
if doit == 0:
return
else:
return

if self.work:
self.work.wait()
self.work = None
self.doit = None

rank0_log(f"Saving a checkpoint in step {curr_step}.")
begin = time.monotonic()
dcp.save(self.states, checkpoint_id=self.create_checkpoint_id(curr_step))
self.reset()
rank0_log(
f"Finish saving the checkpoint in step {curr_step}. "
f"{time.monotonic() - begin} seconds"
)

def load(self, step: int = -1) -> bool:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we have a step arg here? seems like we don't use this arg too, we should remove it first.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do use the step. In the case where there are more than one checkpoint saved, users can specify the step to load a specific checkpoint.

if not self.folder:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we should check either in train.py or in this save and load method to only save/load when step % checkpoint_interval == 0, so that we skip the save/load logic when we don't need to save/load checkpoints.

return False
if not os.path.isdir(self.folder):
return False
if step != -1 and not os.path.isdir(self.create_checkpoint_id(step)):
return False

if step == -1:
step_counts = []
for filename in os.listdir(self.folder):
match = re.search(r"step-(\d+)", filename)
if match:
step_counts.append(int(match.group(1)))
if not step_counts:
return False
step = max(step_counts)

rank0_log("Loading a checkpoint.")
begin = time.monotonic()
dcp.load(
self.states,
checkpoint_id=self.create_checkpoint_id(step),
)
rank0_log(f"Finish loading a checkpoint. {time.monotonic() - begin} seconds.")
return True
60 changes: 59 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@
import argparse
import os
from dataclasses import dataclass, field
from typing import List, Union
from typing import Any, Dict, List, Union

# torch imports
import torch
import torch.nn.functional as F
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler

from torchtrain.checkpoint import CheckpointManager, IntervalType

# torchtrain related
from torchtrain.datasets import create_tokenizer, dataloader_fn
from torchtrain.logging_utils import init_logger, rank0_log
Expand All @@ -29,6 +31,18 @@ class TrainState:
current_loss: float = -1
losses: List[float] = field(default_factory=list)

def state_dict(self) -> Dict[str, Any]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: to avoid confusion with the model/optim state dict, we should rename this to sth like train_state

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is naming is required by DCP.

return {
"step": torch.tensor(self.step, dtype=torch.int32),
"current_loss": torch.tensor(self.current_loss, dtype=torch.float32),
"losses": torch.tensor(self.current_loss, dtype=torch.float32),
}

def load_state_dict(self, state_dict) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto: load_train_state to avoid confusion with DCP.save/load_state_dict

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is naming is required by DCP.

self.step = state_dict["step"].item()
self.current_loss = state_dict["current_loss"].item()
self.losses = state_dict["losses"].tolist()


def build_optimizer(model, args):
# build optimizer
Expand Down Expand Up @@ -116,7 +130,22 @@ def main(args):
# train loop
model.train()

checkpoint = CheckpointManager(
model=model,
optimizer=optimizer,
states={"train_state": train_state},
folder=args.checkpoint_folder,
interval_type=(
IntervalType.SECONDS
if args.checkpoint_interval_type == "seconds"
else IntervalType.STEPS
),
interval=args.checkpoint_interval,
)
checkpoint.load()

with maybe_run_profiler() as torch_profiler:
checkpoint.reset()
while train_state.step < args.steps or args.steps == -1:
train_state.step += 1
# get batch
Expand Down Expand Up @@ -161,6 +190,8 @@ def main(args):
)
scheduler.step()

checkpoint.save(train_state.step, force=(train_state.step == args.steps))


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="TorchTrain arg parser.")
Expand Down Expand Up @@ -224,6 +255,33 @@ def main(args):
parser.add_argument(
"--compile", action="store_true", help="Whether to compile the model."
)
parser.add_argument(
"--checkpoint-interval",
type=int,
default=3600,
help=(
"Checkpointing interval. The unit of measurement is in seconds or "
"steps depending on --checkpoint-internval-type."
),
)
parser.add_argument(
"--checkpoint-interval-type",
type=str,
default="steps",
help=(
"The checkpointing interval unit of measurement."
"The default value is step."
),
)
parser.add_argument(
"--checkpoint-folder",
type=str,
default="",
help=(
"The folder to store the checkpoints. If this is not specified or "
"is an empty string, checkpointing is disabled."
),
)

args = parser.parse_args()
main(args)
Loading