Skip to content

Commit

Permalink
add doctest for same modality
Browse files Browse the repository at this point in the history
  • Loading branch information
rittik9 committed Jan 9, 2025
1 parent 03efd65 commit b71fe12
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 0 deletions.
23 changes: 23 additions & 0 deletions src/torchmetrics/functional/multimodal/clip_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,29 @@ def clip_score(
>>> score.detach()
tensor(24.4255)
Example:
>>> import torch
>>> from torchmetrics.functional.multimodal import clip_score
>>> torch.manual_seed(42)
>>> torch.cuda.manual_seed_all(42)
>>> score = clip_score(
... torch.randint(255, (3, 224, 224)),
... torch.randint(255, (3, 224, 224)),
... "openai/clip-vit-base-patch16"
... )
>>> score.detach()
tensor(99.3556)
Example:
>>> from torchmetrics.functional.multimodal import clip_score
>>> score = clip_score(
... "28-year-old chef found dead in San Francisco mall",
... "A 28-year-old chef who recently moved to San Francisco was found dead.",
... "openai/clip-vit-base-patch16"
... )
>>> score.detach()
tensor(91.3950)
"""
model, processor = _get_clip_model_and_processor(model_name_or_path)
score, _ = _clip_score_update(source, target, model, processor)
Expand Down
18 changes: 18 additions & 0 deletions src/torchmetrics/multimodal/clip_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,24 @@ class CLIPScore(Metric):
>>> score.detach().round()
tensor(25.)
Example:
>>> import torch
>>> from torchmetrics.multimodal.clip_score import CLIPScore
>>> torch.manual_seed(42)
>>> torch.cuda.manual_seed_all(42)
>>> metric = CLIPScore(model_name_or_path="openai/clip-vit-base-patch16")
>>> score = metric(torch.randint(255, (3, 224, 224)),torch.randint(255, (3, 224, 224)))
>>> score.detach().round()
tensor(100.)
Example:
>>> from torchmetrics.multimodal.clip_score import CLIPScore
>>> metric = CLIPScore(model_name_or_path="openai/clip-vit-base-patch16")
>>> score = metric("28-year-old chef found dead in San Francisco mall",
... "A 28-year-old chef who recently moved to San Francisco was found dead.")
>>> score.detach().round()
tensor(91.)
"""

is_differentiable: bool = False
Expand Down

0 comments on commit b71fe12

Please sign in to comment.