@@ -168,6 +168,7 @@ class TextSimilarityTask(Task):
168168 def __init__ (self , task , model , batch_size = 1 , max_length = 384 , ** kwargs ):
169169 super ().__init__ (task = task , model = model , ** kwargs )
170170 self ._static_mode = True
171+ self ._check_predictor_type ()
171172 if not self .from_hf_hub :
172173 self ._check_task_files ()
173174 if self ._static_mode :
@@ -273,12 +274,22 @@ def _run_model(self, inputs):
273274 if "rocketqa" in self .model_name or "ernie-search" in self .model_name :
274275 with static_mode_guard ():
275276 for batch in inputs ["data_loader" ]:
276- input_ids , segment_ids = self ._batchify_fn (batch )
277- self .input_handles [0 ].copy_from_cpu (input_ids )
278- self .input_handles [1 ].copy_from_cpu (segment_ids )
279- self .predictor .run ()
280- scores = self .output_handle [0 ].copy_to_cpu ().tolist ()
281- results .extend (scores )
277+
278+ if self ._predictor_type == "paddle-inference" :
279+ input_ids , segment_ids = self ._batchify_fn (batch )
280+ self .input_handles [0 ].copy_from_cpu (input_ids )
281+ self .input_handles [1 ].copy_from_cpu (segment_ids )
282+ self .predictor .run ()
283+ scores = self .output_handle [0 ].copy_to_cpu ().tolist ()
284+ results .extend (scores )
285+ else :
286+ # onnx mode
287+ input_dict = {}
288+ input_ids , segment_ids = self ._batchify_fn (batch )
289+ input_dict ["input_ids" ] = input_ids
290+ input_dict ["token_type_ids" ] = segment_ids
291+ scores = self .predictor .run (None , input_dict )[0 ].tolist ()
292+ results .extend (scores )
282293 else :
283294 with static_mode_guard ():
284295 for batch in inputs ["data_loader" ]:
0 commit comments