From a46e3251944642f9102aa16ce2d2f9d3a804ff8a Mon Sep 17 00:00:00 2001 From: Somshubra Majumdar Date: Fri, 28 Jul 2023 12:34:52 -0700 Subject: [PATCH] Fix import guard checks (#7124) Signed-off-by: smajumdar --- nemo/utils/model_utils.py | 2 +- tests/collections/nlp/test_flash_attention.py | 21 ++++++++++++------- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/nemo/utils/model_utils.py b/nemo/utils/model_utils.py index 42a0b108944d..0f70631f140f 100644 --- a/nemo/utils/model_utils.py +++ b/nemo/utils/model_utils.py @@ -576,7 +576,7 @@ def check_lib_version(lib_name: str, checked_version: str, operator) -> Tuple[Op f"Could not check version compatibility." ) return False, msg - except (ImportError, ModuleNotFoundError): + except (ImportError, ModuleNotFoundError, AttributeError): pass msg = f"Lib {lib_name} has not been installed. Please use pip or conda to install this package." diff --git a/tests/collections/nlp/test_flash_attention.py b/tests/collections/nlp/test_flash_attention.py index 727742fdffb5..aa96b6753849 100644 --- a/tests/collections/nlp/test_flash_attention.py +++ b/tests/collections/nlp/test_flash_attention.py @@ -44,16 +44,23 @@ except (ImportError, ModuleNotFoundError): HAVE_TRITON = False -import pynvml +try: + import pynvml + + HAVE_PYNVML = True +except (ImportError, ModuleNotFoundError): + HAVE_PYNVML = False def HAVE_AMPERE_GPU(): - pynvml.nvmlInit() - handle = pynvml.nvmlDeviceGetHandleByIndex(0) - device_arch = pynvml.nvmlDeviceGetArchitecture(handle) - pynvml.nvmlShutdown() - return device_arch == pynvml.NVML_DEVICE_ARCH_AMPERE - + if HAVE_PYNVML: + pynvml.nvmlInit() + handle = pynvml.nvmlDeviceGetHandleByIndex(0) + device_arch = pynvml.nvmlDeviceGetArchitecture(handle) + pynvml.nvmlShutdown() + return device_arch == pynvml.NVML_DEVICE_ARCH_AMPERE + else: + return False @pytest.mark.run_only_on('GPU') @pytest.mark.skipif(not HAVE_APEX, reason="apex is not installed")