diff --git a/train.py b/train.py index 447fc7436d..18c460723f 100644 --- a/train.py +++ b/train.py @@ -1,7 +1,7 @@ import argparse import os from dataclasses import dataclass, field -from typing import List +from typing import List, Union # torch imports import torch @@ -113,6 +113,8 @@ def main(args): input_ids = input_ids.cuda() labels = labels.cuda() + optimizer.zero_grad() + # forward pred = model(input_ids) tok_loss = F.cross_entropy( @@ -123,12 +125,14 @@ def main(args): # backward on scaled loss to create scaled gradients scaler.scale(loss).backward() + # clip gradients (after unscaling gradients of the optimizer's params) + scaler.unscale_(optimizer) + model.clip_grad_norm_(args.max_norm) + # optimizer step - # scaler.step() first unscales gradients of the optimizer's params. - # If gradients don't contain infs/NaNs, optimizer.step() is then called, + # If gradients don't contain infs/NaNs, optimizer.step() is then called; # otherwise, optimizer.step() is skipped. scaler.step(optimizer) - optimizer.zero_grad() # updates the scale for next iteration scaler.update() @@ -168,6 +172,9 @@ def main(args): "--optimizer", type=str, default="AdamW", help="optimizer to use" ) parser.add_argument("--lr", type=float, default=2e-5, help="learning rate to use") + parser.add_argument( + "--max_norm", type=Union[float, int], default=1.0, help="max norm for gradient clipping" + ) parser.add_argument( "--steps", type=int, default=-1, help="how many train steps to run" )