Skip to content

Commit d07ddc6

Browse files
New TryExcept decorator (#9154)
* New TryExcept decorator * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent f0e5a60 commit d07ddc6

File tree

4 files changed

+71
-61
lines changed

4 files changed

+71
-61
lines changed

utils/__init__.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,33 @@
33
utils/initialization
44
"""
55

6+
import contextlib
7+
import threading
8+
9+
10+
class TryExcept(contextlib.ContextDecorator):
11+
# YOLOv5 TryExcept class. Usage: @TryExcept() decorator or 'with TryExcept():' context manager
12+
def __init__(self, msg='default message here'):
13+
self.msg = msg
14+
15+
def __enter__(self):
16+
pass
17+
18+
def __exit__(self, exc_type, value, traceback):
19+
if value:
20+
print(f'{self.msg}: {value}')
21+
return True
22+
23+
24+
def threaded(func):
25+
# Multi-threads a target function and returns thread. Usage: @threaded decorator
26+
def wrapper(*args, **kwargs):
27+
thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
28+
thread.start()
29+
return thread
30+
31+
return wrapper
32+
633

734
def notebook_init(verbose=True):
835
# Check system software and hardware

utils/general.py

Lines changed: 3 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import shutil
1616
import signal
1717
import sys
18-
import threading
1918
import time
2019
import urllib
2120
from datetime import datetime
@@ -34,6 +33,7 @@
3433
import torchvision
3534
import yaml
3635

36+
from utils import TryExcept
3737
from utils.downloads import gsutil_getsize
3838
from utils.metrics import box_iou, fitness
3939

@@ -195,27 +195,6 @@ def __exit__(self, exc_type, exc_val, exc_tb):
195195
os.chdir(self.cwd)
196196

197197

198-
def try_except(func):
199-
# try-except function. Usage: @try_except decorator
200-
def handler(*args, **kwargs):
201-
try:
202-
func(*args, **kwargs)
203-
except Exception as e:
204-
print(e)
205-
206-
return handler
207-
208-
209-
def threaded(func):
210-
# Multi-threads a target function and returns thread. Usage: @threaded decorator
211-
def wrapper(*args, **kwargs):
212-
thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
213-
thread.start()
214-
return thread
215-
216-
return wrapper
217-
218-
219198
def methods(instance):
220199
# Get class/instance methods
221200
return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]
@@ -319,7 +298,7 @@ def git_describe(path=ROOT): # path must be a directory
319298
return ''
320299

321300

322-
@try_except
301+
@TryExcept()
323302
@WorkingDirectory(ROOT)
324303
def check_git_status(repo='ultralytics/yolov5'):
325304
# YOLOv5 status check, recommend 'git pull' if code is out of date
@@ -364,7 +343,7 @@ def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=Fals
364343
return result
365344

366345

367-
@try_except
346+
@TryExcept()
368347
def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), install=True, cmds=()):
369348
# Check installed dependencies meet YOLOv5 requirements (pass *.txt file or list of packages)
370349
prefix = colorstr('red', 'bold', 'requirements:')

utils/metrics.py

Lines changed: 38 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
import numpy as np
1212
import torch
1313

14+
from utils import TryExcept, threaded
15+
1416

1517
def fitness(x):
1618
# Model fitness as a weighted combination of metrics
@@ -184,36 +186,35 @@ def tp_fp(self):
184186
# fn = self.matrix.sum(0) - tp # false negatives (missed detections)
185187
return tp[:-1], fp[:-1] # remove background class
186188

189+
@TryExcept('WARNING: ConfusionMatrix plot failure')
187190
def plot(self, normalize=True, save_dir='', names=()):
188-
try:
189-
import seaborn as sn
190-
191-
array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1E-9) if normalize else 1) # normalize columns
192-
array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
193-
194-
fig = plt.figure(figsize=(12, 9), tight_layout=True)
195-
nc, nn = self.nc, len(names) # number of classes, names
196-
sn.set(font_scale=1.0 if nc < 50 else 0.8) # for label size
197-
labels = (0 < nn < 99) and (nn == nc) # apply names to ticklabels
198-
with warnings.catch_warnings():
199-
warnings.simplefilter('ignore') # suppress empty matrix RuntimeWarning: All-NaN slice encountered
200-
sn.heatmap(array,
201-
annot=nc < 30,
202-
annot_kws={
203-
"size": 8},
204-
cmap='Blues',
205-
fmt='.2f',
206-
square=True,
207-
vmin=0.0,
208-
xticklabels=names + ['background FP'] if labels else "auto",
209-
yticklabels=names + ['background FN'] if labels else "auto").set_facecolor((1, 1, 1))
210-
fig.axes[0].set_xlabel('True')
211-
fig.axes[0].set_ylabel('Predicted')
212-
plt.title('Confusion Matrix')
213-
fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250)
214-
plt.close()
215-
except Exception as e:
216-
print(f'WARNING: ConfusionMatrix plot failure: {e}')
191+
import seaborn as sn
192+
193+
array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1E-9) if normalize else 1) # normalize columns
194+
array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
195+
196+
fig, ax = plt.subplots(1, 1, figsize=(12, 9), tight_layout=True)
197+
nc, nn = self.nc, len(names) # number of classes, names
198+
sn.set(font_scale=1.0 if nc < 50 else 0.8) # for label size
199+
labels = (0 < nn < 99) and (nn == nc) # apply names to ticklabels
200+
with warnings.catch_warnings():
201+
warnings.simplefilter('ignore') # suppress empty matrix RuntimeWarning: All-NaN slice encountered
202+
sn.heatmap(array,
203+
ax=ax,
204+
annot=nc < 30,
205+
annot_kws={
206+
"size": 8},
207+
cmap='Blues',
208+
fmt='.2f',
209+
square=True,
210+
vmin=0.0,
211+
xticklabels=names + ['background FP'] if labels else "auto",
212+
yticklabels=names + ['background FN'] if labels else "auto").set_facecolor((1, 1, 1))
213+
ax.set_ylabel('True')
214+
ax.set_ylabel('Predicted')
215+
ax.set_title('Confusion Matrix')
216+
fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250)
217+
plt.close(fig)
217218

218219
def print(self):
219220
for i in range(self.nc + 1):
@@ -320,6 +321,7 @@ def wh_iou(wh1, wh2, eps=1e-7):
320321
# Plots ----------------------------------------------------------------------------------------------------------------
321322

322323

324+
@threaded
323325
def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=()):
324326
# Precision-recall curve
325327
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
@@ -336,12 +338,13 @@ def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=()):
336338
ax.set_ylabel('Precision')
337339
ax.set_xlim(0, 1)
338340
ax.set_ylim(0, 1)
339-
plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
340-
plt.title('Precision-Recall Curve')
341+
ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
342+
ax.set_title('Precision-Recall Curve')
341343
fig.savefig(save_dir, dpi=250)
342-
plt.close()
344+
plt.close(fig)
343345

344346

347+
@threaded
345348
def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confidence', ylabel='Metric'):
346349
# Metric-confidence curve
347350
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
@@ -358,7 +361,7 @@ def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confi
358361
ax.set_ylabel(ylabel)
359362
ax.set_xlim(0, 1)
360363
ax.set_ylim(0, 1)
361-
plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
362-
plt.title(f'{ylabel}-Confidence Curve')
364+
ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
365+
ax.set_title(f'{ylabel}-Confidence Curve')
363366
fig.savefig(save_dir, dpi=250)
364-
plt.close()
367+
plt.close(fig)

utils/plots.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@
1919
import torch
2020
from PIL import Image, ImageDraw, ImageFont
2121

22+
from utils import TryExcept, threaded
2223
from utils.general import (CONFIG_DIR, FONT, LOGGER, check_font, check_requirements, clip_coords, increment_path,
23-
is_ascii, threaded, try_except, xywh2xyxy, xyxy2xywh)
24+
is_ascii, xywh2xyxy, xyxy2xywh)
2425
from utils.metrics import fitness
2526

2627
# Settings
@@ -339,7 +340,7 @@ def plot_val_study(file='', dir='', x=None): # from utils.plots import *; plot_
339340
plt.savefig(f, dpi=300)
340341

341342

342-
@try_except # known issue https://github.com/ultralytics/yolov5/issues/5395
343+
@TryExcept() # known issue https://github.com/ultralytics/yolov5/issues/5395
343344
def plot_labels(labels, names=(), save_dir=Path('')):
344345
# plot dataset labels
345346
LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ")

0 commit comments

Comments
 (0)