Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

padding the length of input for vit_attention #45506

Merged
merged 7 commits into from
Sep 2, 2022

Conversation

fengxiaoshuai
Copy link
Contributor

@fengxiaoshuai fengxiaoshuai commented Aug 29, 2022

PR types

Others

PR changes

Others

Describe

当attention的输入length不是8的整数倍时,fp16的性能很差,这里对multihead plugin的输入进行padding,对于vit_384模型,batch=1时,时间由13.5ms降低到10.5ms

@paddle-bot
Copy link

paddle-bot bot commented Aug 29, 2022

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@@ -310,6 +349,11 @@ int QkvToContextPluginDynamic::enqueue(
// input[0], (B, S, 3 * N * H, 1, 1)
int batch = input_dims.d[0];
int seq_len = input_dims.d[1];
int real_seq_len = seq_len;
if (input_desc[0].type == nvinfer1::DataType::kHALF) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

注释下fp16需要pading

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

注释下fp16需要pading

好的

__global__ void reset_qk_bias(T *input, int real_seq_len, int seq_len) {
if (threadIdx.x < seq_len) {
int id = threadIdx.x + blockIdx.x * seq_len;
input[id] = threadIdx.x >= real_seq_len ? (T)-1e20f : (T)0.0f;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-1e20f 注意低精度下的表示能力

if (ProductDim(input_desc[1].dims) == ProductDim(input_desc[0].dims)) {
qk_bias = reinterpret_cast<float *>(workspace);
auto size = batch * head_number_ * seq_len * seq_len;
cudaMemset(qk_bias, 0, sizeof(float) * size);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

memsetasync下面几个调用都一样

@@ -373,6 +423,35 @@ int QkvToContextPluginDynamic::enqueue(
} else if (input_type == nvinfer1::DataType::kHALF) {
#ifdef TRT_PLUGIN_FP16_AVALIABLE
VLOG(1) << "TRT Plugin DataType selected. QkvToContext-->fp16";
int *padding_offset = nullptr;
half *padding_input = nullptr;
framework::Tensor padding_offset_tensor;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

改成workspace或者成员变量避免显存分配

0,
sizeof(half) * batch * seq_len * 3 * head_number_ * head_size_);

set_padding_offset<<<1, 1, 0, stream>>>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以再提升下并发

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以再提升下并发
好的

@@ -1105,6 +1113,9 @@ def generate_trt_nodes_num():
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num(), (1e-3,
1e-3)
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(), (1e-3,
1e-3)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fp32的精度应该可以高点

@fengxiaoshuai fengxiaoshuai changed the title vit_384_opt vit_attention_length_padding Sep 2, 2022
@@ -342,6 +427,12 @@ int QkvToContextPluginDynamic::enqueue(
head_number_);
qk_bias = temp_qk_bias;
}
// fake qk_bias
if (ProductDim(input_desc[1].dims) == ProductDim(input_desc[0].dims)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

config的时候就可以确定不用每次enque判断.下面几个也一样

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

config的时候就可以确定不用每次enque判断.下面几个也一样

好的,和下面memset的统一放到configure处理

if (ProductDim(input_desc[1].dims) == ProductDim(input_desc[0].dims)) {
qk_bias = reinterpret_cast<float *>(workspace);
auto size = batch * head_number_ * seq_len * seq_len;
cudaMemset(qk_bias, 0, sizeof(float) * size);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

async接口

@fengxiaoshuai fengxiaoshuai changed the title vit_attention_length_padding padding the length of input for vit_attention Sep 2, 2022
@b3602sss b3602sss merged commit f79be65 into PaddlePaddle:develop Sep 2, 2022
@fengxiaoshuai fengxiaoshuai deleted the new_vit_382_opt branch October 8, 2022 09:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants