Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable encoder adapters for Canary and MultiTaskAED models #9409

Merged
merged 27 commits into from
Jun 29, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
7415da6
Fix assertions for adapter types
titu1994 Jun 7, 2024
9e2c5e2
Apply isort and black reformatting
titu1994 Jun 7, 2024
a31e920
Cleanup
titu1994 Jun 7, 2024
06c372a
Apply isort and black reformatting
titu1994 Jun 7, 2024
a1219d4
Finalize support for decoder adapters
titu1994 Jun 7, 2024
748334a
Merge branch 'main' into canary_adapters
titu1994 Jun 7, 2024
38b00d5
Apply isort and black reformatting
titu1994 Jun 7, 2024
f137d72
Merge branch 'main' into canary_adapters
titu1994 Jun 11, 2024
86c0edb
fix the freeze/unfreeze problem by replacing as_frozen with torch.inf…
weiqingw4ng Jun 13, 2024
3e1e8fc
Apply isort and black reformatting
weiqingw4ng Jun 13, 2024
6a6d1d7
Update tests to new generic way of module update
titu1994 Jun 14, 2024
6d57049
Finalize code for update module
titu1994 Jun 14, 2024
f4a5864
Apply isort and black reformatting
titu1994 Jun 14, 2024
f92c082
Fix variable name
titu1994 Jun 14, 2024
5f24cc6
Merge branch 'main' into canary_adapters
titu1994 Jun 15, 2024
cf5207b
Finalize projection support for transformer mha adapters
titu1994 Jun 15, 2024
155619a
Apply isort and black reformatting
titu1994 Jun 15, 2024
166f28f
Correct implementation of freeze restore
titu1994 Jun 15, 2024
8b9c08a
Apply isort and black reformatting
titu1994 Jun 15, 2024
f0e2d08
Corrects the implementation of replace_adapter_modules to limit to ju…
titu1994 Jun 25, 2024
a4f08fa
Apply isort and black reformatting
titu1994 Jun 25, 2024
e8e6092
Remove registration of Transformer MHA
titu1994 Jun 25, 2024
7992cfe
Merge branch 'main' into canary_adapters
titu1994 Jun 25, 2024
fceff70
Remove registration of Transformer MHA
titu1994 Jun 25, 2024
0a9f5fd
Merge branch 'main' into canary_adapters
titu1994 Jun 25, 2024
2b386ae
Merge branch 'main' into canary_adapters
titu1994 Jun 26, 2024
b133a59
Address reviewer comments
titu1994 Jun 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions nemo/collections/asr/models/aed_multitask_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
)
from nemo.collections.asr.metrics import BLEU, WER
from nemo.collections.asr.models.asr_model import ASRModel, ExportableEncDecModel
from nemo.collections.asr.parts.mixins import ASRBPEMixin, ASRTranscriptionMixin
from nemo.collections.asr.parts.mixins import ASRBPEMixin, ASRModuleMixin, ASRTranscriptionMixin
from nemo.collections.asr.parts.mixins.transcription import (
GenericTranscriptionType,
InternalTranscribeConfig,
Expand Down Expand Up @@ -114,7 +114,7 @@ def __post_init__(self):
self.prompt = parse_multitask_prompt(self.prompt)


class EncDecMultiTaskModel(ASRModel, ExportableEncDecModel, ASRBPEMixin, ASRTranscriptionMixin):
class EncDecMultiTaskModel(ASRModel, ExportableEncDecModel, ASRBPEMixin, ASRModuleMixin, ASRTranscriptionMixin):
"""Base class for AED multi-task models"""

def __init__(self, cfg: DictConfig, trainer: Trainer = None):
Expand Down Expand Up @@ -224,6 +224,9 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
self.decoding, tokenize=self.cfg.get('bleu_tokenizer', "13a"), log_prediction=False
) # Wer is handling logging

# Setup encoder adapters (from ASRAdapterModelMixin)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

couldn;t find how its calling set up of encoder adapter, could you please shed some light here on how this is calling specifically "encoder" layer adapter.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As to your question about "specific encoder adapter" - that's not this classes responsibility but the ASRModuleMixin which inherits ASRAdapterModelMixin. Hmm we need some tool to generate class diagrams..

