Skip to content

Commit

Permalink
Merge branch 'add_fvad_doc' of https://github.com/NVIDIA/NeMo into ad…
Browse files Browse the repository at this point in the history
…d_fvad_doc
  • Loading branch information
stevehuang52 committed Jun 22, 2023
2 parents 9555cfd + 9adcbad commit fcfdf01
Show file tree
Hide file tree
Showing 15 changed files with 656 additions and 554 deletions.
718 changes: 363 additions & 355 deletions Jenkinsfile

Large diffs are not rendered by default.

8 changes: 8 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,14 @@ Transformer Engine already supports Flash Attention for GPT models. If you want
pip install flash-attn
pip install triton==2.0.0.dev20221202
NLP inference UI
~~~~~~~~~~~~~~~~~~~~
To launch the inference web UI server, please install the gradio `gradio <https://gradio.app/>`_.

.. code-block:: bash
pip install gradio==3.34.0
NeMo Text Processing
~~~~~~~~~~~~~~~~~~~~
NeMo Text Processing, specifically (Inverse) Text Normalization, is now a separate repository `https://github.com/NVIDIA/NeMo-text-processing <https://github.com/NVIDIA/NeMo-text-processing>`_.
Expand Down
18 changes: 17 additions & 1 deletion examples/asr/transcribe_speech_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,15 @@
predict_ds.batch_size=16 \
output_path=/tmp/
Example for Hybrid-CTC/RNNT models with non-tarred datasets:
python transcribe_speech_parallel.py \
model=stt_en_fastconformer_hybrid_large \
decoder_type=ctc \
predict_ds.manifest_filepath=/dataset/manifest_file.json \
predict_ds.batch_size=16 \
output_path=/tmp/
Example for tarred datasets:
python transcribe_speech_parallel.py \
Expand Down Expand Up @@ -73,7 +82,7 @@
from nemo.collections.asr.data.audio_to_text_dataset import ASRPredictionWriter
from nemo.collections.asr.metrics.rnnt_wer import RNNTDecodingConfig
from nemo.collections.asr.metrics.wer import word_error_rate
from nemo.collections.asr.models import ASRModel
from nemo.collections.asr.models import ASRModel, EncDecHybridRNNTCTCModel
from nemo.collections.asr.models.configs.asr_models_config import ASRDatasetConfig
from nemo.core.config import TrainerConfig, hydra_runner
from nemo.utils import logging
Expand All @@ -92,6 +101,10 @@ class ParallelTranscriptionConfig:

# decoding strategy for RNNT models
rnnt_decoding: RNNTDecodingConfig = RNNTDecodingConfig()

# decoder for hybrid models, must be one of 'ctc', 'rnnt' if not None
decoder_type: Optional[str] = None

trainer: TrainerConfig = TrainerConfig(devices=-1, accelerator="gpu", strategy="ddp")


Expand Down Expand Up @@ -137,6 +150,9 @@ def main(cfg: ParallelTranscriptionConfig):
)
model = ASRModel.from_pretrained(model_name=cfg.model, map_location="cpu")

if isinstance(model, EncDecHybridRNNTCTCModel) and cfg.decoder_type is not None:
model.change_decoding_strategy(decoder_type=cfg.decoder_type)

trainer = ptl.Trainer(**cfg.trainer)

