Skip to content

Commit 39ff9f2

Browse files
authored
fix score threshold in mot_infer (PaddlePaddle#3444)
1 parent 1264fde commit 39ff9f2

File tree

3 files changed

+25
-10
lines changed

3 files changed

+25
-10
lines changed

deploy/python/mot_infer.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,14 +93,15 @@ def preprocess(self, im):
9393
inputs = create_inputs(im, im_info)
9494
return inputs
9595

96-
def postprocess(self, pred_dets, pred_embs):
96+
def postprocess(self, pred_dets, pred_embs, threshold):
9797
online_targets = self.tracker.update(pred_dets, pred_embs)
9898
online_tlwhs, online_ids = [], []
9999
online_scores = []
100100
for t in online_targets:
101101
tlwh = t.tlwh
102102
tid = t.track_id
103103
tscore = t.score
104+
if tscore < threshold: continue
104105
vertical = tlwh[2] / tlwh[3] > 1.6
105106
if tlwh[2] * tlwh[3] > self.tracker.min_box_area and not vertical:
106107
online_tlwhs.append(tlwh)
@@ -137,8 +138,8 @@ def predict(self, image, threshold=0.5, repeats=1):
137138
self.det_times.inference_time_s.end(repeats=repeats)
138139

139140
self.det_times.postprocess_time_s.start()
140-
online_tlwhs, online_scores, online_ids = self.postprocess(pred_dets,
141-
pred_embs)
141+
online_tlwhs, online_scores, online_ids = self.postprocess(
142+
pred_dets, pred_embs, threshold)
142143
self.det_times.postprocess_time_s.end()
143144
self.det_times.img_num += 1
144145
return online_tlwhs, online_scores, online_ids
@@ -363,7 +364,8 @@ def predict_video(detector, camera_id):
363364
online_ids,
364365
online_scores,
365366
frame_id=frame_id,
366-
fps=fps)
367+
fps=fps,
368+
threhold=FLAGS.threshold)
367369
if FLAGS.save_images:
368370
save_dir = os.path.join(FLAGS.output_dir, video_name.split('.')[-2])
369371
if not os.path.exists(save_dir):

ppdet/engine/tracker.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,8 @@ def _eval_seq_jde(self,
112112
dataloader,
113113
save_dir=None,
114114
show_image=False,
115-
frame_rate=30):
115+
frame_rate=30,
116+
draw_threshold=0):
116117
if save_dir:
117118
if not os.path.exists(save_dir): os.makedirs(save_dir)
118119
tracker = self.model.tracker
@@ -140,6 +141,7 @@ def _eval_seq_jde(self,
140141
tlwh = t.tlwh
141142
tid = t.track_id
142143
tscore = t.score
144+
if tscore < draw_threshold: continue
143145
vertical = tlwh[2] / tlwh[3] > 1.6
144146
if tlwh[2] * tlwh[3] > tracker.min_box_area and not vertical:
145147
online_tlwhs.append(tlwh)
@@ -162,7 +164,8 @@ def _eval_seq_sde(self,
162164
save_dir=None,
163165
show_image=False,
164166
frame_rate=30,
165-
det_file=''):
167+
det_file='',
168+
draw_threshold=0):
166169
if save_dir:
167170
if not os.path.exists(save_dir): os.makedirs(save_dir)
168171
tracker = self.model.tracker
@@ -191,6 +194,7 @@ def _eval_seq_sde(self,
191194
dets = dets_list[frame_id]
192195
bbox_tlwh = paddle.to_tensor(dets['bbox'], dtype='float32')
193196
pred_scores = paddle.to_tensor(dets['score'], dtype='float32')
197+
if pred_scores < draw_threshold: continue
194198
if bbox_tlwh.shape[0] > 0:
195199
pred_bboxes = paddle.concat(
196200
(bbox_tlwh[:, 0:2],
@@ -343,7 +347,8 @@ def mot_predict(self,
343347
save_images=False,
344348
save_videos=True,
345349
show_image=False,
346-
det_results_dir=''):
350+
det_results_dir='',
351+
draw_threshold=0.5):
347352
if not os.path.exists(output_dir): os.makedirs(output_dir)
348353
result_root = os.path.join(output_dir, 'mot_results')
349354
if not os.path.exists(result_root): os.makedirs(result_root)
@@ -369,15 +374,17 @@ def mot_predict(self,
369374
dataloader,
370375
save_dir=save_dir,
371376
show_image=show_image,
372-
frame_rate=frame_rate)
377+
frame_rate=frame_rate,
378+
draw_threshold=draw_threshold)
373379
elif model_type in ['DeepSORT']:
374380
results, nf, ta, tc = self._eval_seq_sde(
375381
dataloader,
376382
save_dir=save_dir,
377383
show_image=show_image,
378384
frame_rate=frame_rate,
379385
det_file=os.path.join(det_results_dir,
380-
'{}.txt'.format(seq)))
386+
'{}.txt'.format(seq)),
387+
draw_threshold=draw_threshold)
381388
else:
382389
raise ValueError(model_type)
383390

tools/infer_mot.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,11 @@ def parse_args():
6868
'--show_image',
6969
action='store_true',
7070
help='Show tracking results (image).')
71+
parser.add_argument(
72+
"--draw_threshold",
73+
type=float,
74+
default=0.5,
75+
help="Threshold to reserve the result for visualization.")
7176
args = parser.parse_args()
7277
return args
7378

@@ -94,7 +99,8 @@ def run(FLAGS, cfg):
9499
save_images=FLAGS.save_images,
95100
save_videos=FLAGS.save_videos,
96101
show_image=FLAGS.show_image,
97-
det_results_dir=FLAGS.det_results_dir)
102+
det_results_dir=FLAGS.det_results_dir,
103+
draw_threshold=FLAGS.draw_threshold)
98104

99105

100106
def main():

0 commit comments

Comments
 (0)