In any case, ASRAdapterModelMixin determines how to direct the adapter_name (encoder:xyz) to the module. See below definition of ASRAdapterModelMixin.

self.setup_adapters()

def change_decoding_strategy(self, decoding_cfg: DictConfig):
"""
Changes decoding strategy used during Multi Task decoding process.
Expand Down Expand Up @@ -1003,6 +1006,10 @@ def predict_step(self, batch, batch_idx=0, dataloader_idx=0, has_processed_signa
text = [self.decoding.strip_special_tokens(t) for t in text]
return text

@property
def adapter_module_names(self) -> List[str]:
return ['', 'encoder', 'transf_encoder', 'transf_decoder']
Copy link
Collaborator

@krishnacpuvvada krishnacpuvvada Jun 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reason for adding empty string in list ? ...nvm.. saw the later piece of code where this is used.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its default adapter if user just provides the name of the module but not a target (add_adapter('xyz', adapter_cfg)). In that case adapter name is '' and it usually defaults to encoder adapters



def parse_multitask_prompt(prompt: dict | None) -> list[dict]:
if prompt is None or not prompt:
Expand Down
53 changes: 48 additions & 5 deletions nemo/collections/asr/modules/transformer/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,21 @@
# limitations under the License.

from dataclasses import dataclass
from typing import Dict, Optional
from typing import Dict, List, Optional

import torch
from omegaconf.omegaconf import MISSING
from omegaconf.omegaconf import MISSING, DictConfig

from nemo.collections.asr.modules.transformer.decoder_module import DecoderModule
from nemo.collections.asr.modules.transformer.encoder_module import EncoderModule
from nemo.collections.asr.modules.transformer.transformer_decoders import TransformerDecoder
from nemo.collections.asr.modules.transformer.transformer_decoders import TransformerDecoder, TransformerDecoderAdapter
from nemo.collections.asr.modules.transformer.transformer_encoders import TransformerEncoder
from nemo.collections.asr.modules.transformer.transformer_modules import TransformerEmbedding
from nemo.collections.asr.parts.submodules.adapters.attention_adapter_mixin import AttentionAdapterModuleMixin

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'AttentionAdapterModuleMixin' is not used.
from nemo.collections.asr.parts.utils import adapter_utils
from nemo.core.classes.common import typecheck
from nemo.core.classes.exportable import Exportable
from nemo.core.classes.mixins import adapter_mixins
from nemo.core.neural_types import ChannelType, NeuralType


Expand Down Expand Up @@ -155,6 +158,8 @@


class TransformerDecoderNM(DecoderModule, Exportable):
DECODER_TYPE: type = TransformerDecoder

def __init__(
self,
vocab_size: int,
Expand Down Expand Up @@ -192,7 +197,7 @@
learn_positional_encodings=learn_positional_encodings,
)

self._decoder = TransformerDecoder(
self._decoder = self.DECODER_TYPE(
hidden_size=self.hidden_size,
num_layers=num_layers,
inner_size=inner_size,
Expand All @@ -207,7 +212,12 @@

@typecheck()
def forward(
self, input_ids, decoder_mask, encoder_embeddings, encoder_mask, decoder_mems=None,
self,
input_ids,
decoder_mask,
encoder_embeddings,
encoder_mask,
decoder_mems=None,
):
start_pos = 0
if decoder_mems is not None:
Expand Down Expand Up @@ -274,3 +284,36 @@
return {"last_hidden_states": NeuralType(('B', 'D', 'T', 'D'), ChannelType())}
else:
return {"last_hidden_states": NeuralType(('B', 'T', 'D'), ChannelType())}


class TransformerDecoderNMAdapter(TransformerDecoderNM, adapter_mixins.AdapterModuleMixin):
DECODER_TYPE: type = TransformerDecoderAdapter

# Higher level forwarding
def add_adapter(self, name: str, cfg: dict):
cfg = self._update_adapter_cfg_input_dim(cfg)
self._decoder.add_adapter(name, cfg) # type: adapter_mixins.AdapterModuleMixin

def is_adapter_available(self) -> bool:
return self._decoder.is_adapter_available() # type: adapter_mixins.AdapterModuleMixin

def set_enabled_adapters(self, name: Optional[str] = None, enabled: bool = True):
self._decoder.set_enabled_adapters(name=name, enabled=enabled) # # type: adapter_mixins.AdapterModuleMixin

def get_enabled_adapters(self) -> List[str]:
names = set([])
names.update(self._decoder.get_enabled_adapters()) # type: adapter_mixins.AdapterModuleMixin

names = sorted(list(names))
return names

def _update_adapter_cfg_input_dim(self, cfg: DictConfig):
cfg = adapter_utils.update_adapter_cfg_input_dim(self, cfg, module_dim=self._hidden_size)
return cfg


"""
Register any additional information
"""
if adapter_mixins.get_registered_adapter(TransformerDecoderNM) is None:
adapter_mixins.register_adapter(base_class=TransformerDecoderNM, adapter_class=TransformerDecoderNMAdapter)
102 changes: 101 additions & 1 deletion nemo/collections/asr/modules/transformer/transformer_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,22 @@
# limitations under the License.

import copy
from typing import List, Optional, Set

import torch
import torch.nn as nn
from omegaconf import DictConfig

from nemo.collections.asr.modules.transformer.transformer_modules import MultiHeadAttention, PositionWiseFF
from nemo.collections.asr.parts.submodules.adapters.attention_adapter_mixin import AttentionAdapterModuleMixin
from nemo.collections.asr.parts.utils import adapter_utils
from nemo.collections.common.parts import form_attention_mask
from nemo.core.classes.mixins import adapter_mixins

__all__ = ["TransformerDecoder"]


class TransformerDecoderBlock(nn.Module):
class TransformerDecoderBlock(nn.Module, AttentionAdapterModuleMixin):
"""
Building block of Transformer decoder.

