From b71fe12ccf9444631fd121f068624d3a7fd4d8d8 Mon Sep 17 00:00:00 2001 From: rittik9 Date: Thu, 9 Jan 2025 15:40:52 +0000 Subject: [PATCH] add doctest for same modality --- .../functional/multimodal/clip_score.py | 23 +++++++++++++++++++ src/torchmetrics/multimodal/clip_score.py | 18 +++++++++++++++ 2 files changed, 41 insertions(+) diff --git a/src/torchmetrics/functional/multimodal/clip_score.py b/src/torchmetrics/functional/multimodal/clip_score.py index 140f0e2de76..630b61e11f5 100644 --- a/src/torchmetrics/functional/multimodal/clip_score.py +++ b/src/torchmetrics/functional/multimodal/clip_score.py @@ -264,6 +264,29 @@ def clip_score( >>> score.detach() tensor(24.4255) + Example: + >>> import torch + >>> from torchmetrics.functional.multimodal import clip_score + >>> torch.manual_seed(42) + >>> torch.cuda.manual_seed_all(42) + >>> score = clip_score( + ... torch.randint(255, (3, 224, 224)), + ... torch.randint(255, (3, 224, 224)), + ... "openai/clip-vit-base-patch16" + ... ) + >>> score.detach() + tensor(99.3556) + + Example: + >>> from torchmetrics.functional.multimodal import clip_score + >>> score = clip_score( + ... "28-year-old chef found dead in San Francisco mall", + ... "A 28-year-old chef who recently moved to San Francisco was found dead.", + ... "openai/clip-vit-base-patch16" + ... ) + >>> score.detach() + tensor(91.3950) + """ model, processor = _get_clip_model_and_processor(model_name_or_path) score, _ = _clip_score_update(source, target, model, processor) diff --git a/src/torchmetrics/multimodal/clip_score.py b/src/torchmetrics/multimodal/clip_score.py index 3989cdc29cc..703e122168f 100644 --- a/src/torchmetrics/multimodal/clip_score.py +++ b/src/torchmetrics/multimodal/clip_score.py @@ -108,6 +108,24 @@ class CLIPScore(Metric): >>> score.detach().round() tensor(25.) + Example: + >>> import torch + >>> from torchmetrics.multimodal.clip_score import CLIPScore + >>> torch.manual_seed(42) + >>> torch.cuda.manual_seed_all(42) + >>> metric = CLIPScore(model_name_or_path="openai/clip-vit-base-patch16") + >>> score = metric(torch.randint(255, (3, 224, 224)),torch.randint(255, (3, 224, 224))) + >>> score.detach().round() + tensor(100.) + + Example: + >>> from torchmetrics.multimodal.clip_score import CLIPScore + >>> metric = CLIPScore(model_name_or_path="openai/clip-vit-base-patch16") + >>> score = metric("28-year-old chef found dead in San Francisco mall", + ... "A 28-year-old chef who recently moved to San Francisco was found dead.") + >>> score.detach().round() + tensor(91.) + """ is_differentiable: bool = False