Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reconstruct operator result comparison and support output of relative… #26

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 5 additions & 83 deletions op_tools/op_autocompare_hook.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Copyright (c) 2024, DeepLink.
import torch
import math
import gc
import os

Expand All @@ -12,6 +11,7 @@
is_inplace_op,
is_view_op,
is_opname_match,
compare_result,
)
from .save_op_args import save_op_args, serialize_args_to_dict

Expand Down Expand Up @@ -42,84 +42,6 @@
]


def tensor_max_diff(a, b):
a_cpu, b_cpu = a.cpu(), b.cpu()
if a_cpu.dtype == torch.bool:
a_cpu = a_cpu.int()
if b_cpu.dtype == torch.bool:
b_cpu = b_cpu.int()
diff = torch.abs(a_cpu - b_cpu)
max_diff = diff.max().item()
return max_diff


def tensor_allclose(a, b, atol=1e-3, rtol=1e-3):
a_cpu, b_cpu = a.cpu(), b.cpu()
try:
return torch.allclose(a_cpu, b_cpu, atol=atol, rtol=rtol, equal_nan=True)
except Exception as e: # noqa: F841
return False
return False


def compare_result(name, a, b, atol=1e-3):
error_info = ""
max_diff = float("nan")
allclose = False
if a is None and b is None:
allclose = True
max_diff = 0
print(f"OpAutoCompareHook: {name:<50} allclose: {allclose}\tmax_diff: {f'{max_diff:20.9f}'} {error_info}")
elif isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor):
if a.shape == b.shape:
max_diff = tensor_max_diff(a, b)
allclose = tensor_allclose(a, b)
else:
max_diff = float("nan")
allclose = False
error_info = f"Inconsistent shape: {a.shape} {b.shape}"
if a.dtype != b.dtype:
error_info = f"Inconsistent dtypes: {a.dtype} {b.dtype}"
print(f"OpAutoCompareHook: {name:<50} allclose: {allclose}\tmax_diff: {f'{max_diff:20.9f}'} {error_info}")
elif type(a) != type(b): # noqa: E721
error_info = f"Inconsistent types: {a} {b}"
print(f"OpAutoCompareHook: {name:<50} allclose: {allclose}\tmax_diff: {f'{max_diff:20.9f}'} {error_info}")
elif isinstance(a, (bool, int, float)):
allclose = a == b or (math.isnan(a) and math.isnan(b))
max_diff = a - b
print(f"OpAutoCompareHook: {name:<50} allclose: {allclose}\tmax_diff: {f'{max_diff:20.9f}'}")
elif type(a).__module__.startswith("torch.return_types") or isinstance(a, (tuple, list)):
max_diff_list = []
allclose_list = []
error_info_i = ""
for i in range(len(a)):
if isinstance(a[i], torch.Tensor) and isinstance(a[i], torch.Tensor):
max_diff_i = tensor_max_diff(a[i], b[i])
allclose_i = tensor_allclose(a[i], b[i])
max_diff_list.append(max_diff_i)
allclose_list.append(allclose_i)
if a[0].dtype != b[0].dtype:
error_info_i = f"Inconsistent dtypes: {a[i].dtype} {b[i].dtype}"
print(
f"OpAutoCompareHook: {name:<46} {i}th allclose: {allclose_i}\tmax_diff: {f'{max_diff_i:20.9f}'} {error_info_i}"
)
else:
allclose_i = a[i] == b[i] or (math.isnan(a[i]) and math.isnan(b[i]))
max_diff_i = a[i] - b[i]
max_diff_list.append(max_diff_i)
allclose_list.append(allclose_i)
print(
f"OpAutoCompareHook: {name:<46} {i}th allclose: {allclose_i}\tmax_diff: {f'{max_diff_i:20.9f}'} {error_info_i}"
)

allclose = all(allclose_list)
max_diff = max(max_diff_list)
else:
print(f"OpAutoCompareHook: {name:} {__file__} unhandle output type: {type(a)}")

return allclose, max_diff


class BackwardHookHandle:
def __init__(self, compare_hook) -> None:
self.compare_hook = compare_hook
Expand Down Expand Up @@ -215,15 +137,15 @@ def after_call_op(self, result): # noqa:C901
)

