|
5 | 5 | |API名称 | 新增API名称 | |
6 | 6 | |---|----------------------------------------------------| |
7 | 7 | |提交作者<input type="checkbox" class="rowselector hidden"> | 丘文波, 刘旺旺 | |
8 | | -|提交时间<input type="checkbox" class="rowselector hidden"> | 2022-03-04 | |
9 | | -|版本号 | V1 | |
| 8 | +|提交时间<input type="checkbox" class="rowselector hidden"> | 2022-03-05 | |
| 9 | +|版本号 | V2 | |
10 | 10 | |依赖飞桨版本<input type="checkbox" class="rowselector hidden"> | 如无特殊情况,都应基于develop版本开发 | |
11 | 11 | |文件名 | 20230304_api_design_for_tie_weight_task_103.md<br> | |
12 | 12 |
|
@@ -45,24 +45,108 @@ paddle 中并没有对tie weight的统一实现,调用者需自己写代码实 |
45 | 45 |
|
46 | 46 | paddleNLP中的一些示例代码中也找到了一个tie weight的实现. |
47 | 47 |
|
48 | | -<img alt="img_3.png" src="img_3.png" width="700"/> |
| 48 | +(1) [代码链接1](https://github.com/qiuwenbogdut/PaddleNLP/blob/develop/examples/language_model/transformer-xl/mem_transformer.py#L811) |
| 49 | + |
| 50 | +```python |
| 51 | +if tie_weight: |
| 52 | + for i in range(len(self.crit.out_layers_weight)): |
| 53 | + self.crit.out_layers_weight[i] = self.word_emb.emb_layers[i].weight |
| 54 | + |
| 55 | +if tie_projs: |
| 56 | + for i, tie_proj in enumerate(tie_projs): |
| 57 | + if tie_proj and div_val == 1 and d_model != d_embed: |
| 58 | + self.crit.out_projs[i] = self.word_emb.emb_projs[0] |
| 59 | + elif tie_proj and div_val != 1: |
| 60 | + self.crit.out_projs[i] = self.word_emb.emb_projs[i] |
| 61 | +``` |
| 62 | + |
| 63 | +(2) [代码链接2](https://github.com/PaddlePaddle/PaddleNLP/blob/4e5df921ff61ddae1d869c37aea621b9cac6bcd4/paddlenlp/transformers/reformer/modeling.py#L1977) |
| 64 | + |
| 65 | +```python |
| 66 | +def tie_weights(self): |
| 67 | + """ |
| 68 | + Tie the weights between the input embeddings and the output embeddings. |
| 69 | + """ |
| 70 | + tie_word_embeddings = ( |
| 71 | + self.tie_word_embeddings |
| 72 | + if hasattr(self, "tie_word_embeddings") |
| 73 | + else self.config.get("tie_word_embeddings", False) |
| 74 | + ) |
| 75 | + if hasattr(self, "get_output_embeddings") and hasattr(self, "get_input_embeddings") and tie_word_embeddings: |
| 76 | + output_embeddings = self.get_output_embeddings() |
| 77 | + if output_embeddings is not None: |
| 78 | + self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings()) |
| 79 | +``` |
| 80 | + |
49 | 81 |
|
50 | 82 | 最好是给基础模型加上tie weight的函数,减少调用者的开发. |
51 | 83 |
|
52 | 84 | # 三、业内方案调研 |
53 | 85 | 描述业内深度学习框架如何实现此功能,包括与此功能相关的现状、未来趋势;调研的范围包括不限于TensorFlow、PyTorch、NumPy等 |
54 | 86 |
|
55 | | -(1)目前huggingface的transformers库中实现了这个tieweight 这个基础函数. |
56 | | - |
57 | | -<img alt="img_4.png" src="img_4.png" width="700"/> |
58 | | - |
59 | | -(2) tensor2tensor库 tieweight 实现代码 |
60 | | - |
61 | | -<img alt="img_5.png" src="img_5.png" width="500"/> |
62 | | - |
63 | | -(3) fairseq库 中 tie weight实现函数 |
64 | | - |
65 | | -<img alt="img_6.png" src="img_6.png" width="600"/> |
| 87 | +(1)目前huggingface的transformers库中实现了这个tieweight 这个基础函数. [代码链接](https://github.com/huggingface/transformers/blob/v4.26.1/src/transformers/modeling_utils.py#L1172) |
| 88 | +```python |
| 89 | +def tie_weights(self): |
| 90 | + """ |
| 91 | + Tie the weights between the input embeddings and the output embeddings. |
| 92 | + If the `torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning the |
| 93 | + weights instead. |
| 94 | + """ |
| 95 | + if getattr(self.config, "tie_word_embeddings", True): |
| 96 | + output_embeddings = self.get_output_embeddings() |
| 97 | + if output_embeddings is not None: |
| 98 | + self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings()) |
| 99 | + |
| 100 | + if getattr(self.config, "is_encoder_decoder", False) and getattr(self.config, "tie_encoder_decoder", False): |
| 101 | + if hasattr(self, self.base_model_prefix): |
| 102 | + self = getattr(self, self.base_model_prefix) |
| 103 | + self._tie_encoder_decoder_weights(self.encoder, self.decoder, self.base_model_prefix) |
| 104 | + |
| 105 | + for module in self.modules(): |
| 106 | + if hasattr(module, "_tie_weights"): |
| 107 | + module._tie_weights() |
| 108 | +``` |
| 109 | + |
| 110 | + |
| 111 | +(2) tensor2tensor库 tieweight 实现代码 [代码链接](https://github.com/tensorflow/tensor2tensor/blob/316c9ce2f2b2373f44f5be0da712dda3e5861a75/tensor2tensor/layers/modalities.py#L1106) |
| 112 | +```python |
| 113 | +def symbol_top(body_output, targets, model_hparams, vocab_size): |
| 114 | + del targets # unused arg |
| 115 | + if model_hparams.shared_embedding_and_softmax_weights: |
| 116 | + scope_name = "shared" |
| 117 | + reuse = tf.AUTO_REUSE |
| 118 | + else: |
| 119 | + scope_name = "softmax" |
| 120 | + reuse = False |
| 121 | + with tf.variable_scope(scope_name, reuse=reuse): |
| 122 | + body_output_shape = common_layers.shape_list(body_output) |
| 123 | + var = get_weights(model_hparams, vocab_size, body_output_shape[-1]) |
| 124 | + if (model_hparams.factored_logits and |
| 125 | + model_hparams.mode == tf_estimator.ModeKeys.TRAIN): |
| 126 | + # insert channels dimension |
| 127 | + body_output = tf.expand_dims(body_output, 3) |
| 128 | + return common_layers.FactoredTensor(body_output, var) |
| 129 | + else: |
| 130 | + body_output = tf.reshape(body_output, [-1, body_output_shape[-1]]) |
| 131 | + logits = tf.matmul(body_output, var, transpose_b=True) |
| 132 | + return tf.reshape(logits, |
| 133 | + body_output_shape[:-1] + [1, vocab_size]) |
| 134 | +``` |
| 135 | + |
| 136 | + |
| 137 | +(3) fairseq库 中 tie weight实现函数 [代码链接](https://github.com/facebookresearch/fairseq/blob/main/fairseq/models/fconv.py#L480) |
| 138 | +```python |
| 139 | +self.fc2 = Linear(in_channels, out_embed_dim) |
| 140 | + if share_embed: |
| 141 | + assert out_embed_dim == embed_dim, ( |
| 142 | + "Shared embed weights implies same dimensions " |
| 143 | + " out_embed_dim={} vs embed_dim={}".format(out_embed_dim, embed_dim) |
| 144 | + ) |
| 145 | + self.fc3 = nn.Linear(out_embed_dim, num_embeddings) |
| 146 | + self.fc3.weight = self.embed_tokens.weight |
| 147 | + else: |
| 148 | + self.fc3 = Linear(out_embed_dim, num_embeddings, dropout=dropout) |
| 149 | +``` |
66 | 150 |
|
67 | 151 | # 四、对比分析 |
68 | 152 | paddle和 huggingface的transformers 都是基于动态图进行开发, 所以准备参照huggingface的transformers 的 tie weight 函数思路去实现功能. |
|
0 commit comments