Skip to content

Commit b82405d

Browse files
glenn-jochertdhooghe
authored andcommitted
AutoBatch checks against failed solutions (ultralytics#8159)
* AutoBatch checks against failed solutions @kalenmike this is a simple improvement to AutoBatch to verify that returned solutions have not already failed, i.e. return batch-size 8 when 8 already produced CUDA out of memory. This is a halfway fix until I can implement a 'final solution' that will actively verify the solved-for batch size rather than passively assume it works. * Update autobatch.py * Update autobatch.py
1 parent 34d9a4b commit b82405d

File tree

1 file changed

+19
-10
lines changed

1 file changed

+19
-10
lines changed

utils/autobatch.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import numpy as np
99
import torch
1010

11-
from utils.general import LOGGER, colorstr
11+
from utils.general import LOGGER, colorstr, emojis
1212
from utils.torch_utils import profile
1313

1414

@@ -26,32 +26,41 @@ def autobatch(model, imgsz=640, fraction=0.9, batch_size=16):
2626
# model = torch.hub.load('ultralytics/yolov5', 'yolov5s', autoshape=False)
2727
# print(autobatch(model))
2828

29+
# Check device
2930
prefix = colorstr('AutoBatch: ')
3031
LOGGER.info(f'{prefix}Computing optimal batch size for --imgsz {imgsz}')
3132
device = next(model.parameters()).device # get model device
3233
if device.type == 'cpu':
3334
LOGGER.info(f'{prefix}CUDA not detected, using default CPU batch-size {batch_size}')
3435
return batch_size
3536

37+
# Inspect CUDA memory
3638
gb = 1 << 30 # bytes to GiB (1024 ** 3)
3739
d = str(device).upper() # 'CUDA:0'
3840
properties = torch.cuda.get_device_properties(device) # device properties
39-
t = properties.total_memory / gb # (GiB)
40-
r = torch.cuda.memory_reserved(device) / gb # (GiB)
41-
a = torch.cuda.memory_allocated(device) / gb # (GiB)
42-
f = t - (r + a) # free inside reserved
41+
t = properties.total_memory / gb # GiB total
42+
r = torch.cuda.memory_reserved(device) / gb # GiB reserved
43+
a = torch.cuda.memory_allocated(device) / gb # GiB allocated
44+
f = t - (r + a) # GiB free
4345
LOGGER.info(f'{prefix}{d} ({properties.name}) {t:.2f}G total, {r:.2f}G reserved, {a:.2f}G allocated, {f:.2f}G free')
4446

47+
# Profile batch sizes
4548
batch_sizes = [1, 2, 4, 8, 16]
4649
try:
4750
img = [torch.zeros(b, 3, imgsz, imgsz) for b in batch_sizes]
48-
y = profile(img, model, n=3, device=device)
51+
results = profile(img, model, n=3, device=device)
4952
except Exception as e:
5053
LOGGER.warning(f'{prefix}{e}')
5154

52-
y = [x[2] for x in y if x] # memory [2]
53-
batch_sizes = batch_sizes[:len(y)]
54-
p = np.polyfit(batch_sizes, y, deg=1) # first degree polynomial fit
55+
# Fit a solution
56+
y = [x[2] for x in results if x] # memory [2]
57+
p = np.polyfit(batch_sizes[:len(y)], y, deg=1) # first degree polynomial fit
5558
b = int((f * fraction - p[1]) / p[0]) # y intercept (optimal batch size)
56-
LOGGER.info(f'{prefix}Using batch-size {b} for {d} {t * fraction:.2f}G/{t:.2f}G ({fraction * 100:.0f}%)')
59+
if None in results: # some sizes failed
60+
i = results.index(None) # first fail index
61+
if b >= batch_sizes[i]: # y intercept above failure point
62+
b = batch_sizes[max(i - 1, 0)] # select prior safe point
63+
64+
fraction = np.polyval(p, b) / t # actual fraction predicted
65+
LOGGER.info(emojis(f'{prefix}Using batch-size {b} for {d} {t * fraction:.2f}G/{t:.2f}G ({fraction * 100:.0f}%) ✅'))
5766
return b

0 commit comments

Comments
 (0)