From 7415da60c5b7da940ead0dabbbcbd80c11c417b5 Mon Sep 17 00:00:00 2001 From: Somshubra Majumdar Date: Thu, 6 Jun 2024 18:34:48 -0700 Subject: [PATCH 01/21] Fix assertions for adapter types Signed-off-by: smajumdar --- .../asr/models/aed_multitask_models.py | 11 +- .../transformer/transformer_encoders.py | 104 ++++++++++- .../asr/parts/mixins/asr_adapter_mixins.py | 157 +++++++---------- .../asr/parts/submodules/adapters/__init__.py | 4 + .../adapters/attention_adapter_mixin.py | 108 ++++++++++++ .../multi_head_attention_adapter_module.py | 2 + ...mer_multi_head_attention_adapter_module.py | 126 ++++++++++++++ .../asr/parts/submodules/conformer_modules.py | 61 +------ .../parts/submodules/squeezeformer_modules.py | 63 +------ .../asr/parts/utils/adapter_utils.py | 6 +- nemo/core/classes/mixins/adapter_mixins.py | 64 +++++-- .../mixins/adapters/test_asr_adapter_mixin.py | 163 +++++++++++++++++- 12 files changed, 635 insertions(+), 234 deletions(-) create mode 100644 nemo/collections/asr/parts/submodules/adapters/attention_adapter_mixin.py create mode 100644 nemo/collections/asr/parts/submodules/adapters/transformer_multi_head_attention_adapter_module.py diff --git a/nemo/collections/asr/models/aed_multitask_models.py b/nemo/collections/asr/models/aed_multitask_models.py index edb591921782..1f65961da264 100644 --- a/nemo/collections/asr/models/aed_multitask_models.py +++ b/nemo/collections/asr/models/aed_multitask_models.py @@ -30,7 +30,7 @@ ) from nemo.collections.asr.metrics import BLEU, WER from nemo.collections.asr.models.asr_model import ASRModel, ExportableEncDecModel -from nemo.collections.asr.parts.mixins import ASRBPEMixin, ASRTranscriptionMixin +from nemo.collections.asr.parts.mixins import ASRBPEMixin, ASRTranscriptionMixin, ASRModuleMixin from nemo.collections.asr.parts.mixins.transcription import ( GenericTranscriptionType, InternalTranscribeConfig, @@ -114,7 +114,7 @@ def __post_init__(self): self.prompt = parse_multitask_prompt(self.prompt) -class EncDecMultiTaskModel(ASRModel, ExportableEncDecModel, ASRBPEMixin, ASRTranscriptionMixin): +class EncDecMultiTaskModel(ASRModel, ExportableEncDecModel, ASRBPEMixin, ASRModuleMixin, ASRTranscriptionMixin): """Base class for AED multi-task models""" def __init__(self, cfg: DictConfig, trainer: Trainer = None): @@ -224,6 +224,9 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self.decoding, tokenize=self.cfg.get('bleu_tokenizer', "13a"), log_prediction=False ) # Wer is handling logging + # Setup encoder adapters (from ASRAdapterModelMixin) + self.setup_adapters() + def change_decoding_strategy(self, decoding_cfg: DictConfig): """ Changes decoding strategy used during Multi Task decoding process. @@ -1003,6 +1006,10 @@ def predict_step(self, batch, batch_idx=0, dataloader_idx=0, has_processed_signa text = [self.decoding.strip_special_tokens(t) for t in text] return text + @property + def adapter_module_names(self) -> List[str]: + return ['', 'encoder', 'transf_encoder', 'transf_decoder'] + def parse_multitask_prompt(prompt: dict | None) -> list[dict]: if prompt is None or not prompt: diff --git a/nemo/collections/asr/modules/transformer/transformer_encoders.py b/nemo/collections/asr/modules/transformer/transformer_encoders.py index 544d561267cf..faf1fd09dc2d 100644 --- a/nemo/collections/asr/modules/transformer/transformer_encoders.py +++ b/nemo/collections/asr/modules/transformer/transformer_encoders.py @@ -13,17 +13,22 @@ # limitations under the License. import copy +from typing import Optional, List, Set +from omegaconf import DictConfig import torch import torch.nn as nn from nemo.collections.asr.modules.transformer.transformer_modules import MultiHeadAttention, PositionWiseFF +from nemo.collections.asr.parts.submodules.adapters.attention_adapter_mixin import AttentionAdapterModuleMixin +from nemo.collections.asr.parts.utils import adapter_utils +from nemo.core.classes.mixins import adapter_mixins from nemo.collections.common.parts import form_attention_mask __all__ = ["TransformerEncoder"] -class TransformerEncoderBlock(nn.Module): +class TransformerEncoderBlock(nn.Module, AttentionAdapterModuleMixin): """ Building block of Transformer encoder. @@ -59,6 +64,9 @@ def __init__( self.layer_norm_2 = nn.LayerNorm(hidden_size, eps=1e-5) self.second_sub_layer = PositionWiseFF(hidden_size, inner_size, ffn_dropout, hidden_act) + # Information for the adapter module mixin + self.self_attention_model = "abs_pos" + def forward_preln(self, encoder_query, encoder_mask, encoder_keys): """ Pre-LayerNorm block @@ -70,11 +78,31 @@ def forward_preln(self, encoder_query, encoder_mask, encoder_keys): self_attn_output = self.first_sub_layer(encoder_query, encoder_keys, encoder_keys, encoder_mask) self_attn_output += residual + if self.is_adapter_available(): + # Call the MHA adapters + pack_ip = { + 'x': self_attn_output, + 'loc': 'mha', + 'att_mask': encoder_mask, + 'pos_emb': None, + } + pack_ip = self.forward_enabled_adapters(pack_ip) + self_attn_output = pack_ip['x'] + residual = self_attn_output self_attn_output = self.layer_norm_2(self_attn_output) output_states = self.second_sub_layer(self_attn_output) output_states += residual + if self.is_adapter_available(): + # Call the Linear adapters + pack_ip = { + 'x': output_states, + 'loc': 'post', + } + pack_ip = self.forward_enabled_adapters(pack_ip) + output_states = pack_ip['x'] + return output_states def forward_postln(self, encoder_query, encoder_mask, encoder_keys): @@ -84,10 +112,32 @@ def forward_postln(self, encoder_query, encoder_mask, encoder_keys): """ self_attn_output = self.first_sub_layer(encoder_query, encoder_keys, encoder_keys, encoder_mask) self_attn_output += encoder_query + + if self.is_adapter_available(): + # Call the MHA adapters + pack_ip = { + 'x': self_attn_output, + 'loc': 'mha', + 'att_mask': encoder_mask, + 'pos_emb': None, + } + pack_ip = self.forward_enabled_adapters(pack_ip) + self_attn_output = pack_ip['x'] + self_attn_output = self.layer_norm_1(self_attn_output) output_states = self.second_sub_layer(self_attn_output) output_states += self_attn_output + + if self.is_adapter_available(): + # Call the linear adapters + pack_ip = { + 'x': output_states, + 'loc': 'post', + } + pack_ip = self.forward_enabled_adapters(pack_ip) + output_states = pack_ip['x'] + output_states = self.layer_norm_2(output_states) return output_states @@ -121,6 +171,8 @@ def __init__( else: self.final_layer_norm = None + self.d_model = hidden_size + layer = TransformerEncoderBlock( hidden_size, inner_size, @@ -172,3 +224,53 @@ def forward(self, encoder_states, encoder_mask, encoder_mems_list=None, return_m return cached_mems_list else: return cached_mems_list[-1] + + +class TransformerEncoderAdapter(TransformerEncoder, adapter_mixins.AdapterModuleMixin): + + # Higher level forwarding + def add_adapter(self, name: str, cfg: dict): + self.check_supported_adapter_type_(cfg) + cfg = self._update_adapter_cfg_input_dim(cfg) + for transformer_layer in self.layers: # type: adapter_mixins.AdapterModuleMixin + transformer_layer.add_adapter(name, cfg) + + def is_adapter_available(self) -> bool: + return any([transformer_layer.is_adapter_available() for transformer_layer in self.layers]) + + def set_enabled_adapters(self, name: Optional[str] = None, enabled: bool = True): + for transformer_layer in self.layers: # type: adapter_mixins.AdapterModuleMixin + transformer_layer.set_enabled_adapters(name=name, enabled=enabled) + + def get_enabled_adapters(self) -> List[str]: + names = set([]) + for transformer_layer in self.layers: # type: adapter_mixins.AdapterModuleMixin + names.update(transformer_layer.get_enabled_adapters()) + + names = sorted(list(names)) + return names + + def _update_adapter_cfg_input_dim(self, cfg: DictConfig): + cfg = adapter_utils.update_adapter_cfg_input_dim(self, cfg, module_dim=self.d_model) + return cfg + + def get_accepted_adapter_types(self) -> Set[type]: + types = super().get_accepted_adapter_types() + + if len(types) == 0: + self.set_accepted_adapter_types( + [ + adapter_utils.LINEAR_ADAPTER_CLASSPATH, + adapter_utils.TRANSFORMER_ENCODER_MHA_ADAPTER_CLASSPATH, + ] + ) + types = self.get_accepted_adapter_types() + return types + + + +""" +Register any additional information +""" +if adapter_mixins.get_registered_adapter(TransformerEncoder) is None: + adapter_mixins.register_adapter(base_class=TransformerEncoder, adapter_class=TransformerEncoderAdapter) diff --git a/nemo/collections/asr/parts/mixins/asr_adapter_mixins.py b/nemo/collections/asr/parts/mixins/asr_adapter_mixins.py index f452acd19847..ae36b2d882da 100644 --- a/nemo/collections/asr/parts/mixins/asr_adapter_mixins.py +++ b/nemo/collections/asr/parts/mixins/asr_adapter_mixins.py @@ -54,14 +54,10 @@ def setup_adapters(self): supports_adapters = False # At least the encoder must extend AdapterModuleMixin - if hasattr(self, 'encoder') and isinstance(self.encoder, AdapterModuleMixin): - supports_adapters |= True - - if hasattr(self, 'decoder') and isinstance(self.decoder, AdapterModuleMixin): - supports_adapters |= True - - if hasattr(self, 'joint') and isinstance(self.joint, AdapterModuleMixin): - supports_adapters |= True + valid_adapter_names = [x for x in self.adapter_module_names if x != ''] + for module_name in valid_adapter_names: + if hasattr(self, module_name) and isinstance(getattr(self, module_name), AdapterModuleMixin): + supports_adapters |= True # If adapters are supported, setup the adapter config + any modules (pre-existing adapter modules) if supports_adapters: @@ -87,24 +83,28 @@ def add_adapter(self, name: str, cfg: DictConfig): else: module_names = [module_name] + valid_module_names = [x for x in self.adapter_module_names if x != ''] + default_module_name = self.default_adapter_module_name + + # Check if default module name is None or not + if default_module_name is None: + raise ValueError(f"Default module name is None. Class {self.__class__.__name__} must implement " + f"`default_adapter_module_name`") + # Update the model.cfg with information about the new adapter from cfg with open_dict(self.cfg): for module_name in module_names: # Check if encoder adapters should be added - if module_name in ('', 'encoder'): - # Dispatch the call to the encoder. - self.encoder.add_adapter(name=name, cfg=cfg) - - # Check if decoder adapters should be added - if module_name == 'decoder': - # Dispatch call to the decoder. - self.decoder.add_adapter(name=name, cfg=cfg) + if module_name == '': + if hasattr(self, default_module_name): + # Dispatch the call to the default model. + getattr(self, default_module_name).add_adapter(name=name, cfg=cfg) - # Check if joint adapters should be added; - # Note: We need additional check if joint even exists in model (for CTC models) - if hasattr(self, 'joint') and module_name == 'joint': - # Dispatch call to the joint. - self.joint.add_adapter(name=name, cfg=cfg) + elif module_name in valid_module_names: + # Check if module exists + if hasattr(self, module_name): + # Dispatch the call to the module. + getattr(self, module_name).add_adapter(name=name, cfg=cfg) def is_adapter_available(self) -> bool: """ @@ -116,15 +116,12 @@ def is_adapter_available(self) -> bool: """ config_contains_adapter = super().is_adapter_available() - # Forward the method call to the individual modules - if hasattr(self, 'encoder') and isinstance(self.encoder, AdapterModuleMixin): - config_contains_adapter |= self.encoder.is_adapter_available() - - if hasattr(self, 'decoder') and isinstance(self.decoder, AdapterModuleMixin): - config_contains_adapter |= self.decoder.is_adapter_available() + valid_module_names = [x for x in self.adapter_module_names if x != ''] - if hasattr(self, 'joint') and isinstance(self.joint, AdapterModuleMixin): - config_contains_adapter |= self.joint.is_adapter_available() + # Forward the method call to the individual modules + for module_name in valid_module_names: + if hasattr(self, module_name) and isinstance(getattr(self, module_name), AdapterModuleMixin): + config_contains_adapter |= getattr(self, module_name).is_adapter_available() return config_contains_adapter @@ -160,23 +157,27 @@ def set_enabled_adapters(self, name: Optional[str] = None, enabled: bool = True) else: module_names = [module_name] + valid_module_names = [x for x in self.adapter_module_names if x != ''] + default_module_name = self.default_adapter_module_name + + # Check if default module name is None or not + if default_module_name is None: + raise ValueError(f"Default module name is None. Class {self.__class__.__name__} must implement " + f"`default_adapter_module_name`") + + # Forward the method call to the individual modules if they exist for module_name in module_names: # Check if encoder adapters should be used - # Dispatch the call to the encoder. - if name is None or module_name in ('', 'encoder'): - if self.encoder.is_adapter_available(): - self.encoder.set_enabled_adapters(name=name, enabled=enabled) - - # Dispatch the call to the decoder. - if name is None or module_name == 'decoder': - if self.decoder.is_adapter_available(): - self.decoder.set_enabled_adapters(name=name, enabled=enabled) - - # Dispatch the call to the joint. - # Note: We need additional check for joint, since it may not exist (CTC models). - if name is None or module_name == 'joint': - if hasattr(self, 'joint') and self.joint.is_adapter_available(): - self.joint.set_enabled_adapters(name=name, enabled=enabled) + + if module_name == '': + if hasattr(self, default_module_name): + # Dispatch the call to the default model. + getattr(self, default_module_name).set_enabled_adapters(name=name, enabled=enabled) + + elif module_name in valid_module_names: + if hasattr(self, module_name): + # Dispatch the call to the module. + getattr(self, module_name).set_enabled_adapters(name=name, enabled=enabled) def get_enabled_adapters(self) -> List[str]: """ @@ -187,15 +188,12 @@ def get_enabled_adapters(self) -> List[str]: """ enabled_adapters = super().get_enabled_adapters() - # Check if encoder adapters should be used or are enabled - if hasattr(self, 'encoder') and isinstance(self.encoder, AdapterModuleMixin): - enabled_adapters.extend(self.encoder.get_enabled_adapters()) + valid_module_names = [x for x in self.adapter_module_names if x != ''] - if hasattr(self, 'decoder') and isinstance(self.decoder, AdapterModuleMixin): - enabled_adapters.extend(self.decoder.get_enabled_adapters()) - - if hasattr(self, 'joint') and isinstance(self.joint, AdapterModuleMixin): - enabled_adapters.extend(self.joint.get_enabled_adapters()) + # Check if encoder adapters should be used or are enabled + for module_name in valid_module_names: + if hasattr(self, module_name) and isinstance(getattr(self, module_name), AdapterModuleMixin): + enabled_adapters.extend(getattr(self, module_name).get_enabled_adapters()) enabled_adapters = list(sorted(list(set(enabled_adapters)))) @@ -208,44 +206,19 @@ def check_valid_model_with_adapter_support_(self): # Obtain the global adapter config if possible, otherwise use sensible defaults. global_cfg = self._get_global_cfg() - # Test whether the encoder supports adapters - use_encoder_adapter = global_cfg.get('check_encoder_adapter', True) - if use_encoder_adapter: - if not hasattr(self, 'encoder'): - logging.warning( - "Cannot add adapter to this object as it does not have an `encoder` sub-module!", - mode=logging_mode.ONCE, - ) - - if hasattr(self, 'encoder') and not isinstance(self.encoder, AdapterModuleMixin): - logging.warning( - f'{self.encoder.__class__.__name__} does not implement `AdapterModuleMixin`', - mode=logging_mode.ONCE, - ) - - # Test whether the decoder supports adapters - use_decoder_adapter = global_cfg.get('check_decoder_adapter', True) - if use_decoder_adapter: - if not hasattr(self, 'decoder'): - logging.warning( - "Cannot add adapter to this object as it does not have an `decoder` sub-module!", - mode=logging_mode.ONCE, - ) - - if hasattr(self, 'decoder') and not isinstance(self.decoder, AdapterModuleMixin): - logging.warning( - f'{self.decoder.__class__.__name__} does not implement `AdapterModuleMixin`', - mode=logging_mode.ONCE, - ) - - # Test whether the joint supports adapters - use_joint_adapter = global_cfg.get('check_joint_adapter', True) - if use_joint_adapter: - # Joint is only for RNNT models, skip assertion that it must always exist. - if hasattr(self, 'joint') and not isinstance(self.joint, AdapterModuleMixin): - logging.warning( - f'{self.joint.__class__.__name__} does not implement `AdapterModuleMixin`', mode=logging_mode.ONCE - ) + valid_module_names = [x for x in self.adapter_module_names if x != ''] + + for module_name in valid_module_names: + check_adapter_support = global_cfg.get(f'check_{module_name}_adapter', True) + + if check_adapter_support: + # Test whether the module supports adapters + if hasattr(self, module_name) and not isinstance(getattr(self, module_name), AdapterModuleMixin): + logging.warning( + f'Module `{module_name}` exists, but {getattr(self, module_name).__class__.__name__} ' + f'does not implement `AdapterModuleMixin`', + mode=logging_mode.ONCE, + ) def resolve_adapter_module_name_(self, name: str) -> Tuple[str, str]: """ @@ -293,3 +266,7 @@ def _get_global_cfg(self): def adapter_module_names(self) -> List[str]: valid_module_names = ['', 'encoder', 'decoder', 'joint'] return valid_module_names + + @property + def default_adapter_module_name(self) -> str: + return 'encoder' diff --git a/nemo/collections/asr/parts/submodules/adapters/__init__.py b/nemo/collections/asr/parts/submodules/adapters/__init__.py index 6aa05d07dea1..52458dfb06c3 100644 --- a/nemo/collections/asr/parts/submodules/adapters/__init__.py +++ b/nemo/collections/asr/parts/submodules/adapters/__init__.py @@ -24,3 +24,7 @@ RelPositionMultiHeadAttentionAdapter, RelPositionMultiHeadAttentionAdapterConfig, ) +from nemo.collections.asr.parts.submodules.adapters.transformer_multi_head_attention_adapter_module import ( + TransformerEncoderMultiHeadAttentionAdapter, + TransformerEncoderMultiHeadAttentionAdapterConfig, +) diff --git a/nemo/collections/asr/parts/submodules/adapters/attention_adapter_mixin.py b/nemo/collections/asr/parts/submodules/adapters/attention_adapter_mixin.py new file mode 100644 index 000000000000..ed9a2906b6a8 --- /dev/null +++ b/nemo/collections/asr/parts/submodules/adapters/attention_adapter_mixin.py @@ -0,0 +1,108 @@ +import torch + +from nemo.core.classes.mixins import adapter_mixins + + +class AttentionAdapterModuleMixin(adapter_mixins.AdapterModuleMixin): + """ + Utility class that implements a custom forward method for Modules that are attention based. + Attention based adapters can support either linear adapters, and Multi-Head Attention adapters. + + However, Multi Head Attention adapters require additional arguments, such as `att_mask` and `pos_emb`. + This utility class unifies the adapter forward pass for both types of adapters. + + .. Usage: + + To use this class, inherit from this class, and when calling self.foward_enabled_adapters() pass the following: + + .. code-block:: python + + if self.is_adapter_available(): + # Call the MHA adapters + pack_ip = { + 'x': residual, + 'loc': 'mha', + 'att_mask': att_mask, + 'pos_emb': pos_emb, + } + pack_ip = self.forward_enabled_adapters(pack_ip) + residual = pack_ip['x'] + + if self.is_adapter_available(): + # Call the Linear adapters + pack_ip = { + 'x': x, + 'loc': 'post', + } + pack_ip = self.forward_enabled_adapters(pack_ip) + x = pack_ip['x'] + """ + + def forward_single_enabled_adapter_( + self, + input: dict, + adapter_module: torch.nn.Module, + *, + adapter_name: str, + adapter_strategy: 'nemo.core.classes.mixins.adapter_mixin_strategies.AbstractAdapterStrategy', + ): + """ + Perform the forward step of a single adapter module on some input data. + + **Note**: Subclasses can override this method to accommodate more complicate adapter forward steps. + + Args: + input: Dictionary of packed tensors. The dict should contain at least + `x`: output tensor + `loc`: Semantic location in module where this adapter was called + `att_mask`: Optional, Attention mask + `pos_emb`: Optional, Positional Embedding for Relative Positional Encoding. + The output tensor of the calling module is the input to the first adapter, whose output + is then chained to the next adapter until all adapters are consumed. + adapter_module: The adapter module that is currently required to perform the forward pass. + adapter_name: The resolved name of the adapter that is undergoing the current forward pass. + adapter_strategy: A subclass of `AbstractAdapterStrategy`, that determines how the + output of the adapter should be merged with the input, or if it should be merged at all. + + Returns: + The result tensor, after the current active adapter has finished its forward pass. + """ + if not hasattr(self, 'self_attention_model'): + raise RuntimeError("self_attention_model attribute not found in the module! Please set in the module " + "a string attribute 'self_attention_model' with value 'abs_pos', 'rel_pos' or " + "other supported self-attention model types.") + + # Collect imports to prevent circular imports + from nemo.collections.asr.parts.submodules import multi_head_attention as conformer_mha + from nemo.collections.asr.modules.transformer import transformer_modules as transformer_mha + + + # (input: torch.Tensor, adapter: torch.nn.Module, *, module: 'AdapterModuleMixin') + x = input['x'] + loc = input['loc'] + att_mask = input.get('att_mask', None) + pos_emb = input.get('pos_emb', None) + + from nemo.collections.common.parts import adapter_modules + if isinstance(adapter_module, adapter_modules.LinearAdapter) and loc == 'post': + output = adapter_strategy(x, adapter_module, module=self) + + if isinstance(adapter_module, conformer_mha.MultiHeadAttention) and loc == 'mha': + if self.self_attention_model == 'rel_pos': + x = dict(query=x, key=x, value=x, mask=att_mask, pos_emb=pos_emb) + output = adapter_strategy(x, adapter_module, module=self) + + elif self.self_attention_model == 'abs_pos': + x = dict(query=x, key=x, value=x, mask=att_mask) + output = adapter_strategy(x, adapter_module, module=self) + + else: + raise ValueError(f"Unsupported value of self_attention_model , provided {self.self_attention_model}!") + + else: + # No adapter compatible, skip + output = x + + input['x'] = output + + return input diff --git a/nemo/collections/asr/parts/submodules/adapters/multi_head_attention_adapter_module.py b/nemo/collections/asr/parts/submodules/adapters/multi_head_attention_adapter_module.py index 3df51092ac4b..8afa5d283089 100644 --- a/nemo/collections/asr/parts/submodules/adapters/multi_head_attention_adapter_module.py +++ b/nemo/collections/asr/parts/submodules/adapters/multi_head_attention_adapter_module.py @@ -63,6 +63,7 @@ def forward(self, input: torch.Tensor, adapter: torch.nn.Module, *, module: 'Ada out = self.apply_stochastic_depth(out, input['value'], adapter, module=module) # Return the residual connection output = input + adapter(input) + print("Out", out.shape, "Input", input['value'].shape) result = input['value'] + out # If l2_lambda is activated, register the loss value @@ -390,3 +391,4 @@ class RelPositionalEncodingAdapterConfig: default_factory=lambda: adapter_mixin_strategies.ResidualAddAdapterStrategyConfig() ) _target_: str = "{0}.{1}".format(RelPositionalEncodingAdapter.__module__, RelPositionalEncodingAdapter.__name__) + diff --git a/nemo/collections/asr/parts/submodules/adapters/transformer_multi_head_attention_adapter_module.py b/nemo/collections/asr/parts/submodules/adapters/transformer_multi_head_attention_adapter_module.py new file mode 100644 index 000000000000..8a612de03023 --- /dev/null +++ b/nemo/collections/asr/parts/submodules/adapters/transformer_multi_head_attention_adapter_module.py @@ -0,0 +1,126 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass, field +from typing import Any, Optional + +import torch +from torch import nn as nn + +from nemo.collections.asr.parts.submodules.adapters.multi_head_attention_adapter_module import MHAResidualAddAdapterStrategy, MHAResidualAddAdapterStrategyConfig +from nemo.collections.asr.modules.transformer import transformer_modules +from nemo.collections.common.parts import adapter_modules +from nemo.core.classes.mixins import adapter_mixins, adapter_mixin_strategies + + +class TransformerEncoderMultiHeadAttentionAdapter(transformer_modules.MultiHeadAttention, adapter_modules.AdapterModuleUtil): + """Multi-Head Attention layer of Transformer Encoder. + + Args: + hidden_size (int): number of heads + num_attention_heads (int): size of the features + attn_score_dropout (float): dropout rate for the attention scores + attn_layer_dropout (float): dropout rate for the layer + proj_dim (int, optional): Optional integer value for projection before computing attention. + If None, then there is no projection (equivalent to proj_dim = n_feat). + If > 0, then will project the n_feat to proj_dim before calculating attention. + If <0, then will equal n_head, so that each head has a projected dimension of 1. + adapter_strategy: By default, MHAResidualAddAdapterStrategyConfig. An adapter composition function object. + """ + + def __init__( + self, + hidden_size: int, + num_attention_heads: int, + attn_score_dropout: float = 0.0, + attn_layer_dropout: float = 0.0, + proj_dim: Optional[int] = None, + adapter_strategy: MHAResidualAddAdapterStrategy = None, + ): + super().__init__(hidden_size=hidden_size, num_attention_heads=num_attention_heads, + attn_score_dropout=attn_score_dropout, attn_layer_dropout=attn_layer_dropout) + + self.pre_norm = nn.LayerNorm(hidden_size) + + # Set the projection dim to number of heads automatically + if proj_dim is not None and proj_dim < 1: + proj_dim = num_attention_heads + + self.proj_dim = proj_dim + + # Recompute weights for projection dim + if self.proj_dim is not None: + if self.proj_dim % num_attention_heads != 0: + raise ValueError(f"proj_dim ({proj_dim}) is not divisible by n_head ({num_attention_heads})") + + self.attn_head_size = self.proj_dim // num_attention_heads + self.attn_scale = math.sqrt(math.sqrt(self.attn_head_size)) + self.query_net = nn.Linear(num_attention_heads, self.proj_dim) + self.key_net = nn.Linear(num_attention_heads, self.proj_dim) + self.value_net = nn.Linear(num_attention_heads, self.proj_dim) + self.out_projection = nn.Linear(self.proj_dim, hidden_size) + + # Setup adapter strategy + self.setup_adapter_strategy(adapter_strategy) + + # reset parameters for Q to be identity operation + self.reset_parameters() + + def forward(self, queries, keys, values, attention_mask): + """Compute 'Scaled Dot Product Attention'. + Args: + query (torch.Tensor): (batch, time1, size) + key (torch.Tensor): (batch, time2, size) + value(torch.Tensor): (batch, time2, size) + mask (torch.Tensor): (batch, time1, time2) + cache (torch.Tensor) : (batch, time_cache, size) + + returns: + output (torch.Tensor): transformed `value` (batch, time1, d_model) weighted by the query dot key attention + cache (torch.Tensor) : (batch, time_cache_next, size) + """ + # Need to perform duplicate computations as at this point the tensors have been + # separated by the adapter forward + query = self.pre_norm(queries) + key = self.pre_norm(keys) + value = self.pre_norm(values) + + return super().forward(query, key, value, attention_mask) + + def reset_parameters(self): + with torch.no_grad(): + nn.init.zeros_(self.out_projection.weight) + nn.init.zeros_(self.out_projection.bias) + + def get_default_strategy_config(self) -> 'dataclass': + return MHAResidualAddAdapterStrategyConfig() + + +@dataclass +class TransformerEncoderMultiHeadAttentionAdapterConfig: + hidden_size: int + num_attention_heads: int + attn_score_dropout: float = 0.0 + attn_layer_dropout: float = 0.0 + proj_dim: Optional[int] = None + adapter_strategy: Optional[Any] = field(default_factory=lambda: MHAResidualAddAdapterStrategyConfig()) + _target_: str = "{0}.{1}".format(TransformerEncoderMultiHeadAttentionAdapter.__module__, + TransformerEncoderMultiHeadAttentionAdapter.__name__) + + +""" Register the Adapter Modules to the Adapter Module Registry """ +if adapter_mixins.get_registered_adapter(TransformerEncoderMultiHeadAttentionAdapter) is None: + adapter_mixins.register_adapter(base_class=transformer_modules.MultiHeadAttention, + adapter_class=TransformerEncoderMultiHeadAttentionAdapter) \ No newline at end of file diff --git a/nemo/collections/asr/parts/submodules/conformer_modules.py b/nemo/collections/asr/parts/submodules/conformer_modules.py index 093cde63c439..74fc6e77d68b 100644 --- a/nemo/collections/asr/parts/submodules/conformer_modules.py +++ b/nemo/collections/asr/parts/submodules/conformer_modules.py @@ -24,6 +24,7 @@ RelPositionMultiHeadAttention, RelPositionMultiHeadAttentionLongformer, ) +from nemo.collections.asr.parts.submodules.adapters.attention_adapter_mixin import AttentionAdapterModuleMixin from nemo.collections.asr.parts.utils.activations import Swish from nemo.collections.common.parts import adapter_modules from nemo.collections.common.parts.utils import activation_registry @@ -33,7 +34,7 @@ __all__ = ['ConformerConvolution', 'ConformerFeedForward', 'ConformerLayer'] -class ConformerLayer(torch.nn.Module, AdapterModuleMixin, AccessMixin): +class ConformerLayer(torch.nn.Module, AttentionAdapterModuleMixin, AccessMixin): """A single block of the Conformer encoder. Args: @@ -223,64 +224,6 @@ def forward(self, x, att_mask=None, pos_emb=None, pad_mask=None, cache_last_chan else: return x, cache_last_channel, cache_last_time - def forward_single_enabled_adapter_( - self, - input: dict, - adapter_module: torch.nn.Module, - *, - adapter_name: str, - adapter_strategy: 'nemo.core.classes.mixins.adapter_mixin_strategies.AbstractAdapterStrategy', - ): - """ - Perform the forward step of a single adapter module on some input data. - - **Note**: Subclasses can override this method to accommodate more complicate adapter forward steps. - - Args: - input: Dictionary of packed tensors. The dict should contain at least - `x`: output tensor - `loc`: Semantic location in module where this adapter was called - `att_mask`: Optional, Attention mask - `pos_emb`: Optional, Positional Embedding for Relative Positional Encoding. - The output tensor of the calling module is the input to the first adapter, whose output - is then chained to the next adapter until all adapters are consumed. - adapter_module: The adapter module that is currently required to perform the forward pass. - adapter_name: The resolved name of the adapter that is undergoing the current forward pass. - adapter_strategy: A subclass of `AbstractAdapterStrategy`, that determines how the - output of the adapter should be merged with the input, or if it should be merged at all. - - Returns: - The result tensor, after the current active adapter has finished its forward pass. - """ - # (input: torch.Tensor, adapter: torch.nn.Module, *, module: 'AdapterModuleMixin') - x = input['x'] - loc = input['loc'] - att_mask = input.get('att_mask', None) - pos_emb = input.get('pos_emb', None) - - if isinstance(adapter_module, adapter_modules.LinearAdapter) and loc == 'post': - output = adapter_strategy(x, adapter_module, module=self) - - elif isinstance(adapter_module, MultiHeadAttention) and loc == 'mha': - if self.self_attention_model == 'rel_pos': - x = dict(query=x, key=x, value=x, mask=att_mask, pos_emb=pos_emb) - output = adapter_strategy(x, adapter_module, module=self) - - elif self.self_attention_model == 'abs_pos': - x = dict(query=x, key=x, value=x, mask=att_mask) - output = adapter_strategy(x, adapter_module, module=self) - - else: - raise ValueError(f"Unsupported value of self_attention_model , provided {self.self_attention_model}!") - - else: - # No adapter compatible, skip - output = x - - input['x'] = output - - return input - class ConformerConvolution(nn.Module): """The convolution module for the Conformer model. diff --git a/nemo/collections/asr/parts/submodules/squeezeformer_modules.py b/nemo/collections/asr/parts/submodules/squeezeformer_modules.py index ff2cf7c5b3cc..407714c6dca9 100644 --- a/nemo/collections/asr/parts/submodules/squeezeformer_modules.py +++ b/nemo/collections/asr/parts/submodules/squeezeformer_modules.py @@ -21,9 +21,8 @@ MultiHeadAttention, RelPositionMultiHeadAttention, ) -from nemo.collections.common.parts import adapter_modules +from nemo.collections.asr.parts.submodules.adapters.attention_adapter_mixin import AttentionAdapterModuleMixin from nemo.core.classes.mixins import AccessMixin -from nemo.core.classes.mixins.adapter_mixins import AdapterModuleMixin __all__ = ['SqueezeformerLayer', 'ConformerFeedForward', 'SqueezeformerLayer'] @@ -57,7 +56,7 @@ def forward(self, x): return x * scale + bias -class SqueezeformerLayer(torch.nn.Module, AdapterModuleMixin, AccessMixin): +class SqueezeformerLayer(torch.nn.Module, AttentionAdapterModuleMixin, AccessMixin): """A single block of the Squeezeformer encoder. Args: @@ -197,64 +196,6 @@ def forward(self, x, att_mask=None, pos_emb=None, pad_mask=None): return x - def forward_single_enabled_adapter_( - self, - input: dict, - adapter_module: torch.nn.Module, - *, - adapter_name: str, - adapter_strategy: 'nemo.core.classes.mixins.adapter_mixin_strategies.AbstractAdapterStrategy', - ): - """ - Perform the forward step of a single adapter module on some input data. - - **Note**: Subclasses can override this method to accommodate more complicate adapter forward steps. - - Args: - input: Dictionary of packed tensors. The dict should contain at least - `x`: output tensor - `loc`: Semantic location in module where this adapter was called - `att_mask`: Optional, Attention mask - `pos_emb`: Optional, Positional Embedding for Relative Positional Encoding. - The output tensor of the calling module is the input to the first adapter, whose output - is then chained to the next adapter until all adapters are consumed. - adapter_module: The adapter module that is currently required to perform the forward pass. - adapter_name: The resolved name of the adapter that is undergoing the current forward pass. - adapter_strategy: A subclass of `AbstractAdapterStrategy`, that determines how the - output of the adapter should be merged with the input, or if it should be merged at all. - - Returns: - The result tensor, after the current active adapter has finished its forward pass. - """ - # (input: torch.Tensor, adapter: torch.nn.Module, *, module: 'AdapterModuleMixin') - x = input['x'] - loc = input['loc'] - att_mask = input.get('att_mask', None) - pos_emb = input.get('pos_emb', None) - - if isinstance(adapter_module, adapter_modules.LinearAdapter) and loc == 'post': - output = adapter_strategy(x, adapter_module, module=self) - - elif isinstance(adapter_module, MultiHeadAttention) and loc == 'mha': - if self.self_attention_model == 'rel_pos': - x = dict(query=x, key=x, value=x, mask=att_mask, pos_emb=pos_emb) - output = adapter_strategy(x, adapter_module, module=self) - - elif self.self_attention_model == 'abs_pos': - x = dict(query=x, key=x, value=x, mask=att_mask) - output = adapter_strategy(x, adapter_module, module=self) - - else: - raise ValueError(f"Unsupported value of self_attention_model , provided {self.self_attention_model}!") - - else: - # No adapter compatible, skip - output = x - - input['x'] = output - - return input - def reset_parameters(self): # Used for Squeezeformer initialization only self.feed_forward1.reset_parameters_ff() diff --git a/nemo/collections/asr/parts/utils/adapter_utils.py b/nemo/collections/asr/parts/utils/adapter_utils.py index 5b74a296419a..444e951d6816 100644 --- a/nemo/collections/asr/parts/utils/adapter_utils.py +++ b/nemo/collections/asr/parts/utils/adapter_utils.py @@ -21,6 +21,8 @@ # Constants LINEAR_ADAPTER_CLASSPATH = "nemo.collections.common.parts.adapter_modules.LinearAdapter" + +# Conformer Adapters MHA_ADAPTER_CLASSPATH = ( "nemo.collections.asr.parts.submodules.adapters.multi_head_attention_adapter_module.MultiHeadAttentionAdapter" ) @@ -32,6 +34,8 @@ "nemo.collections.asr.parts.submodules.adapters.multi_head_attention_adapter_module.RelPositionalEncodingAdapter" ) +# Transformer Adapters +TRANSFORMER_ENCODER_MHA_ADAPTER_CLASSPATH = "nemo.collections.asr.parts.submodules.adapters.transformer_multi_head_attention_adapter_module.TransformerEncoderMultiHeadAttentionAdapter" def convert_adapter_cfg_to_dict_config(cfg: DictConfig): # Convert to DictConfig from dict or Dataclass @@ -58,7 +62,7 @@ def update_adapter_cfg_input_dim(module: torch.nn.Module, cfg: DictConfig, *, mo """ cfg = convert_adapter_cfg_to_dict_config(cfg) - input_dim_valid_keys = ['in_features', 'n_feat'] + input_dim_valid_keys = ['in_features', 'n_feat', 'hidden_size'] input_key = None for key in input_dim_valid_keys: diff --git a/nemo/core/classes/mixins/adapter_mixins.py b/nemo/core/classes/mixins/adapter_mixins.py index 2a05f374d464..808ed037c9f3 100644 --- a/nemo/core/classes/mixins/adapter_mixins.py +++ b/nemo/core/classes/mixins/adapter_mixins.py @@ -15,7 +15,7 @@ import inspect from abc import ABC from dataclasses import dataclass, is_dataclass -from typing import List, Optional, Set, Tuple, Union +from typing import List, Optional, Set, Tuple, Union, Iterable import torch import torch.nn as nn @@ -171,21 +171,7 @@ def add_adapter(self, name: str, cfg: Union[DictConfig, AdapterConfig], **kwargs cfg = DictConfig(cfg) adapter_types = self.get_accepted_adapter_types() - _pass_types = False - if len(adapter_types) > 0: - test = model_utils.import_class_by_path(cfg._target_) - for _type in adapter_types: - # TODO: (@adithyare) should revisit if subclass is the best check... - if issubclass(test, _type): - _pass_types = True - break - if not _pass_types: - raise ValueError( - f"Config: \n{OmegaConf.to_yaml(cfg)}\n" - f"It creates adapter class {test} \n" - f"that is not in the list of accepted adapter types.\n" - f"Accepted adapters: {[t for t in adapter_types]}" - ) + self.check_supported_adapter_type_(cfg, adapter_types) # Convert to DictConfig from dict or Dataclass if is_dataclass(cfg): @@ -543,6 +529,34 @@ def forward_single_enabled_adapter_( output = adapter_strategy(input, adapter_module, module=self) return output + def check_supported_adapter_type_(self, adapter_cfg: DictConfig, + supported_adapter_types: Optional[Iterable[type]] = None): + """ + Utility method to check if the adapter module is a supported type by the module. + + This method should be called by the subclass to ensure that the adapter module is a supported type. + """ + _pass_types = False + + if supported_adapter_types is None: + supported_adapter_types = self.get_accepted_adapter_types() + + if len(supported_adapter_types) > 0: + test = model_utils.import_class_by_path(adapter_cfg['_target_']) + for _type in supported_adapter_types: + # TODO: (@adithyare) should revisit if subclass is the best check... + if issubclass(test, _type): + _pass_types = True + break + + if not _pass_types: + raise ValueError( + f"Config: \n{OmegaConf.to_yaml(adapter_cfg)}\n" + f"It creates adapter class {test} \n" + f"that is not in the list of accepted adapter types.\n" + f"Accepted adapters: {[t for t in supported_adapter_types]}" + ) + class AdapterModelPTMixin(AdapterModuleMixin): """ Adapter Mixin that can augment a ModelPT subclass with Adapter support. @@ -982,6 +996,22 @@ def adapter_module_names(self) -> List[str]: Returns: A list of str, one for each of the adapter modules that are supported. By default, the subclass - should support the "global adapter" (''). + should support the "default adapter" (''). """ return [''] + + @property + def default_adapter_module_name(self) -> Optional[str]: + """ + Name of the adapter module that is used as "default" if a name of '' is provided. + + .. note:: + + Subclasses should override this property and return a str name of the module + that they wish to denote as the default. + + Returns: + A str name of a module, which is denoted as 'default' adapter or None. If None, then no default + adapter is supported. + """ + return None diff --git a/tests/collections/asr/mixins/adapters/test_asr_adapter_mixin.py b/tests/collections/asr/mixins/adapters/test_asr_adapter_mixin.py index c520bd4c1292..5701ea128cce 100644 --- a/tests/collections/asr/mixins/adapters/test_asr_adapter_mixin.py +++ b/tests/collections/asr/mixins/adapters/test_asr_adapter_mixin.py @@ -12,12 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import pytest import torch from omegaconf import DictConfig, ListConfig, OmegaConf -from nemo.collections.asr.models import ASRModel, EncDecCTCModel, EncDecRNNTModel -from nemo.collections.asr.parts.submodules.adapters import multi_head_attention_adapter_module +from nemo.collections.asr.models import ASRModel, EncDecCTCModel, EncDecRNNTModel, EncDecMultiTaskModel +from nemo.collections.asr.parts.submodules.adapters import multi_head_attention_adapter_module, transformer_multi_head_attention_adapter_module from nemo.collections.asr.parts.utils import adapter_utils from nemo.collections.common.parts import adapter_modules from nemo.core.classes.mixins.access_mixins import AccessMixin @@ -286,8 +287,130 @@ def rnnt_model(): return model_instance +@pytest.fixture() +def multitask_model(test_data_dir): + preprocessor = {'cls': 'nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor', 'params': dict({})} + + # fmt: off + tokenizer = { + 'dir': None, + 'type': 'agg', + 'langs': { + 'spl_tokens': { + 'dir': os.path.join(test_data_dir, 'asr', 'tokenizers', 'canary'), + 'type': 'bpe', + }, + 'en': { + 'dir': os.path.join(test_data_dir, 'asr', 'tokenizers', 'an4_spe_128'), + 'type': 'bpe', + } + }, + 'custom_tokenizer': { + '_target_': 'nemo.collections.common.tokenizers.canary_tokenizer.CanaryTokenizer', + 'tokenizers': None, + } + } + # fmt: on + + model_defaults = { + "asr_enc_hidden": 128, + "lm_enc_hidden": 128, + "lm_dec_hidden": 128 + } + + # Test case where Encoder (default) is not adapter compatible + encoder = { + '_target_': 'nemo.collections.asr.modules.ConformerEncoderAdapter', + 'feat_in': 64, + 'feat_out': -1, + 'n_layers': 2, + 'd_model': 128, + 'subsampling': 'striding', + 'subsampling_factor': 4, + 'self_attention_model': 'rel_pos', + 'n_heads': 4, + 'conv_kernel_size': 31, + } + + transf_encoder = { + "_target_": "nemo.collections.asr.modules.transformer.transformer_encoders.TransformerEncoderAdapter", + "num_layers": 1, + "hidden_size": "${model_defaults.lm_enc_hidden}", + "inner_size": int(4 * model_defaults['lm_enc_hidden']), + "num_attention_heads": 8, + "ffn_dropout": 0.1, + "attn_score_dropout": 0.1, + "attn_layer_dropout": 0.1, + "mask_future": False, + "pre_ln": True, + "pre_ln_final_layer_norm": True + } + + transf_decoder = { + "_target_": "nemo.collections.asr.modules.transformer.get_nemo_transformer", + "model_name": None, + "pretrained": False, + "encoder": None, + "pre_ln_final_layer_norm": True, + "config_dict": { + "max_sequence_length": 512, + "num_token_types": 0, + "embedding_dropout": 0.1, + "learn_positional_encodings": False, + "hidden_size": "${model_defaults.lm_dec_hidden}", + "inner_size": "${multiply:${model_defaults.lm_dec_hidden}, 4}", + "num_layers": 24, + "num_attention_heads": 8, + "ffn_dropout": 0.1, + "attn_score_dropout": 0.1, + "attn_layer_dropout": 0.1, + "hidden_act": "relu", + "pre_ln": True, + "vocab_size": None # Will be set by the model at runtime + } + } + + head = { + "_target_": "nemo.collections.asr.parts.submodules.token_classifier.TokenClassifier", + "num_layers": 1, + "activation": "relu", + "log_softmax": True, + "hidden_size": "${transf_decoder.config_dict.hidden_size}", + "num_classes": None, # Will be set by the model at runtime + "dropout": 0.0, + "use_transformer_init": True + } + + decoding = {'strategy': 'beam', 'beam': {'beam_size': 1, 'len_pen': 0.0, 'max_generation_delta': 50}} + + loss = { + "_target_": "nemo.collections.common.losses.smoothed_cross_entropy.SmoothedCrossEntropyLoss", + "label_smoothing": 0.0, + "pad_id": None + } + + modelConfig = DictConfig( + { + 'sample_rate': 16000, + 'prompt_format': 'canary', + 'preprocessor': DictConfig(preprocessor), + 'model_defaults': DictConfig(model_defaults), + 'tokenizer': DictConfig(tokenizer), + 'encoder': DictConfig(encoder), + 'transf_encoder': DictConfig(transf_encoder), + 'transf_decoder': DictConfig(transf_decoder), + 'head': DictConfig(head), + 'decoding': DictConfig(decoding), + 'loss': DictConfig(loss), + } + ) + + model_instance = EncDecMultiTaskModel(cfg=modelConfig) + return model_instance + + def get_adapter_cfg(in_features=50, dim=100, norm_pos='pre', atype='linear', **kwargs): - valid_types = ['linear', 'mha', 'relmha'] + valid_types = ['linear', 'mha', 'relmha', 'transf_mha'] if atype not in valid_types: raise ValueError(f"Invalid type. Valid types = {atype}") @@ -297,6 +420,10 @@ def get_adapter_cfg(in_features=50, dim=100, norm_pos='pre', atype='linear', **k cfg = multi_head_attention_adapter_module.MultiHeadAttentionAdapterConfig( n_head=kwargs.get('n_head', 1), n_feat=in_features ) + elif atype == 'transf_mha': + cfg = transformer_multi_head_attention_adapter_module.TransformerEncoderMultiHeadAttentionAdapterConfig( + num_attention_heads=kwargs.get('n_head', 1), hidden_size=in_features + ) elif atype == 'relmha': cfg = multi_head_attention_adapter_module.RelPositionMultiHeadAttentionAdapterConfig( n_head=kwargs.get('n_head', 1), n_feat=in_features @@ -467,6 +594,36 @@ def test_squeezeformer_forward_mha(self, squeezeformer_ctc_adapter, name): assert torch.mean(torch.abs(origial_output - new_output)) < 1e-5 + @pytest.mark.unit + @pytest.mark.parametrize('name', ['adapter_0', 'encoder:adapter_0', 'transf_encoder:adapter_0']) + def test_canary_forward_mha(self, multitask_model, name): + multitask_model.eval() + torch.random.manual_seed(0) + input_signal = torch.randn(2, 512) + input_signal_length = torch.tensor([512, 512], dtype=torch.int32) + transcript = torch.randint(0, multitask_model.tokenizer.vocab_size, size=(2, 10)) + transcript_len = torch.tensor([10, 9], dtype=torch.int32) + + origial_output = multitask_model(input_signal=input_signal, input_signal_length=input_signal_length, + transcript=transcript, transcript_length=transcript_len) + og_logprob = origial_output[0] + og_enc_out = origial_output[2] + + multitask_model.add_adapter(name=name, cfg=get_adapter_cfg(in_features=128, atype='transf_mha')) + + new_output = multitask_model(input_signal=input_signal, input_signal_length=input_signal_length, + transcript=transcript, transcript_length=transcript_len) + + new_logprob = new_output[0] + new_enc_out = new_output[2] + + assert torch.mean(torch.abs(og_logprob - new_logprob)) < 1e-5 + assert torch.mean(torch.abs(og_enc_out - new_enc_out)) < 1e-5 + + # Try to use incorrect adapter + with pytest.raises(ValueError): + multitask_model.add_adapter(name="transf_encoder:adapter_1", cfg=get_adapter_cfg(in_features=128, atype='mha')) + @pytest.mark.unit @pytest.mark.parametrize('name1', ['adapter_0', 'encoder:adapter_0', 'decoder:adapter_0']) @pytest.mark.parametrize('name2', ['adapter_1', 'encoder:adapter_1', 'decoder:adapter_1']) From 9e2c5e26063087194df7818429f93e7699f8d100 Mon Sep 17 00:00:00 2001 From: titu1994 Date: Fri, 7 Jun 2024 02:35:18 +0000 Subject: [PATCH 02/21] Apply isort and black reformatting Signed-off-by: titu1994 --- .../asr/models/aed_multitask_models.py | 2 +- .../transformer/transformer_encoders.py | 7 ++- .../asr/parts/mixins/asr_adapter_mixins.py | 14 +++-- .../adapters/attention_adapter_mixin.py | 12 +++-- .../multi_head_attention_adapter_module.py | 28 +++++----- ...mer_multi_head_attention_adapter_module.py | 51 +++++++++++------- .../asr/parts/submodules/conformer_modules.py | 2 +- .../parts/submodules/squeezeformer_modules.py | 2 +- .../asr/parts/utils/adapter_utils.py | 1 + nemo/core/classes/mixins/adapter_mixins.py | 19 ++++--- .../mixins/adapters/test_asr_adapter_mixin.py | 53 ++++++++++++------- 11 files changed, 114 insertions(+), 77 deletions(-) diff --git a/nemo/collections/asr/models/aed_multitask_models.py b/nemo/collections/asr/models/aed_multitask_models.py index 1f65961da264..bcb7758b1c13 100644 --- a/nemo/collections/asr/models/aed_multitask_models.py +++ b/nemo/collections/asr/models/aed_multitask_models.py @@ -30,7 +30,7 @@ ) from nemo.collections.asr.metrics import BLEU, WER from nemo.collections.asr.models.asr_model import ASRModel, ExportableEncDecModel -from nemo.collections.asr.parts.mixins import ASRBPEMixin, ASRTranscriptionMixin, ASRModuleMixin +from nemo.collections.asr.parts.mixins import ASRBPEMixin, ASRModuleMixin, ASRTranscriptionMixin from nemo.collections.asr.parts.mixins.transcription import ( GenericTranscriptionType, InternalTranscribeConfig, diff --git a/nemo/collections/asr/modules/transformer/transformer_encoders.py b/nemo/collections/asr/modules/transformer/transformer_encoders.py index faf1fd09dc2d..ff8ea0112c0d 100644 --- a/nemo/collections/asr/modules/transformer/transformer_encoders.py +++ b/nemo/collections/asr/modules/transformer/transformer_encoders.py @@ -13,17 +13,17 @@ # limitations under the License. import copy -from typing import Optional, List, Set -from omegaconf import DictConfig +from typing import List, Optional, Set import torch import torch.nn as nn +from omegaconf import DictConfig from nemo.collections.asr.modules.transformer.transformer_modules import MultiHeadAttention, PositionWiseFF from nemo.collections.asr.parts.submodules.adapters.attention_adapter_mixin import AttentionAdapterModuleMixin from nemo.collections.asr.parts.utils import adapter_utils -from nemo.core.classes.mixins import adapter_mixins from nemo.collections.common.parts import form_attention_mask +from nemo.core.classes.mixins import adapter_mixins __all__ = ["TransformerEncoder"] @@ -268,7 +268,6 @@ def get_accepted_adapter_types(self) -> Set[type]: return types - """ Register any additional information """ diff --git a/nemo/collections/asr/parts/mixins/asr_adapter_mixins.py b/nemo/collections/asr/parts/mixins/asr_adapter_mixins.py index ae36b2d882da..bd0607f2c4f3 100644 --- a/nemo/collections/asr/parts/mixins/asr_adapter_mixins.py +++ b/nemo/collections/asr/parts/mixins/asr_adapter_mixins.py @@ -21,7 +21,7 @@ class ASRAdapterModelMixin(AdapterModelPTMixin): - """ ASR Adapter Mixin that can augment any Encoder module with Adapter module support. + """ASR Adapter Mixin that can augment any Encoder module with Adapter module support. This mixin class should be used only with a top level ModelPT subclass, that includes an `encoder` submodule. This mixin class adds several utility methods which are propagated to the `encoder`. @@ -88,8 +88,10 @@ def add_adapter(self, name: str, cfg: DictConfig): # Check if default module name is None or not if default_module_name is None: - raise ValueError(f"Default module name is None. Class {self.__class__.__name__} must implement " - f"`default_adapter_module_name`") + raise ValueError( + f"Default module name is None. Class {self.__class__.__name__} must implement " + f"`default_adapter_module_name`" + ) # Update the model.cfg with information about the new adapter from cfg with open_dict(self.cfg): @@ -162,8 +164,10 @@ def set_enabled_adapters(self, name: Optional[str] = None, enabled: bool = True) # Check if default module name is None or not if default_module_name is None: - raise ValueError(f"Default module name is None. Class {self.__class__.__name__} must implement " - f"`default_adapter_module_name`") + raise ValueError( + f"Default module name is None. Class {self.__class__.__name__} must implement " + f"`default_adapter_module_name`" + ) # Forward the method call to the individual modules if they exist for module_name in module_names: diff --git a/nemo/collections/asr/parts/submodules/adapters/attention_adapter_mixin.py b/nemo/collections/asr/parts/submodules/adapters/attention_adapter_mixin.py index ed9a2906b6a8..94fbcd5bf447 100644 --- a/nemo/collections/asr/parts/submodules/adapters/attention_adapter_mixin.py +++ b/nemo/collections/asr/parts/submodules/adapters/attention_adapter_mixin.py @@ -68,14 +68,15 @@ def forward_single_enabled_adapter_( The result tensor, after the current active adapter has finished its forward pass. """ if not hasattr(self, 'self_attention_model'): - raise RuntimeError("self_attention_model attribute not found in the module! Please set in the module " - "a string attribute 'self_attention_model' with value 'abs_pos', 'rel_pos' or " - "other supported self-attention model types.") + raise RuntimeError( + "self_attention_model attribute not found in the module! Please set in the module " + "a string attribute 'self_attention_model' with value 'abs_pos', 'rel_pos' or " + "other supported self-attention model types." + ) # Collect imports to prevent circular imports - from nemo.collections.asr.parts.submodules import multi_head_attention as conformer_mha from nemo.collections.asr.modules.transformer import transformer_modules as transformer_mha - + from nemo.collections.asr.parts.submodules import multi_head_attention as conformer_mha # (input: torch.Tensor, adapter: torch.nn.Module, *, module: 'AdapterModuleMixin') x = input['x'] @@ -84,6 +85,7 @@ def forward_single_enabled_adapter_( pos_emb = input.get('pos_emb', None) from nemo.collections.common.parts import adapter_modules + if isinstance(adapter_module, adapter_modules.LinearAdapter) and loc == 'post': output = adapter_strategy(x, adapter_module, module=self) diff --git a/nemo/collections/asr/parts/submodules/adapters/multi_head_attention_adapter_module.py b/nemo/collections/asr/parts/submodules/adapters/multi_head_attention_adapter_module.py index 8afa5d283089..db6108a66f62 100644 --- a/nemo/collections/asr/parts/submodules/adapters/multi_head_attention_adapter_module.py +++ b/nemo/collections/asr/parts/submodules/adapters/multi_head_attention_adapter_module.py @@ -106,16 +106,16 @@ class MHAResidualAddAdapterStrategyConfig(adapter_mixin_strategies.ResidualAddAd class MultiHeadAttentionAdapter(mha.MultiHeadAttention, adapter_modules.AdapterModuleUtil): """Multi-Head Attention layer of Transformer. - Args: - n_head (int): number of heads - n_feat (int): size of the features - dropout_rate (float): dropout rate - proj_dim (int, optional): Optional integer value for projection before computing attention. - If None, then there is no projection (equivalent to proj_dim = n_feat). - If > 0, then will project the n_feat to proj_dim before calculating attention. - If <0, then will equal n_head, so that each head has a projected dimension of 1. - adapter_strategy: By default, MHAResidualAddAdapterStrategyConfig. An adapter composition function object. - """ + Args: + n_head (int): number of heads + n_feat (int): size of the features + dropout_rate (float): dropout rate + proj_dim (int, optional): Optional integer value for projection before computing attention. + If None, then there is no projection (equivalent to proj_dim = n_feat). + If > 0, then will project the n_feat to proj_dim before calculating attention. + If <0, then will equal n_head, so that each head has a projected dimension of 1. + adapter_strategy: By default, MHAResidualAddAdapterStrategyConfig. An adapter composition function object. + """ def __init__( self, @@ -301,7 +301,6 @@ class RelPositionMultiHeadAttentionAdapterConfig: class PositionalEncodingAdapter(mha.PositionalEncoding, adapter_modules.AdapterModuleUtil): - """ Absolute positional embedding adapter. @@ -328,7 +327,11 @@ def __init__( ): super().__init__( - d_model=d_model, dropout_rate=0.0, max_len=max_len, xscale=xscale, dropout_rate_emb=0.0, + d_model=d_model, + dropout_rate=0.0, + max_len=max_len, + xscale=xscale, + dropout_rate_emb=0.0, ) # Setup adapter strategy @@ -391,4 +394,3 @@ class RelPositionalEncodingAdapterConfig: default_factory=lambda: adapter_mixin_strategies.ResidualAddAdapterStrategyConfig() ) _target_: str = "{0}.{1}".format(RelPositionalEncodingAdapter.__module__, RelPositionalEncodingAdapter.__name__) - diff --git a/nemo/collections/asr/parts/submodules/adapters/transformer_multi_head_attention_adapter_module.py b/nemo/collections/asr/parts/submodules/adapters/transformer_multi_head_attention_adapter_module.py index 8a612de03023..4c35645657dc 100644 --- a/nemo/collections/asr/parts/submodules/adapters/transformer_multi_head_attention_adapter_module.py +++ b/nemo/collections/asr/parts/submodules/adapters/transformer_multi_head_attention_adapter_module.py @@ -19,26 +19,31 @@ import torch from torch import nn as nn -from nemo.collections.asr.parts.submodules.adapters.multi_head_attention_adapter_module import MHAResidualAddAdapterStrategy, MHAResidualAddAdapterStrategyConfig from nemo.collections.asr.modules.transformer import transformer_modules +from nemo.collections.asr.parts.submodules.adapters.multi_head_attention_adapter_module import ( + MHAResidualAddAdapterStrategy, + MHAResidualAddAdapterStrategyConfig, +) from nemo.collections.common.parts import adapter_modules -from nemo.core.classes.mixins import adapter_mixins, adapter_mixin_strategies +from nemo.core.classes.mixins import adapter_mixin_strategies, adapter_mixins -class TransformerEncoderMultiHeadAttentionAdapter(transformer_modules.MultiHeadAttention, adapter_modules.AdapterModuleUtil): +class TransformerEncoderMultiHeadAttentionAdapter( + transformer_modules.MultiHeadAttention, adapter_modules.AdapterModuleUtil +): """Multi-Head Attention layer of Transformer Encoder. - Args: - hidden_size (int): number of heads - num_attention_heads (int): size of the features - attn_score_dropout (float): dropout rate for the attention scores - attn_layer_dropout (float): dropout rate for the layer - proj_dim (int, optional): Optional integer value for projection before computing attention. - If None, then there is no projection (equivalent to proj_dim = n_feat). - If > 0, then will project the n_feat to proj_dim before calculating attention. - If <0, then will equal n_head, so that each head has a projected dimension of 1. - adapter_strategy: By default, MHAResidualAddAdapterStrategyConfig. An adapter composition function object. - """ + Args: + hidden_size (int): number of heads + num_attention_heads (int): size of the features + attn_score_dropout (float): dropout rate for the attention scores + attn_layer_dropout (float): dropout rate for the layer + proj_dim (int, optional): Optional integer value for projection before computing attention. + If None, then there is no projection (equivalent to proj_dim = n_feat). + If > 0, then will project the n_feat to proj_dim before calculating attention. + If <0, then will equal n_head, so that each head has a projected dimension of 1. + adapter_strategy: By default, MHAResidualAddAdapterStrategyConfig. An adapter composition function object. + """ def __init__( self, @@ -49,8 +54,12 @@ def __init__( proj_dim: Optional[int] = None, adapter_strategy: MHAResidualAddAdapterStrategy = None, ): - super().__init__(hidden_size=hidden_size, num_attention_heads=num_attention_heads, - attn_score_dropout=attn_score_dropout, attn_layer_dropout=attn_layer_dropout) + super().__init__( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + attn_score_dropout=attn_score_dropout, + attn_layer_dropout=attn_layer_dropout, + ) self.pre_norm = nn.LayerNorm(hidden_size) @@ -116,11 +125,13 @@ class TransformerEncoderMultiHeadAttentionAdapterConfig: attn_layer_dropout: float = 0.0 proj_dim: Optional[int] = None adapter_strategy: Optional[Any] = field(default_factory=lambda: MHAResidualAddAdapterStrategyConfig()) - _target_: str = "{0}.{1}".format(TransformerEncoderMultiHeadAttentionAdapter.__module__, - TransformerEncoderMultiHeadAttentionAdapter.__name__) + _target_: str = "{0}.{1}".format( + TransformerEncoderMultiHeadAttentionAdapter.__module__, TransformerEncoderMultiHeadAttentionAdapter.__name__ + ) """ Register the Adapter Modules to the Adapter Module Registry """ if adapter_mixins.get_registered_adapter(TransformerEncoderMultiHeadAttentionAdapter) is None: - adapter_mixins.register_adapter(base_class=transformer_modules.MultiHeadAttention, - adapter_class=TransformerEncoderMultiHeadAttentionAdapter) \ No newline at end of file + adapter_mixins.register_adapter( + base_class=transformer_modules.MultiHeadAttention, adapter_class=TransformerEncoderMultiHeadAttentionAdapter + ) diff --git a/nemo/collections/asr/parts/submodules/conformer_modules.py b/nemo/collections/asr/parts/submodules/conformer_modules.py index 74fc6e77d68b..56086a2cd62a 100644 --- a/nemo/collections/asr/parts/submodules/conformer_modules.py +++ b/nemo/collections/asr/parts/submodules/conformer_modules.py @@ -17,6 +17,7 @@ from torch import nn as nn from torch.nn import LayerNorm +from nemo.collections.asr.parts.submodules.adapters.attention_adapter_mixin import AttentionAdapterModuleMixin from nemo.collections.asr.parts.submodules.batchnorm import FusedBatchNorm1d from nemo.collections.asr.parts.submodules.causal_convs import CausalConv1D from nemo.collections.asr.parts.submodules.multi_head_attention import ( @@ -24,7 +25,6 @@ RelPositionMultiHeadAttention, RelPositionMultiHeadAttentionLongformer, ) -from nemo.collections.asr.parts.submodules.adapters.attention_adapter_mixin import AttentionAdapterModuleMixin from nemo.collections.asr.parts.utils.activations import Swish from nemo.collections.common.parts import adapter_modules from nemo.collections.common.parts.utils import activation_registry diff --git a/nemo/collections/asr/parts/submodules/squeezeformer_modules.py b/nemo/collections/asr/parts/submodules/squeezeformer_modules.py index 407714c6dca9..212320e1f76f 100644 --- a/nemo/collections/asr/parts/submodules/squeezeformer_modules.py +++ b/nemo/collections/asr/parts/submodules/squeezeformer_modules.py @@ -16,12 +16,12 @@ from torch import nn as nn from torch.nn import LayerNorm +from nemo.collections.asr.parts.submodules.adapters.attention_adapter_mixin import AttentionAdapterModuleMixin from nemo.collections.asr.parts.submodules.conformer_modules import ConformerConvolution, ConformerFeedForward from nemo.collections.asr.parts.submodules.multi_head_attention import ( MultiHeadAttention, RelPositionMultiHeadAttention, ) -from nemo.collections.asr.parts.submodules.adapters.attention_adapter_mixin import AttentionAdapterModuleMixin from nemo.core.classes.mixins import AccessMixin __all__ = ['SqueezeformerLayer', 'ConformerFeedForward', 'SqueezeformerLayer'] diff --git a/nemo/collections/asr/parts/utils/adapter_utils.py b/nemo/collections/asr/parts/utils/adapter_utils.py index 444e951d6816..bd0cbcec191c 100644 --- a/nemo/collections/asr/parts/utils/adapter_utils.py +++ b/nemo/collections/asr/parts/utils/adapter_utils.py @@ -37,6 +37,7 @@ # Transformer Adapters TRANSFORMER_ENCODER_MHA_ADAPTER_CLASSPATH = "nemo.collections.asr.parts.submodules.adapters.transformer_multi_head_attention_adapter_module.TransformerEncoderMultiHeadAttentionAdapter" + def convert_adapter_cfg_to_dict_config(cfg: DictConfig): # Convert to DictConfig from dict or Dataclass if is_dataclass(cfg): diff --git a/nemo/core/classes/mixins/adapter_mixins.py b/nemo/core/classes/mixins/adapter_mixins.py index 808ed037c9f3..1cc902bd7586 100644 --- a/nemo/core/classes/mixins/adapter_mixins.py +++ b/nemo/core/classes/mixins/adapter_mixins.py @@ -15,7 +15,7 @@ import inspect from abc import ABC from dataclasses import dataclass, is_dataclass -from typing import List, Optional, Set, Tuple, Union, Iterable +from typing import Iterable, List, Optional, Set, Tuple, Union import torch import torch.nn as nn @@ -124,7 +124,7 @@ def _prepare_default_adapter_config(*, global_key: str, meta_key: str, cfg: Dict class AdapterModuleMixin(ABC): - """ Generic Adapter Mixin that can augment any torch.nn.Module with Adapter module support. + """Generic Adapter Mixin that can augment any torch.nn.Module with Adapter module support. This mixin class adds a hierarchical way to add any type of Adapter modules to a pre-existing module. Since Models are inherently also nn.Module, this mixin can be attached to any Model or Module. @@ -349,7 +349,9 @@ def set_accepted_adapter_types(self, adapter_types: List[Union[type, str]]) -> N self._accepted_adapter_types = set(types) - def get_accepted_adapter_types(self,) -> Set[type]: + def get_accepted_adapter_types( + self, + ) -> Set[type]: """ Utility function to get the set of all classes that are accepted by the module. @@ -529,8 +531,9 @@ def forward_single_enabled_adapter_( output = adapter_strategy(input, adapter_module, module=self) return output - def check_supported_adapter_type_(self, adapter_cfg: DictConfig, - supported_adapter_types: Optional[Iterable[type]] = None): + def check_supported_adapter_type_( + self, adapter_cfg: DictConfig, supported_adapter_types: Optional[Iterable[type]] = None + ): """ Utility method to check if the adapter module is a supported type by the module. @@ -559,7 +562,7 @@ def check_supported_adapter_type_(self, adapter_cfg: DictConfig, class AdapterModelPTMixin(AdapterModuleMixin): - """ Adapter Mixin that can augment a ModelPT subclass with Adapter support. + """Adapter Mixin that can augment a ModelPT subclass with Adapter support. This mixin class should be used only with a top level ModelPT subclass. This mixin class adds several utility methods which should be subclassed and overriden to @@ -655,7 +658,9 @@ def add_adapter(self, name: str, cfg: Union[DictConfig, AdapterConfig]): self.cfg.adapters = OmegaConf.create({}) self.cfg.adapters = _prepare_default_adapter_config( - global_key=self.adapter_global_cfg_key, meta_key=self.adapter_metadata_cfg_key, cfg=self.cfg.adapters, + global_key=self.adapter_global_cfg_key, + meta_key=self.adapter_metadata_cfg_key, + cfg=self.cfg.adapters, ) # If the adapter is not being restored, force unique name to be provided for all adapters. diff --git a/tests/collections/asr/mixins/adapters/test_asr_adapter_mixin.py b/tests/collections/asr/mixins/adapters/test_asr_adapter_mixin.py index 5701ea128cce..155e51000128 100644 --- a/tests/collections/asr/mixins/adapters/test_asr_adapter_mixin.py +++ b/tests/collections/asr/mixins/adapters/test_asr_adapter_mixin.py @@ -13,12 +13,16 @@ # limitations under the License. import os + import pytest import torch from omegaconf import DictConfig, ListConfig, OmegaConf -from nemo.collections.asr.models import ASRModel, EncDecCTCModel, EncDecRNNTModel, EncDecMultiTaskModel -from nemo.collections.asr.parts.submodules.adapters import multi_head_attention_adapter_module, transformer_multi_head_attention_adapter_module +from nemo.collections.asr.models import ASRModel, EncDecCTCModel, EncDecMultiTaskModel, EncDecRNNTModel +from nemo.collections.asr.parts.submodules.adapters import ( + multi_head_attention_adapter_module, + transformer_multi_head_attention_adapter_module, +) from nemo.collections.asr.parts.utils import adapter_utils from nemo.collections.common.parts import adapter_modules from nemo.core.classes.mixins.access_mixins import AccessMixin @@ -312,11 +316,7 @@ def multitask_model(test_data_dir): } # fmt: on - model_defaults = { - "asr_enc_hidden": 128, - "lm_enc_hidden": 128, - "lm_dec_hidden": 128 - } + model_defaults = {"asr_enc_hidden": 128, "lm_enc_hidden": 128, "lm_dec_hidden": 128} # Test case where Encoder (default) is not adapter compatible encoder = { @@ -343,7 +343,7 @@ def multitask_model(test_data_dir): "attn_layer_dropout": 0.1, "mask_future": False, "pre_ln": True, - "pre_ln_final_layer_norm": True + "pre_ln_final_layer_norm": True, } transf_decoder = { @@ -366,8 +366,8 @@ def multitask_model(test_data_dir): "attn_layer_dropout": 0.1, "hidden_act": "relu", "pre_ln": True, - "vocab_size": None # Will be set by the model at runtime - } + "vocab_size": None, # Will be set by the model at runtime + }, } head = { @@ -378,7 +378,7 @@ def multitask_model(test_data_dir): "hidden_size": "${transf_decoder.config_dict.hidden_size}", "num_classes": None, # Will be set by the model at runtime "dropout": 0.0, - "use_transformer_init": True + "use_transformer_init": True, } decoding = {'strategy': 'beam', 'beam': {'beam_size': 1, 'len_pen': 0.0, 'max_generation_delta': 50}} @@ -386,7 +386,7 @@ def multitask_model(test_data_dir): loss = { "_target_": "nemo.collections.common.losses.smoothed_cross_entropy.SmoothedCrossEntropyLoss", "label_smoothing": 0.0, - "pad_id": None + "pad_id": None, } modelConfig = DictConfig( @@ -507,7 +507,8 @@ def test_asr_model_constructor_joint_module_ctc_skip(self, model): assert new_num_params == original_num_params @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.unit def test_asr_model_constructor_joint_module_rnnt(self, rnnt_model): @@ -604,15 +605,23 @@ def test_canary_forward_mha(self, multitask_model, name): transcript = torch.randint(0, multitask_model.tokenizer.vocab_size, size=(2, 10)) transcript_len = torch.tensor([10, 9], dtype=torch.int32) - origial_output = multitask_model(input_signal=input_signal, input_signal_length=input_signal_length, - transcript=transcript, transcript_length=transcript_len) + origial_output = multitask_model( + input_signal=input_signal, + input_signal_length=input_signal_length, + transcript=transcript, + transcript_length=transcript_len, + ) og_logprob = origial_output[0] og_enc_out = origial_output[2] multitask_model.add_adapter(name=name, cfg=get_adapter_cfg(in_features=128, atype='transf_mha')) - new_output = multitask_model(input_signal=input_signal, input_signal_length=input_signal_length, - transcript=transcript, transcript_length=transcript_len) + new_output = multitask_model( + input_signal=input_signal, + input_signal_length=input_signal_length, + transcript=transcript, + transcript_length=transcript_len, + ) new_logprob = new_output[0] new_enc_out = new_output[2] @@ -622,7 +631,9 @@ def test_canary_forward_mha(self, multitask_model, name): # Try to use incorrect adapter with pytest.raises(ValueError): - multitask_model.add_adapter(name="transf_encoder:adapter_1", cfg=get_adapter_cfg(in_features=128, atype='mha')) + multitask_model.add_adapter( + name="transf_encoder:adapter_1", cfg=get_adapter_cfg(in_features=128, atype='mha') + ) @pytest.mark.unit @pytest.mark.parametrize('name1', ['adapter_0', 'encoder:adapter_0', 'decoder:adapter_0']) @@ -645,7 +656,8 @@ def test_asr_multi_adapter_forward(self, model, name1, name2): assert torch.mean(torch.abs(origial_output - new_output)) < 1e-5 @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.parametrize('name1', ['decoder:adapter_0', 'joint:adapter_0']) @pytest.mark.parametrize('name2', ['decoder:adapter_1', 'joint:adapter_1']) @@ -739,7 +751,8 @@ def test_constructor_pretrained(self): assert model.num_weights < 1e5 @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.with_downloads() @pytest.mark.unit From a31e9204e42b28dfc502e93694e53465c6d61859 Mon Sep 17 00:00:00 2001 From: Somshubra Majumdar Date: Thu, 6 Jun 2024 20:00:14 -0700 Subject: [PATCH 03/21] Cleanup Signed-off-by: smajumdar --- .../transformer/transformer_encoders.py | 2 +- .../asr/parts/submodules/adapters/__init__.py | 4 ++++ .../adapters/attention_adapter_mixin.py | 10 +++++++++- .../multi_head_attention_adapter_module.py | 18 +++++++++++++----- .../asr/parts/submodules/conformer_modules.py | 2 -- .../mixins/adapters/test_asr_adapter_mixin.py | 5 +++-- 6 files changed, 30 insertions(+), 11 deletions(-) diff --git a/nemo/collections/asr/modules/transformer/transformer_encoders.py b/nemo/collections/asr/modules/transformer/transformer_encoders.py index ff8ea0112c0d..377155585f0a 100644 --- a/nemo/collections/asr/modules/transformer/transformer_encoders.py +++ b/nemo/collections/asr/modules/transformer/transformer_encoders.py @@ -65,7 +65,7 @@ def __init__( self.second_sub_layer = PositionWiseFF(hidden_size, inner_size, ffn_dropout, hidden_act) # Information for the adapter module mixin - self.self_attention_model = "abs_pos" + self.self_attention_model = "transf_abs" def forward_preln(self, encoder_query, encoder_mask, encoder_keys): """ diff --git a/nemo/collections/asr/parts/submodules/adapters/__init__.py b/nemo/collections/asr/parts/submodules/adapters/__init__.py index 52458dfb06c3..7528b262c427 100644 --- a/nemo/collections/asr/parts/submodules/adapters/__init__.py +++ b/nemo/collections/asr/parts/submodules/adapters/__init__.py @@ -12,6 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +#fmt: off +from nemo.collections.asr.parts.submodules.adapters.attention_adapter_mixin import AttentionAdapterModuleMixin +#fmt: on + from nemo.collections.asr.parts.submodules.adapters.multi_head_attention_adapter_module import ( MHAResidualAddAdapterStrategy, MHAResidualAddAdapterStrategyConfig, diff --git a/nemo/collections/asr/parts/submodules/adapters/attention_adapter_mixin.py b/nemo/collections/asr/parts/submodules/adapters/attention_adapter_mixin.py index 94fbcd5bf447..a75e0c4cf3ee 100644 --- a/nemo/collections/asr/parts/submodules/adapters/attention_adapter_mixin.py +++ b/nemo/collections/asr/parts/submodules/adapters/attention_adapter_mixin.py @@ -1,6 +1,7 @@ import torch from nemo.core.classes.mixins import adapter_mixins +from nemo.utils import logging, logging_mode class AttentionAdapterModuleMixin(adapter_mixins.AdapterModuleMixin): @@ -89,7 +90,7 @@ def forward_single_enabled_adapter_( if isinstance(adapter_module, adapter_modules.LinearAdapter) and loc == 'post': output = adapter_strategy(x, adapter_module, module=self) - if isinstance(adapter_module, conformer_mha.MultiHeadAttention) and loc == 'mha': + elif isinstance(adapter_module, conformer_mha.MultiHeadAttention) and loc == 'mha': if self.self_attention_model == 'rel_pos': x = dict(query=x, key=x, value=x, mask=att_mask, pos_emb=pos_emb) output = adapter_strategy(x, adapter_module, module=self) @@ -101,8 +102,15 @@ def forward_single_enabled_adapter_( else: raise ValueError(f"Unsupported value of self_attention_model , provided {self.self_attention_model}!") + elif isinstance(adapter_module, transformer_mha.MultiHeadAttention) and loc == 'mha': + x = dict(queries=x, keys=x, values=x, attention_mask=att_mask) + output = adapter_strategy(x, adapter_module, module=self) + else: # No adapter compatible, skip + logging.warning("No adapter compatible with the current module. Skipping adapter forward pass.", + mode=logging_mode.ONCE) + output = x input['x'] = output diff --git a/nemo/collections/asr/parts/submodules/adapters/multi_head_attention_adapter_module.py b/nemo/collections/asr/parts/submodules/adapters/multi_head_attention_adapter_module.py index db6108a66f62..635c1a81f188 100644 --- a/nemo/collections/asr/parts/submodules/adapters/multi_head_attention_adapter_module.py +++ b/nemo/collections/asr/parts/submodules/adapters/multi_head_attention_adapter_module.py @@ -29,7 +29,7 @@ class MHAResidualAddAdapterStrategy(adapter_mixin_strategies.ResidualAddAdapterS An implementation of residual addition of an adapter module with its input for the MHA Adapters. """ - def forward(self, input: torch.Tensor, adapter: torch.nn.Module, *, module: 'AdapterModuleMixin'): + def forward(self, input: dict, adapter: torch.nn.Module, *, module: 'AdapterModuleMixin'): """ A basic strategy, comprising of a residual connection over the input, after forward pass by the underlying adapter. Additional work is done to pack and unpack the dictionary of inputs and outputs. @@ -55,19 +55,27 @@ def forward(self, input: torch.Tensor, adapter: torch.nn.Module, *, module: 'Ada """ out = self.compute_output(input, adapter, module=module) + value_name = None + if 'value' in input: + value_name = 'value' + elif 'values' in input: + value_name = 'values' + else: + raise ValueError("Input dictionary must contain 'value' or 'values' key for residual connection. Input " + f"dictionary keys: {input.keys()}") + # If not in training mode, or probability of stochastic depth is 0, skip step. p = self.stochastic_depth if not module.training or p == 0.0: pass else: - out = self.apply_stochastic_depth(out, input['value'], adapter, module=module) + out = self.apply_stochastic_depth(out, input[value_name], adapter, module=module) # Return the residual connection output = input + adapter(input) - print("Out", out.shape, "Input", input['value'].shape) - result = input['value'] + out + result = input[value_name] + out # If l2_lambda is activated, register the loss value - self.compute_auxiliary_losses(result, input['value'], adapter, module=module) + self.compute_auxiliary_losses(result, input[value_name], adapter, module=module) return result diff --git a/nemo/collections/asr/parts/submodules/conformer_modules.py b/nemo/collections/asr/parts/submodules/conformer_modules.py index 56086a2cd62a..5d83cbd005bc 100644 --- a/nemo/collections/asr/parts/submodules/conformer_modules.py +++ b/nemo/collections/asr/parts/submodules/conformer_modules.py @@ -26,10 +26,8 @@ RelPositionMultiHeadAttentionLongformer, ) from nemo.collections.asr.parts.utils.activations import Swish -from nemo.collections.common.parts import adapter_modules from nemo.collections.common.parts.utils import activation_registry from nemo.core.classes.mixins import AccessMixin -from nemo.core.classes.mixins.adapter_mixins import AdapterModuleMixin __all__ = ['ConformerConvolution', 'ConformerFeedForward', 'ConformerLayer'] diff --git a/tests/collections/asr/mixins/adapters/test_asr_adapter_mixin.py b/tests/collections/asr/mixins/adapters/test_asr_adapter_mixin.py index 155e51000128..d15dffc49055 100644 --- a/tests/collections/asr/mixins/adapters/test_asr_adapter_mixin.py +++ b/tests/collections/asr/mixins/adapters/test_asr_adapter_mixin.py @@ -359,7 +359,7 @@ def multitask_model(test_data_dir): "learn_positional_encodings": False, "hidden_size": "${model_defaults.lm_dec_hidden}", "inner_size": "${multiply:${model_defaults.lm_dec_hidden}, 4}", - "num_layers": 24, + "num_layers": 2, "num_attention_heads": 8, "ffn_dropout": 0.1, "attn_score_dropout": 0.1, @@ -614,7 +614,8 @@ def test_canary_forward_mha(self, multitask_model, name): og_logprob = origial_output[0] og_enc_out = origial_output[2] - multitask_model.add_adapter(name=name, cfg=get_adapter_cfg(in_features=128, atype='transf_mha')) + adapter_type = 'transf_mha' if 'transf' in name else 'mha' + multitask_model.add_adapter(name=name, cfg=get_adapter_cfg(in_features=128, atype=adapter_type)) new_output = multitask_model( input_signal=input_signal, From 06c372a9165576011fdf3f251d75d852b33481d2 Mon Sep 17 00:00:00 2001 From: titu1994 Date: Fri, 7 Jun 2024 03:02:17 +0000 Subject: [PATCH 04/21] Apply isort and black reformatting Signed-off-by: titu1994 --- nemo/collections/asr/parts/submodules/adapters/__init__.py | 6 +++--- .../parts/submodules/adapters/attention_adapter_mixin.py | 5 +++-- .../adapters/multi_head_attention_adapter_module.py | 6 ++++-- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/nemo/collections/asr/parts/submodules/adapters/__init__.py b/nemo/collections/asr/parts/submodules/adapters/__init__.py index 7528b262c427..c987e4d24de2 100644 --- a/nemo/collections/asr/parts/submodules/adapters/__init__.py +++ b/nemo/collections/asr/parts/submodules/adapters/__init__.py @@ -12,10 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -#fmt: off +# fmt: off from nemo.collections.asr.parts.submodules.adapters.attention_adapter_mixin import AttentionAdapterModuleMixin -#fmt: on - from nemo.collections.asr.parts.submodules.adapters.multi_head_attention_adapter_module import ( MHAResidualAddAdapterStrategy, MHAResidualAddAdapterStrategyConfig, @@ -32,3 +30,5 @@ TransformerEncoderMultiHeadAttentionAdapter, TransformerEncoderMultiHeadAttentionAdapterConfig, ) + +# fmt: on diff --git a/nemo/collections/asr/parts/submodules/adapters/attention_adapter_mixin.py b/nemo/collections/asr/parts/submodules/adapters/attention_adapter_mixin.py index a75e0c4cf3ee..0696c112529b 100644 --- a/nemo/collections/asr/parts/submodules/adapters/attention_adapter_mixin.py +++ b/nemo/collections/asr/parts/submodules/adapters/attention_adapter_mixin.py @@ -108,8 +108,9 @@ def forward_single_enabled_adapter_( else: # No adapter compatible, skip - logging.warning("No adapter compatible with the current module. Skipping adapter forward pass.", - mode=logging_mode.ONCE) + logging.warning( + "No adapter compatible with the current module. Skipping adapter forward pass.", mode=logging_mode.ONCE + ) output = x diff --git a/nemo/collections/asr/parts/submodules/adapters/multi_head_attention_adapter_module.py b/nemo/collections/asr/parts/submodules/adapters/multi_head_attention_adapter_module.py index 635c1a81f188..2617ed6f575b 100644 --- a/nemo/collections/asr/parts/submodules/adapters/multi_head_attention_adapter_module.py +++ b/nemo/collections/asr/parts/submodules/adapters/multi_head_attention_adapter_module.py @@ -61,8 +61,10 @@ def forward(self, input: dict, adapter: torch.nn.Module, *, module: 'AdapterModu elif 'values' in input: value_name = 'values' else: - raise ValueError("Input dictionary must contain 'value' or 'values' key for residual connection. Input " - f"dictionary keys: {input.keys()}") + raise ValueError( + "Input dictionary must contain 'value' or 'values' key for residual connection. Input " + f"dictionary keys: {input.keys()}" + ) # If not in training mode, or probability of stochastic depth is 0, skip step. p = self.stochastic_depth From a1219d454deacb61c9ff2daf99ce028d8c0e687c Mon Sep 17 00:00:00 2001 From: Somshubra Majumdar Date: Fri, 7 Jun 2024 16:51:05 -0700 Subject: [PATCH 05/21] Finalize support for decoder adapters Signed-off-by: smajumdar --- .../asr/modules/transformer/transformer.py | 46 +++++++- .../transformer/transformer_decoders.py | 103 +++++++++++++++++- .../transformer/transformer_encoders.py | 28 ++--- .../modules/transformer/transformer_utils.py | 9 +- .../asr/parts/utils/adapter_utils.py | 2 +- .../mixins/adapters/test_asr_adapter_mixin.py | 31 +++++- 6 files changed, 197 insertions(+), 22 deletions(-) diff --git a/nemo/collections/asr/modules/transformer/transformer.py b/nemo/collections/asr/modules/transformer/transformer.py index 718448aa1c7c..c98f8b671bdc 100644 --- a/nemo/collections/asr/modules/transformer/transformer.py +++ b/nemo/collections/asr/modules/transformer/transformer.py @@ -13,18 +13,21 @@ # limitations under the License. from dataclasses import dataclass -from typing import Dict, Optional +from typing import Dict, List, Optional import torch -from omegaconf.omegaconf import MISSING +from omegaconf.omegaconf import MISSING, DictConfig from nemo.collections.asr.modules.transformer.decoder_module import DecoderModule from nemo.collections.asr.modules.transformer.encoder_module import EncoderModule -from nemo.collections.asr.modules.transformer.transformer_decoders import TransformerDecoder +from nemo.collections.asr.modules.transformer.transformer_decoders import TransformerDecoder, TransformerDecoderAdapter from nemo.collections.asr.modules.transformer.transformer_encoders import TransformerEncoder from nemo.collections.asr.modules.transformer.transformer_modules import TransformerEmbedding +from nemo.collections.asr.parts.submodules.adapters.attention_adapter_mixin import AttentionAdapterModuleMixin +from nemo.collections.asr.parts.utils import adapter_utils from nemo.core.classes.common import typecheck from nemo.core.classes.exportable import Exportable +from nemo.core.classes.mixins import adapter_mixins from nemo.core.neural_types import ChannelType, NeuralType @@ -155,6 +158,8 @@ def input_example(self, max_batch=1, max_dim=256): class TransformerDecoderNM(DecoderModule, Exportable): + DECODER_TYPE: type = TransformerDecoder + def __init__( self, vocab_size: int, @@ -192,7 +197,7 @@ def __init__( learn_positional_encodings=learn_positional_encodings, ) - self._decoder = TransformerDecoder( + self._decoder = self.DECODER_TYPE( hidden_size=self.hidden_size, num_layers=num_layers, inner_size=inner_size, @@ -274,3 +279,36 @@ def output_types(self) -> Optional[Dict[str, NeuralType]]: return {"last_hidden_states": NeuralType(('B', 'D', 'T', 'D'), ChannelType())} else: return {"last_hidden_states": NeuralType(('B', 'T', 'D'), ChannelType())} + + +class TransformerDecoderNMAdapter(TransformerDecoderNM, adapter_mixins.AdapterModuleMixin): + DECODER_TYPE : type = TransformerDecoderAdapter + + # Higher level forwarding + def add_adapter(self, name: str, cfg: dict): + cfg = self._update_adapter_cfg_input_dim(cfg) + self._decoder.add_adapter(name, cfg) # type: adapter_mixins.AdapterModuleMixin + + def is_adapter_available(self) -> bool: + return self._decoder.is_adapter_available() # type: adapter_mixins.AdapterModuleMixin + + def set_enabled_adapters(self, name: Optional[str] = None, enabled: bool = True): + self._decoder.set_enabled_adapters(name=name, enabled=enabled) # # type: adapter_mixins.AdapterModuleMixin + + def get_enabled_adapters(self) -> List[str]: + names = set([]) + names.update(self._decoder.get_enabled_adapters()) # type: adapter_mixins.AdapterModuleMixin + + names = sorted(list(names)) + return names + + def _update_adapter_cfg_input_dim(self, cfg: DictConfig): + cfg = adapter_utils.update_adapter_cfg_input_dim(self, cfg, module_dim=self._hidden_size) + return cfg + + +""" +Register any additional information +""" +if adapter_mixins.get_registered_adapter(TransformerDecoderNM) is None: + adapter_mixins.register_adapter(base_class=TransformerDecoderNM, adapter_class=TransformerDecoderNMAdapter) diff --git a/nemo/collections/asr/modules/transformer/transformer_decoders.py b/nemo/collections/asr/modules/transformer/transformer_decoders.py index a5b2c299393c..cecb5ab6be1e 100644 --- a/nemo/collections/asr/modules/transformer/transformer_decoders.py +++ b/nemo/collections/asr/modules/transformer/transformer_decoders.py @@ -13,17 +13,22 @@ # limitations under the License. import copy +from omegaconf import DictConfig +from typing import Optional, List, Set import torch import torch.nn as nn from nemo.collections.asr.modules.transformer.transformer_modules import MultiHeadAttention, PositionWiseFF +from nemo.collections.asr.parts.submodules.adapters.attention_adapter_mixin import AttentionAdapterModuleMixin +from nemo.collections.asr.parts.utils import adapter_utils from nemo.collections.common.parts import form_attention_mask +from nemo.core.classes.mixins import adapter_mixins __all__ = ["TransformerDecoder"] -class TransformerDecoderBlock(nn.Module): +class TransformerDecoderBlock(nn.Module, AttentionAdapterModuleMixin): """ Building block of Transformer decoder. @@ -63,6 +68,9 @@ def __init__( self.layer_norm_3 = nn.LayerNorm(hidden_size, eps=1e-5) self.third_sub_layer = PositionWiseFF(hidden_size, inner_size, ffn_dropout, hidden_act) + # Information for the adapter module mixin + self.self_attention_model = "transf_abs" + def forward_preln(self, decoder_query, decoder_mask, decoder_keys, encoder_states, encoder_mask): """ Pre-LayerNorm block @@ -74,6 +82,17 @@ def forward_preln(self, decoder_query, decoder_mask, decoder_keys, encoder_state self_attn_output = self.first_sub_layer(decoder_query, decoder_keys, decoder_keys, decoder_mask) self_attn_output += residual + if self.is_adapter_available(): + # Call the MHA adapters + pack_ip = { + 'x': self_attn_output, + 'loc': 'mha', + 'att_mask': decoder_mask, + 'pos_emb': None, + } + pack_ip = self.forward_enabled_adapters(pack_ip) + self_attn_output = pack_ip['x'] + residual = self_attn_output self_attn_output = self.layer_norm_2(self_attn_output) enc_dec_attn_output = self.second_sub_layer(self_attn_output, encoder_states, encoder_states, encoder_mask) @@ -84,6 +103,15 @@ def forward_preln(self, decoder_query, decoder_mask, decoder_keys, encoder_state output_states = self.third_sub_layer(enc_dec_attn_output) output_states += residual + if self.is_adapter_available(): + # Call the Linear adapters + pack_ip = { + 'x': output_states, + 'loc': 'post', + } + pack_ip = self.forward_enabled_adapters(pack_ip) + output_states = pack_ip['x'] + return output_states def forward_postln(self, decoder_query, decoder_mask, decoder_keys, encoder_states, encoder_mask): @@ -93,6 +121,18 @@ def forward_postln(self, decoder_query, decoder_mask, decoder_keys, encoder_stat """ self_attn_output = self.first_sub_layer(decoder_query, decoder_keys, decoder_keys, decoder_mask) self_attn_output += decoder_query + + if self.is_adapter_available(): + # Call the MHA adapters + pack_ip = { + 'x': self_attn_output, + 'loc': 'mha', + 'att_mask': decoder_mask, + 'pos_emb': None, + } + pack_ip = self.forward_enabled_adapters(pack_ip) + self_attn_output = pack_ip['x'] + self_attn_output = self.layer_norm_1(self_attn_output) enc_dec_attn_output = self.second_sub_layer(self_attn_output, encoder_states, encoder_states, encoder_mask) @@ -101,6 +141,16 @@ def forward_postln(self, decoder_query, decoder_mask, decoder_keys, encoder_stat output_states = self.third_sub_layer(enc_dec_attn_output) output_states += enc_dec_attn_output + + if self.is_adapter_available(): + # Call the linear adapters + pack_ip = { + 'x': output_states, + 'loc': 'post', + } + pack_ip = self.forward_enabled_adapters(pack_ip) + output_states = pack_ip['x'] + return self.layer_norm_3(output_states) def forward(self, decoder_query, decoder_mask, decoder_keys, encoder_states, encoder_mask): @@ -110,6 +160,20 @@ def forward(self, decoder_query, decoder_mask, decoder_keys, encoder_states, enc return self.forward_postln(decoder_query, decoder_mask, decoder_keys, encoder_states, encoder_mask) + def get_accepted_adapter_types(self) -> Set[type]: + types = super().get_accepted_adapter_types() + + if len(types) == 0: + self.set_accepted_adapter_types( + [ + adapter_utils.LINEAR_ADAPTER_CLASSPATH, + adapter_utils.TRANSFORMER_MHA_ADAPTER_CLASSPATH, + ] + ) + types = self.get_accepted_adapter_types() + return types + + class TransformerDecoder(nn.Module): def __init__( self, @@ -131,6 +195,8 @@ def __init__( else: self.final_layer_norm = None + self.d_model = hidden_size + layer = TransformerDecoderBlock( hidden_size, inner_size, @@ -219,3 +285,38 @@ def input_example(self, max_batch=1, max_dim=256): input_ids = torch.randint(low=0, high=2048, size=(max_batch, max_dim, 1024), device=sample.device) encoder_mask = torch.randint(low=0, high=1, size=(max_batch, max_dim), device=sample.device) return tuple([input_ids, encoder_mask, input_ids, encoder_mask]) + + +class TransformerDecoderAdapter(TransformerDecoder, adapter_mixins.AdapterModuleMixin): + + # Higher level forwarding + def add_adapter(self, name: str, cfg: dict): + cfg = self._update_adapter_cfg_input_dim(cfg) + for transformer_layer in self.layers: # type: adapter_mixins.AdapterModuleMixin + transformer_layer.add_adapter(name, cfg) + + def is_adapter_available(self) -> bool: + return any([transformer_layer.is_adapter_available() for transformer_layer in self.layers]) + + def set_enabled_adapters(self, name: Optional[str] = None, enabled: bool = True): + for transformer_layer in self.layers: # type: adapter_mixins.AdapterModuleMixin + transformer_layer.set_enabled_adapters(name=name, enabled=enabled) + + def get_enabled_adapters(self) -> List[str]: + names = set([]) + for transformer_layer in self.layers: # type: adapter_mixins.AdapterModuleMixin + names.update(transformer_layer.get_enabled_adapters()) + + names = sorted(list(names)) + return names + + def _update_adapter_cfg_input_dim(self, cfg: DictConfig): + cfg = adapter_utils.update_adapter_cfg_input_dim(self, cfg, module_dim=self.d_model) + return cfg + + +""" +Register any additional information +""" +if adapter_mixins.get_registered_adapter(TransformerDecoder) is None: + adapter_mixins.register_adapter(base_class=TransformerDecoder, adapter_class=TransformerDecoderAdapter) diff --git a/nemo/collections/asr/modules/transformer/transformer_encoders.py b/nemo/collections/asr/modules/transformer/transformer_encoders.py index 377155585f0a..6a7cfcf24502 100644 --- a/nemo/collections/asr/modules/transformer/transformer_encoders.py +++ b/nemo/collections/asr/modules/transformer/transformer_encoders.py @@ -149,6 +149,20 @@ def forward(self, encoder_query, encoder_mask, encoder_keys): return self.forward_postln(encoder_query, encoder_mask, encoder_keys) + def get_accepted_adapter_types(self) -> Set[type]: + types = super().get_accepted_adapter_types() + + if len(types) == 0: + self.set_accepted_adapter_types( + [ + adapter_utils.LINEAR_ADAPTER_CLASSPATH, + adapter_utils.TRANSFORMER_MHA_ADAPTER_CLASSPATH, + ] + ) + types = self.get_accepted_adapter_types() + return types + + class TransformerEncoder(nn.Module): def __init__( self, @@ -230,7 +244,6 @@ class TransformerEncoderAdapter(TransformerEncoder, adapter_mixins.AdapterModule # Higher level forwarding def add_adapter(self, name: str, cfg: dict): - self.check_supported_adapter_type_(cfg) cfg = self._update_adapter_cfg_input_dim(cfg) for transformer_layer in self.layers: # type: adapter_mixins.AdapterModuleMixin transformer_layer.add_adapter(name, cfg) @@ -254,19 +267,6 @@ def _update_adapter_cfg_input_dim(self, cfg: DictConfig): cfg = adapter_utils.update_adapter_cfg_input_dim(self, cfg, module_dim=self.d_model) return cfg - def get_accepted_adapter_types(self) -> Set[type]: - types = super().get_accepted_adapter_types() - - if len(types) == 0: - self.set_accepted_adapter_types( - [ - adapter_utils.LINEAR_ADAPTER_CLASSPATH, - adapter_utils.TRANSFORMER_ENCODER_MHA_ADAPTER_CLASSPATH, - ] - ) - types = self.get_accepted_adapter_types() - return types - """ Register any additional information diff --git a/nemo/collections/asr/modules/transformer/transformer_utils.py b/nemo/collections/asr/modules/transformer/transformer_utils.py index da9ffb8fbd00..f36bc07e623b 100644 --- a/nemo/collections/asr/modules/transformer/transformer_utils.py +++ b/nemo/collections/asr/modules/transformer/transformer_utils.py @@ -113,7 +113,14 @@ def get_nemo_transformer( else: raise ValueError(f"Unknown arch = {arch}") else: - model = TransformerDecoderNM( + + class_ = TransformerDecoderNM + + if cfg.get('adapter', False): + from nemo.core.classes.mixins.adapter_mixins import get_registered_adapter + class_ = get_registered_adapter(TransformerDecoderNM).adapter_class + + model = class_( vocab_size=cfg.get('vocab_size'), hidden_size=cfg.get('hidden_size'), num_layers=cfg.get('num_layers'), diff --git a/nemo/collections/asr/parts/utils/adapter_utils.py b/nemo/collections/asr/parts/utils/adapter_utils.py index bd0cbcec191c..39e7b8596625 100644 --- a/nemo/collections/asr/parts/utils/adapter_utils.py +++ b/nemo/collections/asr/parts/utils/adapter_utils.py @@ -35,7 +35,7 @@ ) # Transformer Adapters -TRANSFORMER_ENCODER_MHA_ADAPTER_CLASSPATH = "nemo.collections.asr.parts.submodules.adapters.transformer_multi_head_attention_adapter_module.TransformerEncoderMultiHeadAttentionAdapter" +TRANSFORMER_MHA_ADAPTER_CLASSPATH = "nemo.collections.asr.parts.submodules.adapters.transformer_multi_head_attention_adapter_module.TransformerEncoderMultiHeadAttentionAdapter" def convert_adapter_cfg_to_dict_config(cfg: DictConfig): diff --git a/tests/collections/asr/mixins/adapters/test_asr_adapter_mixin.py b/tests/collections/asr/mixins/adapters/test_asr_adapter_mixin.py index d15dffc49055..07420d9e464b 100644 --- a/tests/collections/asr/mixins/adapters/test_asr_adapter_mixin.py +++ b/tests/collections/asr/mixins/adapters/test_asr_adapter_mixin.py @@ -367,6 +367,7 @@ def multitask_model(test_data_dir): "hidden_act": "relu", "pre_ln": True, "vocab_size": None, # Will be set by the model at runtime + "adapter": True, # Add support for adapter class }, } @@ -596,7 +597,7 @@ def test_squeezeformer_forward_mha(self, squeezeformer_ctc_adapter, name): assert torch.mean(torch.abs(origial_output - new_output)) < 1e-5 @pytest.mark.unit - @pytest.mark.parametrize('name', ['adapter_0', 'encoder:adapter_0', 'transf_encoder:adapter_0']) + @pytest.mark.parametrize('name', ['adapter_0', 'encoder:adapter_0', 'transf_encoder:adapter_0', 'transf_decoder:adapter_0']) def test_canary_forward_mha(self, multitask_model, name): multitask_model.eval() torch.random.manual_seed(0) @@ -636,6 +637,34 @@ def test_canary_forward_mha(self, multitask_model, name): name="transf_encoder:adapter_1", cfg=get_adapter_cfg(in_features=128, atype='mha') ) + @pytest.mark.unit + @pytest.mark.parametrize('name', + ['transf_decoder:adapter_0']) + def test_canary_forward_mha_decoder_fails_without_support(self, multitask_model, name): + multitask_model.eval() + torch.random.manual_seed(0) + input_signal = torch.randn(2, 512) + input_signal_length = torch.tensor([512, 512], dtype=torch.int32) + transcript = torch.randint(0, multitask_model.tokenizer.vocab_size, size=(2, 10)) + transcript_len = torch.tensor([10, 9], dtype=torch.int32) + + origial_output = multitask_model( + input_signal=input_signal, + input_signal_length=input_signal_length, + transcript=transcript, + transcript_length=transcript_len, + ) + og_logprob = origial_output[0] + og_enc_out = origial_output[2] + + # Change internal class of transf_decoder module + adapter_class = multitask_model.transf_decoder.__class__ + multitask_model.transf_decoder.__class__ = get_registered_adapter(adapter_class).base_class + + with pytest.raises(AttributeError): + adapter_type = 'transf_mha' if 'transf' in name else 'mha' + multitask_model.add_adapter(name=name, cfg=get_adapter_cfg(in_features=128, atype=adapter_type)) + @pytest.mark.unit @pytest.mark.parametrize('name1', ['adapter_0', 'encoder:adapter_0', 'decoder:adapter_0']) @pytest.mark.parametrize('name2', ['adapter_1', 'encoder:adapter_1', 'decoder:adapter_1']) From 38b00d50a4fe9b8d70e86ef664b821a1cf37feb5 Mon Sep 17 00:00:00 2001 From: titu1994 Date: Fri, 7 Jun 2024 23:51:58 +0000 Subject: [PATCH 06/21] Apply isort and black reformatting Signed-off-by: titu1994 --- nemo/collections/asr/modules/transformer/transformer.py | 9 +++++++-- .../asr/modules/transformer/transformer_decoders.py | 5 ++--- .../asr/modules/transformer/transformer_encoders.py | 1 - .../asr/modules/transformer/transformer_utils.py | 1 + .../asr/mixins/adapters/test_asr_adapter_mixin.py | 7 ++++--- 5 files changed, 14 insertions(+), 9 deletions(-) diff --git a/nemo/collections/asr/modules/transformer/transformer.py b/nemo/collections/asr/modules/transformer/transformer.py index c98f8b671bdc..0ea376340d18 100644 --- a/nemo/collections/asr/modules/transformer/transformer.py +++ b/nemo/collections/asr/modules/transformer/transformer.py @@ -212,7 +212,12 @@ def __init__( @typecheck() def forward( - self, input_ids, decoder_mask, encoder_embeddings, encoder_mask, decoder_mems=None, + self, + input_ids, + decoder_mask, + encoder_embeddings, + encoder_mask, + decoder_mems=None, ): start_pos = 0 if decoder_mems is not None: @@ -282,7 +287,7 @@ def output_types(self) -> Optional[Dict[str, NeuralType]]: class TransformerDecoderNMAdapter(TransformerDecoderNM, adapter_mixins.AdapterModuleMixin): - DECODER_TYPE : type = TransformerDecoderAdapter + DECODER_TYPE: type = TransformerDecoderAdapter # Higher level forwarding def add_adapter(self, name: str, cfg: dict): diff --git a/nemo/collections/asr/modules/transformer/transformer_decoders.py b/nemo/collections/asr/modules/transformer/transformer_decoders.py index cecb5ab6be1e..f8ddbb74169a 100644 --- a/nemo/collections/asr/modules/transformer/transformer_decoders.py +++ b/nemo/collections/asr/modules/transformer/transformer_decoders.py @@ -13,11 +13,11 @@ # limitations under the License. import copy -from omegaconf import DictConfig -from typing import Optional, List, Set +from typing import List, Optional, Set import torch import torch.nn as nn +from omegaconf import DictConfig from nemo.collections.asr.modules.transformer.transformer_modules import MultiHeadAttention, PositionWiseFF from nemo.collections.asr.parts.submodules.adapters.attention_adapter_mixin import AttentionAdapterModuleMixin @@ -159,7 +159,6 @@ def forward(self, decoder_query, decoder_mask, decoder_keys, encoder_states, enc else: return self.forward_postln(decoder_query, decoder_mask, decoder_keys, encoder_states, encoder_mask) - def get_accepted_adapter_types(self) -> Set[type]: types = super().get_accepted_adapter_types() diff --git a/nemo/collections/asr/modules/transformer/transformer_encoders.py b/nemo/collections/asr/modules/transformer/transformer_encoders.py index 6a7cfcf24502..8c0ba933d5d2 100644 --- a/nemo/collections/asr/modules/transformer/transformer_encoders.py +++ b/nemo/collections/asr/modules/transformer/transformer_encoders.py @@ -148,7 +148,6 @@ def forward(self, encoder_query, encoder_mask, encoder_keys): else: return self.forward_postln(encoder_query, encoder_mask, encoder_keys) - def get_accepted_adapter_types(self) -> Set[type]: types = super().get_accepted_adapter_types() diff --git a/nemo/collections/asr/modules/transformer/transformer_utils.py b/nemo/collections/asr/modules/transformer/transformer_utils.py index f36bc07e623b..0c2f36cb7965 100644 --- a/nemo/collections/asr/modules/transformer/transformer_utils.py +++ b/nemo/collections/asr/modules/transformer/transformer_utils.py @@ -118,6 +118,7 @@ def get_nemo_transformer( if cfg.get('adapter', False): from nemo.core.classes.mixins.adapter_mixins import get_registered_adapter + class_ = get_registered_adapter(TransformerDecoderNM).adapter_class model = class_( diff --git a/tests/collections/asr/mixins/adapters/test_asr_adapter_mixin.py b/tests/collections/asr/mixins/adapters/test_asr_adapter_mixin.py index 07420d9e464b..3abd13186b94 100644 --- a/tests/collections/asr/mixins/adapters/test_asr_adapter_mixin.py +++ b/tests/collections/asr/mixins/adapters/test_asr_adapter_mixin.py @@ -597,7 +597,9 @@ def test_squeezeformer_forward_mha(self, squeezeformer_ctc_adapter, name): assert torch.mean(torch.abs(origial_output - new_output)) < 1e-5 @pytest.mark.unit - @pytest.mark.parametrize('name', ['adapter_0', 'encoder:adapter_0', 'transf_encoder:adapter_0', 'transf_decoder:adapter_0']) + @pytest.mark.parametrize( + 'name', ['adapter_0', 'encoder:adapter_0', 'transf_encoder:adapter_0', 'transf_decoder:adapter_0'] + ) def test_canary_forward_mha(self, multitask_model, name): multitask_model.eval() torch.random.manual_seed(0) @@ -638,8 +640,7 @@ def test_canary_forward_mha(self, multitask_model, name): ) @pytest.mark.unit - @pytest.mark.parametrize('name', - ['transf_decoder:adapter_0']) + @pytest.mark.parametrize('name', ['transf_decoder:adapter_0']) def test_canary_forward_mha_decoder_fails_without_support(self, multitask_model, name): multitask_model.eval() torch.random.manual_seed(0) From 86c0edb6f94bf01d72ffe424560d37db1a5e6997 Mon Sep 17 00:00:00 2001 From: Weiqing Wang Date: Thu, 13 Jun 2024 13:25:41 -0700 Subject: [PATCH 07/21] fix the freeze/unfreeze problem by replacing as_frozen with torch.inference_mode --- .../transformer/transformer_generators.py | 4 +- .../parts/submodules/rnnt_beam_decoding.py | 50 +++++++++---------- .../parts/submodules/rnnt_greedy_decoding.py | 44 ++++++++-------- .../transformer/transformer_generators.py | 4 +- 4 files changed, 47 insertions(+), 55 deletions(-) diff --git a/nemo/collections/asr/modules/transformer/transformer_generators.py b/nemo/collections/asr/modules/transformer/transformer_generators.py index 4061f54a907a..c45f92ad1a84 100644 --- a/nemo/collections/asr/modules/transformer/transformer_generators.py +++ b/nemo/collections/asr/modules/transformer/transformer_generators.py @@ -173,7 +173,7 @@ def _forward( def __call__( self, decoder_input_ids=None, encoder_hidden_states=None, encoder_input_mask=None, return_beam_scores=False ): - with self.as_frozen(): + with torch.inference_mode(): results = self._forward( decoder_input_ids, encoder_hidden_states, encoder_input_mask, return_beam_scores=return_beam_scores ) @@ -697,7 +697,7 @@ def _forward(self, src_ids, encoder_input_mask, decoder_input_ids=None, return_b return tgt def __call__(self, src_ids, encoder_input_mask, decoder_input_ids=None, return_beam_scores=False): - with self.as_frozen(): + with torch.inference_mode(): return self._forward(src_ids, encoder_input_mask, decoder_input_ids, return_beam_scores) def freeze(self) -> None: diff --git a/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py index ef3a0cddb286..4f724e2dfd94 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py @@ -369,7 +369,7 @@ def __call__( return_hat_ilm_default = self.joint.return_hat_ilm self.joint.return_hat_ilm = self.hat_subtract_ilm - with torch.no_grad(): + with torch.inference_mode(): # Apply optional preprocessing encoder_output = encoder_output.transpose(1, 2) # (B, T, D) @@ -384,38 +384,34 @@ def __call__( unit='sample', ) as idx_gen: - # Freeze the decoder and joint to prevent recording of gradients - # during the beam loop. - with self.decoder.as_frozen(), self.joint.as_frozen(): + _p = next(self.joint.parameters()) + dtype = _p.dtype - _p = next(self.joint.parameters()) - dtype = _p.dtype + # Decode every sample in the batch independently. + for batch_idx in idx_gen: + inseq = encoder_output[batch_idx : batch_idx + 1, : encoded_lengths[batch_idx], :] # [1, T, D] + logitlen = encoded_lengths[batch_idx] - # Decode every sample in the batch independently. - for batch_idx in idx_gen: - inseq = encoder_output[batch_idx : batch_idx + 1, : encoded_lengths[batch_idx], :] # [1, T, D] - logitlen = encoded_lengths[batch_idx] + if inseq.dtype != dtype: + inseq = inseq.to(dtype=dtype) - if inseq.dtype != dtype: - inseq = inseq.to(dtype=dtype) + # Extract partial hypothesis if exists + partial_hypothesis = partial_hypotheses[batch_idx] if partial_hypotheses is not None else None - # Extract partial hypothesis if exists - partial_hypothesis = partial_hypotheses[batch_idx] if partial_hypotheses is not None else None + # Execute the specific search strategy + nbest_hyps = self.search_algorithm( + inseq, logitlen, partial_hypotheses=partial_hypothesis + ) # sorted list of hypothesis - # Execute the specific search strategy - nbest_hyps = self.search_algorithm( - inseq, logitlen, partial_hypotheses=partial_hypothesis - ) # sorted list of hypothesis + # Prepare the list of hypotheses + nbest_hyps = pack_hypotheses(nbest_hyps) - # Prepare the list of hypotheses - nbest_hyps = pack_hypotheses(nbest_hyps) - - # Pack the result - if self.return_best_hypothesis: - best_hypothesis = nbest_hyps[0] # type: Hypothesis - else: - best_hypothesis = NBestHypotheses(nbest_hyps) # type: NBestHypotheses - hypotheses.append(best_hypothesis) + # Pack the result + if self.return_best_hypothesis: + best_hypothesis = nbest_hyps[0] # type: Hypothesis + else: + best_hypothesis = NBestHypotheses(nbest_hyps) # type: NBestHypotheses + hypotheses.append(best_hypothesis) self.decoder.train(decoder_training_state) self.joint.train(joint_training_state) diff --git a/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py index 420e49c96142..70ab74e7b014 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py @@ -383,14 +383,13 @@ def forward( hypotheses = [] # Process each sequence independently - with self.decoder.as_frozen(), self.joint.as_frozen(): - for batch_idx in range(encoder_output.size(0)): - inseq = encoder_output[batch_idx, :, :].unsqueeze(1) # [T, 1, D] - logitlen = encoded_lengths[batch_idx] + for batch_idx in range(encoder_output.size(0)): + inseq = encoder_output[batch_idx, :, :].unsqueeze(1) # [T, 1, D] + logitlen = encoded_lengths[batch_idx] - partial_hypothesis = partial_hypotheses[batch_idx] if partial_hypotheses is not None else None - hypothesis = self._greedy_decode(inseq, logitlen, partial_hypotheses=partial_hypothesis) - hypotheses.append(hypothesis) + partial_hypothesis = partial_hypotheses[batch_idx] if partial_hypotheses is not None else None + hypothesis = self._greedy_decode(inseq, logitlen, partial_hypotheses=partial_hypothesis) + hypotheses.append(hypothesis) # Pack results into Hypotheses packed_result = pack_hypotheses(hypotheses, encoded_lengths) @@ -720,12 +719,11 @@ def forward( self.decoder.eval() self.joint.eval() - with self.decoder.as_frozen(), self.joint.as_frozen(): - inseq = encoder_output # [B, T, D] + inseq = encoder_output # [B, T, D] - hypotheses = self._greedy_decode( - inseq, logitlen, device=inseq.device, partial_hypotheses=partial_hypotheses - ) + hypotheses = self._greedy_decode( + inseq, logitlen, device=inseq.device, partial_hypotheses=partial_hypotheses + ) # Pack the hypotheses results packed_result = pack_hypotheses(hypotheses, logitlen) @@ -2487,14 +2485,13 @@ def forward( hypotheses = [] # Process each sequence independently - with self.decoder.as_frozen(), self.joint.as_frozen(): - for batch_idx in range(encoder_output.size(0)): - inseq = encoder_output[batch_idx, :, :].unsqueeze(1) # [T, 1, D] - logitlen = encoded_lengths[batch_idx] + for batch_idx in range(encoder_output.size(0)): + inseq = encoder_output[batch_idx, :, :].unsqueeze(1) # [T, 1, D] + logitlen = encoded_lengths[batch_idx] - partial_hypothesis = partial_hypotheses[batch_idx] if partial_hypotheses is not None else None - hypothesis = self._greedy_decode(inseq, logitlen, partial_hypotheses=partial_hypothesis) - hypotheses.append(hypothesis) + partial_hypothesis = partial_hypotheses[batch_idx] if partial_hypotheses is not None else None + hypothesis = self._greedy_decode(inseq, logitlen, partial_hypotheses=partial_hypothesis) + hypotheses.append(hypothesis) # Pack results into Hypotheses packed_result = pack_hypotheses(hypotheses, encoded_lengths) @@ -2775,11 +2772,10 @@ def forward( self.decoder.eval() self.joint.eval() - with self.decoder.as_frozen(), self.joint.as_frozen(): - inseq = encoder_output # [B, T, D] - hypotheses = self._greedy_decode( - inseq, logitlen, device=inseq.device, partial_hypotheses=partial_hypotheses - ) + inseq = encoder_output # [B, T, D] + hypotheses = self._greedy_decode( + inseq, logitlen, device=inseq.device, partial_hypotheses=partial_hypotheses + ) # Pack the hypotheses results packed_result = pack_hypotheses(hypotheses, logitlen) diff --git a/nemo/collections/nlp/modules/common/transformer/transformer_generators.py b/nemo/collections/nlp/modules/common/transformer/transformer_generators.py index 6e17151dcd1b..3af450271898 100644 --- a/nemo/collections/nlp/modules/common/transformer/transformer_generators.py +++ b/nemo/collections/nlp/modules/common/transformer/transformer_generators.py @@ -173,7 +173,7 @@ def _forward( def __call__( self, decoder_input_ids=None, encoder_hidden_states=None, encoder_input_mask=None, return_beam_scores=False ): - with self.as_frozen(): + with torch.inference_mode(): return self._forward( decoder_input_ids, encoder_hidden_states, encoder_input_mask, return_beam_scores=return_beam_scores ) @@ -687,7 +687,7 @@ def _forward(self, src_ids, encoder_input_mask, decoder_input_ids=None, return_b return tgt def __call__(self, src_ids, encoder_input_mask, decoder_input_ids=None, return_beam_scores=False): - with self.as_frozen(): + with torch.inference_mode(): return self._forward(src_ids, encoder_input_mask, decoder_input_ids, return_beam_scores) def freeze(self) -> None: From 3e1e8fccdf20c601585b7d49fb8eb6f0ebf851e6 Mon Sep 17 00:00:00 2001 From: weiqingw4ng Date: Thu, 13 Jun 2024 21:54:48 +0000 Subject: [PATCH 08/21] Apply isort and black reformatting Signed-off-by: weiqingw4ng --- .../transformer/transformer_generators.py | 40 +++++++++++-------- .../parts/submodules/rnnt_beam_decoding.py | 11 ++--- .../transformer/transformer_generators.py | 40 +++++++++++-------- 3 files changed, 52 insertions(+), 39 deletions(-) diff --git a/nemo/collections/asr/modules/transformer/transformer_generators.py b/nemo/collections/asr/modules/transformer/transformer_generators.py index c45f92ad1a84..1a38e7fa4b6c 100644 --- a/nemo/collections/asr/modules/transformer/transformer_generators.py +++ b/nemo/collections/asr/modules/transformer/transformer_generators.py @@ -188,8 +188,7 @@ def __call__( return prefixes, scores, tgt def freeze(self) -> None: - """Freeze weights of embedding, decoder, and classification layers to prevent memory leak. - """ + """Freeze weights of embedding, decoder, and classification layers to prevent memory leak.""" for param in self.embedding.parameters(): param.requires_grad = False self.embedding.eval() @@ -201,8 +200,7 @@ def freeze(self) -> None: self.log_softmax.eval() def unfreeze(self) -> None: - """Unfreeze weights of embedding, decoder, and classification layers. - """ + """Unfreeze weights of embedding, decoder, and classification layers.""" for param in self.embedding.parameters(): param.requires_grad = True self.embedding.train() @@ -357,13 +355,13 @@ def _forward( # choose top-k hypotheses with length penalty applied len_penalties = self.compute_len_penalty(prefixes_len, self.len_pen) scores = scores / len_penalties - scores, indices_i = torch.topk(scores.view(-1, self.beam_size ** 2), self.beam_size, dim=1) + scores, indices_i = torch.topk(scores.view(-1, self.beam_size**2), self.beam_size, dim=1) scores = scores.view(-1, 1) * len_penalties # select prefixes which correspond to the chosen hypotheses prefixes = prefixes.unsqueeze(1).repeat(1, self.beam_size, 1) prefixes = torch.cat((prefixes, prefixes_i.unsqueeze(2)), dim=2) - prefixes = prefixes.view(batch_size, self.beam_size ** 2, -1) + prefixes = prefixes.view(batch_size, self.beam_size**2, -1) p_len = prefixes.size(2) prefixes_ids = indices_i.unsqueeze(2).repeat(1, 1, p_len) prefixes = prefixes.gather(1, prefixes_ids).view(-1, p_len) @@ -463,7 +461,10 @@ def _one_step_forward_lm(self, decoder_input_ids=None, lm_mems_list=None, pos=0) input_mask = mask_padded_tokens(decoder_input_ids, self.pad).float() lm_hidden_states = self.language_model.encoder.embedding.forward(decoder_input_ids, start_pos=pos) lm_mems_list = self.language_model.encoder.encoder.forward( - lm_hidden_states, input_mask, lm_mems_list, return_mems=True, + lm_hidden_states, + input_mask, + lm_mems_list, + return_mems=True, ) lm_log_probs = self.language_model.log_softmax.forward(hidden_states=lm_mems_list[-1][:, -1:]) return lm_log_probs, lm_mems_list @@ -639,13 +640,13 @@ def _forward(self, src_ids, encoder_input_mask, decoder_input_ids=None, return_b # choose top-k hypotheses with length penalty applied len_penalties = self.compute_len_penalty(prefixes_len, self.len_pen) scores = scores / len_penalties - scores, indices_i = torch.topk(scores.view(-1, self.beam_size ** 2), self.beam_size, dim=1) + scores, indices_i = torch.topk(scores.view(-1, self.beam_size**2), self.beam_size, dim=1) scores = scores.view(-1, 1) * len_penalties # select prefixes which correspond to the chosen hypotheses prefixes = prefixes.unsqueeze(1).repeat(1, self.beam_size, 1) prefixes = torch.cat((prefixes, prefixes_i.unsqueeze(2)), dim=2) - prefixes = prefixes.view(batch_size, self.beam_size ** 2, -1) + prefixes = prefixes.view(batch_size, self.beam_size**2, -1) p_len = prefixes.size(2) prefixes_ids = indices_i.unsqueeze(2).repeat(1, 1, p_len) prefixes = prefixes.gather(1, prefixes_ids).view(-1, p_len) @@ -701,8 +702,7 @@ def __call__(self, src_ids, encoder_input_mask, decoder_input_ids=None, return_b return self._forward(src_ids, encoder_input_mask, decoder_input_ids, return_beam_scores) def freeze(self) -> None: - """Freeze weights of embedding, decoder, and classification layers to prevent memory leak. - """ + """Freeze weights of embedding, decoder, and classification layers to prevent memory leak.""" for model_num in range(self.num_models): for param in self.embeddings[model_num].parameters(): param.requires_grad = False @@ -718,8 +718,7 @@ def freeze(self) -> None: self.encoders[model_num].eval() def unfreeze(self) -> None: - """Unfreeze weights of embedding, decoder, and classification layers. - """ + """Unfreeze weights of embedding, decoder, and classification layers.""" for model_num in range(self.num_models): for param in self.embeddings[model_num].parameters(): param.requires_grad = True @@ -781,13 +780,20 @@ def _one_step_forward( ): nmt_log_probs, decoder_mems_list = super()._one_step_forward( - decoder_input_ids, encoder_hidden_states, encoder_input_mask, decoder_mems_list, pos, + decoder_input_ids, + encoder_hidden_states, + encoder_input_mask, + decoder_mems_list, + pos, ) input_mask = mask_padded_tokens(decoder_input_ids, self.pad).float() lm_hidden_states = self.language_model.encoder.embedding.forward(decoder_input_ids, start_pos=pos) lm_mems_list = self.language_model.encoder.encoder.forward( - lm_hidden_states, input_mask, lm_mems_list, return_mems=True, + lm_hidden_states, + input_mask, + lm_mems_list, + return_mems=True, ) lm_log_probs = self.language_model.log_softmax.forward(hidden_states=lm_mems_list[-1][:, -1:]) @@ -863,13 +869,13 @@ def _forward( # choose top-k hypotheses with length penalty applied len_penalties = self.compute_len_penalty(prefixes_len, self.len_pen) scores = scores / len_penalties - scores, indices_i = torch.topk(scores.view(-1, self.beam_size ** 2), self.beam_size, dim=1) + scores, indices_i = torch.topk(scores.view(-1, self.beam_size**2), self.beam_size, dim=1) scores = scores.view(-1, 1) * len_penalties # select prefixes which correspond to the chosen hypotheses prefixes = prefixes.unsqueeze(1).repeat(1, self.beam_size, 1) prefixes = torch.cat((prefixes, prefixes_i.unsqueeze(2)), dim=2) - prefixes = prefixes.view(batch_size, self.beam_size ** 2, -1) + prefixes = prefixes.view(batch_size, self.beam_size**2, -1) p_len = prefixes.size(2) prefixes_ids = indices_i.unsqueeze(2).repeat(1, 1, p_len) prefixes = prefixes.gather(1, prefixes_ids).view(-1, p_len) diff --git a/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py index 4f724e2dfd94..25becda6fa75 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py @@ -201,8 +201,7 @@ class BeamRNNTInfer(Typing): @property def input_types(self): - """Returns definitions of module input ports. - """ + """Returns definitions of module input ports.""" return { "encoder_output": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()), "encoded_lengths": NeuralType(tuple('B'), LengthsType()), @@ -211,8 +210,7 @@ def input_types(self): @property def output_types(self): - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" return {"predictions": [NeuralType(elements_type=HypothesisType())]} def __init__( @@ -635,7 +633,10 @@ def default_beam_search( # keep those hypothesis that have scores greater than next search generation hyps_max = float(max(hyps, key=lambda x: x.score).score) - kept_most_prob = sorted([hyp for hyp in kept_hyps if hyp.score > hyps_max], key=lambda x: x.score,) + kept_most_prob = sorted( + [hyp for hyp in kept_hyps if hyp.score > hyps_max], + key=lambda x: x.score, + ) # If enough hypothesis have scores greater than next search generation, # stop beam search. diff --git a/nemo/collections/nlp/modules/common/transformer/transformer_generators.py b/nemo/collections/nlp/modules/common/transformer/transformer_generators.py index 3af450271898..a76b1606fd1e 100644 --- a/nemo/collections/nlp/modules/common/transformer/transformer_generators.py +++ b/nemo/collections/nlp/modules/common/transformer/transformer_generators.py @@ -179,8 +179,7 @@ def __call__( ) def freeze(self) -> None: - """Freeze weights of embedding, decoder, and classification layers to prevent memory leak. - """ + """Freeze weights of embedding, decoder, and classification layers to prevent memory leak.""" for param in self.embedding.parameters(): param.requires_grad = False self.embedding.eval() @@ -192,8 +191,7 @@ def freeze(self) -> None: self.log_softmax.eval() def unfreeze(self) -> None: - """Unfreeze weights of embedding, decoder, and classification layers. - """ + """Unfreeze weights of embedding, decoder, and classification layers.""" for param in self.embedding.parameters(): param.requires_grad = True self.embedding.train() @@ -347,13 +345,13 @@ def _forward( # choose top-k hypotheses with length penalty applied len_penalties = self.compute_len_penalty(prefixes_len, self.len_pen) scores = scores / len_penalties - scores, indices_i = torch.topk(scores.view(-1, self.beam_size ** 2), self.beam_size, dim=1) + scores, indices_i = torch.topk(scores.view(-1, self.beam_size**2), self.beam_size, dim=1) scores = scores.view(-1, 1) * len_penalties # select prefixes which correspond to the chosen hypotheses prefixes = prefixes.unsqueeze(1).repeat(1, self.beam_size, 1) prefixes = torch.cat((prefixes, prefixes_i.unsqueeze(2)), dim=2) - prefixes = prefixes.view(batch_size, self.beam_size ** 2, -1) + prefixes = prefixes.view(batch_size, self.beam_size**2, -1) p_len = prefixes.size(2) prefixes_ids = indices_i.unsqueeze(2).repeat(1, 1, p_len) prefixes = prefixes.gather(1, prefixes_ids).view(-1, p_len) @@ -453,7 +451,10 @@ def _one_step_forward_lm(self, decoder_input_ids=None, lm_mems_list=None, pos=0) input_mask = mask_padded_tokens(decoder_input_ids, self.pad).float() lm_hidden_states = self.language_model.encoder.embedding.forward(decoder_input_ids, start_pos=pos) lm_mems_list = self.language_model.encoder.encoder.forward( - lm_hidden_states, input_mask, lm_mems_list, return_mems=True, + lm_hidden_states, + input_mask, + lm_mems_list, + return_mems=True, ) lm_log_probs = self.language_model.log_softmax.forward(hidden_states=lm_mems_list[-1][:, -1:]) return lm_log_probs, lm_mems_list @@ -629,13 +630,13 @@ def _forward(self, src_ids, encoder_input_mask, decoder_input_ids=None, return_b # choose top-k hypotheses with length penalty applied len_penalties = self.compute_len_penalty(prefixes_len, self.len_pen) scores = scores / len_penalties - scores, indices_i = torch.topk(scores.view(-1, self.beam_size ** 2), self.beam_size, dim=1) + scores, indices_i = torch.topk(scores.view(-1, self.beam_size**2), self.beam_size, dim=1) scores = scores.view(-1, 1) * len_penalties # select prefixes which correspond to the chosen hypotheses prefixes = prefixes.unsqueeze(1).repeat(1, self.beam_size, 1) prefixes = torch.cat((prefixes, prefixes_i.unsqueeze(2)), dim=2) - prefixes = prefixes.view(batch_size, self.beam_size ** 2, -1) + prefixes = prefixes.view(batch_size, self.beam_size**2, -1) p_len = prefixes.size(2) prefixes_ids = indices_i.unsqueeze(2).repeat(1, 1, p_len) prefixes = prefixes.gather(1, prefixes_ids).view(-1, p_len) @@ -691,8 +692,7 @@ def __call__(self, src_ids, encoder_input_mask, decoder_input_ids=None, return_b return self._forward(src_ids, encoder_input_mask, decoder_input_ids, return_beam_scores) def freeze(self) -> None: - """Freeze weights of embedding, decoder, and classification layers to prevent memory leak. - """ + """Freeze weights of embedding, decoder, and classification layers to prevent memory leak.""" for model_num in range(self.num_models): for param in self.embeddings[model_num].parameters(): param.requires_grad = False @@ -708,8 +708,7 @@ def freeze(self) -> None: self.encoders[model_num].eval() def unfreeze(self) -> None: - """Unfreeze weights of embedding, decoder, and classification layers. - """ + """Unfreeze weights of embedding, decoder, and classification layers.""" for model_num in range(self.num_models): for param in self.embeddings[model_num].parameters(): param.requires_grad = True @@ -771,13 +770,20 @@ def _one_step_forward( ): nmt_log_probs, decoder_mems_list = super()._one_step_forward( - decoder_input_ids, encoder_hidden_states, encoder_input_mask, decoder_mems_list, pos, + decoder_input_ids, + encoder_hidden_states, + encoder_input_mask, + decoder_mems_list, + pos, ) input_mask = mask_padded_tokens(decoder_input_ids, self.pad).float() lm_hidden_states = self.language_model.encoder.embedding.forward(decoder_input_ids, start_pos=pos) lm_mems_list = self.language_model.encoder.encoder.forward( - lm_hidden_states, input_mask, lm_mems_list, return_mems=True, + lm_hidden_states, + input_mask, + lm_mems_list, + return_mems=True, ) lm_log_probs = self.language_model.log_softmax.forward(hidden_states=lm_mems_list[-1][:, -1:]) @@ -853,13 +859,13 @@ def _forward( # choose top-k hypotheses with length penalty applied len_penalties = self.compute_len_penalty(prefixes_len, self.len_pen) scores = scores / len_penalties - scores, indices_i = torch.topk(scores.view(-1, self.beam_size ** 2), self.beam_size, dim=1) + scores, indices_i = torch.topk(scores.view(-1, self.beam_size**2), self.beam_size, dim=1) scores = scores.view(-1, 1) * len_penalties # select prefixes which correspond to the chosen hypotheses prefixes = prefixes.unsqueeze(1).repeat(1, self.beam_size, 1) prefixes = torch.cat((prefixes, prefixes_i.unsqueeze(2)), dim=2) - prefixes = prefixes.view(batch_size, self.beam_size ** 2, -1) + prefixes = prefixes.view(batch_size, self.beam_size**2, -1) p_len = prefixes.size(2) prefixes_ids = indices_i.unsqueeze(2).repeat(1, 1, p_len) prefixes = prefixes.gather(1, prefixes_ids).view(-1, p_len) From 6a6d1d7e150f9a33c558672ad22a481ee3b4cf89 Mon Sep 17 00:00:00 2001 From: Somshubra Majumdar Date: Thu, 13 Jun 2024 18:44:55 -0700 Subject: [PATCH 09/21] Update tests to new generic way of module update Signed-off-by: smajumdar --- nemo/core/classes/mixins/adapter_mixins.py | 63 ++++++++ .../adapters/test_adapter_model_mixin.py | 137 ++++++++++++------ 2 files changed, 154 insertions(+), 46 deletions(-) diff --git a/nemo/core/classes/mixins/adapter_mixins.py b/nemo/core/classes/mixins/adapter_mixins.py index 1cc902bd7586..2e9a0394a3f3 100644 --- a/nemo/core/classes/mixins/adapter_mixins.py +++ b/nemo/core/classes/mixins/adapter_mixins.py @@ -123,6 +123,55 @@ def _prepare_default_adapter_config(*, global_key: str, meta_key: str, cfg: Dict return cfg +def update_module_class_with_adapter_class(module: nn.Module, cfg: DictConfig, update_config: bool = True): + """ + Recursively walks through the module and its children, checking if the class is registered in the adapter registry. + If it is, the module's class is swapped with the registered adapter class. + Also updates the config with the adapter classpath, if required. + + Args: + module: torch.nn.Module to recurse through. + cfg: DictConfig object or dict that contains the config of the module. + update_config: Bool, whether to update the config with the adapter classpath. + """ + def inplace_recursive_walk_dict(d: Union[dict, DictConfig], base_class_path: str, adapter_class_path: str): + """ + Utility function to recursively walk through a dictionary and update the classpath if required. + Update is done inplace + + Args: + d: Dict to recurse through. + base_class_path: The str classpath of the base class. + adapter_class_path: The str classpath of the adapter class. + """ + for k, v in d.items(): + if isinstance(v, dict): + inplace_recursive_walk_dict(v, base_class_path, adapter_class_path) + elif k in ('target', '_target_') and isinstance(v, str) and v == base_class_path: + logging.info(f"Updating config from {v} (base class) to {adapter_class_path} (adapter compatible " + f"class)") + d[k] = adapter_class_path + + if not isinstance(module, AdapterModuleMixin): + info = get_registered_adapter(mod.__class__) + if info is not None: + # Swap the registered class with its registered adapter class. + # Due to direct inheritance of the Adapter subclass from the original class, + # the module's class container will be replaced with the adapter class. + logging.info(f"Swapping class {info.base_class_path} with adapter compatible class: " + f"{info.adapter_class_path}") + adapter_cls = info.adapter_class + module.__class__ = adapter_cls + + if update_config: + # Update the adapter config with the registered adapter config + # Find the location where the original module was registered in config + # and replace it with the adapter classpath. + original_classpath = info.base_class_path + adapter_classpath = info.adapter_class_path + inplace_recursive_walk_dict(cfg, original_classpath, adapter_classpath) + + class AdapterModuleMixin(ABC): """Generic Adapter Mixin that can augment any torch.nn.Module with Adapter module support. @@ -989,6 +1038,20 @@ def update_adapter_cfg(self, cfg: DictConfig): if isinstance(module, AdapterModuleMixin): module.adapter_cfg = cfg + def replace_adapter_compatible_modules(self, module: Optional[nn.Module] = None, update_config: bool = True): + """ + Utility method to replace all modules with Adapter variants, if they exist. + + Args: + module: The module to be replaced with an Adapter variant. + update_config: A flag that determines if the config should be updated or not. + """ + if module is None: + module = self + + for mod in module.modules(): + update_module_class_with_adapter_class(mod, cfg=self.cfg, update_config=update_config) + @property def adapter_module_names(self) -> List[str]: """ diff --git a/tests/core/mixins/adapters/test_adapter_model_mixin.py b/tests/core/mixins/adapters/test_adapter_model_mixin.py index 87c6b4e4cfb3..922878b4bb72 100644 --- a/tests/core/mixins/adapters/test_adapter_model_mixin.py +++ b/tests/core/mixins/adapters/test_adapter_model_mixin.py @@ -14,12 +14,12 @@ import os import shutil import tempfile -from typing import Tuple +from typing import Tuple, List, Optional import pytest import torch from hydra.utils import instantiate -from omegaconf import DictConfig, OmegaConf +from omegaconf import DictConfig, OmegaConf, open_dict from nemo.core import ModelPT, NeuralModule from nemo.core.classes.mixins import adapter_mixin_strategies, adapter_mixins @@ -79,13 +79,13 @@ class DefaultModelAdapterMixin(AdapterModelPTMixin): def setup_adapters(self): supports_adapters = False - # Check the inheriting class' modules supports adapters or not - if hasattr(self, 'encoder') and isinstance(self.encoder, AdapterModuleMixin): - supports_adapters |= True - - if hasattr(self, 'decoder') and isinstance(self.decoder, AdapterModuleMixin): - supports_adapters |= True + # At least the encoder must extend AdapterModuleMixin + valid_adapter_names = [x for x in self.adapter_module_names if x != ''] + for module_name in valid_adapter_names: + if hasattr(self, module_name) and isinstance(getattr(self, module_name), AdapterModuleMixin): + supports_adapters |= True + # If adapters are supported, setup the adapter config + any modules (pre-existing adapter modules) if supports_adapters: super().setup_adapters() @@ -96,66 +96,98 @@ def add_adapter(self, name: str, cfg: DictConfig): # Resolve module name and adapter name module_name, adapter_name = self.resolve_adapter_module_name_(name) - # Try to retrieve global adapter config - global_config = self._get_global_cfg() - - # forward the method call to the individual modules - # If module name is empty, it is a global adapter, otherwise it is a local adapter - if (module_name == '' and global_config.get('encoder_adapter', True)) or (module_name == 'encoder'): - if hasattr(self, 'encoder'): - self.encoder.add_adapter(name, cfg) - - if (module_name == '' and global_config.get('decoder_adapter', False)) or (module_name == 'decoder'): - if hasattr(self, 'decoder'): - self.decoder.add_adapter(name, cfg) + # Use + as a splitter, in order to share one name across multiple modules + if '+' in module_name: + module_names = module_name.split('+') + else: + module_names = [module_name] + + valid_module_names = [x for x in self.adapter_module_names if x != ''] + default_module_name = self.default_adapter_module_name + + # Update the model.cfg with information about the new adapter from cfg + for module_name in module_names: + # Check if encoder adapters should be added + if module_name == '': + for default in default_module_name: # This model has multiple default modules + if hasattr(self, default): + # Dispatch the call to the default model. + getattr(self, default).add_adapter(name=name, cfg=cfg) + + elif module_name in valid_module_names: + # Check if module exists + if hasattr(self, module_name): + # Dispatch the call to the module. + getattr(self, module_name).add_adapter(name=name, cfg=cfg) def set_enabled_adapters(self, name=None, enabled: bool = True): # check if valid model with some adapter support super().set_enabled_adapters(name, enabled) - # Resolve module name and adapter name + # Resolve the module name and adapter name if name is not None: module_name, _ = self.resolve_adapter_module_name_(name) else: module_name = None - # Try to retrieve global adapter config - global_config = self._get_global_cfg() - - # Forward the method call to the individual modules - if name is None or global_config.get('encoder_adapter', True) or module_name in ('', 'encoder'): - if hasattr(self, 'encoder') and self.encoder.is_adapter_available(): - self.encoder.set_enabled_adapters(name, enabled) - - if name is None or global_config.get('decoder_adapter', False) or module_name == 'decoder': - if hasattr(self, 'decoder') and self.decoder.is_adapter_available(): - self.decoder.set_enabled_adapters(name, enabled) + # Use + as a splitter, in order to share one name across multiple modules + if module_name is not None and '+' in module_name: + module_names = module_name.split('+') + else: + module_names = [module_name] + + valid_module_names = [x for x in self.adapter_module_names if x != ''] + default_module_name = self.default_adapter_module_name + + # Check if default module name is None or not + if default_module_name is None: + raise ValueError( + f"Default module name is None. Class {self.__class__.__name__} must implement " + f"`default_adapter_module_name`" + ) + + # Forward the method call to the individual modules if they exist + for module_name in module_names: + # Check if encoder adapters should be used + + if module_name == '': + for default in default_module_name: + if hasattr(self, default) and isinstance(getattr(self, default), AdapterModuleMixin): + if getattr(self, default).is_adapter_available(): + # Dispatch the call to the default model. + getattr(self, default).set_enabled_adapters(name=name, enabled=enabled) + + elif module_name in valid_module_names: + if hasattr(self, module_name) and isinstance(getattr(self, module_name), AdapterModuleMixin): + if getattr(self, module_name).is_adapter_available(): + # Dispatch the call to the module. + getattr(self, module_name).set_enabled_adapters(name=name, enabled=enabled) def get_enabled_adapters(self) -> list: enabled_adapters = super().get_enabled_adapters() - # Forward the method call to the individual modules - if hasattr(self, 'encoder') and isinstance(self.encoder, AdapterModuleMixin): - encoder_adapters = self.encoder.get_enabled_adapters() - enabled_adapters.extend(encoder_adapters) + valid_module_names = [x for x in self.adapter_module_names if x != ''] - if hasattr(self, 'decoder') and isinstance(self.decoder, AdapterModuleMixin): - decoder_adapters = self.decoder.get_enabled_adapters() - enabled_adapters.extend(decoder_adapters) + # Check if encoder adapters should be used or are enabled + for module_name in valid_module_names: + if hasattr(self, module_name) and isinstance(getattr(self, module_name), AdapterModuleMixin): + enabled_adapters.extend(getattr(self, module_name).get_enabled_adapters()) + + enabled_adapters = list(sorted(list(set(enabled_adapters)))) return enabled_adapters def is_adapter_available(self) -> bool: adapters_available = super().is_adapter_available() - # Try to retrieve global adapter config - # Forward the method call to the individual modules - if hasattr(self, 'encoder') and isinstance(self.encoder, AdapterModuleMixin): - print("Encoder is adapter available", self.encoder.is_adapter_available()) - adapters_available |= self.encoder.is_adapter_available() + valid_module_names = [x for x in self.adapter_module_names if x != ''] - if hasattr(self, 'decoder') and isinstance(self.decoder, AdapterModuleMixin): - adapters_available |= self.decoder.is_adapter_available() + # Forward the method call to the individual modules + for module_name in valid_module_names: + print("Module name", module_name) + if hasattr(self, module_name) and isinstance(getattr(self, module_name), AdapterModuleMixin): + adapters_available |= getattr(self, module_name).is_adapter_available() + print("Adapter available for module", module_name, getattr(self, module_name).is_adapter_available()) return adapters_available @@ -198,6 +230,19 @@ def adapter_module_names(self) -> list: valid_adapter_modules = ['', 'encoder', 'decoder'] return valid_adapter_modules + @property + def default_adapter_module_name(self) -> Optional[List[str]]: + global_config = self._get_global_cfg() + default_modules = [] + encoder_adapter = global_config.get('encoder_adapter', True) + decoder_adapter = global_config.get('decoder_adapter', False) + + if encoder_adapter: + default_modules.append('encoder') + if decoder_adapter: + default_modules.append('decoder') + return default_modules + class DefaultAdapterModel(ModelPT, DefaultModelAdapterMixin): def __init__(self, cfg, trainer=None): From 6d570491a2b388498e1762fca5066d4218a39987 Mon Sep 17 00:00:00 2001 From: Somshubra Majumdar Date: Thu, 13 Jun 2024 19:17:41 -0700 Subject: [PATCH 10/21] Finalize code for update module Signed-off-by: smajumdar --- nemo/core/classes/mixins/adapter_mixins.py | 59 +++++++++++-------- .../adapters/test_adapter_model_mixin.py | 35 +++++++++++ 2 files changed, 71 insertions(+), 23 deletions(-) diff --git a/nemo/core/classes/mixins/adapter_mixins.py b/nemo/core/classes/mixins/adapter_mixins.py index 2e9a0394a3f3..f033ae70fccd 100644 --- a/nemo/core/classes/mixins/adapter_mixins.py +++ b/nemo/core/classes/mixins/adapter_mixins.py @@ -123,7 +123,8 @@ def _prepare_default_adapter_config(*, global_key: str, meta_key: str, cfg: Dict return cfg -def update_module_class_with_adapter_class(module: nn.Module, cfg: DictConfig, update_config: bool = True): +def update_module_class_with_adapter_class(module: nn.Module, cfg: DictConfig, update_config: bool = True, + verbose: bool = True): """ Recursively walks through the module and its children, checking if the class is registered in the adapter registry. If it is, the module's class is swapped with the registered adapter class. @@ -133,6 +134,7 @@ def update_module_class_with_adapter_class(module: nn.Module, cfg: DictConfig, u module: torch.nn.Module to recurse through. cfg: DictConfig object or dict that contains the config of the module. update_config: Bool, whether to update the config with the adapter classpath. + verbose: Bool, whether to log the changes made to the module and config. """ def inplace_recursive_walk_dict(d: Union[dict, DictConfig], base_class_path: str, adapter_class_path: str): """ @@ -144,32 +146,40 @@ def inplace_recursive_walk_dict(d: Union[dict, DictConfig], base_class_path: str base_class_path: The str classpath of the base class. adapter_class_path: The str classpath of the adapter class. """ - for k, v in d.items(): - if isinstance(v, dict): + for k, v in d.items(): # Loop through all k, v pairs + if isinstance(v, (dict, DictConfig)): # If value is a dict, recurse through it inplace_recursive_walk_dict(v, base_class_path, adapter_class_path) + + # If key is target and value is base class, update the value to adapter class elif k in ('target', '_target_') and isinstance(v, str) and v == base_class_path: - logging.info(f"Updating config from {v} (base class) to {adapter_class_path} (adapter compatible " - f"class)") + if verbose: + logging.info(f"Updating config from {v} (base class) to {adapter_class_path} (adapter compatible " + f"class)") + + # Update the value inplace d[k] = adapter_class_path - if not isinstance(module, AdapterModuleMixin): - info = get_registered_adapter(mod.__class__) - if info is not None: - # Swap the registered class with its registered adapter class. - # Due to direct inheritance of the Adapter subclass from the original class, - # the module's class container will be replaced with the adapter class. + if not isinstance(module, AdapterModuleMixin): + info = get_registered_adapter(module.__class__) + if info is not None: + if verbose: logging.info(f"Swapping class {info.base_class_path} with adapter compatible class: " f"{info.adapter_class_path}") - adapter_cls = info.adapter_class - module.__class__ = adapter_cls - if update_config: - # Update the adapter config with the registered adapter config - # Find the location where the original module was registered in config - # and replace it with the adapter classpath. - original_classpath = info.base_class_path - adapter_classpath = info.adapter_class_path - inplace_recursive_walk_dict(cfg, original_classpath, adapter_classpath) + # Swap the registered class with its registered adapter class. + # Due to direct inheritance of the Adapter subclass from the original class, + # the module's class container will be replaced with the adapter class. + + adapter_cls = info.adapter_class + module.__class__ = adapter_cls + + if update_config: + # Update the adapter config with the registered adapter config + # Find the location where the original module was registered in config + # and replace it with the adapter classpath. + original_classpath = info.base_class_path + adapter_classpath = info.adapter_class_path + inplace_recursive_walk_dict(cfg, original_classpath, adapter_classpath) class AdapterModuleMixin(ABC): @@ -1038,19 +1048,22 @@ def update_adapter_cfg(self, cfg: DictConfig): if isinstance(module, AdapterModuleMixin): module.adapter_cfg = cfg - def replace_adapter_compatible_modules(self, module: Optional[nn.Module] = None, update_config: bool = True): + def replace_adapter_compatible_modules(self, module: Optional[nn.Module] = None, update_config: bool = True, + verbose: bool = True): """ Utility method to replace all modules with Adapter variants, if they exist. Args: module: The module to be replaced with an Adapter variant. update_config: A flag that determines if the config should be updated or not. + verbose: A flag that determines if the method should log the changes made or not. """ if module is None: module = self - for mod in module.modules(): - update_module_class_with_adapter_class(mod, cfg=self.cfg, update_config=update_config) + # Update the given module itself, and then all its children modules + for name, mod in module.named_modules(): + update_module_class_with_adapter_class(mod, cfg=self.cfg, update_config=update_config, verbose=verbose) @property def adapter_module_names(self) -> List[str]: diff --git a/tests/core/mixins/adapters/test_adapter_model_mixin.py b/tests/core/mixins/adapters/test_adapter_model_mixin.py index 922878b4bb72..d36c7ea485d1 100644 --- a/tests/core/mixins/adapters/test_adapter_model_mixin.py +++ b/tests/core/mixins/adapters/test_adapter_model_mixin.py @@ -347,6 +347,41 @@ def test_base_model_no_support_for_adapters(self, caplog): logging._logger.propagate = False logging.set_verbosity(original_verbosity) + @pytest.mark.unit + def test_base_model_replace_adapter_compatible_modules(self, caplog): + cfg = get_model_config(in_features=50, update_adapter_cfg=False) + model = DefaultAdapterModel(cfg) + + with pytest.raises(AttributeError): + model.add_adapter(name='adapter_0', cfg=get_adapter_cfg()) + + # Replace the modules of the model dynamically to support adapters + model.replace_adapter_compatible_modules() + + assert isinstance(model.encoder, AdapterModuleMixin) + assert model.encoder.is_adapter_available() is False + + model.add_adapter(name='encoder:adapter_0', cfg=get_adapter_cfg()) + assert model.encoder.is_adapter_available() is True + + @pytest.mark.unit + def test_base_model_replace_adapter_compatible_encoder_only(self, caplog): + cfg = get_model_config(in_features=50, update_adapter_cfg=False) + model = DefaultAdapterModel(cfg) + + with pytest.raises(AttributeError): + model.add_adapter(name='adapter_0', cfg=get_adapter_cfg()) + + # Replace the modules of the model dynamically to support adapters + model.replace_adapter_compatible_modules(model.encoder, update_config=True) + + assert isinstance(model.encoder, AdapterModuleMixin) + assert model.encoder.is_adapter_available() is False + + model.add_adapter(name='encoder:adapter_0', cfg=get_adapter_cfg()) + + assert model.encoder.is_adapter_available() is True + @pytest.mark.unit def test_single_adapter(self): cfg = get_model_config(in_features=50) From f4a586432ceb7361cb01b184e5698a48340a9c29 Mon Sep 17 00:00:00 2001 From: titu1994 Date: Fri, 14 Jun 2024 02:18:30 +0000 Subject: [PATCH 11/21] Apply isort and black reformatting Signed-off-by: titu1994 --- nemo/core/classes/mixins/adapter_mixins.py | 22 ++++++++++++------- .../adapters/test_adapter_model_mixin.py | 22 ++++++++++++++----- 2 files changed, 30 insertions(+), 14 deletions(-) diff --git a/nemo/core/classes/mixins/adapter_mixins.py b/nemo/core/classes/mixins/adapter_mixins.py index f033ae70fccd..ca29ed4b1313 100644 --- a/nemo/core/classes/mixins/adapter_mixins.py +++ b/nemo/core/classes/mixins/adapter_mixins.py @@ -123,8 +123,9 @@ def _prepare_default_adapter_config(*, global_key: str, meta_key: str, cfg: Dict return cfg -def update_module_class_with_adapter_class(module: nn.Module, cfg: DictConfig, update_config: bool = True, - verbose: bool = True): +def update_module_class_with_adapter_class( + module: nn.Module, cfg: DictConfig, update_config: bool = True, verbose: bool = True +): """ Recursively walks through the module and its children, checking if the class is registered in the adapter registry. If it is, the module's class is swapped with the registered adapter class. @@ -136,6 +137,7 @@ def update_module_class_with_adapter_class(module: nn.Module, cfg: DictConfig, u update_config: Bool, whether to update the config with the adapter classpath. verbose: Bool, whether to log the changes made to the module and config. """ + def inplace_recursive_walk_dict(d: Union[dict, DictConfig], base_class_path: str, adapter_class_path: str): """ Utility function to recursively walk through a dictionary and update the classpath if required. @@ -153,8 +155,9 @@ def inplace_recursive_walk_dict(d: Union[dict, DictConfig], base_class_path: str # If key is target and value is base class, update the value to adapter class elif k in ('target', '_target_') and isinstance(v, str) and v == base_class_path: if verbose: - logging.info(f"Updating config from {v} (base class) to {adapter_class_path} (adapter compatible " - f"class)") + logging.info( + f"Updating config from {v} (base class) to {adapter_class_path} (adapter compatible " f"class)" + ) # Update the value inplace d[k] = adapter_class_path @@ -163,8 +166,10 @@ def inplace_recursive_walk_dict(d: Union[dict, DictConfig], base_class_path: str info = get_registered_adapter(module.__class__) if info is not None: if verbose: - logging.info(f"Swapping class {info.base_class_path} with adapter compatible class: " - f"{info.adapter_class_path}") + logging.info( + f"Swapping class {info.base_class_path} with adapter compatible class: " + f"{info.adapter_class_path}" + ) # Swap the registered class with its registered adapter class. # Due to direct inheritance of the Adapter subclass from the original class, @@ -1048,8 +1053,9 @@ def update_adapter_cfg(self, cfg: DictConfig): if isinstance(module, AdapterModuleMixin): module.adapter_cfg = cfg - def replace_adapter_compatible_modules(self, module: Optional[nn.Module] = None, update_config: bool = True, - verbose: bool = True): + def replace_adapter_compatible_modules( + self, module: Optional[nn.Module] = None, update_config: bool = True, verbose: bool = True + ): """ Utility method to replace all modules with Adapter variants, if they exist. diff --git a/tests/core/mixins/adapters/test_adapter_model_mixin.py b/tests/core/mixins/adapters/test_adapter_model_mixin.py index d36c7ea485d1..9d39dc0cb45c 100644 --- a/tests/core/mixins/adapters/test_adapter_model_mixin.py +++ b/tests/core/mixins/adapters/test_adapter_model_mixin.py @@ -14,7 +14,7 @@ import os import shutil import tempfile -from typing import Tuple, List, Optional +from typing import List, Optional, Tuple import pytest import torch @@ -28,7 +28,7 @@ class DefaultModule(NeuralModule): - """ Define a default neural module (without adapter support)""" + """Define a default neural module (without adapter support)""" def __init__(self): super().__init__() @@ -51,7 +51,7 @@ def num_params(self): class DefaultModuleAdapter(DefaultModule, AdapterModuleMixin): - """ Subclass the DefaultModule, adding adapter module support""" + """Subclass the DefaultModule, adding adapter module support""" def forward(self, x): x = super(DefaultModuleAdapter, self).forward(x) @@ -66,7 +66,7 @@ def forward(self, x): class DefaultModelAdapterMixin(AdapterModelPTMixin): - """ Mixin class that implements this model's specific overrides to AdapterModelPTMixin + """Mixin class that implements this model's specific overrides to AdapterModelPTMixin It will container two modules, an encoder and a decoder, and both can have adapters. By default, encoder adapters are enabled, and decoder adapters are diabled. Decoder adapters can be enabled via the global_cfg in model.cfg.adapters. @@ -1014,8 +1014,18 @@ def test_multiple_decoder_save_load_adapter_only_exact_name(self): assert (original_state_dict[ogkey] - restored_state_dict[newkey]).abs().mean() < 1e-6 @pytest.mark.unit - @pytest.mark.parametrize("decoder", ["adapter_0",]) # "decoder:adapter_0" - @pytest.mark.parametrize("encoder", ["adapter_1",]) # "encoder:adapter_1" + @pytest.mark.parametrize( + "decoder", + [ + "adapter_0", + ], + ) # "decoder:adapter_0" + @pytest.mark.parametrize( + "encoder", + [ + "adapter_1", + ], + ) # "encoder:adapter_1" def test_multiple_save_load_adapter_with_multiple_load(self, decoder, encoder): # create a model config, but do not add global_cfg to it # we want to test just module level adapter From f92c082b3c9c71c01ccc0a4e316fa2b54ae4902c Mon Sep 17 00:00:00 2001 From: Somshubra Majumdar Date: Fri, 14 Jun 2024 10:24:19 -0700 Subject: [PATCH 12/21] Fix variable name Signed-off-by: smajumdar --- .../asr/modules/transformer/transformer_decoders.py | 12 ++++++------ .../asr/modules/transformer/transformer_encoders.py | 12 ++++++------ .../asr/parts/submodules/conformer_modules.py | 12 ++++++------ 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/nemo/collections/asr/modules/transformer/transformer_decoders.py b/nemo/collections/asr/modules/transformer/transformer_decoders.py index f8ddbb74169a..30c6179b85a6 100644 --- a/nemo/collections/asr/modules/transformer/transformer_decoders.py +++ b/nemo/collections/asr/modules/transformer/transformer_decoders.py @@ -84,14 +84,14 @@ def forward_preln(self, decoder_query, decoder_mask, decoder_keys, encoder_state if self.is_adapter_available(): # Call the MHA adapters - pack_ip = { + pack_input = { 'x': self_attn_output, 'loc': 'mha', 'att_mask': decoder_mask, 'pos_emb': None, } - pack_ip = self.forward_enabled_adapters(pack_ip) - self_attn_output = pack_ip['x'] + pack_input = self.forward_enabled_adapters(pack_input) + self_attn_output = pack_input['x'] residual = self_attn_output self_attn_output = self.layer_norm_2(self_attn_output) @@ -105,12 +105,12 @@ def forward_preln(self, decoder_query, decoder_mask, decoder_keys, encoder_state if self.is_adapter_available(): # Call the Linear adapters - pack_ip = { + pack_input = { 'x': output_states, 'loc': 'post', } - pack_ip = self.forward_enabled_adapters(pack_ip) - output_states = pack_ip['x'] + pack_input = self.forward_enabled_adapters(pack_input) + output_states = pack_input['x'] return output_states diff --git a/nemo/collections/asr/modules/transformer/transformer_encoders.py b/nemo/collections/asr/modules/transformer/transformer_encoders.py index 8c0ba933d5d2..d3116db82482 100644 --- a/nemo/collections/asr/modules/transformer/transformer_encoders.py +++ b/nemo/collections/asr/modules/transformer/transformer_encoders.py @@ -80,14 +80,14 @@ def forward_preln(self, encoder_query, encoder_mask, encoder_keys): if self.is_adapter_available(): # Call the MHA adapters - pack_ip = { + pack_input = { 'x': self_attn_output, 'loc': 'mha', 'att_mask': encoder_mask, 'pos_emb': None, } - pack_ip = self.forward_enabled_adapters(pack_ip) - self_attn_output = pack_ip['x'] + pack_input = self.forward_enabled_adapters(pack_input) + self_attn_output = pack_input['x'] residual = self_attn_output self_attn_output = self.layer_norm_2(self_attn_output) @@ -96,12 +96,12 @@ def forward_preln(self, encoder_query, encoder_mask, encoder_keys): if self.is_adapter_available(): # Call the Linear adapters - pack_ip = { + pack_input = { 'x': output_states, 'loc': 'post', } - pack_ip = self.forward_enabled_adapters(pack_ip) - output_states = pack_ip['x'] + pack_input = self.forward_enabled_adapters(pack_input) + output_states = pack_input['x'] return output_states diff --git a/nemo/collections/asr/parts/submodules/conformer_modules.py b/nemo/collections/asr/parts/submodules/conformer_modules.py index 5d83cbd005bc..c2d897d63225 100644 --- a/nemo/collections/asr/parts/submodules/conformer_modules.py +++ b/nemo/collections/asr/parts/submodules/conformer_modules.py @@ -183,14 +183,14 @@ def forward(self, x, att_mask=None, pos_emb=None, pad_mask=None, cache_last_chan if self.is_adapter_available(): # Call the MHA adapters - pack_ip = { + pack_input = { 'x': residual, 'loc': 'mha', 'att_mask': att_mask, 'pos_emb': pos_emb, } - pack_ip = self.forward_enabled_adapters(pack_ip) - residual = pack_ip['x'] + pack_input = self.forward_enabled_adapters(pack_input) + residual = pack_input['x'] x = self.norm_conv(residual) x = self.conv(x, pad_mask=pad_mask, cache=cache_last_time) @@ -206,12 +206,12 @@ def forward(self, x, att_mask=None, pos_emb=None, pad_mask=None, cache_last_chan if self.is_adapter_available(): # Call the adapters - pack_ip = { + pack_input = { 'x': x, 'loc': 'post', } - pack_ip = self.forward_enabled_adapters(pack_ip) - x = pack_ip['x'] + pack_input = self.forward_enabled_adapters(pack_input) + x = pack_input['x'] if self.is_access_enabled(getattr(self, "model_guid", None)) and self.access_cfg.get( 'save_encoder_tensors', False From cf5207b88b6b66a683a47a2b66debb22b06e67cd Mon Sep 17 00:00:00 2001 From: Somshubra Majumdar Date: Fri, 14 Jun 2024 19:23:04 -0700 Subject: [PATCH 13/21] Finalize projection support for transformer mha adapters Signed-off-by: smajumdar --- .../transformer/transformer_modules.py | 3 +- .../asr/parts/submodules/adapters/__init__.py | 4 +- ...mer_multi_head_attention_adapter_module.py | 16 +++--- .../asr/parts/utils/adapter_utils.py | 2 +- .../mixins/adapters/test_asr_adapter_mixin.py | 21 ++------ .../adapters/test_asr_adapter_modules.py | 49 +++++++++++++++++++ 6 files changed, 66 insertions(+), 29 deletions(-) diff --git a/nemo/collections/asr/modules/transformer/transformer_modules.py b/nemo/collections/asr/modules/transformer/transformer_modules.py index 25fb781f0cd4..d3dcce139ac1 100644 --- a/nemo/collections/asr/modules/transformer/transformer_modules.py +++ b/nemo/collections/asr/modules/transformer/transformer_modules.py @@ -203,8 +203,9 @@ def forward(self, queries, keys, values, attention_mask): attention_probs = self.attn_dropout(attention_probs) context = torch.matmul(attention_probs, value) + context_hidden_size = context.size()[-1] * self.num_attention_heads context = context.permute(0, 2, 1, 3).contiguous() - new_context_shape = context.size()[:-2] + (self.hidden_size,) + new_context_shape = context.size()[:-2] + (context_hidden_size,) context = context.view(*new_context_shape) # output projection diff --git a/nemo/collections/asr/parts/submodules/adapters/__init__.py b/nemo/collections/asr/parts/submodules/adapters/__init__.py index c987e4d24de2..c51d935bddd4 100644 --- a/nemo/collections/asr/parts/submodules/adapters/__init__.py +++ b/nemo/collections/asr/parts/submodules/adapters/__init__.py @@ -27,8 +27,8 @@ RelPositionMultiHeadAttentionAdapterConfig, ) from nemo.collections.asr.parts.submodules.adapters.transformer_multi_head_attention_adapter_module import ( - TransformerEncoderMultiHeadAttentionAdapter, - TransformerEncoderMultiHeadAttentionAdapterConfig, + TransformerMultiHeadAttentionAdapter, + TransformerMultiHeadAttentionAdapterConfig, ) # fmt: on diff --git a/nemo/collections/asr/parts/submodules/adapters/transformer_multi_head_attention_adapter_module.py b/nemo/collections/asr/parts/submodules/adapters/transformer_multi_head_attention_adapter_module.py index 4c35645657dc..473f5a095a2b 100644 --- a/nemo/collections/asr/parts/submodules/adapters/transformer_multi_head_attention_adapter_module.py +++ b/nemo/collections/asr/parts/submodules/adapters/transformer_multi_head_attention_adapter_module.py @@ -28,7 +28,7 @@ from nemo.core.classes.mixins import adapter_mixin_strategies, adapter_mixins -class TransformerEncoderMultiHeadAttentionAdapter( +class TransformerMultiHeadAttentionAdapter( transformer_modules.MultiHeadAttention, adapter_modules.AdapterModuleUtil ): """Multi-Head Attention layer of Transformer Encoder. @@ -76,9 +76,9 @@ def __init__( self.attn_head_size = self.proj_dim // num_attention_heads self.attn_scale = math.sqrt(math.sqrt(self.attn_head_size)) - self.query_net = nn.Linear(num_attention_heads, self.proj_dim) - self.key_net = nn.Linear(num_attention_heads, self.proj_dim) - self.value_net = nn.Linear(num_attention_heads, self.proj_dim) + self.query_net = nn.Linear(hidden_size, self.proj_dim) + self.key_net = nn.Linear(hidden_size, self.proj_dim) + self.value_net = nn.Linear(hidden_size, self.proj_dim) self.out_projection = nn.Linear(self.proj_dim, hidden_size) # Setup adapter strategy @@ -118,7 +118,7 @@ def get_default_strategy_config(self) -> 'dataclass': @dataclass -class TransformerEncoderMultiHeadAttentionAdapterConfig: +class TransformerMultiHeadAttentionAdapterConfig: hidden_size: int num_attention_heads: int attn_score_dropout: float = 0.0 @@ -126,12 +126,12 @@ class TransformerEncoderMultiHeadAttentionAdapterConfig: proj_dim: Optional[int] = None adapter_strategy: Optional[Any] = field(default_factory=lambda: MHAResidualAddAdapterStrategyConfig()) _target_: str = "{0}.{1}".format( - TransformerEncoderMultiHeadAttentionAdapter.__module__, TransformerEncoderMultiHeadAttentionAdapter.__name__ + TransformerMultiHeadAttentionAdapter.__module__, TransformerMultiHeadAttentionAdapter.__name__ ) """ Register the Adapter Modules to the Adapter Module Registry """ -if adapter_mixins.get_registered_adapter(TransformerEncoderMultiHeadAttentionAdapter) is None: +if adapter_mixins.get_registered_adapter(TransformerMultiHeadAttentionAdapter) is None: adapter_mixins.register_adapter( - base_class=transformer_modules.MultiHeadAttention, adapter_class=TransformerEncoderMultiHeadAttentionAdapter + base_class=transformer_modules.MultiHeadAttention, adapter_class=TransformerMultiHeadAttentionAdapter ) diff --git a/nemo/collections/asr/parts/utils/adapter_utils.py b/nemo/collections/asr/parts/utils/adapter_utils.py index 39e7b8596625..b85bdee7051a 100644 --- a/nemo/collections/asr/parts/utils/adapter_utils.py +++ b/nemo/collections/asr/parts/utils/adapter_utils.py @@ -35,7 +35,7 @@ ) # Transformer Adapters -TRANSFORMER_MHA_ADAPTER_CLASSPATH = "nemo.collections.asr.parts.submodules.adapters.transformer_multi_head_attention_adapter_module.TransformerEncoderMultiHeadAttentionAdapter" +TRANSFORMER_MHA_ADAPTER_CLASSPATH = "nemo.collections.asr.parts.submodules.adapters.transformer_multi_head_attention_adapter_module.TransformerMultiHeadAttentionAdapter" def convert_adapter_cfg_to_dict_config(cfg: DictConfig): diff --git a/tests/collections/asr/mixins/adapters/test_asr_adapter_mixin.py b/tests/collections/asr/mixins/adapters/test_asr_adapter_mixin.py index 3abd13186b94..4c8d51450123 100644 --- a/tests/collections/asr/mixins/adapters/test_asr_adapter_mixin.py +++ b/tests/collections/asr/mixins/adapters/test_asr_adapter_mixin.py @@ -419,11 +419,11 @@ def get_adapter_cfg(in_features=50, dim=100, norm_pos='pre', atype='linear', **k cfg = adapter_modules.LinearAdapterConfig(in_features=in_features, dim=dim, norm_position=norm_pos) elif atype == 'mha': cfg = multi_head_attention_adapter_module.MultiHeadAttentionAdapterConfig( - n_head=kwargs.get('n_head', 1), n_feat=in_features + n_head=kwargs.get('n_head', 1), n_feat=in_features, proj_dim=kwargs.get('proj_dim', None), ) elif atype == 'transf_mha': - cfg = transformer_multi_head_attention_adapter_module.TransformerEncoderMultiHeadAttentionAdapterConfig( - num_attention_heads=kwargs.get('n_head', 1), hidden_size=in_features + cfg = transformer_multi_head_attention_adapter_module.TransformerMultiHeadAttentionAdapterConfig( + num_attention_heads=kwargs.get('n_head', 1), hidden_size=in_features, proj_dim=kwargs.get('proj_dim', None), ) elif atype == 'relmha': cfg = multi_head_attention_adapter_module.RelPositionMultiHeadAttentionAdapterConfig( @@ -618,7 +618,7 @@ def test_canary_forward_mha(self, multitask_model, name): og_enc_out = origial_output[2] adapter_type = 'transf_mha' if 'transf' in name else 'mha' - multitask_model.add_adapter(name=name, cfg=get_adapter_cfg(in_features=128, atype=adapter_type)) + multitask_model.add_adapter(name=name, cfg=get_adapter_cfg(in_features=128, atype=adapter_type, proj_dim=4)) new_output = multitask_model( input_signal=input_signal, @@ -644,19 +644,6 @@ def test_canary_forward_mha(self, multitask_model, name): def test_canary_forward_mha_decoder_fails_without_support(self, multitask_model, name): multitask_model.eval() torch.random.manual_seed(0) - input_signal = torch.randn(2, 512) - input_signal_length = torch.tensor([512, 512], dtype=torch.int32) - transcript = torch.randint(0, multitask_model.tokenizer.vocab_size, size=(2, 10)) - transcript_len = torch.tensor([10, 9], dtype=torch.int32) - - origial_output = multitask_model( - input_signal=input_signal, - input_signal_length=input_signal_length, - transcript=transcript, - transcript_length=transcript_len, - ) - og_logprob = origial_output[0] - og_enc_out = origial_output[2] # Change internal class of transf_decoder module adapter_class = multitask_model.transf_decoder.__class__ diff --git a/tests/collections/asr/mixins/adapters/test_asr_adapter_modules.py b/tests/collections/asr/mixins/adapters/test_asr_adapter_modules.py index c4ee4b97a2a6..3a91d7b4762f 100644 --- a/tests/collections/asr/mixins/adapters/test_asr_adapter_modules.py +++ b/tests/collections/asr/mixins/adapters/test_asr_adapter_modules.py @@ -111,6 +111,22 @@ def test_rel_pos_encoding_adapter_config(self): assert cls_subset is None assert dataclass_subset is None + @pytest.mark.unit + def test_transformer_mha_adapter_config(self): + IGNORED_ARGS = ['_target_'] + + result = config_utils.assert_dataclass_signature_match( + adapter_modules.TransformerMultiHeadAttentionAdapter, + adapter_modules.TransformerMultiHeadAttentionAdapterConfig, + ignore_args=IGNORED_ARGS, + ) + + signatures_match, cls_subset, dataclass_subset = result + + assert signatures_match + assert cls_subset is None + assert dataclass_subset is None + @pytest.mark.unit @pytest.mark.parametrize('n_head', [1, 2, 10]) @pytest.mark.parametrize('proj_dim', [None, -1]) @@ -194,6 +210,31 @@ def test_relpos_encoding_init(self): assert (out - x).sum().abs() <= 1e-8 assert out.shape == x.shape + @pytest.mark.unit + @pytest.mark.parametrize('n_head', [1, 2, 10]) + @pytest.mark.parametrize('proj_dim', [None, -1]) + def test_transformer_mha_adapter_init(self, n_head, proj_dim): + torch.random.manual_seed(0) + x = torch.randn(2, 32, 50) + lengths = torch.randint(1, x.size(1), size=(x.size(0),)) + lengths[torch.randint(0, x.size(0), size=(1,))[0]] = x.size(1) + + adapter = adapter_modules.TransformerMultiHeadAttentionAdapter( + num_attention_heads=n_head, hidden_size=50, attn_layer_dropout=0.0, proj_dim=proj_dim + ) + + pad_mask, att_mask = get_mask(lengths) + att_mask = att_mask.unsqueeze(1) + + with torch.no_grad(): + assert adapter.out_projection.weight.sum() == 0 + if hasattr(adapter.out_projection, 'bias') and adapter.out_projection.bias is not None: + assert adapter.out_projection.bias.sum() == 0 + + out = adapter(x, x, x, att_mask) + assert out.sum().abs() <= 1e-8 + assert out.shape == x.shape + @pytest.mark.unit def test_mha_adapter_strategy(self): adapter = adapter_modules.MultiHeadAttentionAdapter(n_head=1, n_feat=50, dropout_rate=0.0) @@ -225,3 +266,11 @@ def test_relpos_encoding_adapter_strategy(self): assert adapter.adapter_strategy is not None # assert default strategy is set assert isinstance(adapter.adapter_strategy, adapter_mixin_strategies.ReturnResultAdapterStrategy) + + @pytest.mark.unit + def test_transformer_mha_adapter_strategy(self): + adapter = adapter_modules.TransformerMultiHeadAttentionAdapter(num_attention_heads=1, hidden_size=50, attn_layer_dropout=0.0) + assert hasattr(adapter, 'adapter_strategy') + assert adapter.adapter_strategy is not None + # assert default strategy is set + assert isinstance(adapter.adapter_strategy, adapter_modules.MHAResidualAddAdapterStrategy) From 155619ad3eb31fee41882eea2f3d46ef42674af1 Mon Sep 17 00:00:00 2001 From: titu1994 Date: Sat, 15 Jun 2024 02:23:58 +0000 Subject: [PATCH 14/21] Apply isort and black reformatting Signed-off-by: titu1994 --- .../asr/modules/transformer/transformer_modules.py | 4 +++- .../transformer_multi_head_attention_adapter_module.py | 4 +--- .../asr/mixins/adapters/test_asr_adapter_mixin.py | 8 ++++++-- .../asr/mixins/adapters/test_asr_adapter_modules.py | 4 +++- 4 files changed, 13 insertions(+), 7 deletions(-) diff --git a/nemo/collections/asr/modules/transformer/transformer_modules.py b/nemo/collections/asr/modules/transformer/transformer_modules.py index d3dcce139ac1..d090604287cb 100644 --- a/nemo/collections/asr/modules/transformer/transformer_modules.py +++ b/nemo/collections/asr/modules/transformer/transformer_modules.py @@ -65,7 +65,9 @@ def forward(self, position_ids): f'Max position id {max_pos_id} is greater than max sequence length {self._max_sequence_length}. Expanding position embeddings just for this batch. This is not expected to work very well. Consider chunking your input into smaller sequences.' ) self._build_pos_enc( - hidden_size=self._hidden_size, max_sequence_length=max_pos_id + 1, device=position_ids.device, + hidden_size=self._hidden_size, + max_sequence_length=max_pos_id + 1, + device=position_ids.device, ) embeddings = torch.embedding(self.pos_enc, position_ids) diff --git a/nemo/collections/asr/parts/submodules/adapters/transformer_multi_head_attention_adapter_module.py b/nemo/collections/asr/parts/submodules/adapters/transformer_multi_head_attention_adapter_module.py index 473f5a095a2b..192879736c32 100644 --- a/nemo/collections/asr/parts/submodules/adapters/transformer_multi_head_attention_adapter_module.py +++ b/nemo/collections/asr/parts/submodules/adapters/transformer_multi_head_attention_adapter_module.py @@ -28,9 +28,7 @@ from nemo.core.classes.mixins import adapter_mixin_strategies, adapter_mixins -class TransformerMultiHeadAttentionAdapter( - transformer_modules.MultiHeadAttention, adapter_modules.AdapterModuleUtil -): +class TransformerMultiHeadAttentionAdapter(transformer_modules.MultiHeadAttention, adapter_modules.AdapterModuleUtil): """Multi-Head Attention layer of Transformer Encoder. Args: diff --git a/tests/collections/asr/mixins/adapters/test_asr_adapter_mixin.py b/tests/collections/asr/mixins/adapters/test_asr_adapter_mixin.py index 4c8d51450123..b2811b082690 100644 --- a/tests/collections/asr/mixins/adapters/test_asr_adapter_mixin.py +++ b/tests/collections/asr/mixins/adapters/test_asr_adapter_mixin.py @@ -419,11 +419,15 @@ def get_adapter_cfg(in_features=50, dim=100, norm_pos='pre', atype='linear', **k cfg = adapter_modules.LinearAdapterConfig(in_features=in_features, dim=dim, norm_position=norm_pos) elif atype == 'mha': cfg = multi_head_attention_adapter_module.MultiHeadAttentionAdapterConfig( - n_head=kwargs.get('n_head', 1), n_feat=in_features, proj_dim=kwargs.get('proj_dim', None), + n_head=kwargs.get('n_head', 1), + n_feat=in_features, + proj_dim=kwargs.get('proj_dim', None), ) elif atype == 'transf_mha': cfg = transformer_multi_head_attention_adapter_module.TransformerMultiHeadAttentionAdapterConfig( - num_attention_heads=kwargs.get('n_head', 1), hidden_size=in_features, proj_dim=kwargs.get('proj_dim', None), + num_attention_heads=kwargs.get('n_head', 1), + hidden_size=in_features, + proj_dim=kwargs.get('proj_dim', None), ) elif atype == 'relmha': cfg = multi_head_attention_adapter_module.RelPositionMultiHeadAttentionAdapterConfig( diff --git a/tests/collections/asr/mixins/adapters/test_asr_adapter_modules.py b/tests/collections/asr/mixins/adapters/test_asr_adapter_modules.py index 3a91d7b4762f..ffaf1e640f3e 100644 --- a/tests/collections/asr/mixins/adapters/test_asr_adapter_modules.py +++ b/tests/collections/asr/mixins/adapters/test_asr_adapter_modules.py @@ -269,7 +269,9 @@ def test_relpos_encoding_adapter_strategy(self): @pytest.mark.unit def test_transformer_mha_adapter_strategy(self): - adapter = adapter_modules.TransformerMultiHeadAttentionAdapter(num_attention_heads=1, hidden_size=50, attn_layer_dropout=0.0) + adapter = adapter_modules.TransformerMultiHeadAttentionAdapter( + num_attention_heads=1, hidden_size=50, attn_layer_dropout=0.0 + ) assert hasattr(adapter, 'adapter_strategy') assert adapter.adapter_strategy is not None # assert default strategy is set From 166f28ff1ed2cd47bea1fe1c9677699693662b4e Mon Sep 17 00:00:00 2001 From: Somshubra Majumdar Date: Sat, 15 Jun 2024 16:38:15 -0700 Subject: [PATCH 15/21] Correct implementation of freeze restore Signed-off-by: smajumdar --- .../transformer/transformer_generators.py | 44 ++++++++++++++++++- 1 file changed, 42 insertions(+), 2 deletions(-) diff --git a/nemo/collections/nlp/modules/common/transformer/transformer_generators.py b/nemo/collections/nlp/modules/common/transformer/transformer_generators.py index a76b1606fd1e..2a2bf44212fb 100644 --- a/nemo/collections/nlp/modules/common/transformer/transformer_generators.py +++ b/nemo/collections/nlp/modules/common/transformer/transformer_generators.py @@ -173,7 +173,7 @@ def _forward( def __call__( self, decoder_input_ids=None, encoder_hidden_states=None, encoder_input_mask=None, return_beam_scores=False ): - with torch.inference_mode(): + with self.as_frozen(): return self._forward( decoder_input_ids, encoder_hidden_states, encoder_input_mask, return_beam_scores=return_beam_scores ) @@ -688,7 +688,7 @@ def _forward(self, src_ids, encoder_input_mask, decoder_input_ids=None, return_b return tgt def __call__(self, src_ids, encoder_input_mask, decoder_input_ids=None, return_beam_scores=False): - with torch.inference_mode(): + with self.as_frozen(): return self._forward(src_ids, encoder_input_mask, decoder_input_ids, return_beam_scores) def freeze(self) -> None: @@ -729,6 +729,40 @@ def as_frozen(self): Context manager which temporarily freezes embedding, decoder, and log_softmax modules, yields control and finally unfreezes the modules. """ + grad_module_list = {'embeddings': {}, 'decoders': {}, 'log_softmaxes': {}, 'encoders': {}} + training_mode_module_list = {'embeddings': {}, 'decoders': {}, 'log_softmaxes': {}, 'encoders': {}} + + def gather_grad_values(module_name): + map_values = [{} for _ in range(self.num_models)] + for model_num in range(self.num_models): + for name, param in getattr(self, module_name)[model_num].named_parameters(): + map_values[model_num][name].append(param.requires_grad) + return map_values + + def reset_grad_values(module_name, map_values, require_grad_default: bool): + for model_num in range(self.num_models): + for name, param in getattr(self, module_name)[model_num].named_parameters(): + if name in map_values[model_num]: + param.requires_grad = map_values[model_num].pop() + else: + param.requires_grad = require_grad_default + + def gather_reset_training_mode_values(module_name, map_values: dict = None): + map_values = [{} for _ in range(self.num_models)] if not map_values else map_values + get_values = len(map_values) == 0 + + for model_num in range(self.num_models): + if get_values: + map_values[model_num] = getattr(self, module_name)[model_num].training + else: + getattr(self, module_name)[model_num].train(map_values[model_num]) + return map_values + + # Cache the param.require_grad state of each module + for module_name in grad_module_list.keys(): + grad_module_list[module_name] = gather_grad_values(module_name) + training_mode_module_list[module_name] = gather_reset_training_mode_values(module_name) + self.freeze() try: @@ -736,6 +770,12 @@ def as_frozen(self): finally: self.unfreeze() + # Reset the param.require_grad state of each module + for module_name in grad_module_list.keys(): + reset_grad_values(module_name, grad_module_list[module_name], require_grad_default=True) + gather_reset_training_mode_values(module_name, map_values=training_mode_module_list[module_name]) + + class BeamSearchSequenceGeneratorWithLanguageModel(GreedySequenceGenerator): def __init__( From 8b9c08ac51a2c24218fbe79189ed480040173b04 Mon Sep 17 00:00:00 2001 From: titu1994 Date: Sat, 15 Jun 2024 23:39:00 +0000 Subject: [PATCH 16/21] Apply isort and black reformatting Signed-off-by: titu1994 --- .../nlp/modules/common/transformer/transformer_generators.py | 1 - 1 file changed, 1 deletion(-) diff --git a/nemo/collections/nlp/modules/common/transformer/transformer_generators.py b/nemo/collections/nlp/modules/common/transformer/transformer_generators.py index 2a2bf44212fb..9bac89f61135 100644 --- a/nemo/collections/nlp/modules/common/transformer/transformer_generators.py +++ b/nemo/collections/nlp/modules/common/transformer/transformer_generators.py @@ -776,7 +776,6 @@ def gather_reset_training_mode_values(module_name, map_values: dict = None): gather_reset_training_mode_values(module_name, map_values=training_mode_module_list[module_name]) - class BeamSearchSequenceGeneratorWithLanguageModel(GreedySequenceGenerator): def __init__( self, embedding, decoder, log_softmax, language_model, beam_size=1, len_pen=0, fusion_coef=0.0, **kwargs From f0e2d0801e1be8450ab104174fb56a52b52c4e04 Mon Sep 17 00:00:00 2001 From: Somshubra Majumdar Date: Mon, 24 Jun 2024 20:46:29 -0700 Subject: [PATCH 17/21] Corrects the implementation of replace_adapter_modules to limit to just the top level modules Signed-off-by: smajumdar --- nemo/core/classes/mixins/adapter_mixins.py | 11 ++++------ .../mixins/adapters/test_asr_adapter_mixin.py | 22 +++++++++++++++---- 2 files changed, 22 insertions(+), 11 deletions(-) diff --git a/nemo/core/classes/mixins/adapter_mixins.py b/nemo/core/classes/mixins/adapter_mixins.py index ca29ed4b1313..72d2eaa65035 100644 --- a/nemo/core/classes/mixins/adapter_mixins.py +++ b/nemo/core/classes/mixins/adapter_mixins.py @@ -1054,21 +1054,18 @@ def update_adapter_cfg(self, cfg: DictConfig): module.adapter_cfg = cfg def replace_adapter_compatible_modules( - self, module: Optional[nn.Module] = None, update_config: bool = True, verbose: bool = True + self, update_config: bool = True, verbose: bool = True ): """ - Utility method to replace all modules with Adapter variants, if they exist. + Utility method to replace all child modules with Adapter variants, if they exist. + Does NOT recurse through children of children modules (only immediate children). Args: - module: The module to be replaced with an Adapter variant. update_config: A flag that determines if the config should be updated or not. verbose: A flag that determines if the method should log the changes made or not. """ - if module is None: - module = self - # Update the given module itself, and then all its children modules - for name, mod in module.named_modules(): + for name, mod in self.named_children(): update_module_class_with_adapter_class(mod, cfg=self.cfg, update_config=update_config, verbose=verbose) @property diff --git a/tests/collections/asr/mixins/adapters/test_asr_adapter_mixin.py b/tests/collections/asr/mixins/adapters/test_asr_adapter_mixin.py index b2811b082690..bf7e10586603 100644 --- a/tests/collections/asr/mixins/adapters/test_asr_adapter_mixin.py +++ b/tests/collections/asr/mixins/adapters/test_asr_adapter_mixin.py @@ -320,7 +320,7 @@ def multitask_model(test_data_dir): # Test case where Encoder (default) is not adapter compatible encoder = { - '_target_': 'nemo.collections.asr.modules.ConformerEncoderAdapter', + '_target_': 'nemo.collections.asr.modules.ConformerEncoder', 'feat_in': 64, 'feat_out': -1, 'n_layers': 2, @@ -333,7 +333,7 @@ def multitask_model(test_data_dir): } transf_encoder = { - "_target_": "nemo.collections.asr.modules.transformer.transformer_encoders.TransformerEncoderAdapter", + "_target_": "nemo.collections.asr.modules.transformer.transformer_encoders.TransformerEncoder", "num_layers": 1, "hidden_size": "${model_defaults.lm_enc_hidden}", "inner_size": int(4 * model_defaults['lm_enc_hidden']), @@ -407,6 +407,9 @@ def multitask_model(test_data_dir): ) model_instance = EncDecMultiTaskModel(cfg=modelConfig) + + # Execute the model class swap logic + model_instance.replace_adapter_compatible_modules() return model_instance @@ -601,10 +604,11 @@ def test_squeezeformer_forward_mha(self, squeezeformer_ctc_adapter, name): assert torch.mean(torch.abs(origial_output - new_output)) < 1e-5 @pytest.mark.unit + @pytest.mark.parametrize('adapter_type', ['linear', 'attn']) @pytest.mark.parametrize( 'name', ['adapter_0', 'encoder:adapter_0', 'transf_encoder:adapter_0', 'transf_decoder:adapter_0'] ) - def test_canary_forward_mha(self, multitask_model, name): + def test_canary_forward_mha(self, multitask_model, name, adapter_type): multitask_model.eval() torch.random.manual_seed(0) input_signal = torch.randn(2, 512) @@ -621,7 +625,9 @@ def test_canary_forward_mha(self, multitask_model, name): og_logprob = origial_output[0] og_enc_out = origial_output[2] - adapter_type = 'transf_mha' if 'transf' in name else 'mha' + if adapter_type == 'attn': + adapter_type = 'transf_mha' if 'transf' in name else 'mha' + multitask_model.add_adapter(name=name, cfg=get_adapter_cfg(in_features=128, atype=adapter_type, proj_dim=4)) new_output = multitask_model( @@ -637,6 +643,14 @@ def test_canary_forward_mha(self, multitask_model, name): assert torch.mean(torch.abs(og_logprob - new_logprob)) < 1e-5 assert torch.mean(torch.abs(og_enc_out - new_enc_out)) < 1e-5 + if 'linear' in adapter_type: + mod_name = name.split(":")[-1] + for mod in multitask_model.modules(): + if isinstance(mod, AdapterModuleMixin): + amodule = mod.get_adapter_module(mod_name) + if amodule is not None: + assert isinstance(amodule, adapter_modules.LinearAdapter) + # Try to use incorrect adapter with pytest.raises(ValueError): multitask_model.add_adapter( From a4f08fa0e4f3259efef84ba3298003f74509687b Mon Sep 17 00:00:00 2001 From: titu1994 Date: Tue, 25 Jun 2024 03:47:15 +0000 Subject: [PATCH 18/21] Apply isort and black reformatting Signed-off-by: titu1994 --- nemo/core/classes/mixins/adapter_mixins.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/nemo/core/classes/mixins/adapter_mixins.py b/nemo/core/classes/mixins/adapter_mixins.py index 72d2eaa65035..1fbd37f01139 100644 --- a/nemo/core/classes/mixins/adapter_mixins.py +++ b/nemo/core/classes/mixins/adapter_mixins.py @@ -1053,9 +1053,7 @@ def update_adapter_cfg(self, cfg: DictConfig): if isinstance(module, AdapterModuleMixin): module.adapter_cfg = cfg - def replace_adapter_compatible_modules( - self, update_config: bool = True, verbose: bool = True - ): + def replace_adapter_compatible_modules(self, update_config: bool = True, verbose: bool = True): """ Utility method to replace all child modules with Adapter variants, if they exist. Does NOT recurse through children of children modules (only immediate children). From e8e6092b070503530a71e38c134dc9d7e13bf982 Mon Sep 17 00:00:00 2001 From: Somshubra Majumdar Date: Mon, 24 Jun 2024 21:05:16 -0700 Subject: [PATCH 19/21] Remove registration of Transformer MHA Signed-off-by: smajumdar --- .../transformer_multi_head_attention_adapter_module.py | 7 ------- nemo/core/classes/mixins/adapter_mixins.py | 2 +- 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/nemo/collections/asr/parts/submodules/adapters/transformer_multi_head_attention_adapter_module.py b/nemo/collections/asr/parts/submodules/adapters/transformer_multi_head_attention_adapter_module.py index 192879736c32..4319a6962f4f 100644 --- a/nemo/collections/asr/parts/submodules/adapters/transformer_multi_head_attention_adapter_module.py +++ b/nemo/collections/asr/parts/submodules/adapters/transformer_multi_head_attention_adapter_module.py @@ -126,10 +126,3 @@ class TransformerMultiHeadAttentionAdapterConfig: _target_: str = "{0}.{1}".format( TransformerMultiHeadAttentionAdapter.__module__, TransformerMultiHeadAttentionAdapter.__name__ ) - - -""" Register the Adapter Modules to the Adapter Module Registry """ -if adapter_mixins.get_registered_adapter(TransformerMultiHeadAttentionAdapter) is None: - adapter_mixins.register_adapter( - base_class=transformer_modules.MultiHeadAttention, adapter_class=TransformerMultiHeadAttentionAdapter - ) diff --git a/nemo/core/classes/mixins/adapter_mixins.py b/nemo/core/classes/mixins/adapter_mixins.py index 1fbd37f01139..05ac9b429d85 100644 --- a/nemo/core/classes/mixins/adapter_mixins.py +++ b/nemo/core/classes/mixins/adapter_mixins.py @@ -1063,7 +1063,7 @@ def replace_adapter_compatible_modules(self, update_config: bool = True, verbose verbose: A flag that determines if the method should log the changes made or not. """ # Update the given module itself, and then all its children modules - for name, mod in self.named_children(): + for name, mod in self.named_modules(): update_module_class_with_adapter_class(mod, cfg=self.cfg, update_config=update_config, verbose=verbose) @property From fceff7020c8dc0b91fdb1e4fe8971e1b060a827e Mon Sep 17 00:00:00 2001 From: Somshubra Majumdar Date: Mon, 24 Jun 2024 22:30:12 -0700 Subject: [PATCH 20/21] Remove registration of Transformer MHA Signed-off-by: smajumdar --- .../adapters/test_adapter_model_mixin.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/tests/core/mixins/adapters/test_adapter_model_mixin.py b/tests/core/mixins/adapters/test_adapter_model_mixin.py index 9d39dc0cb45c..20ced653ceb6 100644 --- a/tests/core/mixins/adapters/test_adapter_model_mixin.py +++ b/tests/core/mixins/adapters/test_adapter_model_mixin.py @@ -364,24 +364,6 @@ def test_base_model_replace_adapter_compatible_modules(self, caplog): model.add_adapter(name='encoder:adapter_0', cfg=get_adapter_cfg()) assert model.encoder.is_adapter_available() is True - @pytest.mark.unit - def test_base_model_replace_adapter_compatible_encoder_only(self, caplog): - cfg = get_model_config(in_features=50, update_adapter_cfg=False) - model = DefaultAdapterModel(cfg) - - with pytest.raises(AttributeError): - model.add_adapter(name='adapter_0', cfg=get_adapter_cfg()) - - # Replace the modules of the model dynamically to support adapters - model.replace_adapter_compatible_modules(model.encoder, update_config=True) - - assert isinstance(model.encoder, AdapterModuleMixin) - assert model.encoder.is_adapter_available() is False - - model.add_adapter(name='encoder:adapter_0', cfg=get_adapter_cfg()) - - assert model.encoder.is_adapter_available() is True - @pytest.mark.unit def test_single_adapter(self): cfg = get_model_config(in_features=50) From b133a59f129545876b7ced055aeae787fe4705fd Mon Sep 17 00:00:00 2001 From: smajumdar Date: Fri, 28 Jun 2024 16:01:14 -0700 Subject: [PATCH 21/21] Address reviewer comments Signed-off-by: smajumdar --- nemo/collections/asr/models/ctc_models.py | 4 ++++ .../asr/modules/transformer/transformer_utils.py | 9 +-------- .../parts/submodules/adapters/attention_adapter_mixin.py | 2 +- .../asr/mixins/adapters/test_asr_adapter_mixin.py | 3 ++- 4 files changed, 8 insertions(+), 10 deletions(-) diff --git a/nemo/collections/asr/models/ctc_models.py b/nemo/collections/asr/models/ctc_models.py index 093419c3ca0c..7540532d371b 100644 --- a/nemo/collections/asr/models/ctc_models.py +++ b/nemo/collections/asr/models/ctc_models.py @@ -879,6 +879,10 @@ def list_available_models(cls) -> List[PretrainedModelInfo]: return results + @property + def adapter_module_names(self) -> List[str]: + return ['', 'encoder', 'decoder'] + @property def wer(self): return self._wer diff --git a/nemo/collections/asr/modules/transformer/transformer_utils.py b/nemo/collections/asr/modules/transformer/transformer_utils.py index 0c2f36cb7965..5de1652ee1b0 100644 --- a/nemo/collections/asr/modules/transformer/transformer_utils.py +++ b/nemo/collections/asr/modules/transformer/transformer_utils.py @@ -114,14 +114,7 @@ def get_nemo_transformer( raise ValueError(f"Unknown arch = {arch}") else: - class_ = TransformerDecoderNM - - if cfg.get('adapter', False): - from nemo.core.classes.mixins.adapter_mixins import get_registered_adapter - - class_ = get_registered_adapter(TransformerDecoderNM).adapter_class - - model = class_( + model = TransformerDecoderNM( vocab_size=cfg.get('vocab_size'), hidden_size=cfg.get('hidden_size'), num_layers=cfg.get('num_layers'), diff --git a/nemo/collections/asr/parts/submodules/adapters/attention_adapter_mixin.py b/nemo/collections/asr/parts/submodules/adapters/attention_adapter_mixin.py index 0696c112529b..0c1852773072 100644 --- a/nemo/collections/asr/parts/submodules/adapters/attention_adapter_mixin.py +++ b/nemo/collections/asr/parts/submodules/adapters/attention_adapter_mixin.py @@ -55,7 +55,7 @@ def forward_single_enabled_adapter_( Args: input: Dictionary of packed tensors. The dict should contain at least `x`: output tensor - `loc`: Semantic location in module where this adapter was called + `loc`: Semantic location in module where this adapter was called. Can be 'mha' or 'post'. `att_mask`: Optional, Attention mask `pos_emb`: Optional, Positional Embedding for Relative Positional Encoding. The output tensor of the calling module is the input to the first adapter, whose output diff --git a/tests/collections/asr/mixins/adapters/test_asr_adapter_mixin.py b/tests/collections/asr/mixins/adapters/test_asr_adapter_mixin.py index bf7e10586603..cac1eb2fcdf3 100644 --- a/tests/collections/asr/mixins/adapters/test_asr_adapter_mixin.py +++ b/tests/collections/asr/mixins/adapters/test_asr_adapter_mixin.py @@ -510,7 +510,8 @@ def test_asr_model_constructor_joint_module_ctc_skip(self, model): original_num_params = model.num_weights # this step should exit without adding adapters and without errors - model.add_adapter(name='joint:adapter_0', cfg=get_adapter_cfg()) + with pytest.raises(ValueError): + model.add_adapter(name='joint:adapter_0', cfg=get_adapter_cfg()) new_num_params = model.num_weights assert new_num_params == original_num_params