Skip to content

Commit

Permalink
Automatic precision comparison supports customizing the data type use…
Browse files Browse the repository at this point in the history
…d in CPU calculations by operator.
  • Loading branch information
zhaoguochun1995 committed Oct 8, 2024
1 parent da1161e commit 5c25bd9
Show file tree
Hide file tree
Showing 7 changed files with 1,024 additions and 207 deletions.
1,056 changes: 910 additions & 146 deletions README.md

Large diffs are not rendered by default.

95 changes: 63 additions & 32 deletions op_tools/op_autocompare_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
compare_result,
is_random_number_gen_op,
garbage_collect,
get_dtype_cast_dict_form_str,
set_env_if_env_is_empty
)
from .save_op_args import save_op_args, serialize_args_to_dict

Expand Down Expand Up @@ -76,11 +78,16 @@ def grad_fun(grad_inputs, grad_outputs):
return grad_fun


set_env_if_env_is_empty("LINEAR_OP_DTYPE_CAST_DICT", "torch.float16->torch.float64,torch.bfloat16->torch.float64,torch.float32->torch.float64") # noqa: E501
set_env_if_env_is_empty("EMBEDDING_OP_DTYPE_CAST_DICT", "torch.float16->torch.float64,torch.bfloat16->torch.float64,torch.float32->torch.float64") # noqa: E501
set_env_if_env_is_empty("NORMALIZE_OP_DTYPE_CAST_DICT", "torch.float16->torch.float64,torch.bfloat16->torch.float64,torch.float32->torch.float64") # noqa: E501
set_env_if_env_is_empty("NORM_OP_DTYPE_CAST_DICT", "torch.float16->torch.float64,torch.bfloat16->torch.float64,torch.float32->torch.float64") # noqa: E501
set_env_if_env_is_empty("CROSS_ENTROPY_OP_DTYPE_CAST_DICT", "torch.float16->torch.float64,torch.bfloat16->torch.float64,torch.float32->torch.float64") # noqa: E501
set_env_if_env_is_empty("MUL_OP_DTYPE_CAST_DICT", "torch.float16->torch.float64,torch.bfloat16->torch.float64,torch.float32->torch.float64") # noqa: E501
set_env_if_env_is_empty("MATMUL_OP_DTYPE_CAST_DICT", "torch.float16->torch.float64,torch.bfloat16->torch.float64,torch.float32->torch.float64") # noqa: E501


class OpAutoCompareHook(BaseHook):
AUTO_COMPARE_DTYPE_CAST_DICT = {
torch.half: torch.float32,
torch.bfloat16: torch.float32,
}

def __init__(self, name, func) -> None:
super().__init__(name, func)
Expand All @@ -97,39 +104,63 @@ def copy_input_to_cpu(self):
detach=True,
)

def update_dtype_cast_dict(self):
# some op on cpu backend not support half, bfloat16
op_name_suffix = self.name.split(".")[-1].upper() # LINEAR, ADD, MATMUL etc
heigher_priority_env_name = op_name_suffix + "_OP_DTYPE_CAST_DICT"
lower_priority_env_name = "OP_DTYPE_CAST_DICT"
default_env_value = "torch.float16->torch.float32,torch.bfloat16->torch.float32"
heigher_priority_env_value = os.environ.get(heigher_priority_env_name, None)
lower_priority_env_value = os.environ.get(lower_priority_env_name, default_env_value)
self.dtype_cast_config_str = heigher_priority_env_value or lower_priority_env_value

self.dtype_cast_dict = get_dtype_cast_dict_form_str(self.dtype_cast_config_str)

ins_dtype = set()
for args in traverse_container(self.args):
if isinstance(args, torch.Tensor):
ins_dtype.add(args.dtype)
raw_dtypes = set(self.dtype_cast_dict.keys())
for dtype in raw_dtypes:
if dtype not in ins_dtype:
self.dtype_cast_dict.pop(dtype)

def run_forward_on_cpu(self):
self.result_device = to_device("cpu", self.result, detach=True)
try:
self.update_dtype_cast_dict()

