Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatic precision comparison supports customizing the data type used in CPU calculations by operator. #52

Merged
merged 1 commit into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,056 changes: 910 additions & 146 deletions README.md

Large diffs are not rendered by default.

95 changes: 63 additions & 32 deletions op_tools/op_autocompare_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
compare_result,
is_random_number_gen_op,
garbage_collect,
get_dtype_cast_dict_form_str,
set_env_if_env_is_empty
)
from .save_op_args import save_op_args, serialize_args_to_dict

Expand Down Expand Up @@ -76,11 +78,16 @@ def grad_fun(grad_inputs, grad_outputs):
return grad_fun


set_env_if_env_is_empty("LINEAR_OP_DTYPE_CAST_DICT", "torch.float16->torch.float64,torch.bfloat16->torch.float64,torch.float32->torch.float64") # noqa: E501
set_env_if_env_is_empty("EMBEDDING_OP_DTYPE_CAST_DICT", "torch.float16->torch.float64,torch.bfloat16->torch.float64,torch.float32->torch.float64") # noqa: E501
set_env_if_env_is_empty("NORMALIZE_OP_DTYPE_CAST_DICT", "torch.float16->torch.float64,torch.bfloat16->torch.float64,torch.float32->torch.float64") # noqa: E501
set_env_if_env_is_empty("NORM_OP_DTYPE_CAST_DICT", "torch.float16->torch.float64,torch.bfloat16->torch.float64,torch.float32->torch.float64") # noqa: E501
set_env_if_env_is_empty("CROSS_ENTROPY_OP_DTYPE_CAST_DICT", "torch.float16->torch.float64,torch.bfloat16->torch.float64,torch.float32->torch.float64") # noqa: E501
set_env_if_env_is_empty("MUL_OP_DTYPE_CAST_DICT", "torch.float16->torch.float64,torch.bfloat16->torch.float64,torch.float32->torch.float64") # noqa: E501
set_env_if_env_is_empty("MATMUL_OP_DTYPE_CAST_DICT", "torch.float16->torch.float64,torch.bfloat16->torch.float64,torch.float32->torch.float64") # noqa: E501


class OpAutoCompareHook(BaseHook):
AUTO_COMPARE_DTYPE_CAST_DICT = {
torch.half: torch.float32,
torch.bfloat16: torch.float32,
}

def __init__(self, name, func) -> None:
super().__init__(name, func)
Expand All @@ -97,39 +104,63 @@ def copy_input_to_cpu(self):
detach=True,
)

def update_dtype_cast_dict(self):
# some op on cpu backend not support half, bfloat16
op_name_suffix = self.name.split(".")[-1].upper() # LINEAR, ADD, MATMUL etc
heigher_priority_env_name = op_name_suffix + "_OP_DTYPE_CAST_DICT"
lower_priority_env_name = "OP_DTYPE_CAST_DICT"
default_env_value = "torch.float16->torch.float32,torch.bfloat16->torch.float32"
heigher_priority_env_value = os.environ.get(heigher_priority_env_name, None)
lower_priority_env_value = os.environ.get(lower_priority_env_name, default_env_value)
self.dtype_cast_config_str = heigher_priority_env_value or lower_priority_env_value

self.dtype_cast_dict = get_dtype_cast_dict_form_str(self.dtype_cast_config_str)

ins_dtype = set()
for args in traverse_container(self.args):
if isinstance(args, torch.Tensor):
ins_dtype.add(args.dtype)
raw_dtypes = set(self.dtype_cast_dict.keys())
for dtype in raw_dtypes:
if dtype not in ins_dtype:
self.dtype_cast_dict.pop(dtype)

def run_forward_on_cpu(self):
self.result_device = to_device("cpu", self.result, detach=True)
try:
self.update_dtype_cast_dict()

