Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support nested NeMo models #5671

Merged
merged 44 commits into from
Jan 23, 2023
Merged
Show file tree
Hide file tree
Changes from 42 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
21f9e79
Nested NeMo models support: save-restore with artifacts
artbataev Dec 19, 2022
a830aee
Fix import
artbataev Dec 19, 2022
cd494d1
Improve tests
artbataev Dec 19, 2022
a19197c
Add test for double nested model
artbataev Dec 19, 2022
f8519a4
Fix target for mock model with children
artbataev Dec 19, 2022
3295d75
MockModel config: revert to original
artbataev Dec 19, 2022
b114920
Add comments. Clean up
artbataev Dec 20, 2022
7b661d3
Merge branch 'main' into support_nested_models
okuchaiev Dec 20, 2022
b722181
Add docstring. Test number of artifacts
artbataev Dec 22, 2022
35ffb35
Merge branch 'main' into support_nested_models
artbataev Jan 9, 2023
f7d42f4
Handle cases when child model can change config
artbataev Jan 10, 2023
b8e0f82
Add test for multiple test-restore passes
artbataev Jan 10, 2023
47e5e89
Merge branch 'main' into support_nested_models
artbataev Jan 10, 2023
3d5e13d
Do not use artifacts from models not in config
artbataev Jan 10, 2023
49d2895
Fix for the case when `artifacts` attribute is not assigned
artbataev Jan 11, 2023
f775a5d
Merge branch 'main' into support_nested_models
ericharper Jan 14, 2023
10549cf
Merge branch 'main' into support_nested_models
artbataev Jan 16, 2023
7a5d83a
Improve documentation. Clarify 2 cases for model construction.
artbataev Jan 16, 2023
0136793
Avoid unpacking duplicated restoration paths
artbataev Jan 16, 2023
2347400
Fix docs
artbataev Jan 16, 2023
40c5437
Merge branch 'main' into support_nested_models
artbataev Jan 16, 2023
b276cb9
Implement explicit submodule registration
artbataev Jan 18, 2023
20067cb
Temporary fix pytorch-lightning
artbataev Jan 18, 2023
77bb11b
Merge branch 'main' into support_nested_models
artbataev Jan 19, 2023
da7de90
Use stt_en_conformer_ctc_small for testing nested models + from_pretr…
artbataev Jan 19, 2023
b5960d5
Test different config path and attribute name
artbataev Jan 19, 2023
7337fd1
Explicitly disallow registering .nemo checkpoint file as an artifact.…
artbataev Jan 19, 2023
5581f18
Add todo
artbataev Jan 19, 2023
f6f84aa
Fix named_nemo_modules
artbataev Jan 19, 2023
d970b50
Add documentation
artbataev Jan 19, 2023
13ba08f
Revert disallowing .nemo model as an artifact
artbataev Jan 19, 2023
96ffd43
Fix unused import
artbataev Jan 19, 2023
75c9622
Disallow registering .nemo file as an artifact
artbataev Jan 20, 2023
d498a26
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 20, 2023
07dc349
Fix docs
artbataev Jan 20, 2023
050a57c
Merge branch 'support_nested_models' of github.com:artbataev/NeMo int…
artbataev Jan 20, 2023
0cae024
Merge branch 'main' into support_nested_models
artbataev Jan 20, 2023
0589c33
Merge branch 'main' into support_nested_models
artbataev Jan 23, 2023
764f76f
Revert test for nested RNNT model, fix according to the new approach
artbataev Jan 23, 2023
5bf9b0e
Fix documentation
artbataev Jan 23, 2023
d929d3f
Fix documentation
artbataev Jan 23, 2023
0e68a46
Merge branch 'main' into support_nested_models
artbataev Jan 23, 2023
80a6706
Fix documentation for register_nemo_submodule
artbataev Jan 23, 2023
eb6ec43
Protect inner nemo submodules mapping
artbataev Jan 23, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 58 additions & 1 deletion docs/source/core/core.rst
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,64 @@ The resulting .nemo file will then have the following file:
4978b28103264263a03439aaa6560e5e_tokenizer.model

