Skip to content

Commit

Permalink
autocompare supports output cosine similarity
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaoguochun1995 committed Oct 9, 2024
1 parent 949a690 commit b43bf7c
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 2 deletions.
10 changes: 9 additions & 1 deletion op_tools/test/test_compare_result.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) 2024, DeepLink.
from op_tools.utils import compare_result
from op_tools.utils import compare_result, tensor_cos_similarity
import torch
import ditorch
import unittest
Expand Down Expand Up @@ -156,6 +156,14 @@ def test_compare_invalid_input(self):
self.assertFalse(compare_result("invalid_value_a", ["1", 2, 3], [1, 2, 3])["allclose"]) # 输入a的元素类型不符合要求
self.assertFalse(compare_result("invalid_value_b", [1, 2, 3], ["1", 2, 3])["allclose"]) # 输入b的元素类型不符合要求

def test_cosine_similarity(self):
x = torch.randn(3, 4, 4, device="cuda").float()
y = torch.randn(3, 4, 4, device="cuda")
self.assertTrue(abs(tensor_cos_similarity(x, x) - 1) < 1e-6)
self.assertTrue(abs(tensor_cos_similarity(x, -x) + 1) < 1e-6)
xy_cos_similarity = tensor_cos_similarity(x, y)
self.assertTrue(xy_cos_similarity >= -1 and xy_cos_similarity <= 1)


if __name__ == "__main__":
unittest.main()
15 changes: 14 additions & 1 deletion op_tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,12 @@ def tensor_max_diff(a, b):
return max_abs_diff, max_relative_diff


def tensor_cos_similarity(a, b):
a_cpu, b_cpu = a.cpu().float(), b.cpu().float()
cos_sim = torch.nn.functional.cosine_similarity(a_cpu.reshape(-1), b_cpu.reshape(-1), dim=-1)
return cos_sim.item()


def tensor_allclose(a, b, atol=1e-3, rtol=1e-3):
a_cpu, b_cpu = a.cpu(), b.cpu()
try:
Expand Down Expand Up @@ -246,7 +252,7 @@ def get_error_tolerance_for_type(dtype_name, atol, rtol):
def compare_result(name, a, b): # noqa: C901
a_list = []
b_list = []
allclose, max_abs_diff, max_relative_diff, error_info, atol, rtol = True, 0, 0, "", 0, 0
allclose, max_abs_diff, max_relative_diff, error_info, atol, rtol, cos_similarity = True, 0, 0, "", 0, 0, -1e8
for item in traverse_container(a):
a_list.append(item)
for item in traverse_container(b):
Expand All @@ -272,6 +278,7 @@ def compare_result(name, a, b): # noqa: C901
b_item = b_list[i]
atol_i, rtol_i = 0, 0
error_info_i = ""
cos_similarity_i = None
if a_item is None and b_item is None:
allclose_i = True
max_abs_diff_i = 0
Expand All @@ -285,6 +292,7 @@ def compare_result(name, a, b): # noqa: C901
if a_item.numel() > 0:
max_abs_diff_i, max_relative_diff_i = tensor_max_diff(a_item, b_item)
allclose_i = tensor_allclose(a_item, b_item, atol=atol_i, rtol=rtol_i)
cos_similarity_i = tensor_cos_similarity(a_item, b_item)
else:
max_abs_diff_i, max_relative_diff_i = 0.0, 0.0
allclose_i = True
Expand Down Expand Up @@ -337,10 +345,14 @@ def compare_result(name, a, b): # noqa: C901
max_relative_diff = max(max_relative_diff_i, max_relative_diff)
atol = max(atol_i, atol)
rtol = max(rtol_i, rtol)
if cos_similarity_i is None:
cos_similarity_i = 1 if allclose_i else -1
cos_similarity = max(cos_similarity, cos_similarity_i)
result_list.append(
{
"name": f"{name + prefex:<30}",
"allclose": allclose_i,
"cosine_similarity": f"{cos_similarity_i:1.9f}",
"max_abs_diff": f"{max_abs_diff_i:10.9f}",
"max_relative_diff": f"{max_relative_diff_i:10.9f}",
"atol": f"{atol_i:10.9f}",
Expand All @@ -351,6 +363,7 @@ def compare_result(name, a, b): # noqa: C901

return {
"allclose": allclose,
"cos_similarity": cos_similarity,
"max_abs_diff": max_abs_diff,
"max_relative_diff": max_relative_diff,
"error_info": error_info,
Expand Down

0 comments on commit b43bf7c

Please sign in to comment.