Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cherry-pick-2.2] fix bias add none bug on static graph for fused_attention_op #37607

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 111 additions & 61 deletions python/paddle/fluid/tests/unittests/test_fused_attention_op_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,12 @@ def compute_reference(pre_layer_norm, query, attn_mask, ln_scale, ln_bias,
seq_len = query.shape[1]
embed_dim = query.shape[2]

has_bias = True
if ln_bias is None:
has_bias = False

if (pre_layer_norm):
ln_out = layer_norm(query, True, True, ln_scale, ln_bias)
ln_out = layer_norm(query, True, has_bias, ln_scale, ln_bias)

num_head = qkv_weight.shape[1]
head_dim = qkv_weight.shape[2]
Expand All @@ -89,17 +93,24 @@ def compute_reference(pre_layer_norm, query, attn_mask, ln_scale, ln_bias,
qkv_weight = qkv_weight.reshape(qkv_weight.shape[0], qkv_weight.shape[1] *
qkv_weight.shape[2] * qkv_weight.shape[3])

qkv_bias = qkv_bias.reshape(qkv_bias.shape[0] * qkv_bias.shape[1] *
qkv_bias.shape[2])
if qkv_bias is not None:
qkv_bias = qkv_bias.reshape(qkv_bias.shape[0] * qkv_bias.shape[1] *
qkv_bias.shape[2])
if (pre_layer_norm):
ln_out = ln_out.reshape(batch_size * seq_len, embed_dim)
qkv = fc(ln_out, qkv_weight)
qkv_bias_out = qkv + qkv_bias
if qkv_bias is not None:
qkv_bias_out = qkv + qkv_bias
else:
qkv_bias_out = qkv
ln_out = ln_out.reshape(batch_size, seq_len, embed_dim)
else:
query = query.reshape(batch_size * seq_len, embed_dim)
qkv = fc(query, qkv_weight)
qkv_bias_out = qkv + qkv_bias
if qkv_bias is not None:
qkv_bias_out = qkv + qkv_bias
else:
qkv_bias_out = qkv
query = query.reshape(batch_size, seq_len, embed_dim)

qkv_bias_out = qkv_bias_out.reshape(batch_size, seq_len, 3, num_head,
Expand Down Expand Up @@ -140,26 +151,42 @@ def compute_reference(pre_layer_norm, query, attn_mask, ln_scale, ln_bias,
out_linear_out = fc(out_linear_input, out_linear_weight)

# bias add, dropout, residual add, layer_norm.
out_linear_bias_out = out_linear_out + out_linear_bias
if out_linear_bias is not None:
out_linear_bias_out = out_linear_out + out_linear_bias
else:
out_linear_bias_out = out_linear_out
out_linear_bias_dropout_out = out_linear_bias_out
out_linear_bias_dropout_residual_out = query + out_linear_bias_dropout_out
if not pre_layer_norm:
out_linear_bias_dropout_residual_out = layer_norm(
out_linear_bias_dropout_residual_out, True, True, ln_2_scale,
out_linear_bias_dropout_residual_out, True, has_bias, ln_2_scale,
ln_2_bias)
return out_linear_bias_dropout_residual_out


class TestFusedAttentionAPI(unittest.TestCase):
def setUp(self):
self.setXType()
self.setPreLn()
self.setAttnMask()
self.setBiasAttr()
self.config()
self.generate_input_data()

def config(self):
def setAttnMask(self):
self.has_attn_mask = True

def setBiasAttr(self):
self.bias_attr = None

def setPreLn(self):
self.pre_layer_norm = False

def setXType(self):
self.x_type = np.float32

def config(self):
self.attn_mask_type = np.float64
self.pre_layer_norm = True
self.has_attn_mask = True
self.training = True
self.need_weight = False

Expand All @@ -172,7 +199,6 @@ def config(self):
self.dropout_prob = 0.0
self.attn_dropout_prob = 0.0
self.weight_attr = None
self.bias_attr = None

self.kdim, self.vdim = self.embed_dim, self.embed_dim
self.key_length, self.value_length = self.query_length, self.query_length
Expand Down Expand Up @@ -205,23 +231,32 @@ def run_imperative(self):
self.embed_dim, self.num_heads, self.dropout_prob,
self.attn_dropout_prob, self.kdim, self.vdim, self.pre_layer_norm,
self.need_weight, self.weight_attr, self.bias_attr)
qkv_bias = np.random.random(fused_attn.qkv_bias.shape).astype('float32')
fused_attn.qkv_bias.set_value(paddle.to_tensor(qkv_bias))
if self.bias_attr is not False:
qkv_bias = np.random.random(fused_attn.qkv_bias.shape).astype(
'float32')
fused_attn.qkv_bias.set_value(paddle.to_tensor(qkv_bias))
out = fused_attn(
paddle.to_tensor(self.query),
paddle.to_tensor(self.query),
paddle.to_tensor(self.query), attn_mask_tensor)
ref_out = compute_reference(self.pre_layer_norm, self.query,
self.attn_mask,
fused_attn.pre_ln_scale.numpy(),
fused_attn.pre_ln_bias.numpy(),
fused_attn.ln_scale.numpy(),
fused_attn.ln_bias.numpy(),
fused_attn.qkv_weight.numpy(),
fused_attn.qkv_bias.numpy(),
fused_attn.linear_weight.numpy(),
fused_attn.linear_bias.numpy())
np.testing.assert_allclose(ref_out, out.numpy(), rtol=1e-5, atol=1e-5)

fused_attn_qkv_bias = None
fused_attn_linear_bias = None
fused_attn_pre_ln_bias = None
fused_attn_ln_bias = None
if self.bias_attr is not False:
fused_attn_qkv_bias = fused_attn.qkv_bias.numpy()
fused_attn_linear_bias = fused_attn.linear_bias.numpy()
fused_attn_pre_ln_bias = fused_attn.pre_ln_bias.numpy()
fused_attn_ln_bias = fused_attn.ln_bias.numpy()

ref_out = compute_reference(
self.pre_layer_norm, self.query, self.attn_mask,
fused_attn.pre_ln_scale.numpy(), fused_attn_pre_ln_bias,
fused_attn.ln_scale.numpy(), fused_attn_ln_bias,
fused_attn.qkv_weight.numpy(), fused_attn_qkv_bias,
fused_attn.linear_weight.numpy(), fused_attn_linear_bias)
np.testing.assert_allclose(ref_out, out.numpy(), rtol=1e-5, atol=1e-4)

def run_static(self):
fused_attn = FusedMultiHeadAttention(
Expand All @@ -248,27 +283,53 @@ def run_static(self):
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())

qkv_bias = None
linear_bias = None
ln_bias = None
ln_2_bias = None
if self.has_attn_mask:
out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias = exe.run(
paddle.static.default_main_program(),
feed={"X": self.query,
"SrcMask": self.attn_mask},
fetch_list=[
final_out, fused_attn.qkv_weight, fused_attn.qkv_bias,
fused_attn.linear_weight, fused_attn.linear_bias,
fused_attn.pre_ln_scale, fused_attn.pre_ln_bias,
fused_attn.ln_scale, fused_attn.ln_bias
])
if self.bias_attr is False:
out, qkv_weight, out_linear_weight, ln_scale, ln_2_scale = exe.run(
paddle.static.default_main_program(),
feed={"X": self.query,
"SrcMask": self.attn_mask},
fetch_list=[
final_out, fused_attn.qkv_weight,
fused_attn.linear_weight, fused_attn.pre_ln_scale,
fused_attn.ln_scale
])
else:
out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias = exe.run(
paddle.static.default_main_program(),
feed={"X": self.query,
"SrcMask": self.attn_mask},
fetch_list=[
final_out, fused_attn.qkv_weight, fused_attn.qkv_bias,
fused_attn.linear_weight, fused_attn.linear_bias,
fused_attn.pre_ln_scale, fused_attn.pre_ln_bias,
fused_attn.ln_scale, fused_attn.ln_bias
])
else:
out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias = exe.run(
paddle.static.default_main_program(),
feed={"X": self.query, },
fetch_list=[
final_out, fused_attn.qkv_weight, fused_attn.qkv_bias,
fused_attn.linear_weight, fused_attn.linear_bias,
fused_attn.pre_ln_scale, fused_attn.pre_ln_bias,
fused_attn.ln_scale, fused_attn.ln_bias
])
if self.bias_attr is False:
out, qkv_weight, out_linear_weight, ln_scale, ln_2_scale = exe.run(
paddle.static.default_main_program(),
feed={"X": self.query, },
fetch_list=[
final_out, fused_attn.qkv_weight,
fused_attn.linear_weight, fused_attn.pre_ln_scale,
fused_attn.ln_scale
])
else:
out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias = exe.run(
paddle.static.default_main_program(),
feed={"X": self.query, },
fetch_list=[
final_out, fused_attn.qkv_weight, fused_attn.qkv_bias,
fused_attn.linear_weight, fused_attn.linear_bias,
fused_attn.pre_ln_scale, fused_attn.pre_ln_bias,
fused_attn.ln_scale, fused_attn.ln_bias
])
return out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias

def test_static_api(self):
Expand All @@ -280,35 +341,24 @@ def test_static_api(self):
self.attn_mask, ln_scale, ln_bias,
ln_2_scale, ln_2_bias, qkv_weight, qkv_bias,
linear_weight, linear_bias)
np.testing.assert_allclose(ref_out, out, rtol=1e-5, atol=1e-5)
np.testing.assert_allclose(ref_out, out, rtol=1e-5, atol=1e-4)

def test_dynamic_api(self):
paddle.disable_static(place=paddle.CUDAPlace(0))
self.run_imperative()


class TestFusedAttentionAPINoneAttnMask(TestFusedAttentionAPI):
def config(self):
self.x_type = np.float32
self.attn_mask_type = np.float64
self.pre_layer_norm = True
def setAttnMask(self):
self.has_attn_mask = False
self.training = True
self.need_weight = False

self.batch_size = 1
self.query_length = 2
self.head_dim = 2
self.num_heads = 2
self.embed_dim = self.head_dim * self.num_heads
def setPreLn(self):
self.pre_layer_norm = True

self.dropout_prob = 0.0
self.attn_dropout_prob = 0.0
self.weight_attr = None
self.bias_attr = None

self.kdim, self.vdim = self.embed_dim, self.embed_dim
self.key_length, self.value_length = self.query_length, self.query_length
class TestFusedAttentionAPIBiasIsNone(TestFusedAttentionAPI):
def setBiasAttr(self):
self.bias_attr = False


if __name__ == "__main__":
Expand Down
6 changes: 4 additions & 2 deletions python/paddle/incubate/nn/functional/fused_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,10 +388,12 @@ def fused_multi_head_attention(x,
if pre_ln_bias:
inputs['LnBias'] = [pre_ln_bias]
inputs['QKVW'] = [qkv_weight]
inputs['QKVBias'] = [qkv_bias]
if qkv_bias is not None:
inputs['QKVBias'] = [qkv_bias]
inputs['SrcMask'] = attn_mask
inputs['OutLinearW'] = [linear_weight]
inputs['OutLinearBias'] = [linear_bias]
if linear_bias is not None:
inputs['OutLinearBias'] = [linear_bias]
if ln_scale:
inputs['Ln2Scale'] = [ln_scale]
if ln_bias:
Expand Down