Skip to content

Commit

Permalink
Merge pull request #770 from roboflow/fixing_no_good_match_bug
Browse files Browse the repository at this point in the history
fix error on no found objects
  • Loading branch information
isaacrob-roboflow authored Nov 1, 2024
2 parents 727ebd0 + be40f2d commit f95802b
Show file tree
Hide file tree
Showing 2 changed files with 168 additions and 2 deletions.
5 changes: 4 additions & 1 deletion inference/models/owlv2/owlv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def get_class_preds_from_embeds(
survival_indices = torchvision.ops.nms(
to_corners(pred_boxes), pred_scores, iou_threshold
)
# put on numpy and filter to post-nms
# filter to post-nms
pred_boxes = pred_boxes[survival_indices, :]
pred_classes = pred_classes[survival_indices]
pred_scores = pred_scores[survival_indices]
Expand Down Expand Up @@ -371,6 +371,9 @@ def infer_from_embed(
all_predicted_classes.append(classes)
all_predicted_scores.append(scores)

if not all_predicted_boxes:
return []

all_predicted_boxes = torch.cat(all_predicted_boxes, dim=0)
all_predicted_classes = torch.cat(all_predicted_classes, dim=0)
all_predicted_scores = torch.cat(all_predicted_scores, dim=0)
Expand Down
165 changes: 164 additions & 1 deletion tests/inference/models_predictions_tests/test_owlv2.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import pytest

from inference.core.entities.requests.owlv2 import OwlV2InferenceRequest
from inference.models.owlv2.owlv2 import OwlV2


@pytest.mark.slow
def test_owlv2():
image = {
"type": "url",
Expand Down Expand Up @@ -49,6 +52,14 @@ def test_owlv2():
assert abs(532 - posts[3].x) < 1.5
assert abs(572 - posts[4].x) < 1.5


@pytest.mark.slow
def test_owlv2_multiple_prompts():
image = {
"type": "url",
"value": "https://media.roboflow.com/inference/seawithdock.jpeg",
}

# test we can handle multiple (positive and negative) prompts for the same image
request = OwlV2InferenceRequest(
image=image,
Expand Down Expand Up @@ -96,7 +107,15 @@ def test_owlv2():
assert abs(532 - posts[2].x) < 1.5
assert abs(572 - posts[3].x) < 1.5

# test that we can handle no prompts for an image

@pytest.mark.slow
def test_owlv2_image_without_prompts():
image = {
"type": "url",
"value": "https://media.roboflow.com/inference/seawithdock.jpeg",
}

# test that we can handle an image without any prompts
request = OwlV2InferenceRequest(
image=image,
training_data=[
Expand Down Expand Up @@ -124,3 +143,147 @@ def test_owlv2():

response = OwlV2().infer_from_request(request)
assert len(response.predictions) == 5


@pytest.mark.slow
def test_owlv2_bad_prompt():
image = {
"type": "url",
"value": "https://media.roboflow.com/inference/seawithdock.jpeg",
}

# test that we can handle a bad prompt
request = OwlV2InferenceRequest(
image=image,
training_data=[
{
"image": image,
"boxes": [
{
"x": 1,
"y": 1,
"w": 1,
"h": 1,
"cls": "post",
"negative": False,
}
],
}
],
visualize_predictions=True,
confidence=0.9,
)

response = OwlV2().infer_from_request(request)
assert len(response.predictions) == 0


@pytest.mark.slow
def test_owlv2_bad_prompt_hidden_among_good_prompts():
image = {
"type": "url",
"value": "https://media.roboflow.com/inference/seawithdock.jpeg",
}

# test that we can handle a bad prompt
request = OwlV2InferenceRequest(
image=image,
training_data=[
{
"image": image,
"boxes": [
{
"x": 1,
"y": 1,
"w": 1,
"h": 1,
"cls": "post",
"negative": False,
},
{
"x": 223,
"y": 306,
"w": 40,
"h": 226,
"cls": "post",
"negative": False,
},
],
}
],
visualize_predictions=True,
confidence=0.9,
)

response = OwlV2().infer_from_request(request)
assert len(response.predictions) == 5


@pytest.mark.slow
def test_owlv2_no_training_data():
image = {
"type": "url",
"value": "https://media.roboflow.com/inference/seawithdock.jpeg",
}

# test that we can handle no training data
request = OwlV2InferenceRequest(
image=image,
training_data=[],
)

response = OwlV2().infer_from_request(request)
assert len(response.predictions) == 0


@pytest.mark.slow
def test_owlv2_multiple_training_images():
image = {
"type": "url",
"value": "https://media.roboflow.com/inference/seawithdock.jpeg",
}
second_image = {
"type": "url",
"value": "https://media.roboflow.com/inference/dock2.jpg",
}

request = OwlV2InferenceRequest(
image=image,
training_data=[
{
"image": image,
"boxes": [
{
"x": 223,
"y": 306,
"w": 40,
"h": 226,
"cls": "post",
"negative": False,
}
],
},
{
"image": second_image,
"boxes": [
{
"x": 3009,
"y": 1873,
"w": 289,
"h": 811,
"cls": "post",
"negative": True,
}
],
},
],
visualize_predictions=True,
confidence=0.9,
)

response = OwlV2().infer_from_request(request)
assert len(response.predictions) == 5


if __name__ == "__main__":
test_owlv2()

0 comments on commit f95802b

Please sign in to comment.