Skip to content

Commit

Permalink
Fix YOLOv5 AnchorGenerator compatibility (#345)
Browse files Browse the repository at this point in the history
* Adopt the new Detect Layer compatibility

* Set fuse False as default

* Cleanup test

* Fix docstring and unittest

* Fix docstrings
  • Loading branch information
zhiqwang authored Mar 8, 2022
1 parent 719f76f commit 79cb37d
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 41 deletions.
19 changes: 6 additions & 13 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,27 +92,20 @@ def test_load_from_ultralytics_voc(

# Define YOLOv5 model
model_yolov5 = load_yolov5_model(checkpoint_path)
model_yolov5.conf = conf # confidence threshold (0-1)
model_yolov5.iou = iou # NMS IoU threshold (0-1)
model_yolov5.eval()
with torch.no_grad():
outs = model_yolov5(img[None])[0]
outs = non_max_suppression(outs, conf, iou, agnostic=True)
out_from_yolov5 = outs[0]
out_yolov5 = outs[0]

# Define yolort model
model_yolort = YOLO.load_from_yolov5(
checkpoint_path,
score_thresh=conf,
version=version,
)
model_yolort = YOLO.load_from_yolov5(checkpoint_path, score_thresh=conf, version=version)
model_yolort.eval()
with torch.no_grad():
out_from_yolort = model_yolort(img[None])
out_yolort = model_yolort(img[None])

torch.testing.assert_allclose(out_from_yolort[0]["boxes"], out_from_yolov5[:, :4])
torch.testing.assert_allclose(out_from_yolort[0]["scores"], out_from_yolov5[:, 4])
torch.testing.assert_allclose(out_from_yolort[0]["labels"], out_from_yolov5[:, 5].to(dtype=torch.int64))
torch.testing.assert_allclose(out_yolort[0]["boxes"], out_yolov5[:, :4])
torch.testing.assert_allclose(out_yolort[0]["scores"], out_yolov5[:, 4])
torch.testing.assert_allclose(out_yolort[0]["labels"], out_yolov5[:, 5].to(dtype=torch.int64))


def test_read_image_to_tensor():
Expand Down
15 changes: 9 additions & 6 deletions test/test_v5.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import hashlib

from torch import Tensor
from yolort.v5 import load_yolov5_model, attempt_download
from yolort.v5 import AutoShape, attempt_download, load_yolov5_model


def test_attempt_download():
Expand All @@ -15,16 +15,19 @@ def test_attempt_download():
assert readable_hash[:8] == "9ca9a642"


def test_load_yolov5_model():
def test_load_yolov5_model_autoshape_attached():
img_path = "test/assets/zidane.jpg"

model_url = "https://github.com/ultralytics/yolov5/releases/download/v4.0/yolov5s.pt"
checkpoint_path = attempt_download(model_url, hash_prefix="9ca9a642")
model_url = "https://github.com/ultralytics/yolov5/releases/download/v6.1/yolov5s.pt"
checkpoint_path = attempt_download(model_url, hash_prefix="8b3b748c")

model = load_yolov5_model(checkpoint_path)
# Attach AutoShape
model = AutoShape(model)

model = load_yolov5_model(checkpoint_path, autoshape=True, verbose=False)
results = model(img_path)

assert isinstance(results.pred, list)
assert len(results.pred) == 1
assert isinstance(results.pred[0], Tensor)
assert results.pred[0].shape == (3, 6)
assert results.pred[0].shape == (4, 6)
46 changes: 25 additions & 21 deletions yolort/v5/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
from pathlib import Path

import torch
from torch import nn

from .models import AutoShape
from .models.yolo import Model
from .utils import attempt_download, intersect_dicts, set_logging
from .models.yolo import Detect, Model
from .utils import attempt_download

__all__ = ["add_yolov5_context", "load_yolov5_model", "get_yolov5_size"]

Expand Down Expand Up @@ -46,32 +46,36 @@ def get_yolov5_size(depth_multiple, width_multiple):
)


def load_yolov5_model(checkpoint_path: str, autoshape: bool = False, verbose: bool = True):
def load_yolov5_model(checkpoint_path: str, fuse: bool = False):
"""
Creates a specified YOLOv5 model
Creates a specified YOLOv5 model.
Note:
Currently this tool is mainly used to load the checkpoints trained by yolov5
with support for versions v3.1, v4.0 (v5.0) and v6.0 (v6.1). In addition it is
available for inference with AutoShape attached for versions v6.0 (v6.1).
Args:
checkpoint_path (str): path of the YOLOv5 model, i.e. 'yolov5s.pt'
autoshape (bool): apply YOLOv5 .autoshape() wrapper to model. Default: False.
verbose (bool): print all information to screen. Default: True.
fuse (bool): fuse model Conv2d() + BatchNorm2d() layers. Default: False
Returns:
YOLOv5 pytorch model
"""
set_logging(verbose=verbose)

with add_yolov5_context():
ckpt = torch.load(attempt_download(checkpoint_path), map_location=torch.device("cpu"))

if isinstance(ckpt, dict):
model_ckpt = ckpt["model"] # load model

model = Model(model_ckpt.yaml) # create model
ckpt_state_dict = model_ckpt.float().state_dict() # checkpoint state_dict as FP32
ckpt_state_dict = intersect_dicts(ckpt_state_dict, model.state_dict(), exclude=["anchors"])
model.load_state_dict(ckpt_state_dict, strict=False)

if autoshape:
model = AutoShape(model)

return model
if fuse:
model = ckpt["ema" if ckpt.get("ema") else "model"].float().fuse().eval()
else: # without layer fuse
model = ckpt["ema" if ckpt.get("ema") else "model"].float().eval()

# Compatibility updates
for m in model.modules():
if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model]:
if isinstance(m, Detect):
if not isinstance(m.anchor_grid, list): # new Detect Layer compatibility
delattr(m, "anchor_grid")
setattr(m, "anchor_grid", [torch.zeros(1)] * m.nl)

return model
2 changes: 1 addition & 1 deletion yolort/v5/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, version="r4.0"):
def forward(self, x: Tensor) -> Tensor:
return self.act(self.bn(self.conv(x)))

def fuseforward(self, x):
def forward_fuse(self, x: Tensor) -> Tensor:
return self.act(self.conv(x))


Expand Down

0 comments on commit 79cb37d

Please sign in to comment.