Skip to content

Commit f94b1ac

Browse files
authored
Merge pull request #4 from PINTO0309/fix_irversion
Fix to preserve domain and ir_version
2 parents e9c04e3 + ea41207 commit f94b1ac

File tree

2 files changed

+14
-6
lines changed

2 files changed

+14
-6
lines changed

snc4onnx/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from snc4onnx.onnx_network_combine import combine, main
22

3-
__version__ = '1.0.12'
3+
__version__ = '1.0.13'

snc4onnx/onnx_network_combine.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -230,11 +230,15 @@ def has_duplicates(seq):
230230
## 1. ONNX load
231231
tmp_onnx_graphs = []
232232
custom_domain_check_onnx_nodes = []
233+
max_ir_version: int = 0
233234
if len(onnx_graphs) > 0:
234235
for onnx_graph in onnx_graphs:
236+
domain: str = onnx_graph.domain
237+
ir_version: int = onnx_graph.ir_version
238+
max_ir_version = ir_version if max_ir_version < ir_version else max_ir_version
235239
gs_graph = gs.import_onnx(onnx_graph)
236240
gs_graph.cleanup().toposort()
237-
tmp_onnx_graphs.append(gs.export_onnx(gs_graph))
241+
tmp_onnx_graphs.append(gs.export_onnx(gs_graph, do_type_check=False, **{'domain': domain, 'ir_version': ir_version}))
238242
custom_domain_check_onnx_nodes = \
239243
custom_domain_check_onnx_nodes + \
240244
[
@@ -243,9 +247,13 @@ def has_duplicates(seq):
243247
]
244248
else:
245249
for onnx_path in input_onnx_file_paths:
246-
gs_graph = gs.import_onnx(onnx.load(onnx_path))
250+
onnx_graph = onnx.load(onnx_path)
251+
domain: str = onnx_graph.domain
252+
ir_version: int = onnx_graph.ir_version
253+
max_ir_version = ir_version if max_ir_version < ir_version else max_ir_version
254+
gs_graph = gs.import_onnx(onnx_graph)
247255
gs_graph.cleanup().toposort()
248-
tmp_onnx_graphs.append(gs.export_onnx(gs_graph))
256+
tmp_onnx_graphs.append(gs.export_onnx(gs_graph, do_type_check=False, **{'domain': domain, 'ir_version': ir_version}))
249257
custom_domain_check_onnx_graph = onnx.load(onnx_path)
250258
custom_domain_check_onnx_nodes = \
251259
custom_domain_check_onnx_nodes + \
@@ -436,7 +444,7 @@ def has_duplicates(seq):
436444

437445
# Cleaning
438446
src_gs_model.cleanup().toposort()
439-
combined_model = gs.export_onnx(src_gs_model)
447+
combined_model = gs.export_onnx(src_gs_model, do_type_check=False, **{'ir_version': max_ir_version})
440448

441449
## Output of onnx files in the process of fusion
442450
if output_of_onnx_file_in_the_process_of_fusion and output_onnx_file_path:
@@ -484,7 +492,7 @@ def has_duplicates(seq):
484492
replaced_output_names.append(tmp_replaced_output_name)
485493

486494
gs_combined_model.cleanup().toposort()
487-
combined_model = gs.export_onnx(gs_combined_model)
495+
combined_model = gs.export_onnx(gs_combined_model, do_type_check=False, **{'ir_version': max_ir_version})
488496

489497
## 4. Optimize
490498
try:

0 commit comments

Comments
 (0)