Skip to content

Commit

Permalink
Reconstruct operator result comparison and support output of relative… (
Browse files Browse the repository at this point in the history
#26)

Reconstruct operator result comparison and support output of relative error
  • Loading branch information
zhaoguochun1995 authored Sep 10, 2024
1 parent 6c968ea commit 610fd57
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 83 deletions.
88 changes: 5 additions & 83 deletions op_tools/op_autocompare_hook.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Copyright (c) 2024, DeepLink.
import torch
import math
import gc
import os

Expand All @@ -12,6 +11,7 @@
is_inplace_op,
is_view_op,
is_opname_match,
compare_result,
)
from .save_op_args import save_op_args, serialize_args_to_dict

Expand Down Expand Up @@ -42,84 +42,6 @@
]


def tensor_max_diff(a, b):
a_cpu, b_cpu = a.cpu(), b.cpu()
if a_cpu.dtype == torch.bool:
a_cpu = a_cpu.int()
if b_cpu.dtype == torch.bool:
b_cpu = b_cpu.int()
diff = torch.abs(a_cpu - b_cpu)
max_diff = diff.max().item()
return max_diff


def tensor_allclose(a, b, atol=1e-3, rtol=1e-3):
a_cpu, b_cpu = a.cpu(), b.cpu()
try:
return torch.allclose(a_cpu, b_cpu, atol=atol, rtol=rtol, equal_nan=True)
except Exception as e: # noqa: F841
return False
return False


def compare_result(name, a, b, atol=1e-3):
error_info = ""
max_diff = float("nan")
allclose = False
if a is None and b is None:
allclose = True
max_diff = 0
print(f"OpAutoCompareHook: {name:<50} allclose: {allclose}\tmax_diff: {f'{max_diff:20.9f}'} {error_info}")
elif isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor):
if a.shape == b.shape:
max_diff = tensor_max_diff(a, b)
allclose = tensor_allclose(a, b)
else:
max_diff = float("nan")
allclose = False
error_info = f"Inconsistent shape: {a.shape} {b.shape}"
if a.dtype != b.dtype:
error_info = f"Inconsistent dtypes: {a.dtype} {b.dtype}"
print(f"OpAutoCompareHook: {name:<50} allclose: {allclose}\tmax_diff: {f'{max_diff:20.9f}'} {error_info}")
elif type(a) != type(b): # noqa: E721
error_info = f"Inconsistent types: {a} {b}"
print(f"OpAutoCompareHook: {name:<50} allclose: {allclose}\tmax_diff: {f'{max_diff:20.9f}'} {error_info}")
elif isinstance(a, (bool, int, float)):
allclose = a == b or (math.isnan(a) and math.isnan(b))
max_diff = a - b
print(f"OpAutoCompareHook: {name:<50} allclose: {allclose}\tmax_diff: {f'{max_diff:20.9f}'}")
elif type(a).__module__.startswith("torch.return_types") or isinstance(a, (tuple, list)):
max_diff_list = []
allclose_list = []
error_info_i = ""
for i in range(len(a)):
if isinstance(a[i], torch.Tensor) and isinstance(a[i], torch.Tensor):
max_diff_i = tensor_max_diff(a[i], b[i])
allclose_i = tensor_allclose(a[i], b[i])
max_diff_list.append(max_diff_i)
allclose_list.append(allclose_i)
if a[0].dtype != b[0].dtype:
error_info_i = f"Inconsistent dtypes: {a[i].dtype} {b[i].dtype}"
print(
f"OpAutoCompareHook: {name:<46} {i}th allclose: {allclose_i}\tmax_diff: {f'{max_diff_i:20.9f}'} {error_info_i}"
)
else:
allclose_i = a[i] == b[i] or (math.isnan(a[i]) and math.isnan(b[i]))
max_diff_i = a[i] - b[i]
max_diff_list.append(max_diff_i)
allclose_list.append(allclose_i)
print(
f"OpAutoCompareHook: {name:<46} {i}th allclose: {allclose_i}\tmax_diff: {f'{max_diff_i:20.9f}'} {error_info_i}"
)

allclose = all(allclose_list)
max_diff = max(max_diff_list)
else:
print(f"OpAutoCompareHook: {name:} {__file__} unhandle output type: {type(a)}")

return allclose, max_diff


