-
Notifications
You must be signed in to change notification settings - Fork 3.1k
Add text semantic matching for taskflow #3003
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 8 commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
8130dbf
Add text semantic matching for taskflow
w5688414 1064eb2
Incorporate text semantic search into text similarity
w5688414 3cf0954
Merge branch 'develop' into taskflow
w5688414 e7af2bb
Add rocketqa base
w5688414 50271be
Merge branch 'taskflow' of https://github.com/w5688414/PaddleNLP into…
w5688414 40b24be
Merge branch 'develop' into taskflow
w5688414 e21ddf8
Merge branch 'taskflow' of https://github.com/w5688414/PaddleNLP into…
w5688414 3fd1110
Update taskflow docs
w5688414 93bc2f4
Add rocketqa based models for taskflow
w5688414 955fa3a
Merge branch 'develop' into taskflow
w5688414 390cd04
Text similarity support paddle inference
w5688414 1ee1bc7
Merge branch 'taskflow' of https://github.com/w5688414/PaddleNLP into…
w5688414 488ca38
remove unused comments
w5688414 70bc2c0
Merge branch 'develop' into taskflow
w5688414 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,6 +14,7 @@ | |
|
|
||
| import paddle | ||
| from paddlenlp.transformers import BertModel, BertTokenizer | ||
| from ..transformers import ErnieCrossEncoder, ErnieTokenizer | ||
|
|
||
| from ..data import Pad, Tuple | ||
| from .utils import static_mode_guard | ||
|
|
@@ -59,17 +60,21 @@ class TextSimilarityTask(Task): | |
| "https://bj.bcebos.com/paddlenlp/taskflow/text_similarity/simbert-base-chinese/model_config.json", | ||
| "1254bbd7598457a9dad0afcb2e24b70c" | ||
| ], | ||
| } | ||
| }, | ||
| } | ||
|
|
||
| def __init__(self, task, model, batch_size=1, max_seq_len=128, **kwargs): | ||
| def __init__(self, task, model, batch_size=1, max_seq_len=384, **kwargs): | ||
| super().__init__(task=task, model=model, **kwargs) | ||
| self._check_task_files() | ||
| if ('rocketqa' not in model): | ||
| self._check_task_files() | ||
| self._get_inference_model() | ||
| else: | ||
| self._construct_model(model) | ||
|
||
| self._construct_tokenizer(model) | ||
| self._get_inference_model() | ||
| self._batch_size = batch_size | ||
| self._max_seq_len = max_seq_len | ||
| self._usage = usage | ||
| self.model_name = model | ||
|
|
||
| def _construct_input_spec(self): | ||
| """ | ||
|
|
@@ -88,15 +93,21 @@ def _construct_model(self, model): | |
| """ | ||
| Construct the inference model for the predictor. | ||
| """ | ||
| self._model = BertModel.from_pretrained(self._task_path, | ||
| pool_act='linear') | ||
| if ("rocketqa" in model): | ||
| self._model = ErnieCrossEncoder(model) | ||
| else: | ||
| self._model = BertModel.from_pretrained(self._task_path, | ||
| pool_act='linear') | ||
| self._model.eval() | ||
|
|
||
| def _construct_tokenizer(self, model): | ||
| """ | ||
| Construct the tokenizer for the predictor. | ||
| """ | ||
| self._tokenizer = BertTokenizer.from_pretrained(model) | ||
| if ("rocketqa" in model): | ||
| self._tokenizer = ErnieTokenizer.from_pretrained(model) | ||
| else: | ||
| self._tokenizer = BertTokenizer.from_pretrained(model) | ||
|
|
||
| def _check_input_text(self, inputs): | ||
| inputs = inputs[0] | ||
|
|
@@ -118,40 +129,52 @@ def _preprocess(self, inputs): | |
| 'lazy_load'] if 'lazy_load' in self.kwargs else False | ||
|
|
||
| examples = [] | ||
|
|
||
| for data in inputs: | ||
| text1, text2 = data[0], data[1] | ||
|
|
||
| text1_encoded_inputs = self._tokenizer( | ||
| text=text1, max_seq_len=self._max_seq_len) | ||
| text1_input_ids = text1_encoded_inputs["input_ids"] | ||
| text1_token_type_ids = text1_encoded_inputs["token_type_ids"] | ||
|
|
||
| text2_encoded_inputs = self._tokenizer( | ||
| text=text2, max_seq_len=self._max_seq_len) | ||
| text2_input_ids = text2_encoded_inputs["input_ids"] | ||
| text2_token_type_ids = text2_encoded_inputs["token_type_ids"] | ||
|
|
||
| examples.append((text1_input_ids, text1_token_type_ids, | ||
| text2_input_ids, text2_token_type_ids)) | ||
| if ("rocketqa" in self.model_name): | ||
| encoded_inputs = self._tokenizer(text=text1, | ||
| text_pair=text2, | ||
| max_seq_len=self._max_seq_len) | ||
| ids = encoded_inputs["input_ids"] | ||
| segment_ids = encoded_inputs["token_type_ids"] | ||
| examples.append((ids, segment_ids)) | ||
| else: | ||
| text1_encoded_inputs = self._tokenizer( | ||
| text=text1, max_seq_len=self._max_seq_len) | ||
| text1_input_ids = text1_encoded_inputs["input_ids"] | ||
| text1_token_type_ids = text1_encoded_inputs["token_type_ids"] | ||
|
|
||
| text2_encoded_inputs = self._tokenizer( | ||
| text=text2, max_seq_len=self._max_seq_len) | ||
| text2_input_ids = text2_encoded_inputs["input_ids"] | ||
| text2_token_type_ids = text2_encoded_inputs["token_type_ids"] | ||
|
|
||
| examples.append((text1_input_ids, text1_token_type_ids, | ||
| text2_input_ids, text2_token_type_ids)) | ||
|
|
||
| batches = [ | ||
| examples[idx:idx + self._batch_size] | ||
| for idx in range(0, len(examples), self._batch_size) | ||
| ] | ||
|
|
||
| batchify_fn = lambda samples, fn=Tuple( | ||
| Pad(axis=0, pad_val=self._tokenizer.pad_token_id, dtype='int64' | ||
| ), # text1_input_ids | ||
| Pad(axis=0, | ||
| pad_val=self._tokenizer.pad_token_type_id, | ||
| dtype='int64'), # text1_token_type_ids | ||
| Pad(axis=0, pad_val=self._tokenizer.pad_token_id, dtype='int64' | ||
| ), # text2_input_ids | ||
| Pad(axis=0, | ||
| pad_val=self._tokenizer.pad_token_type_id, | ||
| dtype='int64'), # text2_token_type_ids | ||
| ): [data for data in fn(samples)] | ||
| if ("rocketqa" in self.model_name): | ||
| batchify_fn = lambda samples, fn=Tuple( | ||
| Pad(axis=0, pad_val=self._tokenizer.pad_token_id), # input ids | ||
| Pad(axis=0, pad_val=self._tokenizer.pad_token_type_id | ||
| ), # token type ids | ||
| ): [data for data in fn(samples)] | ||
| else: | ||
| batchify_fn = lambda samples, fn=Tuple( | ||
| Pad(axis=0, pad_val=self._tokenizer.pad_token_id, dtype='int64' | ||
| ), # text1_input_ids | ||
| Pad(axis=0, | ||
| pad_val=self._tokenizer.pad_token_type_id, | ||
| dtype='int64'), # text1_token_type_ids | ||
| Pad(axis=0, pad_val=self._tokenizer.pad_token_id, dtype='int64' | ||
| ), # text2_input_ids | ||
| Pad(axis=0, | ||
| pad_val=self._tokenizer.pad_token_type_id, | ||
| dtype='int64'), # text2_token_type_ids | ||
| ): [data for data in fn(samples)] | ||
|
|
||
| outputs = {} | ||
| outputs['data_loader'] = batches | ||
|
|
@@ -164,26 +187,35 @@ def _run_model(self, inputs): | |
| Run the task model from the outputs of the `_tokenize` function. | ||
| """ | ||
| results = [] | ||
| with static_mode_guard(): | ||
| if ("rocketqa" in self.model_name): | ||
| for batch in inputs['data_loader']: | ||
| text1_ids, text1_segment_ids, text2_ids, text2_segment_ids = self._batchify_fn( | ||
| batch) | ||
| self.input_handles[0].copy_from_cpu(text1_ids) | ||
| self.input_handles[1].copy_from_cpu(text1_segment_ids) | ||
| self.predictor.run() | ||
| vecs_text1 = self.output_handle[1].copy_to_cpu() | ||
|
|
||
| self.input_handles[0].copy_from_cpu(text2_ids) | ||
| self.input_handles[1].copy_from_cpu(text2_segment_ids) | ||
| self.predictor.run() | ||
| vecs_text2 = self.output_handle[1].copy_to_cpu() | ||
|
|
||
| vecs_text1 = vecs_text1 / (vecs_text1**2).sum( | ||
| axis=1, keepdims=True)**0.5 | ||
| vecs_text2 = vecs_text2 / (vecs_text2**2).sum( | ||
| axis=1, keepdims=True)**0.5 | ||
| similarity = (vecs_text1 * vecs_text2).sum(axis=1) | ||
| results.extend(similarity) | ||
| input_ids, segment_ids = self._batchify_fn(batch) | ||
| input_ids = paddle.to_tensor(input_ids, dtype='int64') | ||
| segment_ids = paddle.to_tensor(segment_ids, dtype='int64') | ||
| scores = self._model.matching(input_ids=input_ids, | ||
| token_type_ids=segment_ids) | ||
| results.extend(scores.numpy().tolist()) | ||
| else: | ||
| with static_mode_guard(): | ||
| for batch in inputs['data_loader']: | ||
| text1_ids, text1_segment_ids, text2_ids, text2_segment_ids = self._batchify_fn( | ||
| batch) | ||
| self.input_handles[0].copy_from_cpu(text1_ids) | ||
| self.input_handles[1].copy_from_cpu(text1_segment_ids) | ||
| self.predictor.run() | ||
| vecs_text1 = self.output_handle[1].copy_to_cpu() | ||
|
|
||
| self.input_handles[0].copy_from_cpu(text2_ids) | ||
| self.input_handles[1].copy_from_cpu(text2_segment_ids) | ||
| self.predictor.run() | ||
| vecs_text2 = self.output_handle[1].copy_to_cpu() | ||
|
|
||
| vecs_text1 = vecs_text1 / (vecs_text1**2).sum( | ||
| axis=1, keepdims=True)**0.5 | ||
| vecs_text2 = vecs_text2 / (vecs_text2**2).sum( | ||
| axis=1, keepdims=True)**0.5 | ||
| similarity = (vecs_text1 * vecs_text2).sum(axis=1) | ||
| results.extend(similarity) | ||
| inputs['result'] = results | ||
| return inputs | ||
|
|
||
|
|
||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议默认调用不加model参数进行简化,直接
similarity = Taskflow("text_similarity"),然后参考UIE提供一个模型选择的表格There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改