Skip to content

Commit 8d0291f

Browse files
leeflixpre-commit-ci[bot]glenn-jocher
authored
Enable TensorFlow ops for --nms and --agnostic-nms (#7281)
* enable TensorFlow ops if flag --nms or --agnostic-nms is used * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update export.py * Update export.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <[email protected]>
1 parent 2da6866 commit 8d0291f

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

export.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ def export_pb(keras_model, im, file, prefix=colorstr('TensorFlow GraphDef:')):
327327
LOGGER.info(f'\n{prefix} export failure: {e}')
328328

329329

330-
def export_tflite(keras_model, im, file, int8, data, ncalib, prefix=colorstr('TensorFlow Lite:')):
330+
def export_tflite(keras_model, im, file, int8, data, nms, agnostic_nms, prefix=colorstr('TensorFlow Lite:')):
331331
# YOLOv5 TensorFlow Lite export
332332
try:
333333
import tensorflow as tf
@@ -343,13 +343,15 @@ def export_tflite(keras_model, im, file, int8, data, ncalib, prefix=colorstr('Te
343343
if int8:
344344
from models.tf import representative_dataset_gen
345345
dataset = LoadImages(check_dataset(data)['train'], img_size=imgsz, auto=False) # representative data
346-
converter.representative_dataset = lambda: representative_dataset_gen(dataset, ncalib)
346+
converter.representative_dataset = lambda: representative_dataset_gen(dataset, ncalib=100)
347347
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
348348
converter.target_spec.supported_types = []
349349
converter.inference_input_type = tf.uint8 # or tf.int8
350350
converter.inference_output_type = tf.uint8 # or tf.int8
351351
converter.experimental_new_quantizer = True
352352
f = str(file).replace('.pt', '-int8.tflite')
353+
if nms or agnostic_nms:
354+
converter.target_spec.supported_ops.append(tf.lite.OpsSet.SELECT_TF_OPS)
353355

354356
tflite_model = converter.convert()
355357
open(f, "wb").write(tflite_model)
@@ -524,7 +526,7 @@ def run(
524526
if pb or tfjs: # pb prerequisite to tfjs
525527
f[6] = export_pb(model, im, file)
526528
if tflite or edgetpu:
527-
f[7] = export_tflite(model, im, file, int8=int8 or edgetpu, data=data, ncalib=100)
529+
f[7] = export_tflite(model, im, file, int8=int8 or edgetpu, data=data, nms=nms, agnostic_nms=agnostic_nms)
528530
if edgetpu:
529531
f[8] = export_edgetpu(model, im, file)
530532
if tfjs:

0 commit comments

Comments
 (0)