Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions examples/torch_migration/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# BERT-SST2-Prod
Reproduction process of BERT on SST2 dataset
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO 下个PR修改


# 安装说明

* 下载代码库

```shell
git clone https://github.com/PaddlePaddle/PaddleNLP/tree/develop/examples/torch_migration
```

* 进入文件夹,安装requirements

```shell
pip install -r requirements.txt
```

* 安装PaddlePaddle与PyTorch

```shell
# CPU版本的PaddlePaddle
pip install paddlepaddle==2.2.0 -i https://mirror.baidu.com/pypi/simple
# 如果希望安装GPU版本的PaddlePaddle,可以使用下面的命令
# pip install paddlepaddle-gpu==2.2.0.post112 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html
# 安装PyTorch
pip install torch==1.10.0+cu113 torchvision==0.11.1+cu113 torchaudio==0.10.0+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html
```

**注意**: 本项目依赖于paddlepaddle-2.2.0版本,安装时需要注意。

* 验证PaddlePaddle是否安装成功

运行python,输入下面的命令。

```shell
import paddle
paddle.utils.run_check()
print(paddle.__version__)
```

如果输出下面的内容,则说明PaddlePaddle安装成功。

```
PaddlePaddle is installed successfully! Let's start deep learning with PaddlePaddle now.
2.2.0
```


* 验证PyTorch是否安装成功

运行python,输入下面的命令,如果可以正常输出,则说明torch安装成功。

```shell
import torch
print(torch.__version__)
# 如果安装的是cpu版本,可以按照下面的命令确认torch是否安装成功
# 期望输出为 tensor([1.])
print(torch.Tensor([1.0]))
# 如果安装的是gpu版本,可以按照下面的命令确认torch是否安装成功
# 期望输出为 tensor([1.], device='cuda:0')
print(torch.Tensor([1.0]).cuda())
```
928 changes: 928 additions & 0 deletions examples/torch_migration/docs/ThesisReproduction_NLP.md

Large diffs are not rendered by default.

86 changes: 86 additions & 0 deletions examples/torch_migration/pipeline/Step1/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# 使用方法


本部分内容以前向对齐为例,介绍基于`repord_log`工具对齐的检查流程。其中与`reprod_log`工具有关的部分都是需要开发者需要添加的部分。


```shell
# 进入文件夹并生成torch的bert模型权重
cd pipeline/weights/ && python torch_bert_weights.py
# 进入文件夹并将torch的bert模型权重转换为paddle
cd pipeline/weights/ && python torch2paddle.py
# 进入文件夹并生成classifier权重
cd pipeline/classifier_weights/ && python generate_classifier_weights.py
# 进入Step1文件夹
cd pipeline/Step1/
# 生成paddle的前向数据
python pd_forward_bert.py
# 生成torch的前向数据
python pt_forward_bert.py
# 对比生成log
python check_step1.py
```

具体地,以PaddlePaddle为例,`pd_forward_bert.py`的具体代码如下所示。

```python
import numpy as np
import paddle
from reprod_log import ReprodLogger
import sys
import os
CURRENT_DIR = os.path.split(os.path.abspath(__file__))[0] # 当前目录
config_path = CURRENT_DIR.rsplit('/', 1)[0]
sys.path.append(config_path)
from models.pd_bert import *

# 导入reprod_log中的ReprodLogger类
from reprod_log import ReprodLogger

reprod_logger = ReprodLogger()

# 组网初始化加载BertModel权重
paddle_dump_path = '../weights/paddle_weight.pdparams'
config = BertConfig()
model = BertForSequenceClassification(config)
checkpoint = paddle.load(paddle_dump_path)
model.bert.load_dict(checkpoint)

# 加载分类权重
classifier_weights = paddle.load(
"../classifier_weights/paddle_classifier_weights.bin")
model.load_dict(classifier_weights)
model.eval()
# 读入fake data并转换为tensor,这里也可以固定seed在线生成fake data
fake_data = np.load("../fake_data/fake_data.npy")
fake_data = paddle.to_tensor(fake_data)
# 模型前向
out = model(fake_data)
# 保存前向结果,对于不同的任务,需要开发者添加。
reprod_logger.add("logits", out.cpu().detach().numpy())
reprod_logger.save("forward_paddle.npy")
```

