Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add change_vocabulary and save_tokenizers() support to Multitask ASR models #8357

Merged
merged 3 commits into from
Feb 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 134 additions & 1 deletion nemo/collections/asr/models/aed_multitask_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import itertools
import json
import os
Expand All @@ -22,7 +23,7 @@
import editdistance
import torch
import torch.distributed as dist
from omegaconf import DictConfig, OmegaConf, open_dict
from omegaconf import DictConfig, ListConfig, OmegaConf, open_dict
from pytorch_lightning import Trainer
from torchmetrics.text import SacreBLEUScore
from tqdm.auto import tqdm
Expand Down Expand Up @@ -212,6 +213,138 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig):

logging.info(f"Changed decoding strategy to \n{OmegaConf.to_yaml(self.cfg.decoding)}")

def change_vocabulary(
self,
new_tokenizer_dir: Union[str, DictConfig],
new_tokenizer_type: str,
decoding_cfg: Optional[DictConfig] = None,
prompt_format: Optional[str] = None,
):
"""
Changes vocabulary used during AED decoding process. Use this method when fine-tuning on from pre-trained model.
This method changes only decoder and leaves encoder and pre-processing modules unchanged. For example, you would
use it if you want to use pretrained encoder when fine-tuning on data in another language, or when you'd need
model to learn capitalization, punctuation and/or special characters.

Args:
new_tokenizer_dir: Directory path to tokenizer or a config for a new tokenizer (if the tokenizer type is `agg`)
new_tokenizer_type: Type of tokenizer. Can be either `agg`, `bpe` or `wpe`.
decoding_cfg: A config for the decoding, which is optional. If the decoding type
needs to be changed (from say Greedy to Beam decoding etc), the config can be passed here.
prompt_format: A string alias of the object that represents the prompt structure.
If not None, it will be used to update the prompt format.
"""
if isinstance(new_tokenizer_dir, (dict, DictConfig)):
if new_tokenizer_type == 'agg':
if not isinstance(new_tokenizer_dir, DictConfig):
new_tokenizer_dir = OmegaConf.create(new_tokenizer_dir)

new_tokenizer_cfg = new_tokenizer_dir
else:
raise ValueError(
f'New tokenizer dir should be a string unless the tokenizer is `agg`, but this tokenizer type is: {new_tokenizer_type}'
)
else:
new_tokenizer_cfg = None

if new_tokenizer_cfg is not None:
tokenizer_cfg = new_tokenizer_cfg
else:
if not os.path.isdir(new_tokenizer_dir):
raise NotADirectoryError(
f'New tokenizer dir must be non-empty path to a directory. But instead got: {new_tokenizer_dir}'
)

if new_tokenizer_type.lower() not in ('bpe', 'wpe'):
raise ValueError(f'New tokenizer type must be either `bpe` or `wpe`')

tokenizer_cfg = OmegaConf.create({'dir': new_tokenizer_dir, 'type': new_tokenizer_type})

if prompt_format is None:
prompt_format = self.cfg.prompt_format

# Setup the tokenizer
self._setup_tokenizer(tokenizer_cfg)

# Initialize a dummy vocabulary
vocabulary = self.tokenizer.tokenizer.get_vocab()

# Setup Decoder
transf_decoder_cfg_dict = self.transf_decoder.to_config_dict()

vocab_size = 8 * ceil(self.tokenizer.vocab_size / 8)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why number 8 here and not another int?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is from the original code. @krishnacpuvvada


# Auto inject vocab size for `get_transformer`
with open_dict(transf_decoder_cfg_dict):
if 'config_dict' in transf_decoder_cfg_dict:
transf_decoder_cfg_dict['config_dict']['vocab_size'] = vocab_size

original_decoder_state_dict = self.transf_decoder.state_dict()
self.transf_decoder = EncDecMultiTaskModel.from_config_dict(transf_decoder_cfg_dict)

# Partially load the original state dict into the new decoder
decoder_state_dict = self.transf_decoder.state_dict()
for og_key, og_value in original_decoder_state_dict.items():
if og_key in decoder_state_dict and og_value.shape == decoder_state_dict[og_key].shape:
decoder_state_dict[og_key] = og_value
else:
logging.warning(
f"Skipping key `{og_key}` in the `transf_decoder` module from original state dict due "
f"to shape mismatch after change in vocabulary.\n"
f"Original shape: {og_value.shape}, New shape: {decoder_state_dict[og_key].shape}"
)

self.transf_decoder.load_state_dict(decoder_state_dict)

# Setup token classifier
with open_dict(self.cfg.head):
self.cfg.head.num_classes = vocab_size

del self.log_softmax
self.log_softmax = EncDecMultiTaskModel.from_config_dict(self.cfg.head)

# Weight tying - if using TokenClassifier only
if isinstance(self.log_softmax, TokenClassifier):
self.log_softmax.mlp.layer0.weight = self.transf_decoder.embedding.token_embedding.weight

# Initialize weights of token classifier
std_init_range = 1 / self.cfg.model_defaults.lm_dec_hidden ** 0.5
self.log_softmax.apply(lambda module: transformer_weights_init(module, std_init_range))

# Setup Decoding class
if decoding_cfg is None:
# Assume same decoding config as before
decoding_cfg = self.cfg.decoding

