Skip to content

Commit d66ab8e

Browse files
jadepengw5688414
andauthored
TextSimilarityTask support onnx (#5841)
* TextSimilarityTask support onnx * Update text_similarity.py --------- Co-authored-by: w5688414 <[email protected]>
1 parent 4de53b4 commit d66ab8e

File tree

1 file changed

+17
-6
lines changed

1 file changed

+17
-6
lines changed

paddlenlp/taskflow/text_similarity.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ class TextSimilarityTask(Task):
168168
def __init__(self, task, model, batch_size=1, max_length=384, **kwargs):
169169
super().__init__(task=task, model=model, **kwargs)
170170
self._static_mode = True
171+
self._check_predictor_type()
171172
if not self.from_hf_hub:
172173
self._check_task_files()
173174
if self._static_mode:
@@ -273,12 +274,22 @@ def _run_model(self, inputs):
273274
if "rocketqa" in self.model_name or "ernie-search" in self.model_name:
274275
with static_mode_guard():
275276
for batch in inputs["data_loader"]:
276-
input_ids, segment_ids = self._batchify_fn(batch)
277-
self.input_handles[0].copy_from_cpu(input_ids)
278-
self.input_handles[1].copy_from_cpu(segment_ids)
279-
self.predictor.run()
280-
scores = self.output_handle[0].copy_to_cpu().tolist()
281-
results.extend(scores)
277+
278+
if self._predictor_type == "paddle-inference":
279+
input_ids, segment_ids = self._batchify_fn(batch)
280+
self.input_handles[0].copy_from_cpu(input_ids)
281+
self.input_handles[1].copy_from_cpu(segment_ids)
282+
self.predictor.run()
283+
scores = self.output_handle[0].copy_to_cpu().tolist()
284+
results.extend(scores)
285+
else:
286+
# onnx mode
287+
input_dict = {}
288+
input_ids, segment_ids = self._batchify_fn(batch)
289+
input_dict["input_ids"] = input_ids
290+
input_dict["token_type_ids"] = segment_ids
291+
scores = self.predictor.run(None, input_dict)[0].tolist()
292+
results.extend(scores)
282293
else:
283294
with static_mode_guard():
284295
for batch in inputs["data_loader"]:

0 commit comments

Comments
 (0)