-
Notifications
You must be signed in to change notification settings - Fork 3.1k
add p1 features to six models #3462
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 40 commits
Commits
Show all changes
44 commits
Select commit
Hold shift + click to select a range
31fb6cf
update p1 models
wj-Mcat ca87219
update p1 models
wj-Mcat d05461f
update p1 models
wj-Mcat 7b29166
complete all p1 method
wj-Mcat d73e8bb
fix past_key_values
wj-Mcat b0e3cd9
add type annotation
wj-Mcat 050867c
ffix default value
wj-Mcat b45ed84
Merge branch 'develop' into add-p1
wj-Mcat 0c73a69
add type annotation to roformer
wj-Mcat 07069bd
Merge branch 'add-p1' of github.com:wj-Mcat/PaddleNLP into add-p1
wj-Mcat c6fdac8
fix variable name bug
wj-Mcat 64cca16
remove position_ids cache in ernie
wj-Mcat 23ebe9e
revert dtype
wj-Mcat 73d787e
revert dtype
wj-Mcat 5abb964
Merge branch 'add-p1' of github.com:wj-Mcat/PaddleNLP into add-p1
wj-Mcat 222e14a
update ernie test_modeling
wj-Mcat 7257b90
update ernie test_modeling
wj-Mcat d6ac632
update past_key_values testing code
wj-Mcat 0241343
Merge branch 'add-p1' of github.com:wj-Mcat/PaddleNLP into add-p1
wj-Mcat d5bd49c
revert get_default_dtype changes
wj-Mcat ed68867
Merge branch 'develop' into add-p1
wj-Mcat 7f2468d
fix testing header importing
wj-Mcat e81207e
update testing metric
wj-Mcat e90534b
remove position_ids cache
wj-Mcat d69be2c
Merge branch 'add-p1' of github.com:wj-Mcat/PaddleNLP into add-p1
wj-Mcat f63c463
Merge branch 'develop' into add-p1
wj-Mcat 1d5d0bc
Merge branch 'add-p1' of github.com:wj-Mcat/PaddleNLP into add-p1
wj-Mcat a279ee3
Merge branch 'develop' into add-p1
wj-Mcat f52aeb0
import precision of logits
wj-Mcat ff36f39
Merge branch 'add-p1' of github.com:wj-Mcat/PaddleNLP into add-p1
wj-Mcat b284178
Merge branch 'develop' into add-p1
wj-Mcat c0faa40
fix yapf for tinybert modeling
wj-Mcat 814bb51
Merge branch 'add-p1' of github.com:wj-Mcat/PaddleNLP into add-p1
wj-Mcat 9684581
Merge branch 'develop' into add-p1
wj-Mcat bfc027f
Merge branch 'develop' into add-p1
wj-Mcat 59251c0
Merge branch 'develop' into add-p1
wj-Mcat d3a8edb
Merge branch 'develop' into add-p1
wj-Mcat 3e3ae98
Merge branch 'develop' into add-p1
wj-Mcat 6242d47
Merge branch 'develop' into add-p1
wj-Mcat eff80ac
Merge branch 'develop' into add-p1
wj-Mcat c3aae03
Merge branch 'develop' into add-p1
wj-Mcat 88bbcdc
Merge branch 'develop' into add-p1
wj-Mcat 53e497d
Merge branch 'develop' into add-p1
wj-Mcat 0f640d6
Merge branch 'develop' into add-p1
wj-Mcat File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,6 +12,9 @@ | |
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from typing import Optional, Tuple | ||
| from paddle import Tensor | ||
|
|
||
| import paddle | ||
| import paddle.nn as nn | ||
| import paddle.nn.functional as F | ||
|
|
@@ -78,31 +81,31 @@ def __init__(self, | |
| self.dropout = nn.Dropout(hidden_dropout_prob) | ||
|
|
||
| def forward(self, | ||
| input_ids, | ||
| token_type_ids=None, | ||
| position_ids=None, | ||
| task_type_ids=None, | ||
| inputs_embeds=None, | ||
| past_key_values_length=None): | ||
| input_ids: Optional[Tensor] = None, | ||
| token_type_ids: Optional[Tensor] = None, | ||
| position_ids: Optional[Tensor] = None, | ||
| task_type_ids: Optional[Tensor] = None, | ||
| inputs_embeds: Optional[Tensor] = None, | ||
| past_key_values_length: int = 0): | ||
|
|
||
| if input_ids is not None: | ||
| input_shape = paddle.shape(input_ids) | ||
| input_embeddings = self.word_embeddings(input_ids) | ||
| else: | ||
| input_shape = paddle.shape(inputs_embeds)[:-1] | ||
| input_embeddings = inputs_embeds | ||
| inputs_embeds = self.word_embeddings(input_ids) | ||
|
|
||
| input_shape = paddle.shape(inputs_embeds)[:-1] | ||
|
|
||
| if position_ids is None: | ||
| # maybe need use shape op to unify static graph and dynamic graph | ||
| #seq_length = input_ids.shape[1] | ||
| ones = paddle.ones(input_shape, dtype="int64") | ||
| seq_length = paddle.cumsum(ones, axis=1) | ||
| position_ids = seq_length - ones | ||
| if past_key_values_length is not None: | ||
| position_ids += past_key_values_length | ||
|
|
||
| if past_key_values_length > 0: | ||
| position_ids = position_ids + past_key_values_length | ||
|
|
||
| position_ids.stop_gradient = True | ||
|
|
||
| position_embeddings = self.position_embeddings(position_ids) | ||
| embeddings = input_embeddings + position_embeddings | ||
| embeddings = inputs_embeds + position_embeddings | ||
|
|
||
| if self.type_vocab_size > 0: | ||
| if token_type_ids is None: | ||
|
|
@@ -882,17 +885,17 @@ def set_input_embeddings(self, value): | |
| self.embeddings.word_embeddings = value | ||
|
|
||
| def forward(self, | ||
| input_ids, | ||
| token_type_ids=None, | ||
| position_ids=None, | ||
| attention_mask=None, | ||
| task_type_ids=None, | ||
| past_key_values=None, | ||
| inputs_embeds=None, | ||
| use_cache=None, | ||
| output_hidden_states=False, | ||
| output_attentions=False, | ||
| return_dict=False): | ||
| input_ids: Optional[Tensor] = None, | ||
| token_type_ids: Optional[Tensor] = None, | ||
| position_ids: Optional[Tensor] = None, | ||
| attention_mask: Optional[Tensor] = None, | ||
| task_type_ids: Optional[Tensor] = None, | ||
| past_key_values: Optional[Tuple[Tuple[Tensor]]] = None, | ||
| inputs_embeds: Optional[Tensor] = None, | ||
| use_cache: Optional[bool] = None, | ||
| output_hidden_states: Optional[bool] = None, | ||
| output_attentions: Optional[bool] = None, | ||
| return_dict: Optional[bool] = None): | ||
| r""" | ||
| Args: | ||
| input_ids (Tensor): | ||
|
|
@@ -974,15 +977,13 @@ def forward(self, | |
| raise ValueError( | ||
| "You cannot specify both input_ids and inputs_embeds at the same time." | ||
| ) | ||
| elif input_ids is not None: | ||
| input_shape = paddle.shape(input_ids) | ||
| elif inputs_embeds is not None: | ||
| input_shape = paddle.shape(inputs_embeds)[:-1] | ||
| else: | ||
| raise ValueError( | ||
| "You have to specify either input_ids or inputs_embeds") | ||
|
|
||
| past_key_values_length = None | ||
| # init the default bool value | ||
| output_attentions = output_attentions if output_attentions is not None else False | ||
| output_hidden_states = output_hidden_states if output_hidden_states is not None else False | ||
| return_dict = return_dict if return_dict is not None else False | ||
| use_cache = use_cache if use_cache is not None else False | ||
|
Comment on lines
+1027
to
+1030
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
不过, 如果后续 use_cache = use_cache if use_cache is not None else self.config.use_cache |
||
| past_key_values_length = 0 | ||
| if past_key_values is not None: | ||
| past_key_values_length = past_key_values[0][0].shape[2] | ||
|
|
||
|
|
@@ -998,11 +999,13 @@ def forward(self, | |
| dtype=attention_mask.dtype) | ||
| attention_mask = paddle.concat([past_mask, attention_mask], | ||
| axis=-1) | ||
|
|
||
| # For 2D attention_mask from tokenizer | ||
| elif attention_mask.ndim == 2: | ||
| attention_mask = paddle.unsqueeze( | ||
| attention_mask, axis=[1, 2]).astype(paddle.get_default_dtype()) | ||
| attention_mask = (1.0 - attention_mask) * -1e4 | ||
|
|
||
| attention_mask.stop_gradient = True | ||
|
|
||
| embedding_output = self.embeddings( | ||
|
|
@@ -1065,15 +1068,15 @@ def __init__(self, ernie, num_classes=2, dropout=None): | |
| self.apply(self.init_weights) | ||
|
|
||
| def forward(self, | ||
| input_ids, | ||
| token_type_ids=None, | ||
| position_ids=None, | ||
| attention_mask=None, | ||
| inputs_embeds=None, | ||
| labels=None, | ||
| output_hidden_states=False, | ||
| output_attentions=False, | ||
| return_dict=False): | ||
| input_ids: Optional[Tensor] = None, | ||
| token_type_ids: Optional[Tensor] = None, | ||
| position_ids: Optional[Tensor] = None, | ||
| attention_mask: Optional[Tensor] = None, | ||
| inputs_embeds: Optional[Tensor] = None, | ||
| labels: Optional[Tensor] = None, | ||
| output_hidden_states: Optional[bool] = None, | ||
| output_attentions: Optional[bool] = None, | ||
| return_dict: Optional[bool] = None): | ||
| r""" | ||
| Args: | ||
| input_ids (Tensor): | ||
|
|
@@ -1177,16 +1180,16 @@ def __init__(self, ernie): | |
| self.apply(self.init_weights) | ||
|
|
||
| def forward(self, | ||
| input_ids, | ||
| token_type_ids=None, | ||
| position_ids=None, | ||
| attention_mask=None, | ||
| inputs_embeds=None, | ||
| start_positions=None, | ||
| end_positions=None, | ||
| output_hidden_states=False, | ||
| output_attentions=False, | ||
| return_dict=False): | ||
| input_ids: Optional[Tensor] = None, | ||
| token_type_ids: Optional[Tensor] = None, | ||
| position_ids: Optional[Tensor] = None, | ||
| attention_mask: Optional[Tensor] = None, | ||
| inputs_embeds: Optional[Tensor] = None, | ||
| start_positions: Optional[Tensor] = None, | ||
| end_positions: Optional[Tensor] = None, | ||
| output_hidden_states: Optional[bool] = None, | ||
| output_attentions: Optional[bool] = None, | ||
| return_dict: Optional[bool] = None): | ||
| r""" | ||
| Args: | ||
| input_ids (Tensor): | ||
|
|
@@ -1308,15 +1311,15 @@ def __init__(self, ernie, num_classes=2, dropout=None): | |
| self.apply(self.init_weights) | ||
|
|
||
| def forward(self, | ||
| input_ids, | ||
| token_type_ids=None, | ||
| position_ids=None, | ||
| attention_mask=None, | ||
| inputs_embeds=None, | ||
| labels=None, | ||
| output_hidden_states=False, | ||
| output_attentions=False, | ||
| return_dict=False): | ||
| input_ids: Optional[Tensor] = None, | ||
| token_type_ids: Optional[Tensor] = None, | ||
| position_ids: Optional[Tensor] = None, | ||
| attention_mask: Optional[Tensor] = None, | ||
| inputs_embeds: Optional[Tensor] = None, | ||
| labels: Optional[Tensor] = None, | ||
| output_hidden_states: Optional[bool] = None, | ||
| output_attentions: Optional[bool] = None, | ||
| return_dict: Optional[bool] = None): | ||
| r""" | ||
| Args: | ||
| input_ids (Tensor): | ||
|
|
@@ -1514,17 +1517,17 @@ def __init__(self, ernie): | |
| self.apply(self.init_weights) | ||
|
|
||
| def forward(self, | ||
| input_ids, | ||
| token_type_ids=None, | ||
| position_ids=None, | ||
| attention_mask=None, | ||
| masked_positions=None, | ||
| inputs_embeds=None, | ||
| labels=None, | ||
| next_sentence_label=None, | ||
| output_hidden_states=False, | ||
| output_attentions=False, | ||
| return_dict=False): | ||
| input_ids: Optional[Tensor] = None, | ||
| token_type_ids: Optional[Tensor] = None, | ||
| position_ids: Optional[Tensor] = None, | ||
| attention_mask: Optional[Tensor] = None, | ||
| masked_positions: Optional[Tensor] = None, | ||
| inputs_embeds: Optional[Tensor] = None, | ||
| labels: Optional[Tensor] = None, | ||
| next_sentence_label: Optional[Tensor] = None, | ||
| output_hidden_states: Optional[bool] = None, | ||
| output_attentions: Optional[bool] = None, | ||
| return_dict: Optional[bool] = None): | ||
| r""" | ||
| Args: | ||
| input_ids (Tensor): | ||
|
|
@@ -1695,16 +1698,16 @@ def __init__(self, ernie): | |
| self.apply(self.init_weights) | ||
|
|
||
| def forward(self, | ||
| input_ids, | ||
| token_type_ids=None, | ||
| position_ids=None, | ||
| attention_mask=None, | ||
| masked_positions=None, | ||
| inputs_embeds=None, | ||
| labels=None, | ||
| output_hidden_states=False, | ||
| output_attentions=False, | ||
| return_dict=False): | ||
| input_ids: Optional[Tensor] = None, | ||
| token_type_ids: Optional[Tensor] = None, | ||
| position_ids: Optional[Tensor] = None, | ||
| attention_mask: Optional[Tensor] = None, | ||
| masked_positions: Optional[Tensor] = None, | ||
| inputs_embeds: Optional[Tensor] = None, | ||
| labels: Optional[Tensor] = None, | ||
| output_hidden_states: Optional[bool] = None, | ||
| output_attentions: Optional[bool] = None, | ||
| return_dict: Optional[bool] = None): | ||
| r""" | ||
|
|
||
| Args: | ||
|
|
@@ -1817,15 +1820,15 @@ def __init__(self, ernie, num_choices=2, dropout=None): | |
| self.apply(self.init_weights) | ||
|
|
||
| def forward(self, | ||
| input_ids, | ||
| token_type_ids=None, | ||
| position_ids=None, | ||
| attention_mask=None, | ||
| inputs_embeds=None, | ||
| labels=None, | ||
| output_hidden_states=False, | ||
| output_attentions=False, | ||
| return_dict=False): | ||
| input_ids: Optional[Tensor] = None, | ||
| token_type_ids: Optional[Tensor] = None, | ||
| position_ids: Optional[Tensor] = None, | ||
| attention_mask: Optional[Tensor] = None, | ||
| inputs_embeds: Optional[Tensor] = None, | ||
| labels: Optional[Tensor] = None, | ||
| output_hidden_states: Optional[bool] = None, | ||
| output_attentions: Optional[bool] = None, | ||
| return_dict: Optional[bool] = None): | ||
| r""" | ||
| The ErnieForMultipleChoice forward method, overrides the __call__() special method. | ||
|
|
||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
根据你的建议,我这边做了一定的条件判断,并做了一个性能测试,详细可见:paddle-tensor-add-benchmark.
所以最终选用
tensor + int的方式来编写tensor常量相加的操作。