Skip to content

Commit 0529b77

Browse files
authored
Update common.py lists for tuples (#7063)
Improved profiling.
1 parent d5e363f commit 0529b77

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

models/common.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
def autopad(k, p=None): # kernel, padding
3232
# Pad to 'same'
3333
if p is None:
34-
p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
34+
p = k // 2 if isinstance(k, int) else (x // 2 for x in k) # auto-pad
3535
return p
3636

3737

@@ -133,7 +133,7 @@ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, nu
133133
self.cv2 = Conv(c1, c_, 1, 1)
134134
self.cv3 = Conv(2 * c_, c2, 1) # act=FReLU(c2)
135135
self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
136-
# self.m = nn.Sequential(*[CrossConv(c_, c_, 3, 1, g, 1.0, shortcut) for _ in range(n)])
136+
# self.m = nn.Sequential(*(CrossConv(c_, c_, 3, 1, g, 1.0, shortcut) for _ in range(n)))
137137

138138
def forward(self, x):
139139
return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))
@@ -194,7 +194,7 @@ def forward(self, x):
194194
warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
195195
y1 = self.m(x)
196196
y2 = self.m(y1)
197-
return self.cv2(torch.cat([x, y1, y2, self.m(y2)], 1))
197+
return self.cv2(torch.cat((x, y1, y2, self.m(y2)), 1))
198198

199199

200200
class Focus(nn.Module):
@@ -205,7 +205,7 @@ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, k
205205
# self.contract = Contract(gain=2)
206206

207207
def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2)
208-
return self.conv(torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1))
208+
return self.conv(torch.cat((x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]), 1))
209209
# return self.conv(self.contract(x))
210210

211211

@@ -219,7 +219,7 @@ def __init__(self, c1, c2, k=1, s=1, g=1, act=True): # ch_in, ch_out, kernel, s
219219

220220
def forward(self, x):
221221
y = self.cv1(x)
222-
return torch.cat([y, self.cv2(y)], 1)
222+
return torch.cat((y, self.cv2(y)), 1)
223223

224224

225225
class GhostBottleneck(nn.Module):

0 commit comments

Comments
 (0)