Skip to content

Commit

Permalink
Supports various operator tools to output and print in tables
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaoguochun1995 committed Sep 13, 2024
1 parent 53509a5 commit 8eb01c2
Show file tree
Hide file tree
Showing 20 changed files with 301 additions and 61 deletions.
4 changes: 3 additions & 1 deletion op_tools/apply_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ def is_should_apply_hook(name, func, args, kwargs=None):
return False
if inspect.isroutine(func) is False:
return False
if name.startswith("torch.Tensor.") and (name.endswith("__get__") or name.endswith("__set__")):
if name.startswith("torch.Tensor.") and (
name.endswith("__get__") or name.endswith("__set__")
):
return False
# Assuming that the torch provided by the manufacturer has not been compromised in terms of CPU functionality
args_on_cpu, device = is_cpu_op(args, kwargs)
Expand Down
4 changes: 3 additions & 1 deletion op_tools/base_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ def is_should_apply_hook(name, func, *args, **kwargs):
return False
if callable(func) is False:
return False
if name.startswith("torch.Tensor.") and (name.endswith("__get__") or name.endswith("__set__")):
if name.startswith("torch.Tensor.") and (
name.endswith("__get__") or name.endswith("__set__")
):
return False
# Assuming that the torch provided by the manufacturer has not been compromised in terms of CPU functionality
args_on_cpu, device = is_cpu_op(*args, **kwargs)
Expand Down
4 changes: 3 additions & 1 deletion op_tools/custom_apply_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ def apply_feature(ops, feature, condition_func=lambda *args, **kwargs: True):
"cast_dtype",
"op_capture",
]
assert feature in feature_options, f"feature must be one of {feature_options}, but got {feature}"
assert (
feature in feature_options
), f"feature must be one of {feature_options}, but got {feature}"
assert callable(condition_func)
if feature == "fallback":
hook_cls = OpFallbackHook
Expand Down
42 changes: 32 additions & 10 deletions op_tools/op_autocompare_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
)
from .save_op_args import save_op_args, serialize_args_to_dict

from .pretty_print import pretty_print_op_args

RANDOM_NUMBER_GEN_OPS = [
"torch.Tensor.random_",
"torch.Tensor.uniform_",
Expand Down Expand Up @@ -137,7 +139,9 @@ def after_call_op(self, result): # noqa:C901
detach=True,
)
# RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.
if (is_inplace_op(self.name) or is_view_op(self.name)) and self.args[0].requires_grad:
if (is_inplace_op(self.name) or is_view_op(self.name)) and self.args[
0
].requires_grad:
args_cpu = [item for item in self.args_cpu]
args_cpu[0] = args_cpu[0].clone()
self.result_cpu = self.func(*args_cpu, **self.kwargs_cpu)
Expand All @@ -153,20 +157,27 @@ def after_call_op(self, result): # noqa:C901
)

if is_inplace_op(self.name):
allclose = compare_result(self.name, self.args[0], args_cpu[0])["allclose"]
allclose = compare_result(self.name, self.args[0], args_cpu[0])[
"allclose"
]
if not allclose:
self.save_forward_args()

if self.result is None:
print(f"{self.name} output is None, acc not checked")
return

allclose = compare_result(self.name, self.result_device, self.result_cpu)["allclose"]
allclose = compare_result(self.name, self.result_device, self.result_cpu)[
"allclose"
]

self.forward_allclose = allclose
if not allclose:
print(f"OpAutoCompareHook: {self.name:<60} input: {serialize_args_to_dict(*self.args, **self.kwargs)}")
print(f"OpAutoCompareHook: {self.name:<60} output: {serialize_args_to_dict(self.result)['args']}")
pretty_print_op_args(
self.name,
serialize_args_to_dict(*self.args, **self.kwargs),
serialize_args_to_dict(self.result),
)
self.save_forward_args()

self.backward_hook_handle = BackwardHookHandle(self)
Expand All @@ -184,8 +195,12 @@ def after_call_op(self, result): # noqa:C901
self.kwargs = to_device("cpu", self.kwargs or {}, detach=True)

