Skip to content

Commit e5c11ff

Browse files
zldrobitpre-commit-ci[bot]glenn-jocher
authored andcommitted
Fix TF exports >= 2GB (ultralytics#6292)
* Fix exporting saved_model: pb exceeds 2GB * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Replace TF v1.x API with TF v2.x API for saved_model export * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Clean up * Remove lambda in tf.function() * Revert "Remove lambda in tf.function()" to be compatible with TF v2.4 This reverts commit 46c7931f11dfdea6ae340c77287c35c30b9e0779. * Fix for pre-commit.ci * Cleanup1 * Cleanup2 * Backwards compatibility update * Update common.py * Update common.py * Cleanup3 Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <[email protected]>
1 parent 84dfc43 commit e5c11ff

File tree

2 files changed

+48
-55
lines changed

2 files changed

+48
-55
lines changed

export.py

Lines changed: 45 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@
1616
TensorFlow Edge TPU | `edgetpu` | yolov5s_edgetpu.tflite
1717
TensorFlow.js | `tfjs` | yolov5s_web_model/
1818
19+
Requirements:
20+
$ pip install -r requirements.txt coremltools onnx onnx-simplifier onnxruntime openvino-dev tensorflow-cpu # CPU
21+
$ pip install -r requirements.txt coremltools onnx onnx-simplifier onnxruntime-gpu openvino-dev tensorflow # GPU
22+
1923
Usage:
2024
$ python path/to/export.py --weights yolov5s.pt --include torchscript onnx openvino engine coreml tflite ...
2125
@@ -45,6 +49,7 @@
4549
import subprocess
4650
import sys
4751
import time
52+
import warnings
4853
from pathlib import Path
4954

5055
import pandas as pd
@@ -239,41 +244,14 @@ def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=F
239244
except Exception as e:
240245
LOGGER.info(f'\n{prefix} export failure: {e}')
241246

242-
def export_keras(model, im, file, dynamic, prefix=colorstr('Keras:')):
243-
# YOLOv5 TensorFlow SavedModel export
244-
try:
245-
import tensorflow as tf
246-
from tensorflow import keras
247-
248-
from models.keras import TFDetect, KerasModel
249-
250-
LOGGER.info(f'\n{prefix} starting export with keras {tf.__version__}...')
251-
f = str(file).replace('.pt', '.h5')
252-
batch_size, ch, *imgsz = list(im.shape) # BCHW
253-
254-
model = KerasModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz)
255-
im = tf.zeros((batch_size, *imgsz, 3)) # BHWC order for Keras
256-
_ = model.predict(im) # first call to create weights
257-
inputs = keras.Input(shape=(*imgsz, 3), batch_size=None if dynamic else batch_size)
258-
outputs = model.predict(inputs)
259-
keras_model = keras.Model(inputs=inputs, outputs=outputs, name="yolov5n")
260-
keras_model.trainable = False
261-
keras_model.summary()
262-
keras_model.save(f, save_format='h5')
263-
264-
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
265-
return keras_model, f
266-
except Exception as e:
267-
LOGGER.info(f'\n{prefix} export failure: {e}')
268-
return None, None
269247

270248
def export_saved_model(model, im, file, dynamic,
271249
tf_nms=False, agnostic_nms=False, topk_per_class=100, topk_all=100, iou_thres=0.45,
272-
conf_thres=0.25, prefix=colorstr('TensorFlow SavedModel:')):
250+
conf_thres=0.25, keras=False, prefix=colorstr('TensorFlow SavedModel:')):
273251
# YOLOv5 TensorFlow SavedModel export
274252
try:
275253
import tensorflow as tf
276-
from tensorflow import keras
254+
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
277255

278256
from models.tf import TFDetect, TFModel
279257

@@ -282,16 +260,28 @@ def export_saved_model(model, im, file, dynamic,
282260
batch_size, ch, *imgsz = list(im.shape) # BCHW
283261

284262
tf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz)
285-
im = tf.ones((batch_size, *imgsz, 3)) # BHWC order for TensorFlow
286-
y = tf_model.predict(im, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)
287-
y = tf_model.predict(im, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)
288-
inputs = keras.Input(shape=(*imgsz, 3), batch_size=None if dynamic else batch_size)
263+
im = tf.zeros((batch_size, *imgsz, 3)) # BHWC order for TensorFlow
264+
_ = tf_model.predict(im, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)
265+
inputs = tf.keras.Input(shape=(*imgsz, 3), batch_size=None if dynamic else batch_size)
289266
outputs = tf_model.predict(inputs, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)
290-
keras_model = keras.Model(inputs=inputs, outputs=outputs)
267+
keras_model = tf.keras.Model(inputs=inputs, outputs=outputs)
291268
keras_model.trainable = False
292269
keras_model.summary()
293-
keras_model.save(f, save_format='tf')
294-
270+
if keras:
271+
keras_model.save(f, save_format='tf')
272+
else:
273+
m = tf.function(lambda x: keras_model(x)) # full model
274+
spec = tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype)
275+
m = m.get_concrete_function(spec)
276+
frozen_func = convert_variables_to_constants_v2(m)
277+
tfm = tf.Module()
278+
tfm.__call__ = tf.function(lambda x: frozen_func(x), [spec])
279+
tfm.__call__(im)
280+
tf.saved_model.save(
281+
tfm,
282+
f,
283+
options=tf.saved_model.SaveOptions(experimental_custom_gradients=False) if
284+
check_version(tf.__version__, '2.6') else tf.saved_model.SaveOptions())
295285
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
296286
return keras_model, f
297287
except Exception as e:
@@ -358,13 +348,14 @@ def export_edgetpu(keras_model, im, file, prefix=colorstr('Edge TPU:')):
358348
cmd = 'edgetpu_compiler --version'
359349
help_url = 'https://coral.ai/docs/edgetpu/compiler/'
360350
assert platform.system() == 'Linux', f'export only supported on Linux. See {help_url}'
361-
if subprocess.run(cmd, shell=True).returncode != 0:
351+
if subprocess.run(cmd + ' >/dev/null', shell=True).returncode != 0:
362352
LOGGER.info(f'\n{prefix} export requires Edge TPU compiler. Attempting install from {help_url}')
353+
sudo = subprocess.run('sudo --version >/dev/null', shell=True).returncode == 0 # sudo installed on system
363354
for c in ['curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -',
364355
'echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | sudo tee /etc/apt/sources.list.d/coral-edgetpu.list',
365356
'sudo apt-get update',
366357
'sudo apt-get install edgetpu-compiler']:
367-
subprocess.run(c, shell=True, check=True)
358+
subprocess.run(c if sudo else c.replace('sudo ', ''), shell=True, check=True)
368359
ver = subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1]
369360

