Skip to content

Commit

Permalink
[Tutorial] Add torch migration tutorial (#3641)
Browse files Browse the repository at this point in the history
Co-authored-by: Zhong Hui <[email protected]>
  • Loading branch information
ymyjl and ZHUI authored Nov 3, 2022
1 parent afeb623 commit 1fc23a8
Show file tree
Hide file tree
Showing 41 changed files with 5,615 additions and 0 deletions.
62 changes: 62 additions & 0 deletions examples/torch_migration/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# BERT-SST2-Prod
Reproduction process of BERT on SST2 dataset

# 安装说明

* 下载代码库

```shell
git clone https://github.com/PaddlePaddle/PaddleNLP/tree/develop/examples/torch_migration
```

* 进入文件夹,安装requirements

```shell
pip install -r requirements.txt
```

* 安装PaddlePaddle与PyTorch

```shell
# CPU版本的PaddlePaddle
pip install paddlepaddle==2.2.0 -i https://mirror.baidu.com/pypi/simple
# 如果希望安装GPU版本的PaddlePaddle,可以使用下面的命令
# pip install paddlepaddle-gpu==2.2.0.post112 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html
# 安装PyTorch
pip install torch==1.10.0+cu113 torchvision==0.11.1+cu113 torchaudio==0.10.0+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html
```

**注意**: 本项目依赖于paddlepaddle-2.2.0版本,安装时需要注意。

* 验证PaddlePaddle是否安装成功

运行python,输入下面的命令。

```shell
import paddle
paddle.utils.run_check()
print(paddle.__version__)
```

如果输出下面的内容,则说明PaddlePaddle安装成功。

```
PaddlePaddle is installed successfully! Let's start deep learning with PaddlePaddle now.
2.2.0
```


* 验证PyTorch是否安装成功

运行python,输入下面的命令,如果可以正常输出,则说明torch安装成功。

```shell
import torch
print(torch.__version__)
# 如果安装的是cpu版本,可以按照下面的命令确认torch是否安装成功
# 期望输出为 tensor([1.])
print(torch.Tensor([1.0]))
# 如果安装的是gpu版本,可以按照下面的命令确认torch是否安装成功
# 期望输出为 tensor([1.], device='cuda:0')
print(torch.Tensor([1.0]).cuda())
```
928 changes: 928 additions & 0 deletions examples/torch_migration/docs/ThesisReproduction_NLP.md

Large diffs are not rendered by default.

86 changes: 86 additions & 0 deletions examples/torch_migration/pipeline/Step1/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# 使用方法


本部分内容以前向对齐为例,介绍基于`repord_log`工具对齐的检查流程。其中与`reprod_log`工具有关的部分都是需要开发者需要添加的部分。


```shell
# 进入文件夹并生成torch的bert模型权重
cd pipeline/weights/ && python torch_bert_weights.py
# 进入文件夹并将torch的bert模型权重转换为paddle
cd pipeline/weights/ && python torch2paddle.py
# 进入文件夹并生成classifier权重
cd pipeline/classifier_weights/ && python generate_classifier_weights.py
# 进入Step1文件夹
cd pipeline/Step1/
# 生成paddle的前向数据
python pd_forward_bert.py
# 生成torch的前向数据
python pt_forward_bert.py
# 对比生成log
python check_step1.py
```

具体地,以PaddlePaddle为例,`pd_forward_bert.py`的具体代码如下所示。

```python
import numpy as np
import paddle
from reprod_log import ReprodLogger
import sys
import os
CURRENT_DIR = os.path.split(os.path.abspath(__file__))[0] # 当前目录
config_path = CURRENT_DIR.rsplit('/', 1)[0]
sys.path.append(config_path)
from models.pd_bert import *

# 导入reprod_log中的ReprodLogger类
from reprod_log import ReprodLogger

reprod_logger = ReprodLogger()

# 组网初始化加载BertModel权重
paddle_dump_path = '../weights/paddle_weight.pdparams'
config = BertConfig()
model = BertForSequenceClassification(config)
checkpoint = paddle.load(paddle_dump_path)
model.bert.load_dict(checkpoint)

# 加载分类权重
classifier_weights = paddle.load(
"../classifier_weights/paddle_classifier_weights.bin")
model.load_dict(classifier_weights)
model.eval()
# 读入fake data并转换为tensor,这里也可以固定seed在线生成fake data
fake_data = np.load("../fake_data/fake_data.npy")
fake_data = paddle.to_tensor(fake_data)
# 模型前向
out = model(fake_data)
# 保存前向结果,对于不同的任务,需要开发者添加。
reprod_logger.add("logits", out.cpu().detach().numpy())
reprod_logger.save("forward_paddle.npy")
```

diff检查的代码可以参考:[check_step1.py](./check_step1.py),具体代码如下所示。

```python
# https://github.com/littletomatodonkey/AlexNet-Prod/blob/master/pipeline/Step1/check_step1.py
# 使用reprod_log排查diff
from reprod_log import ReprodDiffHelper
if __name__ == "__main__":
diff_helper = ReprodDiffHelper()
torch_info = diff_helper.load_info("./forward_torch.npy")
paddle_info = diff_helper.load_info("./forward_paddle.npy")
diff_helper.compare_info(torch_info, paddle_info)
diff_helper.report(path="forward_diff.log")
```

产出日志如下,同时会将check的结果保存在`forward_diff.log`文件中。

```
[2021/11/17 20:15:50] root INFO: logits:
[2021/11/17 20:15:50] root INFO: mean diff: check passed: True, value: 1.30385160446167e-07
[2021/11/17 20:15:50] root INFO: diff check passed
```

平均绝对误差为1.3e-7,测试通过。
23 changes: 23 additions & 0 deletions examples/torch_migration/pipeline/Step1/check_step1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from reprod_log import ReprodDiffHelper

if __name__ == "__main__":
diff_helper = ReprodDiffHelper()
torch_info = diff_helper.load_info("./forward_torch.npy")
paddle_info = diff_helper.load_info("./forward_paddle.npy")

diff_helper.compare_info(torch_info, paddle_info)
diff_helper.report(path="forward_diff.log")
50 changes: 50 additions & 0 deletions examples/torch_migration/pipeline/Step1/pd_forward_bert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import os

import numpy as np
import paddle
from reprod_log import ReprodLogger

CURRENT_DIR = os.path.split(os.path.abspath(__file__))[0] # 当前目录
CONFIG_PATH = CURRENT_DIR.rsplit('/', 1)[0]
sys.path.append(CONFIG_PATH)

from models.pd_bert import BertConfig, BertForSequenceClassification

if __name__ == "__main__":
paddle.set_device("cpu")

# def logger
reprod_logger = ReprodLogger()

paddle_dump_path = '../weights/paddle_weight.pdparams'
config = BertConfig()
model = BertForSequenceClassification(config)
checkpoint = paddle.load(paddle_dump_path)
model.bert.load_dict(checkpoint)

classifier_weights = paddle.load(
"../classifier_weights/paddle_classifier_weights.bin")
model.load_dict(classifier_weights)
model.eval()
# read or gen fake data

fake_data = np.load("../fake_data/fake_data.npy")
fake_data = paddle.to_tensor(fake_data)
# forward
out = model(fake_data)[0]
reprod_logger.add("logits", out.cpu().detach().numpy())
reprod_logger.save("forward_paddle.npy")
48 changes: 48 additions & 0 deletions examples/torch_migration/pipeline/Step1/pt_forward_bert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import os

import numpy as np
from reprod_log import ReprodLogger
import torch

CURRENT_DIR = os.path.split(os.path.abspath(__file__))[0] # 当前目录
CONFIG_PATH = CURRENT_DIR.rsplit('/', 1)[0]
sys.path.append(CONFIG_PATH)

from models.pt_bert import BertConfig, BertForSequenceClassification

if __name__ == "__main__":
# def logger
reprod_logger = ReprodLogger()

pytorch_dump_path = '../weights/torch_weight.bin'
config = BertConfig()
model = BertForSequenceClassification(config)
checkpoint = torch.load(pytorch_dump_path)
model.bert.load_state_dict(checkpoint)

classifier_weights = torch.load(
"../classifier_weights/torch_classifier_weights.bin")
model.load_state_dict(classifier_weights, strict=False)
model.eval()

# read or gen fake data
fake_data = np.load("../fake_data/fake_data.npy")
fake_data = torch.from_numpy(fake_data)
# forward
out = model(fake_data)[0]
reprod_logger.add("logits", out.cpu().detach().numpy())
reprod_logger.save("forward_torch.npy")
114 changes: 114 additions & 0 deletions examples/torch_migration/pipeline/Step1/torch2paddle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import OrderedDict

import numpy as np
import paddle
import torch
from paddlenlp.transformers import BertForPretraining as PDBertForMaskedLM
from transformers import BertForMaskedLM as PTBertForMaskedLM


def convert_pytorch_checkpoint_to_paddle(
pytorch_checkpoint_path="pytorch_model.bin",
paddle_dump_path="model_state.pdparams",
version="old",
):
hf_to_paddle = {
"embeddings.LayerNorm": "embeddings.layer_norm",
"encoder.layer": "encoder.layers",
"attention.self.query": "self_attn.q_proj",
"attention.self.key": "self_attn.k_proj",
"attention.self.value": "self_attn.v_proj",
"attention.output.dense": "self_attn.out_proj",
"intermediate.dense": "linear1",
"output.dense": "linear2",
"attention.output.LayerNorm": "norm1",
"output.LayerNorm": "norm2",
"predictions.decoder.": "predictions.decoder_",
"predictions.transform.dense": "predictions.transform",
"predictions.transform.LayerNorm": "predictions.layer_norm",
}
do_not_transpose = []
if version == "old":
hf_to_paddle.update({
"predictions.bias": "predictions.decoder_bias",
".gamma": ".weight",
".beta": ".bias",
})
do_not_transpose = do_not_transpose + ["predictions.decoder.weight"]

pytorch_state_dict = torch.load(pytorch_checkpoint_path, map_location="cpu")
paddle_state_dict = OrderedDict()
for k, v in pytorch_state_dict.items():
is_transpose = False
if k[-7:] == ".weight":
# embeddings.weight and LayerNorm.weight do not transpose
if all(d not in k for d in do_not_transpose):
if ".embeddings." not in k and ".LayerNorm." not in k:
if v.ndim == 2:
if 'embeddings' not in k:
v = v.transpose(0, 1)
is_transpose = True
is_transpose = False
oldk = k
print(f"Converting: {oldk} => {k} | is_transpose {is_transpose}")
paddle_state_dict[k] = v.data.numpy()

paddle.save(paddle_state_dict, paddle_dump_path)


def compare(out_torch, out_paddle):
out_torch = out_torch.detach().numpy()
out_paddle = out_paddle.detach().numpy()
assert out_torch.shape == out_paddle.shape
abs_dif = np.abs(out_torch - out_paddle)
mean_dif = np.mean(abs_dif)
max_dif = np.max(abs_dif)
min_dif = np.min(abs_dif)
print("mean_dif:{}".format(mean_dif))
print("max_dif:{}".format(max_dif))
print("min_dif:{}".format(min_dif))


def test_forward():
paddle.set_device("cpu")
model_torch = PTBertForMaskedLM.from_pretrained("./bert-base-uncased")
model_paddle = PDBertForMaskedLM.from_pretrained("./bert-base-uncased")
model_torch.eval()
model_paddle.eval()
np.random.seed(42)
x = np.random.randint(1,
model_paddle.bert.config["vocab_size"],
size=(4, 64))
input_torch = torch.tensor(x, dtype=torch.int64)
out_torch = model_torch(input_torch)[0]

input_paddle = paddle.to_tensor(x, dtype=paddle.int64)
out_paddle = model_paddle(input_paddle)[0]

print("torch result shape:{}".format(out_torch.shape))
print("paddle result shape:{}".format(out_paddle.shape))
compare(out_torch, out_paddle)


if __name__ == "__main__":
convert_pytorch_checkpoint_to_paddle("test.bin", "test_paddle.pdparams")
# test_forward()
# torch result shape:torch.Size([4, 64, 30522])
# paddle result shape:[4, 64, 30522]
# mean_dif:1.666686512180604e-05
# max_dif:0.00015211105346679688
# min_dif:0.0
Loading

0 comments on commit 1fc23a8

Please sign in to comment.