diff --git a/examples/asr/jasper_infer.py b/examples/asr/jasper_eval.py similarity index 74% rename from examples/asr/jasper_infer.py rename to examples/asr/jasper_eval.py index 1ac17c32cca2..6cda39da5b34 100644 --- a/examples/asr/jasper_infer.py +++ b/examples/asr/jasper_eval.py @@ -4,15 +4,14 @@ import argparse import copy import os -import pickle from ruamel.yaml import YAML +import numpy as np import nemo import nemo_asr from nemo_asr.helpers import word_error_rate, post_process_predictions, \ post_process_transcripts -import numpy as np def main(): @@ -51,10 +50,6 @@ def main(): required=False, default=0.1) parser.add_argument( "--beam_width", default=128, type=int) - parser.add_argument( - '--mode', - help='either \'eval\' (default) or \'infer\'', - default='infer') args = parser.parse_args() batch_size = args.batch_size @@ -137,24 +132,10 @@ def main(): eval_tensors = [log_probs_e1, predictions_e1, transcript_e1, transcript_len_e1, encoded_len_e1] - if args.lm_path and args.mode == 'infer': - beam_width = args.beam_width - alpha = args.alpha - beta = args.beta - beam_search_with_lm = nemo_asr.BeamSearchDecoderWithLM( - vocab=vocab, - beam_width=beam_width, - alpha=alpha, - beta=beta, - lm_path=args.lm_path, - num_cpus=max(os.cpu_count(), 1)) - beam_predictions_e1 = beam_search_with_lm( - log_probs=log_probs_e1, log_probs_length=encoded_len_e1) - eval_tensors.append(beam_predictions_e1) - evaluated_tensors = neural_factory.infer( tensors=eval_tensors, checkpoint_dir=load_dir, + cache=True ) greedy_hypotheses = post_process_predictions(evaluated_tensors[1], vocab) @@ -163,30 +144,7 @@ def main(): wer = word_error_rate(hypotheses=greedy_hypotheses, references=references) logger.info("Greedy WER {:.2f}%".format(wer*100)) - if args.mode == 'infer': - if args.lm_path: - beam_hypotheses = [] - # Over mini-batch - for i in evaluated_tensors[-1]: - # Over samples - for j in i: - beam_hypotheses.append(j[0][1]) - - wer = word_error_rate( - hypotheses=beam_hypotheses, references=references) - logger.info("Beam WER {:.2f}".format(wer*100)) - - if args.save_logprob: - # Convert logits to list of numpy arrays - logprob = [] - for i, batch in enumerate(evaluated_tensors[0]): - for j in range(batch.shape[0]): - logprob.append( - batch[j][:evaluated_tensors[4][i][j], :].cpu().numpy()) - with open(args.save_logprob, 'wb') as f: - pickle.dump(logprob, f, protocol=pickle.HIGHEST_PROTOCOL) - - if args.mode == 'eval': + if args.lm_path: if args.alpha_max is None: args.alpha_max = args.alpha # include alpha_max in tuning range @@ -198,14 +156,11 @@ def main(): args.beta_max += args.beta_step/10.0 beam_wers = [] - checkpoints_loaded = False for alpha in np.arange(args.alpha, args.alpha_max, args.alpha_step): for beta in np.arange(args.beta, args.beta_max, args.beta_step): - logger.info(f'infering with (alpha, beta): ({alpha}, {beta})') - eval_tensors = [log_probs_e1, predictions_e1, - transcript_e1, transcript_len_e1, - encoded_len_e1] + logger.info('================================') + logger.info(f'Infering with (alpha, beta): ({alpha}, {beta})') beam_search_with_lm = nemo_asr.BeamSearchDecoderWithLM( vocab=vocab, beam_width=args.beam_width, @@ -215,21 +170,13 @@ def main(): num_cpus=max(os.cpu_count(), 1)) beam_predictions_e1 = beam_search_with_lm( log_probs=log_probs_e1, log_probs_length=encoded_len_e1) - eval_tensors.append(beam_predictions_e1) - if checkpoints_loaded: - checkpoint_dir = None - else: - checkpoint_dir = load_dir - checkpoints_loaded = True evaluated_tensors = neural_factory.infer( - tensors=eval_tensors, - checkpoint_dir=checkpoint_dir, + tensors=[beam_predictions_e1], + use_cache=True, + verbose=False ) - references = post_process_transcripts( - evaluated_tensors[2], evaluated_tensors[3], vocab) - beam_hypotheses = [] # Over mini-batch for i in evaluated_tensors[-1]: @@ -239,7 +186,7 @@ def main(): wer = word_error_rate( hypotheses=beam_hypotheses, references=references) - logger.info("Beam WER {:.2f}".format(wer*100)) + logger.info("Beam WER {:.2f}%".format(wer*100)) beam_wers.append(((alpha, beta), wer*100)) logger.info('Beam WER for (alpha, beta)') @@ -249,7 +196,7 @@ def main(): best_beam_wer = min(beam_wers, key=lambda x: x[1]) logger.info('Best (alpha, beta): ' f'{best_beam_wer[0]}, ' - f'WER: {best_beam_wer[1]:.2f}') + f'WER: {best_beam_wer[1]:.2f}%') if __name__ == "__main__": diff --git a/nemo/nemo/backends/pytorch/actions.py b/nemo/nemo/backends/pytorch/actions.py index f9a76a9760fb..58e379110707 100644 --- a/nemo/nemo/backends/pytorch/actions.py +++ b/nemo/nemo/backends/pytorch/actions.py @@ -75,6 +75,7 @@ def __init__(self, local_rank=None, tb_writer=None, self.optimizers = [] self.tb_writer = tb_writer self._modules = set() + self.cache = None @property def modules(self): @@ -387,8 +388,16 @@ def __nm_graph_forward_pass(self, call_chain, registered_tensors, mode=ModelMode.train, - disable_allreduce=False): + disable_allreduce=False, + use_cache=False): for ind in range(1, len(call_chain)): + if use_cache: + in_cache = True + for tensor in call_chain[ind][2].values(): + if tensor.unique_name not in registered_tensors: + in_cache = False + if in_cache: + continue call_args = call_chain[ind][1] # module = call_chain[ind][0] m_id = call_chain[ind][0].unique_instance_id @@ -667,10 +676,27 @@ def _eval(self, tensors_2_evaluate, callback, step, verbose=False): for key, val in vals_to_log.items(): callback.swriter.add_scalar(key, val, step) - def _infer(self, tensors_to_return, step, verbose=False): + def _infer(self, + tensors_to_return, + verbose=False, + cache=False, + use_cache=False, + offload_to_cpu=True): """ Does the same as _eval() just with tensors instead of eval callback. """ + # Checking that cache is used properly + if cache and use_cache: + raise ValueError("cache and use_cache were both set. However cache" + " must first be created prior to using it.") + if cache: + if self.cache is not None: + raise ValueError("cache was set but was not empty") + self.cache = [] + if use_cache: + if not self.cache: + raise ValueError("use_cache was set, but cache was empty") + with torch.no_grad(): # each call chain corresponds to a tensor in tensors_2_evaluate dl_nm = None @@ -685,6 +711,9 @@ def _infer(self, tensors_to_return, step, verbose=False): is_distributed = False world_size = None if dl_nm.placement == DeviceType.AllGpu: + if self.cache or self.use_cache: + raise NotImplementedError( + "Caching is not available for distributed training.") assert dist.is_initialized() is_distributed = True world_size = torch.distributed.get_world_size() @@ -709,7 +738,9 @@ def _infer(self, tensors_to_return, step, verbose=False): else: eval_dataloader = dl_nm.data_iterator eval_dataloader.sampler.set_epoch(0) - else: # Not distributed + elif not use_cache: # Not distributed and not using cache + # There is no need for dataloaders if using cache + # Caching must then cache all outputs from dataloader if dl_nm.dataset is not None: # Todo: remove local_parameters eval_dataloader = torch.utils.data.DataLoader( @@ -737,8 +768,14 @@ def _infer(self, tensors_to_return, step, verbose=False): dl_device = dl_nm._device # Evaluation mini-batch for loop - num_batches = len(eval_dataloader) - for epoch_i, data in enumerate(eval_dataloader, 0): + if use_cache: + num_batches = len(self.cache) + loop_iterator = self.cache + else: + num_batches = len(eval_dataloader) + loop_iterator = eval_dataloader + + for epoch_i, data in enumerate(loop_iterator, 0): if verbose and ( num_batches < 10 or ( epoch_i % int(num_batches / 10) == 0) @@ -746,24 +783,42 @@ def _infer(self, tensors_to_return, step, verbose=False): print( f"Evaluating batch {epoch_i} out of {num_batches}") tensors = [] - if isinstance(data, torch.Tensor): - data = (data,) - for d in data: - if isinstance(d, torch.Tensor): - tensors.append(d.to(dl_device)) - else: - tensors.append(d) + if use_cache: + registered_e_tensors = data + # delete tensors_to_return + for t in tensors_to_return: + if t.unique_name in registered_e_tensors: + del registered_e_tensors[t.unique_name] + else: + if isinstance(data, torch.Tensor): + data = (data,) + for d in data: + if isinstance(d, torch.Tensor): + tensors.append(d.to(dl_device)) + else: + tensors.append(d) - registered_e_tensors = {t.unique_name: d for t, d in - zip(call_chain[0][2].values(), tensors) - if t is not None - } + registered_e_tensors = { + t.unique_name: d for t, d in + zip(call_chain[0][2].values(), tensors) + if t is not None + } self.__nm_graph_forward_pass( call_chain=call_chain, registered_tensors=registered_e_tensors, mode=ModelMode.eval, + use_cache=use_cache ) + if offload_to_cpu: + # Take all cuda tensors and save them to value_dict as + # cpu tensors to save GPU memory + for name, tensor in registered_e_tensors.items(): + if isinstance(tensor, torch.Tensor): + registered_e_tensors[name] = tensor.cpu() + if cache: + self.append_to_cache(registered_e_tensors) + # If distributed. For the outer loop, we need to ensure that # all processes loop through the elements in the same order for t2e in tensors_to_return: @@ -833,6 +888,18 @@ def _infer(self, tensors_to_return, step, verbose=False): # For all other ranks return None + def append_to_cache(self, registered_tensors: dict): + """Simpler helper function to add results of __nm_graph_forward_pass to + current cache. + """ + self.cache.append(registered_tensors) + + def clear_cache(self): + """ Simple helpful function to clear cache by setting self.cache to + None + """ + self.cache = None + def save_state_to(self, path: str): """ Saves current state such as step, epoch and optimizer parameters @@ -1245,7 +1312,13 @@ def infer(self, tensors, checkpoint_dir=None, ckpt_pattern='', - logger=None): + logger=None, + verbose=True, + cache=False, + use_cache=False, + offload_to_cpu=True): + """See NeuralModuleFactory.infer() + """ if checkpoint_dir: # Find all modules that need to be restored @@ -1287,4 +1360,8 @@ def infer(self, ) # Run infer - return self._infer(tensors_to_return=tensors, step=0, verbose=True) + return self._infer(tensors_to_return=tensors, + verbose=verbose, + cache=cache, + use_cache=use_cache, + offload_to_cpu=offload_to_cpu) diff --git a/nemo/nemo/core/neural_factory.py b/nemo/nemo/core/neural_factory.py index 2bbdca196929..17e813fb127f 100644 --- a/nemo/nemo/core/neural_factory.py +++ b/nemo/nemo/core/neural_factory.py @@ -551,11 +551,52 @@ def eval(self, optimization_params={'num_epochs': 1} ) - def infer(self, tensors: List[NmTensor], checkpoint_dir=None, - ckpt_pattern=''): + def infer(self, + tensors: List[NmTensor], + checkpoint_dir=None, + ckpt_pattern='', + verbose=True, + cache=False, + use_cache=False, + offload_to_cpu=True): + """Runs inference to obtain values for tensors + + Args: + tensors (list[NmTensor]): List of NeMo tensors that we want to get + values of. + checkpoint_dir (str): Path to checkpoint directory. Default is None + which does not load checkpoints. + ckpt_pattern (str): Pattern used to check for checkpoints inside + checkpoint_dir. Default is '' which matches any checkpoints + inside checkpoint_dir. + verbose (bool): Controls printing. Defaults to True. + cache (bool): If True, cache all `tensors` and intermediate tensors + so that future calls that have use_cache set will avoid + computation. Defaults to False. + use_cache (bool): If True, remove all values passed to `tensors` + from cache. Run through the entire call chain but if there are + values already in cache, we can skip that Neural Module. + Defaults to False. + offload_to_cpu (bool): If True, all evaluated tensors are moved to + cpu memory after each inference batch. Defaults to True. + + Returns: + List of evaluated tensors. Each element in the list is also a list + where each element is now a batch of tensor values. + """ return self._trainer.infer( - tensors=tensors, checkpoint_dir=checkpoint_dir, - ckpt_pattern=ckpt_pattern, logger=self.logger) + tensors=tensors, + checkpoint_dir=checkpoint_dir, + ckpt_pattern=ckpt_pattern, + verbose=verbose, + logger=self.logger, + cache=cache, + use_cache=use_cache, + offload_to_cpu=offload_to_cpu) + + def clear_cache(self): + """Helper function to clean inference cache.""" + self._trainer.clear_cache() def _get_trainer(self, tb_writer=None): if self._backend == Backend.PyTorch: diff --git a/tests/test_infer.py b/tests/test_infer.py new file mode 100644 index 000000000000..f16711fa4536 --- /dev/null +++ b/tests/test_infer.py @@ -0,0 +1,146 @@ +# Copyright (c) 2019 NVIDIA Corporation +import torch + +from nemo.backends.pytorch.nm import NonTrainableNM +from nemo.core.neural_types import * + +from .context import nemo +from .common_setup import NeMoUnitTest + + +class AddsTen(NonTrainableNM): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + @staticmethod + def create_ports(): + input_ports = { + "mod_in": NeuralType({0: AxisType(BatchTag), + 1: AxisType(BaseTag, dim=1)}) + } + output_ports = { + "mod_out": NeuralType({0: AxisType(BatchTag), + 1: AxisType(BaseTag, dim=1)}) + } + + return input_ports, output_ports + + def forward(self, mod_in): + return mod_in + 10 + + +class SubtractsTen(NonTrainableNM): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + @staticmethod + def create_ports(): + input_ports = { + "mod_in": NeuralType({0: AxisType(BatchTag), + 1: AxisType(BaseTag, dim=1)}) + } + output_ports = { + "mod_out": NeuralType({0: AxisType(BatchTag), + 1: AxisType(BaseTag, dim=1)}) + } + + return input_ports, output_ports + + def forward(self, mod_in): + return mod_in - 10 + + +class TestInfer(NeMoUnitTest): + def test_infer_caching(self): + neural_factory = nemo.core.neural_factory.NeuralModuleFactory( + backend=nemo.core.Backend.PyTorch, create_tb_writer=False) + + data_source = nemo.backends.pytorch.common.ZerosDataLayer( + size=1, + dtype=torch.FloatTensor, + batch_size=1, + output_ports={ + "dl_out": NeuralType({0: AxisType(BatchTag), + 1: AxisType(BaseTag, dim=1)})}) + addten = AddsTen() + minusten = SubtractsTen() + + zero_tensor = data_source() + ten_tensor = addten(mod_in=zero_tensor) + twenty_tensor = addten(mod_in=ten_tensor) + thirty_tensor = addten(mod_in=twenty_tensor) + + evaluated_tensors = neural_factory.infer( + tensors=[twenty_tensor, thirty_tensor], + verbose=False, + cache=True + ) + self.assertEqual(evaluated_tensors[0][0].squeeze().data, 20) + self.assertEqual(evaluated_tensors[1][0].squeeze().data, 30) + + new_ten_tensor = minusten(mod_in=twenty_tensor) + evaluated_tensors = neural_factory.infer( + tensors=[new_ten_tensor], + verbose=False, + use_cache=True + ) + self.assertEqual(evaluated_tensors[0][0].squeeze().data, 10) + + def test_infer_errors(self): + neural_factory = nemo.core.neural_factory.NeuralModuleFactory( + backend=nemo.core.Backend.PyTorch, create_tb_writer=False) + + data_source = nemo.backends.pytorch.common.ZerosDataLayer( + size=1, + dtype=torch.FloatTensor, + batch_size=1, + output_ports={ + "dl_out": NeuralType({0: AxisType(BatchTag), + 1: AxisType(BaseTag, dim=1)})}) + addten = AddsTen() + minusten = SubtractsTen() + + zero_tensor = data_source() + ten_tensor = addten(mod_in=zero_tensor) + twenty_tensor = addten(mod_in=ten_tensor) + thirty_tensor = addten(mod_in=twenty_tensor) + + with self.assertRaisesRegex(ValueError, + "use_cache was set, but cache was empty"): + evaluated_tensors = neural_factory.infer( + tensors=[twenty_tensor, thirty_tensor], + verbose=False, + use_cache=True + ) + + new_ten_tensor = minusten(mod_in=twenty_tensor) + evaluated_tensors = neural_factory.infer( + tensors=[new_ten_tensor], + verbose=False, + cache=True + ) + + with self.assertRaisesRegex(ValueError, + "cache was set but was not empty"): + evaluated_tensors = neural_factory.infer( + tensors=[twenty_tensor, thirty_tensor], + verbose=False, + cache=True + ) + + neural_factory.clear_cache() + evaluated_tensors = neural_factory.infer( + tensors=[new_ten_tensor], + verbose=False, + cache=True + ) + + with self.assertRaisesRegex(ValueError, + "cache and use_cache were both set."): + evaluated_tensors = neural_factory.infer( + tensors=[twenty_tensor, thirty_tensor], + verbose=False, + cache=True, + use_cache=True + ) + self.assertEqual(evaluated_tensors[0][0].squeeze().data, 10)