Skip to content

Commit 6a92926

Browse files
glenn-jocherJoshua Friedrich
authored andcommitted
ACON Activation batch-size 1 bug patch (ultralytics#2901)
* ACON Activation batch-size 1 bug path This is not a great solution to nmaac/acon#4 but it's all I could think of at the moment. WARNING: YOLOv5 models with MetaAconC() activations are incapable of running inference at batch-size 1 properly due to a known bug in nmaac/acon#4 with no known solution. * Update activations.py * Update activations.py * Update activations.py * Update activations.py (cherry picked from commit 9c7bb5a)
1 parent 186fafb commit 6a92926

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

utils/activations.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,13 +84,15 @@ def __init__(self, c1, k=1, s=1, r=16): # ch_in, kernel, stride, r
8484
c2 = max(r, c1 // r)
8585
self.p1 = nn.Parameter(torch.randn(1, c1, 1, 1))
8686
self.p2 = nn.Parameter(torch.randn(1, c1, 1, 1))
87-
self.fc1 = nn.Conv2d(c1, c2, k, s, bias=False)
88-
self.bn1 = nn.BatchNorm2d(c2)
89-
self.fc2 = nn.Conv2d(c2, c1, k, s, bias=False)
90-
self.bn2 = nn.BatchNorm2d(c1)
87+
self.fc1 = nn.Conv2d(c1, c2, k, s, bias=True)
88+
self.fc2 = nn.Conv2d(c2, c1, k, s, bias=True)
89+
# self.bn1 = nn.BatchNorm2d(c2)
90+
# self.bn2 = nn.BatchNorm2d(c1)
9191

9292
def forward(self, x):
9393
y = x.mean(dim=2, keepdims=True).mean(dim=3, keepdims=True)
94-
beta = torch.sigmoid(self.bn2(self.fc2(self.bn1(self.fc1(y)))))
94+
# batch-size 1 bug/instabilities https://github.com/ultralytics/yolov5/issues/2891
95+
# beta = torch.sigmoid(self.bn2(self.fc2(self.bn1(self.fc1(y))))) # bug/unstable
96+
beta = torch.sigmoid(self.fc2(self.fc1(y))) # bug patch BN layers removed
9597
dpx = (self.p1 - self.p2) * x
9698
return dpx * torch.sigmoid(beta * dpx) + self.p2 * x

0 commit comments

Comments
 (0)