@@ -144,7 +144,7 @@ def __init__(
144
144
logger .info (f'- in-build embedder : { "No" if self .embedder is None else "Yes" } ' )
145
145
logger .info (f'- polygon detections : { "No" if polygon is False else "Yes" } ' )
146
146
147
- def update_tracks (self , raw_detections , embeds = None , frame = None , today = None , others = None ):
147
+ def update_tracks (self , raw_detections , embeds = None , frame = None , today = None , others = None , instance_masks = None ):
148
148
149
149
"""Run multi-target tracker on a particular sequence.
150
150
@@ -162,6 +162,8 @@ def update_tracks(self, raw_detections, embeds=None, frame=None, today=None, oth
162
162
Provide today's date, for naming of tracks
163
163
others: Optional[ List ] = None
164
164
Other things associated to detections to be stored in tracks, usually, could be corresponding segmentation mask, other associated values, etc. Currently others is ignored with polygon is True.
165
+ instance_masks: Optional [ List ] = None
166
+ Instance masks corresponding to detections. If given, they are used to filter out background and only use foreground for apperance embedding. Expects numpy boolean mask matrix.
165
167
166
168
Returns
167
169
-------
@@ -184,10 +186,10 @@ def update_tracks(self, raw_detections, embeds=None, frame=None, today=None, oth
184
186
raw_detections = [d for d in raw_detections if d [0 ][2 ] > 0 and d [0 ][3 ] > 0 ]
185
187
186
188
if embeds is None :
187
- embeds = self .generate_embeds (frame , raw_detections )
189
+ embeds = self .generate_embeds (frame , raw_detections , instance_masks = instance_masks )
188
190
189
191
# Proper deep sort detection objects that consist of bbox, confidence and embedding.
190
- detections = self .create_detections (raw_detections , embeds , others = others )
192
+ detections = self .create_detections (raw_detections , embeds , instance_masks = instance_masks , others = others )
191
193
else :
192
194
polygons , bounding_rects = self .process_polygons (raw_detections [0 ])
193
195
@@ -218,15 +220,24 @@ def update_tracks(self, raw_detections, embeds=None, frame=None, today=None, oth
218
220
def refresh_track_ids (self ):
219
221
self .tracker ._next_id
220
222
221
- def generate_embeds (self , frame , raw_dets ):
222
- crops = self .crop_bb (frame , raw_dets )
223
- return self .embedder .predict (crops )
223
+ def generate_embeds (self , frame , raw_dets , instance_masks = None ):
224
+ crops , cropped_inst_masks = self .crop_bb (frame , raw_dets , instance_masks = instance_masks )
225
+ if cropped_inst_masks is not None :
226
+ masked_crops = []
227
+ for crop , mask in zip (crops , cropped_inst_masks ):
228
+ masked_crop = np .zeros_like (crop )
229
+ masked_crop = masked_crop + np .array ([123.675 , 116.28 , 103.53 ], dtype = crop .dtype )
230
+ masked_crop [mask ] = crop [mask ]
231
+ masked_crops .append (masked_crop )
232
+ return self .embedder .predict (masked_crops )
233
+ else :
234
+ return self .embedder .predict (crops )
224
235
225
236
def generate_embeds_poly (self , frame , polygons , bounding_rects ):
226
237
crops = self .crop_poly_pad_black (frame , polygons , bounding_rects )
227
238
return self .embedder .predict (crops )
228
239
229
- def create_detections (self , raw_dets , embeds , others = None ):
240
+ def create_detections (self , raw_dets , embeds , instance_masks = None , others = None ):
230
241
detection_list = []
231
242
for i , (raw_det , embed ) in enumerate (zip (raw_dets , embeds )):
232
243
detection_list .append (
@@ -235,6 +246,7 @@ def create_detections(self, raw_dets, embeds, others=None):
235
246
raw_det [1 ],
236
247
embed ,
237
248
class_name = raw_det [2 ] if len (raw_det )== 3 else None ,
249
+ instance_mask = instance_masks [i ] if isinstance (instance_masks , Iterable ) else instance_masks ,
238
250
others = others [i ] if isinstance (others , Iterable ) else others ,
239
251
)
240
252
) # raw_det = [bbox, conf_score, class]
@@ -265,10 +277,14 @@ def process_polygons(raw_polygons):
265
277
return polygons , bounding_rects
266
278
267
279
@staticmethod
268
- def crop_bb (frame , raw_dets ):
280
+ def crop_bb (frame , raw_dets , instance_masks = None ):
269
281
crops = []
270
282
im_height , im_width = frame .shape [:2 ]
271
- for detection in raw_dets :
283
+ if instance_masks is not None :
284
+ masks = []
285
+ else :
286
+ masks = None
287
+ for i , detection in enumerate (raw_dets ):
272
288
l , t , w , h = [int (x ) for x in detection [0 ]]
273
289
r = l + w
274
290
b = t + h
@@ -277,7 +293,10 @@ def crop_bb(frame, raw_dets):
277
293
crop_t = max (0 , t )
278
294
crop_b = min (im_height , b )
279
295
crops .append (frame [crop_t :crop_b , crop_l :crop_r ])
280
- return crops
296
+ if instance_masks is not None :
297
+ masks .append ( instance_masks [i ][crop_t :crop_b , crop_l :crop_r ] )
298
+
299
+ return crops , masks
281
300
282
301
@staticmethod
283
302
def crop_poly_pad_black (frame , polygons , bounding_rects ):
0 commit comments