5252arg_parser .add_argument ('--int8' , dest = 'int8' , action = 'store_true' , help = 'whether to use int8 model for benchmark' )
5353args = arg_parser .parse_args ()
5454
55- def evaluate (model , eval_dataloader , metric , postprocess = None ):
55+ def evaluate (model , eval_dataloader , postprocess = None ):
5656 """Custom evaluate function to estimate the accuracy of the model.
5757
5858 Args:
@@ -61,12 +61,14 @@ def evaluate(model, eval_dataloader, metric, postprocess=None):
6161 Returns:
6262 accuracy (float): evaluation result, the larger is better.
6363 """
64+ from neural_compressor import METRICS
6465 from neural_compressor .model import Model
6566 model = Model (model )
6667 input_tensor = model .input_tensor
6768 output_tensor = model .output_tensor if len (model .output_tensor )> 1 else \
6869 model .output_tensor [0 ]
6970 iteration = - 1
71+ metric = METRICS ('tensorflow' )['topk' ]()
7072 if args .benchmark and args .mode == 'performance' :
7173 iteration = args .iters
7274
@@ -136,9 +138,6 @@ def run(self):
136138 accuracy_criterion = AccuracyCriterion (tolerable_loss = 0.01 ),
137139 op_type_dict = {'conv2d' :{ 'weight' :{'dtype' :['fp32' ]}, 'activation' :{'dtype' :['fp32' ]} }}
138140 )
139- from neural_compressor import METRICS
140- metrics = METRICS ('tensorflow' )
141- top1 = metrics ['topk' ]()
142141 from tensorflow .core .protobuf import saved_model_pb2
143142 sm = saved_model_pb2 .SavedModel ()
144143 with tf .io .gfile .GFile (args .input_graph , "rb" ) as f :
@@ -147,10 +146,9 @@ def run(self):
147146 from neural_compressor .data import TensorflowShiftRescale
148147 postprocess = TensorflowShiftRescale ()
149148 def eval (model ):
150- return evaluate (model , eval_dataloader , top1 , postprocess )
151- q_model = quantization .fit (graph_def , conf = conf , calib_dataloader = calib_dataloader ,
152- # eval_dataloader=eval_dataloader, eval_metric=top1)
153- eval_func = eval )
149+ return evaluate (model , eval_dataloader , postprocess )
150+ q_model = quantization .fit (graph_def , conf = conf , eval_func = eval ,
151+ calib_dataloader = calib_dataloader )
154152 q_model .save (args .output_graph )
155153
156154 if args .benchmark :
@@ -163,9 +161,6 @@ def eval(model):
163161 'filter' : None
164162 }
165163 dataloader = create_dataloader ('tensorflow' , dataloader_args )
166- from neural_compressor import METRICS
167- metrics = METRICS ('tensorflow' )
168- top1 = metrics ['topk' ]()
169164
170165 if args .int8 or args .input_graph .endswith ("-tune.pb" ):
171166 input_graph = args .input_graph
@@ -180,7 +175,7 @@ def eval(model):
180175 from neural_compressor .data import TensorflowShiftRescale
181176 postprocess = TensorflowShiftRescale ()
182177 def eval (model ):
183- return evaluate (model , dataloader , top1 , postprocess )
178+ return evaluate (model , dataloader , postprocess )
184179
185180 if args .mode == 'performance' :
186181 from neural_compressor .benchmark import fit
0 commit comments