Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CLIP score #1314

Merged
merged 44 commits into from
Nov 18, 2022
Merged
Show file tree
Hide file tree
Changes from 36 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
f9c3fbe
first steps
SkafteNicki Nov 5, 2022
aaa3265
further updates
SkafteNicki Nov 5, 2022
7295fef
add some testing
SkafteNicki Nov 5, 2022
f199d7e
changelog
SkafteNicki Nov 5, 2022
7814d2b
docstring
SkafteNicki Nov 5, 2022
7ca7854
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 5, 2022
7df3cbc
add to index
SkafteNicki Nov 8, 2022
b9db500
add docstrings
SkafteNicki Nov 8, 2022
ff3e62a
update
SkafteNicki Nov 9, 2022
7e1c8d1
fix tests
SkafteNicki Nov 10, 2022
2839bae
Merge branch 'metric/clip' of https://github.com/PyTorchLightning/met…
SkafteNicki Nov 10, 2022
4a7dc1d
Merge branch 'master' into metric/clip
SkafteNicki Nov 10, 2022
b1c8b27
add requirement
SkafteNicki Nov 11, 2022
c354fe0
try fixing mypy and docs
SkafteNicki Nov 11, 2022
711e343
fix
SkafteNicki Nov 11, 2022
95fbff7
skip on no transformer
SkafteNicki Nov 11, 2022
2111d09
fix typing
SkafteNicki Nov 11, 2022
9a3f256
Merge branch 'master' into metric/clip
Borda Nov 11, 2022
3702610
Apply suggestions from code review
SkafteNicki Nov 13, 2022
cd1f50e
Merge branch 'master' into metric/clip
SkafteNicki Nov 14, 2022
1df18be
add functional and refactor
SkafteNicki Nov 14, 2022
3205b91
change variable name
SkafteNicki Nov 14, 2022
6490a03
Merge branch 'master' into metric/clip
SkafteNicki Nov 14, 2022
61cedeb
fix testing
SkafteNicki Nov 14, 2022
c0cec12
try fixing typing
SkafteNicki Nov 14, 2022
61cb3cc
Merge branch 'master' into metric/clip
SkafteNicki Nov 16, 2022
0cbc15d
8g
Borda Nov 16, 2022
0a734db
fix requirement + testing
SkafteNicki Nov 16, 2022
02c5234
Merge branch 'metric/clip' of https://github.com/PyTorchLightning/met…
SkafteNicki Nov 16, 2022
8a876c6
more requirements
SkafteNicki Nov 16, 2022
ae79845
fix
SkafteNicki Nov 16, 2022
89813ce
Merge branch 'master' into metric/clip
SkafteNicki Nov 16, 2022
f091e62
fix doctests
SkafteNicki Nov 16, 2022
fcf268d
Merge branch 'metric/clip' of https://github.com/PyTorchLightning/met…
SkafteNicki Nov 16, 2022
53ec80d
fix
SkafteNicki Nov 16, 2022
e3a9117
remove back
SkafteNicki Nov 17, 2022
2851923
Merge branch 'master' into metric/clip
Borda Nov 17, 2022
ea4b11a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 17, 2022
f1595a2
move section in index
SkafteNicki Nov 17, 2022
ae925b9
set min version of transformers
SkafteNicki Nov 17, 2022
8feb2cb
fix flake
SkafteNicki Nov 17, 2022
95bd30b
simple
Borda Nov 17, 2022
0debc25
Apply suggestions from code review
Borda Nov 17, 2022
56dc6f7
avail
Borda Nov 17, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .azure/gpu-pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ jobs:

container:
image: "$(docker-image)"
options: "--runtime=nvidia -e NVIDIA_VISIBLE_DEVICES=all --name ci-container -v /usr/bin/docker:/tmp/docker:ro"
options: "--runtime=nvidia -e NVIDIA_VISIBLE_DEVICES=all --shm-size=8g --name ci-container -v /usr/bin/docker:/tmp/docker:ro"

workspace:
clean: all
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/docs-check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ jobs:

