Skip to content

Commit f62609e

Browse files
authored
Update check_requirements() with cmds=() argument (#7543)
1 parent 4b284a1 commit f62609e

File tree

2 files changed

+5
-11
lines changed

2 files changed

+5
-11
lines changed

export.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -218,14 +218,8 @@ def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=F
218218
# YOLOv5 TensorRT export https://developer.nvidia.com/tensorrt
219219
try:
220220
assert im.device.type != 'cpu', 'export running on CPU but must be on GPU, i.e. `python export.py --device 0`'
221-
try:
222-
import tensorrt as trt
223-
except Exception:
224-
s = f"\n{prefix} tensorrt not found and is required by YOLOv5"
225-
LOGGER.info(f"{s}, attempting auto-update...")
226-
r = '-U nvidia-tensorrt --index-url https://pypi.ngc.nvidia.com'
227-
LOGGER.info(subprocess.check_output(f"pip install {r}", shell=True).decode())
228-
import tensorrt as trt
221+
check_requirements(('nvidia-tensorrt',), cmds=('-U --index-url https://pypi.ngc.nvidia.com',))
222+
import tensorrt as trt
229223

230224
if trt.__version__[0] == '7': # TensorRT 7 handling https://github.com/ultralytics/yolov5/issues/6012
231225
grid = model.model[-1].anchor_grid

utils/general.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=Fals
321321

322322

323323
@try_except
324-
def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), install=True):
324+
def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), install=True, cmds=()):
325325
# Check installed dependencies meet requirements (pass *.txt file or list of packages)
326326
prefix = colorstr('red', 'bold', 'requirements:')
327327
check_python() # check python version
@@ -334,7 +334,7 @@ def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), insta
334334
requirements = [x for x in requirements if x not in exclude]
335335

336336
n = 0 # number of packages updates
337-
for r in requirements:
337+
for i, r in enumerate(requirements):
338338
try:
339339
pkg.require(r)
340340
except Exception: # DistributionNotFound or VersionConflict if requirements not met
@@ -343,7 +343,7 @@ def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), insta
343343
LOGGER.info(f"{s}, attempting auto-update...")
344344
try:
345345
assert check_online(), f"'pip install {r}' skipped (offline)"
346-
LOGGER.info(check_output(f"pip install '{r}'", shell=True).decode())
346+
LOGGER.info(check_output(f"pip install '{r}' {cmds[i] if cmds else ''}", shell=True).decode())
347347
n += 1
348348
except Exception as e:
349349
LOGGER.warning(f'{prefix} {e}')

0 commit comments

Comments
 (0)