@@ -37,17 +37,21 @@ def metric(k): # compute metric
37
37
bpr = (best > 1. / thr ).float ().mean () # best possible recall
38
38
return bpr , aat
39
39
40
- bpr , aat = metric (m .anchor_grid .clone ().cpu ().view (- 1 , 2 ))
40
+ anchors = m .anchor_grid .clone ().cpu ().view (- 1 , 2 ) # current anchors
41
+ bpr , aat = metric (anchors )
41
42
print (f'anchors/target = { aat :.2f} , Best Possible Recall (BPR) = { bpr :.4f} ' , end = '' )
42
43
if bpr < 0.98 : # threshold to recompute
43
44
print ('. Attempting to improve anchors, please wait...' )
44
45
na = m .anchor_grid .numel () // 2 # number of anchors
45
- new_anchors = kmean_anchors (dataset , n = na , img_size = imgsz , thr = thr , gen = 1000 , verbose = False )
46
- new_bpr = metric (new_anchors .reshape (- 1 , 2 ))[0 ]
46
+ try :
47
+ anchors = kmean_anchors (dataset , n = na , img_size = imgsz , thr = thr , gen = 1000 , verbose = False )
48
+ except Exception as e :
49
+ print (f'{ prefix } ERROR: { e } ' )
50
+ new_bpr = metric (anchors )[0 ]
47
51
if new_bpr > bpr : # replace anchors
48
- new_anchors = torch .tensor (new_anchors , device = m .anchors .device ).type_as (m .anchors )
49
- m .anchor_grid [:] = new_anchors .clone ().view_as (m .anchor_grid ) # for inference
50
- m .anchors [:] = new_anchors .clone ().view_as (m .anchors ) / m .stride .to (m .anchors .device ).view (- 1 , 1 , 1 ) # loss
52
+ anchors = torch .tensor (anchors , device = m .anchors .device ).type_as (m .anchors )
53
+ m .anchor_grid [:] = anchors .clone ().view_as (m .anchor_grid ) # for inference
54
+ m .anchors [:] = anchors .clone ().view_as (m .anchors ) / m .stride .to (m .anchors .device ).view (- 1 , 1 , 1 ) # loss
51
55
check_anchor_order (m )
52
56
print (f'{ prefix } New anchors saved to model. Update model *.yaml to use these anchors in the future.' )
53
57
else :
@@ -119,6 +123,7 @@ def print_results(k):
119
123
print (f'{ prefix } Running kmeans for { n } anchors on { len (wh )} points...' )
120
124
s = wh .std (0 ) # sigmas for whitening
121
125
k , dist = kmeans (wh / s , n , iter = 30 ) # points, mean distance
126
+ assert len (k ) == n , print (f'{ prefix } ERROR: scipy.cluster.vq.kmeans requested { n } points but returned only { len (k )} ' )
122
127
k *= s
123
128
wh = torch .tensor (wh , dtype = torch .float32 ) # filtered
124
129
wh0 = torch .tensor (wh0 , dtype = torch .float32 ) # unfiltered
0 commit comments