Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import argparse
import os
from dataclasses import dataclass, field
from typing import List
from typing import List, Union

# torch imports
import torch
Expand Down Expand Up @@ -113,6 +113,8 @@ def main(args):
input_ids = input_ids.cuda()
labels = labels.cuda()

optimizer.zero_grad()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should use/expose 'set to None' option here as potential mild perf boost.
https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any known case where we want zero_grad(set_to_none=False)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems zero_grad() has set_to_none=True by default. I'll leave it to another PR to expose this option if needed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we need to put this optimizer.zero_grad to the place before forward?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can do it up front or at the end...overall same net effect.
I see more people doing it up front though now, vs a few years ago it was more common at the end.
imo - Advantage of up front is you guarantee gradients are clear at start (they should be but a way to be safe), and if we are doing the set to None, then maybe we get that gain on the first pass as well.


# forward
pred = model(input_ids)
tok_loss = F.cross_entropy(
Expand All @@ -123,12 +125,14 @@ def main(args):
# backward on scaled loss to create scaled gradients
scaler.scale(loss).backward()

# clip gradients (after unscaling gradients of the optimizer's params)
scaler.unscale_(optimizer)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since we don't have unit tests, can you just double check that this unscale command is a noop for bf16 and not a utility like command?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model.clip_grad_norm_(args.max_norm)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since clip_grad_norm_ is not a method on nn.Module, is this assuming that isinstance(model, FullyShardedDataParallel)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now I did assert isinstance(model, FSDP) on line 85 since FSDP is always enabled. Moving forward, we probably need to use config to decide.


# optimizer step
# scaler.step() first unscales gradients of the optimizer's params.
# If gradients don't contain infs/NaNs, optimizer.step() is then called,
# If gradients don't contain infs/NaNs, optimizer.step() is then called;
# otherwise, optimizer.step() is skipped.
scaler.step(optimizer)
optimizer.zero_grad()

# updates the scale for next iteration
scaler.update()
Expand Down Expand Up @@ -168,6 +172,9 @@ def main(args):
"--optimizer", type=str, default="AdamW", help="optimizer to use"
)
parser.add_argument("--lr", type=float, default=2e-5, help="learning rate to use")
parser.add_argument(
"--max_norm", type=Union[float, int], default=1.0, help="max norm for gradient clipping"
)
parser.add_argument(
"--steps", type=int, default=-1, help="how many train steps to run"
)
Expand Down