-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Reconstruct operator result comparison and support output of relative… (
#26) Reconstruct operator result comparison and support output of relative error
- Loading branch information
1 parent
6c968ea
commit 610fd57
Showing
3 changed files
with
169 additions
and
83 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
# Copyright (c) 2024, DeepLink. | ||
from op_tools.utils import compare_result | ||
import torch | ||
import ditorch | ||
import unittest | ||
|
||
|
||
class TestCompareResult(unittest.TestCase): | ||
|
||
def test_compare_same_tensor(self): | ||
result1 = torch.randn(10, 10).cuda() | ||
compare_info = compare_result("same_tensor", result1, result1) | ||
self.assertTrue(compare_info["allclose"]) | ||
self.assertTrue(compare_info["max_abs_diff"] == 0) | ||
self.assertTrue(compare_info["max_relative_diff"] == 0) | ||
self.assertTrue(compare_info["error_info"] == "") | ||
|
||
def test_compare_randn_tensor(self): | ||
result1 = torch.randn(10, 10).cuda() | ||
result2 = torch.randn(10, 10).cuda() | ||
max_abs = torch.abs(result1 - result2) | ||
max_abs_diff = torch.max(max_abs).item() | ||
max_relate_diff = (max_abs / (torch.abs(result1) + 1e-6)).max().item() | ||
compare_info = compare_result("randn_tensor", result1, result2) | ||
self.assertTrue(compare_info["allclose"] is False) | ||
self.assertTrue(compare_info["max_abs_diff"] == max_abs_diff) | ||
self.assertTrue( | ||
abs(compare_info["max_relative_diff"] - max_relate_diff) < 1e-3, | ||
f"{compare_info['max_relative_diff']} != {max_relate_diff}", | ||
) | ||
self.assertTrue(compare_info["error_info"] == "") | ||
|
||
def test_compare_randn_tensor_list(self): | ||
result1 = torch.randn(10, 10).cuda() | ||
result2 = torch.randn(10, 10).cuda() | ||
max_abs = torch.abs(result1 - result2) | ||
max_abs_diff = torch.max(max_abs).item() | ||
max_relate_diff = (max_abs / (torch.abs(result1) + 1e-6)).max().item() | ||
|
||
tensor_list1 = [result1, result1] | ||
tensor_list2 = [result2, result2] | ||
|
||
compare_info = compare_result("randn_tensor_list", tensor_list1, tensor_list2) | ||
self.assertTrue(compare_info["allclose"] is False) | ||
self.assertTrue(compare_info["max_abs_diff"] == max_abs_diff) | ||
self.assertTrue( | ||
abs(compare_info["max_relative_diff"] - max_relate_diff) < 1e-3, | ||
f"{compare_info['max_relative_diff']} != {max_relate_diff}", | ||
) | ||
self.assertTrue(compare_info["error_info"] == "") | ||
|
||
def test_compare_same_int_list(self): | ||
result1 = [t for t in range(10)] | ||
compare_info = compare_result("same_int_list", result1, result1) | ||
self.assertTrue(compare_info["allclose"]) | ||
self.assertTrue(compare_info["max_abs_diff"] == 0) | ||
self.assertTrue(compare_info["max_relative_diff"] == 0) | ||
self.assertTrue(compare_info["error_info"] == "") | ||
|
||
def test_compare_diff_int_list(self): | ||
result1 = [t for t in range(10)] | ||
result2 = [t * 2 for t in range(10)] | ||
compare_info = compare_result("diff_int_list", result1, result2) | ||
self.assertTrue(compare_info["allclose"] is False, compare_info) | ||
self.assertTrue(compare_info["max_abs_diff"] == 9, compare_info) | ||
self.assertTrue(abs(compare_info["max_relative_diff"] - 1) < 1e-3, compare_info) | ||
self.assertTrue(compare_info["error_info"] == "") | ||
|
||
def test_same_torch_return_type(self): | ||
result1 = torch.randn(10, 10).cuda().sort() | ||
|
||
compare_info = compare_result("same_torch_return_type", result1, result1) | ||
self.assertTrue(compare_info["allclose"] is True) | ||
self.assertTrue(compare_info["max_abs_diff"] == 0) | ||
|
||
def test_diff_torch_return_type(self): | ||
result1 = torch.randn(10, 10).cuda().sort() | ||
result2 = torch.randn(10, 10).cuda().sort() | ||
|
||
compare_info = compare_result("diff_torch_return_type", result1, result2) | ||
self.assertTrue(compare_info["allclose"] is False) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters