Skip to content

Commit 1323b48

Browse files
Remove .train() mode exports (#9429)
* Remove `.train()` mode exports No common use cases. Signed-off-by: Glenn Jocher <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Glenn Jocher <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent a4ed988 commit 1323b48

File tree

1 file changed

+4
-7
lines changed

1 file changed

+4
-7
lines changed

export.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def export_torchscript(model, im, file, optimize, prefix=colorstr('TorchScript:'
126126

127127

128128
@try_export
129-
def export_onnx(model, im, file, opset, train, dynamic, simplify, prefix=colorstr('ONNX:')):
129+
def export_onnx(model, im, file, opset, dynamic, simplify, prefix=colorstr('ONNX:')):
130130
# YOLOv5 ONNX export
131131
check_requirements('onnx')
132132
import onnx
@@ -140,8 +140,7 @@ def export_onnx(model, im, file, opset, train, dynamic, simplify, prefix=colorst
140140
f,
141141
verbose=False,
142142
opset_version=opset,
143-
training=torch.onnx.TrainingMode.TRAINING if train else torch.onnx.TrainingMode.EVAL,
144-
do_constant_folding=not train,
143+
do_constant_folding=True,
145144
input_names=['images'],
146145
output_names=['output'],
147146
dynamic_axes={
@@ -459,7 +458,6 @@ def run(
459458
include=('torchscript', 'onnx'), # include formats
460459
half=False, # FP16 half-precision export
461460
inplace=False, # set YOLOv5 Detect() inplace=True
462-
train=False, # model.train() mode
463461
keras=False, # use Keras
464462
optimize=False, # TorchScript: optimize for mobile
465463
int8=False, # CoreML/TF INT8 quantization
@@ -501,7 +499,7 @@ def run(
501499
im = torch.zeros(batch_size, 3, *imgsz).to(device) # image size(1,3,320,192) BCHW iDetection
502500

503501
# Update model
504-
model.train() if train else model.eval() # training mode = no Detect() layer grid construction
502+
model.eval()
505503
for k, m in model.named_modules():
506504
if isinstance(m, Detect):
507505
m.inplace = inplace
@@ -524,7 +522,7 @@ def run(
524522
if engine: # TensorRT required before ONNX
525523
f[1], _ = export_engine(model, im, file, half, dynamic, simplify, workspace, verbose)
526524
if onnx or xml: # OpenVINO requires ONNX
527-
f[2], _ = export_onnx(model, im, file, opset, train, dynamic, simplify)
525+
f[2], _ = export_onnx(model, im, file, opset, dynamic, simplify)
528526
if xml: # OpenVINO
529527
f[3], _ = export_openvino(file, metadata, half)
530528
if coreml: # CoreML
@@ -578,7 +576,6 @@ def parse_opt():
578576
parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
579577
parser.add_argument('--half', action='store_true', help='FP16 half-precision export')
580578
parser.add_argument('--inplace', action='store_true', help='set YOLOv5 Detect() inplace=True')
581-
parser.add_argument('--train', action='store_true', help='model.train() mode')
582579
parser.add_argument('--keras', action='store_true', help='TF: use Keras')
583580
parser.add_argument('--optimize', action='store_true', help='TorchScript: optimize for mobile')
584581
parser.add_argument('--int8', action='store_true', help='CoreML/TF INT8 quantization')

0 commit comments

Comments
 (0)