|
15 | 15 | import onnx
|
16 | 16 | import torch
|
17 | 17 | import transformers
|
| 18 | +from onnxscript import version_converter |
18 | 19 | from packaging import version
|
19 | 20 | from transformers.modeling_utils import PreTrainedModel
|
20 | 21 |
|
|
31 | 32 | from olive.model.config import IoConfig
|
32 | 33 | from olive.model.utils import resolve_onnx_path
|
33 | 34 | from olive.passes import Pass
|
34 |
| -from olive.passes.onnx.common import get_external_data_config, model_proto_to_olive_model |
| 35 | +from olive.passes.onnx.common import get_external_data_config, ir_model_to_olive_model, model_proto_to_olive_model |
35 | 36 | from olive.passes.pass_config import BasePassConfig, PassConfigParam, get_user_script_data_config
|
36 | 37 |
|
37 | 38 | logger = logging.getLogger(__name__)
|
@@ -655,28 +656,9 @@ def _run_for_config(
|
655 | 656 | self, model: ONNXModelHandler, config: type[BasePassConfig], output_model_path: str
|
656 | 657 | ) -> ONNXModelHandler:
|
657 | 658 | output_model_path = resolve_onnx_path(output_model_path)
|
658 |
| - # since external data is saved in a separate file, we need to load the model to get the opset version |
659 |
| - model_proto = onnx.load(model.model_path, load_external_data=False) |
660 |
| - |
661 |
| - model_opset_version = model_proto.opset_import[0].version |
662 |
| - if model_opset_version == config.target_opset: |
663 |
| - logger.info("Model is already in target opset version %s.", config.target_opset) |
664 |
| - return model |
665 |
| - |
666 |
| - converted_model_proto = onnx.version_converter.convert_version(model_proto, config.target_opset) |
667 |
| - # copy the external data of original model to the new model |
668 |
| - dst_init_map = {init.name: init for init in converted_model_proto.graph.initializer} |
669 |
| - for src_init in model_proto.graph.initializer: |
670 |
| - if ( |
671 |
| - src_init.name in dst_init_map |
672 |
| - and src_init.HasField("data_location") |
673 |
| - and src_init.data_location == onnx.TensorProto.EXTERNAL |
674 |
| - ): |
675 |
| - dst_init_map[src_init.name].CopyFrom(src_init) |
676 |
| - onnx.external_data_helper.load_external_data_for_model( |
677 |
| - converted_model_proto, str(Path(model.model_path).resolve().parent) |
678 |
| - ) |
679 |
| - return model_proto_to_olive_model(converted_model_proto, output_model_path, config) |
| 659 | + model_ir = model.load_ir_model() |
| 660 | + version_converter.convert_version(model_ir, config.target_opset, fallback=True) |
| 661 | + return ir_model_to_olive_model(model_ir, output_model_path, config) |
680 | 662 |
|
681 | 663 |
|
682 | 664 | def _validate_dynamic_shapes(dynamic_shapes, dummy_inputs, dummy_kwargs, model):
|
|
0 commit comments