Skip to content

Commit eb4146f

Browse files
zldrobitglenn-jocherunknownpre-commit-ci[bot]
authored
Add EdgeTPU support (ultralytics#3630)
* Add models/tf.py for TensorFlow and TFLite export * Set auto=False for int8 calibration * Update requirements.txt for TensorFlow and TFLite export * Read anchors directly from PyTorch weights * Add --tf-nms to append NMS in TensorFlow SavedModel and GraphDef export * Remove check_anchor_order, check_file, set_logging from import * Reformat code and optimize imports * Autodownload model and check cfg * update --source path, img-size to 320, single output * Adjust representative_dataset * Put representative dataset in tfl_int8 block * detect.py TF inference * weights to string * weights to string * cleanup tf.py * Add --dynamic-batch-size * Add xywh normalization to reduce calibration error * Update requirements.txt TensorFlow 2.3.1 -> 2.4.0 to avoid int8 quantization error * Fix imports Move C3 from models.experimental to models.common * Add models/tf.py for TensorFlow and TFLite export * Set auto=False for int8 calibration * Update requirements.txt for TensorFlow and TFLite export * Read anchors directly from PyTorch weights * Add --tf-nms to append NMS in TensorFlow SavedModel and GraphDef export * Remove check_anchor_order, check_file, set_logging from import * Reformat code and optimize imports * Autodownload model and check cfg * update --source path, img-size to 320, single output * Adjust representative_dataset * detect.py TF inference * Put representative dataset in tfl_int8 block * weights to string * weights to string * cleanup tf.py * Add --dynamic-batch-size * Add xywh normalization to reduce calibration error * Update requirements.txt TensorFlow 2.3.1 -> 2.4.0 to avoid int8 quantization error * Fix imports Move C3 from models.experimental to models.common * implement C3() and SiLU() * Add TensorFlow and TFLite Detection * Add --tfl-detect for TFLite Detection * Add int8 quantized TFLite inference in detect.py * Add --edgetpu for Edge TPU detection * Fix --img-size to add rectangle TensorFlow and TFLite input * Add --no-tf-nms to detect objects using models combined with TensorFlow NMS * Fix --img-size list type input * Update README.md * Add Android project for TFLite inference * Upgrade TensorFlow v2.3.1 -> v2.4.0 * Disable normalization of xywh * Rewrite names init in detect.py * Change input resolution 640 -> 320 on Android * Disable NNAPI * Update README.me --img 640 -> 320 * Update README.me for Edge TPU * Update README.md * Fix reshape dim to support dynamic batching * Fix reshape dim to support dynamic batching * Add epsilon argument in tf_BN, which is different between TF and PT * Set stride to None if not using PyTorch, and do not warmup without PyTorch * Add list support in check_img_size() * Add list input support in detect.py * sys.path.append('./') to run from yolov5/ * Add int8 quantization support for TensorFlow 2.5 * Add get_coco128.sh * Remove --no-tfl-detect in models/tf.py (Use tf-android-tfl-detect branch for EdgeTPU) * Update requirements.txt * Replace torch.load() with attempt_load() * Update requirements.txt * Add --tf-raw-resize to set half_pixel_centers=False * Remove android directory * Update README.md * Update README.md * Add multiple OS support for EdgeTPU detection * Fix export and detect * Export 3 YOLO heads with Edge TPU models * Remove xywh denormalization with Edge TPU models in detect.py * Fix saved_model and pb detect error * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix pre-commit.ci failure * Add edgetpu in export.py docstring * Fix Edge TPU model detection exported by TF 2.7 * Add class names for TF/TFLite in DetectMultibackend * Fix assignment with nl in TFLite Detection * Add check when getting Edge TPU compiler version * Add UTF-8 encoding in opening --data file for Windows * Remove redundant TensorFlow import * Add Edge TPU in export.py's docstring * Add the detect layer in Edge TPU model conversion * Default `dnn=False` * Cleanup data.yaml loading * Update detect.py * Update val.py * Comments and generalize data.yaml names Co-authored-by: Glenn Jocher <[email protected]> Co-authored-by: unknown <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 99cc35e commit eb4146f

File tree

4 files changed

+37
-8
lines changed

4 files changed

+37
-8
lines changed

detect.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
@torch.no_grad()
3939
def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s)
4040
source=ROOT / 'data/images', # file/dir/URL/glob, 0 for webcam
41+
data=ROOT / 'data/coco128.yaml', # dataset.yaml path
4142
imgsz=(640, 640), # inference size (height, width)
4243
conf_thres=0.25, # confidence threshold
4344
iou_thres=0.45, # NMS IOU threshold
@@ -76,7 +77,7 @@ def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s)
7677