If ``verify_src_exists`` is set to ``False``, then the artifact is optional. This means that ``.register_artifact`` will return ``None``
if the ``src`` cannot be found.
if the ``src`` cannot be found.

Nested NeMo Models
------------------

In some cases, it may be helpful to use NeMo models inside other NeMo models. For example, we can incorporate language models into ASR models to use in a decoding process to improve accuracy or use hybrid ASR-TTS models to generate audio from the text on the fly to train or finetune the ASR model.

There are 3 ways to instantiate child models inside parent models:

- use subconfig directly
ericharper marked this conversation as resolved.
Show resolved Hide resolved
- use the ``.nemo`` checkpoint path to load the child model
- use a pretrained NeMo model

To register a child model, use the ``register_nemo_submodule`` method of the parent model. This method will add the child model to a provided model attribute and, in the serialization process, will handle child artifacts correctly and store the child model config in the parent model config in ``config_field``.

.. code-block:: python

from nemo.core.classes import ModelPT

class ChildModel(ModelPT):
... # implement necessary methods

class ParentModel(ModelPT):
ericharper marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, cfg, trainer=None):
super().__init__(cfg=cfg, trainer=trainer)

# optionally annotate type for IDE autocompletion and type checking
self.child_model: Optional[ChildModel]
if cfg.get("child_model") is not None:
# load directly from config
# either if config provided initially, or automatically
# after model restoration
self.register_nemo_submodule(
name="child_model",
config_field="child_model",
model=ChildModel(self.cfg.child_model, trainer=trainer),
)
elif cfg.get('child_model_path') is not None:
# load from .nemo model checkpoint
# while saving, config will be automatically assigned/updated
# in cfg.child_model
self.register_nemo_submodule(
name="child_model",
config_field="child_model",
model=ChildModel.restore_from(self.cfg.child_model_path, trainer=trainer),
)
elif cfg.get('child_model_name') is not None:
# load from pretrained model
# while saving, config will be automatically assigned/updated
# in cfg.child_model
self.register_nemo_submodule(
name="child_model",
config_field="child_model",
model=ChildModel.from_pretrained(self.cfg.child_model_name, trainer=trainer),
)
else:
self.child_model = None


Neural Modules
==============
Expand Down
106 changes: 103 additions & 3 deletions nemo/core/classes/modelPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import copy
import inspect
Expand All @@ -19,7 +20,7 @@
from abc import abstractmethod
from os import path
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union

import hydra
import torch
Expand All @@ -35,6 +36,7 @@
from nemo.utils import logging, model_utils
from nemo.utils.app_state import AppState
from nemo.utils.debug_hook import register_debug_hooks
from nemo.utils.exceptions import NeMoBaseException
from nemo.utils.get_rank import get_rank, is_global_rank_zero

__all__ = ['ModelPT']
Expand Down Expand Up @@ -110,6 +112,9 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):

self._cfg = cfg

# init mapping submodule attribute -> config_field for nested NeMo models
self.nemo_submodule_name_to_config_field = dict()
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved

