Skip to content

Commit

Permalink
Merge branch 'ranking_merge' of https://github.com/NVIDIA/NeMo into r…
Browse files Browse the repository at this point in the history
…anking_merge
  • Loading branch information
yzhang123 committed Apr 8, 2022
2 parents 64b89de + 668eafc commit 715f95c
Show file tree
Hide file tree
Showing 10 changed files with 105 additions and 60 deletions.
32 changes: 23 additions & 9 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -124,18 +124,18 @@ pipeline {
parallel {
stage('En TN grammars') {
steps {
sh 'CUDA_VISIBLE_DEVICES="" python nemo_text_processing/text_normalization/normalize.py "1" --cache_dir /home/TestData/nlp/text_norm/ci/grammars/2-3'
sh 'CUDA_VISIBLE_DEVICES="" python nemo_text_processing/text_normalization/normalize.py "1" --cache_dir /home/TestData/nlp/text_norm/ci/grammars/7-4'
}
}
stage('En ITN grammars') {
steps {
sh 'CUDA_VISIBLE_DEVICES="" python nemo_text_processing/inverse_text_normalization/inverse_normalize.py --language en "twenty" --cache_dir /home/TestData/nlp/text_norm/ci/grammars/2-3'
sh 'CUDA_VISIBLE_DEVICES="" python nemo_text_processing/inverse_text_normalization/inverse_normalize.py --language en "twenty" --cache_dir /home/TestData/nlp/text_norm/ci/grammars/7-4'
}
}
stage('Test En non-deterministic TN & Run all En TN/ITN tests (restore grammars from cache)') {
steps {
sh 'CUDA_VISIBLE_DEVICES="" python nemo_text_processing/text_normalization/normalize_with_audio.py --text "\$.01" --n_tagged 2 --cache_dir /home/TestData/nlp/text_norm/ci/grammars/2-3'
sh 'CUDA_VISIBLE_DEVICES="" pytest tests/nemo_text_processing/en/ -m "not pleasefixme" --cpu --tn_cache_dir /home/TestData/nlp/text_norm/ci/grammars/2-3'
sh 'CUDA_VISIBLE_DEVICES="" python nemo_text_processing/text_normalization/normalize_with_audio.py --text "\$.01" --n_tagged 2 --cache_dir /home/TestData/nlp/text_norm/ci/grammars/7-4'
sh 'CUDA_VISIBLE_DEVICES="" pytest tests/nemo_text_processing/en/ -m "not pleasefixme" --cpu --tn_cache_dir /home/TestData/nlp/text_norm/ci/grammars/7-4'
}
}
}
Expand All @@ -152,7 +152,7 @@ pipeline {
parallel {
stage('L2: Eng TN') {
steps {
sh 'cd tools/text_processing_deployment && python pynini_export.py --output=/home/TestData/nlp/text_norm/output/ --grammars=tn_grammars --cache_dir /home/TestData/nlp/text_norm/ci/grammars/2-3 --language=en && ls -R /home/TestData/nlp/text_norm/output/ && echo ".far files created "|| exit 1'
sh 'cd tools/text_processing_deployment && python pynini_export.py --output=/home/TestData/nlp/text_norm/output/ --grammars=tn_grammars --cache_dir /home/TestData/nlp/text_norm/ci/grammars/7-4 --language=en && ls -R /home/TestData/nlp/text_norm/output/ && echo ".far files created "|| exit 1'
sh 'cd nemo_text_processing/text_normalization/ && python run_predict.py --input=/home/TestData/nlp/text_norm/ci/test.txt --input_case="lower_cased" --language=en --output=/home/TestData/nlp/text_norm/output/test.pynini.txt --verbose'
sh 'cat /home/TestData/nlp/text_norm/output/test.pynini.txt'
sh 'cmp --silent /home/TestData/nlp/text_norm/output/test.pynini.txt /home/TestData/nlp/text_norm/ci/test_goal_py_12-10.txt || exit 1'
Expand All @@ -162,7 +162,7 @@ pipeline {

stage('L2: Eng ITN export') {
steps {
sh 'cd tools/text_processing_deployment && python pynini_export.py --output=/home/TestData/nlp/text_denorm/output/ --grammars=itn_grammars --cache_dir /home/TestData/nlp/text_norm/ci/grammars/2-3 --language=en && ls -R /home/TestData/nlp/text_denorm/output/ && echo ".far files created "|| exit 1'
sh 'cd tools/text_processing_deployment && python pynini_export.py --output=/home/TestData/nlp/text_denorm/output/ --grammars=itn_grammars --cache_dir /home/TestData/nlp/text_norm/ci/grammars/7-4 --language=en && ls -R /home/TestData/nlp/text_denorm/output/ && echo ".far files created "|| exit 1'
sh 'cd nemo_text_processing/inverse_text_normalization/ && python run_predict.py --input=/home/TestData/nlp/text_denorm/ci/test.txt --language=en --output=/home/TestData/nlp/text_denorm/output/test.pynini.txt --verbose'
sh 'cmp --silent /home/TestData/nlp/text_denorm/output/test.pynini.txt /home/TestData/nlp/text_denorm/ci/test_goal_py.txt || exit 1'
sh 'rm -rf /home/TestData/nlp/text_denorm/output/*'
Expand All @@ -171,23 +171,23 @@ pipeline {
stage('L2: TN with Audio (audio and raw text)') {
steps {
sh 'cd nemo_text_processing/text_normalization && \
python normalize_with_audio.py --language=en --cache_dir /home/TestData/nlp/text_norm/ci/grammars/2-3 --text "The total amounts to \\$4.76." \
python normalize_with_audio.py --language=en --cache_dir /home/TestData/nlp/text_norm/ci/grammars/7-4 --text "The total amounts to \\$4.76." \
--audio_data /home/TestData/nlp/text_norm/audio_based/audio.wav | tail -n2 | head -n1 > /tmp/out_raw.txt 2>&1 && \
cmp --silent /tmp/out_raw.txt /home/TestData/nlp/text_norm/audio_based/result.txt || exit 1'
}
}
stage('L2: TN with Audio (audio and text file)') {
steps {
sh 'cd nemo_text_processing/text_normalization && \
python normalize_with_audio.py --language=en --cache_dir /home/TestData/nlp/text_norm/ci/grammars/2-3 --text /home/TestData/nlp/text_norm/audio_based/text.txt \
python normalize_with_audio.py --language=en --cache_dir /home/TestData/nlp/text_norm/ci/grammars/7-4 --text /home/TestData/nlp/text_norm/audio_based/text.txt \
--audio_data /home/TestData/nlp/text_norm/audio_based/audio.wav | tail -n2 | head -n1 > /tmp/out_file.txt 2>&1 && \
cmp --silent /tmp/out_file.txt /home/TestData/nlp/text_norm/audio_based/result.txt || exit 1'
}
}
stage('L2: TN with Audio (manifest)') {
steps {
sh 'cd nemo_text_processing/text_normalization && \
python normalize_with_audio.py --language=en --audio_data /home/TestData/nlp/text_norm/audio_based/manifest.json --n_tagged=120 --cache_dir /home/TestData/nlp/text_norm/ci/grammars/2-3'
python normalize_with_audio.py --language=en --audio_data /home/TestData/nlp/text_norm/audio_based/manifest.json --n_tagged=120 --cache_dir /home/TestData/nlp/text_norm/ci/grammars/7-4'
}
}
}
Expand Down Expand Up @@ -2305,6 +2305,8 @@ pipeline {
model.num_layers=4 \
model.hidden_size=64 \
model.num_attention_heads=8 \
model.activation='swiglu' \
model.bias_gelu_fusion=False \
model.activations_checkpoint_method='block' \
model.activations_checkpoint_num_layers=1 \
model.data.data_prefix=[.5,/home/TestData/nlp/megatron_t5/data/pile_val_small_bert_tokenizer_text_document,.5,/home/TestData/nlp/megatron_t5/data/pile_val_small_bert_tokenizer_text_document] \
Expand All @@ -2326,6 +2328,8 @@ pipeline {
model.num_layers=4 \
model.hidden_size=64 \
model.num_attention_heads=8 \
model.activation='swiglu' \
model.bias_gelu_fusion=False \
model.activations_checkpoint_method='block' \
model.activations_checkpoint_num_layers=1 \
model.data.data_prefix=[.5,/home/TestData/nlp/megatron_t5/data/pile_val_small_bert_tokenizer_text_document,.5,/home/TestData/nlp/megatron_t5/data/pile_val_small_bert_tokenizer_text_document] \
Expand Down Expand Up @@ -2360,6 +2364,7 @@ pipeline {
model.num_layers=4 \
model.hidden_size=64 \
model.num_attention_heads=8 \
model.activation='gelu' \
model.activations_checkpoint_method='block' \
model.activations_checkpoint_num_layers=1 \
model.data.data_prefix=[.5,/home/TestData/nlp/megatron_t5/data/pile_val_small_bert_tokenizer_text_document,.5,/home/TestData/nlp/megatron_t5/data/pile_val_small_bert_tokenizer_text_document] \
Expand All @@ -2382,6 +2387,7 @@ pipeline {
model.num_layers=4 \
model.hidden_size=64 \
model.num_attention_heads=8 \
model.activation='gelu' \
model.activations_checkpoint_method='block' \
model.activations_checkpoint_num_layers=1 \
model.data.data_prefix=[.5,/home/TestData/nlp/megatron_t5/data/pile_val_small_bert_tokenizer_text_document,.5,/home/TestData/nlp/megatron_t5/data/pile_val_small_bert_tokenizer_text_document] \
Expand Down Expand Up @@ -2432,6 +2438,8 @@ pipeline {
model.num_layers=4 \
model.hidden_size=64 \
model.num_attention_heads=8 \
model.activation='reglu' \
model.bias_gelu_fusion=False \
model.activations_checkpoint_method='block' \
model.activations_checkpoint_num_layers=1 \
model.data.data_prefix=[.5,/home/TestData/nlp/megatron_t5/data/pile_val_small_bert_tokenizer_text_document,.5,/home/TestData/nlp/megatron_t5/data/pile_val_small_bert_tokenizer_text_document]"
Expand All @@ -2452,6 +2460,8 @@ pipeline {
model.num_layers=4 \
model.hidden_size=64 \
model.num_attention_heads=8 \
model.activation='reglu' \
model.bias_gelu_fusion=False \
model.activations_checkpoint_method='block' \
model.activations_checkpoint_num_layers=1 \
model.data.data_prefix=[.5,/home/TestData/nlp/megatron_t5/data/pile_val_small_bert_tokenizer_text_document,.5,/home/TestData/nlp/megatron_t5/data/pile_val_small_bert_tokenizer_text_document]"
Expand Down Expand Up @@ -2484,6 +2494,8 @@ pipeline {
model.num_layers=4 \
model.hidden_size=64 \
model.num_attention_heads=8 \
model.activation='geglu' \
model.bias_gelu_fusion=False \
model.activations_checkpoint_method='block' \
model.activations_checkpoint_num_layers=1 \
model.data.data_prefix=[.5,/home/TestData/nlp/megatron_t5/data/pile_val_small_bert_tokenizer_text_document,.5,/home/TestData/nlp/megatron_t5/data/pile_val_small_bert_tokenizer_text_document]"
Expand All @@ -2505,6 +2517,8 @@ pipeline {
model.num_layers=4 \
model.hidden_size=64 \
model.num_attention_heads=8 \
model.activation='geglu' \
model.bias_gelu_fusion=False \
model.activations_checkpoint_method='block' \
model.activations_checkpoint_num_layers=1 \
model.data.data_prefix=[.5,/home/TestData/nlp/megatron_t5/data/pile_val_small_bert_tokenizer_text_document,.5,/home/TestData/nlp/megatron_t5/data/pile_val_small_bert_tokenizer_text_document]"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import ntpath
import os
from typing import Dict, List, Optional

import ntpath
import onnx
import torch
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import Trainer
Expand All @@ -34,7 +33,6 @@
from nemo.collections.nlp.parts.utils_funcs import tensor2list
from nemo.core.classes import typecheck
from nemo.core.classes.common import PretrainedModelInfo
from nemo.core.neural_types import NeuralType
from nemo.utils import logging


Expand Down
53 changes: 40 additions & 13 deletions nemo/collections/nlp/modules/common/megatron/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from nemo.collections.nlp.modules.common.megatron.fused_layer_norm import get_layer_norm
from nemo.collections.nlp.modules.common.megatron.module import MegatronModule
from nemo.collections.nlp.modules.common.megatron.utils import ApexGuardDefaults, attention_mask_func, erf_gelu
from nemo.utils import logging

try:
from apex.transformer import parallel_state, tensor_parallel
Expand Down Expand Up @@ -82,8 +81,8 @@ def __init__(
super(ParallelMLP, self).__init__()
self.activation = activation

if activation not in ['gelu', 'geglu']:
raise ValueError(f"Activation {activation} not supported. Only gelu and geglu are supported.")
if activation not in ['gelu', 'geglu', 'reglu', 'swiglu']:
raise ValueError(f"Activation {activation} not supported. Only gelu, geglu, reglu, swiglu are supported.")

# Project to 4h.
self.dense_h_to_4h = tensor_parallel.ColumnParallelLinear(
Expand All @@ -95,28 +94,47 @@ def __init__(
use_cpu_initialization=use_cpu_initialization,
)

if activation == 'geglu':
# Separate linear layer for GEGLU activation.
if activation in ['geglu', 'reglu', 'swiglu']:
# Separate linear layer for *GLU activations.
# Source: https://github.com/huggingface/transformers/blob/bee361c6f1f7704f8c688895f2f86f6e5ff84727/src/transformers/models/t5/modeling_t5.py#L292
self.dense_h_to_4h_2 = tensor_parallel.ColumnParallelLinear(
hidden_size,
ffn_hidden_size, # NOTE: When using geglu, divide ffn dim by 2/3 to keep overall params the same.
ffn_hidden_size, # NOTE: When using *glu, divide ffn dim by 2/3 to keep overall params the same.
gather_output=False,
init_method=init_method,
skip_bias_add=True,
use_cpu_initialization=use_cpu_initialization,
)
glu_activation_family = True
else:
glu_activation_family = False

if glu_activation_family and bias_gelu_fusion:
raise ValueError(
f"Cannot use bias_gelu_fusion with {activation} activation. Please turn bias gelu fusion off."
)

if glu_activation_family and openai_gelu:
raise ValueError(
f"Cannot use openai_gelu with specificed activation function : {activation} Please turn openai gelu off."
)

if glu_activation_family and onnx_safe:
raise ValueError(
f"Cannot use onnx_safe with specificed activation function : {activation} Please turn onnx safe off."
)

self.bias_gelu_fusion = bias_gelu_fusion
self.activation_func = F.gelu
if activation == 'geglu':
self.activation_func = 'geglu' # Implemented using F.gelu
if bias_gelu_fusion:
logging.warning("Bias Gelu Fusion is not supported for GEGLU activation. Running with pytorch F.gelu")
if openai_gelu:

if activation == "gelu":
self.activation_func = F.gelu
elif openai_gelu:
self.activation_func = openai_gelu
elif onnx_safe:
self.activation_func = erf_gelu
else:
# Remaining acitvations are implemeted in the forward function.
self.activation_func = None

# Project back to h.
self.dense_4h_to_h = tensor_parallel.RowParallelLinear(
Expand All @@ -133,13 +151,22 @@ def forward(self, hidden_states):
# [s, b, 4hp]
intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)

if self.activation == 'geglu':
if self.activation in ['geglu', 'reglu', 'swiglu']:
intermediate_parallel_2, bias_parallel_2 = self.dense_h_to_4h_2(hidden_states)

if self.activation == 'geglu':
intermediate_parallel = F.gelu(intermediate_parallel + bias_parallel) * (
intermediate_parallel_2 + bias_parallel_2
)
elif self.activation == 'swiglu':
# SiLU or sigmoid linear unit is the same as swish with beta = 1 (which is what https://arxiv.org/pdf/2002.05202.pdf uses.)
intermediate_parallel = F.silu(intermediate_parallel + bias_parallel) * (
intermediate_parallel_2 + bias_parallel_2
)
elif self.activation == 'reglu':
intermediate_parallel = F.relu(intermediate_parallel + bias_parallel) * (
intermediate_parallel_2 + bias_parallel_2
)
elif self.bias_gelu_fusion and self.activation == 'gelu':
intermediate_parallel = fused_bias_gelu(intermediate_parallel, bias_parallel)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,8 @@ def __init__(
| pynutil.add_weight(range_graph, 1.1)
)

roman_graph = RomanFst(deterministic=deterministic).fst
# the weight matches the word_graph weight for "I" cases in long sentences with multiple semiotic tokens
classify |= pynutil.add_weight(roman_graph, 1.1)
# roman_graph = RomanFst(deterministic=deterministic).fst
# classify |= pynutil.add_weight(roman_graph, 1.1)

if not deterministic:
abbreviation_graph = AbbreviationFst(deterministic=deterministic).fst
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,31 +157,36 @@ def __init__(
).fst
v_word_graph = vWord(deterministic=deterministic).fst

sem_w = 1
word_w = 100
punct_w = 2
classify_and_verbalize = (
pynutil.add_weight(whitelist_graph, 1.01)
| pynutil.add_weight(pynini.compose(time_graph, v_time_graph), 1.1)
| pynutil.add_weight(pynini.compose(decimal_graph, v_decimal_graph), 1.1)
| pynutil.add_weight(pynini.compose(measure_graph, v_measure_graph), 1.1)
| pynutil.add_weight(pynini.compose(cardinal_graph, v_cardinal_graph), 1.1)
| pynutil.add_weight(pynini.compose(ordinal_graph, v_ordinal_graph), 1.1)
| pynutil.add_weight(pynini.compose(telephone_graph, v_telephone_graph), 1.1)
| pynutil.add_weight(pynini.compose(electronic_graph, v_electronic_graph), 1.1)
| pynutil.add_weight(pynini.compose(fraction_graph, v_fraction_graph), 1.1)
| pynutil.add_weight(pynini.compose(money_graph, v_money_graph), 1.1)
| pynutil.add_weight(word_graph, 100)
| pynutil.add_weight(pynini.compose(date_graph, v_date_graph), 1.09)
| pynutil.add_weight(pynini.compose(range_graph, v_word_graph), 1.1)
pynutil.add_weight(whitelist_graph, sem_w)
| pynutil.add_weight(pynini.compose(time_graph, v_time_graph), sem_w)
| pynutil.add_weight(pynini.compose(decimal_graph, v_decimal_graph), sem_w)
| pynutil.add_weight(pynini.compose(measure_graph, v_measure_graph), sem_w)
| pynutil.add_weight(pynini.compose(cardinal_graph, v_cardinal_graph), sem_w)
| pynutil.add_weight(pynini.compose(ordinal_graph, v_ordinal_graph), sem_w)
| pynutil.add_weight(pynini.compose(telephone_graph, v_telephone_graph), sem_w)
| pynutil.add_weight(pynini.compose(electronic_graph, v_electronic_graph), sem_w)
| pynutil.add_weight(pynini.compose(fraction_graph, v_fraction_graph), sem_w)
| pynutil.add_weight(pynini.compose(money_graph, v_money_graph), sem_w)
| pynutil.add_weight(word_graph, word_w)
| pynutil.add_weight(pynini.compose(date_graph, v_date_graph), sem_w - 0.01)
| pynutil.add_weight(pynini.compose(range_graph, v_word_graph), sem_w)
).optimize()

if not deterministic:
roman_graph = RomanFst(deterministic=deterministic).fst
# the weight matches the word_graph weight for "I" cases in long sentences with multiple semiotic tokens
classify_and_verbalize |= pynutil.add_weight(pynini.compose(roman_graph, v_roman_graph), 100)
classify_and_verbalize |= pynutil.add_weight(pynini.compose(roman_graph, v_roman_graph), word_w)

abbreviation_graph = AbbreviationFst(whitelist=whitelist, deterministic=deterministic).fst
classify_and_verbalize |= pynutil.add_weight(pynini.compose(abbreviation_graph, v_abbreviation), 100)
classify_and_verbalize |= pynutil.add_weight(
pynini.compose(abbreviation_graph, v_abbreviation), word_w
)

punct_only = pynutil.add_weight(punct_graph, weight=20.1)
punct_only = pynutil.add_weight(punct_graph, weight=punct_w)
punct = pynini.closure(
pynini.compose(pynini.closure(NEMO_WHITE_SPACE, 1), delete_extra_space)
| (pynutil.insert(" ") + punct_only),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ def __init__(self, deterministic: bool = True):
| whitelist_graph
)

roman_graph = RomanFst(deterministic=deterministic).fst
graph |= roman_graph
# roman_graph = RomanFst(deterministic=deterministic).fst
# graph |= roman_graph

if not deterministic:
abbreviation_graph = AbbreviationFst(deterministic=deterministic).fst
Expand Down
Loading

0 comments on commit 715f95c

Please sign in to comment.