2020from neural_compressor .common import logger
2121from neural_compressor .common .base_config import BaseConfig , get_all_config_set_from_config_registry
2222from neural_compressor .common .base_tuning import EvaluationFuncWrapper , TuningConfig , init_tuning
23- from neural_compressor .common .utils import dump_elapsed_time
23+ from neural_compressor .common .utils import call_counter , dump_elapsed_time
2424from neural_compressor .tensorflow .quantization import quantize_model
2525from neural_compressor .tensorflow .quantization .config import FRAMEWORK_NAME , StaticQuantConfig
2626from neural_compressor .tensorflow .utils import BaseModel , Model , constants
@@ -36,6 +36,7 @@ def get_all_config_set() -> Union[BaseConfig, List[BaseConfig]]:
3636
3737
3838@dump_elapsed_time ("Pass auto-tune" )
39+ @call_counter
3940def autotune (
4041 model : Union [str , tf .keras .Model , BaseModel ],
4142 tune_config : TuningConfig ,
@@ -52,7 +53,7 @@ def autotune(
5253 baseline : float = eval_func_wrapper .evaluate (model )
5354 tuning_monitor .set_baseline (baseline )
5455 tuning_logger .tuning_start ()
55- for trial_index , quant_config in enumerate (config_loader ):
56+ for trial_index , quant_config in enumerate (config_loader , 1 ):
5657 tuning_logger .trial_start (trial_index = trial_index )
5758 tuning_logger .execution_start ()
5859 logger .info (quant_config .to_dict ())
@@ -65,8 +66,14 @@ def autotune(
6566 tuning_logger .trial_end (trial_index )
6667 if tuning_monitor .need_stop ():
6768 logger .info ("Stopped tuning." )
68- best_quant_config : BaseConfig = tuning_monitor .get_best_quant_config ()
69- best_quant_model = quantize_model (model , quant_config , calib_dataloader , calib_iteration )
69+ best_trial_record = tuning_monitor .get_best_trial_record ()
70+ if best_trial_record .trial_index != trial_index :
71+ logger .info ("Re-quantizing with best quantization config..." )
72+ del q_model
73+ best_quant_config : BaseConfig = best_trial_record .quant_config
74+ best_quant_model = quantize_model (model , best_quant_config , calib_dataloader , calib_iteration )
75+ else :
76+ best_quant_model = q_model
7077 break
7178 tuning_logger .tuning_end ()
7279 return best_quant_model
0 commit comments