Expand Down Expand Up @@ -63,6 +68,9 @@ def __init__(
self.layer_norm_3 = nn.LayerNorm(hidden_size, eps=1e-5)
self.third_sub_layer = PositionWiseFF(hidden_size, inner_size, ffn_dropout, hidden_act)

# Information for the adapter module mixin
self.self_attention_model = "transf_abs"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any reason for hard coding self.self_attention_model?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is slightly a hack, the adapter forward looks up the type of the attention model to see what kind of adapter to call (and what arguments it takes). It just ignores it if its not rel_pos_attention or 'local_rel_pos_attention` or some such variants


def forward_preln(self, decoder_query, decoder_mask, decoder_keys, encoder_states, encoder_mask):
"""
Pre-LayerNorm block
Expand All @@ -74,6 +82,17 @@ def forward_preln(self, decoder_query, decoder_mask, decoder_keys, encoder_state
self_attn_output = self.first_sub_layer(decoder_query, decoder_keys, decoder_keys, decoder_mask)
self_attn_output += residual

if self.is_adapter_available():
# Call the MHA adapters
pack_ip = {
'x': self_attn_output,
'loc': 'mha',
'att_mask': decoder_mask,
'pos_emb': None,
}
pack_ip = self.forward_enabled_adapters(pack_ip)
self_attn_output = pack_ip['x']

residual = self_attn_output
self_attn_output = self.layer_norm_2(self_attn_output)
enc_dec_attn_output = self.second_sub_layer(self_attn_output, encoder_states, encoder_states, encoder_mask)
Expand All @@ -84,6 +103,15 @@ def forward_preln(self, decoder_query, decoder_mask, decoder_keys, encoder_state
output_states = self.third_sub_layer(enc_dec_attn_output)
output_states += residual

if self.is_adapter_available():
krishnacpuvvada marked this conversation as resolved.
Show resolved Hide resolved
# Call the Linear adapters
pack_ip = {
'x': output_states,
'loc': 'post',
}
pack_ip = self.forward_enabled_adapters(pack_ip)
output_states = pack_ip['x']

return output_states

def forward_postln(self, decoder_query, decoder_mask, decoder_keys, encoder_states, encoder_mask):
Expand All @@ -93,6 +121,18 @@ def forward_postln(self, decoder_query, decoder_mask, decoder_keys, encoder_stat
"""
self_attn_output = self.first_sub_layer(decoder_query, decoder_keys, decoder_keys, decoder_mask)
self_attn_output += decoder_query

if self.is_adapter_available():
# Call the MHA adapters
pack_ip = {
'x': self_attn_output,
'loc': 'mha',
'att_mask': decoder_mask,
'pos_emb': None,
}
pack_ip = self.forward_enabled_adapters(pack_ip)
self_attn_output = pack_ip['x']

self_attn_output = self.layer_norm_1(self_attn_output)

enc_dec_attn_output = self.second_sub_layer(self_attn_output, encoder_states, encoder_states, encoder_mask)
Expand All @@ -101,6 +141,16 @@ def forward_postln(self, decoder_query, decoder_mask, decoder_keys, encoder_stat

output_states = self.third_sub_layer(enc_dec_attn_output)
output_states += enc_dec_attn_output

if self.is_adapter_available():
# Call the linear adapters
pack_ip = {
'x': output_states,
'loc': 'post',
}
pack_ip = self.forward_enabled_adapters(pack_ip)
output_states = pack_ip['x']

return self.layer_norm_3(output_states)

def forward(self, decoder_query, decoder_mask, decoder_keys, encoder_states, encoder_mask):
Expand All @@ -109,6 +159,19 @@ def forward(self, decoder_query, decoder_mask, decoder_keys, encoder_states, enc
else:
return self.forward_postln(decoder_query, decoder_mask, decoder_keys, encoder_states, encoder_mask)

def get_accepted_adapter_types(self) -> Set[type]:
types = super().get_accepted_adapter_types()

if len(types) == 0:
self.set_accepted_adapter_types(
[
adapter_utils.LINEAR_ADAPTER_CLASSPATH,
adapter_utils.TRANSFORMER_MHA_ADAPTER_CLASSPATH,
]
)
types = self.get_accepted_adapter_types()
return types


class TransformerDecoder(nn.Module):
def __init__(
Expand All @@ -131,6 +194,8 @@ def __init__(
else:
self.final_layer_norm = None

self.d_model = hidden_size

layer = TransformerDecoderBlock(
hidden_size,
inner_size,
Expand Down Expand Up @@ -219,3 +284,38 @@ def input_example(self, max_batch=1, max_dim=256):
input_ids = torch.randint(low=0, high=2048, size=(max_batch, max_dim, 1024), device=sample.device)
encoder_mask = torch.randint(low=0, high=1, size=(max_batch, max_dim), device=sample.device)
return tuple([input_ids, encoder_mask, input_ids, encoder_mask])


class TransformerDecoderAdapter(TransformerDecoder, adapter_mixins.AdapterModuleMixin):

# Higher level forwarding
def add_adapter(self, name: str, cfg: dict):
cfg = self._update_adapter_cfg_input_dim(cfg)
for transformer_layer in self.layers: # type: adapter_mixins.AdapterModuleMixin
transformer_layer.add_adapter(name, cfg)

def is_adapter_available(self) -> bool:
return any([transformer_layer.is_adapter_available() for transformer_layer in self.layers])

def set_enabled_adapters(self, name: Optional[str] = None, enabled: bool = True):
for transformer_layer in self.layers: # type: adapter_mixins.AdapterModuleMixin
transformer_layer.set_enabled_adapters(name=name, enabled=enabled)

def get_enabled_adapters(self) -> List[str]:
names = set([])
for transformer_layer in self.layers: # type: adapter_mixins.AdapterModuleMixin
names.update(transformer_layer.get_enabled_adapters())

names = sorted(list(names))
return names

def _update_adapter_cfg_input_dim(self, cfg: DictConfig):
cfg = adapter_utils.update_adapter_cfg_input_dim(self, cfg, module_dim=self.d_model)
return cfg


"""
Register any additional information
"""
if adapter_mixins.get_registered_adapter(TransformerDecoder) is None:
adapter_mixins.register_adapter(base_class=TransformerDecoder, adapter_class=TransformerDecoderAdapter)
Loading
Loading