Skip to content

Commit

Permalink
Enable encoder adapters for Canary and MultiTaskAED models (NVIDIA#9409)
Browse files Browse the repository at this point in the history
* Fix assertions for adapter types

Signed-off-by: smajumdar <[email protected]>

* Apply isort and black reformatting

Signed-off-by: titu1994 <[email protected]>

* Cleanup

Signed-off-by: smajumdar <[email protected]>

* Apply isort and black reformatting

Signed-off-by: titu1994 <[email protected]>

* Finalize support for decoder adapters

Signed-off-by: smajumdar <[email protected]>

* Apply isort and black reformatting

Signed-off-by: titu1994 <[email protected]>

* fix the freeze/unfreeze problem by replacing as_frozen with torch.inference_mode

* Apply isort and black reformatting

Signed-off-by: weiqingw4ng <[email protected]>

* Update tests to new generic way of module update

Signed-off-by: smajumdar <[email protected]>

* Finalize code for update module

Signed-off-by: smajumdar <[email protected]>

* Apply isort and black reformatting

Signed-off-by: titu1994 <[email protected]>

* Fix variable name

Signed-off-by: smajumdar <[email protected]>

* Finalize projection support for transformer mha adapters

Signed-off-by: smajumdar <[email protected]>

* Apply isort and black reformatting

Signed-off-by: titu1994 <[email protected]>

* Correct implementation of freeze restore

Signed-off-by: smajumdar <[email protected]>

* Apply isort and black reformatting

Signed-off-by: titu1994 <[email protected]>

* Corrects the implementation of replace_adapter_modules to limit to just the top level modules

Signed-off-by: smajumdar <[email protected]>

* Apply isort and black reformatting

Signed-off-by: titu1994 <[email protected]>

* Remove registration of Transformer MHA

Signed-off-by: smajumdar <[email protected]>

* Remove registration of Transformer MHA

Signed-off-by: smajumdar <[email protected]>

* Address reviewer comments

Signed-off-by: smajumdar <[email protected]>

---------

Signed-off-by: smajumdar <[email protected]>
Signed-off-by: titu1994 <[email protected]>
Signed-off-by: weiqingw4ng <[email protected]>
Co-authored-by: Weiqing Wang <[email protected]>
Co-authored-by: weiqingw4ng <[email protected]>
Signed-off-by: Alex Cui <[email protected]>
  • Loading branch information
3 people authored and BuyuanCui committed Jul 12, 2024
1 parent 8d0659e commit 4d0f0ae
Show file tree
Hide file tree
Showing 23 changed files with 1,300 additions and 419 deletions.
11 changes: 9 additions & 2 deletions nemo/collections/asr/models/aed_multitask_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
)
from nemo.collections.asr.metrics import BLEU, WER
from nemo.collections.asr.models.asr_model import ASRModel, ExportableEncDecModel
from nemo.collections.asr.parts.mixins import ASRBPEMixin, ASRTranscriptionMixin
from nemo.collections.asr.parts.mixins import ASRBPEMixin, ASRModuleMixin, ASRTranscriptionMixin
from nemo.collections.asr.parts.mixins.transcription import (
GenericTranscriptionType,
InternalTranscribeConfig,
Expand Down Expand Up @@ -115,7 +115,7 @@ def __post_init__(self):
self.prompt = parse_multitask_prompt(self.prompt)


class EncDecMultiTaskModel(ASRModel, ExportableEncDecModel, ASRBPEMixin, ASRTranscriptionMixin):
class EncDecMultiTaskModel(ASRModel, ExportableEncDecModel, ASRBPEMixin, ASRModuleMixin, ASRTranscriptionMixin):
"""Base class for AED multi-task models"""

def __init__(self, cfg: DictConfig, trainer: Trainer = None):
Expand Down Expand Up @@ -225,6 +225,9 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
self.decoding, tokenize=self.cfg.get('bleu_tokenizer', "13a"), log_prediction=False
) # Wer is handling logging