test-docs:
runs-on: ubuntu-20.04
timeout-minutes: 15
timeout-minutes: 20
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
Expand Down
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
* `TschuprowsT` ([#1334](https://github.com/Lightning-AI/metrics/pull/1334))


- Added `CLIPScore` to new multimodal package ([#1314](https://github.com/Lightning-AI/metrics/pull/1314))


### Changed

- Changed `MeanAveragePrecision` to vectorize `_find_best_gt_match` operation ([#1259](https://github.com/Lightning-AI/metrics/pull/1259))
Expand Down
8 changes: 8 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,14 @@ Or directly from conda

nominal/*

.. toctree::
:maxdepth: 2
:name: multimodal
:caption: Multimodal
:glob:

multimodal/*

SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
.. toctree::
:maxdepth: 2
:name: pairwise
Expand Down
2 changes: 2 additions & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,5 @@
.. _LogCosh Error: https://arxiv.org/pdf/2101.10427.pdf
.. _Tschuprow's T: https://en.wikipedia.org/wiki/Tschuprow%27s_T
.. _Pearson's Contingency Coefficient: https://www.itl.nist.gov/div898/software/dataplot/refman2/auxillar/pearcont.htm
.. _CLIP score: https://arxiv.org/pdf/2104.08718.pdf
.. _Huggingface OpenAI: https://huggingface.co/openai
20 changes: 20 additions & 0 deletions docs/source/multimodal/clip_score.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
.. customcarditem::
:header: CLIP Score
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/image_classification.svg
:tags: Image

############################################
CLIP Score
############################################

Module Interface
________________

.. autoclass:: torchmetrics.multimodal.clip_score.CLIPScore
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
:noindex:

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.multimodal.clip_score.clip_score
:noindex:
1 change: 1 addition & 0 deletions requirements/devel.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
-r text.txt
# -r detection.txt # version collision with min versio of PyTorch
-r audio.txt
-r multimodal.txt

# add extra testing
-r image_test.txt
Expand Down
1 change: 1 addition & 0 deletions requirements/multimodal.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
transformers>=4.10.0
2 changes: 1 addition & 1 deletion requirements/text_test.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
jiwer>=2.3.0
rouge-score>=0.0.4
bert_score==0.3.10
transformers>=4.4.0
transformers>=4.10.0
Borda marked this conversation as resolved.
Show resolved Hide resolved
huggingface-hub<0.7 # hotfix, failing SDR for latest PT 1.11
sacrebleu>=2.0.0
17 changes: 17 additions & 0 deletions src/torchmetrics/functional/multimodal/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.utilities.imports import _TRANSFORMERS_AVAILABLE

if _TRANSFORMERS_AVAILABLE:
from torchmetrics.functional.multimodal.clip_score import clip_score # noqa: F401
140 changes: 140 additions & 0 deletions src/torchmetrics/functional/multimodal/clip_score.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Tuple, Union

import torch
from torch import Tensor
from typing_extensions import Literal

from torchmetrics.utilities.imports import _TRANSFORMERS_AVAILABLE

if _TRANSFORMERS_AVAILABLE:
from transformers import CLIPModel as _CLIPModel
from transformers import CLIPProcessor as _CLIPProcessor
else:
__doctest_skip__ = ["clip_score"]
_CLIPModel = None # type:ignore
_CLIPProcessor = None # type:ignore


def _clip_score_update(
images: Union[Tensor, List[Tensor]],
text: Union[str, List[str]],
model: _CLIPModel,
processor: _CLIPProcessor,
) -> Tuple[Tensor, int]:
if not isinstance(images, list):
if images.ndim == 3:
images = [images]
else: # unwrap into list
images = [i for i in images]

if not all(i.ndim == 3 for i in images):
raise ValueError("Expected all images to be 3d but found image that has either more or less")

if not isinstance(text, list):
text = [text]

if len(text) != len(images):
raise ValueError(
f"Expected the number of images and text examples to be the same but got {len(images)} and {len(text)}"
)
device = images[0].device
processed_input = processor(
text=text, images=[i.cpu() for i in images], return_tensors="pt", padding=True
) # type:ignore

img_features = model.get_image_features(processed_input["pixel_values"].to(device))
img_features = img_features / img_features.norm(p=2, dim=-1, keepdim=True)

txt_features = model.get_text_features(
processed_input["input_ids"].to(device), processed_input["attention_mask"].to(device)
)
txt_features = txt_features / txt_features.norm(p=2, dim=-1, keepdim=True)

# cosine similarity between feature vectors
score = 100 * (img_features * txt_features).sum(axis=-1)
return score, len(text)


def _get_model_and_processor(
model_name_or_path: Literal[
"openai/clip-vit-base-patch16",
"openai/clip-vit-base-patch32",
"openai/clip-vit-large-patch14-336",
"openai/clip-vit-large-patch14",
] = "openai/clip-vit-large-patch14",
) -> Tuple[_CLIPModel, _CLIPProcessor]:
if _TRANSFORMERS_AVAILABLE:
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
model = _CLIPModel.from_pretrained(model_name_or_path)
processor = _CLIPProcessor.from_pretrained(model_name_or_path)
return model, processor
else:
raise ModuleNotFoundError(
"`clip_score` metric requires `transformers` package be installed."
" Either install with `pip install transformers>=4.0` or `pip install torchmetrics[multimodal]`."
)


def clip_score(
images: Union[Tensor, List[Tensor]],
text: Union[str, List[str]],
model_name_or_path: Literal[
"openai/clip-vit-base-patch16",
"openai/clip-vit-base-patch32",
"openai/clip-vit-large-patch14-336",
"openai/clip-vit-large-patch14",
] = "openai/clip-vit-large-patch14",
) -> Tensor:
"""`CLIP Score`_ is a reference free metric that can be used to evaluate the correlation between a generated
caption for an image and the actual content of the image. It has been found to be highly correlated with human
judgement. The metric is defined as:

.. math::
\text{CLIPScore(I, C)} = max(100 * cos(E_I, E_C), 0)

which corresponds to the cosine similarity between visual CLIP embedding :math:`E_i` for an image :math:`i` and
textual CLIP embedding :math:`E_C` for an caption :math:`C`. The score is bound between 0 and 100 and the closer
to 100 the better.

.. note:: Metric is not scriptable

Args:
images: Either a single [N, C, H, W] tensor or a list of [C, H, W] tensors
text: Either a single caption or a list of captions
model_name_or_path: string indicating the version of the CLIP model to use. Available models are
`"openai/clip-vit-base-patch16"`, `"openai/clip-vit-base-patch32"`, `"openai/clip-vit-large-patch14-336"`
and `"openai/clip-vit-large-patch14"`,

Raises:
ModuleNotFoundError:
If transformers package is not installed
ValueError:
If not all images have format [C, H, W]
ValueError:
If the number of images and captions do not match

Example:
>>> import torch
>>> _ = torch.manual_seed(42)
>>> from torchmetrics.functional.multimodal import clip_score
>>> score = clip_score(torch.randint(255, (3, 224, 224)), "a photo of a cat", "openai/clip-vit-base-patch16")
>>> print(score.detach())
tensor(24.4255)
"""
model, processor = _get_model_and_processor(model_name_or_path)
device = images.device if isinstance(images, Tensor) else images[0].device
score, _ = _clip_score_update(images, text, model.to(device), processor)
score = score.mean(0)
return torch.max(score, torch.zeros_like(score))
17 changes: 17 additions & 0 deletions src/torchmetrics/multimodal/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.utilities.imports import _TRANSFORMERS_AVAILABLE

if _TRANSFORMERS_AVAILABLE:
from torchmetrics.multimodal.clip_score import CLIPScore # noqa: F401
105 changes: 105 additions & 0 deletions src/torchmetrics/multimodal/clip_score.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Union

import torch
from torch import Tensor
from typing_extensions import Literal

from torchmetrics.functional.multimodal.clip_score import _clip_score_update, _get_model_and_processor
from torchmetrics.utilities.imports import _TRANSFORMERS_AVAILABLE

if not _TRANSFORMERS_AVAILABLE:
__doctest_skip__ = ["CLIPScore"]

from torchmetrics import Metric


class CLIPScore(Metric):
"""`CLIP Score`_ is a reference free metric that can be used to evaluate the correlation between a generated
caption for an image and the actual content of the image. It has been found to be highly correlated with human
judgement. The metric is defined as:

.. math::
\text{CLIPScore(I, C)} = max(100 * cos(E_I, E_C), 0)

which corresponds to the cosine similarity between visual CLIP embedding :math:`E_i` for an image :math:`i` and
textual CLIP embedding :math:`E_C` for an caption :math:`C`. The score is bound between 0 and 100 and the closer
to 100 the better.

.. note:: Metric is not scriptable

Args:
model_name_or_path: string indicating the version of the CLIP model to use. Available models are
`"openai/clip-vit-base-patch16"`, `"openai/clip-vit-base-patch32"`, `"openai/clip-vit-large-patch14-336"`
and `"openai/clip-vit-large-patch14"`,

kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Raises:
ModuleNotFoundError:
If transformers package is not installed

Example:
>>> import torch
>>> _ = torch.manual_seed(42)
>>> from torchmetrics.multimodal import CLIPScore
>>> metric = CLIPScore(model_name_or_path="openai/clip-vit-base-patch16")
>>> score = metric(torch.randint(255, (3, 224, 224)), "a photo of a cat")
>>> print(score.detach())
tensor(25.0936)
"""

is_differentiable: bool = False
higher_is_better: bool = True
full_state_update: bool = True
score: Tensor
n_samples: Tensor

def __init__(
self,
model_name_or_path: Literal[
"openai/clip-vit-base-patch16",
"openai/clip-vit-base-patch32",
"openai/clip-vit-large-patch14-336",
"openai/clip-vit-large-patch14",
] = "openai/clip-vit-large-patch14",
**kwargs: Any,
) -> None:

super().__init__(**kwargs)
self.model, self.processor = _get_model_and_processor(model_name_or_path)
self.add_state("score", torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("n_samples", torch.tensor(0, dtype=torch.long), dist_reduce_fx="sum")

def update(self, images: Union[Tensor, List[Tensor]], text: Union[str, List[str]]) -> None:
"""Updates CLIP score on a batch of images and text.

Args:
images: Either a single [N, C, H, W] tensor or a list of [C, H, W] tensors
text: Either a single caption or a list of captions

Raises:
ValueError:
If not all images have format [C, H, W]
ValueError:
If the number of images and captions do not match
"""
score, n_samples = _clip_score_update(images, text, self.model, self.processor)
self.score += score.sum(0)
self.n_samples += n_samples

def compute(self) -> Tensor:
"""Computes accumulated clip score."""
return torch.max(self.score / self.n_samples, torch.zeros_like(self.score))
Loading