# Assert the decoding config with all hyper parameters
decoding_cls = OmegaConf.structured(MultiTaskDecodingConfig)
decoding_cls = OmegaConf.create(OmegaConf.to_container(decoding_cls))
decoding_cfg = OmegaConf.merge(decoding_cls, decoding_cfg)

del self.decoding
self.decoding = MultiTaskDecoding(
decoding_cfg=decoding_cfg,
transformer_decoder=self.transf_decoder,
log_softmax_module=self.log_softmax,
tokenizer=self.tokenizer,
)

with open_dict(self.cfg.decoding):
self.cfg.decoding = decoding_cfg

# Setup loss
with open_dict(self.cfg.loss):
self.cfg.loss.pad_id = self.tokenizer.pad_id

del self.loss
self.loss = EncDecMultiTaskModel.from_config_dict(self.cfg.loss)

# Update config
with open_dict(self.cfg):
self.cfg.prompt_format = prompt_format

logging.info(f"Changed decoder to output to {vocabulary} vocabulary.")

@torch.no_grad()
def transcribe(
self,
Expand Down
95 changes: 94 additions & 1 deletion nemo/collections/asr/parts/mixins/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.

import os
import shutil
import tarfile
from abc import ABC, abstractmethod
from typing import List

Expand All @@ -25,7 +27,7 @@
from nemo.collections.asr.parts.utils import asr_module_utils
from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis
from nemo.collections.common import tokenizers
from nemo.utils import logging
from nemo.utils import app_state, logging


class ASRBPEMixin(ABC):
Expand Down Expand Up @@ -372,6 +374,97 @@ def _cleanup_aggregate_config_and_artifacts_if_needed(self):
if akey.startswith('tokenizer.' + self.AGGREGATE_TOKENIZERS_DICT_PREFIX + '.'):
self.artifacts.pop(akey)

def save_tokenizers(self, directory: str):
"""
Save the model tokenizer(s) to the specified directory.

Args:
directory: The directory to save the tokenizer(s) to.
"""
if not hasattr(self, 'cfg'):
raise RuntimeError(
"The model has not been initialized with a tokenizer yet. Please call the model's "
"__init__ and _setup_tokenizer methods first."
)

if self.tokenizer_type == 'agg':
for lang in self.tokenizer.langs:
subconfig = self.cfg.tokenizer.langs.get(lang)
new_dir = os.path.join(directory, lang)
self._extract_tokenizer_from_config(subconfig, new_dir)
else:
self._extract_tokenizer_from_config(self.cfg.tokenizer, directory)

def _extract_tokenizer_from_config(self, tokenizer_cfg: DictConfig, dir: str):
"""
Extracts the tokenizer from the config and write the objects to dir.
The file may be from a local path (new model init) or from a .nemo file (restored model).
If its from a newly initialized model, the file is copied to dir.
If its from a restored model, the file is extracted from the .nemo file and copied to dir.

Args:
tokenizer_cfg: The tokenizer config to extract the tokenizer from.
dir: The directory to write the tokenizer objects to.
"""
if not os.path.exists(dir):
os.makedirs(dir, exist_ok=True)

nemo_file_objects = []

for k, v in tokenizer_cfg.items():
# Check if the value is a filepath (new model init) or has `nemo:` in it (restored model)
if isinstance(v, str) and os.path.exists(v):
# local file from first instantiation
loc = shutil.copy2(v, dir)
logging.info(f"Saved {k} at {loc}")

if isinstance(v, str) and v.startswith('nemo:'):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why : after nemo

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nemo files modify config of registered items like this to denote registered artifacts.

nemo_object_name = v[5:]
nemo_file_objects.append(nemo_object_name)

if len(nemo_file_objects) > 0:
logging.debug(f"Copying the following nemo file objects to {dir}: {nemo_file_objects}")

if not hasattr(self, 'model_guid'):
raise ValueError(
"The model does not have a model_guid attribute. "
"Please ensure that the model has been restored from a .nemo file."
)

appstate = app_state.AppState()
restore_path = appstate.get_model_metadata_from_guid(self.model_guid).restoration_path
if restore_path is None:
raise ValueError(
"The model has not been restored from a .nemo file. Cannot extract the tokenizer "
"as the nemo file cannot be located."
)

# Read the nemo file without fully extracting all contents
# we start with an assumption of uncompressed tar,
# which should be true for versions 1.7.0 and above
tar_header = "r:"
try:
tar_test = tarfile.open(restore_path, tar_header)
tar_test.close()
except tarfile.ReadError:
# can be older checkpoint => try compressed tar
tar_header = "r:gz"
tar = tarfile.open(restore_path, tar_header)

for nemo_object_name in nemo_file_objects:
members = [x for x in tar.getmembers() if nemo_object_name in x.name]
for member in members:
tar.extract(member, dir)

new_name = member.name.split("_")[1:]
if len(new_name) > 1:
new_name = "_".join(new_name)
else:
new_name = new_name[0]
os.rename(os.path.join(dir, member.name), os.path.join(dir, new_name))

logging.info(f"Saved {nemo_object_name} at {os.path.join(dir, new_name)}")


class ASRModuleMixin(ASRAdapterModelMixin):
"""
Expand Down
Loading