Skip to content

Commit 2fad967

Browse files
author
Ling Evan
committed
instance mask cropping + bug fix on original ltwh update
1 parent a353969 commit 2fad967

File tree

4 files changed

+53
-11
lines changed

4 files changed

+53
-11
lines changed

deep_sort_realtime/deep_sort/detection.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ class Detection(object):
1616
A feature vector that describes the object contained in this image.
1717
class_name : Optional str
1818
Detector predicted class name.
19+
instance_mask : Optional
20+
Instance mask corresponding to bounding box
1921
others : Optional any
2022
Other supplementary fields associated with detection that wants to be stored as a "memory" to be retrieve through the track downstream.
2123
@@ -30,12 +32,13 @@ class Detection(object):
3032
3133
"""
3234

33-
def __init__(self, ltwh, confidence, feature, class_name=None, others=None):
35+
def __init__(self, ltwh, confidence, feature, class_name=None, instance_mask=None, others=None):
3436
# def __init__(self, ltwh, feature):
3537
self.ltwh = np.asarray(ltwh, dtype=np.float)
3638
self.confidence = float(confidence)
3739
self.feature = np.asarray(feature, dtype=np.float32)
3840
self.class_name = class_name
41+
self.instance_mask = instance_mask
3942
self.others = others
4043

4144
def get_ltwh(self):

deep_sort_realtime/deep_sort/track.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ class Track:
4444
Classname of matched detection
4545
det_conf : Optional float
4646
Confidence associated with matched detection
47+
instance_mask : Optional
48+
Instance mask associated with matched detection
4749
others : Optional any
4850
Any supplementary fields related to matched detection
4951
@@ -80,6 +82,7 @@ def __init__(
8082
original_ltwh=None,
8183
det_class=None,
8284
det_conf=None,
85+
instance_mask=None,
8386
others=None,
8487
):
8588
self.mean = mean
@@ -100,6 +103,7 @@ def __init__(
100103
self.original_ltwh = original_ltwh
101104
self.det_class = det_class
102105
self.det_conf = det_conf
106+
self.instance_mask = instance_mask
103107
self.others = others
104108

105109
def to_tlwh(self, orig=False, orig_strict=False):
@@ -189,12 +193,24 @@ def get_det_class(self):
189193
"""
190194
return self.det_class
191195

196+
def get_instance_mask(self):
197+
'''
198+
Get instance mask associated with detection. Will be None is there are no associated detection this round
199+
'''
200+
return self.instance_mask
201+
192202
def get_det_supplementary(self):
193203
"""
194204
Get supplementary info associated with the detection. Will be None is there are no associated detection this round.
195205
"""
196206
return self.others
197207

