Skip to content

Commit 0f0b226

Browse files
authored
Add SMOT tracker (#1573)
* add smot * lint * lint * lint * tutorial * tutorial change * fix comments
1 parent b8a3135 commit 0f0b226

23 files changed

+3007
-1
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ Check the HD video at [Youtube](https://www.youtube.com/watch?v=nfpouVAzXt0) or
4646
| [Semantic Segmentation:](https://gluon-cv.mxnet.io/model_zoo/segmentation.html#semantic-segmentation) <br/>associate each pixel of an image <br/> with a categorical label. | <a href="https://gluon-cv.mxnet.io/model_zoo/segmentation.html#semantic-segmentation"><img src="docs/_static/semantic-segmentation.png" alt="semantic" height="200"/></a> | <a href="https://gluon-cv.mxnet.io/model_zoo/segmentation.html#semantic-segmentation">FCN</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/segmentation.html#semantic-segmentation">PSP</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/segmentation.html#semantic-segmentation">ICNet</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/segmentation.html#semantic-segmentation">DeepLab-v3</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/segmentation.html#semantic-segmentation">DeepLab-v3+</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/segmentation.html#semantic-segmentation">DANet</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/segmentation.html#semantic-segmentation">FastSCNN</a> |
4747
| [Instance Segmentation:](https://gluon-cv.mxnet.io/model_zoo/segmentation.html#instance-segmentation) <br/>detect objects and associate <br/> each pixel inside object area with an <br/> instance label. | <a href="https://gluon-cv.mxnet.io/model_zoo/segmentation.html#instance-segmentation"><img src="docs/_static/instance-segmentation.png" alt="instance" height="200"/></a> | <a href="https://gluon-cv.mxnet.io/model_zoo/segmentation.html#instance-segmentation">Mask RCNN</a>|
4848
| [Pose Estimation:](https://gluon-cv.mxnet.io/model_zoo/pose.html) <br/>detect human pose <br/> from images. | <a href="https://gluon-cv.mxnet.io/model_zoo/pose.html"><img src="docs/_static/pose-estimation.svg" alt="pose" height="200"/></a> | <a href="https://gluon-cv.mxnet.io/model_zoo/pose.html#simple-pose-with-resnet">Simple Pose</a>|
49-
| [Video Action Recognition:](https://gluon-cv.mxnet.io/model_zoo/action_recognition.html) <br/>recognize human actions <br/> in a video. | <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html"><img src="docs/_static/action-recognition.png" alt="action_recognition" height="200"/></a> | MXNet: <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html">TSN</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html">C3D</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html">I3D</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html">I3D_slow</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html">P3D</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html">R3D</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html">R2+1D</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html">Non-local</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html">SlowFast</a> <br/> PyTorch: <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html">TSN</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html">I3D</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html">I3D_slow</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html">R2+1D</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html">Non-local</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html">CSN</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/SlowFast.html">TSN</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html">TPN</a> |
49+
| [Video Action Recognition:](https://gluon-cv.mxnet.io/model_zoo/action_recognition.html) <br/>recognize human actions <br/> in a video. | <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html"><img src="docs/_static/action-recognition.png" alt="action_recognition" height="200"/></a> | MXNet: <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html">TSN</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html">C3D</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html">I3D</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html">I3D_slow</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html">P3D</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html">R3D</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html">R2+1D</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html">Non-local</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html">SlowFast</a> <br/> PyTorch: <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html">TSN</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html">I3D</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html">I3D_slow</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html">R2+1D</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html">Non-local</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html">CSN</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html">SlowFast</a>, <a href="https://gluon-cv.mxnet.io/model_zoo/action_recognition.html">TPN</a> |
5050
| [Depth Prediction:](https://gluon-cv.mxnet.io/model_zoo/depth.html) <br/>predict depth map <br/> from images. | <a href="https://gluon-cv.mxnet.io/model_zoo/depth.html"><img src="docs/_static/depth.png" alt="depth" height="200"/></a> | <a href="https://gluon-cv.mxnet.io/model_zoo/depth.html#kitti-dataset">Monodepth2</a>|
5151
| [GAN:](https://github.com/dmlc/gluon-cv/tree/master/scripts/gan) <br/>generate visually deceptive images | <a href="https://github.com/dmlc/gluon-cv/tree/master/scripts/gan"><img src="https://github.com/dmlc/gluon-cv/raw/master/scripts/gan/wgan/fake_samples_400000.png" alt="lsun" height="200"/></a> | <a href="https://github.com/dmlc/gluon-cv/tree/master/scripts/gan/wgan">WGAN</a>, <a href="https://github.com/dmlc/gluon-cv/tree/master/scripts/gan/cycle_gan">CycleGAN</a>, <a href="https://github.com/dmlc/gluon-cv/tree/master/scripts/gan/stylegan">StyleGAN</a>|
5252
| [Person Re-ID:](https://github.com/dmlc/gluon-cv/tree/master/scripts/re-id/baseline) <br/>re-identify pedestrians across scenes | <a href="https://github.com/dmlc/gluon-cv/tree/master/scripts/re-id/baseline"><img src="https://user-images.githubusercontent.com/3307514/46702937-f4311800-cbd9-11e8-8eeb-c945ec5643fb.png" alt="re-id" height="160"/></a> |<a href="https://github.com/dmlc/gluon-cv/tree/master/scripts/re-id/baseline">Market1501 baseline</a> |

docs/_static/smot_demo.gif

770 KB
Loading

docs/tutorials/index.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,12 @@ Object Tracking
284284

285285
SiamRPN training on VID、DET、COCO、Youtube_bb and test on Otb2015
286286

287+
.. card::
288+
:title: Pre-trained SMOT Models
289+
:link: ../build/examples_tracking/demo_smot.html
290+
291+
Perform Multi-Object Tracking in real-world video with pre-trained SMOT models.
292+
287293

288294
Depth Prediction
289295
---------------------
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
"""03. Multiple object tracking with pre-trained SMOT models
2+
=============================================================
3+
4+
In this tutorial, we present a method,
5+
called `Single-Shot Multi Object Tracking (SMOT) <https://arxiv.org/abs/2010.16031>`_, to perform multi-object tracking.
6+
SMOT is a new tracking framework that converts any single-shot detector (SSD) model into an online multiple object tracker,
7+
which emphasizes simultaneously detecting and tracking of the object paths.
8+
As an example below, we directly use the SSD-Mobilenet object detector pretrained on COCO from :ref:`gluoncv-model-zoo`
9+
and perform multiple object tracking on an arbitrary video.
10+
We want to point out that, SMOT is very efficient, its runtime is close to the runtime of the chosen detector.
11+
12+
"""
13+
14+
######################################################################
15+
# Predict with a SMOT model
16+
# ----------------------------
17+
#
18+
# First, we download a video from MOT challenge website,
19+
20+
from gluoncv import utils
21+
video_path = 'https://motchallenge.net/sequenceVideos/MOT17-02-FRCNN-raw.webm'
22+
im_video = utils.download(video_path)
23+
24+
################################################################
25+
# Then you can simply use our provided script under `/scripts/tracking/smot/demo.py` to obtain the multi-object tracking result.
26+
#
27+
# ::
28+
#
29+
# python demo.py MOT17-02-FRCNN-raw.webm
30+
#
31+
#
32+
################################################################
33+
# You can see the tracking results below. Here, we only track persons,
34+
# but you can track other objects as long as your detector is trained on that category.
35+
#
36+
# .. raw:: html
37+
#
38+
# <div align="center">
39+
# <img src="../../_static/smot_demo.gif">
40+
# </div>
41+
#
42+
# <br>
43+
44+
################################################################
45+
# Our model is able to track multiple persons even when they are partially occluded.
46+
# Try it on your own video and see the results!

gluoncv/model_zoo/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,4 @@
3939
from .siamrpn import *
4040
from .fastscnn import *
4141
from .monodepthv2 import *
42+
from .smot import *

gluoncv/model_zoo/model_store.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@
229229
('661ee2e1bf824f4f4549b3488c59dec0b0078c38', 'monodepth2_resnet18_posenet_kitti_mono_640x192'),
230230
('c14979bb016ed4f555fa09004ddc7616dd60b8b9', 'monodepth2_resnet18_posenet_kitti_mono_stereo_640x192'),
231231
('299b1d9d8a2bcf7c122acd0d23606af4fdfbe7e1', 'i3d_slow_resnet101_f16s4_kinetics700'),
232+
('d6758fc8cddfaaa8d0f7ff2e21adf5b0180f6b4b', 'smot_ssd_bifpn_mobilenet'),
232233
]}
233234

234235
apache_repo_url = 'https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/'

gluoncv/model_zoo/smot/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# pylint: disable=wildcard-import
2+
"""
3+
SMOT: Single-Shot Multi Object Tracking
4+
https://arxiv.org/abs/2010.16031
5+
"""
6+
from __future__ import absolute_import
7+
from .smot_tracker import *
8+
from .tracktors import *

gluoncv/model_zoo/smot/anchor.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# pylint: disable=unused-import
2+
"""Anchor box generator for SSD detector."""
3+
from __future__ import absolute_import
4+
5+
import numpy as np
6+
from mxnet import gluon
7+
8+
9+
class SSDAnchorGenerator(gluon.HybridBlock):
10+
"""Bounding box anchor generator for Single-shot Object Detection.
11+
12+
Parameters
13+
----------
14+
index : int
15+
Index of this generator in SSD models, this is required for naming.
16+
sizes : iterable of floats
17+
Sizes of anchor boxes.
18+
ratios : iterable of floats
19+
Aspect ratios of anchor boxes.
20+
step : int or float
21+
Step size of anchor boxes.
22+
alloc_size : tuple of int
23+
Allocate size for the anchor boxes as (H, W).
24+
Usually we generate enough anchors for large feature map, e.g. 128x128.
25+
Later in inference we can have variable input sizes,
26+
at which time we can crop corresponding anchors from this large
27+
anchor map so we can skip re-generating anchors for each input.
28+
offsets : tuple of float
29+
Center offsets of anchor boxes as (h, w) in range(0, 1).
30+
31+
"""
32+
def __init__(self, index, im_size, sizes, ratios, step, alloc_size=(128, 128),
33+
offsets=(0.5, 0.5), clip=False, **kwargs):
34+
super(SSDAnchorGenerator, self).__init__(**kwargs)
35+
assert len(im_size) == 2
36+
self._im_size = im_size
37+
self._clip = clip
38+
self._sizes = sizes
39+
self._ratios = ratios
40+
anchors = self._generate_anchors(self._sizes, self._ratios, step, alloc_size, offsets)
41+
self._num_anchors = np.size(anchors) / 4
42+
self.anchors = self.params.get_constant('anchor_%d'%(index), anchors)
43+
44+
def _generate_anchors(self, sizes, ratios, step, alloc_size, offsets):
45+
# pylint: disable=unused-argument,too-many-function-args
46+
"""Generate anchors for once. Anchors are stored with (center_x, center_y, w, h) format."""
47+
anchors = []
48+
for i in range(alloc_size[0]):
49+
for j in range(alloc_size[1]):
50+
cy = (i + offsets[0]) * step
51+
cx = (j + offsets[1]) * step
52+
53+
for sz in self._sizes:
54+
for r in ratios:
55+
sr = np.sqrt(r)
56+
w = sz * sr
57+
h = sz / sr
58+
anchors.append([cx, cy, w, h])
59+
return np.array(anchors).reshape(1, 1, alloc_size[0], alloc_size[1], -1)
60+
61+
@property
62+
def num_depth(self):
63+
"""Number of anchors at each pixel."""
64+
return len(self._sizes) * len(self._ratios)
65+
66+
@property
67+
def num_anchors(self):
68+
"""Number of anchors at each pixel."""
69+
return self._num_anchors
70+
71+
# pylint: disable=arguments-differ
72+
def hybrid_forward(self, F, x, anchors):
73+
a = F.slice_like(anchors, x * 0, axes=(2, 3))
74+
a = a.reshape((1, -1, 4))
75+
if self._clip:
76+
cx, cy, cw, ch = a.split(axis=-1, num_outputs=4)
77+
H, W = self._im_size
78+
a = F.concat(*[cx.clip(0, W), cy.clip(0, H), cw.clip(0, W), ch.clip(0, H)], dim=-1)
79+
return a.reshape((1, -1, 4))

gluoncv/model_zoo/smot/decoders.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
"""
2+
MXNet implementation of SMOT: Single-Shot Multi Object Tracking
3+
https://arxiv.org/abs/2010.16031
4+
"""
5+
from mxnet import gluon
6+
from gluoncv.nn.bbox import BBoxCenterToCorner
7+
8+
9+
class NormalizedLandmarkCenterDecoder(gluon.HybridBlock):
10+
"""
11+
Decode bounding boxes training target with normalized center offsets.
12+
This decoder must cooperate with NormalizedBoxCenterEncoder of same `stds`
13+
in order to get properly reconstructed bounding boxes.
14+
15+
Returned bounding boxes are using corner type: `x_{min}, y_{min}, x_{max}, y_{max}`.
16+
17+
Parameters
18+
----------
19+
stds : array-like of size 4
20+
Std value to be divided from encoded values, default is (0.1, 0.1, 0.2, 0.2).
21+
means : array-like of size 4
22+
Mean value to be subtracted from encoded values, default is (0., 0., 0., 0.).
23+
clip: float, default is None
24+
If given, bounding box target will be clipped to this value.
25+
26+
"""
27+
28+
def __init__(self, stds=(0.1, 0.1, 0.2, 0.2), means=(0., 0., 0., 0.),
29+
convert_anchor=True):
30+
super(NormalizedLandmarkCenterDecoder, self).__init__()
31+
assert len(stds) == 4, "Box Encoder requires 4 std values."
32+
self._stds = stds
33+
self._means = means
34+
if convert_anchor:
35+
self.center_to_conner = BBoxCenterToCorner(split=True)
36+
else:
37+
self.center_to_conner = None
38+
39+
def hybrid_forward(self, F, x, anchors):
40+
"""center decoder forward"""
41+
if self.center_to_conner is not None:
42+
a = self.center_to_conner(anchors)
43+
else:
44+
a = anchors.split(axis=-1, num_outputs=4)
45+
ld = F.split(x, axis=-1, num_outputs=10)
46+
47+
x0 = F.broadcast_add(F.broadcast_mul(ld[0] * self._stds[0] + self._means[0], a[2] - a[0]), a[0])
48+
y0 = F.broadcast_add(F.broadcast_mul(ld[1] * self._stds[1] + self._means[1], a[3] - a[1]), a[1])
49+
x1 = F.broadcast_add(F.broadcast_mul(ld[2] * self._stds[0] + self._means[0], a[2] - a[0]), a[0])
50+
y1 = F.broadcast_add(F.broadcast_mul(ld[3] * self._stds[1] + self._means[1], a[3] - a[1]), a[1])
51+
x2 = F.broadcast_add(F.broadcast_mul(ld[4] * self._stds[0] + self._means[0], a[2] - a[0]), a[0])
52+
y2 = F.broadcast_add(F.broadcast_mul(ld[5] * self._stds[1] + self._means[1], a[3] - a[1]), a[1])
53+
x3 = F.broadcast_add(F.broadcast_mul(ld[6] * self._stds[0] + self._means[0], a[2] - a[0]), a[0])
54+
y3 = F.broadcast_add(F.broadcast_mul(ld[7] * self._stds[1] + self._means[1], a[3] - a[1]), a[1])
55+
x4 = F.broadcast_add(F.broadcast_mul(ld[8] * self._stds[0] + self._means[0], a[2] - a[0]), a[0])
56+
y4 = F.broadcast_add(F.broadcast_mul(ld[9] * self._stds[1] + self._means[1], a[3] - a[1]), a[1])
57+
58+
return F.concat(x0, y0, x1, y1, x2, y2, x3, y3, x4, y4, dim=-1)
59+
60+
61+
class GeneralNormalizedKeyPointsDecoder(gluon.HybridBlock):
62+
"""
63+
Decode bounding boxes training target with normalized center offsets.
64+
This decoder must cooperate with NormalizedBoxCenterEncoder of same `stds`
65+
in order to get properly reconstructed bounding boxes.
66+
67+
Returned bounding boxes are using corner type: `x_{min}, y_{min}, x_{max}, y_{max}`.
68+
69+
Parameters
70+
----------
71+
stds : array-like of size 4
72+
Std value to be divided from encoded values, default is (0.1, 0.1, 0.2, 0.2).
73+
means : array-like of size 4
74+
Mean value to be subtracted from encoded values, default is (0., 0., 0., 0.).
75+
clip: float, default is None
76+
If given, bounding box target will be clipped to this value.
77+
78+
"""
79+
80+
def __init__(self, num_points, stds=(0.2, 0.2), means=(0.5, 0.2),
81+
convert_anchor=True):
82+
super(GeneralNormalizedKeyPointsDecoder, self).__init__()
83+
assert len(stds) == 2, "Box Encoder requires 4 std values."
84+
self._stds = stds
85+
self._means = means
86+
self._size = num_points * 2
87+
if convert_anchor:
88+
self.center_to_conner = BBoxCenterToCorner(split=True)
89+
else:
90+
self.center_to_conner = None
91+
92+
def hybrid_forward(self, F, x, anchors):
93+
"""key point decoder forward"""
94+
if self.center_to_conner is not None:
95+
a = self.center_to_conner(anchors)
96+
else:
97+
a = anchors.split(axis=-1, num_outputs=4)
98+
ld = F.split(x, axis=-1, num_outputs=self._size)
99+
100+
outputs = []
101+
for i in range(0, self._size, 2):
102+
x = F.broadcast_add(F.broadcast_mul(ld[i] * self._stds[0] + self._means[0], a[2] - a[0]), a[0])
103+
y = F.broadcast_add(F.broadcast_mul(ld[i+1] * self._stds[1] + self._means[1], a[3] - a[1]), a[1])
104+
outputs.extend([x, y])
105+
106+
return F.concat(*outputs, dim=-1)

0 commit comments

Comments
 (0)