class BackwardHookHandle:
def __init__(self, compare_hook) -> None:
self.compare_hook = compare_hook
Expand Down Expand Up @@ -215,15 +137,15 @@ def after_call_op(self, result): # noqa:C901
)

if is_inplace_op(self.name):
allclose, max_diff = compare_result(self.name, self.args[0], args_cpu[0])
allclose = compare_result(self.name, self.args[0], args_cpu[0])["allclose"]
if not allclose:
self.save_forward_args()

if self.result is None:
print(f"{self.name} output is None, acc not checked")
return

allclose, max_diff = compare_result(self.name, self.result_device, self.result_cpu)
allclose = compare_result(self.name, self.result_device, self.result_cpu)["allclose"]

self.forward_allclose = allclose
self.forward_op_id = self.id
Expand Down Expand Up @@ -297,11 +219,11 @@ def compare_all_grad(self):

if isinstance(arg, torch.Tensor) and arg.requires_grad:
arg_cpu_grad = self.args_cpu_grad[i]
allclose, max_diff = compare_result(
allclose = compare_result(
self.name + f" (ins[{i}].grad)",
self.args_grad[i],
arg_cpu_grad,
)
)["allclose"]

if not allclose:
all_grad_allclose = False
Expand Down
85 changes: 85 additions & 0 deletions op_tools/test/test_compare_result.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright (c) 2024, DeepLink.
from op_tools.utils import compare_result
import torch
import ditorch
import unittest


class TestCompareResult(unittest.TestCase):

def test_compare_same_tensor(self):
result1 = torch.randn(10, 10).cuda()
compare_info = compare_result("same_tensor", result1, result1)
self.assertTrue(compare_info["allclose"])
self.assertTrue(compare_info["max_abs_diff"] == 0)
self.assertTrue(compare_info["max_relative_diff"] == 0)
self.assertTrue(compare_info["error_info"] == "")

def test_compare_randn_tensor(self):
result1 = torch.randn(10, 10).cuda()
result2 = torch.randn(10, 10).cuda()
max_abs = torch.abs(result1 - result2)
max_abs_diff = torch.max(max_abs).item()
max_relate_diff = (max_abs / (torch.abs(result1) + 1e-6)).max().item()
compare_info = compare_result("randn_tensor", result1, result2)
self.assertTrue(compare_info["allclose"] is False)
self.assertTrue(compare_info["max_abs_diff"] == max_abs_diff)
self.assertTrue(
abs(compare_info["max_relative_diff"] - max_relate_diff) < 1e-3,
f"{compare_info['max_relative_diff']} != {max_relate_diff}",
)
self.assertTrue(compare_info["error_info"] == "")

def test_compare_randn_tensor_list(self):
result1 = torch.randn(10, 10).cuda()
result2 = torch.randn(10, 10).cuda()
max_abs = torch.abs(result1 - result2)
max_abs_diff = torch.max(max_abs).item()
max_relate_diff = (max_abs / (torch.abs(result1) + 1e-6)).max().item()

tensor_list1 = [result1, result1]
tensor_list2 = [result2, result2]

compare_info = compare_result("randn_tensor_list", tensor_list1, tensor_list2)
self.assertTrue(compare_info["allclose"] is False)
self.assertTrue(compare_info["max_abs_diff"] == max_abs_diff)
self.assertTrue(
abs(compare_info["max_relative_diff"] - max_relate_diff) < 1e-3,
f"{compare_info['max_relative_diff']} != {max_relate_diff}",
)
self.assertTrue(compare_info["error_info"] == "")

def test_compare_same_int_list(self):
result1 = [t for t in range(10)]
compare_info = compare_result("same_int_list", result1, result1)
self.assertTrue(compare_info["allclose"])
self.assertTrue(compare_info["max_abs_diff"] == 0)
self.assertTrue(compare_info["max_relative_diff"] == 0)
self.assertTrue(compare_info["error_info"] == "")

def test_compare_diff_int_list(self):
result1 = [t for t in range(10)]
result2 = [t * 2 for t in range(10)]
compare_info = compare_result("diff_int_list", result1, result2)
self.assertTrue(compare_info["allclose"] is False, compare_info)
self.assertTrue(compare_info["max_abs_diff"] == 9, compare_info)
self.assertTrue(abs(compare_info["max_relative_diff"] - 1) < 1e-3, compare_info)
self.assertTrue(compare_info["error_info"] == "")

def test_same_torch_return_type(self):
result1 = torch.randn(10, 10).cuda().sort()

