Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 33 additions & 33 deletions tester/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __init__(self, api_config, test_amp):
self.test_amp = test_amp
self.converter = get_converter()

#@func_set_timeout(600)
# @func_set_timeout(600)
def test(self):
if self.need_skip():
print("[Skip]", flush=True)
Expand Down Expand Up @@ -219,17 +219,23 @@ def test(self):
write_to_log("paddle_error", self.api_config.config)
return

if self.api_config.api_name == "paddle.incubate.nn.functional.fused_rms_norm":
if self.api_config.api_name == "paddle.incubate.nn.functional.fused_rms_norm":
paddle_output = paddle_output[0]
if self.api_config.api_name == "paddle.unique":
elif self.api_config.api_name == "paddle.unique":
if "return_index=True" in self.api_config.config:
paddle_output = list(paddle_output)
paddle_output.pop(1)
paddle_output = tuple(paddle_output)
if self.api_config.api_name in {"paddle.mode", "paddle.Tensor.mode"}:
elif self.api_config.api_name in {
"paddle.mode",
"paddle.Tensor.mode",
}:
paddle_output = paddle_output[0]
torch_output = torch_output[0]
if self.api_config.api_name in {"paddle.strided_slice", "paddle.vander"} and any(s < 0 for s in paddle_output.strides):
elif self.api_config.api_name in {
"paddle.strided_slice",
"paddle.vander",
} and any(s < 0 for s in paddle_output.strides):
# torch's from_dlpack now don't support negative strides
paddle_output = paddle_output.contiguous()

Expand Down Expand Up @@ -360,38 +366,32 @@ def test(self):
if self.api_config.api_name == "paddle.Tensor.__setitem__":
torch_out_grads = torch_out_grads[0]
paddle_out_grads = paddle_out_grads[0]

# All configs that not compared with torch
# should be moved to tester/api_config/5_accuracy/grads_diff.txt
if self.api_config.api_name == "paddle.tensordot":
paddle_out_grads = paddle_out_grads[:2]
if self.api_config.api_name == "paddle.combinations":
paddle_out_grads = []
torch_out_grads = []
if self.api_config.api_name == "paddle.diagonal_scatter":
paddle_out_grads = paddle_out_grads[:1]
torch_out_grads = torch_out_grads[:1]
if self.api_config.api_name == "paddle.lerp":
paddle_out_grads = paddle_out_grads[:2]
torch_out_grads = torch_out_grads[:2]
if self.api_config.api_name == "paddle.nn.utils.parameters_to_vector":
paddle_out_grads = []
torch_out_grads = []
if self.api_config.api_name == "paddle.nn.functional.kl_div":
paddle_out_grads = paddle_out_grads[:1]
torch_out_grads = torch_out_grads[:1]
if self.api_config.api_name == "paddle.scale":
paddle_out_grads = paddle_out_grads[0]
torch_out_grads = torch_out_grads[0]
if self.api_config.api_name == "paddle.nn.functional.gaussian_nll_loss":

# All configs that not compared with torch should be copied
# to tester/api_config/5_accuracy/accuracy_gpu_error_grads_diff.txt
if self.api_config.api_name in {
"paddle.lerp",
"paddle.tensordot",
}:
paddle_out_grads = paddle_out_grads[:2]
torch_out_grads = torch_out_grads[:2]
if self.api_config.api_name == "paddle.nn.functional.binary_cross_entropy":
paddle_out_grads = paddle_out_grads[0]
torch_out_grads = torch_out_grads[0]
if self.api_config.api_name == "paddle.nn.functional.binary_cross_entropy_with_logits":
elif self.api_config.api_name in {
"paddle.Tensor.fill_diagonal_tensor",
"paddle.diagonal_scatter",
"paddle.nn.functional.binary_cross_entropy",
"paddle.nn.functional.binary_cross_entropy_with_logits",
"paddle.nn.functional.gaussian_nll_loss",
"paddle.nn.functional.kl_div",
"paddle.scale",
}:
paddle_out_grads = paddle_out_grads[0]
torch_out_grads = torch_out_grads[0]
elif self.api_config.api_name in {
"paddle.combinations",
"paddle.nn.utils.parameters_to_vector",
}:
paddle_out_grads = []
torch_out_grads = []

