Skip to content

Commit

Permalink
[fix]fix the bug of fused_attention and fused_feedforward (PaddlePadd…
Browse files Browse the repository at this point in the history
…le#36972)

* fix bug:
1. atten: set the default value of attn_dropout_rate to None
2. ffn: add activation parameter
  • Loading branch information
zhangkaihuo committed Nov 16, 2021
1 parent c460691 commit a3a6edc
Show file tree
Hide file tree
Showing 9 changed files with 350 additions and 29 deletions.
20 changes: 10 additions & 10 deletions paddle/fluid/operators/fused/attn_bias_add.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ __global__ void BroadcastKernelBinary(
kernel_primitives::ElementwiseBinary<InT, OutT, VecSize, 1, 1, Functor>(
result, arg0, arg1, func);
// store
kernel_primitives::WriteData<OutT, VecSize, 1, 1>(out + fix, result, num);
kernel_primitives::WriteData<OutT, VecSize, 1, 1, true>(out + fix, result,
num);
}

// bias add forward impl for "[m, n] + [n] = [m, n]"
Expand Down Expand Up @@ -267,25 +268,24 @@ __global__ void BiasAddBw1DReduceKernel(const ReduceParamType<T>* temp_sum,
}

template <typename T>
void Launch2DColumnReduce(gpuStream_t stream, const int max_threads,
const int reduce_num, const int left_num,
const T* d_out, T* d_bias) {
void Launch2DColumnReduce(const platform::CUDADeviceContext& dev_ctx,
const int max_threads, const int reduce_num,
const int left_num, const T* d_out, T* d_bias) {
dim3 block;
dim3 grid;
bool should_reduce_again = false;
int blocking_size = 1;
SetConfigForColumnReduce(max_threads, reduce_num, left_num, &blocking_size,
&should_reduce_again, &block, &grid);
const auto& stream = dev_ctx.stream();

if (!should_reduce_again) {
BiasAddBwSinglePassKernel<T><<<grid, block, 0, stream>>>(d_out, reduce_num,
left_num, d_bias);
} else {
framework::Tensor tmp_sum;
tmp_sum.mutable_data<ReduceParamType<T>>(
framework::make_ddim({static_cast<int64_t>(
left_num * grid.y * sizeof(ReduceParamType<T>))}),
paddle::platform::CUDAPlace());
tmp_sum.Resize({grid.y, left_num});
tmp_sum.mutable_data<ReduceParamType<T>>(dev_ctx.GetPlace());

BiasAddBw2DReduceKernel<T><<<grid, block, 0, stream>>>(
d_out, reduce_num, left_num, blocking_size,
Expand All @@ -311,8 +311,8 @@ void LaunchBiasAddBwKernel(const platform::CUDADeviceContext& dev_ctx, int m,
Launch1DColumnReduce(dev_ctx.stream(), max_threads, reduce_num, left_num,
d_out, d_bias);
} else {
Launch2DColumnReduce(dev_ctx.stream(), max_threads, reduce_num, left_num,
d_out, d_bias);
Launch2DColumnReduce(dev_ctx, max_threads, reduce_num, left_num, d_out,
d_bias);
}
}

Expand Down
11 changes: 10 additions & 1 deletion paddle/fluid/operators/fused/fused_feedforward_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/matmul_v2_op.h"

#include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
#include "paddle/fluid/operators/fused/fused_dropout_helper.h"
#include "paddle/fluid/operators/layer_norm_kernel.cu.h"

Expand Down Expand Up @@ -261,7 +262,7 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
framework::Tensor d_linear2_out, d_dropout2_out, d_residual;
d_linear2_out.mutable_data<T>({bsz_seq, d_model}, place);
d_dropout2_out.mutable_data<T>({bsz_seq, d_model}, place);
d_residual.mutable_data<T>({bsz_seq, d_model}, place);
d_residual.mutable_data<T>(d_x->dims(), place);

if (pre_layer_norm) {
fused_dropout_layernorm_helper.ResidualDropoutBiasGrad(
Expand Down Expand Up @@ -301,6 +302,14 @@ class FusedFeedForwardGradKernel : public framework::OpKernel<T> {
} else {
MatMulGrad(ctx, d_linear1_out, x, linear1_weight, d_x, d_linear1_weight);
}
std::vector<const Tensor*> ins(2);
std::vector<Tensor*> outs(1);
ins[0] = &d_residual;
ins[1] = d_x;
outs[0] = d_x;
int elewise_add_axis = -1;
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
ctx, ins, &outs, elewise_add_axis, AddFunctor<T>());
}

void Compute(const framework::ExecutionContext& context) const override {
Expand Down
1 change: 1 addition & 0 deletions python/paddle/fluid/tests/unittests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ if(NOT WITH_GPU)
LIST(REMOVE_ITEM TEST_OPS test_fused_feedforward_op)
LIST(REMOVE_ITEM TEST_OPS test_fused_attention_op)
LIST(REMOVE_ITEM TEST_OPS test_fused_attention_op_api)
LIST(REMOVE_ITEM TEST_OPS test_fused_transformer_encoder_layer)
endif()

if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
from paddle.fluid import layers
import unittest
from op_test import OpTest
from paddle.fluid.framework import default_main_program

default_main_program().random_seed = 42


class TestFusedAttentionOp(OpTest):
Expand Down
17 changes: 16 additions & 1 deletion python/paddle/fluid/tests/unittests/test_fused_feedforward_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from paddle.nn.layer.common import Linear, Dropout
import unittest
from op_test import OpTest
from paddle.fluid.framework import default_main_program


class TestFusedFFNOp(OpTest):
Expand Down Expand Up @@ -91,7 +92,7 @@ def setUp(self):
def Base(self):
paddle.disable_static()
tensor_src = paddle.to_tensor(self.src, stop_gradient=False)
residual = paddle.to_tensor(self.src)
residual = tensor_src
if self.pre_layer_norm:
ln1_out = self.norm1(tensor_src)
linear2_out = self.linear2(
Expand Down Expand Up @@ -140,6 +141,7 @@ def FusedFFN(self):
return out, x.grad

def test_out_and_grad(self):
default_main_program().random_seed = 42
base_out, base_grad = self.Base()
fused_out, fused_grad = self.FusedFFN()
np.testing.assert_allclose(
Expand Down Expand Up @@ -192,6 +194,7 @@ def getShape(self):
class APITestStaticFusedFFN(unittest.TestCase):
def test_static(self):
paddle.enable_static()
default_main_program().random_seed = 42
dtype = "float32"
layer_norm_dtype = "float32"
batch_size = 1
Expand Down Expand Up @@ -324,6 +327,18 @@ def test_dropout_rate_value():

self.assertRaises(ValueError, test_dropout_rate_value)

def test_dropout_mode():
x = paddle.static.data(
name='x3', shape=[1, 10, 10], dtype="float32")
linear1_weight = paddle.static.data(
name='linear1_weight3', shape=[10, 10], dtype="float32")
linear2_weight = paddle.static.data(
name='linear2_weight3', shape=[10, 10], dtype="float32")
incubate_f.fused_feedforward(
x, linear1_weight, linear2_weight, mode='test')

self.assertRaises(ValueError, test_dropout_mode)


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np

import paddle
from paddle.incubate.nn import FusedTransformerEncoderLayer
from paddle.nn import TransformerEncoderLayer
from paddle.fluid.framework import default_main_program
import unittest


class TestFusedTransformerEncoderLayer(unittest.TestCase):
def setActivation(self):
self.activation = 'gelu'

def setPreLayerNorm(self):
self.pre_layer_norm = False

def setAttnMask(self):
self.has_attn_mask = True

def setUp(self):
self.batch_size = np.random.randint(1, 8)
self.query_length = np.random.randint(1, 128)
self.nhead = 16
self.head_dim = 4
self.num_heads = self.nhead
self.d_model = self.head_dim * self.num_heads
self.embed_dim = self.d_model
self.dim_feedforward = np.random.randint(1, 32)
self.dropout_rate = 0
self.attn_dropout_rate = None
self.act_dropout_rate = None
self.attn_mask_type = np.float64
self.key_length = self.query_length
self.dtype = 'float32'
self.setActivation()
self.setPreLayerNorm()
self.setAttnMask()

def fused_weight(self, weight, num_head):
a = paddle.transpose(weight, perm=[1, 0])
return paddle.reshape(
a, shape=[1, num_head, int(a.shape[0] / num_head), a.shape[1]])

def fused_qkv(self, q, k, v, num_head):
fq = self.fused_weight(q, num_head)
fk = self.fused_weight(k, num_head)
fv = self.fused_weight(v, num_head)
return paddle.concat(x=[fq, fk, fv], axis=0)

def test_out(self):
default_main_program().random_seed = 42
base_encoder = TransformerEncoderLayer(
self.d_model, self.nhead, self.dim_feedforward, self.dropout_rate,
self.activation, self.attn_dropout_rate, self.act_dropout_rate,
self.pre_layer_norm)
src = np.random.rand(self.batch_size, self.query_length,
self.embed_dim).astype(self.dtype)

if self.has_attn_mask:
attn_mask = np.ones(
(self.batch_size, self.num_heads, self.query_length,
self.key_length),
dtype=self.attn_mask_type)
attn_mask_tensor = paddle.to_tensor(attn_mask)
else:
attn_mask = None
attn_mask_tensor = None

dout = np.random.random(src.shape).astype(self.dtype)

base_out = base_encoder(
paddle.to_tensor(
src, stop_gradient=False), attn_mask_tensor)
paddle.autograd.backward([base_out], [paddle.to_tensor(dout)], True)

fused_encoder = FusedTransformerEncoderLayer(
self.d_model, self.nhead, self.dim_feedforward, self.dropout_rate,
self.activation, self.attn_dropout_rate, self.act_dropout_rate,
self.pre_layer_norm)

fused_encoder.ffn._linear1_weight.set_value(base_encoder.linear1.weight)
fused_encoder.ffn._linear1_bias.set_value(base_encoder.linear1.bias)
fused_encoder.ffn._linear2_weight.set_value(base_encoder.linear2.weight)
fused_encoder.ffn._linear2_bias.set_value(base_encoder.linear2.bias)
if self.pre_layer_norm:
fused_encoder.ffn._ln1_scale.set_value(base_encoder.norm2.weight)
fused_encoder.ffn._ln1_bias.set_value(base_encoder.norm2.bias)
else:
fused_encoder.ffn._ln2_scale.set_value(base_encoder.norm2.weight)
fused_encoder.ffn._ln2_bias.set_value(base_encoder.norm2.bias)

fused_encoder.fused_attn.linear_weight.set_value(
base_encoder.self_attn.out_proj.weight)
fused_encoder.fused_attn.linear_bias.set_value(
base_encoder.self_attn.out_proj.bias)
if self.pre_layer_norm:
fused_encoder.fused_attn.pre_ln_scale.set_value(
base_encoder.norm1.weight)
fused_encoder.fused_attn.pre_ln_bias.set_value(
base_encoder.norm1.bias)
else:
fused_encoder.fused_attn.ln_scale.set_value(
base_encoder.norm1.weight)
fused_encoder.fused_attn.ln_bias.set_value(base_encoder.norm1.bias)

q = base_encoder.self_attn.q_proj.weight
q_bias = base_encoder.self_attn.q_proj.bias
k = base_encoder.self_attn.k_proj.weight
k_bias = base_encoder.self_attn.k_proj.bias
v = base_encoder.self_attn.v_proj.weight
v_bias = base_encoder.self_attn.v_proj.bias
qkv_weight = self.fused_qkv(q, k, v, self.num_heads)
fused_encoder.fused_attn.qkv_weight.set_value(qkv_weight)

tmp = paddle.concat(x=[q_bias, k_bias, v_bias], axis=0)
qkv_bias = paddle.reshape(
tmp,
shape=[3, self.num_heads, int(tmp.shape[0] / 3 / self.num_heads)])
fused_encoder.fused_attn.qkv_bias.set_value(qkv_bias)

fused_out = fused_encoder(
paddle.to_tensor(
src, stop_gradient=False), attn_mask_tensor)
paddle.autograd.backward([fused_out], [paddle.to_tensor(dout)], True)

correct_ffn_str = 'd_model={}, dim_feedforward={}, dropout_rate={}, epsilon={}, activation={}, act_dropout_rate={}, normalize_before={}, dtype={}'.format(
self.d_model, self.dim_feedforward, self.dropout_rate,
fused_encoder.ffn._epsilon, self.activation, self.dropout_rate,
self.pre_layer_norm, self.dtype)
self.assertTrue(fused_encoder.ffn.extra_repr(), correct_ffn_str)

correct_attn_str = 'embed_dim={}, num_heads={}, dropout_rate={}, attn_dropout_rate={}, epsilon={}, kdim={}, vdim={}, normalize_before={}, need_weights={}, dtype={}'.format(
self.embed_dim, self.num_heads, self.dropout_rate,
self.dropout_rate, fused_encoder.fused_attn._epsilon, None, None,
self.pre_layer_norm, False, self.dtype)
self.assertTrue(fused_encoder.fused_attn.extra_repr(), correct_attn_str)

np.testing.assert_allclose(
fused_out.numpy(), base_out.numpy(), rtol=1e-3, atol=1e-4)
self.assertTrue(
np.allclose(
fused_out.grad.numpy(),
base_out.grad.numpy(),
rtol=1e-3,
atol=1e-4))


class TestFusedTransformerEncoderLayerAct(TestFusedTransformerEncoderLayer):
def setActivation(self):
self.activation = 'relu'


class TestFusedTransformerEncoderLayerPreLayerNorm(
TestFusedTransformerEncoderLayer):
def setPreLayerNorm(self):
self.pre_layer_norm = True


class TestFusedTransformerEncoderLayerAttnMaskIsNone(
TestFusedTransformerEncoderLayer):
def setAttnMask(self):
self.has_attn_mask = False


class TestFusedTransformerEncoderLayerPreLnTrueAttnMaskIsNone(
TestFusedTransformerEncoderLayer):
def setPreLayerNorm(self):
self.pre_layer_norm = True

def setAttnMask(self):
self.has_attn_mask = False


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit a3a6edc

Please sign in to comment.