Skip to content

Commit

Permalink
Supports various operator tools to output and print in tables
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaoguochun1995 committed Sep 13, 2024
1 parent 53509a5 commit 3830985
Show file tree
Hide file tree
Showing 16 changed files with 207 additions and 35 deletions.
9 changes: 7 additions & 2 deletions op_tools/op_autocompare_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
)
from .save_op_args import save_op_args, serialize_args_to_dict

from .pretty_print import pretty_print_op_args

RANDOM_NUMBER_GEN_OPS = [
"torch.Tensor.random_",
"torch.Tensor.uniform_",
Expand Down Expand Up @@ -165,8 +167,11 @@ def after_call_op(self, result): # noqa:C901

self.forward_allclose = allclose
if not allclose:
print(f"OpAutoCompareHook: {self.name:<60} input: {serialize_args_to_dict(*self.args, **self.kwargs)}")
print(f"OpAutoCompareHook: {self.name:<60} output: {serialize_args_to_dict(self.result)['args']}")
pretty_print_op_args(
self.name,
serialize_args_to_dict(*self.args, **self.kwargs),
serialize_args_to_dict(self.result),
)
self.save_forward_args()

self.backward_hook_handle = BackwardHookHandle(self)
Expand Down
10 changes: 7 additions & 3 deletions op_tools/op_dispatch_watch_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,23 @@
from .base_hook import BaseHook, DisableHookGuard

from .save_op_args import serialize_args_to_dict
from .pretty_print import pretty_print_op_args


class OpDispatchWatcherHook(BaseHook):
def __init__(self, name, func) -> None:
super().__init__(name, func)

def before_call_op(self, *args, **kwargs):
with DisableHookGuard():
print(f"OpDispatchWatcherHook: {self.name} input: {serialize_args_to_dict(*args, **kwargs)}")
pass

def after_call_op(self, result):
with DisableHookGuard():
print(f"OpDispatchWatcherHook: {self.name} output: {serialize_args_to_dict(self.result)}")
pretty_print_op_args(
self.name,
serialize_args_to_dict(*self.args, **self.kwargs),
serialize_args_to_dict(self.result),
)

def is_should_apply(self, *args, **kwargs):
if is_opname_match(self.name, os.getenv("OP_DISPATCH_WATCH_DISABLE_LIST", "")):
Expand Down
17 changes: 8 additions & 9 deletions op_tools/op_dtype_cast_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@ def before_call_op(self, *args, **kwargs):
for i in range(len(self.ins_list)):
if isinstance(self.ins_list[i], torch.Tensor):
if self.ins_list[i].dtype != self.raw_ins_list[i].dtype:
print(f"OpDtypeCastHook: {self.name:<50} {i}th arg {self.raw_ins_list[i].dtype} -> {self.ins_list[i].dtype} config:{self.dtype_cast_config_str}") # noqa: E501
self.dtype_cast_back_dict[self.ins_list[i].dtype] = (
self.raw_ins_list[i].dtype
)
print(
f"OpDtypeCastHook: {self.name:<50} {i}th arg {self.raw_ins_list[i].dtype} -> {self.ins_list[i].dtype} config:{self.dtype_cast_config_str}"
) # noqa: E501
self.dtype_cast_back_dict[self.ins_list[i].dtype] = self.raw_ins_list[i].dtype

def after_call_op(self, result):
if self.is_cpu_op:
Expand All @@ -76,11 +76,10 @@ def after_call_op(self, result):
i = -1
for out in traverse_container(self.result_raw):
i += 1
if (
isinstance(out, torch.Tensor)
and out.dtype in self.dtype_cast_back_dict.keys()
):
print(f"OpDtypeCastHook: {self.name:<50} {i}th out {out.dtype} -> {self.dtype_cast_back_dict[out.dtype]} config:{self.dtype_cast_config_str}") # noqa: E501
if isinstance(out, torch.Tensor) and out.dtype in self.dtype_cast_back_dict.keys():
print(
f"OpDtypeCastHook: {self.name:<50} {i}th out {out.dtype} -> {self.dtype_cast_back_dict[out.dtype]} config:{self.dtype_cast_config_str}"
) # noqa: E501

def is_should_apply(self, *args, **kwargs):
if is_opname_match(self.name, os.getenv("OP_DTYPE_CAST_DISABLE_LIST", "")):
Expand Down
26 changes: 24 additions & 2 deletions op_tools/op_fallback_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .base_hook import BaseHook, DisableHookGuard
from .utils import to_device, is_cpu_op, is_opname_match
from .save_op_args import serialize_args_to_dict
from .pretty_print import packect_data_to_dict_list, dict_data_list_to_table


class OpFallbackHook(BaseHook):
Expand Down Expand Up @@ -38,7 +39,6 @@ def before_call_op(self, *args, **kwargs):
self.is_cpu_op, self.device = is_cpu_op(*args, **kwargs)
if self.is_cpu_op:
return
print(f"OpFallbackHook: {self.name:<50} input: {serialize_args_to_dict(*self.args, **self.kwargs)}")
self.args_device = self.args
self.kwargs_device = self.kwargs
self.args = to_device(
Expand Down Expand Up @@ -73,7 +73,29 @@ def after_call_op(self, result):
dtype_convert_back_dict = self.get_dtype_convert_back_dict()

self.result = to_device(self.device, self.result_cpu, dtype_convert_back_dict)
print(f"OpFallbackHook: {self.name:<50} output: {serialize_args_to_dict(self.result)['args']} cpu output: {serialize_args_to_dict(self.result_cpu)['args']} dtype_convert_back_dict:{dtype_convert_back_dict}") # noqa: E501
self.dump_op_args()

def dump_op_args(self):
data_dict_list = []
data_dict_list += packect_data_to_dict_list(
self.name,
serialize_args_to_dict(*self.args_device, **self.kwargs_device),
prefix="device_input ",
)
data_dict_list += packect_data_to_dict_list(
self.name,
serialize_args_to_dict(*self.args, **self.kwargs),
prefix="cpu_input ",
)
data_dict_list += packect_data_to_dict_list(
self.name, serialize_args_to_dict(self.result), prefix="device_output"
)
data_dict_list += packect_data_to_dict_list(
self.name, serialize_args_to_dict(self.result_cpu), prefix="cpu_output "
)

table = dict_data_list_to_table(data_dict_list)
print(table)

def is_should_apply(self, *args, **kwargs):
BLACK_OP_LIST = ["torch.Tensor.cpu"]
Expand Down
33 changes: 31 additions & 2 deletions op_tools/op_time_measure_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@

from .save_op_args import serialize_args_to_dict
from .utils import is_opname_match
from .pretty_print import (
pretty_print_op_args,
dict_data_list_to_table,
packect_data_to_dict_list,
)


class BackwardHookHandle:
Expand All @@ -25,7 +30,21 @@ def grad_fun(grad_inputs, grad_outputs):
torch.cuda.current_stream().synchronize()
self.end_time = time.time()
self.backward_elasped = self.end_time - self.start_time
print(f"OpTimeMeasureHook: {self.name:<30} backward elasped: {(self.backward_elasped * 1000):>10.8f} ms grad_inputs: {serialize_args_to_dict(grad_inputs)} output: {serialize_args_to_dict(grad_outputs)}") # noqa: E501
data_dict_list = []
data_dict_list += packect_data_to_dict_list(
self.name, serialize_args_to_dict(grad_outputs), prefix="grad_outputs "
)
data_dict_list += packect_data_to_dict_list(
self.name, serialize_args_to_dict(grad_inputs), prefix="grad_inputs "
)
table = dict_data_list_to_table(data_dict_list)
print(table)
elasped_info_dict = {
"backward_elasped": f"{(self.backward_elasped * 1000):>10.8f}",
"unit": "ms",
"forward_id": self.id,
}
print(dict_data_list_to_table([elasped_info_dict]))

return grad_fun

Expand Down Expand Up @@ -57,7 +76,17 @@ def after_call_op(self, result):
self.result[i].grad_fn.register_prehook(self.backward_hook_handle.grad_fun_prehook())

with DisableHookGuard():
print(f"OpTimeMeasureHook: {self.name:<30} forward elasped: {(self.foward_elasped * 1000):>10.8f} ms input: {serialize_args_to_dict(*self.args, **self.kwargs)} output: {serialize_args_to_dict(self.result)}") # noqa: E501
pretty_print_op_args(
self.name,
serialize_args_to_dict(*self.args, **self.kwargs),
serialize_args_to_dict(self.result),
)
elasped_info_dict = {
"forward_elasped": f"{(self.foward_elasped * 1000):>10.8f}",
"unit": "ms",
"forward_id": self.id,
}
print(dict_data_list_to_table([elasped_info_dict]))

def is_should_apply(self, *args, **kwargs):
if is_opname_match(self.name, os.getenv("OP_TIME_MEASURE_DISABLE_LIST", "")):
Expand Down
57 changes: 57 additions & 0 deletions op_tools/pretty_print.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from prettytable import PrettyTable


def dict_data_list_to_table(data_dict_list):
table = PrettyTable()
keys = list()
for data_dict in data_dict_list:
if isinstance(data_dict, dict):
for key in data_dict.keys():
if key not in keys:
keys.append(key)
else:
assert False, "data_dict should be dict"
table.field_names = keys
for data_dict in data_dict_list:
table.add_row([data_dict.get(key, "") for key in keys])
return table


def packect_data_to_dict_list(op_name, inputs_dict, prefix):
data_dict_list = []
args = inputs_dict.get("args", [])
kwargs = inputs_dict.get("kwargs", {})
arg_index = -1
for arg in args:
arg_index += 1
if isinstance(arg, dict):
item_name = op_name + f" {prefix}" + (f"[{arg_index}]" if len(args) > 1 else "")
data_dict = {"name": item_name}
data_dict.update(arg)
data_dict_list.append(data_dict)
elif isinstance(arg, (tuple, list)):
arg_sub_index = -1
for item in arg:
arg_sub_index += 1
item_name = op_name + f" {prefix}[{arg_index}]" + f"[{arg_sub_index}]"
if isinstance(item, dict):
data_dict = {"name": item_name}
data_dict.update(item)
data_dict_list.append(data_dict)
else:
data_dict_list.append({"name": item_name, "value": item})
for key, value in kwargs.items():
data_dict_list.append({"name": op_name + f" [{key}]", "value": value})

return data_dict_list


def pretty_print_op_args(op_name, inputs_dict, outputs_dict=None):

input_data_dict_list = packect_data_to_dict_list(op_name, inputs_dict, "inputs")
output_data_dict_list = packect_data_to_dict_list(op_name, outputs_dict, "outputs")
data_dict_list = input_data_dict_list + output_data_dict_list
table = dict_data_list_to_table(data_dict_list)
if len(data_dict_list) > 0:
print(table)
return table
2 changes: 1 addition & 1 deletion op_tools/run_op_from_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def main():
found_files = find_files(args.dir, "input.pth")

for file_path in found_files:
runner = OpRunner(file_path[0:file_path.rfind("/")])
runner = OpRunner(file_path[0 : file_path.rfind("/")])
if args.sync_time_measure:
timer = SyncExecuteTimer()
runner.add_hook(timer)
Expand Down
4 changes: 2 additions & 2 deletions op_tools/save_op_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
def serialize_args_to_dict(*args, **kwargs):
def tensor_to_dict(tensor):
return {
"shape": tensor.shape,
"shape": str(tuple(tensor.shape)),
"stride": tensor.stride(),
"numel": tensor.numel(),
"dtype": str(tensor.dtype),
"dtype": str(tensor.dtype).split(".")[-1],
"device": str(tensor.device),
"requires_grad": tensor.requires_grad,
"layout": str(tensor.layout),
Expand Down
8 changes: 6 additions & 2 deletions op_tools/test/test_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,12 @@ class TestEvent(unittest.TestCase):
def test_event_measure_device_time(self):
x = torch.randn(3, 4).cuda()

start_event = torch.cuda.Event(enable_timing=True, blocking=False, interprocess=False)
end_event = torch.cuda.Event(enable_timing=True, blocking=False, interprocess=False)
start_event = torch.cuda.Event(
enable_timing=True, blocking=False, interprocess=False
)
end_event = torch.cuda.Event(
enable_timing=True, blocking=False, interprocess=False
)

start_event.record(torch.cuda.current_stream())

Expand Down
8 changes: 7 additions & 1 deletion op_tools/test/test_op_autocompare.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@

def f():
a = torch.rand(10, requires_grad=True, device="cuda").half()
a = torch.bernoulli(a) + a + torch.rand_like(a) + torch.empty_like(a).uniform_() + torch.empty_like(a).normal_()
a = (
torch.bernoulli(a)
+ a
+ torch.rand_like(a)
+ torch.empty_like(a).uniform_()
+ torch.empty_like(a).normal_()
)
b = a * 2 + torch.randperm(a.numel(), dtype=a.dtype, device=a.device).view(a.shape)
c = b + a
d = c - a
Expand Down
4 changes: 3 additions & 1 deletion op_tools/test/test_op_dtype_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ def f():
# usage4
os.environ["OP_DTYPE_CAST_DISABLE_LIST"] = ""
os.environ["OP_DTYPE_CAST_LIST"] = "torch.Tensor.sort" # only cast this op
os.environ["OP_DTYPE_CAST_DICT"] = "torch.half->torch.float32" # camb 370 not support bfloat16
os.environ["OP_DTYPE_CAST_DICT"] = (
"torch.half->torch.float32" # camb 370 not support bfloat16
)
dtype_caster.start()
f()
dtype_caster.stop()
Expand Down
4 changes: 3 additions & 1 deletion op_tools/test/test_op_fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ def f():

base = 10000
dim = 1024
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device="cuda", dtype=torch.float32) / dim))
inv_freq = 1.0 / (
base ** (torch.arange(0, dim, 2, device="cuda", dtype=torch.float32) / dim)
)

