Skip to content

Commit c0f8eb0

Browse files
glenn-jocherCesarBazanAV
authored andcommitted
Improved detect.py timing (ultralytics#4741)
* Improved detect.py timing * Eliminate 1 time_sync() call * Inference-only time * dash * #Save section * Cleanup
1 parent 4c25c9c commit c0f8eb0

File tree

2 files changed

+22
-18
lines changed

2 files changed

+22
-18
lines changed

detect.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
import argparse
1010
import sys
11-
import time
1211
from pathlib import Path
1312

1413
import cv2
@@ -123,8 +122,9 @@ def wrap_frozen_graph(gd, inputs, outputs):
123122
# Run inference
124123
if pt and device.type != 'cpu':
125124
model(torch.zeros(1, 3, *imgsz).to(device).type_as(next(model.parameters()))) # run once
126-
t0 = time.time()
125+
dt, seen = [0.0, 0.0, 0.0], 0
127126
for path, img, im0s, vid_cap in dataset:
127+
t1 = time_sync()
128128
if onnx:
129129
img = img.astype('float32')
130130
else:
@@ -133,9 +133,10 @@ def wrap_frozen_graph(gd, inputs, outputs):
133133
img = img / 255.0 # 0 - 255 to 0.0 - 1.0
134134
if len(img.shape) == 3:
135135
img = img[None] # expand for batch dim
136+
t2 = time_sync()
137+
dt[0] += t2 - t1
136138

137139
# Inference
138-
t1 = time_sync()
139140
if pt:
140141
visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False
141142
pred = model(img, augment=augment, visualize=visualize)[0]
@@ -162,17 +163,20 @@ def wrap_frozen_graph(gd, inputs, outputs):
162163
pred[..., 2] *= imgsz[1] # w
163164
pred[..., 3] *= imgsz[0] # h
164165
pred = torch.tensor(pred)
166+
t3 = time_sync()
167+
dt[1] += t3 - t2
165168

166169
# NMS
167170
pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
168-
t2 = time_sync()
171+
dt[2] += time_sync() - t3
169172

170173
# Second-stage classifier (optional)
171174
if classify:
172175
pred = apply_classifier(pred, modelc, img, im0s)
173176

174177
# Process predictions
175-
for i, det in enumerate(pred): # detections per image
178+
for i, det in enumerate(pred): # per image
179+
seen += 1
176180
if webcam: # batch_size >= 1
177181
p, s, im0, frame = path[i], f'{i}: ', im0s[i].copy(), dataset.count
178182
else:
@@ -209,8 +213,8 @@ def wrap_frozen_graph(gd, inputs, outputs):
209213
if save_crop:
210214
save_one_box(xyxy, imc, file=save_dir / 'crops' / names[c] / f'{p.stem}.jpg', BGR=True)
211215

212-
# Print time (inference + NMS)
213-
print(f'{s}Done. ({t2 - t1:.3f}s)')
216+
# Print time (inference-only)
217+
print(f'{s}Done. ({t3 - t2:.3f}s)')
214218

215219
# Stream results
216220
im0 = annotator.result()
@@ -237,15 +241,15 @@ def wrap_frozen_graph(gd, inputs, outputs):
237241
vid_writer[i] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
238242
vid_writer[i].write(im0)
239243

244+
# Print results
245+
t = tuple(x / seen * 1E3 for x in dt) # speeds per image
246+
print(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {(1, 3, *imgsz)}' % t)
240247
if save_txt or save_img:
241248
s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
242249
print(f"Results saved to {colorstr('bold', save_dir)}{s}")
243-
244250
if update:
245251
strip_optimizer(weights) # update model (to fix SourceChangeWarning)
246252

247-
print(f'Done. ({time.time() - t0:.3f}s)')
248-
249253

250254
def parse_opt():
251255
parser = argparse.ArgumentParser()

val.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -154,22 +154,22 @@ def run(data,
154154
names = {k: v for k, v in enumerate(model.names if hasattr(model, 'names') else model.module.names)}
155155
class_map = coco80_to_coco91_class() if is_coco else list(range(1000))
156156
s = ('%20s' + '%11s' * 6) % ('Class', 'Images', 'Labels', 'P', 'R', '[email protected]', '[email protected]:.95')
157-
p, r, f1, mp, mr, map50, map, t0, t1, t2 = 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.
157+
dt, p, r, f1, mp, mr, map50, map = [0.0, 0.0, 0.0], 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
158158
loss = torch.zeros(3, device=device)
159159
jdict, stats, ap, ap_class = [], [], [], []
160160
for batch_i, (img, targets, paths, shapes) in enumerate(tqdm(dataloader, desc=s)):
161-
t_ = time_sync()
161+
t1 = time_sync()
162162
img = img.to(device, non_blocking=True)
163163
img = img.half() if half else img.float() # uint8 to fp16/32
164164
img /= 255.0 # 0 - 255 to 0.0 - 1.0
165165
targets = targets.to(device)
166166
nb, _, height, width = img.shape # batch size, channels, height, width
167-
t = time_sync()
168-
t0 += t - t_
167+
t2 = time_sync()
168+
dt[0] += t2 - t1
169169

170170
# Run model
171171
out, train_out = model(img, augment=augment) # inference and training outputs
172-
t1 += time_sync() - t
172+
dt[1] += time_sync() - t2
173173

174174
# Compute loss
175175
if compute_loss:
@@ -178,9 +178,9 @@ def run(data,
178178
# Run NMS
179179
targets[:, 2:] *= torch.Tensor([width, height, width, height]).to(device) # to pixels
180180
lb = [targets[targets[:, 0] == i, 1:] for i in range(nb)] if save_hybrid else [] # for autolabelling
181-
t = time_sync()
181+
t3 = time_sync()
182182
out = non_max_suppression(out, conf_thres, iou_thres, labels=lb, multi_label=True, agnostic=single_cls)
183-
t2 += time_sync() - t
183+
dt[2] += time_sync() - t3
184184

185185
# Statistics per image
186186
for si, pred in enumerate(out):
@@ -247,7 +247,7 @@ def run(data,
247247
print(pf % (names[c], seen, nt[c], p[i], r[i], ap50[i], ap[i]))
248248

249249
# Print speeds
250-
t = tuple(x / seen * 1E3 for x in (t0, t1, t2)) # speeds per image
250+
t = tuple(x / seen * 1E3 for x in dt) # speeds per image
251251
if not training:
252252
shape = (batch_size, 3, imgsz, imgsz)
253253
print(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {shape}' % t)

0 commit comments

Comments
 (0)