7778
# Load model
7879
device = select_device(device)
79-
model = DetectMultiBackend(weights, device=device, dnn=dnn)
80+
model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data)
8081
stride, names, pt, jit, onnx, engine = model.stride, model.names, model.pt, model.jit, model.onnx, model.engine
8182
imgsz = check_img_size(imgsz, s=stride) # check image size
8283

@@ -204,6 +205,7 @@ def parse_opt():
204205
parser = argparse.ArgumentParser()
205206
parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolov5s.pt', help='model path(s)')
206207
parser.add_argument('--source', type=str, default=ROOT / 'data/images', help='file/dir/URL/glob, 0 for webcam')
208+
parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='(optional) dataset.yaml path')
207209
parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640], help='inference size h,w')
208210
parser.add_argument('--conf-thres', type=float, default=0.25, help='confidence threshold')
209211
parser.add_argument('--iou-thres', type=float, default=0.45, help='NMS IoU threshold')

export.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,24 @@ def export_tflite(keras_model, im, file, int8, data, ncalib, prefix=colorstr('Te
248248
LOGGER.info(f'\n{prefix} export failure: {e}')
249249

250250

251+
def export_edgetpu(keras_model, im, file, prefix=colorstr('Edge TPU:')):
252+
# YOLOv5 Edge TPU export https://coral.ai/docs/edgetpu/models-intro/
253+
try:
254+
cmd = 'edgetpu_compiler --version'
255+
out = subprocess.run(cmd, shell=True, capture_output=True, check=True)
256+
ver = out.stdout.decode().split()[-1]
257+
LOGGER.info(f'\n{prefix} starting export with Edge TPU compiler {ver}...')
258+
f = str(file).replace('.pt', '-int8_edgetpu.tflite')
259+
f_tfl = str(file).replace('.pt', '-int8.tflite') # TFLite model
260+
261+
cmd = f"edgetpu_compiler -s {f_tfl}"
262+
subprocess.run(cmd, shell=True, check=True)
263+
264+
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
265+
except Exception as e:
266+
LOGGER.info(f'\n{prefix} export failure: {e}')
267+
268+
251269
def export_tfjs(keras_model, im, file, prefix=colorstr('TensorFlow.js:')):
252270
# YOLOv5 TensorFlow.js export
253271
try:
@@ -285,6 +303,7 @@ def export_tfjs(keras_model, im, file, prefix=colorstr('TensorFlow.js:')):
285303

286304

287305
def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=False, prefix=colorstr('TensorRT:')):
306+
# YOLOv5 TensorRT export https://developer.nvidia.com/tensorrt
288307
try:
289308
check_requirements(('tensorrt',))
290309
import tensorrt as trt
@@ -356,7 +375,7 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
356375
):
357376
t = time.time()
358377
include = [x.lower() for x in include]
359-
tf_exports = list(x in include for x in ('saved_model', 'pb', 'tflite', 'tfjs')) # TensorFlow exports
378+
tf_exports = list(x in include for x in ('saved_model', 'pb', 'tflite', 'edgetpu', 'tfjs')) # TensorFlow exports
360379
file = Path(url2file(weights) if str(weights).startswith(('http:/', 'https:/')) else weights)
361380

362381
# Checks
@@ -405,15 +424,17 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
405424

