Skip to content

Commit 136fa03

Browse files
authored
Update Detect() grid init for loop
May resolve threaded inference issue in #9425 (comment) by avoiding memory sharing on init. Signed-off-by: Glenn Jocher <[email protected]>
1 parent f038ad7 commit 136fa03

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

models/yolo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ def __init__(self, nc=80, anchors=(), ch=(), inplace=True): # detection layer
4747
self.no = nc + 5 # number of outputs per anchor
4848
self.nl = len(anchors) # number of detection layers
4949
self.na = len(anchors[0]) // 2 # number of anchors
50-
self.grid = [torch.empty(1)] * self.nl # init grid
51-
self.anchor_grid = [torch.empty(1)] * self.nl # init anchor grid
50+
self.grid = [torch.empty(0) for _ in range(self.nl)] # init grid
51+
self.anchor_grid = [torch.empty(0) for _ in range(self.nl)] # init anchor grid
5252
self.register_buffer('anchors', torch.tensor(anchors).float().view(self.nl, -1, 2)) # shape(nl,na,2)
5353
self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
5454
self.inplace = inplace # use inplace ops (e.g. slice assignment)

0 commit comments

Comments
 (0)