Skip to content

Commit a53dee7

Browse files
authored
Reduce val device transfers (ultralytics#7525)
1 parent 9436600 commit a53dee7

File tree

1 file changed

+6
-8
lines changed

1 file changed

+6
-8
lines changed

val.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -220,14 +220,14 @@ def run(
220220
# Metrics
221221
for si, pred in enumerate(out):
222222
labels = targets[targets[:, 0] == si, 1:]
223-
nl = len(labels)
224-
tcls = labels[:, 0].tolist() if nl else [] # target class
223+
nl, npr = labels.shape[0], pred.shape[0] # number of labels, predictions
225224
path, shape = Path(paths[si]), shapes[si][0]
225+
correct = torch.zeros(npr, niou, dtype=torch.bool, device=device) # init
226226
seen += 1
227227

228-
if len(pred) == 0:
228+
if npr == 0:
229229
if nl:
230-
stats.append((torch.zeros(0, niou, dtype=torch.bool), torch.Tensor(), torch.Tensor(), tcls))
230+
stats.append((correct, *torch.zeros((3, 0))))
231231
continue
232232

233233
# Predictions
@@ -244,9 +244,7 @@ def run(
244244
correct = process_batch(predn, labelsn, iouv)
245245
if plots:
246246
confusion_matrix.process_batch(predn, labelsn)
247-
else:
248-
correct = torch.zeros(pred.shape[0], niou, dtype=torch.bool)
249-
stats.append((correct.cpu(), pred[:, 4].cpu(), pred[:, 5].cpu(), tcls)) # (correct, conf, pcls, tcls)
247+
stats.append((correct, pred[:, 4], pred[:, 5], labels[:, 0])) # (correct, conf, pcls, tcls)
250248

251249
# Save/log
252250
if save_txt:
@@ -265,7 +263,7 @@ def run(
265263
callbacks.run('on_val_batch_end')
266264

267265
# Compute metrics
268-
stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy
266+
stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*stats)] # to numpy
269267
if len(stats) and stats[0].any():
270268
tp, fp, p, r, f1, ap, ap_class = ap_per_class(*stats, plot=plots, save_dir=save_dir, names=names)
271269
ap50, ap = ap[:, 0], ap.mean(1) # [email protected], [email protected]:0.95

0 commit comments

Comments
 (0)