@@ -18,7 +18,7 @@ def fitness(x):
1818 return (x [:, :4 ] * w ).sum (1 )
1919
2020
21- def ap_per_class (tp , conf , pred_cls , target_cls , plot = False , save_dir = '.' , names = ()):
21+ def ap_per_class (tp , conf , pred_cls , target_cls , plot = False , save_dir = '.' , names = (), eps = 1e-16 ):
2222 """ Compute the average precision, given the recall and precision curves.
2323 Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.
2424 # Arguments
@@ -37,15 +37,15 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names
3737 tp , conf , pred_cls = tp [i ], conf [i ], pred_cls [i ]
3838
3939 # Find unique classes
40- unique_classes = np .unique (target_cls )
40+ unique_classes , nt = np .unique (target_cls , return_counts = True )
4141 nc = unique_classes .shape [0 ] # number of classes, number of detections
4242
4343 # Create Precision-Recall curve and compute AP for each class
4444 px , py = np .linspace (0 , 1 , 1000 ), [] # for plotting
4545 ap , p , r = np .zeros ((nc , tp .shape [1 ])), np .zeros ((nc , 1000 )), np .zeros ((nc , 1000 ))
4646 for ci , c in enumerate (unique_classes ):
4747 i = pred_cls == c
48- n_l = ( target_cls == c ). sum () # number of labels
48+ n_l = nt [ ci ] # number of labels
4949 n_p = i .sum () # number of predictions
5050
5151 if n_p == 0 or n_l == 0 :
@@ -56,7 +56,7 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names
5656 tpc = tp [i ].cumsum (0 )
5757
5858 # Recall
59- recall = tpc / (n_l + 1e-16 ) # recall curve
59+ recall = tpc / (n_l + eps ) # recall curve
6060 r [ci ] = np .interp (- px , - conf [i ], recall [:, 0 ], left = 0 ) # negative x, xp because xp decreases
6161
6262 # Precision
@@ -70,7 +70,7 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names
7070 py .
append (
np .
interp (
px ,
mrec ,
mpre ))
# precision at [email protected] 7171
7272 # Compute F1 (harmonic mean of precision and recall)
73- f1 = 2 * p * r / (p + r + 1e-16 )
73+ f1 = 2 * p * r / (p + r + eps )
7474 names = [v for k , v in names .items () if k in unique_classes ] # list: only classes that have data
7575 names = {i : v for i , v in enumerate (names )} # to dict
7676 if plot :
@@ -80,7 +80,10 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names
8080 plot_mc_curve (px , r , Path (save_dir ) / 'R_curve.png' , names , ylabel = 'Recall' )
8181
8282 i = f1 .mean (0 ).argmax () # max F1 index
83- return p [:, i ], r [:, i ], ap , f1 [:, i ], unique_classes .astype ('int32' )
83+ p , r , f1 = p [:, i ], r [:, i ], f1 [:, i ]
84+ tp = (r * nt ).round () # true positives
85+ fp = (tp / (p + eps ) - tp ).round () # false positives
86+ return tp , fp , p , r , f1 , ap , unique_classes .astype ('int32' )
8487
8588
8689def compute_ap (recall , precision ):
@@ -162,6 +165,12 @@ def process_batch(self, detections, labels):
162165 def matrix (self ):
163166 return self .matrix
164167
168+ def tp_fp (self ):
169+ tp = self .matrix .diagonal () # true positives
170+ fp = self .matrix .sum (1 ) - tp # false positives
171+ # fn = self.matrix.sum(0) - tp # false negatives (missed detections)
172+ return tp [:- 1 ], fp [:- 1 ] # remove background class
173+
165174 def plot (self , normalize = True , save_dir = '' , names = ()):
166175 try :
167176 import seaborn as sn
0 commit comments