Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions collections/nemo_asr/nemo_asr/data_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@
processing
"""
import torch
from apex import amp

try:
from apex import amp
except AttributeError:
print("Unable to import APEX. Mixed precision and distributed training "
"will not work.")
from nemo.backends.pytorch.nm import DataLayerNM, TrainableNM, NonTrainableNM
from nemo.core import Optimization, DeviceType
from nemo.core.neural_types import *
Expand Down
8 changes: 7 additions & 1 deletion collections/nemo_nlp/nemo_nlp/transformer/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,14 @@
'PositionWiseFF']

import math
try:
from apex.normalization import FusedLayerNorm
except AttributeError:
# this is lie - it isn't fused in this case
print("Unable to import APEX. Mixed precision, distributed training and "
"FusedLayerNorm are not available.")
from torch.nn import LayerNorm as FusedLayerNorm

from apex.normalization import FusedLayerNorm
import torch
from torch import nn

Expand Down
54 changes: 18 additions & 36 deletions examples/asr/ASR_made_simple.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
"metadata": {},
"outputs": [],
"source": [
"train_manifest = \"an4data/an4_train.json\"\n",
"val_manifest = \"an4data/an4_val.json\""
"train_manifest = \"an4_dataset/an4_train.json\"\n",
"val_manifest = \"an4_dataset/an4_val.json\""
]
},
{
Expand Down Expand Up @@ -56,6 +56,20 @@
"### Instantiate necessary Neural Modules"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# First step is to instantiate a NeuralModuleFactory\n",
"# If torch is installed without CUDA and Apex CPU will be used\n",
"# and training is impractically slow even for this dataset\n",
"from nemo.core import DeviceType\n",
"import torch\n",
"nf = nemo.core.NeuralModuleFactory(placement=DeviceType.GPU if torch.cuda.is_available() else DeviceType.CPU)"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -192,44 +206,12 @@
"metadata": {},
"outputs": [],
"source": [
"# instantiate Neural Factory with supported backend\n",
"neural_factory = nemo.core.NeuralModuleFactory(backend=nemo.core.Backend.PyTorch)\n",
"# \n",
"# neural_factory = nemo.core.NeuralModuleFactory(\n",
"# backend=nemo.core.Backend.PyTorch,\n",
"# local_rank=args.local_rank,\n",
"# optimization_level=nemo.core.Optimization.mxprO1,\n",
"# placement=device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"neural_factory.train(tensors_to_optimize=[loss],\n",
"nf.train(tensors_to_optimize=[loss],\n",
" callbacks=[train_callback],\n",
" optimizer=\"novograd\",\n",
" optimization_params={\"num_epochs\": 30, \"lr\": 1e-2,\n",
" \"weight_decay\": 1e-3})"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"jasper_encoder.save_to('jasper_encoder.pt')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand All @@ -248,7 +230,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.6"
"version": "3.7.3"
}
},
"nbformat": 4,
Expand Down
19 changes: 16 additions & 3 deletions examples/asr/InferenceWithBeamSearch.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,20 @@
"labels = jasper_model_definition['labels']"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# First step is to instantiate a NeuralModuleFactory\n",
"# If torch is installed without CUDA and Apex CPU will be used\n",
"# and training is impractically slow even for this dataset\n",
"from nemo.core import DeviceType\n",
"import torch\n",
"nf = nemo.core.NeuralModuleFactory(placement=DeviceType.GPU if torch.cuda.is_available() else DeviceType.CPU)"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -66,7 +80,7 @@
"metadata": {},
"outputs": [],
"source": [
"# Download checkpoint from here: https://drive.google.com/drive/folders/1b-TQYY7o8_CQgZsVEe-8_2kHWU0lYJ-z?usp=sharing\n",
"# Download checkpoint from here: https://ngc.nvidia.com/catalog/models/nvidia:quartznet15x5\n",
"import os\n",
"# Instantiate BeamSearch NM\n",
"beam_search_with_lm = nemo_asr.BeamSearchDecoderWithLM(\n",
Expand All @@ -88,7 +102,6 @@
"source": [
"from nemo_asr.helpers import post_process_predictions, \\\n",
" post_process_transcripts, word_error_rate\n",
"neural_factory = nemo.core.NeuralModuleFactory(backend=nemo.core.Backend.PyTorch)\n",
"\n",
"evaluated_tensors = neural_factory.infer(\n",
" tensors=eval_tensors,\n",
Expand Down Expand Up @@ -138,7 +151,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.6"
"version": "3.7.3"
}
},
"nbformat": 4,
Expand Down
4 changes: 3 additions & 1 deletion examples/start_here/chatbot_example.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
import sys
import gzip
import shutil
import nemo
Expand Down Expand Up @@ -32,6 +31,9 @@

# instantiate neural factory
nf = nemo.core.NeuralModuleFactory()
# To use CPU-only do:
# from nemo.core import DeviceType
# nf = nemo.core.NeuralModuleFactory(placement=DeviceType.CPU)

# instantiate neural modules
dl = nemo.tutorials.DialogDataLayer(**config)
Expand Down
5 changes: 3 additions & 2 deletions examples/start_here/simplest_example.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# Copyright (c) 2019 NVIDIA Corporation
import nemo

# instantiate Neural Factory with supported backend
nf = nemo.core.NeuralModuleFactory()
# To use CPU-only do:
# from nemo.core import DeviceType
# nf = nemo.core.NeuralModuleFactory(placement=DeviceType.CPU)

# instantiate necessary neural modules
# RealFunctionDataLayer defaults to f=torch.sin, sampling from x=[-4, 4]
Expand Down
71 changes: 51 additions & 20 deletions nemo/nemo/backends/pytorch/actions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) 2019 NVIDIA Corporation
import importlib
import itertools
import logging
import os
Expand All @@ -21,14 +22,10 @@
from ...core.neural_factory import Actions, ModelMode, Optimization
from ...utils.helpers import get_checkpoint_from_dir

try:
import apex
from apex.parallel import DistributedDataParallel as DDP
from apex.parallel.LARC import LARC
from apex import amp
except ImportError:
raise ImportError(
"Please install apex from https://www.github.com/nvidia/apex")
# these imports will happen on as-needed basis
amp = None
DDP = None
LARC = None

AmpOptimizations = {
Optimization.mxprO0: "O0",
Expand All @@ -45,6 +42,28 @@
class PtActions(Actions):
def __init__(self, local_rank=None, tb_writer=None,
optimization_level=Optimization.mxprO0):
need_apex = local_rank is not None or \
optimization_level != Optimization.mxprO0
if need_apex:
try:
apex = importlib.import_module('apex')
if optimization_level != Optimization.mxprO0:
global amp
amp = importlib.import_module('apex.amp')
if local_rank is not None:
global DDP
global LARC
parallel = importlib.import_module('apex.parallel')
DDP = parallel.DistributedDataParallel
LARC = parallel.LARC

except ImportError:
raise ImportError(
"NVIDIA Apex is necessary for distributed training and"
"mixed precision training. It only works on GPUs."
"Please install Apex from "
"https://www.github.com/nvidia/apex")

super(PtActions, self).__init__(
local_rank=local_rank,
optimization_level=optimization_level)
Expand Down Expand Up @@ -340,8 +359,12 @@ def __initialize_amp(
self, optimizer, optim_level, amp_min_loss_scale=1.0
):
if optim_level not in AmpOptimizations:
raise ValueError("__initialize_amp() was called but optim_level "
"was set to float32.")
raise ValueError(f"__initialize_amp() was called with unknown "
"optim_level={optim_level}")
# in this case, nothing to do here
if optim_level == Optimization.mxprO0:
return optimizer

if len(self.modules) < 1:
raise ValueError("There were no modules to initialize")
pt_modules = []
Expand Down Expand Up @@ -371,11 +394,12 @@ def __nm_graph_forward_pass(self,
m_id = call_chain[ind][0].unique_instance_id
pmodule = self.module_reference_table[m_id][1]

if isinstance(pmodule, DDP):
if disable_allreduce:
pmodule.disable_allreduce()
else:
pmodule.enable_allreduce()
if self._local_rank is not None:
if isinstance(pmodule, DDP):
if disable_allreduce:
pmodule.disable_allreduce()
else:
pmodule.enable_allreduce()

if mode == ModelMode.train:
# if module.is_trainable():
Expand Down Expand Up @@ -1164,7 +1188,8 @@ def train(self,
final_loss += registered_tensors[tensor.unique_name]
if nan:
continue
if self._optim_level in AmpOptimizations:
if self._optim_level in AmpOptimizations \
and self._optim_level != Optimization.mxprO0:
with amp.scale_loss(
final_loss,
curr_optimizer,
Expand All @@ -1178,10 +1203,15 @@ def train(self,
continue
scaled_loss.backward(
bps_scale.to(scaled_loss.get_device()))
# no AMP optimizations needed
else:
final_loss.backward(
bps_scale.to(
final_loss.get_device()))
# multi-GPU, float32
if self._local_rank is not None:
final_loss.backward(
bps_scale.to(final_loss.get_device()))
# single device (CPU or GPU)
else:
final_loss.backward()

batch_counter += 1

Expand Down Expand Up @@ -1233,7 +1263,8 @@ def infer(self,
mod.restore_from(checkpoint, self._local_rank)

# Init Amp
if self._optim_level in AmpOptimizations:
if self._optim_level in AmpOptimizations and \
self._optim_level != Optimization.mxprO0:
pt_modules = []
for i in range(len(call_chain)):
if isinstance(call_chain[i][0], nn.Module):
Expand Down
10 changes: 10 additions & 0 deletions nemo/nemo/core/neural_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,16 @@ def __init__(
if backend == Backend.PyTorch:
# TODO: Move all framework specific code from this file
import torch
if self._placement != DeviceType.CPU:
if not torch.cuda.is_available():
raise ValueError("You requested to use GPUs but CUDA is "
"not installed. You can try running using"
" CPU-only. To do this, instantiate your"
" factory with placement=DeviceType.CPU"
"\n"
"Note that this is slow and is not "
"well supported.")

torch.backends.cudnn.benchmark = cudnn_benchmark
if random_seed is not None and cudnn_benchmark:
raise ValueError("cudnn_benchmark can not be set to True"
Expand Down
33 changes: 33 additions & 0 deletions scripts/install_decoders_MacOS.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#!/bin/sh
# Make sure swig is installed first. On Anaconda do:
# conda install swig
set -xe
brew update
brew install wget
brew install boost
brew install cmake
export CFLAGS="-stdlib=libc++"
export MACOSX_DEPLOYMENT_TARGET=10.14
git clone https://github.com/PaddlePaddle/DeepSpeech
cd DeepSpeech
git checkout a76fc69
cd ..
mv DeepSpeech/decoders/swig_wrapper.py DeepSpeech/decoders/swig/ctc_decoders.py
mv DeepSpeech/decoders/swig ./decoders
rm -rf DeepSpeech
cd decoders
sed -i'.original' -e "s/\.decode('utf-8')//g" ctc_decoders.py
sed -i'.original' -e 's/\.decode("utf-8")//g' ctc_decoders.py
sed -i'.original' -e "s/name='swig_decoders'/name='ctc_decoders'/g" setup.py
sed -i'.original' -e "s/-space_prefixes\[i\]->approx_ctc/space_prefixes\[i\]->score/g" decoder_utils.cpp
sed -i'.original' -e "s/py_modules=\['swig_decoders'\]/py_modules=\['ctc_decoders', 'swig_decoders'\]/g" setup.py
chmod +x setup.sh
./setup.sh
echo 'Installing kenlm'
cd kenlm
mkdir build
cd build
cmake ..
make -j
cd ..
cd ..