cfg.predict_ds.return_sample_id = True
Expand Down
1 change: 1 addition & 0 deletions nemo/collections/asr/data/audio_to_text_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,6 +713,7 @@ def write_on_batch_end(
item = {}
sample = self.dataset.get_manifest_sample(sample_id)
item["audio_filepath"] = sample.audio_file
item["offset"] = sample.offset
item["duration"] = sample.duration
item["text"] = sample.text_raw
item["pred_text"] = transcribed_text
Expand Down
5 changes: 3 additions & 2 deletions nemo/collections/asr/parts/utils/transcribe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@
from tqdm.auto import tqdm

import nemo.collections.asr as nemo_asr
from nemo.collections.asr.models import ASRModel
from nemo.collections.asr.models.ctc_models import EncDecCTCModel
from nemo.collections.asr.models import ASRModel, EncDecHybridRNNTCTCModel
from nemo.collections.asr.parts.utils import rnnt_utils
from nemo.collections.asr.parts.utils.streaming_utils import FrameBatchASR
from nemo.collections.common.parts.preprocessing.manifest import get_full_path
Expand Down Expand Up @@ -421,6 +420,8 @@ def transcribe_partial_audio(
input_signal=test_batch[0].to(device), input_signal_length=test_batch[1].to(device)
)
logits, logits_len = outputs[0], outputs[1]
if isinstance(asr_model, EncDecHybridRNNTCTCModel) and decoder_type == "ctc":
logits = asr_model.ctc_decoder(encoder_output=logits)
if logprobs:
# dump log probs per file
for idx in range(logits.shape[0]):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,8 @@ def _do_init(self, path, skip_warmup=True, delay_data_mmap=False):
self._create_data_mmap(skip_warmup)
else:
logging.info(" skip creating data numpy buffer of mmap...")
self._bin_buffer_mmap = None
self._bin_buffer = None

def _create_data_mmap(self, skip_warmup):
if not skip_warmup:
Expand All @@ -524,7 +526,8 @@ def _create_data_mmap(self, skip_warmup):
self._bin_buffer = memoryview(self._bin_buffer_mmap)

def __del__(self):
self._bin_buffer_mmap._mmap.close()
if self._bin_buffer_mmap is not None:
self._bin_buffer_mmap._mmap.close()
del self._bin_buffer_mmap
del self._index

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,12 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
self.name_key_to_cfg = {AdapterName.PTUNING_ADAPTER: adapter_cfg}
super().__init__(cfg, trainer)
self.virtual_tokens = cfg.peft.p_tuning.virtual_tokens
self.trainable_keys = self.adapter_keys - set(
[
"model.language_model.adapter_layer.ptuning_adapter.inference_table.prompt_table.taskname.prompt_embeddings.weight"
]
)
# we exclude the above parameter from training because it is present for backward compatibility for inference using FasterTransformer (@adithyare)

def init_peft_modules(self,):
"""
Expand Down Expand Up @@ -268,7 +274,15 @@ def load_state_dict(self, state_dict, strict: bool = True):

def setup_optimizer_param_groups(self):
if self.first_stage_of_pipeline():
super().setup_optimizer_param_groups()
# super().setup_optimizer_param_groups()
self.freeze() # Freeze the entire model
opt_params = []
for n, p in self.named_parameters():
if n in self.trainable_keys:
p.requires_grad = True
opt_params.append(p)

self._optimizer_param_groups = ({"params": opt_params},)
else:
self.freeze() # Freeze the entire model
self._optimizer_param_groups = ({"params": []},)
Expand Down
22 changes: 21 additions & 1 deletion nemo/collections/nlp/modules/common/chatbot_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,29 @@
"""
from __future__ import annotations

from gradio.components import *
import warnings

from markdown2 import Markdown

try:
from typing import Any, Callable, Dict, List, Literal, Tuple

from gradio.components import (
Changeable,
Component,
Enum,
EventListenerMethod,
IOComponent,
JSONSerializable,
Selectable,
document,
processing_utils,
)

GRADIO_AVAILABLE = True
except (ImportError, ModuleNotFoundError):
GRADIO_AVAILABLE = False


class _Keywords(Enum):
NO_VALUE = "NO_VALUE" # Used as a sentinel to determine if nothing is provided as a argument for `value` in `Component.update()`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from nemo.collections.common.parts.utils import activation_registry
from nemo.collections.nlp.modules.common.megatron.fused_bias_gelu import fused_bias_gelu
from nemo.collections.nlp.modules.common.megatron.utils import init_method_const, init_method_normal
from nemo.collections.nlp.modules.common.prompt_encoder import InferenceTable
from nemo.core.classes.mixins import adapter_mixin_strategies

try:
Expand Down Expand Up @@ -65,13 +66,11 @@ class AdapterName(str, enum.Enum):


class InfusedAdapter(nn.Module, AdapterModuleUtil):
def __init__(
self, in_features: int, adapter_strategy: adapter_mixin_strategies.ResidualAddAdapterStrategyConfig = None,
) -> None:
def __init__(self, in_features: int,) -> None:
super().__init__()
self.scalers = nn.Parameter(torch.ones(in_features))
# Setup adapter strategy
self.setup_adapter_strategy(adapter_strategy)
self.setup_adapter_strategy(adapter_mixin_strategies.ReturnResultAdapterStrategy())

def forward(self, x):
x = x * self.scalers[None, None, :]
Expand All @@ -90,7 +89,6 @@ class MLPInfusedAdapter(InfusedAdapter):
@dataclass
class InfusedAdapterConfig:
in_features: int
adapter_strategy: Optional[Any] = adapter_mixin_strategies.ResidualAddAdapterStrategyConfig()
_target_: str = "{0}.{1}".format(InfusedAdapter.__module__, InfusedAdapter.__name__)


Expand All @@ -112,7 +110,6 @@ def __init__(
row_init_method: str = 'zero', # TODO: (@adithyare) should rename this to output_init_method to be more precise.
gather_output: bool = True,
dropout: float = 0.0,
adapter_strategy: adapter_mixin_strategies.ResidualAddAdapterStrategyConfig = None,
):
super().__init__()
if not HAVE_APEX:
Expand Down Expand Up @@ -153,7 +150,7 @@ def __init__(
self.dropout = None

# Setup adapter strategy
self.setup_adapter_strategy(adapter_strategy)
self.setup_adapter_strategy(adapter_mixin_strategies.ReturnResultAdapterStrategy())

def _get_init_fn(self, init_method: str):
if init_method == 'xavier':
Expand Down Expand Up @@ -196,7 +193,6 @@ class ParallelLinearAdapterConfig:
row_init_method: str = 'zero'
gather_output: bool = True
dropout: float = 0.0
adapter_strategy: Optional[Any] = adapter_mixin_strategies.ResidualAddAdapterStrategyConfig()
_target_: str = "{0}.{1}".format(ParallelLinearAdapter.__module__, ParallelLinearAdapter.__name__)


Expand Down Expand Up @@ -250,13 +246,7 @@ class PromptEncoderAdapter(nn.Module, AdapterModuleUtil):
"""

def __init__(
self,
virtual_tokens: int,
bottleneck_dim: int,
embedding_dim: int,
init_std: float,
output_dim: int,
adapter_strategy: adapter_mixin_strategies.ResidualAddAdapterStrategyConfig = None,
self, virtual_tokens: int, bottleneck_dim: int, embedding_dim: int, init_std: float, output_dim: int,
):
"""
Initializes the Tensor Model parallel MLP PromptEncoderMLP module.
Expand All @@ -278,6 +268,7 @@ def __init__(
# (@adithyare) the persistent=False will not pollute the indices into the state_dict of this module.
self.register_buffer("indices", torch.LongTensor(list(range(self.virtual_tokens))), persistent=False)
self.embedding = torch.nn.Embedding(self.virtual_tokens, self.embedding_dim)
self.inference_table = InferenceTable("taskname", self.embedding_dim, self.virtual_tokens)
self.first = ColumnParallelLinear(
self.embedding_dim,
self.bottleneck_dim,
Expand All @@ -301,15 +292,47 @@ def __init__(
gradient_accumulation_fusion=gradient_accumulation_fusion,
)
# Setup adapter strategy
self.setup_adapter_strategy(adapter_strategy)
self.setup_adapter_strategy(adapter_mixin_strategies.ReturnResultAdapterStrategy())

def set_inference_table(self, prompt_representation: torch.Tensor):
"""
This method caches the output representation from the Encoder and saves it inside `self.inference_table`.
"""
prompt_representation = prompt_representation.detach().clone()
self.inference_table.set_prompt_table(prompt_representation)

def clear_inference_table(self,):
self.inference_table.clear_prompt_table()

def get_inference_table(self,):
return self.inference_table.get_prompt_table()

def forward(self, batch_size):
def inner_forward(self,):
input_embeds = self.embedding(self.indices).unsqueeze(0)
intermediate_parallel, bias_parallel = self.first(input_embeds)
intermediate_parallel = fused_bias_gelu(intermediate_parallel, bias_parallel)
output_embeds, bias_parallel = self.second(intermediate_parallel)
output_embeds = output_embeds + bias_parallel
output_embeds = output_embeds.transpose(0, 1)
return output_embeds

def forward(self, batch_size: int, use_cached_reps: bool = False) -> torch.Tensor:
"""
Forward pass through the encoder with caching of prompt representations
"""
if use_cached_reps:
output_embeds = self.get_inference_table().unsqueeze(1)
else:
if self.training:
if self.inference_table.is_inference_ready:
self.clear_inference_table()
output_embeds = self.inner_forward()
else:
if not self.inference_table.is_inference_ready:
output_embeds = self.inner_forward()
self.set_inference_table(output_embeds.squeeze(1))
output_embeds = self.get_inference_table().unsqueeze(1)

output_embeds = output_embeds.expand(self.virtual_tokens, batch_size, self.output_dim)
return output_embeds

Expand All @@ -321,5 +344,4 @@ class PromptEncoderAdapterConfig:
embedding_dim: int
init_std: float
output_dim: int
adapter_strategy: Optional[Any] = adapter_mixin_strategies.ResidualAddAdapterStrategyConfig()
_target_: str = "{0}.{1}".format(PromptEncoderAdapter.__module__, PromptEncoderAdapter.__name__)
Original file line number Diff line number Diff line change
Expand Up @@ -746,10 +746,7 @@ def forward(
ptuning_adapter = self.get_adapter_module(AdapterName.PTUNING_ADAPTER)
v = ptuning_adapter.virtual_tokens
if ptuning_adapter and _sq >= v: # The sequence should be longer the v to insert virtual embeddings.
strategy = ptuning_adapter.adapter_strategy
virtual_embeddings = self.forward_single_enabled_adapter_(
_bs, ptuning_adapter, adapter_name=AdapterName.PTUNING_ADAPTER, adapter_strategy=strategy,
)
virtual_embeddings = ptuning_adapter(_bs)
encoder_input = encoder_input[
v:, :, :
] # the first v tokens are pads so that they can be swapped out with virtual embeddings.
Expand Down
21 changes: 7 additions & 14 deletions nemo/collections/nlp/modules/common/megatron/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,13 +549,9 @@ def forward(
if self.is_adapter_available():
adapter_1 = self.get_adapter_module(AdapterName.PRE_ATTN_ADAPTER)
if adapter_1:
strategy = adapter_1.adapter_strategy
attention_output = self.forward_single_enabled_adapter_(
attention_output,
adapter_1,
adapter_name=AdapterName.PRE_ATTN_ADAPTER,
adapter_strategy=strategy,
)
attention_output = (
adapter_1(attention_output) + attention_output
) # simple adapter call with residual connection

layernorm_input = bias_dropout_add_func(attention_output, attention_bias, residual, self.hidden_dropout)
# print(f"Layer: {self.layer_number} Attention checksum {layernorm_input.sum()}")
Expand Down Expand Up @@ -626,15 +622,12 @@ def forward(
layernorm_input = normalization_output
# MLP.
mlp_output, mlp_bias = self.mlp(normalization_output)
if (
self.is_adapter_available()
): # TODO: (@adithyre) was able to move adapter_2 back to the end of the transformer after ptl 1.7 update.
if self.is_adapter_available():
# TODO: (@adithyre) was able to move adapter_2 back to the end of the transformer after ptl 1.7 update.
adapter_2 = self.get_adapter_module(AdapterName.POST_ATTN_ADAPTER)
if adapter_2:
strategy = adapter_2.adapter_strategy
mlp_output = self.forward_single_enabled_adapter_(
mlp_output, adapter_2, adapter_name=AdapterName.POST_ATTN_ADAPTER, adapter_strategy=strategy
)
mlp_output = adapter_2(mlp_output) + mlp_output # simple adapter call with residual connection

residual = layernorm_input

bias_dropout_add_func = self._get_bias_droput_add_func(
Expand Down
Loading

0 comments on commit fcfdf01

Please sign in to comment.