self.args_cpu = to_device(
"cpu",
self.args_cpu,
dtype_cast_dict=self.dtype_cast_dict,
detach=True,
)
self.kwargs_cpu = to_device(
"cpu",
self.kwargs_cpu or {},
dtype_cast_dict=self.dtype_cast_dict,
detach=True,
)
# RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.
if (is_inplace_op(self.name) or self.kwargs.get("inplace", False) or is_view_op(self.name)) and self.args[0].requires_grad:
args_cpu = [item for item in self.args_cpu]
args_cpu[0] = args_cpu[0].clone()
self.args_cpu = tuple(args_cpu)
self.result_cpu = self.func(*self.args_cpu, **self.kwargs_cpu)
self.dtype_cast_dict = dict()
args_cpu = self.args_cpu
except Exception as e: # noqa: F841
self.dtype_cast_dict = OpAutoCompareHook.AUTO_COMPARE_DTYPE_CAST_DICT
# some op on cpu backend not support half, bfloat16
self.args_cpu = to_device(
"cpu",
self.args_cpu,
dtype_cast_dict=self.dtype_cast_dict,
detach=True,
)
self.kwargs_cpu = to_device(
else:
self.result_cpu = self.func(*self.args_cpu, **self.kwargs_cpu)

if len(self.dtype_cast_dict) > 0 and 0:
self.result_cpu = to_device(
"cpu",
self.kwargs_cpu or {},
dtype_cast_dict=self.dtype_cast_dict,
self.result_cpu,
dtype_cast_dict={value : key for key, value in self.dtype_cast_dict.items()},
detach=True,
)
# RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.
if (is_inplace_op(self.name) or self.kwargs.get("inplace", False) or is_view_op(self.name)) and self.args[0].requires_grad:
args_cpu = [item for item in self.args_cpu]
args_cpu[0] = args_cpu[0].clone()
self.args_cpu = tuple(args_cpu)
self.result_cpu = self.func(*self.args_cpu, **self.kwargs_cpu)
else:
args_cpu = self.args_cpu
self.result_cpu = self.func(*self.args_cpu, **self.kwargs_cpu)

def run_backward_on_cpu(self, grad_inputs, grad_output):
self.grad_outputs_cpu = to_device("cpu", grad_output, dtype_cast_dict=self.dtype_cast_dict, detach=True)
self.grad_outputs_device = to_device("cpu", grad_output, detach=True)
self.grad_outputs_cpu = to_device("cpu", self.grad_outputs_device, dtype_cast_dict=self.dtype_cast_dict, detach=True)
self.grad_inputs_cpu = to_device("cpu", grad_inputs, dtype_cast_dict=self.dtype_cast_dict, detach=True)
for arg_cpu in traverse_container(self.args_cpu):
if isinstance(arg_cpu, torch.Tensor) and arg_cpu.grad is not None:
Expand Down Expand Up @@ -255,7 +286,7 @@ def compare_backward_relate(self):

id = self.forward_op_id
self = None
garbage_collect(id)
garbage_collect(id, 10)

def save_forward_args(self):
save_op_args(
Expand All @@ -277,7 +308,7 @@ def save_backward_args(self):
save_op_args(
self.name,
f"{self.identifier}/device/grad_outputs",
*tuple(self.grad_outputs_cpu),
*tuple(self.grad_outputs_device),
)
save_op_args(
self.name,
Expand Down Expand Up @@ -326,7 +357,7 @@ def after_call_op(self, result): # noqa:C901
else:
self = None

garbage_collect(id)
garbage_collect(id, 10)
return result

def is_should_apply(self, *args, **kwargs):
Expand Down
9 changes: 4 additions & 5 deletions op_tools/op_dtype_cast_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
traverse_container,
get_dtype_cast_dict_form_str,
is_opname_match,
garbage_collect
)
from .pretty_print import dict_data_list_to_table

Expand All @@ -17,11 +18,6 @@ class OpDtypeCastHook(BaseHook):

def __init__(self, name, func) -> None:
super().__init__(name, func)
self.dtype_cast_config_str = os.environ.get(
"OP_DTYPE_CAST_DICT",
"torch.float16->torch.float32,torch.bfloat16->torch.float32",
)
self.dtype_cast_dict = get_dtype_cast_dict_form_str(self.dtype_cast_config_str)

def before_call_op(self, *args, **kwargs):
self.dtype_cast_config_str = os.environ.get(
Expand Down Expand Up @@ -93,6 +89,9 @@ def after_call_op(self, result):
self.data_dict_list.append(data_dict)
if len(self.data_dict_list) > 0:
print(dict_data_list_to_table(self.data_dict_list))
id = self.id
self = None
garbage_collect(id)

def is_should_apply(self, *args, **kwargs):
if is_opname_match(self.name, os.getenv("OP_DTYPE_CAST_DISABLE_LIST", "")):
Expand Down
25 changes: 19 additions & 6 deletions op_tools/op_fallback_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os

from .base_hook import BaseHook, DisableHookGuard
from .utils import to_device, is_cpu_op, is_opname_match
from .utils import to_device, is_cpu_op, is_opname_match, garbage_collect
from .save_op_args import serialize_args_to_dict
from .pretty_print import packect_data_to_dict_list, dict_data_list_to_table

Expand Down Expand Up @@ -36,6 +36,7 @@ def get_dtype_convert_back_dict(self):

def before_call_op(self, *args, **kwargs):
with DisableHookGuard():
self.dtype_cast_dict = dict()
self.is_cpu_op, self.device = is_cpu_op(*args, **kwargs)
if self.is_cpu_op:
return
Expand All @@ -56,24 +57,28 @@ def after_call_op(self, result):
with DisableHookGuard():
if self.result is not None and self.exception is None:
self.result_cpu = self.result
dtype_convert_back_dict = dict()
self.dtype_convert_back_dict = dict()
else:
# cpu backend do not support half or bfloat16
self.dtype_cast_dict = OpFallbackHook.FALLBACK_DTYPE_CAST_DICT
self.args = to_device(
"cpu",
self.args_device,
dtype_cast_dict=OpFallbackHook.FALLBACK_DTYPE_CAST_DICT,
dtype_cast_dict=self.dtype_cast_dict,
)
self.kwargs = to_device(
"cpu",
self.kwargs_device or {},
dtype_cast_dict=OpFallbackHook.FALLBACK_DTYPE_CAST_DICT,
dtype_cast_dict=self.dtype_cast_dict,
)
self.result_cpu = self.func(*self.args, **self.kwargs)
dtype_convert_back_dict = self.get_dtype_convert_back_dict()
self.dtype_convert_back_dict = self.get_dtype_convert_back_dict()

self.result = to_device(self.device, self.result_cpu, dtype_convert_back_dict)
self.result = to_device(self.device, self.result_cpu, self.dtype_convert_back_dict)
self.dump_op_args()
id = self.id
self = None
garbage_collect(id)

def dump_op_args(self):
data_dict_list = []
Expand All @@ -89,7 +94,15 @@ def dump_op_args(self):
data_dict_list += packect_data_to_dict_list(self.name + " output(cpu)", serialize_args_to_dict(self.result_cpu))

table = dict_data_list_to_table(data_dict_list)
dtype_cast_info = ""
if len(self.dtype_cast_dict) > 0:
dtype_cast_info = "cpu_dtype_cast_info: " + str(self.dtype_cast_dict)

print("\n" * 2)
print(f"{self.name} forward_id: {self.id} {dtype_cast_info}")
print(f"{self.current_location}")
print(table)
print("\n" * 2)

def is_should_apply(self, *args, **kwargs):
BLACK_OP_LIST = ["torch.Tensor.cpu"]
Expand Down
2 changes: 1 addition & 1 deletion op_tools/op_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(self, dir=".", hook=OpRunnerHook()) -> None:
self.dir = dir
self.hooks = []
self.add_hook(hook)
print(f"{dir}")
print(f"dir: {dir}")
self.load_forward_input()
self.load_forward_output()
self.load_backward_data()
Expand Down
8 changes: 4 additions & 4 deletions op_tools/test/test_compare_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def test_compare_diff_int_list(self):
compare_info = compare_result("diff_int_list", result1, result2)
self.assertTrue(compare_info["allclose"] is False, compare_info)
self.assertTrue(compare_info["max_abs_diff"] == 9, compare_info)
self.assertTrue(abs(compare_info["max_relative_diff"] - 1) < 1e-3, compare_info)
self.assertTrue(compare_info["max_relative_diff"] <= 1)
self.assertTrue(isinstance(compare_info["result_list"], list))

def test_same_torch_return_type(self):
Expand Down Expand Up @@ -102,7 +102,7 @@ def test_compare_different_int(self):
compare_info = compare_result("different_int", result1, result2)
self.assertTrue(compare_info["allclose"] is False)
self.assertTrue(compare_info["max_abs_diff"] == i + 10)
self.assertTrue(abs(compare_info["max_relative_diff"] - ((i + 10) / i)) < 1e-3)
self.assertTrue(compare_info["max_relative_diff"] < (abs(result1 - result2) / result2))
self.assertTrue(isinstance(compare_info["result_list"], list))

def test_compare_same_float(self):
Expand All @@ -112,7 +112,7 @@ def test_compare_same_float(self):
compare_info = compare_result("same_float", result1, result2)
self.assertTrue(compare_info["allclose"] is True)
self.assertTrue(compare_info["max_abs_diff"] == 0)
self.assertTrue(abs(compare_info["max_relative_diff"] - 0) < 1e-3)
self.assertTrue(compare_info["max_relative_diff"] <= (abs(result1 - result2) / (result2 + 1e-9)))
self.assertTrue(isinstance(compare_info["result_list"], list))

def test_compare_different_float(self):
Expand All @@ -122,7 +122,7 @@ def test_compare_different_float(self):
compare_info = compare_result("different_float", result1, result2)
self.assertTrue(compare_info["allclose"] is False)
self.assertTrue(compare_info["max_abs_diff"] == i + 10)
self.assertTrue(abs(compare_info["max_relative_diff"] - ((i + 10) / i)) < 1e-3)
self.assertTrue(compare_info["max_relative_diff"] < (abs(result1 - result2) / result2))
self.assertTrue(isinstance(compare_info["result_list"], list))

def test_compare_same_bool(self):
Expand Down
36 changes: 23 additions & 13 deletions op_tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def compare_result(name, a, b): # noqa: C901
for i in range(len(a_list)):
a_item = a_list[i]
b_item = b_list[i]
atol, rtol = 0, 0
atol_i, rtol_i = 0, 0
error_info_i = ""
if a_item is None and b_item is None:
allclose_i = True
Expand All @@ -281,10 +281,10 @@ def compare_result(name, a, b): # noqa: C901
error_info_i += f"Inconsistent dtypes: {a_item.dtype} {b_item.dtype}"
b_item = b_item.to(a_item.dtype)
if a_item.shape == b_item.shape:
atol, rtol = get_error_tolerance(a_item.dtype, name)
atol_i, rtol_i = get_error_tolerance(a_item.dtype, name)
if a_item.numel() > 0:
max_abs_diff_i, max_relative_diff_i = tensor_max_diff(a_item, b_item)
allclose_i = tensor_allclose(a_item, b_item, atol=atol, rtol=rtol)
allclose_i = tensor_allclose(a_item, b_item, atol=atol_i, rtol=rtol_i)
else:
max_abs_diff_i, max_relative_diff_i = 0.0, 0.0
allclose_i = True
Expand All @@ -306,11 +306,11 @@ def compare_result(name, a, b): # noqa: C901
if not allclose_i:
error_info_i = f" value: {a_item} {b_item} "
elif isinstance(a_item, (float, int)):
atol = 1e-6
rtol = 1e-6
allclose_i = (math.isnan(a_item) and math.isnan(b_item)) or (abs(a_item - b_item) <= atol + rtol * abs(a_item))
atol_i = 1e-6
rtol_i = 1e-6
allclose_i = (math.isnan(a_item) and math.isnan(b_item)) or (abs(a_item - b_item) <= atol_i + rtol_i * abs(b_item))
max_abs_diff_i = abs(a_item - b_item)
max_relative_diff_i = max_abs_diff_i / (abs(a_item) + 1e-6)
max_relative_diff_i = max_abs_diff_i / (abs(b_item) + 1e-6)
if not allclose_i:
error_info_i = f" value: {a_item} {b_item} "
else:
Expand All @@ -320,8 +320,12 @@ def compare_result(name, a, b): # noqa: C901
allclose_i = False
error_info_i = str(e)
error_info_i += f" value: {a_item} {b_item}"
max_abs_diff_i = float("nan")
max_relative_diff_i = float("nan")
if not allclose_i:
max_abs_diff_i = float("nan")
max_relative_diff_i = float("nan")
else:
max_abs_diff_i = 0
max_relative_diff_i = 0
if len(a_list) > 1:
prefex = f"[{i}]"
else:
Expand All @@ -331,14 +335,16 @@ def compare_result(name, a, b): # noqa: C901
allclose = allclose_i and allclose
max_abs_diff = max(max_abs_diff_i, max_abs_diff)
max_relative_diff = max(max_relative_diff_i, max_relative_diff)
atol = max(atol_i, atol)
rtol = max(rtol_i, rtol)
result_list.append(
{
"name": f"{name + prefex:<30}",
"allclose": allclose_i,
"max_abs_diff": f"{max_abs_diff_i:10.9f}",
"max_relative_diff": f"{max_relative_diff_i:10.9f}",
"atol": f"{atol:10.9f}",
"rtol": f"{rtol:10.9f}",
"atol": f"{atol_i:10.9f}",
"rtol": f"{rtol_i:10.9f}",
"error_info": error_info_i,
}
)
Expand All @@ -355,8 +361,7 @@ def compare_result(name, a, b): # noqa: C901
}


def garbage_collect(id):
gc_cycle = int(os.getenv("OP_TOOLS_GARBAGE_COLLECTION_CYCLE", "100"))
def garbage_collect(id, gc_cycle=int(os.getenv("OP_TOOLS_GARBAGE_COLLECTION_CYCLE", "50"))):
if id % gc_cycle == 0:
gc.collect()

Expand Down Expand Up @@ -386,3 +391,8 @@ def current_location(name=None, stack_depth=-1, print_stack=False):

file, line, func, text = stack[stack_depth]
return f"{file}:{line} {func}: {text}"


def set_env_if_env_is_empty(env_name, env_value):
if os.environ.get(env_name, None) is None:
os.environ[env_name] = env_value

0 comments on commit 5c25bd9

Please sign in to comment.