Skip to content

Commit a776002

Browse files
authored
add test (#4879)
1 parent da6c908 commit a776002

File tree

2 files changed

+41
-3
lines changed

2 files changed

+41
-3
lines changed

paddlenlp/taskflow/text_similarity.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import paddle
1616

17-
from paddlenlp.transformers import BertModel, BertTokenizer
17+
from paddlenlp.transformers import AutoModel, AutoTokenizer
1818

1919
from ..data import Pad, Tuple
2020
from ..transformers import ErnieCrossEncoder, ErnieTokenizer
@@ -155,7 +155,7 @@ def _construct_model(self, model):
155155
if "rocketqa" in model:
156156
self._model = ErnieCrossEncoder(model)
157157
else:
158-
self._model = BertModel.from_pretrained(self._task_path, pool_act="linear")
158+
self._model = AutoModel.from_pretrained(self._task_path, pool_act="linear")
159159
self._model.eval()
160160

161161
def _construct_tokenizer(self, model):
@@ -165,7 +165,7 @@ def _construct_tokenizer(self, model):
165165
if "rocketqa" in model:
166166
self._tokenizer = ErnieTokenizer.from_pretrained(model)
167167
else:
168-
self._tokenizer = BertTokenizer.from_pretrained(model)
168+
self._tokenizer = AutoTokenizer.from_pretrained(model)
169169

170170
def _check_input_text(self, inputs):
171171
inputs = inputs[0]
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
from paddlenlp.taskflow import Taskflow
18+
19+
20+
class TestTextSimilarityTask(unittest.TestCase):
21+
def test_bert_model(self):
22+
similarity = Taskflow(
23+
task="text_similarity",
24+
model="simbert-base-chinese",
25+
task_path="__internal_testing__/tiny-random-bert",
26+
)
27+
results = similarity([["世界上什么东西最小", "世界上什么东西最小?"]])
28+
self.assertTrue(len(results) == 1)
29+
self.assertTrue("text1" in results[0])
30+
self.assertTrue("text2" in results[0])
31+
self.assertIsInstance(results[0]["similarity"], float)
32+
33+
results = similarity([["光眼睛大就好看吗", "眼睛好看吗?"], ["小蝌蚪找妈妈怎么样", "小蝌蚪找妈妈是谁画的"]])
34+
self.assertTrue(len(results) == 2)
35+
for result in results:
36+
self.assertTrue("text1" in result)
37+
self.assertTrue("text2" in result)
38+
self.assertIsInstance(result["similarity"], float)

0 commit comments

Comments
 (0)