From 46dfa42d871eab0069d1bd6d7b8423144da266ae Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Tue, 28 Nov 2023 15:19:48 +0100 Subject: [PATCH] Lazy import in image/multimodal domains (#2215) * lazy multimodal * lazy image * fix import logic * fix import logic * revert * fixes * fixes --------- Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> --- src/torchmetrics/functional/image/lpips.py | 4 +-- src/torchmetrics/functional/image/uqi.py | 9 +++---- .../functional/multimodal/clip_iqa.py | 27 ++++++++++--------- .../functional/multimodal/clip_score.py | 14 +++++++--- src/torchmetrics/multimodal/clip_iqa.py | 4 +-- src/torchmetrics/multimodal/clip_score.py | 4 +-- 6 files changed, 35 insertions(+), 27 deletions(-) diff --git a/src/torchmetrics/functional/image/lpips.py b/src/torchmetrics/functional/image/lpips.py index 63a708969c0..6f8d4b4a450 100644 --- a/src/torchmetrics/functional/image/lpips.py +++ b/src/torchmetrics/functional/image/lpips.py @@ -40,8 +40,6 @@ if not _TORCHVISION_AVAILABLE: __doctest_skip__ = ["learned_perceptual_image_patch_similarity"] -else: - from torchvision import models as tv def _get_net(net: str, pretrained: bool) -> nn.modules.container.Sequential: @@ -52,6 +50,8 @@ def _get_net(net: str, pretrained: bool) -> nn.modules.container.Sequential: pretrained: If pretrained weights should be used """ + from torchvision import models as tv + if _TORCHVISION_GREATER_EQUAL_0_13: if pretrained: pretrained_features = getattr(tv, net)(weights=getattr(tv, _weight_map[net]).IMAGENET1K_V1).features diff --git a/src/torchmetrics/functional/image/uqi.py b/src/torchmetrics/functional/image/uqi.py index fe17f7b382a..c52ebe8c16b 100644 --- a/src/torchmetrics/functional/image/uqi.py +++ b/src/torchmetrics/functional/image/uqi.py @@ -14,8 +14,7 @@ from typing import Optional, Sequence, Tuple import torch -from torch import Tensor -from torch.nn import functional as F # noqa: N812 +from torch import Tensor, nn from typing_extensions import Literal from torchmetrics.functional.image.helper import _gaussian_kernel_2d @@ -92,11 +91,11 @@ def _uqi_compute( pad_h = (kernel_size[0] - 1) // 2 pad_w = (kernel_size[1] - 1) // 2 - preds = F.pad(preds, (pad_h, pad_h, pad_w, pad_w), mode="reflect") - target = F.pad(target, (pad_h, pad_h, pad_w, pad_w), mode="reflect") + preds = nn.functional.pad(preds, (pad_h, pad_h, pad_w, pad_w), mode="reflect") + target = nn.functional.pad(target, (pad_h, pad_h, pad_w, pad_w), mode="reflect") input_list = torch.cat((preds, target, preds * preds, target * target, preds * target)) # (5 * B, C, H, W) - outputs = F.conv2d(input_list, kernel, groups=channel) + outputs = nn.functional.conv2d(input_list, kernel, groups=channel) output_list = outputs.split(preds.shape[0]) mu_pred_sq = output_list[0].pow(2) diff --git a/src/torchmetrics/functional/multimodal/clip_iqa.py b/src/torchmetrics/functional/multimodal/clip_iqa.py index 078760dedde..ef2f354f1fc 100644 --- a/src/torchmetrics/functional/multimodal/clip_iqa.py +++ b/src/torchmetrics/functional/multimodal/clip_iqa.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Literal, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Literal, Tuple, Union import torch from torch import Tensor @@ -20,7 +20,11 @@ from torchmetrics.utilities.checks import _SKIP_SLOW_DOCTEST, _try_proceed_with_timeout from torchmetrics.utilities.imports import _PIQ_GREATER_EQUAL_0_8, _TRANSFORMERS_GREATER_EQUAL_4_10 -if _TRANSFORMERS_GREATER_EQUAL_4_10: +if TYPE_CHECKING: + from transformers import CLIPModel as _CLIPModel + from transformers import CLIPProcessor as _CLIPProcessor + +if _SKIP_SLOW_DOCTEST and _TRANSFORMERS_GREATER_EQUAL_4_10: from transformers import CLIPModel as _CLIPModel from transformers import CLIPProcessor as _CLIPProcessor @@ -28,13 +32,10 @@ def _download_clip() -> None: _CLIPModel.from_pretrained("openai/clip-vit-base-patch16") _CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16") - if _SKIP_SLOW_DOCTEST and not _try_proceed_with_timeout(_download_clip): - __doctest_skip__ = ["clip_score"] - + if not _try_proceed_with_timeout(_download_clip): + __doctest_skip__ = ["clip_image_quality_assessment"] else: __doctest_skip__ = ["clip_image_quality_assessment"] - _CLIPModel = None - _CLIPProcessor = None if not _PIQ_GREATER_EQUAL_0_8: __doctest_skip__ = ["clip_image_quality_assessment"] @@ -67,8 +68,10 @@ def _get_clip_iqa_model_and_processor( "openai/clip-vit-large-patch14-336", "openai/clip-vit-large-patch14", ] -) -> Tuple[_CLIPModel, _CLIPProcessor]: +) -> Tuple["_CLIPModel", "_CLIPProcessor"]: """Extract the CLIP model and processor from the model name or path.""" + from transformers import CLIPProcessor as _CLIPProcessor + if model_name_or_path == "clip_iqa": if not _PIQ_GREATER_EQUAL_0_8: raise ValueError( @@ -141,8 +144,8 @@ def _clip_iqa_format_prompts(prompts: Tuple[Union[str, Tuple[str, str]]] = ("qua def _clip_iqa_get_anchor_vectors( model_name_or_path: str, - model: _CLIPModel, - processor: _CLIPProcessor, + model: "_CLIPModel", + processor: "_CLIPProcessor", prompts_list: List[str], device: Union[str, torch.device], ) -> Tensor: @@ -176,8 +179,8 @@ def _clip_iqa_get_anchor_vectors( def _clip_iqa_update( model_name_or_path: str, images: Tensor, - model: _CLIPModel, - processor: _CLIPProcessor, + model: "_CLIPModel", + processor: "_CLIPProcessor", data_range: float, device: Union[str, torch.device], ) -> Tensor: diff --git a/src/torchmetrics/functional/multimodal/clip_score.py b/src/torchmetrics/functional/multimodal/clip_score.py index 01576369d15..9070f4d8d6a 100644 --- a/src/torchmetrics/functional/multimodal/clip_score.py +++ b/src/torchmetrics/functional/multimodal/clip_score.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Tuple, Union +from typing import TYPE_CHECKING, List, Tuple, Union import torch from torch import Tensor @@ -21,7 +21,11 @@ from torchmetrics.utilities.checks import _SKIP_SLOW_DOCTEST, _try_proceed_with_timeout from torchmetrics.utilities.imports import _TRANSFORMERS_GREATER_EQUAL_4_10 -if _TRANSFORMERS_GREATER_EQUAL_4_10: +if TYPE_CHECKING and _TRANSFORMERS_GREATER_EQUAL_4_10: + from transformers import CLIPModel as _CLIPModel + from transformers import CLIPProcessor as _CLIPProcessor + +if _SKIP_SLOW_DOCTEST and _TRANSFORMERS_GREATER_EQUAL_4_10: from transformers import CLIPModel as _CLIPModel from transformers import CLIPProcessor as _CLIPProcessor @@ -29,9 +33,8 @@ def _download_clip() -> None: _CLIPModel.from_pretrained("openai/clip-vit-large-patch14") _CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") - if _SKIP_SLOW_DOCTEST and not _try_proceed_with_timeout(_download_clip): + if not _try_proceed_with_timeout(_download_clip): __doctest_skip__ = ["clip_score"] - else: __doctest_skip__ = ["clip_score"] _CLIPModel = None @@ -96,6 +99,9 @@ def _get_clip_model_and_processor( ] = "openai/clip-vit-large-patch14", ) -> Tuple[_CLIPModel, _CLIPProcessor]: if _TRANSFORMERS_GREATER_EQUAL_4_10: + from transformers import CLIPModel as _CLIPModel + from transformers import CLIPProcessor as _CLIPProcessor + model = _CLIPModel.from_pretrained(model_name_or_path) processor = _CLIPProcessor.from_pretrained(model_name_or_path) return model, processor diff --git a/src/torchmetrics/multimodal/clip_iqa.py b/src/torchmetrics/multimodal/clip_iqa.py index de48df1bbae..252d7ad5931 100644 --- a/src/torchmetrics/multimodal/clip_iqa.py +++ b/src/torchmetrics/multimodal/clip_iqa.py @@ -39,7 +39,7 @@ if not _MATPLOTLIB_AVAILABLE: __doctest_skip__ = ["CLIPImageQualityAssessment.plot"] -if _TRANSFORMERS_GREATER_EQUAL_4_10: +if _SKIP_SLOW_DOCTEST and _TRANSFORMERS_GREATER_EQUAL_4_10: from transformers import CLIPModel as _CLIPModel from transformers import CLIPProcessor as _CLIPProcessor @@ -47,7 +47,7 @@ def _download_clip() -> None: _CLIPModel.from_pretrained("openai/clip-vit-large-patch14") _CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") - if _SKIP_SLOW_DOCTEST and not _try_proceed_with_timeout(_download_clip): + if not _try_proceed_with_timeout(_download_clip): __doctest_skip__ = ["CLIPImageQualityAssessment", "CLIPImageQualityAssessment.plot"] else: __doctest_skip__ = ["CLIPImageQualityAssessment", "CLIPImageQualityAssessment.plot"] diff --git a/src/torchmetrics/multimodal/clip_score.py b/src/torchmetrics/multimodal/clip_score.py index c38b12e4a60..a5a8201893a 100644 --- a/src/torchmetrics/multimodal/clip_score.py +++ b/src/torchmetrics/multimodal/clip_score.py @@ -26,7 +26,7 @@ if not _MATPLOTLIB_AVAILABLE: __doctest_skip__ = ["CLIPScore.plot"] -if _TRANSFORMERS_GREATER_EQUAL_4_10: +if _SKIP_SLOW_DOCTEST and _TRANSFORMERS_GREATER_EQUAL_4_10: from transformers import CLIPModel as _CLIPModel from transformers import CLIPProcessor as _CLIPProcessor @@ -34,7 +34,7 @@ def _download_clip() -> None: _CLIPModel.from_pretrained("openai/clip-vit-large-patch14") _CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") - if _SKIP_SLOW_DOCTEST and not _try_proceed_with_timeout(_download_clip): + if not _try_proceed_with_timeout(_download_clip): __doctest_skip__ = ["CLIPScore", "CLIPScore.plot"] else: __doctest_skip__ = ["CLIPScore", "CLIPScore.plot"]