diff检查的代码可以参考:[check_step1.py](./check_step1.py),具体代码如下所示。

```python
# https://github.com/littletomatodonkey/AlexNet-Prod/blob/master/pipeline/Step1/check_step1.py
# 使用reprod_log排查diff
from reprod_log import ReprodDiffHelper
if __name__ == "__main__":
diff_helper = ReprodDiffHelper()
torch_info = diff_helper.load_info("./forward_torch.npy")
paddle_info = diff_helper.load_info("./forward_paddle.npy")
diff_helper.compare_info(torch_info, paddle_info)
diff_helper.report(path="forward_diff.log")
```

产出日志如下,同时会将check的结果保存在`forward_diff.log`文件中。

```
[2021/11/17 20:15:50] root INFO: logits:
[2021/11/17 20:15:50] root INFO: mean diff: check passed: True, value: 1.30385160446167e-07
[2021/11/17 20:15:50] root INFO: diff check passed
```

平均绝对误差为1.3e-7,测试通过。
23 changes: 23 additions & 0 deletions examples/torch_migration/pipeline/Step1/check_step1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from reprod_log import ReprodDiffHelper

if __name__ == "__main__":
diff_helper = ReprodDiffHelper()
torch_info = diff_helper.load_info("./forward_torch.npy")
paddle_info = diff_helper.load_info("./forward_paddle.npy")

diff_helper.compare_info(torch_info, paddle_info)
diff_helper.report(path="forward_diff.log")
50 changes: 50 additions & 0 deletions examples/torch_migration/pipeline/Step1/pd_forward_bert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import os

import numpy as np
import paddle
from reprod_log import ReprodLogger

CURRENT_DIR = os.path.split(os.path.abspath(__file__))[0] # 当前目录
CONFIG_PATH = CURRENT_DIR.rsplit('/', 1)[0]
sys.path.append(CONFIG_PATH)

from models.pd_bert import BertConfig, BertForSequenceClassification

if __name__ == "__main__":
paddle.set_device("cpu")

# def logger
reprod_logger = ReprodLogger()

paddle_dump_path = '../weights/paddle_weight.pdparams'
config = BertConfig()
model = BertForSequenceClassification(config)
checkpoint = paddle.load(paddle_dump_path)
model.bert.load_dict(checkpoint)

classifier_weights = paddle.load(
"../classifier_weights/paddle_classifier_weights.bin")
model.load_dict(classifier_weights)
model.eval()
# read or gen fake data

fake_data = np.load("../fake_data/fake_data.npy")
fake_data = paddle.to_tensor(fake_data)
# forward
out = model(fake_data)[0]
reprod_logger.add("logits", out.cpu().detach().numpy())
reprod_logger.save("forward_paddle.npy")
48 changes: 48 additions & 0 deletions examples/torch_migration/pipeline/Step1/pt_forward_bert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import os

import numpy as np
from reprod_log import ReprodLogger
import torch

CURRENT_DIR = os.path.split(os.path.abspath(__file__))[0] # 当前目录
CONFIG_PATH = CURRENT_DIR.rsplit('/', 1)[0]
sys.path.append(CONFIG_PATH)

from models.pt_bert import BertConfig, BertForSequenceClassification

if __name__ == "__main__":
# def logger
reprod_logger = ReprodLogger()

pytorch_dump_path = '../weights/torch_weight.bin'
config = BertConfig()
model = BertForSequenceClassification(config)
checkpoint = torch.load(pytorch_dump_path)
model.bert.load_state_dict(checkpoint)

classifier_weights = torch.load(
"../classifier_weights/torch_classifier_weights.bin")
model.load_state_dict(classifier_weights, strict=False)
model.eval()

