Skip to content

Commit 1535aec

Browse files
authored
Refactor OnnxOpVersionConversion to use the onnxscript version converter (#1784)
Refactor OnnxOpVersionConversion to use the onnxscript version converter with the fallback option. When the version converter implemented in ONNX IR supports the target version, the converter will be used. Otherwise, we fall back to use the version converter in onnx. The API call handles initializers to ensure models >2gb can be supported.
1 parent 3bbc554 commit 1535aec

File tree

2 files changed

+6
-24
lines changed

2 files changed

+6
-24
lines changed

olive/passes/onnx/conversion.py

Lines changed: 5 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import onnx
1616
import torch
1717
import transformers
18+
from onnxscript import version_converter
1819
from packaging import version
1920
from transformers.modeling_utils import PreTrainedModel
2021

@@ -31,7 +32,7 @@
3132
from olive.model.config import IoConfig
3233
from olive.model.utils import resolve_onnx_path
3334
from olive.passes import Pass
34-
from olive.passes.onnx.common import get_external_data_config, model_proto_to_olive_model
35+
from olive.passes.onnx.common import get_external_data_config, ir_model_to_olive_model, model_proto_to_olive_model
3536
from olive.passes.pass_config import BasePassConfig, PassConfigParam, get_user_script_data_config
3637

3738
logger = logging.getLogger(__name__)
@@ -655,28 +656,9 @@ def _run_for_config(
655656
self, model: ONNXModelHandler, config: type[BasePassConfig], output_model_path: str
656657
) -> ONNXModelHandler:
657658
output_model_path = resolve_onnx_path(output_model_path)
658-
# since external data is saved in a separate file, we need to load the model to get the opset version
659-
model_proto = onnx.load(model.model_path, load_external_data=False)
660-
661-
model_opset_version = model_proto.opset_import[0].version
662-
if model_opset_version == config.target_opset:
663-
logger.info("Model is already in target opset version %s.", config.target_opset)
664-
return model
665-
666-
converted_model_proto = onnx.version_converter.convert_version(model_proto, config.target_opset)
667-
# copy the external data of original model to the new model
668-
dst_init_map = {init.name: init for init in converted_model_proto.graph.initializer}
669-
for src_init in model_proto.graph.initializer:
670-
if (
671-
src_init.name in dst_init_map
672-
and src_init.HasField("data_location")
673-
and src_init.data_location == onnx.TensorProto.EXTERNAL
674-
):
675-
dst_init_map[src_init.name].CopyFrom(src_init)
676-
onnx.external_data_helper.load_external_data_for_model(
677-
converted_model_proto, str(Path(model.model_path).resolve().parent)
678-
)
679-
return model_proto_to_olive_model(converted_model_proto, output_model_path, config)
659+
model_ir = model.load_ir_model()
660+
version_converter.convert_version(model_ir, config.target_opset, fallback=True)
661+
return ir_model_to_olive_model(model_ir, output_model_path, config)
680662

681663

682664
def _validate_dynamic_shapes(dynamic_shapes, dummy_inputs, dummy_kwargs, model):

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
numpy
22
onnx
3-
onnxscript>=0.2.4
3+
onnxscript>=0.2.5
44
optuna
55
pandas
66
pydantic

0 commit comments

Comments
 (0)