Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 16 additions & 12 deletions examples/asr/jasper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from ruamel.yaml import YAML

import nemo
from nemo.utils.lr_policies import SquareAnnealing
from nemo.utils.lr_policies import CosineAnnealing
import nemo.utils.argparse as nm_argparse
import nemo_asr
from nemo_asr.helpers import monitor_asr_train_progress, \
Expand All @@ -30,9 +30,10 @@ def parse_args():
)

# Overwrite default args
parser.add_argument("--num_epochs", type=int, default=None, required=True,
help="number of epochs to train. You should specify"
"either num_epochs or max_steps")
parser.add_argument("--max_steps", type=int, default=None, required=False,
help="max number of steps to train")
parser.add_argument("--num_epochs", type=int, default=None, required=False,
help="number of epochs to train")
parser.add_argument("--model_config", type=str, required=True,
help="model configuration file: model.yaml")

Expand All @@ -43,18 +44,18 @@ def parse_args():
parser.add_argument("--warmup_steps", default=0, type=int)

args = parser.parse_args()
if args.max_steps is not None:
raise ValueError("Jasper uses num_epochs instead of max_steps")

if args.max_steps is not None and args.num_epochs is not None:
raise ValueError("Either max_steps or num_epochs should be provided.")
return args


def construct_name(name, lr, batch_size, num_epochs, wd, optimizer,
def construct_name(name, lr, batch_size, max_steps, wd, optimizer,
iter_per_step):
return ("{0}-lr_{1}-bs_{2}-e_{3}-wd_{4}-opt_{5}-ips_{6}".format(
return ("{0}-lr_{1}-bs_{2}-s_{3}-wd_{4}-opt_{5}-ips_{6}".format(
name, lr,
batch_size,
num_epochs,
max_steps,
wd,
optimizer,
iter_per_step))
Expand Down Expand Up @@ -241,7 +242,7 @@ def main():
args.exp_name,
args.lr,
args.batch_size,
args.num_epochs,
args.max_steps,
args.weight_decay,
args.optimizer,
args.iter_per_step)
Expand Down Expand Up @@ -275,11 +276,14 @@ def main():
neural_factory.train(
tensors_to_optimize=[train_loss],
callbacks=callbacks,
lr_policy=SquareAnnealing(args.num_epochs * steps_per_epoch,
warmup_steps=args.warmup_steps),
lr_policy=CosineAnnealing(
args.max_steps if args.max_steps is not None else
args.num_epochs * steps_per_epoch,
warmup_steps=args.warmup_steps),
optimizer=args.optimizer,
optimization_params={
"num_epochs": args.num_epochs,
"max_steps": args.max_steps,
"lr": args.lr,
"betas": (
args.beta1,
Expand Down
10 changes: 6 additions & 4 deletions nemo/nemo/backends/pytorch/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -929,8 +929,10 @@ def train(self,
stop_on_nan_loss=False):
if not optimization_params:
optimization_params = {}
num_epochs = optimization_params.get("num_epochs", 1)
num_epochs = optimization_params.get("num_epochs", None)
max_steps = optimization_params.get("max_steps", None)
if num_epochs is None and max_steps is None:
raise ValueError("You must specify either max_steps or num_epochs")
grad_norm_clip = optimization_params.get('grad_norm_clip', None)

if batches_per_step is None:
Expand Down Expand Up @@ -1108,8 +1110,8 @@ def train(self,

# MAIN TRAINING LOOP
# iteration over epochs
for epoch_ind in range(self.epoch_num, num_epochs):
self.epoch_num = epoch_ind
self.epoch_num = 0
while num_epochs is None or self.epoch_num < num_epochs:
if train_sampler is not None:
train_sampler.set_epoch(self.epoch_num)
if max_steps is not None and self.step >= max_steps:
Expand Down Expand Up @@ -1230,9 +1232,9 @@ def train(self,
self._perform_on_iteration_end(callbacks=callbacks)
self.step += 1
# End of epoch for loop

# Register epochs end with callbacks
self._perform_on_epoch_end(callbacks=callbacks)
self.epoch_num += 1
self._perform_on_action_end(callbacks=callbacks)

def infer(self,
Expand Down
3 changes: 2 additions & 1 deletion nemo/nemo/core/neural_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,8 @@ def eval(self,
self.train(
tensors_to_optimize=None,
optimizer='sgd',
callbacks=callbacks
callbacks=callbacks,
optimization_params={'num_epochs': 1}
)

def infer(self, tensors: List[NmTensor], checkpoint_dir=None,
Expand Down
5 changes: 2 additions & 3 deletions tests/test_neural_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ def test_creation(self):
instance, nemo.backends.pytorch.tutorials.TaylorNet))

def test_simple_example(self):
#######################################################################
neural_factory = nemo.core.neural_factory.NeuralModuleFactory(
backend=nemo.core.Backend.PyTorch,
local_rank=None,
Expand All @@ -36,5 +35,5 @@ def test_simple_example(self):

optimizer = neural_factory.get_trainer()
optimizer.train([loss_tensor], optimizer="sgd",
optimization_params={"lr": 1e-3})
#######################################################################
optimization_params={"lr": 1e-3,
"num_epochs": 1})