Skip to content

Commit

Permalink
Fix bugs when bias add none in static graph for fused_attention op. (#…
Browse files Browse the repository at this point in the history
…37566) (#37608)

cherry-pick of PR #37566:

Based on #37411, this PR:

    Continue to fix the bugs when bias add is none in static graph for fused_attention op.
    Polish and improve the unittests in test_fused_attention_op_api.py.
  • Loading branch information
limin2021 authored Nov 29, 2021
1 parent 4066713 commit 46988e2
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 63 deletions.
172 changes: 111 additions & 61 deletions python/paddle/fluid/tests/unittests/test_fused_attention_op_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,12 @@ def compute_reference(pre_layer_norm, query, attn_mask, ln_scale, ln_bias,
seq_len = query.shape[1]
embed_dim = query.shape[2]

has_bias = True
if ln_bias is None:
has_bias = False

if (pre_layer_norm):
ln_out = layer_norm(query, True, True, ln_scale, ln_bias)
ln_out = layer_norm(query, True, has_bias, ln_scale, ln_bias)

num_head = qkv_weight.shape[1]
head_dim = qkv_weight.shape[2]
Expand All @@ -89,17 +93,24 @@ def compute_reference(pre_layer_norm, query, attn_mask, ln_scale, ln_bias,
qkv_weight = qkv_weight.reshape(qkv_weight.shape[0], qkv_weight.shape[1] *
qkv_weight.shape[2] * qkv_weight.shape[3])

qkv_bias = qkv_bias.reshape(qkv_bias.shape[0] * qkv_bias.shape[1] *
qkv_bias.shape[2])
if qkv_bias is not None:
qkv_bias = qkv_bias.reshape(qkv_bias.shape[0] * qkv_bias.shape[1] *
qkv_bias.shape[2])
if (pre_layer_norm):
ln_out = ln_out.reshape(batch_size * seq_len, embed_dim)
qkv = fc(ln_out, qkv_weight)
qkv_bias_out = qkv + qkv_bias
if qkv_bias is not None:
qkv_bias_out = qkv + qkv_bias
else:
qkv_bias_out = qkv
ln_out = ln_out.reshape(batch_size, seq_len, embed_dim)
else:
query = query.reshape(batch_size * seq_len, embed_dim)
qkv = fc(query, qkv_weight)
qkv_bias_out = qkv + qkv_bias
if qkv_bias is not None:
qkv_bias_out = qkv + qkv_bias
else:
qkv_bias_out = qkv
query = query.reshape(batch_size, seq_len, embed_dim)

qkv_bias_out = qkv_bias_out.reshape(batch_size, seq_len, 3, num_head,
Expand Down Expand Up @@ -140,26 +151,42 @@ def compute_reference(pre_layer_norm, query, attn_mask, ln_scale, ln_bias,
out_linear_out = fc(out_linear_input, out_linear_weight)

# bias add, dropout, residual add, layer_norm.
out_linear_bias_out = out_linear_out + out_linear_bias
if out_linear_bias is not None:
out_linear_bias_out = out_linear_out + out_linear_bias
else:
out_linear_bias_out = out_linear_out
out_linear_bias_dropout_out = out_linear_bias_out
out_linear_bias_dropout_residual_out = query + out_linear_bias_dropout_out
if not pre_layer_norm:
out_linear_bias_dropout_residual_out = layer_norm(
out_linear_bias_dropout_residual_out, True, True, ln_2_scale,
out_linear_bias_dropout_residual_out, True, has_bias, ln_2_scale,
ln_2_bias)
return out_linear_bias_dropout_residual_out


class TestFusedAttentionAPI(unittest.TestCase):
def setUp(self):
self.setXType()
self.setPreLn()
self.setAttnMask()
self.setBiasAttr()
self.config()
self.generate_input_data()

def config(self):
def setAttnMask(self):
self.has_attn_mask = True

def setBiasAttr(self):
self.bias_attr = None

def setPreLn(self):
self.pre_layer_norm = False

def setXType(self):
self.x_type = np.float32

def config(self):
self.attn_mask_type = np.float64
self.pre_layer_norm = True
self.has_attn_mask = True
self.training = True
self.need_weight = False

Expand All @@ -172,7 +199,6 @@ def config(self):
self.dropout_prob = 0.0
self.attn_dropout_prob = 0.0
self.weight_attr = None
self.bias_attr = None

self.kdim, self.vdim = self.embed_dim, self.embed_dim
self.key_length, self.value_length = self.query_length, self.query_length
Expand Down Expand Up @@ -205,23 +231,32 @@ def run_imperative(self):
self.embed_dim, self.num_heads, self.dropout_prob,
self.attn_dropout_prob, self.kdim, self.vdim, self.pre_layer_norm,
self.need_weight, self.weight_attr, self.bias_attr)
qkv_bias = np.random.random(fused_attn.qkv_bias.shape).astype('float32')
fused_attn.qkv_bias.set_value(paddle.to_tensor(qkv_bias))
if self.bias_attr is not False:
qkv_bias = np.random.random(fused_attn.qkv_bias.shape).astype(
'float32')
fused_attn.qkv_bias.set_value(paddle.to_tensor(qkv_bias))
out = fused_attn(
paddle.to_tensor(self.query),
paddle.to_tensor(self.query),
paddle.to_tensor(self.query), attn_mask_tensor)
ref_out = compute_reference(self.pre_layer_norm, self.query,
self.attn_mask,
fused_attn.pre_ln_scale.numpy(),
fused_attn.pre_ln_bias.numpy(),
fused_attn.ln_scale.numpy(),
fused_attn.ln_bias.numpy(),
fused_attn.qkv_weight.numpy(),
fused_attn.qkv_bias.numpy(),
fused_attn.linear_weight.numpy(),
fused_attn.linear_bias.numpy())
np.testing.assert_allclose(ref_out, out.numpy(), rtol=1e-5, atol=1e-5)

fused_attn_qkv_bias = None
fused_attn_linear_bias = None
fused_attn_pre_ln_bias = None
fused_attn_ln_bias = None
if self.bias_attr is not False:
fused_attn_qkv_bias = fused_attn.qkv_bias.numpy()
fused_attn_linear_bias = fused_attn.linear_bias.numpy()
fused_attn_pre_ln_bias = fused_attn.pre_ln_bias.numpy()
fused_attn_ln_bias = fused_attn.ln_bias.numpy()

ref_out = compute_reference(
self.pre_layer_norm, self.query, self.attn_mask,
fused_attn.pre_ln_scale.numpy(), fused_attn_pre_ln_bias,
fused_attn.ln_scale.numpy(), fused_attn_ln_bias,
fused_attn.qkv_weight.numpy(), fused_attn_qkv_bias,
fused_attn.linear_weight.numpy(), fused_attn_linear_bias)
np.testing.assert_allclose(ref_out, out.numpy(), rtol=1e-5, atol=1e-4)

def run_static(self):
fused_attn = FusedMultiHeadAttention(
Expand All @@ -248,27 +283,53 @@ def run_static(self):
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())

qkv_bias = None
linear_bias = None
ln_bias = None
ln_2_bias = None
if self.has_attn_mask:
out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias = exe.run(
paddle.static.default_main_program(),
feed={"X": self.query,
"SrcMask": self.attn_mask},
fetch_list=[
final_out, fused_attn.qkv_weight, fused_attn.qkv_bias,
fused_attn.linear_weight, fused_attn.linear_bias,
fused_attn.pre_ln_scale, fused_attn.pre_ln_bias,
fused_attn.ln_scale, fused_attn.ln_bias
])
if self.bias_attr is False:
out, qkv_weight, out_linear_weight, ln_scale, ln_2_scale = exe.run(
paddle.static.default_main_program(),
feed={"X": self.query,
"SrcMask": self.attn_mask},
fetch_list=[
final_out, fused_attn.qkv_weight,
fused_attn.linear_weight, fused_attn.pre_ln_scale,
fused_attn.ln_scale
])
else:
out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias = exe.run(
paddle.static.default_main_program(),
feed={"X": self.query,
"SrcMask": self.attn_mask},
fetch_list=[
final_out, fused_attn.qkv_weight, fused_attn.qkv_bias,
fused_attn.linear_weight, fused_attn.linear_bias,
fused_attn.pre_ln_scale, fused_attn.pre_ln_bias,
fused_attn.ln_scale, fused_attn.ln_bias
])
else:
out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias = exe.run(
paddle.static.default_main_program(),
feed={"X": self.query, },
fetch_list=[
final_out, fused_attn.qkv_weight, fused_attn.qkv_bias,
fused_attn.linear_weight, fused_attn.linear_bias,
fused_attn.pre_ln_scale, fused_attn.pre_ln_bias,
fused_attn.ln_scale, fused_attn.ln_bias
])
if self.bias_attr is False:
out, qkv_weight, out_linear_weight, ln_scale, ln_2_scale = exe.run(
paddle.static.default_main_program(),
feed={"X": self.query, },
fetch_list=[
final_out, fused_attn.qkv_weight,
fused_attn.linear_weight, fused_attn.pre_ln_scale,
fused_attn.ln_scale
])
else:
out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias = exe.run(
paddle.static.default_main_program(),
feed={"X": self.query, },
fetch_list=[
final_out, fused_attn.qkv_weight, fused_attn.qkv_bias,
fused_attn.linear_weight, fused_attn.linear_bias,
fused_attn.pre_ln_scale, fused_attn.pre_ln_bias,
fused_attn.ln_scale, fused_attn.ln_bias
])
return out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias

def test_static_api(self):
Expand All @@ -280,35 +341,24 @@ def test_static_api(self):
self.attn_mask, ln_scale, ln_bias,
ln_2_scale, ln_2_bias, qkv_weight, qkv_bias,
linear_weight, linear_bias)
np.testing.assert_allclose(ref_out, out, rtol=1e-5, atol=1e-5)
np.testing.assert_allclose(ref_out, out, rtol=1e-5, atol=1e-4)

def test_dynamic_api(self):
paddle.disable_static(place=paddle.CUDAPlace(0))
self.run_imperative()


class TestFusedAttentionAPINoneAttnMask(TestFusedAttentionAPI):
def config(self):
self.x_type = np.float32
self.attn_mask_type = np.float64
self.pre_layer_norm = True
def setAttnMask(self):
self.has_attn_mask = False
self.training = True
self.need_weight = False

self.batch_size = 1
self.query_length = 2
self.head_dim = 2
self.num_heads = 2
self.embed_dim = self.head_dim * self.num_heads
def setPreLn(self):
self.pre_layer_norm = True

self.dropout_prob = 0.0
self.attn_dropout_prob = 0.0
self.weight_attr = None
self.bias_attr = None

self.kdim, self.vdim = self.embed_dim, self.embed_dim
self.key_length, self.value_length = self.query_length, self.query_length
class TestFusedAttentionAPIBiasIsNone(TestFusedAttentionAPI):
def setBiasAttr(self):
self.bias_attr = False


if __name__ == "__main__":
Expand Down
6 changes: 4 additions & 2 deletions python/paddle/incubate/nn/functional/fused_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,10 +388,12 @@ def fused_multi_head_attention(x,
if pre_ln_bias:
inputs['LnBias'] = [pre_ln_bias]
inputs['QKVW'] = [qkv_weight]
inputs['QKVBias'] = [qkv_bias]
if qkv_bias is not None:
inputs['QKVBias'] = [qkv_bias]
inputs['SrcMask'] = attn_mask
inputs['OutLinearW'] = [linear_weight]
inputs['OutLinearBias'] = [linear_bias]
if linear_bias is not None:
inputs['OutLinearBias'] = [linear_bias]
if ln_scale:
inputs['Ln2Scale'] = [ln_scale]
if ln_bias:
Expand Down

0 comments on commit 46988e2

Please sign in to comment.