Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@
from utils.downloads import attempt_download
from utils.general import (LOGGER, check_dataset, check_file, check_git_status, check_img_size, check_requirements,
check_suffix, check_version, check_yaml, colorstr, get_latest_run, increment_path,
init_seeds, intersect_dicts, is_ascii, labels_to_class_weights, labels_to_image_weights,
methods, one_cycle, print_args, print_mutation, strip_optimizer)
init_seeds, intersect_dicts, labels_to_class_weights, labels_to_image_weights, methods,
one_cycle, print_args, print_mutation, strip_optimizer)
from utils.loggers import Loggers
from utils.loggers.wandb.wandb_utils import check_wandb_resume
from utils.loss import ComputeLoss
Expand Down
11 changes: 11 additions & 0 deletions utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import re
import shutil
import signal
import threading
import time
import urllib
from datetime import datetime
Expand Down Expand Up @@ -167,6 +168,16 @@ def handler(*args, **kwargs):
return handler


def threaded(func):
# Multi-threads a target function and returns thread. Usage: @threaded decorator
def wrapper(*args, **kwargs):
thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
thread.start()
return thread

return wrapper


def methods(instance):
# Get class/instance methods
return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]
Expand Down
7 changes: 3 additions & 4 deletions utils/loggers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import os
import warnings
from threading import Thread

import pkg_resources as pkg
import torch
Expand Down Expand Up @@ -109,7 +108,7 @@ def on_train_batch_end(self, ni, model, imgs, targets, paths, plots):
self.tb.add_graph(torch.jit.trace(de_parallel(model), imgs[0:1], strict=False), [])
if ni < 3:
f = self.save_dir / f'train_batch{ni}.jpg' # filename
Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start()
plot_images(imgs, targets, paths, f)
if self.wandb and ni == 10:
files = sorted(self.save_dir.glob('train*.jpg'))
self.wandb.log({'Mosaics': [wandb.Image(str(f), caption=f.name) for f in files if f.exists()]})
Expand All @@ -132,7 +131,7 @@ def on_val_end(self):

def on_fit_epoch_end(self, vals, epoch, best_fitness, fi):
# Callback runs at the end of each fit (train+val) epoch
x = {k: v for k, v in zip(self.keys, vals)} # dict
x = dict(zip(self.keys, vals))
if self.csv:
file = self.save_dir / 'results.csv'
n = len(x) + 1 # number of cols
Expand Down Expand Up @@ -171,7 +170,7 @@ def on_train_end(self, last, best, plots, epoch, results):
self.tb.add_image(f.stem, cv2.imread(str(f))[..., ::-1], epoch, dataformats='HWC')

if self.wandb:
self.wandb.log({k: v for k, v in zip(self.keys[3:10], results)}) # log best.pt val results
self.wandb.log(dict(zip(self.keys[3:10], results)))
self.wandb.log({"Results": [wandb.Image(str(f), caption=f.name) for f in files]})
# Calling wandb.log. TODO: Refactor this into WandbLogger.log_model
if not self.opt.evolve:
Expand Down
13 changes: 7 additions & 6 deletions utils/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from PIL import Image, ImageDraw, ImageFont

from utils.general import (CONFIG_DIR, FONT, LOGGER, Timeout, check_font, check_requirements, clip_coords,
increment_path, is_ascii, try_except, xywh2xyxy, xyxy2xywh)
increment_path, is_ascii, threaded, try_except, xywh2xyxy, xyxy2xywh)
from utils.metrics import fitness

# Settings
Expand All @@ -32,9 +32,9 @@ class Colors:
# Ultralytics color palette https://ultralytics.com/
def __init__(self):
# hex = matplotlib.colors.TABLEAU_COLORS.values()
hex = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB',
'2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7')
self.palette = [self.hex2rgb('#' + c) for c in hex]
hexs = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB',
'2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7')
self.palette = [self.hex2rgb(f'#{c}') for c in hexs]
self.n = len(self.palette)

def __call__(self, i, bgr=False):
Expand Down Expand Up @@ -100,7 +100,7 @@ def box_label(self, box, label='', color=(128, 128, 128), txt_color=(255, 255, 2
if label:
tf = max(self.lw - 1, 1) # font thickness
w, h = cv2.getTextSize(label, 0, fontScale=self.lw / 3, thickness=tf)[0] # text width, height
outside = p1[1] - h - 3 >= 0 # label fits outside box
outside = p1[1] - h >= 3
p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3
cv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA) # filled
cv2.putText(self.im,
Expand Down Expand Up @@ -184,6 +184,7 @@ def output_to_target(output):
return np.array(targets)


@threaded
def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max_size=1920, max_subplots=16):
# Plot image grid with labels
if isinstance(images, torch.Tensor):
Expand Down Expand Up @@ -420,7 +421,7 @@ def plot_results(file='path/to/results.csv', dir=''):
ax = ax.ravel()
files = list(save_dir.glob('results*.csv'))
assert len(files), f'No results.csv files found in {save_dir.resolve()}, nothing to plot.'
for fi, f in enumerate(files):
for f in files:
try:
data = pd.read_csv(f)
s = [x.strip() for x in data.columns]
Expand Down
7 changes: 2 additions & 5 deletions val.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import os
import sys
from pathlib import Path
from threading import Thread

import numpy as np
import torch
Expand Down Expand Up @@ -255,10 +254,8 @@ def run(

# Plot images
if plots and batch_i < 3:
f = save_dir / f'val_batch{batch_i}_labels.jpg' # labels
Thread(target=plot_images, args=(im, targets, paths, f, names), daemon=True).start()
f = save_dir / f'val_batch{batch_i}_pred.jpg' # predictions
Thread(target=plot_images, args=(im, output_to_target(out), paths, f, names), daemon=True).start()
plot_images(im, targets, paths, save_dir / f'val_batch{batch_i}_labels.jpg', names) # labels
plot_images(im, output_to_target(out), paths, save_dir / f'val_batch{batch_i}_pred.jpg', names) # pred

callbacks.run('on_val_batch_end')

Expand Down