self.save_hyperparameters("cfg")
self._train_dl = None
self._validation_dl = None
Expand Down Expand Up @@ -221,11 +226,15 @@ def register_artifact(
str: If src is not None or empty it always returns absolute path which is guaranteed to exist during model instance life
"""

app_state = AppState()

if src is None or src == "":
return src

if Path(src).suffix == ".nemo":
raise NeMoBaseException(
"Registering .nemo files as artifacts not supported. "
"If you are trying to make a nested model, use `register_nemo_submodule`."
)

if not hasattr(self, 'artifacts'):
self.artifacts = {}

Expand All @@ -240,6 +249,97 @@ def register_artifact(

return self._save_restore_connector.register_artifact(self, config_path, src, verify_src_exists)

def has_artifacts(self) -> bool:
"""Returns True if model has artifacts registered"""
return hasattr(self, 'artifacts') and self.artifacts is not None and len(self.artifacts) > 0
ericharper marked this conversation as resolved.
Show resolved Hide resolved

def has_native_or_submodules_artifacts(self) -> bool:
ericharper marked this conversation as resolved.
Show resolved Hide resolved
"""Returns True if it has artifacts or any of the submodules have artifacts"""
for module in self.modules():
if (
isinstance(module, ModelPT)
and hasattr(module, 'artifacts')
and module.artifacts is not None
and len(module.artifacts) > 0
):
return True
return False

def register_nemo_submodule(self, name: str, config_field: str, model: "ModelPT") -> None:
ericharper marked this conversation as resolved.
Show resolved Hide resolved
"""
Adds NeMo model as a submodule. Submodule can be accessed via the `name` attribute on self.
artbataev marked this conversation as resolved.
Show resolved Hide resolved
In the saving process, the whole parent model (self) is held as a solid model with artifacts
from the child submodule, the submodule config will be saved to the `config_field` of the parent model.
This method is necessary to create a nested model, e.g.
.. code-block:: python

class ParentModel(ModelPT):
def __init__(self, cfg, trainer=None):
super().__init__(cfg=cfg, trainer=trainer)

# annotate type for autocompletion and type checking (optional)
self.child_model: Optional[ChildModel] = None
if cfg.get("child_model") is not None:
self.register_nemo_submodule(
name="child_model",
config_field="child_model",
model=ChildModel(self.cfg.child_model, trainer=trainer),
)
# ... other code

Args:
name: name of the attribute for the submodule
config_field: field in config, where submodule config should be saved
model: NeMo model, instance of ModelPT
"""
# check it is a real NeMo model
if not isinstance(model, ModelPT):
raise NeMoBaseException(
f"Model is not and instance of ModelPT, so can't be registered. Got {type(model).__name__}"
)
# check if it is called after __init__
if not hasattr(self, "nemo_submodule_name_to_config_field"):
raise NeMoBaseException(
"You are trying to register a submodule before the model is initialized. This is not allowed. "
"Did you forget to call `super().__init__`?"
)
# assign attribute to self
setattr(self, name, model)
# add to the submodules mapping
self.nemo_submodule_name_to_config_field[name] = config_field

def named_nemo_modules(
self, prefix_name: str = "", prefix_config: str = ""
) -> Iterator[Tuple[str, str, "ModelPT"]]:
"""
Returns an iterator over all NeMo submodules recursively, yielding
tuples of (attribute path, path in config, submodule), starting from the core module

Args:
prefix_name: prefix for the name path
prefix_config: prefix for the path in config

Returns:
Iterator over (attribute path, path in config, submodule), starting from (prefix, self)
"""
if not hasattr(self, "nemo_submodule_name_to_config_field"):
raise NeMoBaseException(
"Model is not fully initialized. Calling `named_nemo_modules` before __init__ not allowed. "
"Did you forget to call `super().__init__`?"
)

yield prefix_name, prefix_config, self

# recursive iteration over all NeMo submodules
for name, config_field in self.nemo_submodule_name_to_config_field.items():
attribute_path = f"{prefix_name}.{name}" if prefix_name else name
config_path = f"{prefix_config}.{config_field}" if prefix_config else config_field
module: ModelPT = getattr(self, name)
for submodule_name, subconfig_path, submodule in module.named_nemo_modules(
prefix_name=attribute_path, prefix_config=config_path
):
yield submodule_name, subconfig_path, submodule

def save_to(self, save_path: str):
"""
Saves model instance (weights and configuration) into .nemo file
Expand Down
105 changes: 79 additions & 26 deletions nemo/core/connectors/save_restore_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations # necessary for lazy types evaluation

import os
import shutil
import tarfile
import tempfile
import uuid
from typing import Optional, Union
from typing import Optional, Set, Union

import torch
from omegaconf import DictConfig, OmegaConf
from omegaconf.omegaconf import open_dict
from pytorch_lightning.trainer.trainer import Trainer

from nemo.core import classes as nemo_classes # to avoid circular import do not import ModelPT directly
from nemo.utils import logging, model_utils
from nemo.utils.app_state import AppState
from nemo.utils.get_rank import is_global_rank_zero
Expand All @@ -37,14 +39,14 @@ def __init__(self) -> None:
self._model_weights_ckpt = "model_weights.ckpt"
self._model_extracted_dir = None

def save_to(self, model, save_path: str):
def save_to(self, model: "nemo_classes.ModelPT", save_path: str):
"""
Saves model instance (weights and configuration) into .nemo file.
You can use "restore_from" method to fully restore instance from .nemo file.

.nemo file is an archive (tar.gz) with the following:
model_config.yaml - model configuration in .yaml format. You can deserialize this into cfg argument for model's constructor
model_wights.chpt - model checkpoint
model_wights.ckpt - model checkpoint

Args:
model: ModelPT object to be saved.
Expand All @@ -56,7 +58,9 @@ def save_to(self, model, save_path: str):
config_yaml = os.path.join(tmpdir, self.model_config_yaml)
model_weights = os.path.join(tmpdir, self.model_weights_ckpt)
model.to_config_file(path2yaml_file=config_yaml)
if hasattr(model, 'artifacts') and model.artifacts is not None:
# update subconfigs, if there are child model, since child model can change its config
self._update_subconfigs(model, path2yaml_file=config_yaml)
if model.has_native_or_submodules_artifacts():
self._handle_artifacts(model, nemo_file_folder=tmpdir)
# We should not update self._cfg here - the model can still be in use
self._update_artifact_paths(model, path2yaml_file=config_yaml)
Expand Down Expand Up @@ -400,40 +404,70 @@ def register_artifact(self, model, config_path: str, src: str, verify_src_exists
def _handle_artifacts(self, model, nemo_file_folder):
tarfile_artifacts = []
app_state = AppState()
for conf_path, artiitem in model.artifacts.items():
if artiitem.path_type == model_utils.ArtifactPathType.LOCAL_PATH:
if not os.path.exists(artiitem.path):
raise FileNotFoundError(f"Artifact {conf_path} not found at location: {artiitem.path}")

# Generate new uniq artifact name and copy it to nemo_file_folder
# Note uuid.uuid4().hex is guaranteed to be 32 character long
artifact_base_name = os.path.basename(artiitem.path)
artifact_uniq_name = f"{uuid.uuid4().hex}_{artifact_base_name}"
shutil.copy2(artiitem.path, os.path.join(nemo_file_folder, artifact_uniq_name))
# aggregate artifacts from self and all children recursively
artifacts_containers = []
for _, config_path, module in model.named_nemo_modules():
if module.has_artifacts(): # NeMo model with artifacts
artifacts_containers.append((config_path, module.artifacts))

if len(artifacts_containers) > 0 and (not hasattr(model, "artifacts") or model.artifacts is None):
# model has no artifacts, but submodules have some
model.artifacts = dict()
for config_path, artifacts in artifacts_containers:
for subconf_path, artiitem in artifacts.items():
conf_path = f"{config_path}.{subconf_path}" if config_path else f"{subconf_path}"
if artiitem.path_type == model_utils.ArtifactPathType.LOCAL_PATH:
if not os.path.exists(artiitem.path):
raise FileNotFoundError(f"Artifact {conf_path} not found at location: {artiitem.path}")

# Generate new uniq artifact name and copy it to nemo_file_folder
# Note uuid.uuid4().hex is guaranteed to be 32 character long
artifact_base_name = os.path.basename(artiitem.path)
artifact_uniq_name = f"{uuid.uuid4().hex}_{artifact_base_name}"
shutil.copy2(artiitem.path, os.path.join(nemo_file_folder, artifact_uniq_name))

# Update artifacts registry
artiitem.hashed_path = "nemo:" + artifact_uniq_name
model.artifacts[conf_path] = artiitem
titu1994 marked this conversation as resolved.
Show resolved Hide resolved

elif artiitem.path_type == model_utils.ArtifactPathType.TAR_PATH:
# process all tarfile artifacts in one go, so preserve key-value pair
tarfile_artifacts.append((conf_path, artiitem))
if subconf_path: # artifact from submodule
model.artifacts[conf_path] = artiitem

# Update artifacts registry
artiitem.hashed_path = "nemo:" + artifact_uniq_name
model.artifacts[conf_path] = artiitem

elif artiitem.path_type == model_utils.ArtifactPathType.TAR_PATH:
# process all tarfile artifacts in one go, so preserve key-value pair
tarfile_artifacts.append((conf_path, artiitem))

else:
raise ValueError(f"Directly referencing artifacts from other nemo files isn't supported yet")
else:
raise ValueError(f"Directly referencing artifacts from other nemo files isn't supported yet")

# Process current tarfile artifacts by unpacking the previous tarfile and extract the artifacts
# that are currently required.
# artifacts can be native (from the model itself) and from submodules
restoration_paths: Set[str] = set() # model + submodules restoration paths, handle only unique paths
model_metadata = app_state.get_model_metadata_from_guid(model.model_guid)
if len(tarfile_artifacts) > 0 and model_metadata.restoration_path is not None:
if model_metadata.restoration_path is not None:
restoration_paths.add(model_metadata.restoration_path)
# aggregate restoration paths for all submodules recursively
for module in model.modules():
if isinstance(module, nemo_classes.ModelPT): # if NeMo model
submodule_restoration_path = app_state.get_model_metadata_from_guid(module.model_guid).restoration_path
if submodule_restoration_path is not None:
restoration_paths.add(submodule_restoration_path)
if len(tarfile_artifacts) > 0 and len(restoration_paths) == 0:
# TODO: see cases when this can occur, and if we can fix them
logging.warning("Model contains registered artifacts, but no restoration paths found")
titu1994 marked this conversation as resolved.
Show resolved Hide resolved
if len(tarfile_artifacts) > 0 and len(restoration_paths) > 0:
# Need to step into nemo archive to extract file
# Get path where the command is executed - the artifacts will be "retrieved" there
# (original .nemo behavior)
cwd = os.getcwd()
try:
# Step into the nemo archive to try and find the file
with tempfile.TemporaryDirectory() as archive_dir:
self._unpack_nemo_file(path2file=model_metadata.restoration_path, out_folder=archive_dir)
# unpack all restorations paths (nemo checkpoints)
# in nemo checkpoints all resources contain hash in name, so there should be no collisions
for path in restoration_paths:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Restoration paths were tempdirs, do we recreate those tempdirs?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

restoration_paths are only paths to .nemo checkpoints, since we don't use .nemo files inside parent .nemo file. Using .nemo in .nemo will still break the code and should be avoided.

I changed it to a set to unpack each checkpoint only once

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to add a test where we attempt to register a Nemo file. Maybe check and raise an error ?

self._unpack_nemo_file(path2file=path, out_folder=archive_dir)
os.chdir(archive_dir)
for conf_path, artiitem in tarfile_artifacts:
# Get basename and copy it to nemo_file_folder
Expand All @@ -454,8 +488,27 @@ def _handle_artifacts(self, model, nemo_file_folder):
# change back working directory
os.chdir(cwd)

@staticmethod
def _update_subconfigs(model: "nemo_classes.ModelPT", path2yaml_file):
"""
Update subconfigs of the model if ModelPT has submodules
Should be called before updating artifacts paths
"""
# check if there are submodules
if len(model.nemo_submodule_name_to_config_field) == 0:
return
conf = OmegaConf.load(path2yaml_file)
# update subconfigs for all children recoursively
# parent configs updated before children
for _, conf_path, submodule in model.named_nemo_modules():
if not conf_path: # self
continue
OmegaConf.update(conf, conf_path, submodule.cfg)
with open(path2yaml_file, 'w', encoding='utf-8') as fout:
OmegaConf.save(config=conf, f=fout, resolve=True)

def _update_artifact_paths(self, model, path2yaml_file):
if model.artifacts is not None and len(model.artifacts) > 0:
if hasattr(model, "artifacts") and model.artifacts is not None and len(model.artifacts) > 0:
conf = OmegaConf.load(path2yaml_file)
for conf_path, item in model.artifacts.items():
if item.hashed_path is None:
Expand Down
Loading