Skip to content

Commit e64f95a

Browse files
authored
Merge pull request #204 from Oneflow-Inc/update_resnet_to_onnx
Update resnet to onnx
2 parents cec8f14 + 163d02d commit e64f95a

File tree

2 files changed

+8
-10
lines changed

2 files changed

+8
-10
lines changed

Classification/cnns/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -586,9 +586,9 @@ onnx_model_dir = 'onnx/model'
586586

587587
**步骤三:调用 flow.onnx.export 方法**
588588

589-
接下来代码中会调用`oneflow_to_onnx()`方法,此方法包含了核心的模型转换方法: `flow.onnx.export()`
589+
接下来代码中会调用`oneflow_to_onnx()`方法,此方法包含了核心的模型转换方法: `oneflow_onnx.oneflow2onnx.util.export_onnx_model()`,更多OneFlow和ONNX模型转换相关的问题请看: [oneflow_convert_tools介绍](https://docs.oneflow.org/extended_topics/oneflow_convert_tools.html)
590590

591-
**flow.onnx.export** 将从 OneFlow 网络得到 ONNX 模型,它的第一个参数是上文所说的专用于推理的 job function,第二个参数是OneFlow模型路径,第三个参数是(转换后)ONNX模型的存放路径
591+
**oneflow_to_onnx** 将从 OneFlow 网络得到 ONNX 模型,它的第一个参数是上文所说的专用于推理的 job function,第二个参数是OneFlow模型路径,第三个参数是(转换后)ONNX模型的存放路径
592592

593593
```python
594594
onnx_model = oneflow_to_onnx(InferenceNet, flow_weights_path, onnx_model_dir, external_data=False)

Classification/cnns/resnet_to_onnx.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,9 @@
11
"""
22
Copyright 2020 The OneFlow Authors. All rights reserved.
3-
43
Licensed under the Apache License, Version 2.0 (the "License");
54
you may not use this file except in compliance with the License.
65
You may obtain a copy of the License at
7-
86
http://www.apache.org/licenses/LICENSE-2.0
9-
107
Unless required by applicable law or agreed to in writing, software
118
distributed under the License is distributed on an "AS IS" BASIS,
129
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@@ -29,6 +26,7 @@
2926
from resnet_model import resnet50
3027
import config as configs
3128
from imagenet1000_clsidx_to_labels import clsidx_2_labels
29+
from oneflow_onnx.oneflow2onnx.util import export_onnx_model
3230

3331
parser = configs.get_parser()
3432
args = parser.parse_args()
@@ -92,12 +90,12 @@ def oneflow_to_onnx(
9290
assert os.path.exists(flow_weights_path) and os.path.isdir(onnx_model_dir)
9391

9492
onnx_model_path = os.path.join(
95-
onnx_model_dir, os.path.basename(flow_weights_path) + ".onnx"
93+
onnx_model_dir, "model.onnx"
9694
)
97-
flow.onnx.export(
95+
export_onnx_model(
9896
job_func,
99-
flow_weights_path,
100-
onnx_model_path,
97+
flow_weight_dir=flow_weights_path,
98+
onnx_model_path=onnx_model_dir,
10199
opset=11,
102100
external_data=external_data,
103101
)
@@ -132,4 +130,4 @@ def check_equality(
132130
are_equal, onnx_res = check_equality(InferenceNet, onnx_model, image_path)
133131
clsidx_onnx = onnx_res.argmax()
134132
print("Are the results equal? {}".format("Yes" if are_equal else "No"))
135-
print("Class: {}; score: {}".format(clsidx_2_labels[clsidx_onnx], onnx_res.max()))
133+
print("Class: {}; score: {}".format(clsidx_2_labels[clsidx_onnx], onnx_res.max()))

0 commit comments

Comments
 (0)