compare_info = compare_result("same_torch_return_type", result1, result1)
self.assertTrue(compare_info["allclose"] is True)
self.assertTrue(compare_info["max_abs_diff"] == 0)

def test_diff_torch_return_type(self):
result1 = torch.randn(10, 10).cuda().sort()
result2 = torch.randn(10, 10).cuda().sort()

compare_info = compare_result("diff_torch_return_type", result1, result2)
self.assertTrue(compare_info["allclose"] is False)


if __name__ == "__main__":
unittest.main()
79 changes: 79 additions & 0 deletions op_tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch
import re
import importlib
import math


def traverse_container(container):
Expand Down Expand Up @@ -148,3 +149,81 @@ def get_dtype_cast_dict_form_str(config):

def is_view_op(name):
return name in VIEW_OPS


def tensor_max_diff(a, b):
a_cpu, b_cpu = a.cpu(), b.cpu()
if a_cpu.dtype == torch.bool:
a_cpu = a_cpu.int()
if b_cpu.dtype == torch.bool:
b_cpu = b_cpu.int()
diff = torch.abs(a_cpu - b_cpu)
max_abs_diff = diff.max().item()
max_relative_diff = (diff / (a_cpu.abs() + 1e-6)).max().item()
return max_abs_diff, max_relative_diff


def tensor_allclose(a, b, atol=1e-3, rtol=1e-3):
a_cpu, b_cpu = a.cpu(), b.cpu()
try:
return torch.allclose(a_cpu, b_cpu, atol=atol, rtol=rtol, equal_nan=True)
except Exception as e: # noqa: F841
return False
return False


def compare_result(name, a, b, atol=1e-3, rtol=1e-3):
a_list = []
b_list = []
allclose, max_abs_diff, max_relative_diff, error_info = True, 0, 0, ""
for item in traverse_container(a):
a_list.append(item)
for item in traverse_container(b):
b_list.append(item)

if len(a_list) != len(b_list):
error_info += f"Inconsistent output length: {len(a_list)} {len(b_list)}, {a} {b}"
max_abs_diff = float("nan")
max_relative_diff = float("nan")
allclose = False
return {"allclose": allclose, "max_abs_diff": max_abs_diff, "max_relative_diff": max_relative_diff, "error_info": error_info}

for i in range(len(a_list)):
a_item = a_list[i]
b_item = b_list[i]
if a_item is None and b_item is None:
allclose_i = True
max_abs_diff_i = 0
max_relative_diff_i = 0
elif isinstance(a_item, torch.Tensor) and isinstance(b_item, torch.Tensor):
if a_item.shape == b_item.shape:
max_abs_diff_i, max_relative_diff_i = tensor_max_diff(a_item, b_item)
allclose_i = tensor_allclose(a_item, b_item, atol=atol, rtol=rtol)
else:
error_info += f"Inconsistent shape: {a_item.shape} {b_item.shape}"
max_abs_diff_i = float("nan")
max_relative_diff_i = float("nan")
allclose_i = False
if a_item.dtype != b_item.dtype:
error_info += f"Inconsistent dtypes: {a_item.dtype} {b_item.dtype}"

elif type(a) != type(b): # noqa: E721
error_info += f"Inconsistent types: {type(a)} {type(b)}"
max_abs_diff_i = float("nan")
max_relative_diff_i = float("nan")
allclose_i = False
elif isinstance(a_item, (bool, int, float)):
allclose_i = a_item == b_item or (math.isnan(a_item) and math.isnan(b_item))
max_abs_diff_i = abs(a_item - b_item)
max_relative_diff_i = max_abs_diff_i / (abs(a_item) + 1e-6)
error_info += ""
if len(a_list) > 1:
prefex = f" {i}th "
else:
prefex = ""

print(f"compare_result: {name + prefex:<50} allclose: {allclose_i}\tmax_abs_diff: {f'{max_abs_diff_i:20.9f}'} \tmax_relative_diff: {f'{max_relative_diff_i:20.9f}'} {error_info}") # noqa: E501
allclose = allclose_i and allclose
max_abs_diff = max(max_abs_diff_i, max_abs_diff)
max_relative_diff = max(max_relative_diff_i, max_relative_diff)
return {"allclose": allclose, "max_abs_diff": max_abs_diff, "max_relative_diff": max_relative_diff, "error_info": error_info}

0 comments on commit 610fd57

Please sign in to comment.