Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 10 additions & 63 deletions examples/asr/jasper_infer.py → examples/asr/jasper_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,14 @@
import argparse
import copy
import os
import pickle

from ruamel.yaml import YAML
import numpy as np

import nemo
import nemo_asr
from nemo_asr.helpers import word_error_rate, post_process_predictions, \
post_process_transcripts
import numpy as np


def main():
Expand Down Expand Up @@ -51,10 +50,6 @@ def main():
required=False, default=0.1)
parser.add_argument(
"--beam_width", default=128, type=int)
parser.add_argument(
'--mode',
help='either \'eval\' (default) or \'infer\'',
default='infer')

args = parser.parse_args()
batch_size = args.batch_size
Expand Down Expand Up @@ -137,24 +132,10 @@ def main():
eval_tensors = [log_probs_e1, predictions_e1,
transcript_e1, transcript_len_e1, encoded_len_e1]

if args.lm_path and args.mode == 'infer':
beam_width = args.beam_width
alpha = args.alpha
beta = args.beta
beam_search_with_lm = nemo_asr.BeamSearchDecoderWithLM(
vocab=vocab,
beam_width=beam_width,
alpha=alpha,
beta=beta,
lm_path=args.lm_path,
num_cpus=max(os.cpu_count(), 1))
beam_predictions_e1 = beam_search_with_lm(
log_probs=log_probs_e1, log_probs_length=encoded_len_e1)
eval_tensors.append(beam_predictions_e1)

evaluated_tensors = neural_factory.infer(
tensors=eval_tensors,
checkpoint_dir=load_dir,
cache=True
)

greedy_hypotheses = post_process_predictions(evaluated_tensors[1], vocab)
Expand All @@ -163,30 +144,7 @@ def main():
wer = word_error_rate(hypotheses=greedy_hypotheses, references=references)
logger.info("Greedy WER {:.2f}%".format(wer*100))

if args.mode == 'infer':
if args.lm_path:
beam_hypotheses = []
# Over mini-batch
for i in evaluated_tensors[-1]:
# Over samples
for j in i:
beam_hypotheses.append(j[0][1])

wer = word_error_rate(
hypotheses=beam_hypotheses, references=references)
logger.info("Beam WER {:.2f}".format(wer*100))

if args.save_logprob:
# Convert logits to list of numpy arrays
logprob = []
for i, batch in enumerate(evaluated_tensors[0]):
for j in range(batch.shape[0]):
logprob.append(
batch[j][:evaluated_tensors[4][i][j], :].cpu().numpy())
with open(args.save_logprob, 'wb') as f:
pickle.dump(logprob, f, protocol=pickle.HIGHEST_PROTOCOL)

if args.mode == 'eval':
if args.lm_path:
if args.alpha_max is None:
args.alpha_max = args.alpha
# include alpha_max in tuning range
Expand All @@ -198,14 +156,11 @@ def main():
args.beta_max += args.beta_step/10.0

beam_wers = []
checkpoints_loaded = False

for alpha in np.arange(args.alpha, args.alpha_max, args.alpha_step):
for beta in np.arange(args.beta, args.beta_max, args.beta_step):
logger.info(f'infering with (alpha, beta): ({alpha}, {beta})')
eval_tensors = [log_probs_e1, predictions_e1,
transcript_e1, transcript_len_e1,
encoded_len_e1]
logger.info('================================')
logger.info(f'Infering with (alpha, beta): ({alpha}, {beta})')
beam_search_with_lm = nemo_asr.BeamSearchDecoderWithLM(
vocab=vocab,
beam_width=args.beam_width,
Expand All @@ -215,21 +170,13 @@ def main():
num_cpus=max(os.cpu_count(), 1))
beam_predictions_e1 = beam_search_with_lm(
log_probs=log_probs_e1, log_probs_length=encoded_len_e1)
eval_tensors.append(beam_predictions_e1)
if checkpoints_loaded:
checkpoint_dir = None
else:
checkpoint_dir = load_dir
checkpoints_loaded = True

evaluated_tensors = neural_factory.infer(
tensors=eval_tensors,
checkpoint_dir=checkpoint_dir,
tensors=[beam_predictions_e1],
use_cache=True,
verbose=False
)

references = post_process_transcripts(
evaluated_tensors[2], evaluated_tensors[3], vocab)

beam_hypotheses = []
# Over mini-batch
for i in evaluated_tensors[-1]:
Expand All @@ -239,7 +186,7 @@ def main():