208+
def get_feature(self):
209+
'''
210+
Get latest appearance feature
211+
'''
212+
return self.features[-1]
213+
198214
def predict(self, kf):
199215
"""Propagate the state distribution to the current time step using a
200216
Kalman filter prediction step.
@@ -210,6 +226,7 @@ def predict(self, kf):
210226
self.time_since_update += 1
211227
self.original_ltwh = None
212228
self.det_conf = None
229+
self.instance_mask = None
213230
self.others = None
214231

215232
def update(self, kf, detection):
@@ -231,6 +248,7 @@ def update(self, kf, detection):
231248
self.features.append(detection.feature)
232249
self.det_conf = detection.confidence
233250
self.det_class = detection.class_name
251+
self.instance_mask = detection.instance_mask
234252
self.others = detection.others
235253

236254
self.hits += 1

deep_sort_realtime/deep_sort/tracker.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,8 +194,10 @@ def _initiate_track(self, detection):
194194
self.max_age,
195195
# mean, covariance, self._next_id, self.n_init, self.max_age,
196196
feature=detection.feature,
197+
original_ltwh=detection.get_ltwh(),
197198
det_class=detection.class_name,
198199
det_conf=detection.confidence,
200+
instance_mask=detection.instance_mask,
199201
others=detection.others,
200202
)
201203
)

deep_sort_realtime/deepsort_tracker.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def __init__(
144144
logger.info(f'- in-build embedder : {"No" if self.embedder is None else "Yes"}')
145145
logger.info(f'- polygon detections : {"No" if polygon is False else "Yes"}')
146146

147-
def update_tracks(self, raw_detections, embeds=None, frame=None, today=None, others=None):
147+
def update_tracks(self, raw_detections, embeds=None, frame=None, today=None, others=None, instance_masks=None):
148148

149149
"""Run multi-target tracker on a particular sequence.
150150
@@ -162,6 +162,8 @@ def update_tracks(self, raw_detections, embeds=None, frame=None, today=None, oth
162162
Provide today's date, for naming of tracks
163163
others: Optional[ List ] = None
164164
Other things associated to detections to be stored in tracks, usually, could be corresponding segmentation mask, other associated values, etc. Currently others is ignored with polygon is True.
165+
instance_masks: Optional [ List ] = None
166+
Instance masks corresponding to detections. If given, they are used to filter out background and only use foreground for apperance embedding. Expects numpy boolean mask matrix.
165167
166168
Returns
167169
-------
@@ -184,10 +186,10 @@ def update_tracks(self, raw_detections, embeds=None, frame=None, today=None, oth
184186
raw_detections = [d for d in raw_detections if d[0][2] > 0 and d[0][3] > 0]
185187

186188
if embeds is None:
187-
embeds = self.generate_embeds(frame, raw_detections)
189+
embeds = self.generate_embeds(frame, raw_detections, instance_masks=instance_masks)
188190

189191
# Proper deep sort detection objects that consist of bbox, confidence and embedding.
190-
detections = self.create_detections(raw_detections, embeds, others=others)
192+
detections = self.create_detections(raw_detections, embeds, instance_masks=instance_masks, others=others)
191193
else:
192194
polygons, bounding_rects = self.process_polygons(raw_detections[0])
193195

@@ -218,15 +220,24 @@ def update_tracks(self, raw_detections, embeds=None, frame=None, today=None, oth
218220
def refresh_track_ids(self):
219221
self.tracker._next_id
220222

221-
def generate_embeds(self, frame, raw_dets):
222-
crops = self.crop_bb(frame, raw_dets)
223-
return self.embedder.predict(crops)
223+
def generate_embeds(self, frame, raw_dets, instance_masks=None):
224+
crops, cropped_inst_masks = self.crop_bb(frame, raw_dets, instance_masks=instance_masks)
225+
if cropped_inst_masks is not None:
226+
masked_crops = []
227+
for crop, mask in zip(crops, cropped_inst_masks):
228+
masked_crop = np.zeros_like(crop)
229+
masked_crop = masked_crop + np.array([123.675, 116.28, 103.53], dtype=crop.dtype)
230+
masked_crop[mask] = crop[mask]
231+
masked_crops.append(masked_crop)
232+
return self.embedder.predict(masked_crops)
233+
else:
234+
return self.embedder.predict(crops)
224235

225236
def generate_embeds_poly(self, frame, polygons, bounding_rects):
226237
crops = self.crop_poly_pad_black(frame, polygons, bounding_rects)
227238
return self.embedder.predict(crops)
228239

229-
def create_detections(self, raw_dets, embeds, others=None):
240+
def create_detections(self, raw_dets, embeds, instance_masks=None, others=None):
230241
detection_list = []
231242
for i, (raw_det, embed) in enumerate(zip(raw_dets, embeds)):
232243
detection_list.append(
@@ -235,6 +246,7 @@ def create_detections(self, raw_dets, embeds, others=None):
235246
raw_det[1],
236247
embed,
237248
class_name=raw_det[2] if len(raw_det)==3 else None,
249+
instance_mask = instance_masks[i] if isinstance(instance_masks, Iterable) else instance_masks,
238250
others = others[i] if isinstance(others, Iterable) else others,
239251
)
240252
) # raw_det = [bbox, conf_score, class]
@@ -265,10 +277,14 @@ def process_polygons(raw_polygons):
265277
return polygons, bounding_rects
266278

267279
@staticmethod
268-
def crop_bb(frame, raw_dets):
280+
def crop_bb(frame, raw_dets, instance_masks=None):
269281
crops = []
270282
im_height, im_width = frame.shape[:2]
271-
for detection in raw_dets:
283+
if instance_masks is not None:
284+
masks = []
285+
else:
286+
masks = None
287+
for i, detection in enumerate(raw_dets):
272288
l, t, w, h = [int(x) for x in detection[0]]
273289
r = l + w
274290
b = t + h
@@ -277,7 +293,10 @@ def crop_bb(frame, raw_dets):
277293
crop_t = max(0, t)
278294
crop_b = min(im_height, b)
279295
crops.append(frame[crop_t:crop_b, crop_l:crop_r])
280-
return crops
296+
if instance_masks is not None:
297+
masks.append( instance_masks[i][crop_t:crop_b, crop_l:crop_r] )
298+
299+
return crops, masks
281300

282301
@staticmethod
283302
def crop_poly_pad_black(frame, polygons, bounding_rects):

0 commit comments

Comments
 (0)