# read or gen fake data
fake_data = np.load("../fake_data/fake_data.npy")
fake_data = torch.from_numpy(fake_data)
# forward
out = model(fake_data)[0]
reprod_logger.add("logits", out.cpu().detach().numpy())
reprod_logger.save("forward_torch.npy")
114 changes: 114 additions & 0 deletions examples/torch_migration/pipeline/Step1/torch2paddle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import OrderedDict

import numpy as np
import paddle
import torch
from paddlenlp.transformers import BertForPretraining as PDBertForMaskedLM
from transformers import BertForMaskedLM as PTBertForMaskedLM


def convert_pytorch_checkpoint_to_paddle(
pytorch_checkpoint_path="pytorch_model.bin",
paddle_dump_path="model_state.pdparams",
version="old",
):
hf_to_paddle = {
"embeddings.LayerNorm": "embeddings.layer_norm",
"encoder.layer": "encoder.layers",
"attention.self.query": "self_attn.q_proj",
"attention.self.key": "self_attn.k_proj",
"attention.self.value": "self_attn.v_proj",
"attention.output.dense": "self_attn.out_proj",
"intermediate.dense": "linear1",
"output.dense": "linear2",
"attention.output.LayerNorm": "norm1",
"output.LayerNorm": "norm2",
"predictions.decoder.": "predictions.decoder_",
"predictions.transform.dense": "predictions.transform",
"predictions.transform.LayerNorm": "predictions.layer_norm",
}
do_not_transpose = []
if version == "old":
hf_to_paddle.update({
"predictions.bias": "predictions.decoder_bias",
".gamma": ".weight",
".beta": ".bias",
})
do_not_transpose = do_not_transpose + ["predictions.decoder.weight"]

pytorch_state_dict = torch.load(pytorch_checkpoint_path, map_location="cpu")
paddle_state_dict = OrderedDict()
for k, v in pytorch_state_dict.items():
is_transpose = False
if k[-7:] == ".weight":
# embeddings.weight and LayerNorm.weight do not transpose
if all(d not in k for d in do_not_transpose):
if ".embeddings." not in k and ".LayerNorm." not in k:
if v.ndim == 2:
if 'embeddings' not in k:
v = v.transpose(0, 1)
is_transpose = True
is_transpose = False
oldk = k
print(f"Converting: {oldk} => {k} | is_transpose {is_transpose}")
paddle_state_dict[k] = v.data.numpy()

paddle.save(paddle_state_dict, paddle_dump_path)


def compare(out_torch, out_paddle):
out_torch = out_torch.detach().numpy()
out_paddle = out_paddle.detach().numpy()
assert out_torch.shape == out_paddle.shape
abs_dif = np.abs(out_torch - out_paddle)
mean_dif = np.mean(abs_dif)
max_dif = np.max(abs_dif)
min_dif = np.min(abs_dif)
print("mean_dif:{}".format(mean_dif))
print("max_dif:{}".format(max_dif))
print("min_dif:{}".format(min_dif))


def test_forward():
paddle.set_device("cpu")
model_torch = PTBertForMaskedLM.from_pretrained("./bert-base-uncased")
model_paddle = PDBertForMaskedLM.from_pretrained("./bert-base-uncased")
model_torch.eval()
model_paddle.eval()
np.random.seed(42)
x = np.random.randint(1,
model_paddle.bert.config["vocab_size"],
size=(4, 64))
input_torch = torch.tensor(x, dtype=torch.int64)
out_torch = model_torch(input_torch)[0]

input_paddle = paddle.to_tensor(x, dtype=paddle.int64)
out_paddle = model_paddle(input_paddle)[0]

print("torch result shape:{}".format(out_torch.shape))
print("paddle result shape:{}".format(out_paddle.shape))
compare(out_torch, out_paddle)


if __name__ == "__main__":
convert_pytorch_checkpoint_to_paddle("test.bin", "test_paddle.pdparams")
# test_forward()
# torch result shape:torch.Size([4, 64, 30522])
# paddle result shape:[4, 64, 30522]
# mean_dif:1.666686512180604e-05
# max_dif:0.00015211105346679688
# min_dif:0.0
Loading