Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,13 @@
from utils.downloads import attempt_download
from utils.general import (LOGGER, check_dataset, check_file, check_git_status, check_img_size, check_requirements,
check_suffix, check_yaml, colorstr, get_latest_run, increment_path, init_seeds,
intersect_dicts, labels_to_class_weights, labels_to_image_weights, methods, one_cycle,
print_args, print_mutation, strip_optimizer)
intersect_dicts, is_ascii, labels_to_class_weights, labels_to_image_weights, methods,
one_cycle, print_args, print_mutation, strip_optimizer)
from utils.loggers import Loggers
from utils.loggers.wandb.wandb_utils import check_wandb_resume
from utils.loss import ComputeLoss
from utils.metrics import fitness
from utils.plots import plot_evolve, plot_labels
from utils.plots import check_font, plot_evolve, plot_labels
from utils.torch_utils import EarlyStopping, ModelEMA, de_parallel, select_device, torch_distributed_zero_first

LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
Expand Down Expand Up @@ -105,6 +105,8 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
init_seeds(1 + RANK)
with torch_distributed_zero_first(LOCAL_RANK):
data_dict = data_dict or check_dataset(data) # check if None
if not is_ascii(data_dict['names']): # non-latin labels, i.e. asian, arabic, cyrillic
check_font('Arial.Unicode.ttf', progress=True)
train_path, val_path = data_dict['train'], data_dict['val']
nc = 1 if single_cls else int(data_dict['nc']) # number of classes
names = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
Expand Down
4 changes: 2 additions & 2 deletions utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,13 +424,13 @@ def check_file(file, suffix=''):
return files[0] # return file


def check_font(font=FONT):
def check_font(font=FONT, progress=False):
# Download font to CONFIG_DIR if necessary
font = Path(font)
if not font.exists() and not (CONFIG_DIR / font.name).exists():
url = "https://ultralytics.com/assets/" + font.name
LOGGER.info(f'Downloading {url} to {CONFIG_DIR / font.name}...')
torch.hub.download_url_to_file(url, str(font), progress=False)
torch.hub.download_url_to_file(url, str(font), progress=progress)


def check_dataset(data, autodownload=True):
Expand Down
7 changes: 4 additions & 3 deletions utils/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from PIL import Image, ImageDraw, ImageFont

from utils.general import (CONFIG_DIR, FONT, LOGGER, Timeout, check_font, check_requirements, clip_coords,
increment_path, is_ascii, is_chinese, try_except, xywh2xyxy, xyxy2xywh)
increment_path, is_ascii, try_except, xywh2xyxy, xyxy2xywh)
from utils.metrics import fitness

# Settings
Expand Down Expand Up @@ -72,11 +72,12 @@ class Annotator:
# YOLOv5 Annotator for train/val mosaics and jpgs and detect/hub inference annotations
def __init__(self, im, line_width=None, font_size=None, font='Arial.ttf', pil=False, example='abc'):
assert im.data.contiguous, 'Image not contiguous. Apply np.ascontiguousarray(im) to Annotator() input images.'
self.pil = pil or not is_ascii(example) or is_chinese(example)
non_ascii = not is_ascii(example) # non-latin labels, i.e. asian, arabic, cyrillic
self.pil = pil or non_ascii
if self.pil: # use PIL
self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)
self.draw = ImageDraw.Draw(self.im)
self.font = check_pil_font(font='Arial.Unicode.ttf' if is_chinese(example) else font,
self.font = check_pil_font(font='Arial.Unicode.ttf' if non_ascii else font,
size=font_size or max(round(sum(self.im.size) / 2 * 0.035), 12))
else: # use cv2
self.im = im
Expand Down