Skip to content

Commit

Permalink
Supports various operator tools to output and print in tables
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaoguochun1995 committed Sep 13, 2024
1 parent 53509a5 commit cf72166
Show file tree
Hide file tree
Showing 16 changed files with 239 additions and 64 deletions.
34 changes: 14 additions & 20 deletions op_tools/op_autocompare_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
)
from .save_op_args import save_op_args, serialize_args_to_dict

from .pretty_print import pretty_print_op_args

RANDOM_NUMBER_GEN_OPS = [
"torch.Tensor.random_",
"torch.Tensor.uniform_",
Expand Down Expand Up @@ -153,20 +155,25 @@ def after_call_op(self, result): # noqa:C901
)

if is_inplace_op(self.name):
allclose = compare_result(self.name, self.args[0], args_cpu[0])["allclose"]
compare_info = compare_result(self.name, self.args[0], args_cpu[0])
allclose = compare_info["allclose"]
if not allclose:
self.save_forward_args()

compare_info = compare_result(self.name, self.result_device, self.result_cpu)
allclose = compare_info["allclose"]

if self.result is None:
print(f"{self.name} output is None, acc not checked")
return

allclose = compare_result(self.name, self.result_device, self.result_cpu)["allclose"]

self.forward_allclose = allclose
if not allclose:
print(f"OpAutoCompareHook: {self.name:<60} input: {serialize_args_to_dict(*self.args, **self.kwargs)}")
print(f"OpAutoCompareHook: {self.name:<60} output: {serialize_args_to_dict(self.result)['args']}")
pretty_print_op_args(
self.name,
serialize_args_to_dict(*self.args, **self.kwargs),
serialize_args_to_dict(self.result),
)
self.save_forward_args()

self.backward_hook_handle = BackwardHookHandle(self)
Expand Down Expand Up @@ -228,21 +235,8 @@ def compare_all_grad(self):
if isinstance(arg, torch.Tensor) and (arg.requires_grad and self.args_grad[i] is None):
return

all_grad_allclose = True
for i in range(len(self.args)):
arg = self.args[i]

if isinstance(arg, torch.Tensor) and arg.requires_grad:
arg_cpu_grad = self.args_cpu_grad[i]
allclose = compare_result(
self.name + f" (ins[{i}].grad)",
self.args_grad[i],
arg_cpu_grad,
)["allclose"]

if not allclose:
all_grad_allclose = False
if not all_grad_allclose:
compare_info = compare_result(self.name + " grad", self.args_cpu_grad, self.args_grad)
if not compare_info["allclose"]:
# Parameters are not saved when forward accuracy is normal
if self.forward_allclose:
self.save_forward_args()
Expand Down
10 changes: 7 additions & 3 deletions op_tools/op_dispatch_watch_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,23 @@
from .base_hook import BaseHook, DisableHookGuard

from .save_op_args import serialize_args_to_dict
from .pretty_print import pretty_print_op_args


class OpDispatchWatcherHook(BaseHook):
def __init__(self, name, func) -> None:
super().__init__(name, func)

def before_call_op(self, *args, **kwargs):
with DisableHookGuard():
print(f"OpDispatchWatcherHook: {self.name} input: {serialize_args_to_dict(*args, **kwargs)}")
pass

def after_call_op(self, result):
with DisableHookGuard():
print(f"OpDispatchWatcherHook: {self.name} output: {serialize_args_to_dict(self.result)}")
pretty_print_op_args(
self.name,
serialize_args_to_dict(*self.args, **self.kwargs),
serialize_args_to_dict(self.result),
)

def is_should_apply(self, *args, **kwargs):
if is_opname_match(self.name, os.getenv("OP_DISPATCH_WATCH_DISABLE_LIST", "")):
Expand Down
30 changes: 21 additions & 9 deletions op_tools/op_dtype_cast_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
get_dtype_cast_dict_form_str,
is_opname_match,
)
from .pretty_print import dict_data_list_to_table


class OpDtypeCastHook(BaseHook):
Expand Down Expand Up @@ -54,13 +55,19 @@ def before_call_op(self, *args, **kwargs):
for arg in traverse_container(self.args_raw):
self.raw_ins_list.append(arg)

