Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 26 additions & 24 deletions utils/autoanchor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
import yaml
from tqdm import tqdm

from utils.general import colorstr
from utils.general import LOGGER, colorstr, emojis

PREFIX = colorstr('AutoAnchor: ')


def check_anchor_order(m):
Expand All @@ -19,14 +21,12 @@ def check_anchor_order(m):
da = a[-1] - a[0] # delta a
ds = m.stride[-1] - m.stride[0] # delta s
if da.sign() != ds.sign(): # same order
print('Reversing anchor order')
LOGGER.info(f'{PREFIX}Reversing anchor order')
m.anchors[:] = m.anchors.flip(0)


def check_anchors(dataset, model, thr=4.0, imgsz=640):
# Check anchor fit to data, recompute if necessary
prefix = colorstr('autoanchor: ')
print(f'\n{prefix}Analyzing anchors... ', end='')
m = model.module.model[-1] if hasattr(model, 'module') else model.model[-1] # Detect()
shapes = imgsz * dataset.shapes / dataset.shapes.max(1, keepdims=True)
scale = np.random.uniform(0.9, 1.1, size=(shapes.shape[0], 1)) # augment scale
Expand All @@ -42,23 +42,24 @@ def metric(k): # compute metric

anchors = m.anchors.clone() * m.stride.to(m.anchors.device).view(-1, 1, 1) # current anchors
bpr, aat = metric(anchors.cpu().view(-1, 2))
print(f'anchors/target = {aat:.2f}, Best Possible Recall (BPR) = {bpr:.4f}', end='')
if bpr < 0.98: # threshold to recompute
print('. Attempting to improve anchors, please wait...')
s = f'\n{PREFIX}{aat:.2f} anchors/target, {bpr:.3f} Best Possible Recall (BPR). '
if bpr > 0.98: # threshold to recompute
LOGGER.info(emojis(f'{s}Current anchors are a good fit to dataset βœ…'))
else:
LOGGER.info(emojis(f'{s}Anchors are a poor fit to dataset ⚠️, attempting to improve...'))
na = m.anchors.numel() // 2 # number of anchors
try:
anchors = kmean_anchors(dataset, n=na, img_size=imgsz, thr=thr, gen=1000, verbose=False)
except Exception as e:
print(f'{prefix}ERROR: {e}')
LOGGER.info(f'{PREFIX}ERROR: {e}')
new_bpr = metric(anchors)[0]
if new_bpr > bpr: # replace anchors
anchors = torch.tensor(anchors, device=m.anchors.device).type_as(m.anchors)
m.anchors[:] = anchors.clone().view_as(m.anchors) / m.stride.to(m.anchors.device).view(-1, 1, 1) # loss
check_anchor_order(m)
print(f'{prefix}New anchors saved to model. Update model *.yaml to use these anchors in the future.')
LOGGER.info(f'{PREFIX}New anchors saved to model. Update model *.yaml to use these anchors in the future.')
else:
print(f'{prefix}Original anchors better than new anchors. Proceeding with original anchors.')
print('') # newline
LOGGER.info(f'{PREFIX}Original anchors better than new anchors. Proceeding with original anchors.')


