Skip to content

Commit 535f285

Browse files
authored
Update loss.py with if self.gr < 1: (ultralytics#7087)
* Update loss.py with `if self.gr < 1:` * Update loss.py
1 parent f123016 commit 535f285

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

utils/loss.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,11 +139,13 @@ def __call__(self, p, targets): # predictions, targets
139139
lbox += (1.0 - iou).mean() # iou loss
140140

141141
# Objectness
142-
score_iou = iou.detach().clamp(0).type(tobj.dtype)
142+
iou = iou.detach().clamp(0).type(tobj.dtype)
143143
if self.sort_obj_iou:
144-
sort_id = torch.argsort(score_iou)
145-
b, a, gj, gi, score_iou = b[sort_id], a[sort_id], gj[sort_id], gi[sort_id], score_iou[sort_id]
146-
tobj[b, a, gj, gi] = (1.0 - self.gr) + self.gr * score_iou # iou ratio
144+
j = iou.argsort()
145+
b, a, gj, gi, iou = b[j], a[j], gj[j], gi[j], iou[j]
146+
if self.gr < 1:
147+
iou = (1.0 - self.gr) + self.gr * iou
148+
tobj[b, a, gj, gi] = iou # iou ratio
147149

148150
# Classification
149151
if self.nc > 1: # cls loss (only if multiple classes)

0 commit comments

Comments
 (0)