forked from PaddlePaddle/PaddleNLP
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add faster transformer for decoding (PaddlePaddle#37)
* add faster transformer * add README * add fp16 * add big config * delete temp time record * improve performance * vocab_size * add decoding sample config * add desription for variable in op.cc * add clang format * add reader description * update comments * process format and rm useless import * comments * add decription * undo useless change * update load * update according to comments * ulter * update readme * add more info in readme * update readme * gpu:0 -> gpu
- Loading branch information
Showing
20 changed files
with
1,852 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# This file is used by clang-format to autoformat paddle source code | ||
# | ||
# The clang-format is part of llvm toolchain. | ||
# It need to install llvm and clang to format source code style. | ||
# | ||
# The basic usage is, | ||
# clang-format -i -style=file PATH/TO/SOURCE/CODE | ||
# | ||
# The -style=file implicit use ".clang-format" file located in one of | ||
# parent directory. | ||
# The -i means inplace change. | ||
# | ||
# The document of clang-format is | ||
# http://clang.llvm.org/docs/ClangFormat.html | ||
# http://clang.llvm.org/docs/ClangFormatStyleOptions.html | ||
--- | ||
Language: Cpp | ||
BasedOnStyle: Google | ||
IndentWidth: 2 | ||
TabWidth: 2 | ||
ContinuationIndentWidth: 4 | ||
MaxEmptyLinesToKeep: 2 | ||
AccessModifierOffset: -2 # The private/protected/public has no indent in class | ||
Standard: Cpp11 | ||
AllowAllParametersOfDeclarationOnNextLine: true | ||
BinPackParameters: false | ||
BinPackArguments: false | ||
... | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
#!/usr/bin/env bash | ||
set -e | ||
|
||
readonly VERSION="3.8" | ||
|
||
version=$(clang-format -version) | ||
|
||
if ! [[ $version == *"$VERSION"* ]]; then | ||
echo "clang-format version check failed." | ||
echo "a version contains '$VERSION' is needed, but get '$version'" | ||
echo "you can install the right version, and make an soft-link to '\$PATH' env" | ||
exit -1 | ||
fi | ||
|
||
clang-format $@ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
162 changes: 162 additions & 0 deletions
162
examples/machine_translation/transformer/faster_transformer/README.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
# Faster Transformer 预测 | ||
|
||
在这里我们集成了 Nvidia [Faster Transformer](https://github.com/NVIDIA/DeepLearningExamples/tree/master/FasterTransformer) 用于预测加速。同时集成了 Faster Transformer float32 以及 float16 预测。目前仅集成 beam search 作为解码的策略,并应用在动态图 Transformer 英德翻译的推理预测中。以下是使用 Faster Transformer 的说明。 | ||
|
||
## 使用环境说明 | ||
|
||
* 本项目依赖于 PaddlePaddle 2.0.1 及以上版本或适当的 develop 版本 | ||
* CMake >= 3.10 | ||
* CUDA 10.1 或是更新的版本(需要 PaddlePaddle 框架一致) | ||
* gcc 版本需要与编译 PaddlePaddle 版本一致,比如使用 gcc8.2 | ||
* 推荐使用 Python3 | ||
* [Faster Transformer](https://github.com/NVIDIA/DeepLearningExamples/tree/master/FasterTransformer/v3.1#setup) 使用必要的环境 | ||
|
||
## 快速开始 | ||
|
||
### 编译自定义OP | ||
|
||
自定义 OP 需要将实现的 C++、CUDA 代码编译成动态库,我们提供对应的 CMakeLists.txt ,可以参考使用如下的方式完成编译。同样的自定义 op 编译的说明也可以在自定义 op 对应的路径 `PaddleNLP/paddlenlp/ext_op/` 下面找到。 | ||
|
||
#### 克隆 PaddleNLP | ||
|
||
首先,因为需要基于当前环境重新编译,当前的 paddlenlp 的 python 包里面并不包含 Faster Transformer 相关 lib,需要克隆一个 PaddleNLP,并重新编译: | ||
``` sh | ||
git clone https://github.com/PaddlePaddle/PaddleNLP.git | ||
``` | ||
|
||
其次,配置环境变量,让我们可以使用当前 clone 的 paddlenlp,并进入到自定义 OP 的路径,准备后续的编译操作: | ||
|
||
``` sh | ||
export PYTHONPATH=$PWD/PaddleNLP/:$PYTHONPATH | ||
cd PaddleNLP/paddlenlp/ext_op/ | ||
``` | ||
|
||
#### 编译 | ||
|
||
编译之前,请确保安装的 PaddlePaddle 的版本需要大于 2.0.1,并且正常可用。 | ||
|
||
编译自定义 OP 可以参照一下步骤: | ||
|
||
``` sh | ||
mkdir build | ||
cd build/ | ||
cmake .. -DSM=xx -DCMAKE_BUILD_TYPE=Release | ||
make -j | ||
cd ../ | ||
``` | ||
|
||
注意:`xx` 是指的所用 GPU 的 compute capability。举例来说,可以将之指定为 70(V100) 或是 75(T4)。 | ||
|
||
最终,编译会在 `./build/lib/` 路径下,产出 `libdecoding_op.so`,即需要的 Faster Transformer decoding 执行的库。 | ||
|
||
## 使用 Faster Transformer 完成预测 | ||
|
||
编写 python 脚本的时候,调用 `FasterTransformer` API 并传入 `libdecoding_op.so` 的位置即可实现将 Faster Transformer 用于当前的预测。 | ||
|
||
举例如下: | ||
|
||
``` python | ||
from paddlenlp.ext_op import FasterTransformer | ||
|
||
transformer = FasterTransformer( | ||
src_vocab_size=args.src_vocab_size, | ||
trg_vocab_size=args.trg_vocab_size, | ||
max_length=args.max_length + 1, | ||
n_layer=args.n_layer, | ||
n_head=args.n_head, | ||
d_model=args.d_model, | ||
d_inner_hid=args.d_inner_hid, | ||
dropout=args.dropout, | ||
weight_sharing=args.weight_sharing, | ||
bos_id=args.bos_idx, | ||
eos_id=args.eos_idx, | ||
beam_size=args.beam_size, | ||
max_out_len=args.max_out_len, | ||
decoding_lib=args.decoding_lib, | ||
use_fp16_decoding=args.use_fp16_decoding) | ||
``` | ||
|
||
更详细的例子可以参考 `encoder_decoding_predict.py`,我们提供了更详细用例。 | ||
|
||
|
||
#### 数据准备 | ||
|
||
公开数据集:WMT 翻译大赛是机器翻译领域最具权威的国际评测大赛,其中英德翻译任务提供了一个中等规模的数据集,这个数据集是较多论文中使用的数据集,也是 Transformer 论文中用到的一个数据集。我们也将[WMT'14 EN-DE 数据集](http://www.statmt.org/wmt14/translation-task.html)作为示例提供。 | ||
|
||
同时,我们提供了一份已经处理好的数据集,可以编写如下代码,对应的数据集将会自动下载并且解压到 `~/.paddlenlp/datasets/machine_translation/WMT14ende/`。 | ||
|
||
``` python | ||
# 获取默认的数据处理方式 | ||
transform_func = WMT14ende.get_default_transform_func(root=root) | ||
# 下载并处理 WMT14.en-de 翻译数据集 | ||
dataset = WMT14ende.get_datasets(mode="train", transform_func=transform_func) | ||
``` | ||
|
||
|
||
#### 模型推断 | ||
|
||
使用模型推断前提是需要指定一个合适的 checkpoint,需要在对应的 `../configs/transformer.base.yaml` 中修改对应的模型载入的路径参数 `init_from_params`。 | ||
|
||
我们提供一个已经训练好的动态图的 base model 的 checkpoint 以供使用,可以通过[tranformer-base-wmt_ende_bpe](https://paddlenlp.bj.bcebos.com/models/transformers/transformer/tranformer-base-wmt_ende_bpe.tar.gz)下载。 | ||
|
||
``` sh | ||
wget https://paddlenlp.bj.bcebos.com/models/transformers/transformer/tranformer-base-wmt_ende_bpe.tar.gz | ||
tar -zxf tranformer-base-wmt_ende_bpe.tar.gz | ||
``` | ||
|
||
然后,需要修改对应的 `../configs/transformer.base.yaml` 配置文件中的 `init_from_params` 的值为 `./base_trained_models/step_final/`。 | ||
|
||
#### 使用动态图预测(使用 float32 decoding 预测) | ||
|
||
以英德翻译数据为例,模型训练完成后可以执行以下命令对指定文件中的文本进行翻译: | ||
|
||
``` sh | ||
# setting visible devices for prediction | ||
export CUDA_VISIBLE_DEVICES=0 | ||
export FLAGS_fraction_of_gpu_memory_to_use=0.1 | ||
cp -rf ../../../../paddlenlp/ext_op/build/third-party/build/bin/decoding_gemm ./ | ||
./decoding_gemm 8 4 8 64 38512 32 512 0 | ||
python encoder_decoding_predict.py --config ../configs/transformer.base.yaml --decoding-lib ../../../../paddlenlp/ext_op/build/lib/libdecoding_op.so | ||
``` | ||
|
||
其中,`--config` 选项用于指明配置文件的位置,而 `--decoding-lib` 选项用于指明编译好的 Faster Transformer decoding lib 的位置。 | ||
|
||
翻译结果会输出到 `output_file` 指定的文件。执行预测时需要设置 `init_from_params` 来给出模型所在目录,更多参数的使用可以在 `./sample/config/transformer.base.yaml` 文件中查阅注释说明并进行更改设置。如果执行不提供 `--config` 选项,程序将默认使用 base model 的配置。 | ||
|
||
|
||
#### 使用动态图预测(使用 float16 decoding 预测) | ||
|
||
float16 与 float32 预测的基本流程相同,不过在使用 float16 的 decoding 进行预测的时候,需要再加上 `--use-fp16-decoding` 选项。后按照与之前相同的方式执行即可。具体执行方式如下: | ||
|
||
``` sh | ||
# setting visible devices for prediction | ||
export CUDA_VISIBLE_DEVICES=0 | ||
export FLAGS_fraction_of_gpu_memory_to_use=0.1 | ||
cp -rf ../../../../paddlenlp/ext_op/build/third-party/build/bin/decoding_gemm ./ | ||
./decoding_gemm 8 4 8 64 38512 32 512 1 | ||
python encoder_decoding_predict.py --config ../configs/transformer.base.yaml --decoding-lib ../../../../paddlenlp/ext_op/build/lib/libdecoding_op.so --use-fp16-decoding | ||
``` | ||
|
||
其中,`--config` 选项用于指明配置文件的位置,而 `--decoding-lib` 选项用于指明编译好的 Faster Transformer decoding lib 的位置。 | ||
|
||
翻译结果会输出到 `output_file` 指定的文件。执行预测时需要设置 `init_from_params` 来给出模型所在目录,更多参数的使用可以在 `./sample/config/transformer.base.yaml` 文件中查阅注释说明并进行更改设置。如果执行不提供 `--config` 选项,程序将默认使用 base model 的配置。 | ||
|
||
需要注意的是,目前预测仅实现了单卡的预测,原因在于,翻译后面需要的模型评估依赖于预测结果写入文件顺序,多卡情况下,目前暂未支持将结果按照指定顺序写入文件。 | ||
|
||
## 模型评估 | ||
|
||
评估方式与动态图评估方式相同,预测结果中每行输出是对应行输入的得分最高的翻译,对于使用 BPE 的数据,预测出的翻译结果也将是 BPE 表示的数据,要还原成原始的数据(这里指 tokenize 后的数据)才能进行正确的评估。评估过程具体如下(BLEU 是翻译任务常用的自动评估方法指标): | ||
|
||
``` sh | ||
# 还原 predict.txt 中的预测结果为 tokenize 后的数据 | ||
sed -r 's/(@@ )|(@@ ?$)//g' predict.txt > predict.tok.txt | ||
# 若无 BLEU 评估工具,需先进行下载 | ||
git clone https://github.com/moses-smt/mosesdecoder.git | ||
# 以英德翻译 newstest2014 测试数据为例 | ||
perl mosesdecoder/scripts/generic/multi-bleu.perl ~/.paddlenlp/datasets/machine_translation/WMT14ende/WMT14.en-de/wmt14_ende_data/newstest2014.tok.de < predict.tok.txt | ||
``` | ||
|
||
执行上述操作之后,可以看到类似如下的结果,此处结果是 base model 在 newstest2014 上的 BLEU 结果: | ||
``` | ||
BLEU = 26.89, 58.4/32.6/20.5/13.4 (BP=1.000, ratio=1.010, hyp_len=65166, ref_len=64506) | ||
``` |
115 changes: 115 additions & 0 deletions
115
examples/machine_translation/transformer/faster_transformer/encoder_decoding_predict.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
import sys | ||
import os | ||
import numpy as np | ||
from attrdict import AttrDict | ||
import argparse | ||
import time | ||
|
||
import paddle | ||
import paddle.nn as nn | ||
import paddle.nn.functional as F | ||
|
||
import yaml | ||
from pprint import pprint | ||
|
||
from paddlenlp.transformers import TransformerModel | ||
from paddlenlp.transformers import position_encoding_init | ||
from paddlenlp.ext_op import FasterTransformer | ||
|
||
sys.path.append("../") | ||
import reader | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--config", | ||
default="../configs/transformer.base.yaml", | ||
type=str, | ||
help="Path of the config file. ") | ||
parser.add_argument( | ||
"--decoding-lib", | ||
default="../../../../paddlenlp/ext_op/build/lib/libdecoding_op.so", | ||
type=str, | ||
help="Path of libdecoding_op.so. ") | ||
parser.add_argument( | ||
"--use-fp16-decoding", | ||
action="store_true", | ||
help="Whether to use fp16 decoding to predict. ") | ||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
def post_process_seq(seq, bos_idx, eos_idx, output_bos=False, output_eos=False): | ||
""" | ||
Post-process the decoded sequence. | ||
""" | ||
eos_pos = len(seq) - 1 | ||
for i, idx in enumerate(seq): | ||
if idx == eos_idx: | ||
eos_pos = i | ||
break | ||
seq = [ | ||
idx for idx in seq[:eos_pos + 1] | ||
if (output_bos or idx != bos_idx) and (output_eos or idx != eos_idx) | ||
] | ||
return seq | ||
|
||
|
||
def do_predict(args): | ||
place = "gpu" | ||
paddle.set_device(place) | ||
|
||
# Define data loader | ||
test_loader, to_tokens = reader.create_infer_loader(args) | ||
|
||
# Define model | ||
transformer = FasterTransformer( | ||
src_vocab_size=args.src_vocab_size, | ||
trg_vocab_size=args.trg_vocab_size, | ||
max_length=args.max_length + 1, | ||
n_layer=args.n_layer, | ||
n_head=args.n_head, | ||
d_model=args.d_model, | ||
d_inner_hid=args.d_inner_hid, | ||
dropout=args.dropout, | ||
weight_sharing=args.weight_sharing, | ||
bos_id=args.bos_idx, | ||
eos_id=args.eos_idx, | ||
beam_size=args.beam_size, | ||
max_out_len=args.max_out_len, | ||
decoding_lib=args.decoding_lib, | ||
use_fp16_decoding=args.use_fp16_decoding) | ||
|
||
# Set evaluate mode | ||
transformer.eval() | ||
|
||
# Load checkpoint. | ||
transformer.load(init_from_params=os.path.join(args.init_from_params, | ||
"transformer.pdparams")) | ||
|
||
f = open(args.output_file, "w") | ||
with paddle.no_grad(): | ||
for (src_word, ) in test_loader: | ||
finished_seq = transformer(src_word=src_word) | ||
finished_seq = finished_seq.numpy().transpose([1, 2, 0]) | ||
for ins in finished_seq: | ||
for beam_idx, beam in enumerate(ins): | ||
if beam_idx >= args.n_best: | ||
break | ||
id_list = post_process_seq(beam, args.bos_idx, args.eos_idx) | ||
word_list = to_tokens(id_list) | ||
sequence = " ".join(word_list) + "\n" | ||
f.write(sequence) | ||
|
||
|
||
if __name__ == "__main__": | ||
ARGS = parse_args() | ||
yaml_file = ARGS.config | ||
with open(yaml_file, 'rt') as f: | ||
args = AttrDict(yaml.safe_load(f)) | ||
pprint(args) | ||
args.decoding_lib = ARGS.decoding_lib | ||
args.use_fp16_decoding = ARGS.use_fp16_decoding | ||
|
||
do_predict(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.