-
Notifications
You must be signed in to change notification settings - Fork 631
[torchtrain] add gradient clipping #28
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,7 +1,7 @@ | ||
| import argparse | ||
| import os | ||
| from dataclasses import dataclass, field | ||
| from typing import List | ||
| from typing import List, Union | ||
|
|
||
| # torch imports | ||
| import torch | ||
|
|
@@ -113,6 +113,8 @@ def main(args): | |
| input_ids = input_ids.cuda() | ||
| labels = labels.cuda() | ||
|
|
||
| optimizer.zero_grad() | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why we need to put this optimizer.zero_grad to the place before forward?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can do it up front or at the end...overall same net effect. |
||
|
|
||
| # forward | ||
| pred = model(input_ids) | ||
| tok_loss = F.cross_entropy( | ||
|
|
@@ -123,12 +125,14 @@ def main(args): | |
| # backward on scaled loss to create scaled gradients | ||
| scaler.scale(loss).backward() | ||
|
|
||
| # clip gradients (after unscaling gradients of the optimizer's params) | ||
| scaler.unscale_(optimizer) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. since we don't have unit tests, can you just double check that this unscale command is a noop for bf16 and not a utility like command?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (nvm, I double checked - it is indeed a no op if scaler not enabled. |
||
| model.clip_grad_norm_(args.max_norm) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right now I did |
||
|
|
||
| # optimizer step | ||
| # scaler.step() first unscales gradients of the optimizer's params. | ||
| # If gradients don't contain infs/NaNs, optimizer.step() is then called, | ||
| # If gradients don't contain infs/NaNs, optimizer.step() is then called; | ||
| # otherwise, optimizer.step() is skipped. | ||
| scaler.step(optimizer) | ||
| optimizer.zero_grad() | ||
|
|
||
| # updates the scale for next iteration | ||
| scaler.update() | ||
|
|
@@ -168,6 +172,9 @@ def main(args): | |
| "--optimizer", type=str, default="AdamW", help="optimizer to use" | ||
| ) | ||
| parser.add_argument("--lr", type=float, default=2e-5, help="learning rate to use") | ||
| parser.add_argument( | ||
| "--max_norm", type=Union[float, int], default=1.0, help="max norm for gradient clipping" | ||
| ) | ||
| parser.add_argument( | ||
| "--steps", type=int, default=-1, help="how many train steps to run" | ||
| ) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should use/expose 'set to None' option here as potential mild perf boost.
https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any known case where we want
zero_grad(set_to_none=False)?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems
zero_grad()hasset_to_none=Trueby default. I'll leave it to another PR to expose this option if needed.