Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"""
utils/initialization
"""
import threading


def notebook_init(verbose=True):
Expand Down Expand Up @@ -34,3 +35,13 @@ def notebook_init(verbose=True):
select_device(newline=False)
print(emojis(f'Setup complete βœ… {s}'))
return display


def threaded(func):
# Multi-threads a target function and returns thread. Usage: @threaded decorator
def wrapper(*args, **kwargs):
thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
thread.start()
return thread

return wrapper
10 changes: 0 additions & 10 deletions utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,16 +206,6 @@ def handler(*args, **kwargs):
return handler


def threaded(func):
# Multi-threads a target function and returns thread. Usage: @threaded decorator
def wrapper(*args, **kwargs):
thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
thread.start()
return thread

return wrapper


def methods(instance):
# Get class/instance methods
return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]
Expand Down
21 changes: 12 additions & 9 deletions utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import numpy as np
import torch

from utils import threaded


def fitness(x):
# Model fitness as a weighted combination of metrics
Expand Down Expand Up @@ -184,6 +186,7 @@ def tp_fp(self):
# fn = self.matrix.sum(0) - tp # false negatives (missed detections)
return tp[:-1], fp[:-1] # remove background class

@threaded
def plot(self, normalize=True, save_dir='', names=()):
try:
import seaborn as sn
Expand All @@ -209,7 +212,7 @@ def plot(self, normalize=True, save_dir='', names=()):
yticklabels=names + ['background FN'] if labels else "auto").set_facecolor((1, 1, 1))
fig.axes[0].set_xlabel('True')
fig.axes[0].set_ylabel('Predicted')
plt.title('Confusion Matrix')
fig.title('Confusion Matrix')
fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250)
plt.close()
except Exception as e:
Expand Down Expand Up @@ -318,8 +321,7 @@ def wh_iou(wh1, wh2, eps=1e-7):


# Plots ----------------------------------------------------------------------------------------------------------------


@threaded
def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=()):
# Precision-recall curve
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
Expand All @@ -336,12 +338,13 @@ def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=()):
ax.set_ylabel('Precision')
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
plt.title('Precision-Recall Curve')
ax.set_title('Precision-Recall Curve')
ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
fig.savefig(save_dir, dpi=250)
plt.close()
fig.close()


@threaded
def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confidence', ylabel='Metric'):
# Metric-confidence curve
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
Expand All @@ -358,7 +361,7 @@ def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confi
ax.set_ylabel(ylabel)
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
plt.title(f'{ylabel}-Confidence Curve')
ax.set_title(f'{ylabel}-Confidence Curve')
ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
fig.savefig(save_dir, dpi=250)
plt.close()
fig.close()
19 changes: 10 additions & 9 deletions utils/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@
import torch
from PIL import Image, ImageDraw, ImageFont

from utils import threaded
from utils.general import (CONFIG_DIR, FONT, LOGGER, check_font, check_requirements, clip_coords, increment_path,
is_ascii, threaded, try_except, xywh2xyxy, xyxy2xywh)
is_ascii, try_except, xywh2xyxy, xyxy2xywh)
from utils.metrics import fitness

# Settings
Expand Down Expand Up @@ -150,8 +151,8 @@ def feature_visualization(x, module_type, stage, n=32, save_dir=Path('runs/detec

LOGGER.info(f'Saving {f}... ({n}/{channels})')
plt.title('Features')
plt.savefig(f, dpi=300, bbox_inches='tight')
plt.close()
fig.savefig(f, dpi=300, bbox_inches='tight')
fig.close()
np.save(str(f.with_suffix('.npy')), x[0].cpu().numpy()) # npy save


Expand Down Expand Up @@ -273,12 +274,12 @@ def plot_val_txt(): # from utils.plots import *; plot_val()
fig, ax = plt.subplots(1, 1, figsize=(6, 6), tight_layout=True)
ax.hist2d(cx, cy, bins=600, cmax=10, cmin=0)
ax.set_aspect('equal')
plt.savefig('hist2d.png', dpi=300)
fig.savefig('hist2d.png', dpi=300)

fig, ax = plt.subplots(1, 2, figsize=(12, 6), tight_layout=True)
ax[0].hist(cx, bins=600)
ax[1].hist(cy, bins=600)
plt.savefig('hist1d.png', dpi=200)
fig.savefig('hist1d.png', dpi=200)


def plot_targets_txt(): # from utils.plots import *; plot_targets_txt()
Expand All @@ -291,7 +292,7 @@ def plot_targets_txt(): # from utils.plots import *; plot_targets_txt()
ax[i].hist(x[i], bins=100, label=f'{x[i].mean():.3g} +/- {x[i].std():.3g}')
ax[i].legend()
ax[i].set_title(s[i])
plt.savefig('targets.jpg', dpi=200)
fig.savefig('targets.jpg', dpi=200)


def plot_val_study(file='', dir='', x=None): # from utils.plots import *; plot_val_study()
Expand Down Expand Up @@ -403,8 +404,8 @@ def imshow_cls(im, labels=None, pred=None, names=None, nmax=25, verbose=False, f
if labels is not None:
s = names[labels[i]] + (f'β€”{names[pred[i]]}' if pred is not None else '')
ax[i].set_title(s, fontsize=8, verticalalignment='top')
plt.savefig(f, dpi=300, bbox_inches='tight')
plt.close()
fig.savefig(f, dpi=300, bbox_inches='tight')
fig.close()
if verbose:
LOGGER.info(f"Saving {f}")
if labels is not None:
Expand Down Expand Up @@ -464,7 +465,7 @@ def plot_results(file='path/to/results.csv', dir=''):
LOGGER.info(f'Warning: Plotting error for {f}: {e}')
ax[1].legend()
fig.savefig(save_dir / 'results.png', dpi=200)
plt.close()
fig.close()


def profile_idetection(start=0, stop=0, labels=(), save_dir=''):
Expand Down