406425
# TensorFlow Exports
407426
if any(tf_exports):
408-
pb, tflite, tfjs = tf_exports[1:]
427+
pb, tflite, edgetpu, tfjs = tf_exports[1:]
409428
assert not (tflite and tfjs), 'TFLite and TF.js models must be exported separately, please pass only one type.'
410429
model = export_saved_model(model, im, file, dynamic, tf_nms=nms or agnostic_nms or tfjs,
411430
agnostic_nms=agnostic_nms or tfjs, topk_per_class=topk_per_class, topk_all=topk_all,
412431
conf_thres=conf_thres, iou_thres=iou_thres) # keras model
413432
if pb or tfjs: # pb prerequisite to tfjs
414433
export_pb(model, im, file)
415-
if tflite:
416-
export_tflite(model, im, file, int8=int8, data=data, ncalib=100)
434+
if tflite or edgetpu:
435+
export_tflite(model, im, file, int8=int8 or edgetpu, data=data, ncalib=100)
436+
if edgetpu:
437+
export_edgetpu(model, im, file)
417438
if tfjs:
418439
export_tfjs(model, im, file)
419440

models/common.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import requests
1818
import torch
1919
import torch.nn as nn
20+
import yaml
2021
from PIL import Image
2122
from torch.cuda import amp
2223

@@ -276,14 +277,15 @@ def forward(self, x):
276277

277278
class DetectMultiBackend(nn.Module):
278279
# YOLOv5 MultiBackend class for python inference on various backends
279-
def __init__(self, weights='yolov5s.pt', device=None, dnn=False):
280+
def __init__(self, weights='yolov5s.pt', device=None, dnn=False, data=None):
280281
# Usage:
281282
# PyTorch: weights = *.pt
282283
# TorchScript: *.torchscript
283284
# CoreML: *.mlmodel
284285
# TensorFlow: *_saved_model
285286
# TensorFlow: *.pb
286287
# TensorFlow Lite: *.tflite
288+
# TensorFlow Edge TPU: *_edgetpu.tflite
287289
# ONNX Runtime: *.onnx
288290
# OpenCV DNN: *.onnx with dnn=True
289291
# TensorRT: *.engine
@@ -297,6 +299,9 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=False):
297299
pt, jit, onnx, engine, tflite, pb, saved_model, coreml = (suffix == x for x in suffixes) # backend booleans
298300
stride, names = 64, [f'class{i}' for i in range(1000)] # assign defaults
299301
w = attempt_download(w) # download if not local
302+
if data: # data.yaml path (optional)
303+
with open(data, errors='ignore') as f:
304+
names = yaml.safe_load(f)['names'] # class names
300305

301306
if jit: # TorchScript
302307
LOGGER.info(f'Loading {w} for TorchScript inference...')
@@ -343,7 +348,7 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=False):
343348
binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
344349
context = model.create_execution_context()
345350
batch_size = bindings['images'].shape[0]
346-
else: # TensorFlow model (TFLite, pb, saved_model)
351+
else: # TensorFlow (TFLite, pb, saved_model)
347352
if pb: # https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
348353
LOGGER.info(f'Loading {w} for TensorFlow *.pb inference...')
349354
import tensorflow as tf
@@ -425,6 +430,7 @@ def forward(self, im, augment=False, visualize=False, val=False):
425430
y[..., 1] *= h # y
426431
y[..., 2] *= w # w
427432
y[..., 3] *= h # h
433+
428434
y = torch.tensor(y) if isinstance(y, np.ndarray) else y
429435
return (y, []) if val else y
430436

val.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def run(data,
124124
(save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir
125125

126126
# Load model
127-
model = DetectMultiBackend(weights, device=device, dnn=dnn)
127+
model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data)
128128
stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine
129129
imgsz = check_img_size(imgsz, s=stride) # check image size
130130
half &= (pt or jit or engine) and device.type != 'cpu' # half precision only supported by PyTorch on CUDA

0 commit comments

Comments
 (0)