From 862e07a0c09e6e0fec656cef239d6f701c4696c5 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Tue, 22 Mar 2022 16:57:20 +0100 Subject: [PATCH 1/3] Update loss.py --- utils/loss.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/utils/loss.py b/utils/loss.py index a06330e034bc..a7ecc6337ca6 100644 --- a/utils/loss.py +++ b/utils/loss.py @@ -129,8 +129,9 @@ def __call__(self, p, targets): # predictions, targets n = b.shape[0] # number of targets if n: - pxy, pwh, _, pcls = pi[b, a, gj, gi].tensor_split((2, 4, 5), dim=1) # target-subset of predictions - + # pxy, pwh, _, pcls = pi[b, a, gj, gi].tensor_split((2, 4, 5), dim=1) # faster, requires torch 1.8.0 + pxy, pwh, _, pcls = pi[b, a, gj, gi].split((2, 2, 1, self.nc), 1) # target-subset of predictions + # Regression pxy = pxy.sigmoid() * 2 - 0.5 pwh = (pwh.sigmoid() * 2) ** 2 * anchors[i] From 7a10fb11bc56cba3f2a2de4c4794d35a57eca35d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 22 Mar 2022 15:59:07 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- utils/loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/loss.py b/utils/loss.py index a7ecc6337ca6..e8ecc6d15485 100644 --- a/utils/loss.py +++ b/utils/loss.py @@ -131,7 +131,7 @@ def __call__(self, p, targets): # predictions, targets if n: # pxy, pwh, _, pcls = pi[b, a, gj, gi].tensor_split((2, 4, 5), dim=1) # faster, requires torch 1.8.0 pxy, pwh, _, pcls = pi[b, a, gj, gi].split((2, 2, 1, self.nc), 1) # target-subset of predictions - + # Regression pxy = pxy.sigmoid() * 2 - 0.5 pwh = (pwh.sigmoid() * 2) ** 2 * anchors[i] From dfeb9c11f01a72b63baadda57957369c49e0cc8b Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Tue, 22 Mar 2022 17:00:39 +0100 Subject: [PATCH 3/3] Update loss.py --- utils/loss.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/utils/loss.py b/utils/loss.py index e8ecc6d15485..bf9b592d4ad2 100644 --- a/utils/loss.py +++ b/utils/loss.py @@ -108,13 +108,15 @@ def __init__(self, model, autobalance=False): if g > 0: BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g) - det = de_parallel(model).model[-1] # Detect() module - self.balance = {3: [4.0, 1.0, 0.4]}.get(det.nl, [4.0, 1.0, 0.25, 0.06, 0.02]) # P3-P7 - self.ssi = list(det.stride).index(16) if autobalance else 0 # stride 16 index + m = de_parallel(model).model[-1] # Detect() module + self.balance = {3: [4.0, 1.0, 0.4]}.get(m.nl, [4.0, 1.0, 0.25, 0.06, 0.02]) # P3-P7 + self.ssi = list(m.stride).index(16) if autobalance else 0 # stride 16 index self.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, 1.0, h, autobalance + self.na = m.na # number of anchors + self.nc = m.nc # number of classes + self.nl = m.nl # number of layers + self.anchors = m.anchors self.device = device - for k in 'na', 'nc', 'nl', 'anchors': - setattr(self, k, getattr(det, k)) def __call__(self, p, targets): # predictions, targets lcls = torch.zeros(1, device=self.device) # class loss