|
24 | 24 | } |
25 | 25 |
|
26 | 26 |
|
27 | | -def load(model, output_dir="./saved_results"): |
| 27 | +def load(output_dir="./saved_results", model=None): |
28 | 28 | from neural_compressor.common.base_config import ConfigRegistry |
29 | 29 |
|
30 | 30 | qconfig_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), "qconfig.json") |
31 | | - config_mapping = load_config_mapping(qconfig_file_path, ConfigRegistry.get_all_configs()["torch"]) |
32 | | - model.qconfig = config_mapping |
33 | | - # select load function |
34 | | - config_object = config_mapping[next(iter(config_mapping))] |
35 | | - if isinstance(config_object, FP8Config): |
36 | | - from neural_compressor.torch.algorithms.habana_fp8 import load |
37 | | - |
38 | | - return load(model, output_dir) |
| 31 | + with open(qconfig_file_path, "r") as f: |
| 32 | + per_op_qconfig = json.load(f) |
| 33 | + if " " in per_op_qconfig.keys(): # ipex qconfig format: {' ': {'q_op_infos': {'0': {'op_type': ... |
| 34 | + from neural_compressor.torch.algorithms.static_quant import load |
| 35 | + |
| 36 | + return load(output_dir) |
| 37 | + |
| 38 | + else: # FP8 |
| 39 | + config_mapping = load_config_mapping(qconfig_file_path, ConfigRegistry.get_all_configs()["torch"]) |
| 40 | + model.qconfig = config_mapping |
| 41 | + # select load function |
| 42 | + config_object = config_mapping[next(iter(config_mapping))] |
| 43 | + if isinstance(config_object, FP8Config): |
| 44 | + from neural_compressor.torch.algorithms.habana_fp8 import load |
| 45 | + |
| 46 | + return load(model, output_dir) |
0 commit comments