self.data_dict_list = []
for i in range(len(self.ins_list)):
if isinstance(self.ins_list[i], torch.Tensor):
if self.ins_list[i].dtype != self.raw_ins_list[i].dtype:
print(f"OpDtypeCastHook: {self.name:<50} {i}th arg {self.raw_ins_list[i].dtype} -> {self.ins_list[i].dtype} config:{self.dtype_cast_config_str}") # noqa: E501
self.dtype_cast_back_dict[self.ins_list[i].dtype] = (
self.raw_ins_list[i].dtype
)
# print(f"OpDtypeCastHook: {self.name:<50} {i}th arg {self.raw_ins_list[i].dtype} -> {self.ins_list[i].dtype} config:{self.dtype_cast_config_str}") # noqa: E501
self.dtype_cast_back_dict[self.ins_list[i].dtype] = self.raw_ins_list[i].dtype
data_dict = {
"name": self.name,
"target": f"input[{i}]",
"action": f"{self.raw_ins_list[i].dtype} -> {self.ins_list[i].dtype}",
"config": self.dtype_cast_config_str,
}
self.data_dict_list.append(data_dict)

def after_call_op(self, result):
if self.is_cpu_op:
Expand All @@ -76,11 +83,16 @@ def after_call_op(self, result):
i = -1
for out in traverse_container(self.result_raw):
i += 1
if (
isinstance(out, torch.Tensor)
and out.dtype in self.dtype_cast_back_dict.keys()
):
print(f"OpDtypeCastHook: {self.name:<50} {i}th out {out.dtype} -> {self.dtype_cast_back_dict[out.dtype]} config:{self.dtype_cast_config_str}") # noqa: E501
if isinstance(out, torch.Tensor) and out.dtype in self.dtype_cast_back_dict.keys():
data_dict = {
"name": self.name,
"target": f"output[{i}]",
"action": f"{out.dtype} -> {self.dtype_cast_back_dict[out.dtype]}",
"config": self.dtype_cast_config_str,
}
self.data_dict_list.append(data_dict)
if len(self.data_dict_list) > 0:
print(dict_data_list_to_table(self.data_dict_list))

def is_should_apply(self, *args, **kwargs):
if is_opname_match(self.name, os.getenv("OP_DTYPE_CAST_DISABLE_LIST", "")):
Expand Down
22 changes: 20 additions & 2 deletions op_tools/op_fallback_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .base_hook import BaseHook, DisableHookGuard
from .utils import to_device, is_cpu_op, is_opname_match
from .save_op_args import serialize_args_to_dict
from .pretty_print import packect_data_to_dict_list, dict_data_list_to_table


