@@ -220,14 +220,14 @@ def run(
220
220
# Metrics
221
221
for si , pred in enumerate (out ):
222
222
labels = targets [targets [:, 0 ] == si , 1 :]
223
- nl = len (labels )
224
- tcls = labels [:, 0 ].tolist () if nl else [] # target class
223
+ nl , npr = labels .shape [0 ], pred .shape [0 ] # number of labels, predictions
225
224
path , shape = Path (paths [si ]), shapes [si ][0 ]
225
+ correct = torch .zeros (npr , niou , dtype = torch .bool , device = device ) # init
226
226
seen += 1
227
227
228
- if len ( pred ) == 0 :
228
+ if npr == 0 :
229
229
if nl :
230
- stats .append ((torch . zeros ( 0 , niou , dtype = torch .bool ), torch . Tensor (), torch . Tensor (), tcls ))
230
+ stats .append ((correct , * torch .zeros (( 3 , 0 )) ))
231
231
continue
232
232
233
233
# Predictions
@@ -244,9 +244,7 @@ def run(
244
244
correct = process_batch (predn , labelsn , iouv )
245
245
if plots :
246
246
confusion_matrix .process_batch (predn , labelsn )
247
- else :
248
- correct = torch .zeros (pred .shape [0 ], niou , dtype = torch .bool )
249
- stats .append ((correct .cpu (), pred [:, 4 ].cpu (), pred [:, 5 ].cpu (), tcls )) # (correct, conf, pcls, tcls)
247
+ stats .append ((correct , pred [:, 4 ], pred [:, 5 ], labels [:, 0 ])) # (correct, conf, pcls, tcls)
250
248
251
249
# Save/log
252
250
if save_txt :
@@ -265,7 +263,7 @@ def run(
265
263
callbacks .run ('on_val_batch_end' )
266
264
267
265
# Compute metrics
268
- stats = [np . concatenate (x , 0 ) for x in zip (* stats )] # to numpy
266
+ stats = [torch . cat (x , 0 ). cpu (). numpy ( ) for x in zip (* stats )] # to numpy
269
267
if len (stats ) and stats [0 ].any ():
270
268
tp , fp , p , r , f1 , ap , ap_class = ap_per_class (* stats , plot = plots , save_dir = save_dir , names = names )
271
269
ap50 ,
ap = ap [:,
0 ],
ap .
mean (
1 )
# [email protected] , [email protected] :0.95
0 commit comments