370361
LOGGER.info(f'\n{prefix} starting export with Edge TPU compiler {ver}...')
@@ -446,16 +437,17 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
446437
tf_exports = list(x in include for x in ('saved_model', 'pb', 'tflite', 'edgetpu', 'tfjs')) # TensorFlow exports
447438
file = Path(url2file(weights) if str(weights).startswith(('http:/', 'https:/')) else weights)
448439

449-
# Checks
450-
imgsz *= 2 if len(imgsz) == 1 else 1 # expand
451-
opset = 12 if ('openvino' in include) else opset # OpenVINO requires opset <= 12
452-
453440
# Load PyTorch model
454441
device = select_device(device)
455442
assert not (device.type == 'cpu' and half), '--half only compatible with GPU export, i.e. use --device 0'
456443
model = attempt_load(weights, map_location=device, inplace=True, fuse=True) # load FP32 model
457444
nc, names = model.nc, model.names # number of classes, class names
458445

446+
# Checks
447+
imgsz *= 2 if len(imgsz) == 1 else 1 # expand
448+
opset = 12 if ('openvino' in include) else opset # OpenVINO requires opset <= 12
449+
assert nc == len(names), f'Model class count {nc} != len(names) {len(names)}'
450+
459451
# Input
460452
gs = int(max(model.stride)) # grid size (max stride)
461453
imgsz = [check_img_size(x, gs) for x in imgsz] # verify img_size are gs-multiples
@@ -477,10 +469,12 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
477469

478470
for _ in range(2):
479471
y = model(im) # dry runs
480-
LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} ({file_size(file):.1f} MB)")
472+
shape = tuple(y[0].shape) # model output shape
473+
LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with output shape {shape} ({file_size(file):.1f} MB)")
481474

482475
# Exports
483-
f = [''] * 11 # exported filenames
476+
f = [''] * 10 # exported filenames
477+
warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning) # suppress TracerWarning
484478
if 'torchscript' in include:
485479
f[0] = export_torchscript(model, im, file, optimize)
486480
if 'engine' in include: # TensorRT required before ONNX
@@ -510,17 +504,15 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
510504
if tfjs:
511505
f[9] = export_tfjs(model, im, file)
512506

513-
if 'keras' in include:
514-
_, f[10] = export_keras(model, im, file, dynamic)
515-
516507
# Finish
517508
f = [str(x) for x in f if x] # filter out '' and None
518-
LOGGER.info(f'\nExport complete ({time.time() - t:.2f}s)'
519-
f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
520-
f"\nVisualize with https://netron.app"
521-
f"\nDetect with `python detect.py --weights {f[-1]}`"
522-
f" or `model = torch.hub.load('ultralytics/yolov5', 'custom', '{f[-1]}')"
523-
f"\nValidate with `python val.py --weights {f[-1]}`")
509+
if any(f):
510+
LOGGER.info(f'\nExport complete ({time.time() - t:.2f}s)'
511+
f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
512+
f"\nDetect: python detect.py --weights {f[-1]}"
513+
f"\nPyTorch Hub: model = torch.hub.load('ultralytics/yolov5', 'custom', '{f[-1]}')"
514+
f"\nValidate: python val.py --weights {f[-1]}"
515+
f"\nVisualize: https://netron.app")
524516
return f # return list of exported files/dirs
525517

526518

models/common.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,8 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=False, data=None):
359359
if saved_model: # SavedModel
360360
LOGGER.info(f'Loading {w} for TensorFlow SavedModel inference...')
361361
import tensorflow as tf
362-
model = tf.keras.models.load_model(w)
362+
keras = False # assume TF1 saved_model
363+
model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w)
363364
elif pb: # GraphDef https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
364365
LOGGER.info(f'Loading {w} for TensorFlow GraphDef inference...')
365366
import tensorflow as tf
@@ -431,7 +432,7 @@ def forward(self, im, augment=False, visualize=False, val=False):
431432
else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
432433
im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3)
433434
if self.saved_model: # SavedModel
434-
y = self.model(im, training=False).numpy()
435+
y = (self.model(im, training=False) if self.keras else self.model(im)[0]).numpy()
435436
elif self.pb: # GraphDef
436437
y = self.frozen_func(x=self.tf.constant(im)).numpy()
437438
elif self.tflite: # Lite

0 commit comments

Comments
 (0)