Skip to content

Commit

Permalink
Zgc/ditorch support print in table format (#34)
Browse files Browse the repository at this point in the history
* Update usage documentation and add instructions for related function environment variables

* Supports various operator tools to output and print in tables

* Supports saving operator performance data to files and printing in table form

* Supports printing and saving the summary results of autocompare in tabular form to a file

* When multiple custom hooks are applied on the same interface, only the last hook will take effect.

* The files generated by the operator tool are uniformly placed in the op_tools_results directory.

* Fixed the bug when the operator offline runner runs the inplace operator more than once
  • Loading branch information
zhaoguochun1995 authored Sep 14, 2024
1 parent d853ded commit 1cd0513
Show file tree
Hide file tree
Showing 20 changed files with 361 additions and 85 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -144,4 +144,5 @@ CMakeUserPresets.json
# Autogened on ascend
fusion_result.json

op_capture_result/
op_tools_results/
export_only_prof_dir/
6 changes: 6 additions & 0 deletions op_tools/custom_apply_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .op_dispatch_watch_hook import OpDispatchWatcherHook
from .op_time_measure_hook import OpTimeMeasureHook
from .op_dtype_cast_hook import OpDtypeCastHook
from .base_hook import BaseHook


def get_func_name(func):
Expand Down Expand Up @@ -58,6 +59,11 @@ def condition_func(*args, **kwargs):
else:
condition_func = condition_funcs
assert callable(condition_func)
if issubclass(type(func), BaseHook):
print(
f"The {name} is applying multiple hooks, and the previous hook {func.class_name()} will be replaced by the {hook.class_name()}." # noqa: E501
)
func = func.func

hook_obj = hook(name, func)
hook_obj.add_condition_func(condition_func)
Expand Down
93 changes: 62 additions & 31 deletions op_tools/op_autocompare_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import gc
import os
import time
import atexit

from .base_hook import BaseHook, DisableHookGuard
from .utils import (
Expand All @@ -16,6 +17,8 @@
)
from .save_op_args import save_op_args, serialize_args_to_dict

from .pretty_print import pretty_print_op_args, dict_data_list_to_table

RANDOM_NUMBER_GEN_OPS = [
"torch.Tensor.random_",
"torch.Tensor.uniform_",
Expand Down Expand Up @@ -85,6 +88,9 @@ def grad_fun(grad):
return grad_fun


global_autocompare_result = []


class OpAutoCompareHook(BaseHook):
AUTO_COMPARE_DTYPE_CAST_DICT = {
torch.half: torch.float32,
Expand All @@ -95,7 +101,8 @@ def __init__(self, name, func) -> None:
super().__init__(name, func)

def before_call_op(self, *args, **kwargs):
self.forward_op_id = f"autocompare/{self.id}/{time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime())}"
self.forward_op_id = self.id
self.identifier = f"autocompare/{self.id}/{time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime())}"
with DisableHookGuard():
self.is_cpu_op, self.device = is_cpu_op(*args, **kwargs)
if self.is_cpu_op:
Expand Down Expand Up @@ -153,20 +160,29 @@ def after_call_op(self, result): # noqa:C901
)

if is_inplace_op(self.name):
allclose = compare_result(self.name, self.args[0], args_cpu[0])["allclose"]
compare_info = compare_result(self.name, self.args[0], args_cpu[0])
allclose = compare_info["allclose"]
if not allclose:
self.save_forward_args()
compare_info.update({"forward_id": self.forward_op_id})
global_autocompare_result.append(compare_info)

compare_info = compare_result(self.name, self.result_device, self.result_cpu)
compare_info.update({"forward_id": self.forward_op_id})
global_autocompare_result.append(compare_info)
allclose = compare_info["allclose"]

if self.result is None:
print(f"{self.name} output is None, acc not checked")
return

allclose = compare_result(self.name, self.result_device, self.result_cpu)["allclose"]

self.forward_allclose = allclose
if not allclose:
print(f"OpAutoCompareHook: {self.name:<60} input: {serialize_args_to_dict(*self.args, **self.kwargs)}")
print(f"OpAutoCompareHook: {self.name:<60} output: {serialize_args_to_dict(self.result)['args']}")
pretty_print_op_args(
self.name,
serialize_args_to_dict(*self.args, **self.kwargs),
serialize_args_to_dict(self.result),
)
self.save_forward_args()

self.backward_hook_handle = BackwardHookHandle(self)
Expand Down Expand Up @@ -228,25 +244,14 @@ def compare_all_grad(self):
if isinstance(arg, torch.Tensor) and (arg.requires_grad and self.args_grad[i] is None):
return

all_grad_allclose = True
for i in range(len(self.args)):
arg = self.args[i]

if isinstance(arg, torch.Tensor) and arg.requires_grad:
arg_cpu_grad = self.args_cpu_grad[i]
allclose = compare_result(
self.name + f" (ins[{i}].grad)",
self.args_grad[i],
arg_cpu_grad,
)["allclose"]

if not allclose:
all_grad_allclose = False
if not all_grad_allclose:
compare_info = compare_result(self.name + " grad", self.args_cpu_grad, self.args_grad)
compare_info.update({"forward_id": self.forward_op_id})
global_autocompare_result.append(compare_info)
if not compare_info["allclose"]:
# Parameters are not saved when forward accuracy is normal
if self.forward_allclose:
self.save_forward_args()
self.save_backward_args
self.save_backward_args()
self = None
gc.collect()

Expand All @@ -258,38 +263,38 @@ def set_input_grad(self, index, grad):
def save_forward_args(self):
save_op_args(
self.name,
f"{self.forward_op_id}/device/input",
f"{self.identifier}/device/input",
*self.args,
**self.kwargs,
)
save_op_args(self.name, f"{self.forward_op_id}/device/output", self.result)
save_op_args(self.name, f"{self.identifier}/device/output", self.result)
save_op_args(
self.name,
f"{self.forward_op_id}/cpu/input",
f"{self.identifier}/cpu/input",
*self.args_cpu,
**self.kwargs_cpu,
)
save_op_args(self.name, f"{self.forward_op_id}/cpu/output", self.result_cpu)
save_op_args(self.name, f"{self.identifier}/cpu/output", self.result_cpu)

def save_backward_args(self):
save_op_args(
self.name,
f"{self.forward_op_id}/device/grad_outputs",
*tuple(self.grad_output),
f"{self.identifier}/device/grad_outputs",
*tuple(self.grad_outputs_cpu),
)
save_op_args(
self.name,
f"{self.forward_op_id}/device/grad_inputs",
f"{self.identifier}/device/grad_inputs",
*tuple(self.args_grad),
)
save_op_args(
self.name,
f"{self.forward_op_id}/cpu/grad_inputs",
f"{self.identifier}/cpu/grad_inputs",
*tuple(self.args_cpu_grad),
)
save_op_args(
self.name,
f"{self.forward_op_id}/cpu/grad_outputs",
f"{self.identifier}/cpu/grad_outputs",
*tuple(self.grad_outputs_cpu),
)

Expand All @@ -307,3 +312,29 @@ def is_should_apply(self, *args, **kwargs):
return False

return is_opname_match(self.name, os.getenv("OP_AUTOCOMPARE_LIST", ".*"))


def dump_all_autocompare_info():
if len(global_autocompare_result) == 0:
return
all_compare_info_list = []
while len(global_autocompare_result) > 0:
compare_info = global_autocompare_result.pop(0)
while len(compare_info["result_list"]) > 0:
compare_result = compare_info["result_list"].pop(0)
all_compare_info_list.append({"forward_id": compare_info["forward_id"], **compare_result})

table = dict_data_list_to_table(all_compare_info_list)
print(table)
data_string = table.get_csv_string()
file_name = f"op_tools_results/op_autocompare_result/op_autocompare_info_pid{os.getpid()}_{time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime())}.csv" # noqa: E501
dir = file_name[0 : file_name.rfind("/")]
os.makedirs(dir, exist_ok=True)

with open(file_name, "w") as f:
f.write(data_string)
f.close
print(f"op autocompare info saved to {file_name}")


atexit.register(dump_all_autocompare_info)
10 changes: 7 additions & 3 deletions op_tools/op_dispatch_watch_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,23 @@
from .base_hook import BaseHook, DisableHookGuard

from .save_op_args import serialize_args_to_dict
from .pretty_print import pretty_print_op_args


class OpDispatchWatcherHook(BaseHook):
def __init__(self, name, func) -> None:
super().__init__(name, func)

def before_call_op(self, *args, **kwargs):
with DisableHookGuard():
print(f"OpDispatchWatcherHook: {self.name} input: {serialize_args_to_dict(*args, **kwargs)}")
pass

def after_call_op(self, result):
with DisableHookGuard():
print(f"OpDispatchWatcherHook: {self.name} output: {serialize_args_to_dict(self.result)}")
pretty_print_op_args(
self.name,
serialize_args_to_dict(*self.args, **self.kwargs),
serialize_args_to_dict(self.result),
)

def is_should_apply(self, *args, **kwargs):
if is_opname_match(self.name, os.getenv("OP_DISPATCH_WATCH_DISABLE_LIST", "")):
Expand Down
30 changes: 21 additions & 9 deletions op_tools/op_dtype_cast_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
get_dtype_cast_dict_form_str,
is_opname_match,
)
from .pretty_print import dict_data_list_to_table


class OpDtypeCastHook(BaseHook):
Expand Down Expand Up @@ -54,13 +55,19 @@ def before_call_op(self, *args, **kwargs):
for arg in traverse_container(self.args_raw):
self.raw_ins_list.append(arg)

self.data_dict_list = []
for i in range(len(self.ins_list)):
if isinstance(self.ins_list[i], torch.Tensor):
if self.ins_list[i].dtype != self.raw_ins_list[i].dtype:
print(f"OpDtypeCastHook: {self.name:<50} {i}th arg {self.raw_ins_list[i].dtype} -> {self.ins_list[i].dtype} config:{self.dtype_cast_config_str}") # noqa: E501
self.dtype_cast_back_dict[self.ins_list[i].dtype] = (
self.raw_ins_list[i].dtype
)
# print(f"OpDtypeCastHook: {self.name:<50} {i}th arg {self.raw_ins_list[i].dtype} -> {self.ins_list[i].dtype} config:{self.dtype_cast_config_str}") # noqa: E501
self.dtype_cast_back_dict[self.ins_list[i].dtype] = self.raw_ins_list[i].dtype
data_dict = {
"name": self.name,
"target": f"input[{i}]",
"action": f"{self.raw_ins_list[i].dtype} -> {self.ins_list[i].dtype}",
"config": self.dtype_cast_config_str,
}
self.data_dict_list.append(data_dict)

def after_call_op(self, result):
if self.is_cpu_op:
Expand All @@ -76,11 +83,16 @@ def after_call_op(self, result):
i = -1
for out in traverse_container(self.result_raw):
i += 1
if (
isinstance(out, torch.Tensor)
and out.dtype in self.dtype_cast_back_dict.keys()
):
print(f"OpDtypeCastHook: {self.name:<50} {i}th out {out.dtype} -> {self.dtype_cast_back_dict[out.dtype]} config:{self.dtype_cast_config_str}") # noqa: E501
if isinstance(out, torch.Tensor) and out.dtype in self.dtype_cast_back_dict.keys():
data_dict = {
"name": self.name,
"target": f"output[{i}]",
"action": f"{out.dtype} -> {self.dtype_cast_back_dict[out.dtype]}",
"config": self.dtype_cast_config_str,
}
self.data_dict_list.append(data_dict)
if len(self.data_dict_list) > 0:
print(dict_data_list_to_table(self.data_dict_list))

def is_should_apply(self, *args, **kwargs):
if is_opname_match(self.name, os.getenv("OP_DTYPE_CAST_DISABLE_LIST", "")):
Expand Down
22 changes: 20 additions & 2 deletions op_tools/op_fallback_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .base_hook import BaseHook, DisableHookGuard
from .utils import to_device, is_cpu_op, is_opname_match
from .save_op_args import serialize_args_to_dict
from .pretty_print import packect_data_to_dict_list, dict_data_list_to_table


class OpFallbackHook(BaseHook):
Expand Down Expand Up @@ -38,7 +39,6 @@ def before_call_op(self, *args, **kwargs):
self.is_cpu_op, self.device = is_cpu_op(*args, **kwargs)
if self.is_cpu_op:
return
print(f"OpFallbackHook: {self.name:<50} input: {serialize_args_to_dict(*self.args, **self.kwargs)}")
self.args_device = self.args
self.kwargs_device = self.kwargs
self.args = to_device(
Expand Down Expand Up @@ -73,7 +73,25 @@ def after_call_op(self, result):
dtype_convert_back_dict = self.get_dtype_convert_back_dict()

self.result = to_device(self.device, self.result_cpu, dtype_convert_back_dict)
print(f"OpFallbackHook: {self.name:<50} output: {serialize_args_to_dict(self.result)['args']} cpu output: {serialize_args_to_dict(self.result_cpu)['args']} dtype_convert_back_dict:{dtype_convert_back_dict}") # noqa: E501
self.dump_op_args()

def dump_op_args(self):
data_dict_list = []
data_dict_list += packect_data_to_dict_list(
self.name,
serialize_args_to_dict(*self.args_device, **self.kwargs_device),
prefix="device_input ",
)
data_dict_list += packect_data_to_dict_list(
self.name,
serialize_args_to_dict(*self.args, **self.kwargs),
prefix="cpu_input ",
)
data_dict_list += packect_data_to_dict_list(self.name, serialize_args_to_dict(self.result), prefix="device_output")
data_dict_list += packect_data_to_dict_list(self.name, serialize_args_to_dict(self.result_cpu), prefix="cpu_output ")

table = dict_data_list_to_table(data_dict_list)
print(table)

def is_should_apply(self, *args, **kwargs):
BLACK_OP_LIST = ["torch.Tensor.cpu"]
Expand Down
5 changes: 4 additions & 1 deletion op_tools/op_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import torch
import time
from .utils import to_device, get_function_from_string, traverse_container
from .utils import to_device, get_function_from_string, traverse_container, is_inplace_op
from .op_autocompare_hook import OpAutoCompareHook


Expand Down Expand Up @@ -147,6 +147,9 @@ def load_backward_data(self):
self.grad_outputs_cpu = None

def run_forward(self):
if is_inplace_op(self.name):
self.args = to_device("cuda", self.args_cpu)
self.kwargs = to_device("cuda", self.kwargs_cpu)
self.run_before_forward()
self.result = self.func(*self.args, **self.kwargs)
self.run_after_forward()
Expand Down
Loading

0 comments on commit 1cd0513

Please sign in to comment.