|
1 | 1 | # Model validation metrics |
2 | 2 |
|
| 3 | +import math |
3 | 4 | import warnings |
4 | 5 | from pathlib import Path |
5 | 6 |
|
6 | 7 | import matplotlib.pyplot as plt |
7 | 8 | import numpy as np |
8 | 9 | import torch |
9 | 10 |
|
10 | | -from . import general |
11 | | - |
12 | 11 |
|
13 | 12 | def fitness(x): |
14 | 13 | # Model fitness as a weighted combination of metrics |
@@ -128,7 +127,7 @@ def process_batch(self, detections, labels): |
128 | 127 | detections = detections[detections[:, 4] > self.conf] |
129 | 128 | gt_classes = labels[:, 0].int() |
130 | 129 | detection_classes = detections[:, 5].int() |
131 | | - iou = general.box_iou(labels[:, 1:], detections[:, :4]) |
| 130 | + iou = box_iou(labels[:, 1:], detections[:, :4]) |
132 | 131 |
|
133 | 132 | x = torch.where(iou > self.iou_thres) |
134 | 133 | if x[0].shape[0]: |
@@ -184,6 +183,84 @@ def print(self): |
184 | 183 | print(' '.join(map(str, self.matrix[i]))) |
185 | 184 |
|
186 | 185 |
|
| 186 | +def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7): |
| 187 | + # Returns the IoU of box1 to box2. box1 is 4, box2 is nx4 |
| 188 | + box2 = box2.T |
| 189 | + |
| 190 | + # Get the coordinates of bounding boxes |
| 191 | + if x1y1x2y2: # x1, y1, x2, y2 = box1 |
| 192 | + b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3] |
| 193 | + b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3] |
| 194 | + else: # transform from xywh to xyxy |
| 195 | + b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2 |
| 196 | + b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2 |
| 197 | + b2_x1, b2_x2 = box2[0] - box2[2] / 2, box2[0] + box2[2] / 2 |
| 198 | + b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2 |
| 199 | + |
| 200 | + # Intersection area |
| 201 | + inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \ |
| 202 | + (torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0) |
| 203 | + |
| 204 | + # Union Area |
| 205 | + w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps |
| 206 | + w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps |
| 207 | + union = w1 * h1 + w2 * h2 - inter + eps |
| 208 | + |
| 209 | + iou = inter / union |
| 210 | + if GIoU or DIoU or CIoU: |
| 211 | + cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) # convex (smallest enclosing box) width |
| 212 | + ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height |
| 213 | + if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1 |
| 214 | + c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared |
| 215 | + rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + |
| 216 | + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center distance squared |
| 217 | + if DIoU: |
| 218 | + return iou - rho2 / c2 # DIoU |
| 219 | + elif CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47 |
| 220 | + v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2) |
| 221 | + with torch.no_grad(): |
| 222 | + alpha = v / (v - iou + (1 + eps)) |
| 223 | + return iou - (rho2 / c2 + v * alpha) # CIoU |
| 224 | + else: # GIoU https://arxiv.org/pdf/1902.09630.pdf |
| 225 | + c_area = cw * ch + eps # convex area |
| 226 | + return iou - (c_area - union) / c_area # GIoU |
| 227 | + else: |
| 228 | + return iou # IoU |
| 229 | + |
| 230 | + |
| 231 | +def box_iou(box1, box2): |
| 232 | + # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py |
| 233 | + """ |
| 234 | + Return intersection-over-union (Jaccard index) of boxes. |
| 235 | + Both sets of boxes are expected to be in (x1, y1, x2, y2) format. |
| 236 | + Arguments: |
| 237 | + box1 (Tensor[N, 4]) |
| 238 | + box2 (Tensor[M, 4]) |
| 239 | + Returns: |
| 240 | + iou (Tensor[N, M]): the NxM matrix containing the pairwise |
| 241 | + IoU values for every element in boxes1 and boxes2 |
| 242 | + """ |
| 243 | + |
| 244 | + def box_area(box): |
| 245 | + # box = 4xn |
| 246 | + return (box[2] - box[0]) * (box[3] - box[1]) |
| 247 | + |
| 248 | + area1 = box_area(box1.T) |
| 249 | + area2 = box_area(box2.T) |
| 250 | + |
| 251 | + # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2) |
| 252 | + inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2) |
| 253 | + return inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter) |
| 254 | + |
| 255 | + |
| 256 | +def wh_iou(wh1, wh2): |
| 257 | + # Returns the nxm IoU matrix. wh1 is nx2, wh2 is mx2 |
| 258 | + wh1 = wh1[:, None] # [N,1,2] |
| 259 | + wh2 = wh2[None] # [1,M,2] |
| 260 | + inter = torch.min(wh1, wh2).prod(2) # [N,M] |
| 261 | + return inter / (wh1.prod(2) + wh2.prod(2) - inter) # iou = inter / (area1 + area2 - inter) |
| 262 | + |
| 263 | + |
187 | 264 | # Plots ---------------------------------------------------------------------------------------------------------------- |
188 | 265 |
|
189 | 266 | def plot_pr_curve(px, py, ap, save_dir='pr_curve.png', names=()): |
|
0 commit comments