def kmean_anchors(dataset='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=1000, verbose=True):
Expand All @@ -81,7 +82,6 @@ def kmean_anchors(dataset='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen
from scipy.cluster.vq import kmeans

thr = 1 / thr
prefix = colorstr('autoanchor: ')

def metric(k, wh): # compute metrics
r = wh[:, None] / k[None]
Expand All @@ -93,15 +93,17 @@ def anchor_fitness(k): # mutation fitness
_, best = metric(torch.tensor(k, dtype=torch.float32), wh)
return (best * (best > thr).float()).mean() # fitness

def print_results(k):
def print_results(k, verbose=True):
k = k[np.argsort(k.prod(1))] # sort small to large
x, best = metric(k, wh0)
bpr, aat = (best > thr).float().mean(), (x > thr).float().mean() * n # best possible recall, anch > thr
print(f'{prefix}thr={thr:.2f}: {bpr:.4f} best possible recall, {aat:.2f} anchors past thr')
print(f'{prefix}n={n}, img_size={img_size}, metric_all={x.mean():.3f}/{best.mean():.3f}-mean/best, '
f'past_thr={x[x > thr].mean():.3f}-mean: ', end='')
s = f'{PREFIX}thr={thr:.2f}: {bpr:.4f} best possible recall, {aat:.2f} anchors past thr\n' \
f'{PREFIX}n={n}, img_size={img_size}, metric_all={x.mean():.3f}/{best.mean():.3f}-mean/best, ' \
f'past_thr={x[x > thr].mean():.3f}-mean: '
for i, x in enumerate(k):
print('%i,%i' % (round(x[0]), round(x[1])), end=', ' if i < len(k) - 1 else '\n') # use in *.cfg
s += '%i,%i, ' % (round(x[0]), round(x[1]))
if verbose:
LOGGER.info(s[:-2])
return k

if isinstance(dataset, str): # *.yaml file
Expand All @@ -117,19 +119,19 @@ def print_results(k):
# Filter
i = (wh0 < 3.0).any(1).sum()
if i:
print(f'{prefix}WARNING: Extremely small objects found. {i} of {len(wh0)} labels are < 3 pixels in size.')
LOGGER.info(f'{PREFIX}WARNING: Extremely small objects found. {i} of {len(wh0)} labels are < 3 pixels in size.')
wh = wh0[(wh0 >= 2.0).any(1)] # filter > 2 pixels
# wh = wh * (np.random.rand(wh.shape[0], 1) * 0.9 + 0.1) # multiply by random scale 0-1

# Kmeans calculation
print(f'{prefix}Running kmeans for {n} anchors on {len(wh)} points...')
LOGGER.info(f'{PREFIX}Running kmeans for {n} anchors on {len(wh)} points...')
s = wh.std(0) # sigmas for whitening
k, dist = kmeans(wh / s, n, iter=30) # points, mean distance
assert len(k) == n, f'{prefix}ERROR: scipy.cluster.vq.kmeans requested {n} points but returned only {len(k)}'
assert len(k) == n, f'{PREFIX}ERROR: scipy.cluster.vq.kmeans requested {n} points but returned only {len(k)}'
k *= s
wh = torch.tensor(wh, dtype=torch.float32) # filtered
wh0 = torch.tensor(wh0, dtype=torch.float32) # unfiltered
k = print_results(k)
k = print_results(k, verbose=False)

# Plot
# k, d = [None] * 20, [None] * 20
Expand All @@ -146,7 +148,7 @@ def print_results(k):
# Evolve
npr = np.random
f, sh, mp, s = anchor_fitness(k), k.shape, 0.9, 0.1 # fitness, generations, mutation prob, sigma
pbar = tqdm(range(gen), desc=f'{prefix}Evolving anchors with Genetic Algorithm:') # progress bar
pbar = tqdm(range(gen), desc=f'{PREFIX}Evolving anchors with Genetic Algorithm:') # progress bar
for _ in pbar:
v = np.ones(sh)
while (v == 1).all(): # mutate until a change occurs (prevent duplicates)
Expand All @@ -155,8 +157,8 @@ def print_results(k):
fg = anchor_fitness(kg)
if fg > f:
f, k = fg, kg.copy()
pbar.desc = f'{prefix}Evolving anchors with Genetic Algorithm: fitness = {f:.4f}'
pbar.desc = f'{PREFIX}Evolving anchors with Genetic Algorithm: fitness = {f:.4f}'
if verbose:
print_results(k)
print_results(k, verbose)

return print_results(k)
14 changes: 7 additions & 7 deletions utils/autobatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
from torch.cuda import amp

from utils.general import colorstr
from utils.general import LOGGER, colorstr
from utils.torch_utils import profile


Expand All @@ -27,11 +27,11 @@ def autobatch(model, imgsz=640, fraction=0.9, batch_size=16):
# model = torch.hub.load('ultralytics/yolov5', 'yolov5s', autoshape=False)
# print(autobatch(model))

prefix = colorstr('autobatch: ')
print(f'{prefix}Computing optimal batch size for --imgsz {imgsz}')
prefix = colorstr('AutoBatch: ')
LOGGER.info(f'{prefix}Computing optimal batch size for --imgsz {imgsz}')
device = next(model.parameters()).device # get model device
if device.type == 'cpu':
print(f'{prefix}CUDA not detected, using default CPU batch-size {batch_size}')
LOGGER.info(f'{prefix}CUDA not detected, using default CPU batch-size {batch_size}')
return batch_size

d = str(device).upper() # 'CUDA:0'
Expand All @@ -40,18 +40,18 @@ def autobatch(model, imgsz=640, fraction=0.9, batch_size=16):
r = torch.cuda.memory_reserved(device) / 1024 ** 3 # (GiB)
a = torch.cuda.memory_allocated(device) / 1024 ** 3 # (GiB)
f = t - (r + a) # free inside reserved
print(f'{prefix}{d} ({properties.name}) {t:.2f}G total, {r:.2f}G reserved, {a:.2f}G allocated, {f:.2f}G free')
LOGGER.info(f'{prefix}{d} ({properties.name}) {t:.2f}G total, {r:.2f}G reserved, {a:.2f}G allocated, {f:.2f}G free')

batch_sizes = [1, 2, 4, 8, 16]
try:
img = [torch.zeros(b, 3, imgsz, imgsz) for b in batch_sizes]
y = profile(img, model, n=3, device=device)
except Exception as e:
print(f'{prefix}{e}')
LOGGER.warning(f'{prefix}{e}')

y = [x[2] for x in y if x] # memory [2]
batch_sizes = batch_sizes[:len(y)]
p = np.polyfit(batch_sizes, y, deg=1) # first degree polynomial fit
b = int((f * fraction - p[1]) / p[0]) # y intercept (optimal batch size)
print(f'{prefix}Using batch-size {b} for {d} {t * fraction:.2f}G/{t:.2f}G ({fraction * 100:.0f}%)')
LOGGER.info(f'{prefix}Using batch-size {b} for {d} {t * fraction:.2f}G/{t:.2f}G ({fraction * 100:.0f}%)')
return b
6 changes: 3 additions & 3 deletions utils/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
import torch
from PIL import Image, ImageDraw, ImageFont

from utils.general import (Timeout, check_requirements, clip_coords, increment_path, is_ascii, is_chinese, try_except,
user_config_dir, xywh2xyxy, xyxy2xywh)
from utils.general import (LOGGER, Timeout, check_requirements, clip_coords, increment_path, is_ascii, is_chinese,
try_except, user_config_dir, xywh2xyxy, xyxy2xywh)
from utils.metrics import fitness

# Settings
Expand Down Expand Up @@ -328,7 +328,7 @@ def plot_val_study(file='', dir='', x=None): # from utils.plots import *; plot_
@Timeout(30) # known issue https://github.com/ultralytics/yolov5/issues/5611
def plot_labels(labels, names=(), save_dir=Path('')):
# plot dataset labels
print('Plotting labels... ')
LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ")
c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes
nc = int(c.max() + 1) # number of classes
x = pd.DataFrame(b.transpose(), columns=['x', 'y', 'width', 'height'])
Expand Down