-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Description
Describe the issue
Onnxrunme tensort execution provider is crashing on loading model when parsing ReduceMax instruction
see more detailed below in ho wto reproduce it.
I used ONNX opset 17 and 16 with pytorch 2.0
Same issue with onnxruntime 1.13 to 1.15
using
TensorRT-8.6.1.6
cudnn 8.9.3
cuda 11.8
python 3.10
pytorch 2.0
same issue with C++ ort APIs
Note that because of #16883
I had to use a released package (1.13.1)
To reproduce
- get a standard faster-rcnn model from Facebookresearch detectron2 model zoo in func setup()
- convert the torch model to onnx using function export_tracing()
- prepare a sample input with function get_sample_inputs() to be used for inferencing with torch and inferencing with onnxruntime tensorrt EP
- check the resultant onnx model with check_onnx_model() : we see that model is valid but graph is not (don't bother these is a bug in the onnx checker functions (!)
- do the inference with torch and onnxruntime and verify results match
- => here we get and error with tensorrt EP . while inference is ok with CPU EP and Cuda EP - results match with these.
as traces show : the problem is with ReduceMax operator which is not well interpreetd by onnxruntime EP , not tensorrt (when I run the model to native tensorrt there is no problem - after converting the ONNX to trt format (note that I need a special version of nvidia graph-surgeon to convert to trt)
Note
the following code demonstate the error:
#!/usr/bin/env python
# Copyright (c) Facebook, Inc. and its affiliates.
# this is an adaptation of detectron2/tools/deploy/export_model.py
# it does export of a faster-rcnn model to onnx and test it vs the original detectron2 model
# requires any RGB input image (jpg or png)
import argparse
import os
from typing import Dict, List, Tuple
import torch
from torch import Tensor, nn
import detectron2.data.transforms as T
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import get_cfg
from detectron2.data import build_detection_test_loader, detection_utils
from detectron2.evaluation import COCOEvaluator, inference_on_dataset, print_csv_format
from detectron2 import model_zoo
"""
# cannot use detectron2 export lib since it depends on Caffe2 which is not provided anymore with pytorch dist
from detectron2.export import (
STABLE_ONNX_OPSET_VERSION,
TracingAdapter,
dump_torchscript_IR,
scripting_with_instances,
)
"""
# # use export lib stripped out from caffe2 (/detectron2/export/__init__.py)
from lib.export import (
TracingAdapter,
dump_torchscript_IR,
scripting_with_instances,
)
from detectron2.modeling import GeneralizedRCNN, RetinaNet, build_model
from detectron2.modeling.postprocessing import detector_postprocess
from detectron2.projects.point_rend import add_pointrend_config
from detectron2.structures import Boxes
from detectron2.utils.env import TORCH_VERSION
from detectron2.utils.file_io import PathManager
from detectron2.utils.logger import setup_logger
import onnx
import onnxruntime as ort
import numpy as np
import cv2 as cv2
def setup_cfg(args):
cfg = get_cfg()
#use detectron2 satndard faster rcnn
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # set threshold for this model
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-Detection/faster_rcnn_R_50_FPN_1x.yaml")
cfg.MODEL.DEVICE = 'cuda'
# cuda context is initialized before creating dataloader, so we don't fork anymore
cfg.DATALOADER.NUM_WORKERS = 0
add_pointrend_config(cfg)
cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/faster_rcnn_R_50_FPN_1x.yaml"))
cfg.merge_from_list(args.opts)
cfg.freeze()
return cfg
# experimental. API not yet final
def export_tracing(torch_model, inputs):
assert TORCH_VERSION >= (1, 8)
image = inputs[0]["image"]
inputs = [{"image": image}] # remove other unused keys
inference=None
"""
if isinstance(torch_model, GeneralizedRCNN):
def inference(model, inputs):
# use do_postprocess=False so it returns ROI mask
inst = model.inference(inputs, do_postprocess=False)[0]
return [{"instances": inst}]
else:
inference = None # assume that we just call the model directly
"""
traceable_model = TracingAdapter(torch_model, inputs, inference)
with PathManager.open(os.path.join(args.output, "faster_rcnn_fpn.onnx"), "wb") as f:
torch.onnx.export(
traceable_model,
(image,),
f,
do_constant_folding=True,
export_params=True,
input_names=["image"], # the model's input names
output_names=["boxes", "labels", "scores", "image_dims"], # the model's output names
dynamic_axes={
"image" : {1: "height", 2: "width"},
"boxes" : {0: "findings"}, # boxes is a tensor of shape [number of findings, 4]
"labels" : {0: "findings"},
"scores" : {0: "findings"}
},
verbose=True,
opset_version=17) #issue is same with opset 16 and opset 18 is not validated for pytorch 2.0
logger.info("Inputs schema: " + str(traceable_model.inputs_schema))
logger.info("Outputs schema: " + str(traceable_model.outputs_schema))
onnx_model_path = os.path.join(args.output, "faster_rcnn_fpn.onnx")
onnx_model = onnx.load(onnx_model_path)
return onnx_model
def get_sample_inputs(args):
if args.sample_image is None:
# get a first batch from dataset
data_loader = build_detection_test_loader(cfg, cfg.DATASETS.TEST[0])
first_batch = next(iter(data_loader))
return first_batch
else:
# get a sample data
original_image = cv2.imread("./input.jpg")
print ("original_image input shape :", original_image.shape)
# Do same preprocessing as DefaultPredictor
aug = T.ResizeShortestEdge(
[cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST
)
image_with_different_size = aug.get_transform(original_image).apply_image(original_image)
cv2.imwrite("./inputExpanded.jpg", image_with_different_size)
image = original_image
height, width = original_image.shape[:2]
image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1)) # need chanel first for onnx
print ("image input shape :", image.shape)
inputs = {"image": image, "height": height, "width": width}
# Sample ready
sample_inputs = [inputs]
return sample_inputs
def check_onnx_model (onnx_model):
# Check the model
try:
onnx.checker.check_model(onnx_model, full_check=True)
except onnx.checker.ValidationError as e:
print("The model is invalid: %s" % e)
else:
print("The model is valid!")
# check the onnx graph
try:
graph = onnx_model.graph
onnx.checker.check_graph(graph)
except onnx.checker.ValidationError as e:
print("The graph is invalid: %s" % e)
else:
print("The graph is valid!")
input_shapes = [[d.dim_value for d in _input.type.tensor_type.shape.dim] for _input in onnx_model.graph.input]
print ('onnx model input shapes', input_shapes)
return None
def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
def eval_onnx_model (torch_model, onnx_model, sample_inputs, args):
# get D2 results
torch_model.eval()
torch_outputs = torch_model(sample_inputs)
print ('torch_outputs: ', torch_outputs)
print ('torch size of outputs: ', len(torch_outputs))
t_outputs_scores = to_numpy(torch_outputs[0]['instances'].scores)
print('d2_torch_scores: ', t_outputs_scores)
t_outputs_boxes = to_numpy(torch_outputs[0]['instances'].pred_boxes.tensor)
print('d2_torch_boxes: ', t_outputs_boxes)
t_outputs_classes = to_numpy(torch_outputs[0]['instances'].pred_classes)
print('d2_torch_classes: ', t_outputs_classes)
print('')
# get ONNXRT results
onnx_model_path = os.path.join(args.output, "faster_rcnn_fpn.onnx")
providers = [('TensorrtExecutionProvider')]
#providers = [('CUDAExecutionProvider')] # works !
sess_opt = ort.SessionOptions()
sess = ort.InferenceSession(onnx_model_path, sess_options=sess_opt, providers=providers)
input_name = sess.get_inputs()[0].name
print("input name", input_name)
input_shape = sess.get_inputs()[0].shape
print("input shape", input_shape)
input_type = sess.get_inputs()[0].type
print("input type", input_type)
output_name = sess.get_outputs()[0].name
print("output name", output_name)
output_shape = sess.get_outputs()[0].shape
print("output shape", output_shape)
output_type = sess.get_outputs()[0].type
print("output type", output_type)
image = sample_inputs[0]['image']
np_image = image.cpu().numpy()
# compute ONNX Runtime output prediction
ort_inputs = {sess.get_inputs()[0].name: np_image}
ort_outputs = sess.run(None, ort_inputs)
print ('ort_outputs: ', ort_outputs)
print('ort_outputs number: ', len(ort_outputs))
print('')
boxes = ort_outputs[0]
classes = ort_outputs[1]
scores = ort_outputs[2]
print ('ort_boxes : ', boxes)
print ('ort scores : ', scores)
print ('ort classes : ', classes)
print('')
# eval torch and onnxrt outputs
np.testing.assert_allclose(t_outputs_boxes, boxes, rtol=1e-03, atol=1e-05)
np.testing.assert_allclose(t_outputs_scores, scores, rtol=1e-03, atol=1e-05)
np.testing.assert_allclose(t_outputs_classes, classes, rtol=1e-03, atol=1e-05)
print('detectron2 torch and onnx models results match!')
print('')
return None
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Export a model for deployment.")
parser.add_argument("--sample-image", default=None, type=str, help="sample image for input")
parser.add_argument("--output", help="output directory for the converted model")
parser.add_argument(
"opts",
help="Modify config options using the command-line",
default=None,
nargs=argparse.REMAINDER,
)
args = parser.parse_args()
logger = setup_logger()
logger.info("Command line arguments: " + str(args))
PathManager.mkdirs(args.output)
cfg = setup_cfg(args)
# create a torch model
torch_model = build_model(cfg)
DetectionCheckpointer(torch_model).resume_or_load(cfg.MODEL.WEIGHTS)
torch_model.eval()
# convert and save model
sample_inputs = get_sample_inputs(args)
onnx_model = export_tracing(torch_model, sample_inputs)
check_onnx_model (onnx_model)
eval_onnx_model(torch_model, onnx_model, sample_inputs, args)
logger.info("Success.")
What exact command you run:
python3 export_model.py --output onnx_output --sample-image input.jpg
Full logs or other relevant observations:
[04/04 16:14:53 detectron2]: Command line arguments: Namespace(sample_image='input.jpg', output='onnx_output', opts=[])
original_image input shape : (480, 640, 3)
image input shape : torch.Size([3, 480, 640])
%/model/ReduceMax_output_0 : Long(2, strides=[1], requires_grad=0, device=cpu) = **onnx::ReduceMax[axes=[0],** keepdims=0, onnx_name="/model/ReduceMax"](%/model/Concat_1_output_0), scope: lib.export.flatten.TracingAdapter::/detectron2.modeling.meta_arch.rcnn.GeneralizedRCNN::model # /usr/local/lib/python3.10/dist-packages/**detectron2**/structures/image_list.py:83:0
%max_coordinate.3 : Float(device=cpu) = **onnx::ReduceMax[keepdims=0]**(%/model/roi_heads/Cast_9_output_0) # /usr/local/lib/python3.10/dist-packages/**torchvision**/ops/boxes.py:91:21
============= Diagnostic Run torch.onnx.export version 2.0.0+cu118 =============
verbose: False, log level: Level.ERROR
======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================
[04/04 16:15:02 detectron2]: Inputs schema: TupleSchema(schemas=[ListSchema(schemas=[DictSchema(schemas=[IdentitySchema()], sizes=[1], keys=['image'])], sizes=[1])], sizes=[1])
[04/04 16:15:02 detectron2]: Outputs schema: ListSchema(schemas=[DictSchema(schemas=[InstancesSchema(schemas=[TensorWrapSchema(class_name='detectron2.structures.Boxes'), IdentitySchema(), IdentitySchema()], sizes=[1, 1, 1], keys=['pred_boxes', 'pred_classes', 'scores'])], sizes=[4], keys=['instances'])], sizes=[4])
The model is valid!
The graph is invalid: Unrecognized attribute: axes for operator ReduceMax
==> Context: Bad node spec for node. Name: /model/ReduceMax OpType: ReduceMax
onnx model input shapes [[3, 0, 0]]
2023-04-04 16:37:52.173723690 [E:onnxruntime:Default, tensorrt_execution_provider.h:61 log] [2023-04-04 16:37:52 ERROR] **ReduceMax_1597: at least 1 dimensions are required for input.**
2**023-04-04 16:37:52.324418966 [E:onnxruntime:, inference_session.cc:1532 operator()] Exception during initialization: /onnxruntime_src/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:897 SubGraphCollection_t onnxruntime::TensorrtExecutionProvider::GetSupportedList(SubGraphCollection_t, int, int, const onnxruntime::GraphViewer&, bool*) const [ONNXRuntimeError] : 1 : FAIL : TensorRT input: /model/proposal_generator/GatherND_2_output_0 has no shape specified. Please run shape inference on the onnx model first. Details can be found in https://onnxruntime.ai/docs/execution-providers/TensorRT-ExecutionProvider.html#shape-inference-for-tensorrt-subgraphs**
Traceback (most recent call last):
File "/cad-engine/export_model.py", line 264, in <module>
eval_onnx_model(torch_model, onnx_model, sample_inputs, args)
File "/cad-engine/export_model.py", line 190, in eval_onnx_model
sess = ort.InferenceSession(onnx_model_path, sess_options=sess_opt, providers=providers)
File "/usr/local/lib/python3.10/dist-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 360, in __init__
self._create_inference_session(providers, provider_options, disabled_optimizers)
File "/usr/local/lib/python3.10/dist-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 408, in _create_inference_session
sess.initialize_session(providers, provider_options, disabled_optimizers)
onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Exception duri
Urgency
I am blocked with onnxrt and need to revert to tensorrt native APIs which defeats our portability strategy.
Platform
Linux
OS Version
SLES15 SP4
ONNX Runtime Installation
Released Package
ONNX Runtime Version or Commit ID
1.13.1
ONNX Runtime API
Python
Architecture
X64
Execution Provider
TensorRT
Execution Provider Library Version
TensorRT-8.6.1.6