Skip to content

Commit

Permalink
add unittests
Browse files Browse the repository at this point in the history
  • Loading branch information
rittik9 committed Jan 9, 2025
1 parent 045faad commit 4c074c5
Showing 1 changed file with 30 additions and 0 deletions.
30 changes: 30 additions & 0 deletions tests/unittests/multimodal/test_clip_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,36 @@ def test_warning_on_long_caption(self, inputs, model_name_or_path):
):
metric.update(preds[0], target[0])

@skip_on_connection_issues()
def test_clip_score_image_to_image(self, inputs, model_name_or_path):
"""Test CLIP score for image-to-image comparison."""
metric = CLIPScore(model_name_or_path=model_name_or_path)
preds, _ = inputs
score = metric(preds[0][0], preds[0][1])
assert score.detach().round() == torch.tensor(96.0)

@skip_on_connection_issues()
def test_clip_score_text_to_text(self, inputs, model_name_or_path):
"""Test CLIP score for text-to-text comparison."""
metric = CLIPScore(model_name_or_path=model_name_or_path)
_, target = inputs
score = metric(target[0][0], target[0][1])
assert score.detach().round() == torch.tensor(65.0)

@skip_on_connection_issues()
def test_clip_score_functional_image_to_image(self, inputs, model_name_or_path):
"""Test functional implementation of image-to-image CLIP score."""
preds, _ = inputs
score = clip_score(preds[0][0], preds[0][1], model_name_or_path=model_name_or_path)
assert score.detach().round() == torch.tensor(96.0)

@skip_on_connection_issues()
def test_clip_score_functional_text_to_text(self, inputs, model_name_or_path):
"""Test functional implementation of text-to-text CLIP score."""
_, target = inputs
score = clip_score(target[0][0], target[0][1], model_name_or_path=model_name_or_path)
assert score.detach().round() == torch.tensor(65.0)


@pytest.mark.parametrize(
("input_data", "expected"),
Expand Down

0 comments on commit 4c074c5

Please sign in to comment.