wer = word_error_rate(
hypotheses=beam_hypotheses, references=references)
logger.info("Beam WER {:.2f}".format(wer*100))
logger.info("Beam WER {:.2f}%".format(wer*100))
beam_wers.append(((alpha, beta), wer*100))

logger.info('Beam WER for (alpha, beta)')
Expand All @@ -249,7 +196,7 @@ def main():
best_beam_wer = min(beam_wers, key=lambda x: x[1])
logger.info('Best (alpha, beta): '
f'{best_beam_wer[0]}, '
f'WER: {best_beam_wer[1]:.2f}')
f'WER: {best_beam_wer[1]:.2f}%')


if __name__ == "__main__":
Expand Down
113 changes: 95 additions & 18 deletions nemo/nemo/backends/pytorch/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(self, local_rank=None, tb_writer=None,
self.optimizers = []
self.tb_writer = tb_writer
self._modules = set()
self.cache = None

@property
def modules(self):
Expand Down Expand Up @@ -387,8 +388,16 @@ def __nm_graph_forward_pass(self,
call_chain,
registered_tensors,
mode=ModelMode.train,
disable_allreduce=False):
disable_allreduce=False,
use_cache=False):
for ind in range(1, len(call_chain)):
if use_cache:
in_cache = True
for tensor in call_chain[ind][2].values():
if tensor.unique_name not in registered_tensors:
in_cache = False
if in_cache:
continue
call_args = call_chain[ind][1]
# module = call_chain[ind][0]
m_id = call_chain[ind][0].unique_instance_id
Expand Down Expand Up @@ -667,10 +676,27 @@ def _eval(self, tensors_2_evaluate, callback, step, verbose=False):
for key, val in vals_to_log.items():
callback.swriter.add_scalar(key, val, step)

def _infer(self, tensors_to_return, step, verbose=False):
def _infer(self,
tensors_to_return,
verbose=False,
cache=False,
use_cache=False,
offload_to_cpu=True):
"""
Does the same as _eval() just with tensors instead of eval callback.
"""
# Checking that cache is used properly
if cache and use_cache:
raise ValueError("cache and use_cache were both set. However cache"
" must first be created prior to using it.")
if cache:
if self.cache is not None:
raise ValueError("cache was set but was not empty")
self.cache = []
if use_cache:
if not self.cache:
raise ValueError("use_cache was set, but cache was empty")