if isinstance(paddle_out_grads, paddle.Tensor):
if isinstance(torch_out_grads, torch.Tensor):
Expand Down
34 changes: 34 additions & 0 deletions tester/api_config/5_accuracy/accuracy_gpu_error_grads_diff.txt
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,40 @@ paddle.scale(Tensor([2, 1, 1, 2, 3],"float64"), scale=Tensor([],"float32"), )
paddle.scale(Tensor([2, 1, 2, 3],"float32"), scale=Tensor([],"float32"), )
paddle.scale(Tensor([2, 1, 2, 3],"float64"), scale=Tensor([],"float32"), )
paddle.scale(Tensor([2, 3, 4, 5, 6],"float32"), scale=Tensor([1],"float32"), )
paddle.nn.functional.gaussian_nll_loss(Tensor([10, 2, 3],"float32"), Tensor([10, 2, 3],"float32"), Tensor([10, 2, 1],"float32"), False, 1e-06, "none", None, )
paddle.nn.functional.gaussian_nll_loss(Tensor([10, 2, 3],"float32"), Tensor([10, 2, 3],"float32"), Tensor([10, 2, 1],"float32"), full=False, reduction="none", )
paddle.nn.functional.gaussian_nll_loss(Tensor([10, 2, 3],"float32"), Tensor([10, 2, 3],"float32"), Tensor([10, 2],"float32"), False, 1e-06, "none", None, )
paddle.nn.functional.gaussian_nll_loss(Tensor([10, 2, 3],"float32"), Tensor([10, 2, 3],"float32"), Tensor([10, 2],"float32"), full=False, reduction="none", )
paddle.nn.functional.gaussian_nll_loss(Tensor([10, 2],"float32"), Tensor([10, 2],"float32"), Tensor([10, 2],"float32"), False, 1e-06, "none", None, )
paddle.nn.functional.gaussian_nll_loss(Tensor([10, 2],"float32"), Tensor([10, 2],"float32"), Tensor([10, 2],"float32"), True, 1e-06, "mean", None, )
paddle.nn.functional.gaussian_nll_loss(Tensor([10, 2],"float32"), Tensor([10, 2],"float32"), Tensor([10, 2],"float32"), True, 1e-06, "sum", None, )
paddle.nn.functional.gaussian_nll_loss(Tensor([10, 2],"float32"), Tensor([10, 2],"float32"), Tensor([10, 2],"float32"), full=False, reduction="none", )
paddle.nn.functional.gaussian_nll_loss(Tensor([10, 2],"float32"), Tensor([10, 2],"float32"), Tensor([10, 2],"float32"), full=True, reduction="mean", )
paddle.nn.functional.gaussian_nll_loss(Tensor([10, 2],"float32"), Tensor([10, 2],"float32"), Tensor([10, 2],"float32"), full=True, reduction="sum", )
paddle.nn.functional.gaussian_nll_loss(Tensor([10, 2],"float64"), Tensor([10, 2],"float64"), Tensor([10, 2],"float64"), False, 1e-06, "none", None, )
paddle.nn.functional.gaussian_nll_loss(Tensor([10, 2],"float64"), Tensor([10, 2],"float64"), Tensor([10, 2],"float64"), full=False, reduction="none", )
paddle.Tensor.fill_diagonal_tensor(Tensor([2, 4, 3, 2],"float32"), Tensor([2, 2, 3],"float32"), offset=0, dim1=1, dim2=2, )
paddle.Tensor.fill_diagonal_tensor(Tensor([2, 4, 3, 2],"float64"), Tensor([2, 2, 3],"float64"), offset=0, dim1=1, dim2=2, )
paddle.Tensor.fill_diagonal_tensor(Tensor([2, 4, 4],"float32"), Tensor([4, 2],"float32"), 0, 0, 1, )
paddle.Tensor.fill_diagonal_tensor(Tensor([2, 4, 4],"float64"), Tensor([4, 2],"float64"), 0, 0, 1, )
paddle.Tensor.fill_diagonal_tensor(Tensor([2, 4, 7],"float32"), Tensor([2, 4],"float32"), 0, 1, 2, )
paddle.Tensor.fill_diagonal_tensor(Tensor([2, 4, 7],"float64"), Tensor([2, 4],"float64"), 0, 1, 2, )
paddle.Tensor.fill_diagonal_tensor(Tensor([3, 3],"float32"), Tensor([1],"float32"), -2, 0, 1, )
paddle.Tensor.fill_diagonal_tensor(Tensor([3, 3],"float32"), Tensor([2],"float32"), -1, 0, 1, )
paddle.Tensor.fill_diagonal_tensor(Tensor([3, 3],"float32"), Tensor([2],"float32"), 1, 0, 1, )
paddle.Tensor.fill_diagonal_tensor(Tensor([3, 3],"float32"), Tensor([3],"float32"), 0, 0, 1, )
paddle.Tensor.fill_diagonal_tensor(Tensor([3, 3],"float64"), Tensor([1],"float64"), -2, 0, 1, )
paddle.Tensor.fill_diagonal_tensor(Tensor([3, 3],"float64"), Tensor([2],"float64"), -1, 0, 1, )
paddle.Tensor.fill_diagonal_tensor(Tensor([3, 3],"float64"), Tensor([2],"float64"), 1, 0, 1, )
paddle.Tensor.fill_diagonal_tensor(Tensor([3, 3],"float64"), Tensor([3],"float64"), 0, 0, 1, )
paddle.Tensor.fill_diagonal_tensor(Tensor([4, 3],"float32"), Tensor([2],"float32"), offset=1, dim1=0, dim2=1, )
paddle.Tensor.fill_diagonal_tensor(Tensor([4, 3],"float32"), Tensor([3],"float32"), offset=-1, dim1=0, dim2=1, )
paddle.Tensor.fill_diagonal_tensor(Tensor([4, 3],"float32"), Tensor([3],"float32"), offset=0, dim1=0, dim2=1, )
paddle.Tensor.fill_diagonal_tensor(Tensor([4, 3],"float64"), Tensor([2],"float64"), offset=1, dim1=0, dim2=1, )
paddle.Tensor.fill_diagonal_tensor(Tensor([4, 3],"float64"), Tensor([3],"float64"), offset=-1, dim1=0, dim2=1, )
paddle.Tensor.fill_diagonal_tensor(Tensor([4, 3],"float64"), Tensor([3],"float64"), offset=0, dim1=0, dim2=1, )
paddle.Tensor.fill_diagonal_tensor(Tensor([5, 3],"float32"), Tensor([3],"float32"), 0, 0, 1, )
paddle.Tensor.fill_diagonal_tensor(Tensor([5, 3],"float64"), Tensor([3],"float64"), 0, 0, 1, )
paddle.nn.functional.binary_cross_entropy(Tensor([1, 1, 2],"float64"), label=Tensor([1, 1, 2],"float64"), weight=None, reduction="mean", name=None, )
paddle.nn.functional.binary_cross_entropy(Tensor([100, 1],"float32"), Tensor([100, 1],"float32"), reduction="none", )
paddle.nn.functional.binary_cross_entropy(Tensor([100, 80],"float32"), Tensor([100, 80],"float32"), reduction="none", )
Expand Down
4 changes: 2 additions & 2 deletions tester/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,8 @@
]
)

# All configs that report dtype diff when not in not_check_dtype list
# should be moved to tester/api_config/5_accuracy/dtype_diff.txt
# All configs that report dtype diff when not in not_check_dtype list should be
# copied to tester/api_config/5_accuracy/accuracy_gpu_error_dtype_diff.txt
not_check_dtype = frozenset(
[
"paddle.where",
Expand Down