if is_inplace_op(self.name):
allclose, max_diff = compare_result(self.name, self.args[0], args_cpu[0])
allclose = compare_result(self.name, self.args[0], args_cpu[0])["allclose"]
if not allclose:
self.save_forward_args()

if self.result is None:
print(f"{self.name} output is None, acc not checked")
return

allclose, max_diff = compare_result(self.name, self.result_device, self.result_cpu)
allclose = compare_result(self.name, self.result_device, self.result_cpu)["allclose"]

self.forward_allclose = allclose
self.forward_op_id = self.id
Expand Down Expand Up @@ -297,11 +219,11 @@ def compare_all_grad(self):

if isinstance(arg, torch.Tensor) and arg.requires_grad:
arg_cpu_grad = self.args_cpu_grad[i]
allclose, max_diff = compare_result(
allclose = compare_result(
self.name + f" (ins[{i}].grad)",
self.args_grad[i],
arg_cpu_grad,
)
)["allclose"]

if not allclose:
all_grad_allclose = False
Expand Down
85 changes: 85 additions & 0 deletions op_tools/test/test_compare_result.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright (c) 2024, DeepLink.
from op_tools.utils import compare_result
import torch
import ditorch
import unittest


class TestCompareResult(unittest.TestCase):

def test_compare_same_tensor(self):
result1 = torch.randn(10, 10).cuda()
compare_info = compare_result("same_tensor", result1, result1)
self.assertTrue(compare_info["allclose"])
self.assertTrue(compare_info["max_abs_diff"] == 0)
self.assertTrue(compare_info["max_relative_diff"] == 0)
self.assertTrue(compare_info["error_info"] == "")

def test_compare_randn_tensor(self):
result1 = torch.randn(10, 10).cuda()
result2 = torch.randn(10, 10).cuda()
max_abs = torch.abs(result1 - result2)
max_abs_diff = torch.max(max_abs).item()
max_relate_diff = (max_abs / (torch.abs(result1) + 1e-6)).max().item()
compare_info = compare_result("randn_tensor", result1, result2)
self.assertTrue(compare_info["allclose"] is False)
self.assertTrue(compare_info["max_abs_diff"] == max_abs_diff)
self.assertTrue(
abs(compare_info["max_relative_diff"] - max_relate_diff) < 1e-3,
f"{compare_info['max_relative_diff']} != {max_relate_diff}",
)
self.assertTrue(compare_info["error_info"] == "")

def test_compare_randn_tensor_list(self):
result1 = torch.randn(10, 10).cuda()
result2 = torch.randn(10, 10).cuda()
max_abs = torch.abs(result1 - result2)
max_abs_diff = torch.max(max_abs).item()
max_relate_diff = (max_abs / (torch.abs(result1) + 1e-6)).max().item()

tensor_list1 = [result1, result1]
tensor_list2 = [result2, result2]

compare_info = compare_result("randn_tensor_list", tensor_list1, tensor_list2)
self.assertTrue(compare_info["allclose"] is False)
self.assertTrue(compare_info["max_abs_diff"] == max_abs_diff)
self.assertTrue(
abs(compare_info["max_relative_diff"] - max_relate_diff) < 1e-3,
f"{compare_info['max_relative_diff']} != {max_relate_diff}",
)
self.assertTrue(compare_info["error_info"] == "")

def test_compare_same_int_list(self):
result1 = [t for t in range(10)]
compare_info = compare_result("same_int_list", result1, result1)
self.assertTrue(compare_info["allclose"])
self.assertTrue(compare_info["max_abs_diff"] == 0)
self.assertTrue(compare_info["max_relative_diff"] == 0)
self.assertTrue(compare_info["error_info"] == "")

def test_compare_diff_int_list(self):
result1 = [t for t in range(10)]
result2 = [t * 2 for t in range(10)]
compare_info = compare_result("diff_int_list", result1, result2)
self.assertTrue(compare_info["allclose"] is False, compare_info)
self.assertTrue(compare_info["max_abs_diff"] == 9, compare_info)
self.assertTrue(abs(compare_info["max_relative_diff"] - 1) < 1e-3, compare_info)
self.assertTrue(compare_info["error_info"] == "")

def test_same_torch_return_type(self):
result1 = torch.randn(10, 10).cuda().sort()

