Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 17 additions & 17 deletions utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,27 +506,27 @@ def check_dataset(data, autodownload=True):

def check_amp(model):
# Check PyTorch Automatic Mixed Precision (AMP) functionality. Return True on correct operation
from models.common import AutoShape
from models.common import AutoShape, DetectMultiBackend

def amp_allclose(model, im):
# All close FP32 vs AMP results
m = AutoShape(model, verbose=False) # model
a = m(im).xywhn[0] # FP32 inference
m.amp = True
b = m(im).xywhn[0] # AMP inference
return a.shape == b.shape and torch.allclose(a, b, atol=0.1) # close to 10% absolute tolerance

if next(model.parameters()).device.type == 'cpu': # get model device
return False
prefix = colorstr('AMP: ')
file = ROOT / 'data' / 'images' / 'bus.jpg' # image to test
if file.exists():
im = cv2.imread(file)[..., ::-1] # OpenCV image (BGR to RGB)
elif check_online():
im = 'https://ultralytics.com/images/bus.jpg'
else:
LOGGER.warning(emojis(f'{prefix}checks skipped ⚠️, not online.'))
return True
m = AutoShape(model, verbose=False) # model
a = m(im).xywhn[0] # FP32 inference
m.amp = True
b = m(im).xywhn[0] # AMP inference
if (a.shape == b.shape) and torch.allclose(a, b, atol=0.05): # close to 5% absolute tolerance
device = next(model.parameters()).device # get model device
if device.type == 'cpu':
return False # AMP disabled on CPU
f = ROOT / 'data' / 'images' / 'bus.jpg' # image to check
im = f if f.exists() else 'https://ultralytics.com/images/bus.jpg' if check_online() else np.ones((640, 640, 3))
try:
assert amp_allclose(model, im) or amp_allclose(DetectMultiBackend('yolov5n.pt', device), im)
LOGGER.info(emojis(f'{prefix}checks passed βœ…'))
return True
else:
except Exception:
help_url = 'https://github.com/ultralytics/yolov5/issues/7908'
LOGGER.warning(emojis(f'{prefix}checks failed ❌, disabling Automatic Mixed Precision. See {help_url}'))
return False
Expand Down