Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,9 +524,10 @@ class AutoShape(nn.Module):
max_det = 1000 # maximum number of detections per image
amp = False # Automatic Mixed Precision (AMP) inference

def __init__(self, model):
def __init__(self, model, verbose=True):
super().__init__()
LOGGER.info('Adding AutoShape... ')
if verbose:
LOGGER.info('Adding AutoShape... ')
copy_attr(self, model, include=('yaml', 'nc', 'hyp', 'names', 'stride', 'abc'), exclude=()) # copy attributes
self.dmb = isinstance(model, DetectMultiBackend) # DetectMultiBackend() instance
self.pt = not self.dmb or model.pt # PyTorch model
Expand Down
16 changes: 8 additions & 8 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import torch.distributed as dist
import torch.nn as nn
import yaml
from torch.cuda import amp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import SGD, Adam, AdamW, lr_scheduler
from tqdm import tqdm
Expand All @@ -46,10 +45,10 @@
from utils.callbacks import Callbacks
from utils.dataloaders import create_dataloader
from utils.downloads import attempt_download
from utils.general import (LOGGER, check_dataset, check_file, check_git_status, check_img_size, check_requirements,
check_suffix, check_version, check_yaml, colorstr, get_latest_run, increment_path,
init_seeds, intersect_dicts, labels_to_class_weights, labels_to_image_weights, methods,
one_cycle, print_args, print_mutation, strip_optimizer)
from utils.general import (LOGGER, check_amp, check_dataset, check_file, check_git_status, check_img_size,
check_requirements, check_suffix, check_version, check_yaml, colorstr, get_latest_run,
increment_path, init_seeds, intersect_dicts, labels_to_class_weights,
labels_to_image_weights, methods, one_cycle, print_args, print_mutation, strip_optimizer)
from utils.loggers import Loggers
from utils.loggers.wandb.wandb_utils import check_wandb_resume
from utils.loss import ComputeLoss
Expand Down Expand Up @@ -126,6 +125,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
LOGGER.info(f'Transferred {len(csd)}/{len(model.state_dict())} items from {weights}') # report
else:
model = Model(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
amp = check_amp(model) # check AMP

# Freeze
freeze = [f'model.{x}.' for x in (freeze if len(freeze) > 1 else range(freeze[0]))] # layers to freeze
Expand All @@ -141,7 +141,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio

# Batch size
if RANK == -1 and batch_size == -1: # single-GPU only, estimate best batch size
batch_size = check_train_batch_size(model, imgsz)
batch_size = check_train_batch_size(model, imgsz, amp)
loggers.on_params_update({"batch_size": batch_size})

# Optimizer
Expand Down Expand Up @@ -293,7 +293,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
maps = np.zeros(nc) # mAP per class
results = (0, 0, 0, 0, 0, 0, 0) # P, R, [email protected], [email protected], val_loss(box, obj, cls)
scheduler.last_epoch = start_epoch - 1 # do not move
scaler = amp.GradScaler(enabled=cuda)
scaler = torch.cuda.amp.GradScaler(enabled=amp)
stopper = EarlyStopping(patience=opt.patience)
compute_loss = ComputeLoss(model) # init loss class
callbacks.run('on_train_start')
Expand Down Expand Up @@ -348,7 +348,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
imgs = nn.functional.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)

# Forward
with amp.autocast(enabled=cuda):
with torch.cuda.amp.autocast(amp):
pred = model(imgs) # forward
loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
if RANK != -1:
Expand Down
5 changes: 2 additions & 3 deletions utils/autobatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,14 @@

import numpy as np
import torch
from torch.cuda import amp

from utils.general import LOGGER, colorstr
from utils.torch_utils import profile


def check_train_batch_size(model, imgsz=640):
def check_train_batch_size(model, imgsz=640, amp=True):
# Check YOLOv5 training batch size
with amp.autocast():
with torch.cuda.amp.autocast(amp):
return autobatch(deepcopy(model).train(), imgsz) # compute optimal batch size


Expand Down
25 changes: 24 additions & 1 deletion utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,11 @@
from utils.downloads import gsutil_getsize
from utils.metrics import box_iou, fitness

# Settings
FILE = Path(__file__).resolve()
ROOT = FILE.parents[1] # YOLOv5 root directory
RANK = int(os.getenv('RANK', -1))

# Settings
DATASETS_DIR = ROOT.parent / 'datasets' # YOLOv5 datasets directory
NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads
AUTOINSTALL = str(os.getenv('YOLOv5_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode
Expand Down Expand Up @@ -505,6 +507,27 @@ def check_dataset(data, autodownload=True):
return data # dictionary


def check_amp(model):
# Check PyTorch Automatic Mixed Precision (AMP) functionality. Return True on correct operation
from models.common import AutoShape

if next(model.parameters()).device.type == 'cpu': # get model device
return False
prefix = colorstr('AMP: ')
im = cv2.imread(ROOT / 'data' / 'images' / 'bus.jpg')[..., ::-1] # OpenCV image (BGR to RGB)
m = AutoShape(model, verbose=False) # model
a = m(im).xyxy[0] # FP32 inference
m.amp = True
b = m(im).xyxy[0] # AMP inference
if (a.shape == b.shape) and torch.allclose(a, b, atol=1.0): # close to 1.0 pixel bounding box
LOGGER.info(emojis(f'{prefix}checks passed βœ…'))
return True
else:
help_url = 'https://github.com/ultralytics/yolov5/issues/7908'
LOGGER.warning(emojis(f'{prefix}checks failed ❌, disabling Automatic Mixed Precision. See {help_url}'))
return False


def url2file(url):
# Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt
url = str(Path(url)).replace(':/', '://') # Pathlib turns :// -> :/
Expand Down