def run_backward_on_cpu(self, grad_inputs, grad_output):
self.grad_outputs_cpu = to_device("cpu", grad_output, dtype_cast_dict=self.dtype_cast_dict, detach=True)
self.grad_inputs_cpu = to_device("cpu", grad_inputs, dtype_cast_dict=self.dtype_cast_dict, detach=True)
self.grad_outputs_cpu = to_device(
"cpu", grad_output, dtype_cast_dict=self.dtype_cast_dict, detach=True
)
self.grad_inputs_cpu = to_device(
"cpu", grad_inputs, dtype_cast_dict=self.dtype_cast_dict, detach=True
)
for arg_cpu in traverse_container(self.args_cpu):
if isinstance(arg_cpu, torch.Tensor) and arg_cpu.grad is not None:
arg_cpu.grad.zero_()
Expand All @@ -196,7 +211,10 @@ def run_backward_on_cpu(self, grad_inputs, grad_output):

self.args_cpu_grad = []
for i in range(len(self.args_cpu)):
if isinstance(self.args_cpu[i], torch.Tensor) and self.args_cpu[i].grad is not None:
if (
isinstance(self.args_cpu[i], torch.Tensor)
and self.args_cpu[i].grad is not None
):
self.args_cpu_grad.append(self.args_cpu[i].grad)
else:
self.args_cpu_grad.append(None)
Expand Down Expand Up @@ -225,7 +243,9 @@ def compare_all_grad(self):
# Check if all gradients have been obtained
for i in range(len(self.args)):
arg = self.args[i]
if isinstance(arg, torch.Tensor) and (arg.requires_grad and self.args_grad[i] is None):
if isinstance(arg, torch.Tensor) and (
arg.requires_grad and self.args_grad[i] is None
):
return

all_grad_allclose = True
Expand Down Expand Up @@ -253,7 +273,9 @@ def compare_all_grad(self):
def set_input_grad(self, index, grad):
if not hasattr(self, "args_grad"):
self.args_grad = [None for i in range(len(self.args))]
self.args_grad[index] = to_device("cpu", grad, dtype_cast_dict=self.dtype_cast_dict, detach=True)
self.args_grad[index] = to_device(
"cpu", grad, dtype_cast_dict=self.dtype_cast_dict, detach=True
)

