From 4c074c563b317083765008ee1709a302726094ef Mon Sep 17 00:00:00 2001 From: rittik9 Date: Thu, 9 Jan 2025 20:25:18 +0000 Subject: [PATCH] add unittests --- tests/unittests/multimodal/test_clip_score.py | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/tests/unittests/multimodal/test_clip_score.py b/tests/unittests/multimodal/test_clip_score.py index b948e1889c7..cdabbbf502a 100644 --- a/tests/unittests/multimodal/test_clip_score.py +++ b/tests/unittests/multimodal/test_clip_score.py @@ -149,6 +149,36 @@ def test_warning_on_long_caption(self, inputs, model_name_or_path): ): metric.update(preds[0], target[0]) + @skip_on_connection_issues() + def test_clip_score_image_to_image(self, inputs, model_name_or_path): + """Test CLIP score for image-to-image comparison.""" + metric = CLIPScore(model_name_or_path=model_name_or_path) + preds, _ = inputs + score = metric(preds[0][0], preds[0][1]) + assert score.detach().round() == torch.tensor(96.0) + + @skip_on_connection_issues() + def test_clip_score_text_to_text(self, inputs, model_name_or_path): + """Test CLIP score for text-to-text comparison.""" + metric = CLIPScore(model_name_or_path=model_name_or_path) + _, target = inputs + score = metric(target[0][0], target[0][1]) + assert score.detach().round() == torch.tensor(65.0) + + @skip_on_connection_issues() + def test_clip_score_functional_image_to_image(self, inputs, model_name_or_path): + """Test functional implementation of image-to-image CLIP score.""" + preds, _ = inputs + score = clip_score(preds[0][0], preds[0][1], model_name_or_path=model_name_or_path) + assert score.detach().round() == torch.tensor(96.0) + + @skip_on_connection_issues() + def test_clip_score_functional_text_to_text(self, inputs, model_name_or_path): + """Test functional implementation of text-to-text CLIP score.""" + _, target = inputs + score = clip_score(target[0][0], target[0][1], model_name_or_path=model_name_or_path) + assert score.detach().round() == torch.tensor(65.0) + @pytest.mark.parametrize( ("input_data", "expected"),