-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable encoder adapters for Canary and MultiTaskAED models #9409
Changes from 13 commits
7415da6
9e2c5e2
a31e920
06c372a
a1219d4
748334a
38b00d5
f137d72
86c0edb
3e1e8fc
6a6d1d7
6d57049
f4a5864
f92c082
5f24cc6
cf5207b
155619a
166f28f
8b9c08a
f0e2d08
a4f08fa
e8e6092
7992cfe
fceff70
0a9f5fd
2b386ae
b133a59
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,7 +30,7 @@ | |
) | ||
from nemo.collections.asr.metrics import BLEU, WER | ||
from nemo.collections.asr.models.asr_model import ASRModel, ExportableEncDecModel | ||
from nemo.collections.asr.parts.mixins import ASRBPEMixin, ASRTranscriptionMixin | ||
from nemo.collections.asr.parts.mixins import ASRBPEMixin, ASRModuleMixin, ASRTranscriptionMixin | ||
from nemo.collections.asr.parts.mixins.transcription import ( | ||
GenericTranscriptionType, | ||
InternalTranscribeConfig, | ||
|
@@ -114,7 +114,7 @@ def __post_init__(self): | |
self.prompt = parse_multitask_prompt(self.prompt) | ||
|
||
|
||
class EncDecMultiTaskModel(ASRModel, ExportableEncDecModel, ASRBPEMixin, ASRTranscriptionMixin): | ||
class EncDecMultiTaskModel(ASRModel, ExportableEncDecModel, ASRBPEMixin, ASRModuleMixin, ASRTranscriptionMixin): | ||
"""Base class for AED multi-task models""" | ||
|
||
def __init__(self, cfg: DictConfig, trainer: Trainer = None): | ||
|
@@ -224,6 +224,9 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): | |
self.decoding, tokenize=self.cfg.get('bleu_tokenizer', "13a"), log_prediction=False | ||
) # Wer is handling logging | ||
|
||
# Setup encoder adapters (from ASRAdapterModelMixin) | ||
self.setup_adapters() | ||
|
||
def change_decoding_strategy(self, decoding_cfg: DictConfig): | ||
""" | ||
Changes decoding strategy used during Multi Task decoding process. | ||
|
@@ -1003,6 +1006,10 @@ def predict_step(self, batch, batch_idx=0, dataloader_idx=0, has_processed_signa | |
text = [self.decoding.strip_special_tokens(t) for t in text] | ||
return text | ||
|
||
@property | ||
def adapter_module_names(self) -> List[str]: | ||
return ['', 'encoder', 'transf_encoder', 'transf_decoder'] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. reason for adding empty string in list ? ...nvm.. saw the later piece of code where this is used. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Its default adapter if user just provides the name of the module but not a target (add_adapter('xyz', adapter_cfg)). In that case adapter name is '' and it usually defaults to |
||
|
||
|
||
def parse_multitask_prompt(prompt: dict | None) -> list[dict]: | ||
if prompt is None or not prompt: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,17 +13,22 @@ | |
# limitations under the License. | ||
|
||
import copy | ||
from typing import List, Optional, Set | ||
|
||
import torch | ||
import torch.nn as nn | ||
from omegaconf import DictConfig | ||
|
||
from nemo.collections.asr.modules.transformer.transformer_modules import MultiHeadAttention, PositionWiseFF | ||
from nemo.collections.asr.parts.submodules.adapters.attention_adapter_mixin import AttentionAdapterModuleMixin | ||
from nemo.collections.asr.parts.utils import adapter_utils | ||
from nemo.collections.common.parts import form_attention_mask | ||
from nemo.core.classes.mixins import adapter_mixins | ||
|
||
__all__ = ["TransformerDecoder"] | ||
|
||
|
||
class TransformerDecoderBlock(nn.Module): | ||
class TransformerDecoderBlock(nn.Module, AttentionAdapterModuleMixin): | ||
""" | ||
Building block of Transformer decoder. | ||
|
||
|
@@ -63,6 +68,9 @@ def __init__( | |
self.layer_norm_3 = nn.LayerNorm(hidden_size, eps=1e-5) | ||
self.third_sub_layer = PositionWiseFF(hidden_size, inner_size, ffn_dropout, hidden_act) | ||
|
||
# Information for the adapter module mixin | ||
self.self_attention_model = "transf_abs" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. any reason for hard coding There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is slightly a hack, the adapter forward looks up the type of the attention model to see what kind of adapter to call (and what arguments it takes). It just ignores it if its not |
||
|
||
def forward_preln(self, decoder_query, decoder_mask, decoder_keys, encoder_states, encoder_mask): | ||
""" | ||
Pre-LayerNorm block | ||
|
@@ -74,6 +82,17 @@ def forward_preln(self, decoder_query, decoder_mask, decoder_keys, encoder_state | |
self_attn_output = self.first_sub_layer(decoder_query, decoder_keys, decoder_keys, decoder_mask) | ||
self_attn_output += residual | ||
|
||
if self.is_adapter_available(): | ||
# Call the MHA adapters | ||
pack_ip = { | ||
'x': self_attn_output, | ||
'loc': 'mha', | ||
'att_mask': decoder_mask, | ||
'pos_emb': None, | ||
} | ||
pack_ip = self.forward_enabled_adapters(pack_ip) | ||
self_attn_output = pack_ip['x'] | ||
|
||
residual = self_attn_output | ||
self_attn_output = self.layer_norm_2(self_attn_output) | ||
enc_dec_attn_output = self.second_sub_layer(self_attn_output, encoder_states, encoder_states, encoder_mask) | ||
|
@@ -84,6 +103,15 @@ def forward_preln(self, decoder_query, decoder_mask, decoder_keys, encoder_state | |
output_states = self.third_sub_layer(enc_dec_attn_output) | ||
output_states += residual | ||
|
||
if self.is_adapter_available(): | ||
krishnacpuvvada marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# Call the Linear adapters | ||
pack_ip = { | ||
'x': output_states, | ||
'loc': 'post', | ||
} | ||
pack_ip = self.forward_enabled_adapters(pack_ip) | ||
output_states = pack_ip['x'] | ||
|
||
return output_states | ||
|
||
def forward_postln(self, decoder_query, decoder_mask, decoder_keys, encoder_states, encoder_mask): | ||
|
@@ -93,6 +121,18 @@ def forward_postln(self, decoder_query, decoder_mask, decoder_keys, encoder_stat | |
""" | ||
self_attn_output = self.first_sub_layer(decoder_query, decoder_keys, decoder_keys, decoder_mask) | ||
self_attn_output += decoder_query | ||
|
||
if self.is_adapter_available(): | ||
# Call the MHA adapters | ||
pack_ip = { | ||
'x': self_attn_output, | ||
'loc': 'mha', | ||
'att_mask': decoder_mask, | ||
'pos_emb': None, | ||
} | ||
pack_ip = self.forward_enabled_adapters(pack_ip) | ||
self_attn_output = pack_ip['x'] | ||
|
||
self_attn_output = self.layer_norm_1(self_attn_output) | ||
|
||
enc_dec_attn_output = self.second_sub_layer(self_attn_output, encoder_states, encoder_states, encoder_mask) | ||
|
@@ -101,6 +141,16 @@ def forward_postln(self, decoder_query, decoder_mask, decoder_keys, encoder_stat | |
|
||
output_states = self.third_sub_layer(enc_dec_attn_output) | ||
output_states += enc_dec_attn_output | ||
|
||
if self.is_adapter_available(): | ||
# Call the linear adapters | ||
pack_ip = { | ||
'x': output_states, | ||
'loc': 'post', | ||
} | ||
pack_ip = self.forward_enabled_adapters(pack_ip) | ||
output_states = pack_ip['x'] | ||
|
||
return self.layer_norm_3(output_states) | ||
|
||
def forward(self, decoder_query, decoder_mask, decoder_keys, encoder_states, encoder_mask): | ||
|
@@ -109,6 +159,19 @@ def forward(self, decoder_query, decoder_mask, decoder_keys, encoder_states, enc | |
else: | ||
return self.forward_postln(decoder_query, decoder_mask, decoder_keys, encoder_states, encoder_mask) | ||
|
||
def get_accepted_adapter_types(self) -> Set[type]: | ||
types = super().get_accepted_adapter_types() | ||
|
||
if len(types) == 0: | ||
self.set_accepted_adapter_types( | ||
[ | ||
adapter_utils.LINEAR_ADAPTER_CLASSPATH, | ||
adapter_utils.TRANSFORMER_MHA_ADAPTER_CLASSPATH, | ||
] | ||
) | ||
types = self.get_accepted_adapter_types() | ||
return types | ||
|
||
|
||
class TransformerDecoder(nn.Module): | ||
def __init__( | ||
|
@@ -131,6 +194,8 @@ def __init__( | |
else: | ||
self.final_layer_norm = None | ||
|
||
self.d_model = hidden_size | ||
|
||
layer = TransformerDecoderBlock( | ||
hidden_size, | ||
inner_size, | ||
|
@@ -219,3 +284,38 @@ def input_example(self, max_batch=1, max_dim=256): | |
input_ids = torch.randint(low=0, high=2048, size=(max_batch, max_dim, 1024), device=sample.device) | ||
encoder_mask = torch.randint(low=0, high=1, size=(max_batch, max_dim), device=sample.device) | ||
return tuple([input_ids, encoder_mask, input_ids, encoder_mask]) | ||
|
||
|
||
class TransformerDecoderAdapter(TransformerDecoder, adapter_mixins.AdapterModuleMixin): | ||
|
||
# Higher level forwarding | ||
def add_adapter(self, name: str, cfg: dict): | ||
cfg = self._update_adapter_cfg_input_dim(cfg) | ||
for transformer_layer in self.layers: # type: adapter_mixins.AdapterModuleMixin | ||
transformer_layer.add_adapter(name, cfg) | ||
|
||
def is_adapter_available(self) -> bool: | ||
return any([transformer_layer.is_adapter_available() for transformer_layer in self.layers]) | ||
|
||
def set_enabled_adapters(self, name: Optional[str] = None, enabled: bool = True): | ||
for transformer_layer in self.layers: # type: adapter_mixins.AdapterModuleMixin | ||
transformer_layer.set_enabled_adapters(name=name, enabled=enabled) | ||
|
||
def get_enabled_adapters(self) -> List[str]: | ||
names = set([]) | ||
for transformer_layer in self.layers: # type: adapter_mixins.AdapterModuleMixin | ||
names.update(transformer_layer.get_enabled_adapters()) | ||
|
||
names = sorted(list(names)) | ||
return names | ||
|
||
def _update_adapter_cfg_input_dim(self, cfg: DictConfig): | ||
cfg = adapter_utils.update_adapter_cfg_input_dim(self, cfg, module_dim=self.d_model) | ||
return cfg | ||
|
||
|
||
""" | ||
Register any additional information | ||
""" | ||
if adapter_mixins.get_registered_adapter(TransformerDecoder) is None: | ||
adapter_mixins.register_adapter(base_class=TransformerDecoder, adapter_class=TransformerDecoderAdapter) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
couldn;t find how its calling set up of encoder adapter, could you please shed some light here on how this is calling specifically "encoder" layer adapter.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the core entrypoint of adapters.
Users should mostly focus on https://docs.nvidia.com/nemo-framework/user-guide/latest/nemotoolkit/core/adapters/intro.html#adapters and https://docs.nvidia.com/nemo-framework/user-guide/latest/nemotoolkit/core/adapters/intro.html#using-the-adapter-module
Devs should mostly focus on https://colab.research.google.com/github/NVIDIA/NeMo/blob/stable/tutorials/02_NeMo_Adapters.ipynb
I think it would be good to link the dev tutorial somewhere in the core docs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As to your question about "specific encoder adapter" - that's not this classes responsibility but the
ASRModuleMixin
which inheritsASRAdapterModelMixin
. Hmm we need some tool to generate class diagrams..In any case, ASRAdapterModelMixin determines how to direct the
adapter_name
(encoder:xyz
) to the module. See below definition ofASRAdapterModelMixin
.