class OpFallbackHook(BaseHook):
Expand Down Expand Up @@ -38,7 +39,6 @@ def before_call_op(self, *args, **kwargs):
self.is_cpu_op, self.device = is_cpu_op(*args, **kwargs)
if self.is_cpu_op:
return
print(f"OpFallbackHook: {self.name:<50} input: {serialize_args_to_dict(*self.args, **self.kwargs)}")
self.args_device = self.args
self.kwargs_device = self.kwargs
self.args = to_device(
Expand Down Expand Up @@ -73,7 +73,25 @@ def after_call_op(self, result):
dtype_convert_back_dict = self.get_dtype_convert_back_dict()

self.result = to_device(self.device, self.result_cpu, dtype_convert_back_dict)
print(f"OpFallbackHook: {self.name:<50} output: {serialize_args_to_dict(self.result)['args']} cpu output: {serialize_args_to_dict(self.result_cpu)['args']} dtype_convert_back_dict:{dtype_convert_back_dict}") # noqa: E501
self.dump_op_args()

def dump_op_args(self):
data_dict_list = []
data_dict_list += packect_data_to_dict_list(
self.name,
serialize_args_to_dict(*self.args_device, **self.kwargs_device),
prefix="device_input ",
)
data_dict_list += packect_data_to_dict_list(
self.name,
serialize_args_to_dict(*self.args, **self.kwargs),
prefix="cpu_input ",
)
data_dict_list += packect_data_to_dict_list(self.name, serialize_args_to_dict(self.result), prefix="device_output")
data_dict_list += packect_data_to_dict_list(self.name, serialize_args_to_dict(self.result_cpu), prefix="cpu_output ")

table = dict_data_list_to_table(data_dict_list)
print(table)

def is_should_apply(self, *args, **kwargs):
BLACK_OP_LIST = ["torch.Tensor.cpu"]
Expand Down
29 changes: 27 additions & 2 deletions op_tools/op_time_measure_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@

from .save_op_args import serialize_args_to_dict
from .utils import is_opname_match
from .pretty_print import (
pretty_print_op_args,
dict_data_list_to_table,
packect_data_to_dict_list,
)


class BackwardHookHandle:
Expand All @@ -25,7 +30,17 @@ def grad_fun(grad_inputs, grad_outputs):
torch.cuda.current_stream().synchronize()
self.end_time = time.time()
self.backward_elasped = self.end_time - self.start_time
print(f"OpTimeMeasureHook: {self.name:<30} backward elasped: {(self.backward_elasped * 1000):>10.8f} ms grad_inputs: {serialize_args_to_dict(grad_inputs)} output: {serialize_args_to_dict(grad_outputs)}") # noqa: E501
data_dict_list = []
data_dict_list += packect_data_to_dict_list(self.name, serialize_args_to_dict(grad_outputs), prefix="grad_outputs ")
data_dict_list += packect_data_to_dict_list(self.name, serialize_args_to_dict(grad_inputs), prefix="grad_inputs ")
table = dict_data_list_to_table(data_dict_list)
print(table)
elasped_info_dict = {
"backward_elasped": f"{(self.backward_elasped * 1000):>10.8f}",
"unit": "ms",
"forward_id": self.id,
}
print(dict_data_list_to_table([elasped_info_dict]))

return grad_fun

Expand Down Expand Up @@ -57,7 +72,17 @@ def after_call_op(self, result):
self.result[i].grad_fn.register_prehook(self.backward_hook_handle.grad_fun_prehook())

with DisableHookGuard():
print(f"OpTimeMeasureHook: {self.name:<30} forward elasped: {(self.foward_elasped * 1000):>10.8f} ms input: {serialize_args_to_dict(*self.args, **self.kwargs)} output: {serialize_args_to_dict(self.result)}") # noqa: E501
pretty_print_op_args(
self.name,
serialize_args_to_dict(*self.args, **self.kwargs),
serialize_args_to_dict(self.result),
)
elasped_info_dict = {
"forward_elasped": f"{(self.foward_elasped * 1000):>10.8f}",
"unit": "ms",
"forward_id": self.id,
}
print(dict_data_list_to_table([elasped_info_dict]))

def is_should_apply(self, *args, **kwargs):
if is_opname_match(self.name, os.getenv("OP_TIME_MEASURE_DISABLE_LIST", "")):
Expand Down
57 changes: 57 additions & 0 deletions op_tools/pretty_print.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from prettytable import PrettyTable


def dict_data_list_to_table(data_dict_list):
table = PrettyTable()
keys = list()
for data_dict in data_dict_list:
if isinstance(data_dict, dict):
for key in data_dict.keys():
if key not in keys:
keys.append(key)
else:
assert False, "data_dict should be dict"
table.field_names = keys
for data_dict in data_dict_list:
table.add_row([data_dict.get(key, "") for key in keys])
return table


def packect_data_to_dict_list(op_name, inputs_dict, prefix):
data_dict_list = []
args = inputs_dict.get("args", [])
kwargs = inputs_dict.get("kwargs", {})
arg_index = -1
for arg in args:
arg_index += 1
if isinstance(arg, dict):
item_name = op_name + f" {prefix}" + (f"[{arg_index}]" if len(args) > 1 else "")
data_dict = {"name": item_name}
data_dict.update(arg)
data_dict_list.append(data_dict)
elif isinstance(arg, (tuple, list)):
arg_sub_index = -1
for item in arg:
arg_sub_index += 1
item_name = op_name + f" {prefix}[{arg_index}]" + f"[{arg_sub_index}]"
if isinstance(item, dict):
data_dict = {"name": item_name}
data_dict.update(item)
data_dict_list.append(data_dict)
else:
data_dict_list.append({"name": item_name, "value": item})
for key, value in kwargs.items():
data_dict_list.append({"name": op_name + f" [{key}]", "value": value})

return data_dict_list


def pretty_print_op_args(op_name, inputs_dict, outputs_dict=None):

input_data_dict_list = packect_data_to_dict_list(op_name, inputs_dict, "inputs")
output_data_dict_list = packect_data_to_dict_list(op_name, outputs_dict, "outputs")
data_dict_list = input_data_dict_list + output_data_dict_list
table = dict_data_list_to_table(data_dict_list)
if len(data_dict_list) > 0:
print(table)
return table
2 changes: 1 addition & 1 deletion op_tools/run_op_from_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def main():
found_files = find_files(args.dir, "input.pth")

for file_path in found_files:
runner = OpRunner(file_path[0:file_path.rfind("/")])
runner = OpRunner(file_path[0 : file_path.rfind("/")])
if args.sync_time_measure:
timer = SyncExecuteTimer()
runner.add_hook(timer)
Expand Down
4 changes: 2 additions & 2 deletions op_tools/save_op_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
def serialize_args_to_dict(*args, **kwargs):
def tensor_to_dict(tensor):
return {
"shape": tensor.shape,
"shape": str(tuple(tensor.shape)),
"stride": tensor.stride(),
"numel": tensor.numel(),
"dtype": str(tensor.dtype),
"dtype": str(tensor.dtype).split(".")[-1],
"device": str(tensor.device),
"requires_grad": tensor.requires_grad,
"layout": str(tensor.layout),
Expand Down
8 changes: 6 additions & 2 deletions op_tools/test/test_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,12 @@ class TestEvent(unittest.TestCase):
def test_event_measure_device_time(self):
x = torch.randn(3, 4).cuda()

start_event = torch.cuda.Event(enable_timing=True, blocking=False, interprocess=False)
end_event = torch.cuda.Event(enable_timing=True, blocking=False, interprocess=False)
start_event = torch.cuda.Event(
enable_timing=True, blocking=False, interprocess=False
)
end_event = torch.cuda.Event(
enable_timing=True, blocking=False, interprocess=False
)

start_event.record(torch.cuda.current_stream())

Expand Down
8 changes: 7 additions & 1 deletion op_tools/test/test_op_autocompare.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@

def f():
a = torch.rand(10, requires_grad=True, device="cuda").half()
a = torch.bernoulli(a) + a + torch.rand_like(a) + torch.empty_like(a).uniform_() + torch.empty_like(a).normal_()
a = (
torch.bernoulli(a)
+ a
+ torch.rand_like(a)
+ torch.empty_like(a).uniform_()
+ torch.empty_like(a).normal_()
)
b = a * 2 + torch.randperm(a.numel(), dtype=a.dtype, device=a.device).view(a.shape)
c = b + a
d = c - a
Expand Down
4 changes: 3 additions & 1 deletion op_tools/test/test_op_dtype_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ def f():
# usage4
os.environ["OP_DTYPE_CAST_DISABLE_LIST"] = ""
os.environ["OP_DTYPE_CAST_LIST"] = "torch.Tensor.sort" # only cast this op
os.environ["OP_DTYPE_CAST_DICT"] = "torch.half->torch.float32" # camb 370 not support bfloat16
os.environ["OP_DTYPE_CAST_DICT"] = (
"torch.half->torch.float32" # camb 370 not support bfloat16
)
dtype_caster.start()
f()
dtype_caster.stop()
Expand Down
4 changes: 3 additions & 1 deletion op_tools/test/test_op_fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ def f():

base = 10000
dim = 1024
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device="cuda", dtype=torch.float32) / dim))
inv_freq = 1.0 / (
base ** (torch.arange(0, dim, 2, device="cuda", dtype=torch.float32) / dim)
)

