Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use module-based k2 import guard #6006

Merged
merged 3 commits into from
Feb 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions nemo/collections/asr/losses/lattice_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,6 @@ def __init__(

# we assume that self._blank + 1 == num_classes
if backend == "k2":
# use k2 import guard
from nemo.core.utils.k2_utils import k2_import_guard

k2_import_guard()

if criterion_type == "ml":
from nemo.collections.asr.parts.k2.ml_loss import MLLoss as K2Loss
elif criterion_type == "map":
Expand Down
5 changes: 0 additions & 5 deletions nemo/collections/asr/modules/graph_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,6 @@ def __init__(

# we assume that self._blank + 1 == num_classes
if backend == "k2":
# use k2 import guard
from nemo.core.utils.k2_utils import k2_import_guard

k2_import_guard()

if self.dec_type == "topo":
from nemo.collections.asr.parts.k2.graph_decoders import BaseDecoder as Decoder
elif self.dec_type == "token_lm":
Expand Down
5 changes: 0 additions & 5 deletions nemo/collections/asr/parts/k2/classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,6 @@ def _init_k2(self):
This method is expected to run after the __init__ which sets self._cfg
self._cfg is expected to have the attribute graph_module_cfg
"""
# use k2 import guard
from nemo.core.utils.k2_utils import k2_import_guard

k2_import_guard()

if not hasattr(self, "_cfg"):
raise ValueError("self._cfg must be set before calling _init_k2().")
if not hasattr(self._cfg, "graph_module_cfg") or self._cfg.graph_module_cfg is None:
Expand Down
10 changes: 1 addition & 9 deletions nemo/collections/asr/parts/k2/graph_compilers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,7 @@
from nemo.collections.asr.parts.k2.topologies import build_topo
from nemo.collections.asr.parts.k2.utils import compose_with_self_loops, intersect_with_self_loops

# use k2 import guard
# fmt: off
from nemo.core.utils.k2_utils import k2_import_guard # isort:skip
k2_import_guard()
import k2 # isort:skip
# fmt: on
from nemo.core.utils.k2_guard import k2 # import k2 from guard module


class CtcTopologyCompiler(object):
Expand All @@ -55,9 +50,6 @@ def __init__(
topo_with_self_loops: bool = True,
device: torch.device = torch.device("cpu"),
):
# use k2 import guard
k2_import_guard()

self.topo_type = topo_type
self.device = device
self.base_graph = k2.arc_sort(build_topo(topo_type, list(range(num_classes)), topo_with_self_loops)).to(
Expand Down
10 changes: 1 addition & 9 deletions nemo/collections/asr/parts/k2/graph_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,9 @@
prep_padded_densefsavec,
shift_labels_inpl,
)
from nemo.core.utils.k2_guard import k2 # import k2 from guard module
from nemo.utils import logging

# use k2 import guard
# fmt: off
from nemo.core.utils.k2_utils import k2_import_guard # isort:skip
k2_import_guard()
import k2 # isort:skip
# fmt: on


class BaseDecoder(object):
"""Base graph decoder with topology for decoding graph.
Expand All @@ -57,8 +51,6 @@ def __init__(
topo_with_self_loops: bool = True,
device: torch.device = torch.device("cpu"),
):
# use k2 import guard
k2_import_guard()

if cfg is not None:
intersect_pruned = cfg.get("intersect_pruned", intersect_pruned)
Expand Down
11 changes: 1 addition & 10 deletions nemo/collections/asr/parts/k2/map_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,9 @@
prep_padded_densefsavec,
shift_labels_inpl,
)
from nemo.core.utils.k2_guard import k2 # import k2 from guard module
from nemo.utils import logging

# use k2 import guard
# fmt: off
from nemo.core.utils.k2_utils import k2_import_guard # isort:skip
k2_import_guard()
import k2 # isort:skip
# fmt: on


class MAPLoss(torch.nn.Module):
"""
Expand Down Expand Up @@ -78,9 +72,6 @@ def __init__(
intersect_conf: GraphIntersectDenseConfig = GraphIntersectDenseConfig(),
boost_coeff: float = 0.0,
):
# use k2 import guard
k2_import_guard()

super().__init__()
if cfg is not None:
topo_type = cfg.get("topo_type", topo_type)
Expand Down
10 changes: 1 addition & 9 deletions nemo/collections/asr/parts/k2/ml_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,7 @@
prep_padded_densefsavec,
)

# use k2 import guard
# fmt: off
from nemo.core.utils.k2_utils import k2_import_guard # isort:skip
k2_import_guard()
import k2 # isort:skip
# fmt: on
from nemo.core.utils.k2_guard import k2 # import k2 from guard module


class MLLoss(torch.nn.Module):
Expand All @@ -71,9 +66,6 @@ def __init__(
graph_type: str = "topo",
token_lm: Optional[Union['k2.Fsa', str]] = None,
):
# use k2 import guard
k2_import_guard()

super().__init__()
if cfg is not None:
topo_type = cfg.get("topo_type", topo_type)
Expand Down
7 changes: 1 addition & 6 deletions nemo/collections/asr/parts/k2/topologies.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,7 @@

from typing import List

# use k2 import guard
# fmt: off
from nemo.core.utils.k2_utils import k2_import_guard # isort:skip
k2_import_guard()
import k2 # isort:skip
# fmt: on
from nemo.core.utils.k2_guard import k2 # import k2 from guard module


def build_topo(name: str, tokens: List[int], with_self_loops: bool = True) -> 'k2.Fsa':
Expand Down
8 changes: 1 addition & 7 deletions nemo/collections/asr/parts/k2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,9 @@

import torch

from nemo.core.utils.k2_guard import k2 # import k2 from guard module
from nemo.utils import logging

# use k2 import guard
# fmt: off
from nemo.core.utils.k2_utils import k2_import_guard # isort:skip
k2_import_guard()
import k2 # isort:skip
# fmt: on


def create_supervision(input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Creates a special supervisions tensor from input lengths.
Expand Down
2 changes: 1 addition & 1 deletion nemo/core/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from nemo.core.utils import k2_utils, numba_utils, process_launcher
from nemo.core.utils import numba_utils, process_launcher
61 changes: 27 additions & 34 deletions nemo/core/utils/k2_utils.py → nemo/core/utils/k2_guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Tuple
"""
Guard for importing optional NeMo dependency k2.
Contains checks for k2 availability and version.
Use `from nemo.core.utils.k2_guard import k2` to import k2 instead of direct import.
If there is an error, the module will raise an exception with a helpful message.
"""

__K2_MINIMUM_MAJOR_VERSION__ = 1
__K2_MINIMUM_MINOR_VERSION__ = 11
import textwrap

from packaging.version import Version
from pytorch_lightning.utilities.imports import package_available

__K2_MINIMUM_MAJOR_VERSION = 1
__K2_MINIMUM_MINOR_VERSION = 11

__K2_MINIMUM_VERSION = Version(f"{__K2_MINIMUM_MAJOR_VERSION}.{__K2_MINIMUM_MINOR_VERSION}")

K2_INSTALLATION_MESSAGE = (
"Could not import `k2`.\n"
Expand All @@ -26,39 +38,20 @@
"It is advised to always install k2 using setup.sh only, "
"as different versions of k2 may not interact with the NeMo code as expected."
)
K2_IMPORT_SUCCESS = "Successfully imported k2."


def k2_is_available() -> Tuple[bool, str]:
try:
import k2
if not package_available("k2"):
raise ModuleNotFoundError(K2_INSTALLATION_MESSAGE)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is it an optional dependency if merely importing it will raise an error ? K2 is not installed by default, the check should be don't at init inside of where it will be used rather than during simply import of a file.

Copy link
Collaborator

@titu1994 titu1994 Feb 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, k2 is not included in the import path of init.py files for this reason, but at least the check is done at init of the class. This instead puts it in the import path so if user simply imports the file it will crash with helpful error.

Numba checks specifically defers this crash to the moment user actually tries to use the module rather than simply import the module

Copy link
Collaborator Author

@artbataev artbataev Feb 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Importing k2 raises an error only when importing k2_guard of k2 from k2_guard
k2_guard will not be imported and cached.

So, every module that tries to import k2 from k2_guard will raise an error (if no k2 / incorrect version) and will not be initialized (no classes in this module).
As a result, importing from any of the modules in nemo, that depend on k2 (maybe not directly) will result in error with the necessary message, which is a correct behavior.

But importing not at the top level will allow to raise an error in runtime, we do not need to call k2_import_guard():

if backend == "k2":
    if criterion_type == "ml":
        from nemo.collections.asr.parts.k2.ml_loss import MLLoss as K2Loss  # maybe error here
    elif criterion_type == "map":
        from nemo.collections.asr.parts.k2.map_loss import MAPLoss as K2Loss  # or here

This doesn't change the previous behavior: in modules k2_import_guard() call was used at the top level, so incorrect k2 version resulted in immediate crash before importing any defined class.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can set __K2_MINIMUM_MINOR_VERSION = 111 and see that the behavior didn't change and everything works.


return True, K2_IMPORT_SUCCESS
except (ImportError, ModuleNotFoundError):
pass
return False, K2_INSTALLATION_MESSAGE
import k2 # noqa: E402

__k2_version = Version(k2.__dev_version__)

def k2_satisfies_minimum_version() -> Tuple[bool, str]:
import k2

k2_major_version, k2_minor_version = k2.__dev_version__.split(".")[:2]
ver_satisfies = (
int(k2_major_version) >= __K2_MINIMUM_MAJOR_VERSION__ and int(k2_minor_version) >= __K2_MINIMUM_MINOR_VERSION__
)
return (
ver_satisfies,
(
f"Minimum required k2 version: {__K2_MINIMUM_MAJOR_VERSION__}.{__K2_MINIMUM_MINOR_VERSION__};\n"
f"Installed k2 version: {k2_major_version}.{k2_minor_version}"
),
if __k2_version < __K2_MINIMUM_VERSION:
raise ImportError(
textwrap.dedent(
f"""
Minimum required k2 version: {__K2_MINIMUM_VERSION};
Installed k2 version: {__k2_version}
"""
)
)


def k2_import_guard():
avail, msg = k2_is_available()
if not avail:
raise ModuleNotFoundError(msg)
ver_ge, msg = k2_satisfies_minimum_version()
if not ver_ge:
raise ImportError(msg)