Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions tester/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,28 @@ def test(self):
# torch's from_dlpack now don't support negative strides
paddle_output = paddle_output.contiguous()

paddle_output_cache = []
if self.api_config.api_name == "paddle.linalg.eigh":
# The output of eigen vectors are not unique, because multiplying an eigen vector by -1 in the real case
# or by e^(i*\theta) in the complex case produces another set of valid eigen vectors of the matrix.
# So we test whether the elements of each coef_vector (i.e. paddle_output / torch_output for each eigen vector)
# are all the same and whether the |coef| == 1 for simplicity.
paddle_output, torch_output = list(paddle_output), list(torch_output)
paddle_output_cache = [i.clone() for i in paddle_output]
eigvector_len = paddle_output[1].shape[-2]
paddle_eigvectors = paddle_output.pop(1).matrix_transpose().reshape([-1, eigvector_len])
torch_eigvectors = torch_output.pop(1).transpose(-1, -2).reshape((-1, eigvector_len))
for i in range(paddle_eigvectors.shape[0]):
coef_vector = paddle.to_tensor(paddle_eigvectors[i].numpy()/torch_eigvectors[i].numpy(), dtype=paddle_eigvectors[i].dtype)
coef_vector = coef_vector.round(2)
coef_0 = paddle_eigvectors[i].numpy()[0]/torch_eigvectors[i].numpy()[0]
coef_vector_approx = torch.tensor([coef_0] * eigvector_len)
abs_coef = coef_vector.abs().astype("float64")[0]
one = torch.tensor(1.0, dtype=torch.float64)
paddle_output.append([coef_vector, abs_coef])
torch_output.append([coef_vector_approx, one])
Comment on lines +252 to +253
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我的想法是:

paddle_output = [coef_vector, abs_coef]
torch_output = [coef_vector_approx, one]

这样可以同时比较绝对误差和相对误差,不需要 paddle_output_cache~

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯嗯,忘记删了



if isinstance(paddle_output, paddle.Tensor):
if isinstance(torch_output, torch.Tensor):
try:
Expand Down Expand Up @@ -332,6 +354,9 @@ def test(self):
write_to_log("accuracy_error", self.api_config.config)
return

if self.api_config.api_name == "paddle.linalg.eigh":
paddle_output = paddle_output_cache

if self.need_check_grad() and torch_grad_success:
try:
paddle_out_grads = None
Expand Down
1 change: 1 addition & 0 deletions tester/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1105,6 +1105,7 @@ def is_forward_only(self):
"distributed_push_sparse",
"dpsgd",
"edit_distance",
"eigh",
"eigvals",
"embedding_grad_dense",
"embedding_with_eltwise_add_xpu",
Expand Down