11
11
import numpy as np
12
12
import torch
13
13
14
+ from utils import TryExcept , threaded
15
+
14
16
15
17
def fitness (x ):
16
18
# Model fitness as a weighted combination of metrics
@@ -184,36 +186,35 @@ def tp_fp(self):
184
186
# fn = self.matrix.sum(0) - tp # false negatives (missed detections)
185
187
return tp [:- 1 ], fp [:- 1 ] # remove background class
186
188
189
+ @TryExcept ('WARNING: ConfusionMatrix plot failure' )
187
190
def plot (self , normalize = True , save_dir = '' , names = ()):
188
- try :
189
- import seaborn as sn
190
-
191
- array = self .matrix / ((self .matrix .sum (0 ).reshape (1 , - 1 ) + 1E-9 ) if normalize else 1 ) # normalize columns
192
- array [array < 0.005 ] = np .nan # don't annotate (would appear as 0.00)
193
-
194
- fig = plt .figure (figsize = (12 , 9 ), tight_layout = True )
195
- nc , nn = self .nc , len (names ) # number of classes, names
196
- sn .set (font_scale = 1.0 if nc < 50 else 0.8 ) # for label size
197
- labels = (0 < nn < 99 ) and (nn == nc ) # apply names to ticklabels
198
- with warnings .catch_warnings ():
199
- warnings .simplefilter ('ignore' ) # suppress empty matrix RuntimeWarning: All-NaN slice encountered
200
- sn .heatmap (array ,
201
- annot = nc < 30 ,
202
- annot_kws = {
203
- "size" : 8 },
204
- cmap = 'Blues' ,
205
- fmt = '.2f' ,
206
- square = True ,
207
- vmin = 0.0 ,
208
- xticklabels = names + ['background FP' ] if labels else "auto" ,
209
- yticklabels = names + ['background FN' ] if labels else "auto" ).set_facecolor ((1 , 1 , 1 ))
210
- fig .axes [0 ].set_xlabel ('True' )
211
- fig .axes [0 ].set_ylabel ('Predicted' )
212
- plt .title ('Confusion Matrix' )
213
- fig .savefig (Path (save_dir ) / 'confusion_matrix.png' , dpi = 250 )
214
- plt .close ()
215
- except Exception as e :
216
- print (f'WARNING: ConfusionMatrix plot failure: { e } ' )
191
+ import seaborn as sn
192
+
193
+ array = self .matrix / ((self .matrix .sum (0 ).reshape (1 , - 1 ) + 1E-9 ) if normalize else 1 ) # normalize columns
194
+ array [array < 0.005 ] = np .nan # don't annotate (would appear as 0.00)
195
+
196
+ fig , ax = plt .subplots (1 , 1 , figsize = (12 , 9 ), tight_layout = True )
197
+ nc , nn = self .nc , len (names ) # number of classes, names
198
+ sn .set (font_scale = 1.0 if nc < 50 else 0.8 ) # for label size
199
+ labels = (0 < nn < 99 ) and (nn == nc ) # apply names to ticklabels
200
+ with warnings .catch_warnings ():
201
+ warnings .simplefilter ('ignore' ) # suppress empty matrix RuntimeWarning: All-NaN slice encountered
202
+ sn .heatmap (array ,
203
+ ax = ax ,
204
+ annot = nc < 30 ,
205
+ annot_kws = {
206
+ "size" : 8 },
207
+ cmap = 'Blues' ,
208
+ fmt = '.2f' ,
209
+ square = True ,
210
+ vmin = 0.0 ,
211
+ xticklabels = names + ['background FP' ] if labels else "auto" ,
212
+ yticklabels = names + ['background FN' ] if labels else "auto" ).set_facecolor ((1 , 1 , 1 ))
213
+ ax .set_ylabel ('True' )
214
+ ax .set_ylabel ('Predicted' )
215
+ ax .set_title ('Confusion Matrix' )
216
+ fig .savefig (Path (save_dir ) / 'confusion_matrix.png' , dpi = 250 )
217
+ plt .close (fig )
217
218
218
219
def print (self ):
219
220
for i in range (self .nc + 1 ):
@@ -320,6 +321,7 @@ def wh_iou(wh1, wh2, eps=1e-7):
320
321
# Plots ----------------------------------------------------------------------------------------------------------------
321
322
322
323
324
+ @threaded
323
325
def plot_pr_curve (px , py , ap , save_dir = Path ('pr_curve.png' ), names = ()):
324
326
# Precision-recall curve
325
327
fig , ax = plt .subplots (1 , 1 , figsize = (9 , 6 ), tight_layout = True )
@@ -336,12 +338,13 @@ def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=()):
336
338
ax .set_ylabel ('Precision' )
337
339
ax .set_xlim (0 , 1 )
338
340
ax .set_ylim (0 , 1 )
339
- plt .legend (bbox_to_anchor = (1.04 , 1 ), loc = "upper left" )
340
- plt . title ('Precision-Recall Curve' )
341
+ ax .legend (bbox_to_anchor = (1.04 , 1 ), loc = "upper left" )
342
+ ax . set_title ('Precision-Recall Curve' )
341
343
fig .savefig (save_dir , dpi = 250 )
342
- plt .close ()
344
+ plt .close (fig )
343
345
344
346
347
+ @threaded
345
348
def plot_mc_curve (px , py , save_dir = Path ('mc_curve.png' ), names = (), xlabel = 'Confidence' , ylabel = 'Metric' ):
346
349
# Metric-confidence curve
347
350
fig , ax = plt .subplots (1 , 1 , figsize = (9 , 6 ), tight_layout = True )
@@ -358,7 +361,7 @@ def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confi
358
361
ax .set_ylabel (ylabel )
359
362
ax .set_xlim (0 , 1 )
360
363
ax .set_ylim (0 , 1 )
361
- plt .legend (bbox_to_anchor = (1.04 , 1 ), loc = "upper left" )
362
- plt . title (f'{ ylabel } -Confidence Curve' )
364
+ ax .legend (bbox_to_anchor = (1.04 , 1 ), loc = "upper left" )
365
+ ax . set_title (f'{ ylabel } -Confidence Curve' )
363
366
fig .savefig (save_dir , dpi = 250 )
364
- plt .close ()
367
+ plt .close (fig )
0 commit comments