-
Notifications
You must be signed in to change notification settings - Fork 252
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable checkpointing with DCP #26
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. | ||
|
||
import enum | ||
import os | ||
import re | ||
import time | ||
from typing import Any, Dict | ||
|
||
import torch | ||
import torch.distributed as dist | ||
import torch.distributed.checkpoint as dcp | ||
import torch.nn as nn | ||
from torch.distributed.checkpoint.state_dict import ( | ||
get_model_state_dict, | ||
get_optimizer_state_dict, | ||
set_model_state_dict, | ||
set_optimizer_state_dict, | ||
) | ||
from torchtrain.logging_utils import rank0_log | ||
|
||
|
||
class IntervalType(enum.Enum): | ||
SECONDS = enum.auto() | ||
STEPS = enum.auto() | ||
|
||
|
||
class ModelWrapper: | ||
def __init__(self, model: nn.Module) -> None: | ||
self.model = model | ||
|
||
def state_dict(self) -> None: | ||
return get_model_state_dict(self.model) | ||
|
||
def load_state_dict(self, state_dict: Dict[str, Any]) -> None: | ||
set_model_state_dict(self.model, state_dict) | ||
|
||
|
||
class OptimizerWrapper: | ||
def __init__(self, model: nn.Module, optim: torch.optim.Optimizer) -> None: | ||
self.model = model | ||
self.optim = optim | ||
|
||
def state_dict(self) -> None: | ||
return get_optimizer_state_dict(self.model, self.optim) | ||
|
||
def load_state_dict(self, state_dict: Dict[str, Any]) -> None: | ||
set_optimizer_state_dict(self.model, self.optim, optim_state_dict=state_dict) | ||
|
||
|
||
class CheckpointManager: | ||
def __init__( | ||
self, | ||
model: nn.Module, | ||
optimizer: torch.optim.Optimizer, | ||
states: Dict[str, Any], | ||
folder: str, | ||
interval_type: IntervalType, | ||
interval: int, | ||
) -> None: | ||
self.folder = folder | ||
self.states = states | ||
self.states.update( | ||
{ | ||
"model": ModelWrapper(model), | ||
"optimizer": OptimizerWrapper(model, optimizer), | ||
} | ||
) | ||
self.interval_type = interval_type | ||
self.interval = interval | ||
self.begin = 0 | ||
self.work = None | ||
self.pg = dist.new_group(backend="gloo") | ||
self.doit = None | ||
|
||
def reset(self) -> None: | ||
self.begin = time.monotonic() | ||
|
||
def create_checkpoint_id(self, step: int) -> str: | ||
return os.path.join(self.folder, f"step-{step}") | ||
|
||
def save(self, curr_step: int, force: bool = False) -> None: | ||
if not self.folder: | ||
return | ||
|
||
if not force: | ||
if self.interval_type == IntervalType.STEPS and not ( | ||
curr_step % self.interval == 0 | ||
): | ||
return | ||
if self.interval_type == IntervalType.SECONDS: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would prefer we get rid of the seconds handling as mentioned in another comment to keep our stack simple enough. We can add it back once we feel this mode is needed for the actual training There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is sometimes better to use time because a lot of features can change the per-iteration time like model type, batch size and other stuffs. Using steps may require some tuning to avoid affect the overall performance. We can change the default to steps so that users don't need to worry about it now. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure sounds good! My main motivation is that our library to be as simple as possible, we can evaluate once we start real trainings if we would use time interval type, and decide later whether we want to keep it or not |
||
doit = (time.monotonic() - self.begin) >= self.interval | ||
self.doit = torch.tensor(int(doit)) | ||
if self.work is None: | ||
self.work = dist.all_reduce(self.doit, group=self.pg, async_op=True) | ||
return | ||
elif curr_step % 5 == 4: | ||
self.work.wait() | ||
self.work = None | ||
doit = self.doit.item() | ||
self.doit = None | ||
if doit == 0: | ||
return | ||
else: | ||
return | ||
|
||
if self.work: | ||
self.work.wait() | ||
self.work = None | ||
self.doit = None | ||
|
||
rank0_log(f"Saving a checkpoint in step {curr_step}.") | ||
begin = time.monotonic() | ||
dcp.save(self.states, checkpoint_id=self.create_checkpoint_id(curr_step)) | ||
self.reset() | ||
rank0_log( | ||
f"Finish saving the checkpoint in step {curr_step}. " | ||
f"{time.monotonic() - begin} seconds" | ||
) | ||
|
||
def load(self, step: int = -1) -> bool: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why we have a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We do use the |
||
if not self.folder: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: we should check either in |
||
return False | ||
if not os.path.isdir(self.folder): | ||
return False | ||
if step != -1 and not os.path.isdir(self.create_checkpoint_id(step)): | ||
return False | ||
|
||
if step == -1: | ||
step_counts = [] | ||
for filename in os.listdir(self.folder): | ||
match = re.search(r"step-(\d+)", filename) | ||
if match: | ||
step_counts.append(int(match.group(1))) | ||
if not step_counts: | ||
return False | ||
step = max(step_counts) | ||
|
||
rank0_log("Loading a checkpoint.") | ||
begin = time.monotonic() | ||
dcp.load( | ||
self.states, | ||
checkpoint_id=self.create_checkpoint_id(step), | ||
) | ||
rank0_log(f"Finish loading a checkpoint. {time.monotonic() - begin} seconds.") | ||
return True |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,14 +4,16 @@ | |
import argparse | ||
import os | ||
from dataclasses import dataclass, field | ||
from typing import List, Union | ||
from typing import Any, Dict, List, Union | ||
|
||
# torch imports | ||
import torch | ||
import torch.nn.functional as F | ||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP | ||
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler | ||
|
||
from torchtrain.checkpoint import CheckpointManager, IntervalType | ||
|
||
# torchtrain related | ||
from torchtrain.datasets import create_tokenizer, dataloader_fn | ||
from torchtrain.logging_utils import init_logger, rank0_log | ||
|
@@ -29,6 +31,18 @@ class TrainState: | |
current_loss: float = -1 | ||
losses: List[float] = field(default_factory=list) | ||
|
||
def state_dict(self) -> Dict[str, Any]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: to avoid confusion with the model/optim state dict, we should rename this to sth like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is naming is required by DCP. |
||
return { | ||
"step": torch.tensor(self.step, dtype=torch.int32), | ||
"current_loss": torch.tensor(self.current_loss, dtype=torch.float32), | ||
"losses": torch.tensor(self.current_loss, dtype=torch.float32), | ||
} | ||
|
||
def load_state_dict(self, state_dict) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is naming is required by DCP. |
||
self.step = state_dict["step"].item() | ||
self.current_loss = state_dict["current_loss"].item() | ||
self.losses = state_dict["losses"].tolist() | ||
|
||
|
||
def build_optimizer(model, args): | ||
# build optimizer | ||
|
@@ -116,7 +130,22 @@ def main(args): | |
# train loop | ||
model.train() | ||
|
||
checkpoint = CheckpointManager( | ||
model=model, | ||
optimizer=optimizer, | ||
states={"train_state": train_state}, | ||
folder=args.checkpoint_folder, | ||
interval_type=( | ||
IntervalType.SECONDS | ||
if args.checkpoint_interval_type == "seconds" | ||
else IntervalType.STEPS | ||
), | ||
interval=args.checkpoint_interval, | ||
) | ||
checkpoint.load() | ||
|
||
with maybe_run_profiler() as torch_profiler: | ||
checkpoint.reset() | ||
while train_state.step < args.steps or args.steps == -1: | ||
train_state.step += 1 | ||
# get batch | ||
|
@@ -161,6 +190,8 @@ def main(args): | |
) | ||
scheduler.step() | ||
|
||
checkpoint.save(train_state.step, force=(train_state.step == args.steps)) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(description="TorchTrain arg parser.") | ||
|
@@ -224,6 +255,33 @@ def main(args): | |
parser.add_argument( | ||
"--compile", action="store_true", help="Whether to compile the model." | ||
) | ||
parser.add_argument( | ||
"--checkpoint-interval", | ||
type=int, | ||
default=3600, | ||
help=( | ||
"Checkpointing interval. The unit of measurement is in seconds or " | ||
"steps depending on --checkpoint-internval-type." | ||
), | ||
) | ||
parser.add_argument( | ||
"--checkpoint-interval-type", | ||
type=str, | ||
default="steps", | ||
help=( | ||
"The checkpointing interval unit of measurement." | ||
"The default value is step." | ||
), | ||
) | ||
parser.add_argument( | ||
"--checkpoint-folder", | ||
type=str, | ||
default="", | ||
help=( | ||
"The folder to store the checkpoints. If this is not specified or " | ||
"is an empty string, checkpointing is disabled." | ||
), | ||
) | ||
|
||
args = parser.parse_args() | ||
main(args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we change this to something like a
/tmp/torchtrain
so that it saves somewhere when we locally run it?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should set this be an opt-in feature so that people won't get surprise and may save too many files to
/tmp
when people are using the same machine. And since if the training finishes, there will be a checkpoint, users may unconsciously ignore all the new training because of an existing checkpoint with last_step. That happens a lot. So it's better to do an opt-in feature.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make sense