self.args_cpu = to_device(
"cpu",
self.args_cpu,
dtype_cast_dict=self.dtype_cast_dict,
detach=True,
)
self.kwargs_cpu = to_device(
"cpu",
self.kwargs_cpu or {},
dtype_cast_dict=self.dtype_cast_dict,
detach=True,
)
# RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.
if (is_inplace_op(self.name) or self.kwargs.get("inplace", False) or is_view_op(self.name)) and self.args[0].requires_grad:
args_cpu = [item for item in self.args_cpu]
args_cpu[0] = args_cpu[0].clone()
self.args_cpu = tuple(args_cpu)
self.result_cpu = self.func(*self.args_cpu, **self.kwargs_cpu)
self.dtype_cast_dict = dict()
args_cpu = self.args_cpu
except Exception as e: # noqa: F841
self.dtype_cast_dict = OpAutoCompareHook.AUTO_COMPARE_DTYPE_CAST_DICT
# some op on cpu backend not support half, bfloat16
self.args_cpu = to_device(
"cpu",
self.args_cpu,
dtype_cast_dict=self.dtype_cast_dict,
detach=True,
)
self.kwargs_cpu = to_device(
else:
self.result_cpu = self.func(*self.args_cpu, **self.kwargs_cpu)

if len(self.dtype_cast_dict) > 0 and 0:
self.result_cpu = to_device(
"cpu",
self.kwargs_cpu or {},
dtype_cast_dict=self.dtype_cast_dict,
self.result_cpu,
dtype_cast_dict={value : key for key, value in self.dtype_cast_dict.items()},
detach=True,
)
# RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.
if (is_inplace_op(self.name) or self.kwargs.get("inplace", False) or is_view_op(self.name)) and self.args[0].requires_grad:
args_cpu = [item for item in self.args_cpu]
args_cpu[0] = args_cpu[0].clone()
self.args_cpu = tuple(args_cpu)
self.result_cpu = self.func(*self.args_cpu, **self.kwargs_cpu)
else:
args_cpu = self.args_cpu
self.result_cpu = self.func(*self.args_cpu, **self.kwargs_cpu)

def run_backward_on_cpu(self, grad_inputs, grad_output):
self.grad_outputs_cpu = to_device("cpu", grad_output, dtype_cast_dict=self.dtype_cast_dict, detach=True)
self.grad_outputs_device = to_device("cpu", grad_output, detach=True)
self.grad_outputs_cpu = to_device("cpu", self.grad_outputs_device, dtype_cast_dict=self.dtype_cast_dict, detach=True)
self.grad_inputs_cpu = to_device("cpu", grad_inputs, dtype_cast_dict=self.dtype_cast_dict, detach=True)
for arg_cpu in traverse_container(self.args_cpu):
if isinstance(arg_cpu, torch.Tensor) and arg_cpu.grad is not None:
Expand Down Expand Up @@ -255,7 +286,7 @@ def compare_backward_relate(self):

id = self.forward_op_id
self = None
garbage_collect(id)
garbage_collect(id, 10)

def save_forward_args(self):
save_op_args(
Expand All @@ -277,7 +308,7 @@ def save_backward_args(self):
save_op_args(
self.name,
f"{self.identifier}/device/grad_outputs",
*tuple(self.grad_outputs_cpu),
*tuple(self.grad_outputs_device),
)
save_op_args(
self.name,
Expand Down Expand Up @@ -326,7 +357,7 @@ def after_call_op(self, result): # noqa:C901
else:
self = None

garbage_collect(id)
garbage_collect(id, 10)
return result

def is_should_apply(self, *args, **kwargs):
Expand Down
9 changes: 4 additions & 5 deletions op_tools/op_dtype_cast_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
traverse_container,
get_dtype_cast_dict_form_str,
is_opname_match,
garbage_collect
)
from .pretty_print import dict_data_list_to_table

Expand All @@ -17,11 +18,6 @@ class OpDtypeCastHook(BaseHook):

def __init__(self, name, func) -> None:
super().__init__(name, func)
self.dtype_cast_config_str = os.environ.get(
"OP_DTYPE_CAST_DICT",
"torch.float16->torch.float32,torch.bfloat16->torch.float32",
)
self.dtype_cast_dict = get_dtype_cast_dict_form_str(self.dtype_cast_config_str)

def before_call_op(self, *args, **kwargs):
self.dtype_cast_config_str = os.environ.get(
Expand Down Expand Up @@ -93,6 +89,9 @@ def after_call_op(self, result):
self.data_dict_list.append(data_dict)
if len(self.data_dict_list) > 0:
print(dict_data_list_to_table(self.data_dict_list))
id = self.id
self = None
garbage_collect(id)

def is_should_apply(self, *args, **kwargs):
if is_opname_match(self.name, os.getenv("OP_DTYPE_CAST_DISABLE_LIST", "")):
Expand Down
25 changes: 19 additions & 6 deletions op_tools/op_fallback_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os

from .base_hook import BaseHook, DisableHookGuard
from .utils import to_device, is_cpu_op, is_opname_match
from .utils import to_device, is_cpu_op, is_opname_match, garbage_collect
from .save_op_args import serialize_args_to_dict
from .pretty_print import packect_data_to_dict_list, dict_data_list_to_table

Expand Down Expand Up @@ -36,6 +36,7 @@ def get_dtype_convert_back_dict(self):

def before_call_op(self, *args, **kwargs):
with DisableHookGuard():
self.dtype_cast_dict = dict()
self.is_cpu_op, self.device = is_cpu_op(*args, **kwargs)
if self.is_cpu_op:
return
Expand All @@ -56,24 +57,28 @@ def after_call_op(self, result):
with DisableHookGuard():
if self.result is not None and self.exception is None:
self.result_cpu = self.result
dtype_convert_back_dict = dict()
self.dtype_convert_back_dict = dict()
else:
# cpu backend do not support half or bfloat16
self.dtype_cast_dict = OpFallbackHook.FALLBACK_DTYPE_CAST_DICT
self.args = to_device(
"cpu",
self.args_device,
dtype_cast_dict=OpFallbackHook.FALLBACK_DTYPE_CAST_DICT,
dtype_cast_dict=self.dtype_cast_dict,
)
self.kwargs = to_device(
"cpu",
self.kwargs_device or {},
dtype_cast_dict=OpFallbackHook.FALLBACK_DTYPE_CAST_DICT,
dtype_cast_dict=self.dtype_cast_dict,
)
self.result_cpu = self.func(*self.args, **self.kwargs)
dtype_convert_back_dict = self.get_dtype_convert_back_dict()
self.dtype_convert_back_dict = self.get_dtype_convert_back_dict()

self.result = to_device(self.device, self.result_cpu, dtype_convert_back_dict)
self.result = to_device(self.device, self.result_cpu, self.dtype_convert_back_dict)
self.dump_op_args()
id = self.id
self = None
garbage_collect(id)

def dump_op_args(self):
data_dict_list = []
Expand All @@ -89,7 +94,15 @@ def dump_op_args(self):
data_dict_list += packect_data_to_dict_list(self.name + " output(cpu)", serialize_args_to_dict(self.result_cpu))

table = dict_data_list_to_table(data_dict_list)
dtype_cast_info = ""
if len(self.dtype_cast_dict) > 0:
dtype_cast_info = "cpu_dtype_cast_info: " + str(self.dtype_cast_dict)

print("\n" * 2)
print(f"{self.name} forward_id: {self.id} {dtype_cast_info}")
print(f"{self.current_location}")
print(table)
print("\n" * 2)

def is_should_apply(self, *args, **kwargs):
BLACK_OP_LIST = ["torch.Tensor.cpu"]
Expand Down
2 changes: 1 addition & 1 deletion op_tools/op_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(self, dir=".", hook=OpRunnerHook()) -> None:
self.dir = dir
self.hooks = []
self.add_hook(hook)
print(f"{dir}")
print(f"dir: {dir}")
self.load_forward_input()
self.load_forward_output()
self.load_backward_data()
Expand Down
8 changes: 4 additions & 4 deletions op_tools/test/test_compare_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def test_compare_diff_int_list(self):
compare_info = compare_result("diff_int_list", result1, result2)
self.assertTrue(compare_info["allclose"] is False, compare_info)
self.assertTrue(compare_info["max_abs_diff"] == 9, compare_info)
self.assertTrue(abs(compare_info["max_relative_diff"] - 1) < 1e-3, compare_info)
self.assertTrue(compare_info["max_relative_diff"] <= 1)
self.assertTrue(isinstance(compare_info["result_list"], list))

def test_same_torch_return_type(self):
Expand Down Expand Up @@ -102,7 +102,7 @@ def test_compare_different_int(self):
compare_info = compare_result("different_int", result1, result2)
self.assertTrue(compare_info["allclose"] is False)
self.assertTrue(compare_info["max_abs_diff"] == i + 10)
self.assertTrue(abs(compare_info["max_relative_diff"] - ((i + 10) / i)) < 1e-3)
self.assertTrue(compare_info["max_relative_diff"] < (abs(result1 - result2) / result2))
self.assertTrue(isinstance(compare_info["result_list"], list))

def test_compare_same_float(self):
Expand All @@ -112,7 +112,7 @@ def test_compare_same_float(self):
compare_info = compare_result("same_float", result1, result2)
self.assertTrue(compare_info["allclose"] is True)
self.assertTrue(compare_info["max_abs_diff"] == 0)
self.assertTrue(abs(compare_info["max_relative_diff"] - 0) < 1e-3)
self.assertTrue(compare_info["max_relative_diff"] <= (abs(result1 - result2) / (result2 + 1e-9)))
self.assertTrue(isinstance(compare_info["result_list"], list))

def test_compare_different_float(self):
Expand All @@ -122,7 +122,7 @@ def test_compare_different_float(self):
compare_info = compare_result("different_float", result1, result2)
self.assertTrue(compare_info["allclose"] is False)
self.assertTrue(compare_info["max_abs_diff"] == i + 10)
self.assertTrue(abs(compare_info["max_relative_diff"] - ((i + 10) / i)) < 1e-3)
self.assertTrue(compare_info["max_relative_diff"] < (abs(result1 - result2) / result2))
self.assertTrue(isinstance(compare_info["result_list"], list))

def test_compare_same_bool(self):
Expand Down
36 changes: 23 additions & 13 deletions op_tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def compare_result(name, a, b): # noqa: C901
for i in range(len(a_list)):
a_item = a_list[i]
b_item = b_list[i]
atol, rtol = 0, 0
atol_i, rtol_i = 0, 0
error_info_i = ""
if a_item is None and b_item is None:
allclose_i = True
Expand All @@ -281,10 +281,10 @@ def compare_result(name, a, b): # noqa: C901
error_info_i += f"Inconsistent dtypes: {a_item.dtype} {b_item.dtype}"
b_item = b_item.to(a_item.dtype)
if a_item.shape == b_item.shape:
atol, rtol = get_error_tolerance(a_item.dtype, name)
atol_i, rtol_i = get_error_tolerance(a_item.dtype, name)
if a_item.numel() > 0:
max_abs_diff_i, max_relative_diff_i = tensor_max_diff(a_item, b_item)
allclose_i = tensor_allclose(a_item, b_item, atol=atol, rtol=rtol)
allclose_i = tensor_allclose(a_item, b_item, atol=atol_i, rtol=rtol_i)
else:
max_abs_diff_i, max_relative_diff_i = 0.0, 0.0
allclose_i = True
Expand All @@ -306,11 +306,11 @@ def compare_result(name, a, b): # noqa: C901
if not allclose_i:
error_info_i = f" value: {a_item} {b_item} "
elif isinstance(a_item, (float, int)):
atol = 1e-6
rtol = 1e-6
allclose_i = (math.isnan(a_item) and math.isnan(b_item)) or (abs(a_item - b_item) <= atol + rtol * abs(a_item))
atol_i = 1e-6
rtol_i = 1e-6
allclose_i = (math.isnan(a_item) and math.isnan(b_item)) or (abs(a_item - b_item) <= atol_i + rtol_i * abs(b_item))
max_abs_diff_i = abs(a_item - b_item)
max_relative_diff_i = max_abs_diff_i / (abs(a_item) + 1e-6)
max_relative_diff_i = max_abs_diff_i / (abs(b_item) + 1e-6)
if not allclose_i:
error_info_i = f" value: {a_item} {b_item} "
else:
Expand All @@ -320,8 +320,12 @@ def compare_result(name, a, b): # noqa: C901
allclose_i = False
error_info_i = str(e)
error_info_i += f" value: {a_item} {b_item}"
max_abs_diff_i = float("nan")
max_relative_diff_i = float("nan")
if not allclose_i:
max_abs_diff_i = float("nan")
max_relative_diff_i = float("nan")
else:
max_abs_diff_i = 0
max_relative_diff_i = 0
if len(a_list) > 1:
prefex = f"[{i}]"
else:
Expand All @@ -331,14 +335,16 @@ def compare_result(name, a, b): # noqa: C901
allclose = allclose_i and allclose
max_abs_diff = max(max_abs_diff_i, max_abs_diff)
max_relative_diff = max(max_relative_diff_i, max_relative_diff)
atol = max(atol_i, atol)
rtol = max(rtol_i, rtol)
result_list.append(
{
"name": f"{name + prefex:<30}",
"allclose": allclose_i,
"max_abs_diff": f"{max_abs_diff_i:10.9f}",
"max_relative_diff": f"{max_relative_diff_i:10.9f}",
"atol": f"{atol:10.9f}",
"rtol": f"{rtol:10.9f}",
"atol": f"{atol_i:10.9f}",
"rtol": f"{rtol_i:10.9f}",
"error_info": error_info_i,
}
)
Expand All @@ -355,8 +361,7 @@ def compare_result(name, a, b): # noqa: C901
}


def garbage_collect(id):
gc_cycle = int(os.getenv("OP_TOOLS_GARBAGE_COLLECTION_CYCLE", "100"))
def garbage_collect(id, gc_cycle=int(os.getenv("OP_TOOLS_GARBAGE_COLLECTION_CYCLE", "50"))):
if id % gc_cycle == 0:
gc.collect()

Expand Down Expand Up @@ -386,3 +391,8 @@ def current_location(name=None, stack_depth=-1, print_stack=False):

file, line, func, text = stack[stack_depth]
return f"{file}:{line} {func}: {text}"


def set_env_if_env_is_empty(env_name, env_value):
if os.environ.get(env_name, None) is None:
os.environ[env_name] = env_value