From 821369caadbae3c2fb04984c03161ace8e58e593 Mon Sep 17 00:00:00 2001 From: glennjocher Date: Thu, 25 Aug 2022 05:03:44 +0200 Subject: [PATCH 1/3] Faster threaded plots --- utils/__init__.py | 11 +++++++++++ utils/general.py | 10 ---------- utils/metrics.py | 6 +++++- utils/plots.py | 3 ++- 4 files changed, 18 insertions(+), 12 deletions(-) diff --git a/utils/__init__.py b/utils/__init__.py index a63c473a4340..b9c75dbc2251 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -2,6 +2,7 @@ """ utils/initialization """ +import threading def notebook_init(verbose=True): @@ -34,3 +35,13 @@ def notebook_init(verbose=True): select_device(newline=False) print(emojis(f'Setup complete ✅ {s}')) return display + + +def threaded(func): + # Multi-threads a target function and returns thread. Usage: @threaded decorator + def wrapper(*args, **kwargs): + thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True) + thread.start() + return thread + + return wrapper \ No newline at end of file diff --git a/utils/general.py b/utils/general.py index d8c90f10ac8f..da9f2b56e100 100755 --- a/utils/general.py +++ b/utils/general.py @@ -206,16 +206,6 @@ def handler(*args, **kwargs): return handler -def threaded(func): - # Multi-threads a target function and returns thread. Usage: @threaded decorator - def wrapper(*args, **kwargs): - thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True) - thread.start() - return thread - - return wrapper - - def methods(instance): # Get class/instance methods return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")] diff --git a/utils/metrics.py b/utils/metrics.py index 8fa3c7e217c7..0ad7663576bc 100644 --- a/utils/metrics.py +++ b/utils/metrics.py @@ -11,6 +11,8 @@ import numpy as np import torch +from utils import threaded + def fitness(x): # Model fitness as a weighted combination of metrics @@ -184,6 +186,7 @@ def tp_fp(self): # fn = self.matrix.sum(0) - tp # false negatives (missed detections) return tp[:-1], fp[:-1] # remove background class + @threaded def plot(self, normalize=True, save_dir='', names=()): try: import seaborn as sn @@ -319,7 +322,7 @@ def wh_iou(wh1, wh2, eps=1e-7): # Plots ---------------------------------------------------------------------------------------------------------------- - +@threaded def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=()): # Precision-recall curve fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True) @@ -342,6 +345,7 @@ def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=()): plt.close() +@threaded def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confidence', ylabel='Metric'): # Metric-confidence curve fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True) diff --git a/utils/plots.py b/utils/plots.py index d35e2bdd168a..6fb71cbb740b 100644 --- a/utils/plots.py +++ b/utils/plots.py @@ -19,8 +19,9 @@ import torch from PIL import Image, ImageDraw, ImageFont +from utils import threaded from utils.general import (CONFIG_DIR, FONT, LOGGER, check_font, check_requirements, clip_coords, increment_path, - is_ascii, threaded, try_except, xywh2xyxy, xyxy2xywh) + is_ascii, try_except, xywh2xyxy, xyxy2xywh) from utils.metrics import fitness # Settings From 4749b21ecbff4deda096661b9a64f861720f4ebe Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 25 Aug 2022 03:04:18 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- utils/__init__.py | 2 +- utils/metrics.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/utils/__init__.py b/utils/__init__.py index b9c75dbc2251..eacb7bc4327a 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -44,4 +44,4 @@ def wrapper(*args, **kwargs): thread.start() return thread - return wrapper \ No newline at end of file + return wrapper diff --git a/utils/metrics.py b/utils/metrics.py index 0ad7663576bc..a195b46a72a9 100644 --- a/utils/metrics.py +++ b/utils/metrics.py @@ -322,6 +322,7 @@ def wh_iou(wh1, wh2, eps=1e-7): # Plots ---------------------------------------------------------------------------------------------------------------- + @threaded def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=()): # Precision-recall curve From 1e00ba229a5d96fa8c723d75e4d836503e63d0f7 Mon Sep 17 00:00:00 2001 From: glennjocher Date: Thu, 25 Aug 2022 13:27:56 +0200 Subject: [PATCH 3/3] threaded-safe plots --- utils/metrics.py | 16 +++++++--------- utils/plots.py | 16 ++++++++-------- 2 files changed, 15 insertions(+), 17 deletions(-) diff --git a/utils/metrics.py b/utils/metrics.py index a195b46a72a9..2fe26b4e587b 100644 --- a/utils/metrics.py +++ b/utils/metrics.py @@ -212,7 +212,7 @@ def plot(self, normalize=True, save_dir='', names=()): yticklabels=names + ['background FN'] if labels else "auto").set_facecolor((1, 1, 1)) fig.axes[0].set_xlabel('True') fig.axes[0].set_ylabel('Predicted') - plt.title('Confusion Matrix') + fig.title('Confusion Matrix') fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250) plt.close() except Exception as e: @@ -321,8 +321,6 @@ def wh_iou(wh1, wh2, eps=1e-7): # Plots ---------------------------------------------------------------------------------------------------------------- - - @threaded def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=()): # Precision-recall curve @@ -340,10 +338,10 @@ def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=()): ax.set_ylabel('Precision') ax.set_xlim(0, 1) ax.set_ylim(0, 1) - plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left") - plt.title('Precision-Recall Curve') + ax.set_title('Precision-Recall Curve') + ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left") fig.savefig(save_dir, dpi=250) - plt.close() + fig.close() @threaded @@ -363,7 +361,7 @@ def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confi ax.set_ylabel(ylabel) ax.set_xlim(0, 1) ax.set_ylim(0, 1) - plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left") - plt.title(f'{ylabel}-Confidence Curve') + ax.set_title(f'{ylabel}-Confidence Curve') + ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left") fig.savefig(save_dir, dpi=250) - plt.close() + fig.close() diff --git a/utils/plots.py b/utils/plots.py index 6fb71cbb740b..a9f39c38b480 100644 --- a/utils/plots.py +++ b/utils/plots.py @@ -151,8 +151,8 @@ def feature_visualization(x, module_type, stage, n=32, save_dir=Path('runs/detec LOGGER.info(f'Saving {f}... ({n}/{channels})') plt.title('Features') - plt.savefig(f, dpi=300, bbox_inches='tight') - plt.close() + fig.savefig(f, dpi=300, bbox_inches='tight') + fig.close() np.save(str(f.with_suffix('.npy')), x[0].cpu().numpy()) # npy save @@ -274,12 +274,12 @@ def plot_val_txt(): # from utils.plots import *; plot_val() fig, ax = plt.subplots(1, 1, figsize=(6, 6), tight_layout=True) ax.hist2d(cx, cy, bins=600, cmax=10, cmin=0) ax.set_aspect('equal') - plt.savefig('hist2d.png', dpi=300) + fig.savefig('hist2d.png', dpi=300) fig, ax = plt.subplots(1, 2, figsize=(12, 6), tight_layout=True) ax[0].hist(cx, bins=600) ax[1].hist(cy, bins=600) - plt.savefig('hist1d.png', dpi=200) + fig.savefig('hist1d.png', dpi=200) def plot_targets_txt(): # from utils.plots import *; plot_targets_txt() @@ -292,7 +292,7 @@ def plot_targets_txt(): # from utils.plots import *; plot_targets_txt() ax[i].hist(x[i], bins=100, label=f'{x[i].mean():.3g} +/- {x[i].std():.3g}') ax[i].legend() ax[i].set_title(s[i]) - plt.savefig('targets.jpg', dpi=200) + fig.savefig('targets.jpg', dpi=200) def plot_val_study(file='', dir='', x=None): # from utils.plots import *; plot_val_study() @@ -404,8 +404,8 @@ def imshow_cls(im, labels=None, pred=None, names=None, nmax=25, verbose=False, f if labels is not None: s = names[labels[i]] + (f'—{names[pred[i]]}' if pred is not None else '') ax[i].set_title(s, fontsize=8, verticalalignment='top') - plt.savefig(f, dpi=300, bbox_inches='tight') - plt.close() + fig.savefig(f, dpi=300, bbox_inches='tight') + fig.close() if verbose: LOGGER.info(f"Saving {f}") if labels is not None: @@ -465,7 +465,7 @@ def plot_results(file='path/to/results.csv', dir=''): LOGGER.info(f'Warning: Plotting error for {f}: {e}') ax[1].legend() fig.savefig(save_dir / 'results.png', dpi=200) - plt.close() + fig.close() def profile_idetection(start=0, stop=0, labels=(), save_dir=''):