Skip to content

Commit 7b31a53

Browse files
authored
Add tensorrt>=7.0.0 checks (#6193)
* Add `tensorrt>=7.0.0` checks * Update export.py * Update common.py * Update export.py
1 parent a2f4a17 commit 7b31a53

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

export.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@
6161
from models.yolo import Detect
6262
from utils.activations import SiLU
6363
from utils.datasets import LoadImages
64-
from utils.general import (LOGGER, check_dataset, check_img_size, check_requirements, colorstr, file_size, print_args,
65-
url2file)
64+
from utils.general import (LOGGER, check_dataset, check_img_size, check_requirements, check_version, colorstr,
65+
file_size, print_args, url2file)
6666
from utils.torch_utils import select_device
6767

6868

@@ -174,14 +174,14 @@ def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=F
174174
check_requirements(('tensorrt',))
175175
import tensorrt as trt
176176

177-
opset = (12, 13)[trt.__version__[0] == '8'] # test on TensorRT 7.x and 8.x
178-
if opset == 12: # TensorRT 7 handling https://github.com/ultralytics/yolov5/issues/6012
177+
if trt.__version__[0] == 7: # TensorRT 7 handling https://github.com/ultralytics/yolov5/issues/6012
179178
grid = model.model[-1].anchor_grid
180179
model.model[-1].anchor_grid = [a[..., :1, :1, :] for a in grid]
181-
export_onnx(model, im, file, opset, train, False, simplify)
180+
export_onnx(model, im, file, 12, train, False, simplify) # opset 12
182181
model.model[-1].anchor_grid = grid
183182
else: # TensorRT >= 8
184-
export_onnx(model, im, file, opset, train, False, simplify)
183+
check_version(trt.__version__, '7.0.0', hard=True) # require tensorrt>=8.0.0
184+
export_onnx(model, im, file, 13, train, False, simplify) # opset 13
185185
onnx = file.with_suffix('.onnx')
186186
assert onnx.exists(), f'failed to export ONNX file: {onnx}'
187187

models/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=False, data=None):
337337
elif engine: # TensorRT
338338
LOGGER.info(f'Loading {w} for TensorRT inference...')
339339
import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download
340-
check_version(trt.__version__, '8.0.0', verbose=True) # version requirement
340+
check_version(trt.__version__, '7.0.0', hard=True) # require tensorrt>=7.0.0
341341
Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
342342
logger = trt.Logger(trt.Logger.INFO)
343343
with open(w, 'rb') as f, trt.Runtime(logger) as runtime:

0 commit comments

Comments
 (0)