Skip to content

Commit

Permalink
Recover atol of fused_attention.
Browse files Browse the repository at this point in the history
  • Loading branch information
limin2021 committed Nov 9, 2021
1 parent 28e539e commit 68bc308
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
12 changes: 6 additions & 6 deletions python/paddle/fluid/tests/unittests/test_fused_attention_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,9 @@ def test_fused_attention_op(self):
final_out_ref, x_grad_ref = self.GetBaselineOut()
final_out, x_grad = self.GetFusedAttentionOut()
np.testing.assert_allclose(
final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-3)
final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-5)
np.testing.assert_allclose(
x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-3)
x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-5)


class TestFusedAttentionOpPreLn(TestFusedAttentionOp):
Expand Down Expand Up @@ -249,9 +249,9 @@ def test_fused_attention_op(self):
final_out_ref, x_grad_ref = self.GetBaselineOut()
final_out, x_grad = self.GetFusedAttentionOut()
np.testing.assert_allclose(
final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-1)
final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-5)
np.testing.assert_allclose(
x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-1)
x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-5)


class TestFusedAttentionOpNoneAttnMask(TestFusedAttentionOp):
Expand Down Expand Up @@ -279,9 +279,9 @@ def test_fused_attention_op(self):
final_out_ref, x_grad_ref = self.GetBaselineOut()
final_out, x_grad = self.GetFusedAttentionOut()
np.testing.assert_allclose(
final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-1)
final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-5)
np.testing.assert_allclose(
x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-1)
x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-5)


class TestFusedAttentionOpFp16(TestFusedAttentionOp):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def run_imperative(self):
fused_attn.qkv_bias.numpy(),
fused_attn.linear_weight.numpy(),
fused_attn.linear_bias.numpy())
np.testing.assert_allclose(ref_out, out.numpy(), rtol=1e-5, atol=1e-3)
np.testing.assert_allclose(ref_out, out.numpy(), rtol=1e-5, atol=1e-5)

def run_static(self):
fused_attn = FusedMultiHeadAttention(
Expand Down Expand Up @@ -312,7 +312,7 @@ def test_static_api(self):
self.attn_mask, ln_scale, ln_bias,
ln_2_scale, ln_2_bias, qkv_weight, qkv_bias,
linear_weight, linear_bias)
np.testing.assert_allclose(ref_out, out, rtol=1e-5, atol=1e-3)
np.testing.assert_allclose(ref_out, out, rtol=1e-5, atol=1e-5)

def test_dynamic_api(self):
paddle.disable_static(place=paddle.CUDAPlace(0))
Expand Down

0 comments on commit 68bc308

Please sign in to comment.