-
Notifications
You must be signed in to change notification settings - Fork 252
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable checkpointing with DCP #26
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! thanks for doing this super fast! I have a few suggestions inlined.
@@ -7,6 +7,11 @@ TRAINER_DIR=${1:-/home/$USER/local/torchtrain} | |||
MODEL="debugmodel" | |||
NGPU=8 | |||
MP=4 | |||
# Change this string to a meaningful one to enable checkpoint | |||
CHECKPOINT_FOLDER="" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we change this to something like a /tmp/torchtrain
so that it saves somewhere when we locally run it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should set this be an opt-in feature so that people won't get surprise and may save too many files to /tmp
when people are using the same machine. And since if the training finishes, there will be a checkpoint, users may unconsciously ignore all the new training because of an existing checkpoint with last_step. That happens a lot. So it's better to do an opt-in feature.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make sense
train.py
Outdated
) | ||
parser.add_argument( | ||
"--checkpoint-interval-type", | ||
type=str, default="seconds", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need to use interval type seconds here? I think maybe a simple checkpoint-interval
that documented as number of iterations should be enough? as we ultimately don't know how much time a model fwd/bwd/optim time would take, I think a number of iterations is more sound than seconds.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can keep the time feature as described above but use step as the default one.
train.py
Outdated
rank0_log(f"current loss: {train_state.current_loss}") | ||
|
||
checkpoint.save(train_state.step) | ||
|
||
if train_state.step == args.steps: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
iiuc this is after all steps we save a final checkpoint?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes
and (curr_step - self.begin) < self.interval | ||
): | ||
return | ||
if self.interval_type == IntervalType.SECONDS: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would prefer we get rid of the seconds handling as mentioned in another comment to keep our stack simple enough.
We can add it back once we feel this mode is needed for the actual training
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is sometimes better to use time because a lot of features can change the per-iteration time like model type, batch size and other stuffs. Using steps may require some tuning to avoid affect the overall performance.
We can change the default to steps so that users don't need to worry about it now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure sounds good! My main motivation is that our library to be as simple as possible, we can evaluate once we start real trainings if we would use time interval type, and decide later whether we want to keep it or not
) | ||
|
||
def load(self, step: int = -1) -> bool: | ||
if not self.folder: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: we should check either in train.py
or in this save
and load
method to only save/load when step % checkpoint_interval == 0
, so that we skip the save/load logic when we don't need to save/load checkpoints.
a0257bc
to
017680f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm! have a few more minor comments inlined
@@ -30,6 +31,18 @@ class TrainState: | |||
current_loss: float = -1 | |||
losses: List[float] = field(default_factory=list) | |||
|
|||
def state_dict(self) -> Dict[str, Any]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: to avoid confusion with the model/optim state dict, we should rename this to sth like train_state
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is naming is required by DCP.
"losses": torch.tensor(self.current_loss, dtype=torch.float32), | ||
} | ||
|
||
def load_state_dict(self, state_dict) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto: load_train_state
to avoid confusion with DCP.save/load_state_dict
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is naming is required by DCP.
f"{time.monotonic() - begin} seconds" | ||
) | ||
|
||
def load(self, step: int = -1) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why we have a step
arg here? seems like we don't use this arg too, we should remove it first.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We do use the step
. In the case where there are more than one checkpoint saved, users can specify the step
to load a specific checkpoint.
and (curr_step - self.begin) < self.interval | ||
): | ||
return | ||
if self.interval_type == IntervalType.SECONDS: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure sounds good! My main motivation is that our library to be as simple as possible, we can evaluate once we start real trainings if we would use time interval type, and decide later whether we want to keep it or not
@@ -7,6 +7,11 @@ TRAINER_DIR=${1:-/home/$USER/local/torchtrain} | |||
MODEL="debugmodel" | |||
NGPU=8 | |||
MP=4 | |||
# Change this string to a meaningful one to enable checkpoint | |||
CHECKPOINT_FOLDER="" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make sense
could you try rebasing before you merge? you'll pick up the linter CI that way and itll force you to lint your new files. Hopefully not too many conflicts to resolve since most of the hairy linter changes were in model.py |
Summary: This PR enable checkpointing. The PR only enables checkpointing in the local storages. Only when DCP enables automatic storage detection can this checkpoint manager support remote storages. This PR didn't checkpoint dataloader. Test Plan: Changed CHECKPOINT_FOLDER to /tmp/checkpoint_chienchin and ran ./run_llama_train.sh twice. The first run ran through all 100 steps and the checkpoints were saved. The second run loaded the checkpoint back and detected the saved step count is 100. No training was done for the second step. Reviewers: Subscribers: Tasks: Tags:
…t of measurement.
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
0541be2
to
fe2e1c6
Compare
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: This PR enable checkpointing. The PR only enables checkpointing in the local storages. Only when DCP enables automatic storage detection can this checkpoint manager support remote storages. This PR didn't checkpoint dataloader. Test Plan: Changed CHECKPOINT_FOLDER to /tmp/checkpoint_chienchin and ran ./run_llama_train.sh twice. The first run ran through all 100 steps and the checkpoints were saved. The second run loaded the checkpoint back and detected the saved step count is 100. No training was done for the second step.
Summary:
This PR enable checkpointing. The PR only enables checkpointing in the local storages. Only when DCP enables automatic storage detection can this checkpoint manager support remote storages.
This PR didn't checkpoint dataloader.
Test Plan:
Changed CHECKPOINT_FOLDER to /tmp/checkpoint_chienchin and ran ./run_llama_train.sh twice. The first run ran through all 100 steps and the checkpoints were saved. The second run loaded the checkpoint back and detected the saved step count is 100. No training was done for the second step.