with torch.no_grad():
# each call chain corresponds to a tensor in tensors_2_evaluate
dl_nm = None
Expand All @@ -685,6 +711,9 @@ def _infer(self, tensors_to_return, step, verbose=False):
is_distributed = False
world_size = None
if dl_nm.placement == DeviceType.AllGpu:
if self.cache or self.use_cache:
raise NotImplementedError(
"Caching is not available for distributed training.")
assert dist.is_initialized()
is_distributed = True
world_size = torch.distributed.get_world_size()
Expand All @@ -709,7 +738,9 @@ def _infer(self, tensors_to_return, step, verbose=False):
else:
eval_dataloader = dl_nm.data_iterator
eval_dataloader.sampler.set_epoch(0)
else: # Not distributed
elif not use_cache: # Not distributed and not using cache
# There is no need for dataloaders if using cache
# Caching must then cache all outputs from dataloader
if dl_nm.dataset is not None:
# Todo: remove local_parameters
eval_dataloader = torch.utils.data.DataLoader(
Expand Down Expand Up @@ -737,33 +768,57 @@ def _infer(self, tensors_to_return, step, verbose=False):
dl_device = dl_nm._device

# Evaluation mini-batch for loop
num_batches = len(eval_dataloader)
for epoch_i, data in enumerate(eval_dataloader, 0):
if use_cache:
num_batches = len(self.cache)
loop_iterator = self.cache
else:
num_batches = len(eval_dataloader)
loop_iterator = eval_dataloader

for epoch_i, data in enumerate(loop_iterator, 0):
if verbose and (
num_batches < 10 or (
epoch_i % int(num_batches / 10) == 0)
):
print(
f"Evaluating batch {epoch_i} out of {num_batches}")
tensors = []
if isinstance(data, torch.Tensor):
data = (data,)
for d in data:
if isinstance(d, torch.Tensor):
tensors.append(d.to(dl_device))
else:
tensors.append(d)
if use_cache:
registered_e_tensors = data
# delete tensors_to_return
for t in tensors_to_return:
if t.unique_name in registered_e_tensors:
del registered_e_tensors[t.unique_name]
else:
if isinstance(data, torch.Tensor):
data = (data,)
for d in data:
if isinstance(d, torch.Tensor):
tensors.append(d.to(dl_device))
else:
tensors.append(d)

registered_e_tensors = {t.unique_name: d for t, d in
zip(call_chain[0][2].values(), tensors)
if t is not None
}
registered_e_tensors = {
t.unique_name: d for t, d in
zip(call_chain[0][2].values(), tensors)
if t is not None
}
self.__nm_graph_forward_pass(
call_chain=call_chain,
registered_tensors=registered_e_tensors,
mode=ModelMode.eval,
use_cache=use_cache
)

if offload_to_cpu:
# Take all cuda tensors and save them to value_dict as
# cpu tensors to save GPU memory
for name, tensor in registered_e_tensors.items():
if isinstance(tensor, torch.Tensor):
registered_e_tensors[name] = tensor.cpu()
if cache:
self.append_to_cache(registered_e_tensors)

# If distributed. For the outer loop, we need to ensure that
# all processes loop through the elements in the same order
for t2e in tensors_to_return:
Expand Down Expand Up @@ -833,6 +888,18 @@ def _infer(self, tensors_to_return, step, verbose=False):
# For all other ranks
return None

def append_to_cache(self, registered_tensors: dict):
"""Simpler helper function to add results of __nm_graph_forward_pass to
current cache.
"""
self.cache.append(registered_tensors)

def clear_cache(self):
""" Simple helpful function to clear cache by setting self.cache to
None
"""
self.cache = None

def save_state_to(self, path: str):
"""
Saves current state such as step, epoch and optimizer parameters
Expand Down Expand Up @@ -1245,7 +1312,13 @@ def infer(self,
tensors,
checkpoint_dir=None,
ckpt_pattern='',
logger=None):
logger=None,
verbose=True,
cache=False,
use_cache=False,
offload_to_cpu=True):
"""See NeuralModuleFactory.infer()
"""

if checkpoint_dir:
# Find all modules that need to be restored
Expand Down Expand Up @@ -1287,4 +1360,8 @@ def infer(self,
)

# Run infer
return self._infer(tensors_to_return=tensors, step=0, verbose=True)
return self._infer(tensors_to_return=tensors,
verbose=verbose,
cache=cache,
use_cache=use_cache,
offload_to_cpu=offload_to_cpu)
49 changes: 45 additions & 4 deletions nemo/nemo/core/neural_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,11 +551,52 @@ def eval(self,
optimization_params={'num_epochs': 1}
)

def infer(self, tensors: List[NmTensor], checkpoint_dir=None,
ckpt_pattern=''):
def infer(self,
tensors: List[NmTensor],
checkpoint_dir=None,
ckpt_pattern='',
verbose=True,
cache=False,
use_cache=False,
offload_to_cpu=True):
"""Runs inference to obtain values for tensors

Args:
tensors (list[NmTensor]): List of NeMo tensors that we want to get
values of.
checkpoint_dir (str): Path to checkpoint directory. Default is None
which does not load checkpoints.
ckpt_pattern (str): Pattern used to check for checkpoints inside
checkpoint_dir. Default is '' which matches any checkpoints
inside checkpoint_dir.
verbose (bool): Controls printing. Defaults to True.
cache (bool): If True, cache all `tensors` and intermediate tensors
so that future calls that have use_cache set will avoid
computation. Defaults to False.
use_cache (bool): If True, remove all values passed to `tensors`
from cache. Run through the entire call chain but if there are
values already in cache, we can skip that Neural Module.
Defaults to False.
offload_to_cpu (bool): If True, all evaluated tensors are moved to
cpu memory after each inference batch. Defaults to True.

Returns:
List of evaluated tensors. Each element in the list is also a list
where each element is now a batch of tensor values.
"""
return self._trainer.infer(
tensors=tensors, checkpoint_dir=checkpoint_dir,
ckpt_pattern=ckpt_pattern, logger=self.logger)
tensors=tensors,
checkpoint_dir=checkpoint_dir,
ckpt_pattern=ckpt_pattern,
verbose=verbose,
logger=self.logger,
cache=cache,
use_cache=use_cache,
offload_to_cpu=offload_to_cpu)

def clear_cache(self):
"""Helper function to clean inference cache."""
self._trainer.clear_cache()

def _get_trainer(self, tb_writer=None):
if self._backend == Backend.PyTorch:
Expand Down
Loading