Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

autocompare supports output cosine similarity #56

Merged
merged 1 commit into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion op_tools/test/test_compare_result.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) 2024, DeepLink.
from op_tools.utils import compare_result
from op_tools.utils import compare_result, tensor_cos_similarity
import torch
import ditorch
import unittest
Expand Down Expand Up @@ -156,6 +156,14 @@ def test_compare_invalid_input(self):
self.assertFalse(compare_result("invalid_value_a", ["1", 2, 3], [1, 2, 3])["allclose"]) # 输入a的元素类型不符合要求
self.assertFalse(compare_result("invalid_value_b", [1, 2, 3], ["1", 2, 3])["allclose"]) # 输入b的元素类型不符合要求

def test_cosine_similarity(self):
x = torch.randn(3, 4, 4, device="cuda").float()
y = torch.randn(3, 4, 4, device="cuda")
self.assertTrue(abs(tensor_cos_similarity(x, x) - 1) < 1e-6)
self.assertTrue(abs(tensor_cos_similarity(x, -x) + 1) < 1e-6)
xy_cos_similarity = tensor_cos_similarity(x, y)
self.assertTrue(xy_cos_similarity >= -1 and xy_cos_similarity <= 1)


if __name__ == "__main__":
unittest.main()
15 changes: 14 additions & 1 deletion op_tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,12 @@ def tensor_max_diff(a, b):
return max_abs_diff, max_relative_diff


def tensor_cos_similarity(a, b):
a_cpu, b_cpu = a.cpu().float(), b.cpu().float()
cos_sim = torch.nn.functional.cosine_similarity(a_cpu.reshape(-1), b_cpu.reshape(-1), dim=-1)
return cos_sim.item()


def tensor_allclose(a, b, atol=1e-3, rtol=1e-3):
a_cpu, b_cpu = a.cpu(), b.cpu()
try:
Expand Down Expand Up @@ -246,7 +252,7 @@ def get_error_tolerance_for_type(dtype_name, atol, rtol):
def compare_result(name, a, b): # noqa: C901
a_list = []
b_list = []
allclose, max_abs_diff, max_relative_diff, error_info, atol, rtol = True, 0, 0, "", 0, 0
allclose, max_abs_diff, max_relative_diff, error_info, atol, rtol, cos_similarity = True, 0, 0, "", 0, 0, -1e8
for item in traverse_container(a):
a_list.append(item)
for item in traverse_container(b):
Expand All @@ -272,6 +278,7 @@ def compare_result(name, a, b): # noqa: C901
b_item = b_list[i]
atol_i, rtol_i = 0, 0
error_info_i = ""
cos_similarity_i = None
if a_item is None and b_item is None:
allclose_i = True
max_abs_diff_i = 0
Expand All @@ -285,6 +292,7 @@ def compare_result(name, a, b): # noqa: C901
if a_item.numel() > 0:
max_abs_diff_i, max_relative_diff_i = tensor_max_diff(a_item, b_item)
allclose_i = tensor_allclose(a_item, b_item, atol=atol_i, rtol=rtol_i)
cos_similarity_i = tensor_cos_similarity(a_item, b_item)
else:
max_abs_diff_i, max_relative_diff_i = 0.0, 0.0
allclose_i = True
Expand Down Expand Up @@ -337,10 +345,14 @@ def compare_result(name, a, b): # noqa: C901
max_relative_diff = max(max_relative_diff_i, max_relative_diff)
atol = max(atol_i, atol)
rtol = max(rtol_i, rtol)
if cos_similarity_i is None:
cos_similarity_i = 1 if allclose_i else -1
cos_similarity = max(cos_similarity, cos_similarity_i)
result_list.append(
{
"name": f"{name + prefex:<30}",
"allclose": allclose_i,
"cosine_similarity": f"{cos_similarity_i:1.9f}",
"max_abs_diff": f"{max_abs_diff_i:10.9f}",
"max_relative_diff": f"{max_relative_diff_i:10.9f}",
"atol": f"{atol_i:10.9f}",
Expand All @@ -351,6 +363,7 @@ def compare_result(name, a, b): # noqa: C901

return {
"allclose": allclose,
"cos_similarity": cos_similarity,
"max_abs_diff": max_abs_diff,
"max_relative_diff": max_relative_diff,
"error_info": error_info,
Expand Down