diff --git a/monai/losses/perceptual.py b/monai/losses/perceptual.py index ee653fac9d..2ae03bc8dc 100644 --- a/monai/losses/perceptual.py +++ b/monai/losses/perceptual.py @@ -209,7 +209,7 @@ def __init__( ) -> None: super().__init__() torch.hub._validate_not_a_forked_repo = lambda a, b, c: True - self.model = torch.hub.load("warvito/MedicalNet-models", model=net, verbose=verbose) + self.model = torch.hub.load("warvito/MedicalNet-models", model=net, verbose=verbose, trust_repo=True) self.eval() self.channel_wise = channel_wise @@ -297,7 +297,7 @@ class RadImageNetPerceptualSimilarity(nn.Module): def __init__(self, net: str = "radimagenet_resnet50", verbose: bool = False) -> None: super().__init__() - self.model = torch.hub.load("Warvito/radimagenet-models", model=net, verbose=verbose) + self.model = torch.hub.load("Warvito/radimagenet-models", model=net, verbose=verbose, trust_repo=True) self.eval() for param in self.parameters(): diff --git a/tests/test_utils.py b/tests/test_utils.py index 784b25f663..f87b16fb71 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -56,6 +56,8 @@ nib, _ = optional_import("nibabel") http_error, has_req = optional_import("requests", name="HTTPError") +file_url_error, has_gdown = optional_import("gdown.exceptions", name="FileURLRetrievalError") + quick_test_var = "QUICKTEST" _tf32_enabled = None @@ -63,6 +65,23 @@ MODULE_PATH = Path(__file__).resolve().parents[1] +DOWNLOAD_EXCEPTS: tuple[type, ...] = (ContentTooShortError, HTTPError, ConnectionError) +if has_req: + DOWNLOAD_EXCEPTS += (http_error,) +if has_gdown: + DOWNLOAD_EXCEPTS += (file_url_error,) + +DOWNLOAD_FAIL_MSGS = ( + "unexpected EOF", # incomplete download + "network issue", + "gdown dependency", # gdown not installed + "md5 check", + "limit", # HTTP Error 503: Egress is over the account limit + "authenticate", + "timed out", # urlopen error [Errno 110] Connection timed out + "HTTPError", # HTTPError: 429 Client Error: Too Many Requests for huggingface hub +) + def testing_data_config(*keys): """get _test_data_config[keys0][keys1]...[keysN]""" @@ -142,29 +161,21 @@ def assert_allclose( @contextmanager def skip_if_downloading_fails(): + """ + Skips a test if downloading something raises an exception recognised to indicate a download has failed. + """ + try: yield - except (ContentTooShortError, HTTPError, ConnectionError) + (http_error,) if has_req else () as e: # noqa: B030 - raise unittest.SkipTest(f"error while downloading: {e}") from e + except DOWNLOAD_EXCEPTS as e: + raise unittest.SkipTest(f"Error while downloading: {e}") from e except ssl.SSLError as ssl_e: if "decryption failed" in str(ssl_e): raise unittest.SkipTest(f"SSL error while downloading: {ssl_e}") from ssl_e except (RuntimeError, OSError) as rt_e: err_str = str(rt_e) - if any( - k in err_str - for k in ( - "unexpected EOF", # incomplete download - "network issue", - "gdown dependency", # gdown not installed - "md5 check", - "limit", # HTTP Error 503: Egress is over the account limit - "authenticate", - "timed out", # urlopen error [Errno 110] Connection timed out - "HTTPError", # HTTPError: 429 Client Error: Too Many Requests for huggingface hub - ) - ): - raise unittest.SkipTest(f"error while downloading: {rt_e}") from rt_e # incomplete download + if any(k in err_str for k in DOWNLOAD_FAIL_MSGS): + raise unittest.SkipTest(f"Error while downloading: {rt_e}") from rt_e # incomplete download raise rt_e