# Setup encoder adapters (from ASRAdapterModelMixin)
self.setup_adapters()

def change_decoding_strategy(self, decoding_cfg: DictConfig):
"""
Changes decoding strategy used during Multi Task decoding process.
Expand Down Expand Up @@ -1057,6 +1060,10 @@ def predict_step(self, batch, batch_idx=0, dataloader_idx=0, has_processed_signa
text = [self.decoding.strip_special_tokens(t) for t in text]
return text

@property
def adapter_module_names(self) -> List[str]:
return ['', 'encoder', 'transf_encoder', 'transf_decoder']


def parse_multitask_prompt(prompt: dict | None) -> list[dict]:
if prompt is None or not prompt:
Expand Down
4 changes: 4 additions & 0 deletions nemo/collections/asr/models/ctc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -879,6 +879,10 @@ def list_available_models(cls) -> List[PretrainedModelInfo]:

return results

@property
def adapter_module_names(self) -> List[str]:
return ['', 'encoder', 'decoder']

@property
def wer(self):
return self._wer
Expand Down
53 changes: 48 additions & 5 deletions nemo/collections/asr/modules/transformer/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,21 @@
# limitations under the License.

from dataclasses import dataclass
from typing import Dict, Optional
from typing import Dict, List, Optional

import torch
from omegaconf.omegaconf import MISSING
from omegaconf.omegaconf import MISSING, DictConfig

from nemo.collections.asr.modules.transformer.decoder_module import DecoderModule
from nemo.collections.asr.modules.transformer.encoder_module import EncoderModule
from nemo.collections.asr.modules.transformer.transformer_decoders import TransformerDecoder
from nemo.collections.asr.modules.transformer.transformer_decoders import TransformerDecoder, TransformerDecoderAdapter
from nemo.collections.asr.modules.transformer.transformer_encoders import TransformerEncoder
from nemo.collections.asr.modules.transformer.transformer_modules import TransformerEmbedding
from nemo.collections.asr.parts.submodules.adapters.attention_adapter_mixin import AttentionAdapterModuleMixin
from nemo.collections.asr.parts.utils import adapter_utils
from nemo.core.classes.common import typecheck
from nemo.core.classes.exportable import Exportable
from nemo.core.classes.mixins import adapter_mixins
from nemo.core.neural_types import ChannelType, NeuralType


Expand Down Expand Up @@ -155,6 +158,8 @@ def input_example(self, max_batch=1, max_dim=256):


class TransformerDecoderNM(DecoderModule, Exportable):
DECODER_TYPE: type = TransformerDecoder

def __init__(
self,
vocab_size: int,
Expand Down Expand Up @@ -192,7 +197,7 @@ def __init__(
learn_positional_encodings=learn_positional_encodings,
)

self._decoder = TransformerDecoder(
self._decoder = self.DECODER_TYPE(
hidden_size=self.hidden_size,
num_layers=num_layers,
inner_size=inner_size,
Expand All @@ -207,7 +212,12 @@ def __init__(

@typecheck()
def forward(
self, input_ids, decoder_mask, encoder_embeddings, encoder_mask, decoder_mems=None,
self,
input_ids,
decoder_mask,
encoder_embeddings,
encoder_mask,
decoder_mems=None,
):
start_pos = 0
if decoder_mems is not None:
Expand Down Expand Up @@ -274,3 +284,36 @@ def output_types(self) -> Optional[Dict[str, NeuralType]]:
return {"last_hidden_states": NeuralType(('B', 'D', 'T', 'D'), ChannelType())}
else:
return {"last_hidden_states": NeuralType(('B', 'T', 'D'), ChannelType())}


class TransformerDecoderNMAdapter(TransformerDecoderNM, adapter_mixins.AdapterModuleMixin):
DECODER_TYPE: type = TransformerDecoderAdapter

# Higher level forwarding
def add_adapter(self, name: str, cfg: dict):
cfg = self._update_adapter_cfg_input_dim(cfg)
self._decoder.add_adapter(name, cfg) # type: adapter_mixins.AdapterModuleMixin

def is_adapter_available(self) -> bool:
return self._decoder.is_adapter_available() # type: adapter_mixins.AdapterModuleMixin

def set_enabled_adapters(self, name: Optional[str] = None, enabled: bool = True):
self._decoder.set_enabled_adapters(name=name, enabled=enabled) # # type: adapter_mixins.AdapterModuleMixin

def get_enabled_adapters(self) -> List[str]:
names = set([])
names.update(self._decoder.get_enabled_adapters()) # type: adapter_mixins.AdapterModuleMixin

names = sorted(list(names))
return names

def _update_adapter_cfg_input_dim(self, cfg: DictConfig):
cfg = adapter_utils.update_adapter_cfg_input_dim(self, cfg, module_dim=self._hidden_size)
return cfg


"""
Register any additional information
"""
if adapter_mixins.get_registered_adapter(TransformerDecoderNM) is None:
adapter_mixins.register_adapter(base_class=TransformerDecoderNM, adapter_class=TransformerDecoderNMAdapter)
102 changes: 101 additions & 1 deletion nemo/collections/asr/modules/transformer/transformer_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,22 @@
# limitations under the License.

import copy
from typing import List, Optional, Set

import torch
import torch.nn as nn
from omegaconf import DictConfig

from nemo.collections.asr.modules.transformer.transformer_modules import MultiHeadAttention, PositionWiseFF
from nemo.collections.asr.parts.submodules.adapters.attention_adapter_mixin import AttentionAdapterModuleMixin
from nemo.collections.asr.parts.utils import adapter_utils
from nemo.collections.common.parts import form_attention_mask
from nemo.core.classes.mixins import adapter_mixins

__all__ = ["TransformerDecoder"]


class TransformerDecoderBlock(nn.Module):
class TransformerDecoderBlock(nn.Module, AttentionAdapterModuleMixin):
"""
Building block of Transformer decoder.
Expand Down Expand Up @@ -63,6 +68,9 @@ def __init__(
self.layer_norm_3 = nn.LayerNorm(hidden_size, eps=1e-5)
self.third_sub_layer = PositionWiseFF(hidden_size, inner_size, ffn_dropout, hidden_act)

# Information for the adapter module mixin
self.self_attention_model = "transf_abs"

def forward_preln(self, decoder_query, decoder_mask, decoder_keys, encoder_states, encoder_mask):
"""
Pre-LayerNorm block
Expand All @@ -74,6 +82,17 @@ def forward_preln(self, decoder_query, decoder_mask, decoder_keys, encoder_state
self_attn_output = self.first_sub_layer(decoder_query, decoder_keys, decoder_keys, decoder_mask)
self_attn_output += residual

if self.is_adapter_available():
# Call the MHA adapters
pack_input = {
'x': self_attn_output,
'loc': 'mha',
'att_mask': decoder_mask,
'pos_emb': None,
}
pack_input = self.forward_enabled_adapters(pack_input)
self_attn_output = pack_input['x']

residual = self_attn_output
self_attn_output = self.layer_norm_2(self_attn_output)
enc_dec_attn_output = self.second_sub_layer(self_attn_output, encoder_states, encoder_states, encoder_mask)
Expand All @@ -84,6 +103,15 @@ def forward_preln(self, decoder_query, decoder_mask, decoder_keys, encoder_state
output_states = self.third_sub_layer(enc_dec_attn_output)
output_states += residual

if self.is_adapter_available():
# Call the Linear adapters
pack_input = {
'x': output_states,
'loc': 'post',
}
pack_input = self.forward_enabled_adapters(pack_input)
output_states = pack_input['x']

return output_states

def forward_postln(self, decoder_query, decoder_mask, decoder_keys, encoder_states, encoder_mask):
Expand All @@ -93,6 +121,18 @@ def forward_postln(self, decoder_query, decoder_mask, decoder_keys, encoder_stat
"""
self_attn_output = self.first_sub_layer(decoder_query, decoder_keys, decoder_keys, decoder_mask)
self_attn_output += decoder_query

if self.is_adapter_available():
# Call the MHA adapters
pack_ip = {
'x': self_attn_output,
'loc': 'mha',
'att_mask': decoder_mask,
'pos_emb': None,
}
pack_ip = self.forward_enabled_adapters(pack_ip)
self_attn_output = pack_ip['x']

self_attn_output = self.layer_norm_1(self_attn_output)

enc_dec_attn_output = self.second_sub_layer(self_attn_output, encoder_states, encoder_states, encoder_mask)
Expand All @@ -101,6 +141,16 @@ def forward_postln(self, decoder_query, decoder_mask, decoder_keys, encoder_stat

output_states = self.third_sub_layer(enc_dec_attn_output)
output_states += enc_dec_attn_output

if self.is_adapter_available():
# Call the linear adapters
pack_ip = {
'x': output_states,
'loc': 'post',
}
pack_ip = self.forward_enabled_adapters(pack_ip)
output_states = pack_ip['x']

return self.layer_norm_3(output_states)

def forward(self, decoder_query, decoder_mask, decoder_keys, encoder_states, encoder_mask):
Expand All @@ -109,6 +159,19 @@ def forward(self, decoder_query, decoder_mask, decoder_keys, encoder_states, enc
else:
return self.forward_postln(decoder_query, decoder_mask, decoder_keys, encoder_states, encoder_mask)

def get_accepted_adapter_types(self) -> Set[type]:
types = super().get_accepted_adapter_types()

if len(types) == 0:
self.set_accepted_adapter_types(
[
adapter_utils.LINEAR_ADAPTER_CLASSPATH,
adapter_utils.TRANSFORMER_MHA_ADAPTER_CLASSPATH,
]
)
types = self.get_accepted_adapter_types()
return types


class TransformerDecoder(nn.Module):
def __init__(
Expand All @@ -131,6 +194,8 @@ def __init__(
else:
self.final_layer_norm = None

self.d_model = hidden_size

layer = TransformerDecoderBlock(
hidden_size,
inner_size,
Expand Down Expand Up @@ -219,3 +284,38 @@ def input_example(self, max_batch=1, max_dim=256):
input_ids = torch.randint(low=0, high=2048, size=(max_batch, max_dim, 1024), device=sample.device)
encoder_mask = torch.randint(low=0, high=1, size=(max_batch, max_dim), device=sample.device)
return tuple([input_ids, encoder_mask, input_ids, encoder_mask])


class TransformerDecoderAdapter(TransformerDecoder, adapter_mixins.AdapterModuleMixin):

# Higher level forwarding
def add_adapter(self, name: str, cfg: dict):
cfg = self._update_adapter_cfg_input_dim(cfg)
for transformer_layer in self.layers: # type: adapter_mixins.AdapterModuleMixin
transformer_layer.add_adapter(name, cfg)

def is_adapter_available(self) -> bool:
return any([transformer_layer.is_adapter_available() for transformer_layer in self.layers])

def set_enabled_adapters(self, name: Optional[str] = None, enabled: bool = True):
for transformer_layer in self.layers: # type: adapter_mixins.AdapterModuleMixin
transformer_layer.set_enabled_adapters(name=name, enabled=enabled)

def get_enabled_adapters(self) -> List[str]:
names = set([])
for transformer_layer in self.layers: # type: adapter_mixins.AdapterModuleMixin
names.update(transformer_layer.get_enabled_adapters())

names = sorted(list(names))
return names

def _update_adapter_cfg_input_dim(self, cfg: DictConfig):
cfg = adapter_utils.update_adapter_cfg_input_dim(self, cfg, module_dim=self.d_model)
return cfg


"""
Register any additional information
"""
if adapter_mixins.get_registered_adapter(TransformerDecoder) is None:
adapter_mixins.register_adapter(base_class=TransformerDecoder, adapter_class=TransformerDecoderAdapter)
Loading

0 comments on commit 4d0f0ae

Please sign in to comment.