Skip to content

Commit 01cdb76

Browse files
authored
Add SPPF() layer (#4420)
* Add `SPPF()` layer * Cleanup * Add credit
1 parent 24bea5e commit 01cdb76

File tree

2 files changed

+25
-5
lines changed

2 files changed

+25
-5
lines changed

models/common.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
161161

162162

163163
class SPP(nn.Module):
164-
# Spatial pyramid pooling layer used in YOLOv3-SPP
164+
# Spatial Pyramid Pooling (SPP) layer https://arxiv.org/abs/1406.4729
165165
def __init__(self, c1, c2, k=(5, 9, 13)):
166166
super().__init__()
167167
c_ = c1 // 2 # hidden channels
@@ -176,6 +176,24 @@ def forward(self, x):
176176
return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))
177177

178178

179+
class SPPF(nn.Module):
180+
# Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher
181+
def __init__(self, c1, c2, k=5): # equivalent to SPP(k=(5, 9, 13))
182+
super().__init__()
183+
c_ = c1 // 2 # hidden channels
184+
self.cv1 = Conv(c1, c_, 1, 1)
185+
self.cv2 = Conv(c_ * 4, c2, 1, 1)
186+
self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
187+
188+
def forward(self, x):
189+
x = self.cv1(x)
190+
with warnings.catch_warnings():
191+
warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
192+
y1 = self.m(x)
193+
y2 = self.m(y1)
194+
return self.cv2(torch.cat([x, y1, y2, self.m(y2)], 1))
195+
196+
179197
class Focus(nn.Module):
180198
# Focus wh information into c-space
181199
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups

models/yolo.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -237,8 +237,8 @@ def parse_model(d, ch): # model_dict, input_channels(3)
237237
pass
238238

239239
n = n_ = max(round(n * gd), 1) if n > 1 else n # depth gain
240-
if m in [Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP,
241-
C3, C3TR, C3SPP, C3Ghost]:
240+
if m in [Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv,
241+
BottleneckCSP, C3, C3TR, C3SPP, C3Ghost]:
242242
c1, c2 = ch[f], args[0]
243243
if c2 != no: # if not output
244244
c2 = make_divisible(c2 * gw, 8)
@@ -279,6 +279,7 @@ def parse_model(d, ch): # model_dict, input_channels(3)
279279
parser = argparse.ArgumentParser()
280280
parser.add_argument('--cfg', type=str, default='yolov5s.yaml', help='model.yaml')
281281
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
282+
parser.add_argument('--profile', action='store_true', help='profile model speed')
282283
opt = parser.parse_args()
283284
opt.cfg = check_file(opt.cfg) # check file
284285
set_logging()
@@ -289,8 +290,9 @@ def parse_model(d, ch): # model_dict, input_channels(3)
289290
model.train()
290291

291292
# Profile
292-
# img = torch.rand(8 if torch.cuda.is_available() else 1, 3, 320, 320).to(device)
293-
# y = model(img, profile=True)
293+
if opt.profile:
294+
img = torch.rand(8 if torch.cuda.is_available() else 1, 3, 640, 640).to(device)
295+
y = model(img, profile=True)
294296

295297
# Tensorboard (not working https://github.com/ultralytics/yolov5/issues/2898)
296298
# from torch.utils.tensorboard import SummaryWriter

0 commit comments

Comments
 (0)