def save_forward_args(self):
save_op_args(
Expand Down
12 changes: 9 additions & 3 deletions op_tools/op_capture_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ def __init__(self, name, func) -> None:
super().__init__(name, func)

def before_call_op(self, *args, **kwargs):
self.forward_op_id = f"{self.id}/{time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime())}"
self.forward_op_id = (
f"{self.id}/{time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime())}"
)
with DisableHookGuard():
name = self.name

Expand All @@ -39,12 +41,16 @@ def after_call_op(self, result):
id = f"{self.forward_op_id}/output"
save_op_args(self.name, id, self.result)

self.backward_hook_handle = BackwardHookHandle(self.name, self.forward_op_id)
self.backward_hook_handle = BackwardHookHandle(
self.name, self.forward_op_id
)

for result in traverse_container(self.result):
if isinstance(result, torch.Tensor):
if result.grad_fn is not None:
result.grad_fn.register_hook(self.backward_hook_handle.grad_fun_hook())
result.grad_fn.register_hook(
self.backward_hook_handle.grad_fun_hook()
)

def is_should_apply(self, *args, **kwargs):
if is_opname_match(self.name, os.getenv("OP_CAPTURE_DISABLE_LIST", "")):
Expand Down
10 changes: 7 additions & 3 deletions op_tools/op_dispatch_watch_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,23 @@
from .base_hook import BaseHook, DisableHookGuard

from .save_op_args import serialize_args_to_dict
from .pretty_print import pretty_print_op_args


class OpDispatchWatcherHook(BaseHook):
def __init__(self, name, func) -> None:
super().__init__(name, func)

def before_call_op(self, *args, **kwargs):
with DisableHookGuard():
print(f"OpDispatchWatcherHook: {self.name} input: {serialize_args_to_dict(*args, **kwargs)}")
pass

def after_call_op(self, result):
with DisableHookGuard():
print(f"OpDispatchWatcherHook: {self.name} output: {serialize_args_to_dict(self.result)}")
pretty_print_op_args(
self.name,
serialize_args_to_dict(*self.args, **self.kwargs),
serialize_args_to_dict(self.result),
)

def is_should_apply(self, *args, **kwargs):
if is_opname_match(self.name, os.getenv("OP_DISPATCH_WATCH_DISABLE_LIST", "")):
Expand Down
30 changes: 27 additions & 3 deletions op_tools/op_fallback_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .base_hook import BaseHook, DisableHookGuard
from .utils import to_device, is_cpu_op, is_opname_match
from .save_op_args import serialize_args_to_dict
from .pretty_print import packect_data_to_dict_list, dict_data_list_to_table


class OpFallbackHook(BaseHook):
Expand Down Expand Up @@ -38,7 +39,6 @@ def before_call_op(self, *args, **kwargs):
self.is_cpu_op, self.device = is_cpu_op(*args, **kwargs)
if self.is_cpu_op:
return
print(f"OpFallbackHook: {self.name:<50} input: {serialize_args_to_dict(*self.args, **self.kwargs)}")
self.args_device = self.args
self.kwargs_device = self.kwargs
self.args = to_device(
Expand Down Expand Up @@ -72,8 +72,32 @@ def after_call_op(self, result):
self.result_cpu = self.func(*self.args, **self.kwargs)
dtype_convert_back_dict = self.get_dtype_convert_back_dict()

self.result = to_device(self.device, self.result_cpu, dtype_convert_back_dict)
print(f"OpFallbackHook: {self.name:<50} output: {serialize_args_to_dict(self.result)['args']} cpu output: {serialize_args_to_dict(self.result_cpu)['args']} dtype_convert_back_dict:{dtype_convert_back_dict}") # noqa: E501
self.result = to_device(
self.device, self.result_cpu, dtype_convert_back_dict
)
self.dump_op_args()

def dump_op_args(self):
data_dict_list = []
data_dict_list += packect_data_to_dict_list(
self.name,
serialize_args_to_dict(*self.args_device, **self.kwargs_device),
prefix="device_input ",
)
data_dict_list += packect_data_to_dict_list(
self.name,
serialize_args_to_dict(*self.args, **self.kwargs),
prefix="cpu_input ",
)
data_dict_list += packect_data_to_dict_list(
self.name, serialize_args_to_dict(self.result), prefix="device_output"
)
data_dict_list += packect_data_to_dict_list(
self.name, serialize_args_to_dict(self.result_cpu), prefix="cpu_output "
)

table = dict_data_list_to_table(data_dict_list)
print(table)

def is_should_apply(self, *args, **kwargs):
BLACK_OP_LIST = ["torch.Tensor.cpu"]
Expand Down
30 changes: 22 additions & 8 deletions op_tools/op_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,19 @@ def after_backward(self):
class AsyncEventTimer(OpRunnerHook):
def __init__(self) -> None:
super().__init__()
self.forward_start_event = torch.cuda.Event(enable_timing=True, blocking=False, interprocess=False)
self.forward_end_event = torch.cuda.Event(enable_timing=True, blocking=False, interprocess=False)

self.backward_start_event = torch.cuda.Event(enable_timing=True, blocking=False, interprocess=False)
self.backward_end_event = torch.cuda.Event(enable_timing=True, blocking=False, interprocess=False)
self.forward_start_event = torch.cuda.Event(
enable_timing=True, blocking=False, interprocess=False
)
self.forward_end_event = torch.cuda.Event(
enable_timing=True, blocking=False, interprocess=False
)

self.backward_start_event = torch.cuda.Event(
enable_timing=True, blocking=False, interprocess=False
)
self.backward_end_event = torch.cuda.Event(
enable_timing=True, blocking=False, interprocess=False
)

def before_forward(self):
self.forward_start_event.record(torch.cuda.current_stream)
Expand All @@ -55,7 +63,9 @@ def after_forward(self):
torch.cuda.current_stream().synchronize()
self.forward_end_time = time.time()
self.forward_elasped_time = self.forward_end_time - self.forward_start_time
print(f"SyncExecuteTimer: {self.runner.name} forward elasped {self.forward_elasped_time * 1000:>.8f} ms ")
print(
f"SyncExecuteTimer: {self.runner.name} forward elasped {self.forward_elasped_time * 1000:>.8f} ms "
)

def before_backward(self):
torch.cuda.current_stream().synchronize()
Expand All @@ -65,7 +75,9 @@ def after_backward(self):
torch.cuda.current_stream().synchronize()
self.backward_end_time = time.time()
self.backward_elasped_time = self.backward_end_time - self.forward_start_time
print(f"SyncExecuteTimer: {self.runner.name} backward elasped {self.backward_elasped_time * 1000:>.8f} ms")
print(
f"SyncExecuteTimer: {self.runner.name} backward elasped {self.backward_elasped_time * 1000:>.8f} ms"
)


class OpAccyChecker(OpRunnerHook):
Expand Down Expand Up @@ -157,5 +169,7 @@ def run_backward(self):
self.run_before_backward()
for result in traverse_container(self.result):
if isinstance(result, torch.Tensor) and result.requires_grad:
result.backward(*self.grad_outputs["args"], **self.grad_outputs["kwargs"])
result.backward(
*self.grad_outputs["args"], **self.grad_outputs["kwargs"]
)
self.run_after_backward()
60 changes: 51 additions & 9 deletions op_tools/op_time_measure_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@

from .save_op_args import serialize_args_to_dict
from .utils import is_opname_match
from .pretty_print import (
pretty_print_op_args,
dict_data_list_to_table,
packect_data_to_dict_list,
)


class BackwardHookHandle:
Expand All @@ -25,7 +30,21 @@ def grad_fun(grad_inputs, grad_outputs):
torch.cuda.current_stream().synchronize()
self.end_time = time.time()
self.backward_elasped = self.end_time - self.start_time
print(f"OpTimeMeasureHook: {self.name:<30} backward elasped: {(self.backward_elasped * 1000):>10.8f} ms grad_inputs: {serialize_args_to_dict(grad_inputs)} output: {serialize_args_to_dict(grad_outputs)}") # noqa: E501
data_dict_list = []
data_dict_list += packect_data_to_dict_list(
self.name, serialize_args_to_dict(grad_outputs), prefix="grad_outputs "
)
data_dict_list += packect_data_to_dict_list(
self.name, serialize_args_to_dict(grad_inputs), prefix="grad_inputs "
)
table = dict_data_list_to_table(data_dict_list)
print(table)
elasped_info_dict = {
"backward_elasped": f"{(self.backward_elasped * 1000):>10.8f}",
"unit": "ms",
"forward_id": self.id,
}
print(dict_data_list_to_table([elasped_info_dict]))

return grad_fun

Expand All @@ -46,18 +65,41 @@ def after_call_op(self, result):
self.backward_hook_handle = BackwardHookHandle(self.name, self.id)
if isinstance(self.result, torch.Tensor):
if self.result.grad_fn is not None:
self.result.grad_fn.register_hook(self.backward_hook_handle.grad_fun_posthook())
self.result.grad_fn.register_prehook(self.backward_hook_handle.grad_fun_prehook())
elif isinstance(self.result, (tuple, list)) or type(self.result).__module__.startswith("torch.return_types"):
self.result.grad_fn.register_hook(
self.backward_hook_handle.grad_fun_posthook()
)
self.result.grad_fn.register_prehook(
self.backward_hook_handle.grad_fun_prehook()
)
elif isinstance(self.result, (tuple, list)) or type(
self.result
).__module__.startswith("torch.return_types"):
# torch.return_types is a structseq, aka a "namedtuple"-like thing defined by the Python C-API.
for i in range(len(self.result)):
if isinstance(self.result[i], torch.Tensor) and self.result[i].grad_fn is not None:
self.result[i].grad_fn.register_hook(self.backward_hook_handle.grad_fun_posthook())

self.result[i].grad_fn.register_prehook(self.backward_hook_handle.grad_fun_prehook())
if (
isinstance(self.result[i], torch.Tensor)
and self.result[i].grad_fn is not None
):
self.result[i].grad_fn.register_hook(
self.backward_hook_handle.grad_fun_posthook()
)

self.result[i].grad_fn.register_prehook(
self.backward_hook_handle.grad_fun_prehook()
)

with DisableHookGuard():
print(f"OpTimeMeasureHook: {self.name:<30} forward elasped: {(self.foward_elasped * 1000):>10.8f} ms input: {serialize_args_to_dict(*self.args, **self.kwargs)} output: {serialize_args_to_dict(self.result)}") # noqa: E501
pretty_print_op_args(
self.name,
serialize_args_to_dict(*self.args, **self.kwargs),
serialize_args_to_dict(self.result),
)
elasped_info_dict = {
"forward_elasped": f"{(self.foward_elasped * 1000):>10.8f}",
"unit": "ms",
"forward_id": self.id,
}
print(dict_data_list_to_table([elasped_info_dict]))

def is_should_apply(self, *args, **kwargs):
if is_opname_match(self.name, os.getenv("OP_TIME_MEASURE_DISABLE_LIST", "")):
Expand Down
Loading

0 comments on commit 8eb01c2

Please sign in to comment.