Skip to content

Commit

Permalink
Lazy import in image/multimodal domains (#2215)
Browse files Browse the repository at this point in the history
* lazy multimodal

* lazy image

* fix import logic

* fix import logic

* revert

* fixes

* fixes

---------

Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Nov 28, 2023
1 parent 0cb4c79 commit 46dfa42
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 27 deletions.
4 changes: 2 additions & 2 deletions src/torchmetrics/functional/image/lpips.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@

if not _TORCHVISION_AVAILABLE:
__doctest_skip__ = ["learned_perceptual_image_patch_similarity"]
else:
from torchvision import models as tv


def _get_net(net: str, pretrained: bool) -> nn.modules.container.Sequential:
Expand All @@ -52,6 +50,8 @@ def _get_net(net: str, pretrained: bool) -> nn.modules.container.Sequential:
pretrained: If pretrained weights should be used
"""
from torchvision import models as tv

if _TORCHVISION_GREATER_EQUAL_0_13:
if pretrained:
pretrained_features = getattr(tv, net)(weights=getattr(tv, _weight_map[net]).IMAGENET1K_V1).features
Expand Down
9 changes: 4 additions & 5 deletions src/torchmetrics/functional/image/uqi.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
from typing import Optional, Sequence, Tuple

import torch
from torch import Tensor
from torch.nn import functional as F # noqa: N812
from torch import Tensor, nn
from typing_extensions import Literal

from torchmetrics.functional.image.helper import _gaussian_kernel_2d
Expand Down Expand Up @@ -92,11 +91,11 @@ def _uqi_compute(
pad_h = (kernel_size[0] - 1) // 2
pad_w = (kernel_size[1] - 1) // 2

preds = F.pad(preds, (pad_h, pad_h, pad_w, pad_w), mode="reflect")
target = F.pad(target, (pad_h, pad_h, pad_w, pad_w), mode="reflect")
preds = nn.functional.pad(preds, (pad_h, pad_h, pad_w, pad_w), mode="reflect")
target = nn.functional.pad(target, (pad_h, pad_h, pad_w, pad_w), mode="reflect")

input_list = torch.cat((preds, target, preds * preds, target * target, preds * target)) # (5 * B, C, H, W)
outputs = F.conv2d(input_list, kernel, groups=channel)
outputs = nn.functional.conv2d(input_list, kernel, groups=channel)
output_list = outputs.split(preds.shape[0])

mu_pred_sq = output_list[0].pow(2)
Expand Down
27 changes: 15 additions & 12 deletions src/torchmetrics/functional/multimodal/clip_iqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List, Literal, Tuple, Union
from typing import TYPE_CHECKING, Dict, List, Literal, Tuple, Union

import torch
from torch import Tensor
Expand All @@ -20,21 +20,22 @@
from torchmetrics.utilities.checks import _SKIP_SLOW_DOCTEST, _try_proceed_with_timeout
from torchmetrics.utilities.imports import _PIQ_GREATER_EQUAL_0_8, _TRANSFORMERS_GREATER_EQUAL_4_10

if _TRANSFORMERS_GREATER_EQUAL_4_10:
if TYPE_CHECKING:
from transformers import CLIPModel as _CLIPModel
from transformers import CLIPProcessor as _CLIPProcessor

if _SKIP_SLOW_DOCTEST and _TRANSFORMERS_GREATER_EQUAL_4_10:
from transformers import CLIPModel as _CLIPModel
from transformers import CLIPProcessor as _CLIPProcessor

def _download_clip() -> None:
_CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
_CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")

if _SKIP_SLOW_DOCTEST and not _try_proceed_with_timeout(_download_clip):
__doctest_skip__ = ["clip_score"]

if not _try_proceed_with_timeout(_download_clip):
__doctest_skip__ = ["clip_image_quality_assessment"]
else:
__doctest_skip__ = ["clip_image_quality_assessment"]
_CLIPModel = None
_CLIPProcessor = None

if not _PIQ_GREATER_EQUAL_0_8:
__doctest_skip__ = ["clip_image_quality_assessment"]
Expand Down Expand Up @@ -67,8 +68,10 @@ def _get_clip_iqa_model_and_processor(
"openai/clip-vit-large-patch14-336",
"openai/clip-vit-large-patch14",
]
) -> Tuple[_CLIPModel, _CLIPProcessor]:
) -> Tuple["_CLIPModel", "_CLIPProcessor"]:
"""Extract the CLIP model and processor from the model name or path."""
from transformers import CLIPProcessor as _CLIPProcessor

if model_name_or_path == "clip_iqa":
if not _PIQ_GREATER_EQUAL_0_8:
raise ValueError(
Expand Down Expand Up @@ -141,8 +144,8 @@ def _clip_iqa_format_prompts(prompts: Tuple[Union[str, Tuple[str, str]]] = ("qua

def _clip_iqa_get_anchor_vectors(
model_name_or_path: str,
model: _CLIPModel,
processor: _CLIPProcessor,
model: "_CLIPModel",
processor: "_CLIPProcessor",
prompts_list: List[str],
device: Union[str, torch.device],
) -> Tensor:
Expand Down Expand Up @@ -176,8 +179,8 @@ def _clip_iqa_get_anchor_vectors(
def _clip_iqa_update(
model_name_or_path: str,
images: Tensor,
model: _CLIPModel,
processor: _CLIPProcessor,
model: "_CLIPModel",
processor: "_CLIPProcessor",
data_range: float,
device: Union[str, torch.device],
) -> Tensor:
Expand Down
14 changes: 10 additions & 4 deletions src/torchmetrics/functional/multimodal/clip_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Tuple, Union
from typing import TYPE_CHECKING, List, Tuple, Union

import torch
from torch import Tensor
Expand All @@ -21,17 +21,20 @@
from torchmetrics.utilities.checks import _SKIP_SLOW_DOCTEST, _try_proceed_with_timeout
from torchmetrics.utilities.imports import _TRANSFORMERS_GREATER_EQUAL_4_10

if _TRANSFORMERS_GREATER_EQUAL_4_10:
if TYPE_CHECKING and _TRANSFORMERS_GREATER_EQUAL_4_10:
from transformers import CLIPModel as _CLIPModel
from transformers import CLIPProcessor as _CLIPProcessor

if _SKIP_SLOW_DOCTEST and _TRANSFORMERS_GREATER_EQUAL_4_10:
from transformers import CLIPModel as _CLIPModel
from transformers import CLIPProcessor as _CLIPProcessor

def _download_clip() -> None:
_CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
_CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")

if _SKIP_SLOW_DOCTEST and not _try_proceed_with_timeout(_download_clip):
if not _try_proceed_with_timeout(_download_clip):
__doctest_skip__ = ["clip_score"]

else:
__doctest_skip__ = ["clip_score"]
_CLIPModel = None
Expand Down Expand Up @@ -96,6 +99,9 @@ def _get_clip_model_and_processor(
] = "openai/clip-vit-large-patch14",
) -> Tuple[_CLIPModel, _CLIPProcessor]:
if _TRANSFORMERS_GREATER_EQUAL_4_10:
from transformers import CLIPModel as _CLIPModel
from transformers import CLIPProcessor as _CLIPProcessor

model = _CLIPModel.from_pretrained(model_name_or_path)
processor = _CLIPProcessor.from_pretrained(model_name_or_path)
return model, processor
Expand Down
4 changes: 2 additions & 2 deletions src/torchmetrics/multimodal/clip_iqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,15 @@
if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["CLIPImageQualityAssessment.plot"]

if _TRANSFORMERS_GREATER_EQUAL_4_10:
if _SKIP_SLOW_DOCTEST and _TRANSFORMERS_GREATER_EQUAL_4_10:
from transformers import CLIPModel as _CLIPModel
from transformers import CLIPProcessor as _CLIPProcessor

def _download_clip() -> None:
_CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
_CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")

if _SKIP_SLOW_DOCTEST and not _try_proceed_with_timeout(_download_clip):
if not _try_proceed_with_timeout(_download_clip):
__doctest_skip__ = ["CLIPImageQualityAssessment", "CLIPImageQualityAssessment.plot"]
else:
__doctest_skip__ = ["CLIPImageQualityAssessment", "CLIPImageQualityAssessment.plot"]
Expand Down
4 changes: 2 additions & 2 deletions src/torchmetrics/multimodal/clip_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,15 @@
if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["CLIPScore.plot"]

if _TRANSFORMERS_GREATER_EQUAL_4_10:
if _SKIP_SLOW_DOCTEST and _TRANSFORMERS_GREATER_EQUAL_4_10:
from transformers import CLIPModel as _CLIPModel
from transformers import CLIPProcessor as _CLIPProcessor

def _download_clip() -> None:
_CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
_CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")

if _SKIP_SLOW_DOCTEST and not _try_proceed_with_timeout(_download_clip):
if not _try_proceed_with_timeout(_download_clip):
__doctest_skip__ = ["CLIPScore", "CLIPScore.plot"]
else:
__doctest_skip__ = ["CLIPScore", "CLIPScore.plot"]
Expand Down

0 comments on commit 46dfa42

Please sign in to comment.