From 80c1f77ac9bcc65dfcea9c6c7f87ce80153cb285 Mon Sep 17 00:00:00 2001 From: liu zhengxi <380185688@qq.com> Date: Fri, 5 Mar 2021 20:27:07 +0800 Subject: [PATCH] Add faster transformer for decoding (#37) * add faster transformer * add README * add fp16 * add big config * delete temp time record * improve performance * vocab_size * add decoding sample config * add desription for variable in op.cc * add clang format * add reader description * update comments * process format and rm useless import * comments * add decription * undo useless change * update load * update according to comments * ulter * update readme * add more info in readme * update readme * gpu:0 -> gpu --- .clang-format | 29 ++ .clang_format.hook | 15 + .../machine_translation/transformer/README.md | 4 + .../transformer/faster_transformer/README.md | 162 +++++++++ .../encoder_decoding_predict.py | 115 +++++++ paddlenlp/__init__.py | 1 + paddlenlp/ext_op/CMakeLists.txt | 141 ++++++++ paddlenlp/ext_op/README.md | 109 ++++++ paddlenlp/ext_op/__init__.py | 2 + .../ext_op/sample/config/decoding.sample.yaml | 95 ++++++ paddlenlp/ext_op/sample/decoding_sample.py | 107 ++++++ .../ext_op/sample/encoder_decoding_sample.py | 108 ++++++ paddlenlp/ext_op/src/CMakeLists.txt | 23 ++ paddlenlp/ext_op/src/fusion_decoding_op.cc | 185 +++++++++++ paddlenlp/ext_op/src/fusion_decoding_op.cu | 266 +++++++++++++++ paddlenlp/ext_op/src/fusion_decoding_op.h | 20 ++ paddlenlp/ext_op/src/pd_traits.h | 25 ++ paddlenlp/ext_op/transformer/__init__.py | 0 paddlenlp/ext_op/transformer/decoding.py | 312 ++++++++++++++++++ .../ext_op/transformer/faster_transformer.py | 133 ++++++++ 20 files changed, 1852 insertions(+) create mode 100644 .clang-format create mode 100755 .clang_format.hook create mode 100644 examples/machine_translation/transformer/faster_transformer/README.md create mode 100644 examples/machine_translation/transformer/faster_transformer/encoder_decoding_predict.py create mode 100644 paddlenlp/ext_op/CMakeLists.txt create mode 100644 paddlenlp/ext_op/README.md create mode 100644 paddlenlp/ext_op/__init__.py create mode 100644 paddlenlp/ext_op/sample/config/decoding.sample.yaml create mode 100644 paddlenlp/ext_op/sample/decoding_sample.py create mode 100644 paddlenlp/ext_op/sample/encoder_decoding_sample.py create mode 100644 paddlenlp/ext_op/src/CMakeLists.txt create mode 100644 paddlenlp/ext_op/src/fusion_decoding_op.cc create mode 100644 paddlenlp/ext_op/src/fusion_decoding_op.cu create mode 100644 paddlenlp/ext_op/src/fusion_decoding_op.h create mode 100644 paddlenlp/ext_op/src/pd_traits.h create mode 100644 paddlenlp/ext_op/transformer/__init__.py create mode 100644 paddlenlp/ext_op/transformer/decoding.py create mode 100644 paddlenlp/ext_op/transformer/faster_transformer.py diff --git a/.clang-format b/.clang-format new file mode 100644 index 000000000000..30863c27a8fd --- /dev/null +++ b/.clang-format @@ -0,0 +1,29 @@ +# This file is used by clang-format to autoformat paddle source code +# +# The clang-format is part of llvm toolchain. +# It need to install llvm and clang to format source code style. +# +# The basic usage is, +# clang-format -i -style=file PATH/TO/SOURCE/CODE +# +# The -style=file implicit use ".clang-format" file located in one of +# parent directory. +# The -i means inplace change. +# +# The document of clang-format is +# http://clang.llvm.org/docs/ClangFormat.html +# http://clang.llvm.org/docs/ClangFormatStyleOptions.html +--- +Language: Cpp +BasedOnStyle: Google +IndentWidth: 2 +TabWidth: 2 +ContinuationIndentWidth: 4 +MaxEmptyLinesToKeep: 2 +AccessModifierOffset: -2 # The private/protected/public has no indent in class +Standard: Cpp11 +AllowAllParametersOfDeclarationOnNextLine: true +BinPackParameters: false +BinPackArguments: false +... + diff --git a/.clang_format.hook b/.clang_format.hook new file mode 100755 index 000000000000..40d70f56cf97 --- /dev/null +++ b/.clang_format.hook @@ -0,0 +1,15 @@ +#!/usr/bin/env bash +set -e + +readonly VERSION="3.8" + +version=$(clang-format -version) + +if ! [[ $version == *"$VERSION"* ]]; then + echo "clang-format version check failed." + echo "a version contains '$VERSION' is needed, but get '$version'" + echo "you can install the right version, and make an soft-link to '\$PATH' env" + exit -1 +fi + +clang-format $@ diff --git a/examples/machine_translation/transformer/README.md b/examples/machine_translation/transformer/README.md index 3379d538c942..dcc933b6add9 100644 --- a/examples/machine_translation/transformer/README.md +++ b/examples/machine_translation/transformer/README.md @@ -132,6 +132,10 @@ python deploy/python/inference.py --config ./configs/transformer.base.yaml 翻译结果同样将会保存在 `predict.txt` 文件中,可以在配置文件中自定义更改 `output_file` 来指定预测结果写入到的文件的名称。 +## 使用 Faster Transformer 实现预测 + +具体的说明可以参考 `faster_transformer/README.md`。`cd faster_transformer/` 即可查看。 + ## 模型评估 预测结果中每行输出是对应行输入的得分最高的翻译,对于使用 BPE 的数据,预测出的翻译结果也将是 BPE 表示的数据,要还原成原始的数据(这里指 tokenize 后的数据)才能进行正确的评估。评估过程具体如下(BLEU 是翻译任务常用的自动评估方法指标): diff --git a/examples/machine_translation/transformer/faster_transformer/README.md b/examples/machine_translation/transformer/faster_transformer/README.md new file mode 100644 index 000000000000..b25562ddcdd0 --- /dev/null +++ b/examples/machine_translation/transformer/faster_transformer/README.md @@ -0,0 +1,162 @@ +# Faster Transformer 预测 + +在这里我们集成了 Nvidia [Faster Transformer](https://github.com/NVIDIA/DeepLearningExamples/tree/master/FasterTransformer) 用于预测加速。同时集成了 Faster Transformer float32 以及 float16 预测。目前仅集成 beam search 作为解码的策略,并应用在动态图 Transformer 英德翻译的推理预测中。以下是使用 Faster Transformer 的说明。 + +## 使用环境说明 + +* 本项目依赖于 PaddlePaddle 2.0.1 及以上版本或适当的 develop 版本 +* CMake >= 3.10 +* CUDA 10.1 或是更新的版本(需要 PaddlePaddle 框架一致) +* gcc 版本需要与编译 PaddlePaddle 版本一致,比如使用 gcc8.2 +* 推荐使用 Python3 +* [Faster Transformer](https://github.com/NVIDIA/DeepLearningExamples/tree/master/FasterTransformer/v3.1#setup) 使用必要的环境 + +## 快速开始 + +### 编译自定义OP + +自定义 OP 需要将实现的 C++、CUDA 代码编译成动态库,我们提供对应的 CMakeLists.txt ,可以参考使用如下的方式完成编译。同样的自定义 op 编译的说明也可以在自定义 op 对应的路径 `PaddleNLP/paddlenlp/ext_op/` 下面找到。 + +#### 克隆 PaddleNLP + +首先,因为需要基于当前环境重新编译,当前的 paddlenlp 的 python 包里面并不包含 Faster Transformer 相关 lib,需要克隆一个 PaddleNLP,并重新编译: +``` sh +git clone https://github.com/PaddlePaddle/PaddleNLP.git +``` + +其次,配置环境变量,让我们可以使用当前 clone 的 paddlenlp,并进入到自定义 OP 的路径,准备后续的编译操作: + +``` sh +export PYTHONPATH=$PWD/PaddleNLP/:$PYTHONPATH +cd PaddleNLP/paddlenlp/ext_op/ +``` + +#### 编译 + +编译之前,请确保安装的 PaddlePaddle 的版本需要大于 2.0.1,并且正常可用。 + +编译自定义 OP 可以参照一下步骤: + +``` sh +mkdir build +cd build/ +cmake .. -DSM=xx -DCMAKE_BUILD_TYPE=Release +make -j +cd ../ +``` + +注意:`xx` 是指的所用 GPU 的 compute capability。举例来说,可以将之指定为 70(V100) 或是 75(T4)。 + +最终,编译会在 `./build/lib/` 路径下,产出 `libdecoding_op.so`,即需要的 Faster Transformer decoding 执行的库。 + +## 使用 Faster Transformer 完成预测 + +编写 python 脚本的时候,调用 `FasterTransformer` API 并传入 `libdecoding_op.so` 的位置即可实现将 Faster Transformer 用于当前的预测。 + +举例如下: + +``` python +from paddlenlp.ext_op import FasterTransformer + +transformer = FasterTransformer( + src_vocab_size=args.src_vocab_size, + trg_vocab_size=args.trg_vocab_size, + max_length=args.max_length + 1, + n_layer=args.n_layer, + n_head=args.n_head, + d_model=args.d_model, + d_inner_hid=args.d_inner_hid, + dropout=args.dropout, + weight_sharing=args.weight_sharing, + bos_id=args.bos_idx, + eos_id=args.eos_idx, + beam_size=args.beam_size, + max_out_len=args.max_out_len, + decoding_lib=args.decoding_lib, + use_fp16_decoding=args.use_fp16_decoding) +``` + +更详细的例子可以参考 `encoder_decoding_predict.py`,我们提供了更详细用例。 + + +#### 数据准备 + +公开数据集:WMT 翻译大赛是机器翻译领域最具权威的国际评测大赛,其中英德翻译任务提供了一个中等规模的数据集,这个数据集是较多论文中使用的数据集,也是 Transformer 论文中用到的一个数据集。我们也将[WMT'14 EN-DE 数据集](http://www.statmt.org/wmt14/translation-task.html)作为示例提供。 + +同时,我们提供了一份已经处理好的数据集,可以编写如下代码,对应的数据集将会自动下载并且解压到 `~/.paddlenlp/datasets/machine_translation/WMT14ende/`。 + +``` python +# 获取默认的数据处理方式 +transform_func = WMT14ende.get_default_transform_func(root=root) +# 下载并处理 WMT14.en-de 翻译数据集 +dataset = WMT14ende.get_datasets(mode="train", transform_func=transform_func) +``` + + +#### 模型推断 + +使用模型推断前提是需要指定一个合适的 checkpoint,需要在对应的 `../configs/transformer.base.yaml` 中修改对应的模型载入的路径参数 `init_from_params`。 + +我们提供一个已经训练好的动态图的 base model 的 checkpoint 以供使用,可以通过[tranformer-base-wmt_ende_bpe](https://paddlenlp.bj.bcebos.com/models/transformers/transformer/tranformer-base-wmt_ende_bpe.tar.gz)下载。 + +``` sh +wget https://paddlenlp.bj.bcebos.com/models/transformers/transformer/tranformer-base-wmt_ende_bpe.tar.gz +tar -zxf tranformer-base-wmt_ende_bpe.tar.gz +``` + +然后,需要修改对应的 `../configs/transformer.base.yaml` 配置文件中的 `init_from_params` 的值为 `./base_trained_models/step_final/`。 + +#### 使用动态图预测(使用 float32 decoding 预测) + +以英德翻译数据为例,模型训练完成后可以执行以下命令对指定文件中的文本进行翻译: + +``` sh +# setting visible devices for prediction +export CUDA_VISIBLE_DEVICES=0 +export FLAGS_fraction_of_gpu_memory_to_use=0.1 +cp -rf ../../../../paddlenlp/ext_op/build/third-party/build/bin/decoding_gemm ./ +./decoding_gemm 8 4 8 64 38512 32 512 0 +python encoder_decoding_predict.py --config ../configs/transformer.base.yaml --decoding-lib ../../../../paddlenlp/ext_op/build/lib/libdecoding_op.so +``` + +其中,`--config` 选项用于指明配置文件的位置,而 `--decoding-lib` 选项用于指明编译好的 Faster Transformer decoding lib 的位置。 + +翻译结果会输出到 `output_file` 指定的文件。执行预测时需要设置 `init_from_params` 来给出模型所在目录,更多参数的使用可以在 `./sample/config/transformer.base.yaml` 文件中查阅注释说明并进行更改设置。如果执行不提供 `--config` 选项,程序将默认使用 base model 的配置。 + + +#### 使用动态图预测(使用 float16 decoding 预测) + +float16 与 float32 预测的基本流程相同,不过在使用 float16 的 decoding 进行预测的时候,需要再加上 `--use-fp16-decoding` 选项。后按照与之前相同的方式执行即可。具体执行方式如下: + +``` sh +# setting visible devices for prediction +export CUDA_VISIBLE_DEVICES=0 +export FLAGS_fraction_of_gpu_memory_to_use=0.1 +cp -rf ../../../../paddlenlp/ext_op/build/third-party/build/bin/decoding_gemm ./ +./decoding_gemm 8 4 8 64 38512 32 512 1 +python encoder_decoding_predict.py --config ../configs/transformer.base.yaml --decoding-lib ../../../../paddlenlp/ext_op/build/lib/libdecoding_op.so --use-fp16-decoding +``` + +其中,`--config` 选项用于指明配置文件的位置,而 `--decoding-lib` 选项用于指明编译好的 Faster Transformer decoding lib 的位置。 + +翻译结果会输出到 `output_file` 指定的文件。执行预测时需要设置 `init_from_params` 来给出模型所在目录,更多参数的使用可以在 `./sample/config/transformer.base.yaml` 文件中查阅注释说明并进行更改设置。如果执行不提供 `--config` 选项,程序将默认使用 base model 的配置。 + +需要注意的是,目前预测仅实现了单卡的预测,原因在于,翻译后面需要的模型评估依赖于预测结果写入文件顺序,多卡情况下,目前暂未支持将结果按照指定顺序写入文件。 + +## 模型评估 + +评估方式与动态图评估方式相同,预测结果中每行输出是对应行输入的得分最高的翻译,对于使用 BPE 的数据,预测出的翻译结果也将是 BPE 表示的数据,要还原成原始的数据(这里指 tokenize 后的数据)才能进行正确的评估。评估过程具体如下(BLEU 是翻译任务常用的自动评估方法指标): + +``` sh +# 还原 predict.txt 中的预测结果为 tokenize 后的数据 +sed -r 's/(@@ )|(@@ ?$)//g' predict.txt > predict.tok.txt +# 若无 BLEU 评估工具,需先进行下载 +git clone https://github.com/moses-smt/mosesdecoder.git +# 以英德翻译 newstest2014 测试数据为例 +perl mosesdecoder/scripts/generic/multi-bleu.perl ~/.paddlenlp/datasets/machine_translation/WMT14ende/WMT14.en-de/wmt14_ende_data/newstest2014.tok.de < predict.tok.txt +``` + +执行上述操作之后,可以看到类似如下的结果,此处结果是 base model 在 newstest2014 上的 BLEU 结果: +``` +BLEU = 26.89, 58.4/32.6/20.5/13.4 (BP=1.000, ratio=1.010, hyp_len=65166, ref_len=64506) +``` diff --git a/examples/machine_translation/transformer/faster_transformer/encoder_decoding_predict.py b/examples/machine_translation/transformer/faster_transformer/encoder_decoding_predict.py new file mode 100644 index 000000000000..809279f62971 --- /dev/null +++ b/examples/machine_translation/transformer/faster_transformer/encoder_decoding_predict.py @@ -0,0 +1,115 @@ +import sys +import os +import numpy as np +from attrdict import AttrDict +import argparse +import time + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +import yaml +from pprint import pprint + +from paddlenlp.transformers import TransformerModel +from paddlenlp.transformers import position_encoding_init +from paddlenlp.ext_op import FasterTransformer + +sys.path.append("../") +import reader + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--config", + default="../configs/transformer.base.yaml", + type=str, + help="Path of the config file. ") + parser.add_argument( + "--decoding-lib", + default="../../../../paddlenlp/ext_op/build/lib/libdecoding_op.so", + type=str, + help="Path of libdecoding_op.so. ") + parser.add_argument( + "--use-fp16-decoding", + action="store_true", + help="Whether to use fp16 decoding to predict. ") + args = parser.parse_args() + return args + + +def post_process_seq(seq, bos_idx, eos_idx, output_bos=False, output_eos=False): + """ + Post-process the decoded sequence. + """ + eos_pos = len(seq) - 1 + for i, idx in enumerate(seq): + if idx == eos_idx: + eos_pos = i + break + seq = [ + idx for idx in seq[:eos_pos + 1] + if (output_bos or idx != bos_idx) and (output_eos or idx != eos_idx) + ] + return seq + + +def do_predict(args): + place = "gpu" + paddle.set_device(place) + + # Define data loader + test_loader, to_tokens = reader.create_infer_loader(args) + + # Define model + transformer = FasterTransformer( + src_vocab_size=args.src_vocab_size, + trg_vocab_size=args.trg_vocab_size, + max_length=args.max_length + 1, + n_layer=args.n_layer, + n_head=args.n_head, + d_model=args.d_model, + d_inner_hid=args.d_inner_hid, + dropout=args.dropout, + weight_sharing=args.weight_sharing, + bos_id=args.bos_idx, + eos_id=args.eos_idx, + beam_size=args.beam_size, + max_out_len=args.max_out_len, + decoding_lib=args.decoding_lib, + use_fp16_decoding=args.use_fp16_decoding) + + # Set evaluate mode + transformer.eval() + + # Load checkpoint. + transformer.load(init_from_params=os.path.join(args.init_from_params, + "transformer.pdparams")) + + f = open(args.output_file, "w") + with paddle.no_grad(): + for (src_word, ) in test_loader: + finished_seq = transformer(src_word=src_word) + finished_seq = finished_seq.numpy().transpose([1, 2, 0]) + for ins in finished_seq: + for beam_idx, beam in enumerate(ins): + if beam_idx >= args.n_best: + break + id_list = post_process_seq(beam, args.bos_idx, args.eos_idx) + word_list = to_tokens(id_list) + sequence = " ".join(word_list) + "\n" + f.write(sequence) + + +if __name__ == "__main__": + ARGS = parse_args() + yaml_file = ARGS.config + with open(yaml_file, 'rt') as f: + args = AttrDict(yaml.safe_load(f)) + pprint(args) + args.decoding_lib = ARGS.decoding_lib + args.use_fp16_decoding = ARGS.use_fp16_decoding + + do_predict(args) diff --git a/paddlenlp/__init__.py b/paddlenlp/__init__.py index 1488df683b3f..73722a32d72b 100644 --- a/paddlenlp/__init__.py +++ b/paddlenlp/__init__.py @@ -17,6 +17,7 @@ from . import data from . import datasets from . import embeddings +from . import ext_op from . import layers from . import metrics from . import models diff --git a/paddlenlp/ext_op/CMakeLists.txt b/paddlenlp/ext_op/CMakeLists.txt new file mode 100644 index 000000000000..5196524c4999 --- /dev/null +++ b/paddlenlp/ext_op/CMakeLists.txt @@ -0,0 +1,141 @@ +cmake_minimum_required(VERSION 3.10 FATAL_ERROR) +project(FasterTransformer LANGUAGES CXX CUDA) + +find_package(CUDA 10.1 REQUIRED) + +INCLUDE(ExternalProject) + +set(CXX_STD "11" CACHE STRING "C++ standard") + +set(CUDA_PATH ${CUDA_TOOLKIT_ROOT_DIR}) + +list(APPEND CMAKE_MODULE_PATH ${CUDA_PATH}/lib64) + +if (${CUDA_VERSION} GREATER_EQUAL 11.0) + message(STATUS "Add DCUDA11_MODE") + add_definitions("-DCUDA11_MODE") +endif() + +# Setting compiler flags +set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS}") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") +set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler -Wall") + +if (SM STREQUAL 80 OR + SM STREQUAL 86 OR + SM STREQUAL 70 OR + SM STREQUAL 75 OR + SM STREQUAL 61 OR + SM STREQUAL 60) +set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -gencode=arch=compute_${SM},code=\\\"sm_${SM},compute_${SM}\\\"") + if (SM STREQUAL 70 OR SM STREQUAL 75 OR SM STREQUAL 80 OR SM STREQUAL 86) + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DWMMA") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DWMMA") + set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -DWMMA") + endif() +message("-- Assign GPU architecture (sm=${SM})") + +else() +set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} \ + -gencode=arch=compute_70,code=\\\"sm_70,compute_70\\\" \ + -gencode=arch=compute_75,code=\\\"sm_75,compute_75\\\" \ + ") + +set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DWMMA") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DWMMA") +set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -DWMMA") + +message("-- Assign GPU architecture (sm=70,75)") +endif() + +set(CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} -Wall -O0") +set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -Wall -O0") +set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -O0 -G -Xcompiler -Wall") + +set(CMAKE_CXX_STANDARD "${CXX_STD}") +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-extended-lambda") +set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr") +set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --std=c++${CXX_STD}") + +set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3") +set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -Xcompiler -O3") + +set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) +set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) +set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) + +list(APPEND COMMON_HEADER_DIRS + ${PROJECT_SOURCE_DIR} + ${CUDA_PATH}/include) + +set(COMMON_LIB_DIRS + ${CUDA_PATH}/lib64 +) + +set(THIRD_PARTY_PATH "third-party") +set(THIRD_PARTY_NAME "fastertransformer") + +ExternalProject_Add( + extern_${THIRD_PARTY_NAME} + GIT_REPOSITORY https://github.com/NVIDIA/DeepLearningExamples.git + GIT_TAG master + PREFIX ${THIRD_PARTY_PATH} + SOURCE_DIR ${THIRD_PARTY_PATH}/${THIRD_PARTY_NAME} + SOURCE_SUBDIR "FasterTransformer/v3.1" + BINARY_DIR ${THIRD_PARTY_PATH}/build + INSTALL_COMMAND "" + CMAKE_ARGS -DCMAKE_BUILD_TYPE=Release -DSM=${SM} +) +ExternalProject_Get_property(extern_${THIRD_PARTY_NAME} BINARY_DIR) +ExternalProject_Get_property(extern_${THIRD_PARTY_NAME} SOURCE_DIR) +ExternalProject_Get_property(extern_${THIRD_PARTY_NAME} SOURCE_SUBDIR) + +set(FT_INCLUDE_PATH ${SOURCE_DIR}/${SOURCE_SUBDIR}) +set(FT_LIB_PATH ${BINARY_DIR}/lib) + +include_directories( + ${FT_INCLUDE_PATH} +) + +set(PYTHON_PATH "python" CACHE STRING "Python path") + +execute_process(COMMAND ${PYTHON_PATH} "-c" "from __future__ import print_function; import paddle; print(paddle.sysconfig.get_include())" + RESULT_VARIABLE _INC_PYTHON_SUCCESS + OUTPUT_VARIABLE _INC_PYTHON_VALUES) +if (NOT _INC_PYTHON_SUCCESS MATCHES 0) + message(FATAL_ERROR "Python config Error.") +endif() +string(REGEX REPLACE ";" "\\\\;" _INC_PYTHON_VALUES ${_INC_PYTHON_VALUES}) +string(REGEX REPLACE "\n" ";" _INC_PYTHON_VALUES ${_INC_PYTHON_VALUES}) +list(GET _INC_PYTHON_VALUES 0 PY_INCLUDE_DIR) +list(GET _INC_PYTHON_VALUES 1 PY_SUFFIX) + +list(APPEND COMMON_HEADER_DIRS ${PY_INCLUDE_DIR}) +list(APPEND COMMON_HEADER_DIRS ${PY_INCLUDE_DIR}/third_party) + +include_directories( + ${COMMON_HEADER_DIRS} +) + +execute_process(COMMAND ${PYTHON_PATH} "-c" "from __future__ import print_function; import paddle; print(paddle.sysconfig.get_lib())" + RESULT_VARIABLE _LIB_PYTHON_SUCCESS + OUTPUT_VARIABLE _LIB_PYTHON_VALUES) +if (NOT _LIB_PYTHON_SUCCESS MATCHES 0) + message(FATAL_ERROR "Python config Error.") +endif() +string(REGEX REPLACE ";" "\\\\;" _LIB_PYTHON_VALUES ${_LIB_PYTHON_VALUES}) +string(REGEX REPLACE "\n" ";" _LIB_PYTHON_VALUES ${_LIB_PYTHON_VALUES}) +list(GET _LIB_PYTHON_VALUES 0 PY_LIB_DIR) +list(GET _LIB_PYTHON_VALUES 1 PY_SUFFIX) +list(APPEND COMMON_LIB_DIRS ${PY_LIB_DIR}) + +link_directories( + ${COMMON_LIB_DIRS} +) + +link_directories( + ${FT_LIB_PATH} +) + +add_subdirectory(src) diff --git a/paddlenlp/ext_op/README.md b/paddlenlp/ext_op/README.md new file mode 100644 index 000000000000..f45820d12956 --- /dev/null +++ b/paddlenlp/ext_op/README.md @@ -0,0 +1,109 @@ +# 自定义OP编译使用 + +## 子目录结构 + +```text +. +├── sample/ # 基于 Transformer 机器翻译使用样例(beam search) +├── src/ # 自定义 OP C++ CUDA 代码 +└── transformer/ # Python API 封装脚本 +``` + +## 使用环境说明 + +* 本项目依赖于 PaddlePaddle 2.0.1 及以上版本或适当的 develop 版本 +* CMake >= 3.10 +* CUDA 10.1 或是更新的版本(需要 PaddlePaddle 框架一致) +* gcc 版本需要与编译 PaddlePaddle 版本一致,比如使用 gcc8.2 +* 推荐使用 Python3 +* [Faster Transformer](https://github.com/NVIDIA/DeepLearningExamples/tree/master/FasterTransformer/v3.1#setup) 使用必要的环境 + +## 快速开始 + +### 编译自定义OP + +自定义 OP 需要将实现的 C++、CUDA 代码编译成动态库,我们提供对应的 CMakeLists.txt ,可以参考使用如下的方式完成编译。 + +#### 克隆 PaddleNLP + +首先,因为需要基于当前环境重新编译,当前的 paddlenlp 的 python 包里面并不包含 Faster Transformer 相关 lib,需要克隆一个 PaddleNLP,并重新编译: +``` sh +git clone https://github.com/PaddlePaddle/PaddleNLP.git +``` + +其次,配置环境变量,让我们可以使用当前 clone 的 paddlenlp,并进入到自定义 OP 的路径,准备后续的编译操作: + +``` sh +export PYTHONPATH=$PWD/PaddleNLP/:$PYTHONPATH +cd PaddleNLP/paddlenlp/ext_op/ +``` + +#### 编译 + +编译之前,请确保安装的 PaddlePaddle 的版本需要大于 2.0.1,并且正常可用。 + +编译自定义 OP 可以参照一下步骤: + +``` sh +mkdir build +cd build/ +cmake .. -DSM=xx -DCMAKE_BUILD_TYPE=Release +make -j +cd ../ +``` + +注意:`xx` 是指的所用 GPU 的 compute capability。举例来说,可以将之指定为 70(V100) 或是 75(T4)。 + +最终,编译会在 `./build/lib/` 路径下,产出 `libdecoding_op.so`,即需要的 Faster Transformer decoding 执行的库。 + +#### 使用 + +编写 python 脚本的时候,调用 `FasterTransformer` API 并传入 `libdecoding_op.so` 的位置即可实现将 Faster Transformer 用于当前的预测。 + +举例如下: + +``` python +from paddlenlp.ext_op import FasterTransformer + +transformer = FasterTransformer( + src_vocab_size=args.src_vocab_size, + trg_vocab_size=args.trg_vocab_size, + max_length=args.max_length + 1, + n_layer=args.n_layer, + n_head=args.n_head, + d_model=args.d_model, + d_inner_hid=args.d_inner_hid, + dropout=args.dropout, + weight_sharing=args.weight_sharing, + bos_id=args.bos_idx, + eos_id=args.eos_idx, + beam_size=args.beam_size, + max_out_len=args.max_out_len, + decoding_lib=args.decoding_lib, + use_fp16_decoding=args.use_fp16_decoding) +``` + +更详细的例子可以参考 `./sample/decoding_sample.py` 以及 `./sample/encoder_decoding_sample.py`,我们提供了更详细用例。 + +#### 执行 decoding on PaddlePaddle + +使用 PaddlePaddle 仅执行 decoding 测试(float32): + +``` sh +export CUDA_VISIBLE_DEVICES=0 +export FLAGS_fraction_of_gpu_memory_to_use=0.1 +./build/third-party/build/bin/decoding_gemm 32 4 8 64 30000 32 512 0 +python sample/decoding_sample.py --config ./sample/config/decoding.sample.yaml --decoding-lib ./build/lib/libdecoding_op.so +``` + +使用 PaddlePaddle 仅执行 decoding 测试(float16): +执行 float16 的 decoding,需要在执行的时候,加上 `--use-fp16-decoding` 选项。 + +``` sh +export CUDA_VISIBLE_DEVICES=0 +export FLAGS_fraction_of_gpu_memory_to_use=0.1 +./build/third-party/build/bin/decoding_gemm 32 4 8 64 30000 32 512 1 +python sample/decoding_sample.py --config ./sample/config/decoding.sample.yaml --decoding-lib ./build/lib/libdecoding_op.so --use-fp16-decoding +``` + +其中,`decoding_gemm` 不同参数的意义可以参考 [FasterTransformer 文档](https://github.com/NVIDIA/DeepLearningExamples/tree/master/FasterTransformer/v3.1#execute-the-decoderdecoding-demos)。 diff --git a/paddlenlp/ext_op/__init__.py b/paddlenlp/ext_op/__init__.py new file mode 100644 index 000000000000..3e1c5915157d --- /dev/null +++ b/paddlenlp/ext_op/__init__.py @@ -0,0 +1,2 @@ +from .transformer.decoding import * +from .transformer.faster_transformer import * diff --git a/paddlenlp/ext_op/sample/config/decoding.sample.yaml b/paddlenlp/ext_op/sample/config/decoding.sample.yaml new file mode 100644 index 000000000000..5724d8460641 --- /dev/null +++ b/paddlenlp/ext_op/sample/config/decoding.sample.yaml @@ -0,0 +1,95 @@ +# The frequency to save trained models when training. +save_step: 10000 +# The frequency to fetch and print output when training. +print_step: 100 +# Path of the checkpoint, to resume the previous training +init_from_checkpoint: "" +# Path of the pretrain model, to better solve the current task +init_from_pretrain_model: "" +# Path of trained parameter, to make prediction +init_from_params: "./trained_models/step_final/" +# The directory for saving model +save_model: "trained_models" +# The directory for saving inference model +inference_model_dir: "infer_model" +# Set seed for CE or debug +random_seed: None +# The file to output the translation results of predict_file to. +output_file: "predict.txt" +# The , and tokens in the dictionary. +special_token: ["", "", ""] +# The directory to store data. +root: None + +# Whether to use cuda +use_gpu: True + +# Args for reader, see reader.py for details +pool_size: 200000 +sort_type: "global" +batch_size: 4096 +infer_batch_size: 32 +shuffle_batch: True +# Data shuffle only works when sort_type is pool or none +shuffle: True +# shuffle_seed must be set when shuffle is True and using multi-cards to train. +# Otherwise, the number of batches cannot be guaranteed. +shuffle_seed: 128 + +# Hyparams for training: +# The number of epoches for training +epoch: 30 + +# The hyper parameters for Adam optimizer. +# This static learning_rate will be applied to the LearningRateScheduler +# derived learning rate the to get the final learning rate. +learning_rate: 2.0 +beta1: 0.9 +beta2: 0.997 +eps: 1e-9 +# The parameters for learning rate scheduling. +warmup_steps: 8000 +# The weight used to mix up the ground-truth distribution and the fixed +# uniform distribution in label smoothing when training. +# Set this as zero if label smoothing is not wanted. +label_smooth_eps: 0.1 + +# Hyparams for generation: +# The parameters for beam search. +beam_size: 4 +max_out_len: 32 +# The number of decoded sentences to output. +n_best: 1 + +# Hyparams for model: +# These following five vocabularies related configurations will be set +# automatically according to the passed vocabulary path and special tokens. +# Size of source word dictionary. +src_vocab_size: 30000 +# Size of target word dictionay +trg_vocab_size: 30000 +# Used to pad vocab size to be multiple of pad_factor. +pad_factor: 8 +# Index for token +bos_idx: 0 +# Index for token +eos_idx: 1 +# Index for token +unk_idx: 2 +# Max length of sequences deciding the size of position encoding table. +max_length: 32 +# The dimension for word embeddings, which is also the last dimension of +# the input and output of multi-head attention, position-wise feed-forward +# networks, encoder and decoder. +d_model: 512 +# Size of the hidden layer in position-wise feed-forward networks. +d_inner_hid: 2048 +# Number of head used in multi-head attention. +n_head: 8 +# Number of sub-layers to be stacked in the encoder and decoder. +n_layer: 6 +# Dropout rates. +dropout: 0.1 +# The flag indicating whether to share embedding and softmax weights. +# Vocabularies in source and target should be same for weight sharing. +weight_sharing: True diff --git a/paddlenlp/ext_op/sample/decoding_sample.py b/paddlenlp/ext_op/sample/decoding_sample.py new file mode 100644 index 000000000000..95f25b522bd8 --- /dev/null +++ b/paddlenlp/ext_op/sample/decoding_sample.py @@ -0,0 +1,107 @@ +import sys +import os +import numpy as np +from attrdict import AttrDict +import argparse +import time + +import paddle + +import yaml +from pprint import pprint + +from paddlenlp.ext_op import FasterTransformer + +from paddlenlp.utils.log import logger + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--config", + default="./sample/config/decoding.sample.yaml", + type=str, + help="Path of the config file. ") + parser.add_argument( + "--decoding-lib", + default="./build/lib/libdecoding_op.so", + type=str, + help="Path of libdecoding_op.so. ") + parser.add_argument( + "--use-fp16-decoding", + action="store_true", + help="Whether to use fp16 decoding to predict. ") + args = parser.parse_args() + return args + + +def generate_encoder_result(batch_size, max_seq_len, memory_hidden_dim, dtype): + memory_sequence_length = np.random.randint( + 1, max_seq_len, size=batch_size).astype(np.int32) + memory_sequence_length[np.random.randint(0, batch_size)] = max_seq_len + outter_embbeding = np.random.randn(memory_hidden_dim) * 0.01 + + memory = [] + mem_max_seq_len = np.max(memory_sequence_length) + for i in range(batch_size): + data = np.random.randn(mem_max_seq_len, memory_hidden_dim) * 0.01 + for j in range(memory_sequence_length[i], mem_max_seq_len): + data[j] = outter_embbeding + memory.append(data) + memory = np.asarray(memory) + memory = paddle.to_tensor(memory, dtype=dtype) + memory_sequence_length = paddle.to_tensor( + memory_sequence_length, dtype="int32") + + return memory, memory_sequence_length + + +def do_predict(args): + place = "gpu" + place = paddle.set_device(place) + + # Define model + transformer = FasterTransformer( + src_vocab_size=args.src_vocab_size, + trg_vocab_size=args.trg_vocab_size, + max_length=args.max_length + 1, + n_layer=args.n_layer, + n_head=args.n_head, + d_model=args.d_model, + d_inner_hid=args.d_inner_hid, + dropout=args.dropout, + weight_sharing=args.weight_sharing, + bos_id=args.bos_idx, + eos_id=args.eos_idx, + beam_size=args.beam_size, + max_out_len=args.max_out_len, + decoding_lib=args.decoding_lib, + use_fp16_decoding=args.use_fp16_decoding) + + # Set evaluate mode + transformer.eval() + + enc_output, mem_seq_len = generate_encoder_result( + args.infer_batch_size, args.max_length, args.d_model, "float16" + if args.use_fp16_decoding else "float32") + with paddle.no_grad(): + for i in range(100): + # For warmup. + if 50 == i: + start = time.time() + transformer.decoding( + enc_output=enc_output, memory_seq_lens=mem_seq_len) + logger.info("Average test time for decoding is %f ms" % ( + (time.time() - start) / 50 * 1000)) + + +if __name__ == "__main__": + ARGS = parse_args() + yaml_file = ARGS.config + with open(yaml_file, 'rt') as f: + args = AttrDict(yaml.safe_load(f)) + pprint(args) + args.decoding_lib = ARGS.decoding_lib + args.use_fp16_decoding = ARGS.use_fp16_decoding + + do_predict(args) diff --git a/paddlenlp/ext_op/sample/encoder_decoding_sample.py b/paddlenlp/ext_op/sample/encoder_decoding_sample.py new file mode 100644 index 000000000000..07af848ca32d --- /dev/null +++ b/paddlenlp/ext_op/sample/encoder_decoding_sample.py @@ -0,0 +1,108 @@ +import sys +import os +import numpy as np +from attrdict import AttrDict +import argparse +import time + +import paddle + +import yaml +from pprint import pprint + +from paddlenlp.ext_op import FasterTransformer + +from paddlenlp.utils.log import logger +from paddlenlp.data import Pad + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--config", + default="./sample/config/decoding.sample.yaml", + type=str, + help="Path of the config file. ") + parser.add_argument( + "--decoding-lib", + default="./build/lib/libdecoding_op.so", + type=str, + help="Path of libdecoding_op.so. ") + parser.add_argument( + "--use-fp16-decoding", + action="store_true", + help="Whether to use fp16 decoding to predict. ") + args = parser.parse_args() + return args + + +def generate_src_word(batch_size, vocab_size, max_length, eos_idx, pad_idx): + memory_sequence_length = np.random.randint( + low=1, high=max_length, size=batch_size).astype(np.int32) + data = [] + for i in range(batch_size): + data.append( + np.random.randint( + low=3, + high=vocab_size, + size=memory_sequence_length[i], + dtype=np.int64)) + + word_pad = Pad(pad_idx) + src_word = word_pad([list(word) + [eos_idx] for word in data]) + + return paddle.to_tensor(src_word, dtype="int64") + + +def do_predict(args): + place = "gpu" + paddle.set_device(place) + + # Define model + transformer = FasterTransformer( + src_vocab_size=args.src_vocab_size, + trg_vocab_size=args.trg_vocab_size, + max_length=args.max_length + 1, + n_layer=args.n_layer, + n_head=args.n_head, + d_model=args.d_model, + d_inner_hid=args.d_inner_hid, + dropout=args.dropout, + weight_sharing=args.weight_sharing, + bos_id=args.bos_idx, + eos_id=args.eos_idx, + beam_size=args.beam_size, + max_out_len=args.max_out_len, + decoding_lib=args.decoding_lib, + use_fp16_decoding=args.use_fp16_decoding) + + # Set evaluate mode + transformer.eval() + + src_word = generate_src_word( + batch_size=args.infer_batch_size, + vocab_size=args.src_vocab_size, + max_length=args.max_length, + eos_idx=args.eos_idx, + pad_idx=args.bos_idx) + + with paddle.no_grad(): + for i in range(100): + # For warmup. + if 50 == i: + start = time.time() + transformer(src_word=src_word) + logger.info("Average test time for encoder-decoding is %f ms" % ( + (time.time() - start) / 50 * 1000)) + + +if __name__ == "__main__": + ARGS = parse_args() + yaml_file = ARGS.config + with open(yaml_file, 'rt') as f: + args = AttrDict(yaml.safe_load(f)) + pprint(args) + args.decoding_lib = ARGS.decoding_lib + args.use_fp16_decoding = ARGS.use_fp16_decoding + + do_predict(args) diff --git a/paddlenlp/ext_op/src/CMakeLists.txt b/paddlenlp/ext_op/src/CMakeLists.txt new file mode 100644 index 000000000000..a3acec118562 --- /dev/null +++ b/paddlenlp/ext_op/src/CMakeLists.txt @@ -0,0 +1,23 @@ +include_directories(${PY_INCLUDE_DIR}) +include_directories(${PY_INCLUDE_DIR}\third_party) + +if(EXISTS ${PY_LIB_DIR}/libpaddle_framework.so) + set(lib_link + -lpaddle_framework + ) +endif() + +set(ft_lib_link + -ldecoder -ldecoding -ltopk -lcuda_int8_kernels -lcuda_kernels -lonline_softmax_beamsearch +) + +add_definitions(-DPADDLE_WITH_CUDA) +add_definitions(-DEIGEN_USE_GPU) +add_definitions(-DPADDLE_USE_DSO) +add_definitions(-DPADDLE_WITH_MKLDNN) + +set(decoding_op_files fusion_decoding_op.cc fusion_decoding_op.cu) + +add_library(decoding_op SHARED ${decoding_op_files}) +add_dependencies(decoding_op extern_${THIRD_PARTY_NAME}) +target_link_libraries(decoding_op PRIVATE -lcublas -lcudart ${lib_link} ${ft_lib_link}) diff --git a/paddlenlp/ext_op/src/fusion_decoding_op.cc b/paddlenlp/ext_op/src/fusion_decoding_op.cc new file mode 100644 index 000000000000..81e6a3535a10 --- /dev/null +++ b/paddlenlp/ext_op/src/fusion_decoding_op.cc @@ -0,0 +1,185 @@ +#include +#include + +#include "fastertransformer/common.h" +#include "fastertransformer/decoding_beamsearch.h" +#include "fastertransformer/open_decoder.h" + +#include "fusion_decoding_op.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/errors.h" + +namespace paddle { +namespace operators { + +class FusionDecodingOp : public framework::OperatorWithKernel { +public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + auto input_dims = ctx->GetInputDim("Input"); + auto beam_size = ctx->Attrs().Get("beam_size"); + auto max_len = ctx->Attrs().Get("max_len"); + int batch_size = input_dims[0] / beam_size; + + auto output_dims = framework::make_ddim({max_len, batch_size, beam_size}); + ctx->SetOutputDim("OutputIds", output_dims); + ctx->SetOutputDim("ParentIds", output_dims); + ctx->SetOutputDim("SequenceLength", + framework::make_ddim({batch_size * beam_size})); + } + +protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input"); + return framework::OpKernelType(data_type, ctx.device_context()); + } +}; + +class FusionDecodingOpMaker : public framework::OpProtoAndCheckerMaker { +public: + void Make() override { + // do op maker. + // Add Parameters. + AddInput("Input", "The input of fusion_decoding op. "); + AddInput("MemSeqLen", "The sequence lengths of memory sequence. "); + AddInput("WordEmbedding", + "The input represents embedding tensors for target Ids. "); + + AddInput("SelfLayernormWeight", + "The tensors of layer norm's scale before self " + "attention layers. ") + .AsDuplicable(); + AddInput("SelfLayernormBias", + "The tensors of layer norm's bias before self attention " + "layers. ") + .AsDuplicable() + .AsDispensable(); + AddInput("SelfQueryWeight", + "The tensors of self attention's query projection weights. ") + .AsDuplicable(); + AddInput("SelfQueryBias", + "The tensors of self attention's query projection biases. ") + .AsDuplicable() + .AsDispensable(); + AddInput("SelfKeyWeight", + "The tensors of self attention's key projection weights. ") + .AsDuplicable(); + AddInput("SelfKeyBias", + "The tensors of self attention's key projection biases. ") + .AsDuplicable() + .AsDispensable(); + AddInput("SelfValueWeight", + "The tensors of self attention's value projection weights. ") + .AsDuplicable(); + AddInput("SelfValueBias", + "The tensors of self attention's value projection biases. ") + .AsDuplicable() + .AsDispensable(); + AddInput("SelfOutWeight", + "The tensors of self attention's output projection weights. ") + .AsDuplicable(); + AddInput("SelfOutBias", + "The tensors of self attention's output projection biases. ") + .AsDuplicable() + .AsDispensable(); + + AddInput( + "CrossLayernormWeight", + "The tensors of layer norm's weights before cross attention layers. ") + .AsDuplicable(); + AddInput( + "CrossLayernormBias", + "The tensors of layer norm's biases before cross attention layers. ") + .AsDuplicable() + .AsDispensable(); + AddInput("CrossQueryWeight", + "The tensors of cross attention's query projection weights. ") + .AsDuplicable(); + AddInput("CrossQueryBias", + "The tensors of cross attention's query projection biases. ") + .AsDuplicable() + .AsDispensable(); + AddInput("CrossKeyWeight", + "The tensors of cross attention's key projection weights. ") + .AsDuplicable(); + AddInput("CrossKeyBias", + "The tensors of cross attention's key projection biases. ") + .AsDuplicable() + .AsDispensable(); + AddInput("CrossValueWeight", + "The tensors of cross attention's value projection weights. ") + .AsDuplicable(); + AddInput("CrossValueBias", + "The tensors of cross attention's value projection biases. ") + .AsDuplicable() + .AsDispensable(); + AddInput("CrossOutWeight", + "The tensors of cross attention's output projection weights. ") + .AsDuplicable(); + AddInput("CrossOutBias", + "The tensors of cross attention's output projection biases. ") + .AsDuplicable() + .AsDispensable(); + + AddInput("FFNLayernormWeight", + "The tensors of layer norm's weights before ffn. ") + .AsDuplicable(); + AddInput("FFNLayernormBias", + "The tensors of layer norm's biases before ffn. ") + .AsDuplicable() + .AsDispensable(); + AddInput("FFNInterWeight", "The tensors of inter fc weights. ") + .AsDuplicable(); + AddInput("FFNInterBias", "The tensors of inter fc biases. ") + .AsDuplicable() + .AsDispensable(); + AddInput("FFNOutWeight", "The tensors of output weights. ").AsDuplicable(); + AddInput("FFNOutBias", "The tensors of output biases. ") + .AsDuplicable() + .AsDispensable(); + + AddInput("DecoderLayernormWeight", + "The tensor of layer norm's weights after decoders. "); + AddInput("DecoderLayernormBias", + "The tensor of layer norm's biases after decoders. "); + AddInput("EmbWeight", "The tensor of logits projection weight. "); + AddInput("EmbBias", "The tensor of logits projection bias. ") + .AsDispensable(); + AddInput("PositionEncEmb", "The tensor of positional enbedding table. "); + + AddOutput("OutputIds", "The tensor of output ids. "); + AddOutput("ParentIds", "The tensor of parent ids. "); + AddOutput("SequenceLength", "The tensor of sequence length. "); + + AddAttr( + "decoding_strategy", + "Decoding strategies. As for now, only beam search is supported. ") + .SetDefault("beam_search"); + AddAttr("beam_size", "The beam size for beam search. ").SetDefault(1); + AddAttr("n_head", "The number of heads. ").SetDefault(8); + AddAttr("size_per_head", "The size per head. ").SetDefault(64); + AddAttr("num_layer", "The number of layers. ").SetDefault(6); + AddAttr("bos_id", "Start id. ").SetDefault(0); + AddAttr("eos_id", "End id. ").SetDefault(1); + AddAttr("max_len", "Max output length. ").SetDefault(256); + AddAttr("beam_search_diversity_rate", + "The diversity rate for beam search. ") + .SetDefault(0.0); + + AddComment(R"DOC( + Decoding Operator. +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_WITHOUT_GRADIENT(fusion_decoding, + ops::FusionDecodingOp, + ops::FusionDecodingOpMaker); +REGISTER_OP_CPU_KERNEL(fusion_decoding, ops::NotImpleKernel); diff --git a/paddlenlp/ext_op/src/fusion_decoding_op.cu b/paddlenlp/ext_op/src/fusion_decoding_op.cu new file mode 100644 index 000000000000..06a18a3c6e6f --- /dev/null +++ b/paddlenlp/ext_op/src/fusion_decoding_op.cu @@ -0,0 +1,266 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "fastertransformer/cuda/cub/cub.cuh" + + +#include "fusion_decoding_op.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/memory/memcpy.h" +#include "paddle/fluid/platform/cuda_device_function.h" +#include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/platform/enforce.h" +#include "pd_traits.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; + +template +class FusionDecodingKernel : public framework::OpKernel { +public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto stream = ctx.cuda_device_context().stream(); + auto& dev_ctx = + ctx.template device_context(); + + auto* input = ctx.Input("Input"); + + auto* memory_sequence_length = ctx.Input("MemSeqLen"); + auto* word_emb = ctx.Input("WordEmbedding"); + + auto self_layernorm_weight = ctx.MultiInput("SelfLayernormWeight"); + auto self_layernorm_bias = ctx.MultiInput("SelfLayernormBias"); + auto self_attn_query_weight = ctx.MultiInput("SelfQueryWeight"); + auto self_attn_query_bias = ctx.MultiInput("SelfQueryBias"); + auto self_attn_key_weight = ctx.MultiInput("SelfKeyWeight"); + auto self_attn_key_bias = ctx.MultiInput("SelfKeyBias"); + auto self_attn_value_weight = ctx.MultiInput("SelfValueWeight"); + auto self_attn_value_bias = ctx.MultiInput("SelfValueBias"); + auto self_attn_output_weight = ctx.MultiInput("SelfOutWeight"); + auto self_attn_output_bias = ctx.MultiInput("SelfOutBias"); + + auto cross_layernorm_weight = + ctx.MultiInput("CrossLayernormWeight"); + auto cross_layernorm_bias = ctx.MultiInput("CrossLayernormBias"); + auto cross_attn_query_weight = ctx.MultiInput("CrossQueryWeight"); + auto cross_attn_query_bias = ctx.MultiInput("CrossQueryBias"); + auto cross_attn_key_weight = ctx.MultiInput("CrossKeyWeight"); + auto cross_attn_key_bias = ctx.MultiInput("CrossKeyBias"); + auto cross_attn_value_weight = ctx.MultiInput("CrossValueWeight"); + auto cross_attn_value_bias = ctx.MultiInput("CrossValueBias"); + auto cross_attn_output_weight = ctx.MultiInput("CrossOutWeight"); + auto cross_attn_output_bias = ctx.MultiInput("CrossOutBias"); + + auto ffn_layernorm_weight = ctx.MultiInput("FFNLayernormWeight"); + auto ffn_layernorm_bias = ctx.MultiInput("FFNLayernormBias"); + auto ffn_intermediate_weight = ctx.MultiInput("FFNInterWeight"); + auto ffn_intermediate_bias = ctx.MultiInput("FFNInterBias"); + auto ffn_output_weight = ctx.MultiInput("FFNOutWeight"); + auto ffn_output_bias = ctx.MultiInput("FFNOutBias"); + + auto* decoder_layernorm_weight = + ctx.Input("DecoderLayernormWeight"); + auto* decoder_layernorm_bias = ctx.Input("DecoderLayernormBias"); + auto* embedding_weight = ctx.Input("EmbWeight"); + auto* embedding_bias = ctx.Input("EmbBias"); + auto* position_encoding_table = ctx.Input("PositionEncEmb"); + + auto* output_ids = ctx.Output("OutputIds"); + auto* parent_ids = ctx.Output("ParentIds"); + auto* sequence_length = ctx.Output("SequenceLength"); + + // Not used for now. + std::string decoding_strategy = ctx.Attr("decoding_strategy"); + int beam_width_ = + (decoding_strategy == "beam_search") ? ctx.Attr("beam_size") : 1; + int64_t max_seq_len_ = ctx.Attr("max_len"); + int head_num_ = ctx.Attr("n_head"); + int size_per_head_ = ctx.Attr("size_per_head"); + int num_layer_ = ctx.Attr("num_layer"); + int start_id_ = ctx.Attr("bos_id"); + int end_id_ = ctx.Attr("eos_id"); + float beam_search_diversity_rate_ = + ctx.Attr("beam_search_diversity_rate"); + + typedef PDTraits traits_; + typedef typename traits_::DataType DataType_; + + auto input_dims = input->dims(); + int batch_size_ = input_dims[0] / beam_width_; + const int memory_max_seq_len = input_dims[1]; + const int memory_hidden_dim = input_dims[2]; + const int vocab_size = word_emb->dims()[0]; + + DecodingInitParam decoding_params; + decoding_params.cublas_handle = dev_ctx.cublas_handle(); + + decoding_params.output_ids = output_ids->mutable_data(ctx.GetPlace()); + decoding_params.parent_ids = parent_ids->mutable_data(ctx.GetPlace()); + decoding_params.sequence_length = + sequence_length->mutable_data(ctx.GetPlace()); + + typedef DecoderTransformerTraits DecodingTraits_; + DecodingBeamsearch* decoding_beamsearch_; + decoding_params.stream = stream; + int device_id; + cudaGetDevice(&device_id); + fastertransformer::Allocator allocator_(device_id); + + decoding_beamsearch_ = new DecodingBeamsearch( + allocator_, + batch_size_, + beam_width_, + max_seq_len_, + head_num_, + size_per_head_, + vocab_size, + num_layer_, + memory_hidden_dim, + memory_max_seq_len, + start_id_, + end_id_, + beam_search_diversity_rate_); + + decoding_params.memory_tensor = + reinterpret_cast(input->data()); + decoding_params.memory_sequence_length = + memory_sequence_length->data(); + + DecoderInitParam* params = + new DecoderInitParam[num_layer_]; + + for (int i = 0; i < num_layer_; i++) { + params[i].stream = stream; + params[i].cublas_handle = dev_ctx.cublas_handle(); + + // self attn + params[i].self_layernorm.gamma = reinterpret_cast( + self_layernorm_weight[i]->data()); + params[i].self_layernorm.beta = + reinterpret_cast(self_layernorm_bias[i]->data()); + // query + params[i].self_attention.query_weight.kernel = + reinterpret_cast( + self_attn_query_weight[i]->data()); + params[i].self_attention.query_weight.bias = + reinterpret_cast( + self_attn_query_bias[i]->data()); + // key + params[i].self_attention.key_weight.kernel = + reinterpret_cast( + self_attn_key_weight[i]->data()); + params[i].self_attention.key_weight.bias = + reinterpret_cast(self_attn_key_bias[i]->data()); + // value + params[i].self_attention.value_weight.kernel = + reinterpret_cast( + self_attn_value_weight[i]->data()); + params[i].self_attention.value_weight.bias = + reinterpret_cast( + self_attn_value_bias[i]->data()); + // out proj + params[i].self_attention.attention_output_weight.kernel = + reinterpret_cast( + self_attn_output_weight[i]->data()); + params[i].self_attention.attention_output_weight.bias = + reinterpret_cast( + self_attn_output_bias[i]->data()); + + // cross + params[i].cross_layernorm.gamma = reinterpret_cast( + cross_layernorm_weight[i]->data()); + params[i].cross_layernorm.beta = reinterpret_cast( + cross_layernorm_bias[i]->data()); + // query + params[i].cross_attention.query_weight.kernel = + reinterpret_cast( + cross_attn_query_weight[i]->data()); + params[i].cross_attention.query_weight.bias = + reinterpret_cast( + cross_attn_query_bias[i]->data()); + // key + params[i].cross_attention.key_weight.kernel = + reinterpret_cast( + cross_attn_key_weight[i]->data()); + params[i].cross_attention.key_weight.bias = + reinterpret_cast(cross_attn_key_bias[i]->data()); + // value + params[i].cross_attention.value_weight.kernel = + reinterpret_cast( + cross_attn_value_weight[i]->data()); + params[i].cross_attention.value_weight.bias = + reinterpret_cast( + cross_attn_value_bias[i]->data()); + // out proj + params[i].cross_attention.attention_output_weight.kernel = + reinterpret_cast( + cross_attn_output_weight[i]->data()); + params[i].cross_attention.attention_output_weight.bias = + reinterpret_cast( + cross_attn_output_bias[i]->data()); + + // ffn + params[i].ffn_layernorm.gamma = reinterpret_cast( + ffn_layernorm_weight[i]->data()); + params[i].ffn_layernorm.beta = + reinterpret_cast(ffn_layernorm_bias[i]->data()); + // intermediate proj + params[i].ffn.intermediate_weight.kernel = + reinterpret_cast( + ffn_intermediate_weight[i]->data()); + params[i].ffn.intermediate_weight.bias = + reinterpret_cast( + ffn_intermediate_bias[i]->data()); + // out proj + params[i].ffn.output_weight.kernel = + reinterpret_cast(ffn_output_weight[i]->data()); + params[i].ffn.output_weight.bias = + reinterpret_cast(ffn_output_bias[i]->data()); + } + + decoding_params.layernorm.gamma = + reinterpret_cast(decoder_layernorm_weight->data()); + decoding_params.layernorm.beta = + reinterpret_cast(decoder_layernorm_bias->data()); + // for embedding + decoding_params.embedding_table = + reinterpret_cast(word_emb->data()); + + // for weight sharing matmul + decoding_params.embedding_kernel = + reinterpret_cast(embedding_weight->data()); + // for matmul bias + decoding_params.embedding_bias = + (embedding_bias) + ? reinterpret_cast(embedding_bias->data()) + : nullptr; + decoding_params.position_encoding_table = + reinterpret_cast(position_encoding_table->data()); + + decoding_beamsearch_->forward(params, decoding_params); + + delete decoding_beamsearch_; + delete[] params; + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plf = paddle::platform; + +REGISTER_OP_CUDA_KERNEL( + fusion_decoding, + ops::FusionDecodingKernel, + ops::FusionDecodingKernel); diff --git a/paddlenlp/ext_op/src/fusion_decoding_op.h b/paddlenlp/ext_op/src/fusion_decoding_op.h new file mode 100644 index 000000000000..2cd6ac995ec5 --- /dev/null +++ b/paddlenlp/ext_op/src/fusion_decoding_op.h @@ -0,0 +1,20 @@ +#pragma once + +#include "fastertransformer/common.h" +#include "fastertransformer/decoding_beamsearch.h" +#include "fastertransformer/open_decoder.h" + +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class NotImpleKernel : public framework::OpKernel { +public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_THROW("CPU is not support for this kernel now. Please use GPU. "); + } +}; +} // namespace operators +} // namespace paddle diff --git a/paddlenlp/ext_op/src/pd_traits.h b/paddlenlp/ext_op/src/pd_traits.h new file mode 100644 index 000000000000..1e90da055d8e --- /dev/null +++ b/paddlenlp/ext_op/src/pd_traits.h @@ -0,0 +1,25 @@ +#pragma once + +#include "fastertransformer/common.h" +#include "paddle/fluid/platform/float16.h" + +using namespace fastertransformer; +namespace paddle { +template +class PDTraits; + +template <> +class PDTraits { +public: + typedef float DataType; + static const OperationType OpType = OperationType::FP32; +}; + +template <> +class PDTraits { +public: + typedef half DataType; + static const OperationType OpType = OperationType::FP16; +}; + +} // namespace paddle diff --git a/paddlenlp/ext_op/transformer/__init__.py b/paddlenlp/ext_op/transformer/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/paddlenlp/ext_op/transformer/decoding.py b/paddlenlp/ext_op/transformer/decoding.py new file mode 100644 index 000000000000..6f69625c75b8 --- /dev/null +++ b/paddlenlp/ext_op/transformer/decoding.py @@ -0,0 +1,312 @@ +import os +import numpy as np + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from paddle.fluid.layer_helper import LayerHelper + + +def infer_transformer_decoder( + enc_output, memory_seq_lens, word_emb, slf_ln_weight, slf_ln_bias, + slf_q_weight, slf_q_bias, slf_k_weight, slf_k_bias, slf_v_weight, + slf_v_bias, slf_out_weight, slf_out_bias, cross_ln_weight, + cross_ln_bias, cross_q_weight, cross_q_bias, cross_k_weight, + cross_k_bias, cross_v_weight, cross_v_bias, cross_out_weight, + cross_out_bias, ffn_ln_weight, ffn_ln_bias, ffn_inter_weight, + ffn_inter_bias, ffn_out_weight, ffn_out_bias, decoder_ln_weight, + decoder_ln_bias, linear_weight, linear_bias, pos_emb, _beam_size, + _n_head, _size_per_head, _n_layer, _bos_id, _eos_id, _max_out_len, + _beam_search_diversity_rate): + helper = LayerHelper('fusion_decoding', **locals()) + + inputs = { + 'Input': enc_output, + 'MemSeqLen': memory_seq_lens, + 'WordEmbedding': word_emb, + 'SelfLayernormWeight': slf_ln_weight, + 'SelfLayernormBias': slf_ln_bias, + 'SelfQueryWeight': slf_q_weight, + 'SelfQueryBias': slf_q_bias, + 'SelfKeyWeight': slf_k_weight, + 'SelfKeyBias': slf_k_bias, + 'SelfValueWeight': slf_v_weight, + 'SelfValueBias': slf_v_bias, + 'SelfOutWeight': slf_out_weight, + 'SelfOutBias': slf_out_bias, + 'CrossLayernormWeight': cross_ln_weight, + 'CrossLayernormBias': cross_ln_bias, + 'CrossQueryWeight': cross_q_weight, + 'CrossQueryBias': cross_q_bias, + 'CrossKeyWeight': cross_k_weight, + 'CrossKeyBias': cross_k_bias, + 'CrossValueWeight': cross_v_weight, + 'CrossValueBias': cross_v_bias, + 'CrossOutWeight': cross_out_weight, + 'CrossOutBias': cross_out_bias, + 'FFNLayernormWeight': ffn_ln_weight, + 'FFNLayernormBias': ffn_ln_bias, + 'FFNInterWeight': ffn_inter_weight, + 'FFNInterBias': ffn_inter_bias, + 'FFNOutWeight': ffn_out_weight, + 'FFNOutBias': ffn_out_bias, + 'DecoderLayernormWeight': decoder_ln_weight, + 'DecoderLayernormBias': decoder_ln_bias, + 'EmbWeight': linear_weight, + 'EmbBias': linear_bias, + 'PositionEncEmb': pos_emb + } + + attrs = { + 'beam_size': _beam_size, + 'n_head': _n_head, + 'size_per_head': _size_per_head, + 'num_layer': _n_layer, + 'bos_id': _bos_id, + 'eos_id': _eos_id, + 'max_len': _max_out_len, + 'beam_search_diversity_rate': _beam_search_diversity_rate + } + + output_ids = helper.create_variable_for_type_inference("int32") + parent_ids = helper.create_variable_for_type_inference("int32") + sequence_length = helper.create_variable_for_type_inference("int32") + + outputs = { + 'OutputIds': output_ids, + 'ParentIds': parent_ids, + 'SequenceLength': sequence_length + } + + helper.append_op( + type='fusion_decoding', inputs=inputs, outputs=outputs, attrs=attrs) + + return output_ids, parent_ids, sequence_length + + +def finalize(beam_size, output_ids, parent_ids, out_seq_lens, max_seq_len=None): + if max_seq_len is None: + max_seq_len = paddle.max(out_seq_lens) + output_ids = paddle.slice(output_ids, [0], [0], [max_seq_len]) + parent_ids = paddle.slice(parent_ids, [0], [0], [max_seq_len]) % beam_size + ids = paddle.nn.functional.gather_tree(output_ids, parent_ids) + return ids + + +def transfer_param(p, is_bias=False): + param_shape = p.shape + del p + return paddle.create_parameter( + shape=param_shape, dtype="float16", is_bias=is_bias) + + +class InferTransformerDecoding(nn.Layer): + def __init__(self, + decoder, + word_embedding, + positional_embedding, + linear, + max_length, + n_layer, + n_head, + d_model, + bos_id=0, + eos_id=1, + beam_size=4, + max_out_len=256, + beam_search_diversity_rate=0.0, + decoding_lib=None, + use_fp16_decoding=False): + if decoding_lib is None: + raise ValueError( + "The args decoding_lib must be set to use Faster Transformer. ") + elif not os.path.exists(decoding_lib): + raise ValueError("The path to decoding lib is not exist.") + + super(InferTransformerDecoding, self).__init__() + paddle.utils.load_op_library(decoding_lib) + for arg, value in locals().items(): + if arg not in [ + "self", "decoder", "word_embedding", "positional_embedding", + "linear" + ]: + setattr(self, "_" + arg, value) + # process weights + if use_fp16_decoding: + for mod in decoder.layers: + mod.norm1.weight = transfer_param(mod.norm1.weight) + mod.norm1.bias = transfer_param(mod.norm1.bias, is_bias=True) + mod.self_attn.q_proj.weight = transfer_param( + mod.self_attn.q_proj.weight) + mod.self_attn.q_proj.bias = transfer_param( + mod.self_attn.q_proj.bias, is_bias=True) + mod.self_attn.k_proj.weight = transfer_param( + mod.self_attn.k_proj.weight) + mod.self_attn.k_proj.bias = transfer_param( + mod.self_attn.k_proj.bias, is_bias=True) + mod.self_attn.v_proj.weight = transfer_param( + mod.self_attn.v_proj.weight) + mod.self_attn.v_proj.bias = transfer_param( + mod.self_attn.v_proj.bias, is_bias=True) + mod.self_attn.out_proj.weight = transfer_param( + mod.self_attn.out_proj.weight) + mod.self_attn.out_proj.bias = transfer_param( + mod.self_attn.out_proj.bias, is_bias=True) + + mod.norm2.weight = transfer_param(mod.norm2.weight) + mod.norm2.bias = transfer_param(mod.norm2.bias, is_bias=True) + mod.cross_attn.q_proj.weight = transfer_param( + mod.cross_attn.q_proj.weight) + mod.cross_attn.q_proj.bias = transfer_param( + mod.cross_attn.q_proj.bias, is_bias=True) + mod.cross_attn.k_proj.weight = transfer_param( + mod.cross_attn.k_proj.weight) + mod.cross_attn.k_proj.bias = transfer_param( + mod.cross_attn.k_proj.bias, is_bias=True) + mod.cross_attn.v_proj.weight = transfer_param( + mod.cross_attn.v_proj.weight) + mod.cross_attn.v_proj.bias = transfer_param( + mod.cross_attn.v_proj.bias, is_bias=True) + mod.cross_attn.out_proj.weight = transfer_param( + mod.cross_attn.out_proj.weight) + mod.cross_attn.out_proj.bias = transfer_param( + mod.cross_attn.out_proj.bias, is_bias=True) + + mod.norm3.weight = transfer_param(mod.norm3.weight) + mod.norm3.bias = transfer_param(mod.norm3.bias, is_bias=True) + mod.linear1.weight = transfer_param(mod.linear1.weight) + mod.linear1.bias = transfer_param( + mod.linear1.bias, is_bias=True) + mod.linear2.weight = transfer_param(mod.linear2.weight) + mod.linear2.bias = transfer_param( + mod.linear2.bias, is_bias=True) + + decoder.norm.weight = transfer_param(decoder.norm.weight) + decoder.norm.bias = transfer_param(decoder.norm.bias, is_bias=True) + + linear.weight = transfer_param(linear.weight) + + self.slf_ln_weight = [] + self.slf_ln_bias = [] + self.slf_q_weight = [] + self.slf_q_bias = [] + self.slf_k_weight = [] + self.slf_k_bias = [] + self.slf_v_weight = [] + self.slf_v_bias = [] + self.slf_out_weight = [] + self.slf_out_bias = [] + + self.cross_ln_weight = [] + self.cross_ln_bias = [] + self.cross_q_weight = [] + self.cross_q_bias = [] + self.cross_k_weight = [] + self.cross_k_bias = [] + self.cross_v_weight = [] + self.cross_v_bias = [] + self.cross_out_weight = [] + self.cross_out_bias = [] + + self.ffn_ln_weight = [] + self.ffn_ln_bias = [] + self.ffn_inter_weight = [] + self.ffn_inter_bias = [] + self.ffn_out_weight = [] + self.ffn_out_bias = [] + + for mod in decoder.layers: + self.slf_ln_weight.append(mod.norm1.weight) + self.slf_ln_bias.append(mod.norm1.bias) + self.slf_q_weight.append(mod.self_attn.q_proj.weight) + self.slf_q_bias.append(mod.self_attn.q_proj.bias) + self.slf_k_weight.append(mod.self_attn.k_proj.weight) + self.slf_k_bias.append(mod.self_attn.k_proj.bias) + self.slf_v_weight.append(mod.self_attn.v_proj.weight) + self.slf_v_bias.append(mod.self_attn.v_proj.bias) + self.slf_out_weight.append(mod.self_attn.out_proj.weight) + self.slf_out_bias.append(mod.self_attn.out_proj.bias) + + self.cross_ln_weight.append(mod.norm2.weight) + self.cross_ln_bias.append(mod.norm2.bias) + self.cross_q_weight.append(mod.cross_attn.q_proj.weight) + self.cross_q_bias.append(mod.cross_attn.q_proj.bias) + self.cross_k_weight.append(mod.cross_attn.k_proj.weight) + self.cross_k_bias.append(mod.cross_attn.k_proj.bias) + self.cross_v_weight.append(mod.cross_attn.v_proj.weight) + self.cross_v_bias.append(mod.cross_attn.v_proj.bias) + self.cross_out_weight.append(mod.cross_attn.out_proj.weight) + self.cross_out_bias.append(mod.cross_attn.out_proj.bias) + + self.ffn_ln_weight.append(mod.norm3.weight) + self.ffn_ln_bias.append(mod.norm3.bias) + self.ffn_inter_weight.append(mod.linear1.weight) + self.ffn_inter_bias.append(mod.linear1.bias) + self.ffn_out_weight.append(mod.linear2.weight) + self.ffn_out_bias.append(mod.linear2.bias) + + self.decoder_ln_weight = [decoder.norm.weight] + self.decoder_ln_bias = [decoder.norm.bias] + + self.pos_emb = [positional_embedding.weight] + self.word_emb = [word_embedding.weight] + + self.linear_weight = [linear.weight] + self.linear_bias = [linear.bias] + + def forward(self, enc_output, memory_seq_lens): + enc_output = nn.decode.BeamSearchDecoder.tile_beam_merge_with_batch( + enc_output, self._beam_size) + memory_seq_lens = nn.decode.BeamSearchDecoder.tile_beam_merge_with_batch( + memory_seq_lens, self._beam_size) + + output_ids, parent_ids, sequence_length = infer_transformer_decoder( + [enc_output], [memory_seq_lens], + [paddle.cast( + self.word_emb[0], dtype="float16")] + if self._use_fp16_decoding else self.word_emb, + self.slf_ln_weight, + self.slf_ln_bias, + self.slf_q_weight, + self.slf_q_bias, + self.slf_k_weight, + self.slf_k_bias, + self.slf_v_weight, + self.slf_v_bias, + self.slf_out_weight, + self.slf_out_bias, + self.cross_ln_weight, + self.cross_ln_bias, + self.cross_q_weight, + self.cross_q_bias, + self.cross_k_weight, + self.cross_k_bias, + self.cross_v_weight, + self.cross_v_bias, + self.cross_out_weight, + self.cross_out_bias, + self.ffn_ln_weight, + self.ffn_ln_bias, + self.ffn_inter_weight, + self.ffn_inter_bias, + self.ffn_out_weight, + self.ffn_out_bias, + self.decoder_ln_weight, + self.decoder_ln_bias, + self.linear_weight, + self.linear_bias, [paddle.cast( + self.pos_emb[0], dtype="float16")] + if self._use_fp16_decoding else self.pos_emb, + self._beam_size, + self._n_head, + int(self._d_model / self._n_head), + self._n_layer, + self._bos_id, + self._eos_id, + self._max_out_len, + self._beam_search_diversity_rate) + + ids = finalize(self._beam_size, output_ids, parent_ids, sequence_length) + + return ids diff --git a/paddlenlp/ext_op/transformer/faster_transformer.py b/paddlenlp/ext_op/transformer/faster_transformer.py new file mode 100644 index 000000000000..a2519802b13d --- /dev/null +++ b/paddlenlp/ext_op/transformer/faster_transformer.py @@ -0,0 +1,133 @@ +import os +import numpy as np + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from paddlenlp.transformers import TransformerModel, position_encoding_init +from paddlenlp.ext_op import InferTransformerDecoding + + +class FasterTransformer(TransformerModel): + def __init__(self, + src_vocab_size, + trg_vocab_size, + max_length, + n_layer, + n_head, + d_model, + d_inner_hid, + dropout, + weight_sharing, + bos_id=0, + eos_id=1, + beam_size=4, + max_out_len=256, + decoding_lib=None, + use_fp16_decoding=False): + if decoding_lib is None: + raise ValueError( + "The args decoding_lib must be set to use Faster Transformer. ") + elif not os.path.exists(decoding_lib): + raise ValueError("The path to decoding lib is not exist.") + + args = dict(locals()) + args.pop("self") + args.pop("__class__", None) + self.beam_size = args.pop("beam_size") + self.max_out_len = args.pop("max_out_len") + self.decoding_lib = args.pop("decoding_lib") + self.use_fp16_decoding = args.pop("use_fp16_decoding") + self.dropout = dropout + self.weight_sharing = weight_sharing + self.trg_vocab_size = trg_vocab_size + self.d_model = d_model + self.bos_id = bos_id + self.max_length = max_length + super(FasterTransformer, self).__init__(**args) + + self.decoding_linear = nn.Linear( + in_features=d_model, out_features=trg_vocab_size) + + self.decoding = InferTransformerDecoding( + decoder=self.transformer.decoder, + word_embedding=self.trg_word_embedding.word_embedding, + positional_embedding=self.trg_pos_embedding.pos_encoder, + linear=self.decoding_linear, + max_length=max_length, + n_layer=n_layer, + n_head=n_head, + d_model=d_model, + bos_id=bos_id, + eos_id=eos_id, + beam_size=beam_size, + max_out_len=max_out_len, + decoding_lib=self.decoding_lib, + use_fp16_decoding=self.use_fp16_decoding) + + def forward(self, src_word): + src_max_len = paddle.shape(src_word)[-1] + src_slf_attn_bias = paddle.cast( + src_word == self.bos_id, + dtype=paddle.get_default_dtype()).unsqueeze([1, 2]) * -1e9 + src_pos = paddle.cast( + src_word != self.bos_id, dtype="int64") * paddle.arange( + start=0, end=src_max_len) + + # Run encoder + src_emb = self.src_word_embedding(src_word) + src_pos_emb = self.src_pos_embedding(src_pos) + src_emb = src_emb + src_pos_emb + enc_input = F.dropout( + src_emb, p=self.dropout, + training=False) if self.dropout else src_emb + enc_output = self.transformer.encoder(enc_input, src_slf_attn_bias) + + if self.use_fp16_decoding: + enc_output = paddle.cast(enc_output, dtype="float16") + + mem_seq_lens = paddle.sum(paddle.cast( + src_word != self.bos_id, dtype="int32"), + axis=1) + ids = self.decoding(enc_output, mem_seq_lens) + + return ids + + def load(self, init_from_params): + # Load the trained model + assert init_from_params, ( + "Please set init_from_params to load the infer model.") + + model_dict = paddle.load(init_from_params) + + # To set weight[padding_idx] to 0. + model_dict["trg_word_embedding.word_embedding.weight"][ + self.bos_id] = [0] * self.d_model + + # Dealing with weight sharing. + if self.weight_sharing: + model_dict["decoding_linear.weight"] = np.transpose(model_dict[ + "trg_word_embedding.word_embedding.weight"]) + model_dict["decoding_linear.bias"] = np.zeros( + [self.trg_vocab_size], dtype="float32") + else: + model_dict["decoding_linear.weight"] = model_dict["linear.weight"] + model_dict["decoding_linear.bias"] = np.zeros( + [self.trg_vocab_size], dtype="float32") + + # To avoid a longer length than training, reset the size of position + # encoding to max_length + model_dict["encoder.pos_encoder.weight"] = position_encoding_init( + self.max_length, self.d_model) + model_dict["decoder.pos_encoder.weight"] = position_encoding_init( + self.max_length, self.d_model) + + if self.use_fp16_decoding: + for item in self.state_dict(): + if "decoder" in item: + model_dict[item] = np.float16(model_dict[item]) + model_dict["decoding_linear.weight"] = np.float16(model_dict[ + "decoding_linear.weight"]) + + self.load_dict(model_dict)