compare_info = compare_result("same_torch_return_type", result1, result1)
self.assertTrue(compare_info["allclose"] is True)
self.assertTrue(compare_info["max_abs_diff"] == 0)

def test_diff_torch_return_type(self):
result1 = torch.randn(10, 10).cuda().sort()
result2 = torch.randn(10, 10).cuda().sort()

compare_info = compare_result("diff_torch_return_type", result1, result2)
self.assertTrue(compare_info["allclose"] is False)


if __name__ == "__main__":
unittest.main()
79 changes: 79 additions & 0 deletions op_tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch
import re
import importlib
import math


def traverse_container(container):
Expand Down Expand Up @@ -148,3 +149,81 @@ def get_dtype_cast_dict_form_str(config):

def is_view_op(name):
return name in VIEW_OPS


def tensor_max_diff(a, b):
a_cpu, b_cpu = a.cpu(), b.cpu()
if a_cpu.dtype == torch.bool:
a_cpu = a_cpu.int()
if b_cpu.dtype == torch.bool:
b_cpu = b_cpu.int()
diff = torch.abs(a_cpu - b_cpu)
max_abs_diff = diff.max().item()
max_relative_diff = (diff / (a_cpu.abs() + 1e-6)).max().item()
return max_abs_diff, max_relative_diff


def tensor_allclose(a, b, atol=1e-3, rtol=1e-3):
a_cpu, b_cpu = a.cpu(), b.cpu()
try:
return torch.allclose(a_cpu, b_cpu, atol=atol, rtol=rtol, equal_nan=True)
except Exception as e: # noqa: F841
return False
return False


def compare_result(name, a, b, atol=1e-3, rtol=1e-3):
a_list = []
b_list = []
allclose, max_abs_diff, max_relative_diff, error_info = True, 0, 0, ""
for item in traverse_container(a):
a_list.append(item)
for item in traverse_container(b):
b_list.append(item)

if len(a_list) != len(b_list):
error_info += f"Inconsistent output length: {len(a_list)} {len(b_list)}, {a} {b}"
max_abs_diff = float("nan")
max_relative_diff = float("nan")
allclose = False
return {"allclose": allclose, "max_abs_diff": max_abs_diff, "max_relative_diff": max_relative_diff, "error_info": error_info}

for i in range(len(a_list)):
a_item = a_list[i]
b_item = b_list[i]
if a_item is None and b_item is None:
allclose_i = True
max_abs_diff_i = 0
max_relative_diff_i = 0
elif isinstance(a_item, torch.Tensor) and isinstance(b_item, torch.Tensor):
if a_item.shape == b_item.shape:
max_abs_diff_i, max_relative_diff_i = tensor_max_diff(a_item, b_item)
allclose_i = tensor_allclose(a_item, b_item, atol=atol, rtol=rtol)
else:
error_info += f"Inconsistent shape: {a_item.shape} {b_item.shape}"
max_abs_diff_i = float("nan")
max_relative_diff_i = float("nan")
allclose_i = False
if a_item.dtype != b_item.dtype:
error_info += f"Inconsistent dtypes: {a_item.dtype} {b_item.dtype}"

elif type(a) != type(b): # noqa: E721
error_info += f"Inconsistent types: {type(a)} {type(b)}"
max_abs_diff_i = float("nan")
max_relative_diff_i = float("nan")
allclose_i = False
elif isinstance(a_item, (bool, int, float)):
allclose_i = a_item == b_item or (math.isnan(a_item) and math.isnan(b_item))
max_abs_diff_i = abs(a_item - b_item)
max_relative_diff_i = max_abs_diff_i / (abs(a_item) + 1e-6)
error_info += ""
if len(a_list) > 1:
prefex = f" {i}th "
else:
prefex = ""

print(f"compare_result: {name + prefex:<50} allclose: {allclose_i}\tmax_abs_diff: {f'{max_abs_diff_i:20.9f}'} \tmax_relative_diff: {f'{max_relative_diff_i:20.9f}'} {error_info}") # noqa: E501
allclose = allclose_i and allclose
max_abs_diff = max(max_abs_diff_i, max_abs_diff)
max_relative_diff = max(max_relative_diff_i, max_relative_diff)
return {"allclose": allclose, "max_abs_diff": max_abs_diff, "max_relative_diff": max_relative_diff, "error_info": error_info}