x = torch.randn(3, 4).cuda().to(torch.float16) # camb_mlu370 not support bfloat16
y = x.clone()
Expand Down
8 changes: 6 additions & 2 deletions op_tools/test/test_opname_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,12 @@ def test_opname_match(self):
self.assertEqual(is_opname_match("torch.addc", "torch.addc,torch.sub"), True)
self.assertEqual(is_opname_match("torch.addc", "torch.add,torch.sub"), False)
self.assertEqual(is_opname_match("torch.sub", "torch.addc,torch.sub"), True)
self.assertEqual(is_opname_match("torch.sub", "torch.addc,torch.subc,torch.mul"), False)
self.assertEqual(is_opname_match("torch.subc", "torch.addc,torch.sub,torch.mul"), False)
self.assertEqual(
is_opname_match("torch.sub", "torch.addc,torch.subc,torch.mul"), False
)
self.assertEqual(
is_opname_match("torch.subc", "torch.addc,torch.sub,torch.mul"), False
)
self.assertEqual(is_opname_match("torch.subc", ".*"), True)
self.assertEqual(is_opname_match("torch.subc", "torch.add,.*"), True)
self.assertEqual(is_opname_match("torch.subc", None), True)
Expand Down
19 changes: 19 additions & 0 deletions op_tools/test/test_pretty_print.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from op_tools.save_op_args import serialize_args_to_dict
from op_tools.pretty_print import pretty_print_op_args
from op_tools.utils import compare_result
import torch
import ditorch

x = torch.randn(3, 4, device="cuda")
y = torch.randn(3, 4, 7, 8, device="cpu")

pretty_print_op_args(
op_name="torch.add",
inputs_dict=serialize_args_to_dict(x, x, x),
outputs_dict=serialize_args_to_dict(x),
)
pretty_print_op_args(
op_name="torch.stack",
inputs_dict=serialize_args_to_dict([x, x, x], dim=1),
outputs_dict=serialize_args_to_dict(x),
)
4 changes: 2 additions & 2 deletions op_tools/test/test_run_op_from_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ def run_command_in_sub_process(commands):
print(result.stdout)
print(result.stderr)
if result.returncode != 0:
print(F"Test {commands} FAILED")
print(f"Test {commands} FAILED")
else:
print(F"Test {commands} PASSED")
print(f"Test {commands} PASSED")
print("\n\n\n")


Expand Down
Loading

0 comments on commit 3830985

Please sign in to comment.