x = torch.randn(3, 4).cuda().to(torch.float16) # camb_mlu370 not support bfloat16
y = x.clone()
Expand Down
8 changes: 6 additions & 2 deletions op_tools/test/test_opname_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,12 @@ def test_opname_match(self):
self.assertEqual(is_opname_match("torch.addc", "torch.addc,torch.sub"), True)
self.assertEqual(is_opname_match("torch.addc", "torch.add,torch.sub"), False)
self.assertEqual(is_opname_match("torch.sub", "torch.addc,torch.sub"), True)
self.assertEqual(is_opname_match("torch.sub", "torch.addc,torch.subc,torch.mul"), False)
self.assertEqual(is_opname_match("torch.subc", "torch.addc,torch.sub,torch.mul"), False)
self.assertEqual(
is_opname_match("torch.sub", "torch.addc,torch.subc,torch.mul"), False
)
self.assertEqual(
is_opname_match("torch.subc", "torch.addc,torch.sub,torch.mul"), False
)
self.assertEqual(is_opname_match("torch.subc", ".*"), True)
self.assertEqual(is_opname_match("torch.subc", "torch.add,.*"), True)
self.assertEqual(is_opname_match("torch.subc", None), True)
Expand Down
Loading

0 comments on commit cf72166

Please sign in to comment.