Skip to content

Commit 6929649

Browse files
committed
[103] 新增tie_weights能力 提交rfc文档 v2
1 parent 9cb4f54 commit 6929649

File tree

1 file changed

+98
-14
lines changed

1 file changed

+98
-14
lines changed

docs/community/rfcs/20230304_api_design_for_tie_weight_task_103.md

Lines changed: 98 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
|API名称 | 新增API名称 |
66
|---|----------------------------------------------------|
77
|提交作者<input type="checkbox" class="rowselector hidden"> | 丘文波, 刘旺旺 |
8-
|提交时间<input type="checkbox" class="rowselector hidden"> | 2022-03-04 |
9-
|版本号 | V1 |
8+
|提交时间<input type="checkbox" class="rowselector hidden"> | 2022-03-05 |
9+
|版本号 | V2 |
1010
|依赖飞桨版本<input type="checkbox" class="rowselector hidden"> | 如无特殊情况,都应基于develop版本开发 |
1111
|文件名 | 20230304_api_design_for_tie_weight_task_103.md<br> |
1212

@@ -45,24 +45,108 @@ paddle 中并没有对tie weight的统一实现,调用者需自己写代码实
4545

4646
paddleNLP中的一些示例代码中也找到了一个tie weight的实现.
4747

48-
<img alt="img_3.png" src="img_3.png" width="700"/>
48+
(1) [代码链接1](https://github.com/qiuwenbogdut/PaddleNLP/blob/develop/examples/language_model/transformer-xl/mem_transformer.py#L811)
49+
50+
```python
51+
if tie_weight:
52+
for i in range(len(self.crit.out_layers_weight)):
53+
self.crit.out_layers_weight[i] = self.word_emb.emb_layers[i].weight
54+
55+
if tie_projs:
56+
for i, tie_proj in enumerate(tie_projs):
57+
if tie_proj and div_val == 1 and d_model != d_embed:
58+
self.crit.out_projs[i] = self.word_emb.emb_projs[0]
59+
elif tie_proj and div_val != 1:
60+
self.crit.out_projs[i] = self.word_emb.emb_projs[i]
61+
```
62+
63+
(2) [代码链接2](https://github.com/PaddlePaddle/PaddleNLP/blob/4e5df921ff61ddae1d869c37aea621b9cac6bcd4/paddlenlp/transformers/reformer/modeling.py#L1977)
64+
65+
```python
66+
def tie_weights(self):
67+
"""
68+
Tie the weights between the input embeddings and the output embeddings.
69+
"""
70+
tie_word_embeddings = (
71+
self.tie_word_embeddings
72+
if hasattr(self, "tie_word_embeddings")
73+
else self.config.get("tie_word_embeddings", False)
74+
)
75+
if hasattr(self, "get_output_embeddings") and hasattr(self, "get_input_embeddings") and tie_word_embeddings:
76+
output_embeddings = self.get_output_embeddings()
77+
if output_embeddings is not None:
78+
self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
79+
```
80+
4981

5082
最好是给基础模型加上tie weight的函数,减少调用者的开发.
5183

5284
# 三、业内方案调研
5385
描述业内深度学习框架如何实现此功能,包括与此功能相关的现状、未来趋势;调研的范围包括不限于TensorFlow、PyTorch、NumPy等
5486

55-
(1)目前huggingface的transformers库中实现了这个tieweight 这个基础函数.
56-
57-
<img alt="img_4.png" src="img_4.png" width="700"/>
58-
59-
(2) tensor2tensor库 tieweight 实现代码
60-
61-
<img alt="img_5.png" src="img_5.png" width="500"/>
62-
63-
(3) fairseq库 中 tie weight实现函数
64-
65-
<img alt="img_6.png" src="img_6.png" width="600"/>
87+
(1)目前huggingface的transformers库中实现了这个tieweight 这个基础函数. [代码链接](https://github.com/huggingface/transformers/blob/v4.26.1/src/transformers/modeling_utils.py#L1172)
88+
```python
89+
def tie_weights(self):
90+
"""
91+
Tie the weights between the input embeddings and the output embeddings.
92+
If the `torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning the
93+
weights instead.
94+
"""
95+
if getattr(self.config, "tie_word_embeddings", True):
96+
output_embeddings = self.get_output_embeddings()
97+
if output_embeddings is not None:
98+
self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
99+
100+
if getattr(self.config, "is_encoder_decoder", False) and getattr(self.config, "tie_encoder_decoder", False):
101+
if hasattr(self, self.base_model_prefix):
102+
self = getattr(self, self.base_model_prefix)
103+
self._tie_encoder_decoder_weights(self.encoder, self.decoder, self.base_model_prefix)
104+
105+
for module in self.modules():
106+
if hasattr(module, "_tie_weights"):
107+
module._tie_weights()
108+
```
109+
110+
111+
(2) tensor2tensor库 tieweight 实现代码 [代码链接](https://github.com/tensorflow/tensor2tensor/blob/316c9ce2f2b2373f44f5be0da712dda3e5861a75/tensor2tensor/layers/modalities.py#L1106)
112+
```python
113+
def symbol_top(body_output, targets, model_hparams, vocab_size):
114+
del targets # unused arg
115+
if model_hparams.shared_embedding_and_softmax_weights:
116+
scope_name = "shared"
117+
reuse = tf.AUTO_REUSE
118+
else:
119+
scope_name = "softmax"
120+
reuse = False
121+
with tf.variable_scope(scope_name, reuse=reuse):
122+
body_output_shape = common_layers.shape_list(body_output)
123+
var = get_weights(model_hparams, vocab_size, body_output_shape[-1])
124+
if (model_hparams.factored_logits and
125+
model_hparams.mode == tf_estimator.ModeKeys.TRAIN):
126+
# insert channels dimension
127+
body_output = tf.expand_dims(body_output, 3)
128+
return common_layers.FactoredTensor(body_output, var)
129+
else:
130+
body_output = tf.reshape(body_output, [-1, body_output_shape[-1]])
131+
logits = tf.matmul(body_output, var, transpose_b=True)
132+
return tf.reshape(logits,
133+
body_output_shape[:-1] + [1, vocab_size])
134+
```
135+
136+
137+
(3) fairseq库 中 tie weight实现函数 [代码链接](https://github.com/facebookresearch/fairseq/blob/main/fairseq/models/fconv.py#L480)
138+
```python
139+
self.fc2 = Linear(in_channels, out_embed_dim)
140+
if share_embed:
141+
assert out_embed_dim == embed_dim, (
142+
"Shared embed weights implies same dimensions "
143+
" out_embed_dim={} vs embed_dim={}".format(out_embed_dim, embed_dim)
144+
)
145+
self.fc3 = nn.Linear(out_embed_dim, num_embeddings)
146+
self.fc3.weight = self.embed_tokens.weight
147+
else:
148+
self.fc3 = Linear(out_embed_dim, num_embeddings, dropout=dropout)
149+
```
66150

67151
# 四、对比分析
68152
paddle和 huggingface的transformers 都是基于动态图进行开发, 所以准备参照huggingface的transformers 的 tie weight 函数思路去实现功能.

0 commit comments

Comments
 (0)