Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zgc/ditorch support print in table format #34

Merged
merged 7 commits into from
Sep 14, 2024
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -144,4 +144,5 @@ CMakeUserPresets.json
# Autogened on ascend
fusion_result.json

op_capture_result/
op_tools_results/
export_only_prof_dir/
35 changes: 34 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,21 @@ OpDtypeCastHook: torch.Tensor.sum 0th out torc
```

### 自定义算子工具生效的条件
```
def apply_feature(ops, feature, condition_func=lambda *args, **kwargs: True):
...
```

op_tools.apply_feature接口可以作用在torch接口和其他第三方接口上,通过condition_func参数可以自定义生效条件,当condition_func返回True时,工具生效,否则不生效。condition_func的输入形参和算子输入形参相同。
feature参数为功能特性,目前支持以下类型:
- fallback: 算子fallback
- cast_dtype: 算子数据类型转换
- op_capture: 算子参数抓取
- autocompare: 算子精度对比 (做精度对比时,需要设备实现和cpu实现的调用接口一致)
- dump_op_args: 算子参数打印
- measure_op_time: 算子执行时间测量


```
import torch
import ditorch
Expand Down Expand Up @@ -496,4 +511,22 @@ apply OpDtypeCastHook on torch.div
OpDtypeCastHook: torch.div 0th arg torch.float32 -> torch.float16 config:torch.float32->torch.float16
OpDtypeCastHook: torch.div 1th arg torch.float32 -> torch.float16 config:torch.float32->torch.float16
OpDtypeCastHook: torch.div 0th out torch.float16 -> torch.float32 config:torch.float32->torch.float16
```
```

### 相关环境变量
| 工具 | 环境变量名 | 值 | 说明 | 备注 |
|---------------------------------|----------------------------------------------|-----------------------------------------------------------|---------------------------|-----------------------------|
| [算子参数抓取工具](#tool1) | OP_CAPTURE_DISABLE_LIST | torch.add,torch.nn.functional.linear,torch.Tensor.relu_ | 不抓取这些算子的参数 | 算子名全称,多个算子时以逗号隔开 |
| [算子参数抓取工具](#tool1) | OP_CAPTURE_LIST | 同上 | 只抓取这些算子的参数 | 同上 |
| [精度分析工具](#tool2) | OP_AUTOCOMPARE_LIST | 同上 | 只对指定的算子做精度对比 | 同上 |
| [精度分析工具](#tool2) | OP_AUTOCOMPARE_DISABLE_LIST | 同上 | 精度对比时忽略指定的这些算子 | 同上 |
| [算子数据类型转换工具](#tool5) | OP_DTYPE_CAST_DISABLE_LIST | 同上 | 做类型转换时忽略指定的这些算子 | 同上 |
| [算子数据类型转换工具](#tool5) | OP_DTYPE_CAST_LIST | 同上 | 只对指定的算子做类型转换 | 同上 |
| [精度分析工具](#tool2) | AUTOCOMPARE_ERROR_TOLERANCE | atol,rtol | allclose 参数 | 如设置,则使用给定的误差阈阈值覆盖默认值 |
| [精度分析工具](#tool2) | AUTOCOMPARE_ERROR_TOLERANCE_FLOAT16 | atol,rtol | allclose 参数 | 如设置且数据类型满足,则使用给定的误差阈值 |
| [精度分析工具](#tool2) | AUTOCOMPARE_ERROR_TOLERANCE_BFLOAT16 | atol,rtol | allclose 参数 | 同上 |
| [精度分析工具](#tool2) | AUTOCOMPARE_ERROR_TOLERANCE_FLOAT32 | atol,rtol | allclose 参数 | 同上 |
| [精度分析工具](#tool2) | AUTOCOMPARE_ERROR_TOLERANCE_FLOAT64 | atol,rtol | allclose 参数 | 同上 |
| [精度分析工具](#tool2) | LINEAR_AUTOCOMPARE_ERROR_TOLERANCE_FLOAT16 | atol,rtol | allclose 参数 |如设置且算子名和数据类型满足,则使用给定的误差阈值。算子名取算子全称最后一个'.'右边的部分,如torch.add,则算子名为ADD_,torch.nn.functional.linear的算子名为LINEAR_ |
| [算子数据类型转换工具](#tool5) | OP_DTYPE_CAST_DICT |torch.float16->torch.float32,torch.bfloat16->torch.float32 | 给定要转换的数据类型和目标数据类型 | 有多组时以逗号隔开 |

6 changes: 6 additions & 0 deletions op_tools/custom_apply_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .op_dispatch_watch_hook import OpDispatchWatcherHook
from .op_time_measure_hook import OpTimeMeasureHook
from .op_dtype_cast_hook import OpDtypeCastHook
from .base_hook import BaseHook


def get_func_name(func):
Expand Down Expand Up @@ -58,6 +59,11 @@ def condition_func(*args, **kwargs):
else:
condition_func = condition_funcs
assert callable(condition_func)
if issubclass(type(func), BaseHook):
print(
f"The {name} is applying multiple hooks, and the previous hook {func.class_name()} will be replaced by the {hook.class_name()}." # noqa: E501
)
func = func.func

hook_obj = hook(name, func)
hook_obj.add_condition_func(condition_func)
Expand Down
93 changes: 62 additions & 31 deletions op_tools/op_autocompare_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import gc
import os
import time
import atexit

from .base_hook import BaseHook, DisableHookGuard
from .utils import (
Expand All @@ -16,6 +17,8 @@
)
from .save_op_args import save_op_args, serialize_args_to_dict

from .pretty_print import pretty_print_op_args, dict_data_list_to_table

RANDOM_NUMBER_GEN_OPS = [
"torch.Tensor.random_",
"torch.Tensor.uniform_",
Expand Down Expand Up @@ -85,6 +88,9 @@ def grad_fun(grad):
return grad_fun


global_autocompare_result = []


class OpAutoCompareHook(BaseHook):
AUTO_COMPARE_DTYPE_CAST_DICT = {
torch.half: torch.float32,
Expand All @@ -95,7 +101,8 @@ def __init__(self, name, func) -> None:
super().__init__(name, func)

def before_call_op(self, *args, **kwargs):
self.forward_op_id = f"autocompare/{self.id}/{time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime())}"
self.forward_op_id = self.id
self.identifier = f"autocompare/{self.id}/{time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime())}"
with DisableHookGuard():
self.is_cpu_op, self.device = is_cpu_op(*args, **kwargs)
if self.is_cpu_op:
Expand Down Expand Up @@ -153,20 +160,29 @@ def after_call_op(self, result): # noqa:C901
)

if is_inplace_op(self.name):
allclose = compare_result(self.name, self.args[0], args_cpu[0])["allclose"]
compare_info = compare_result(self.name, self.args[0], args_cpu[0])
allclose = compare_info["allclose"]
if not allclose:
self.save_forward_args()
compare_info.update({"forward_id": self.forward_op_id})
global_autocompare_result.append(compare_info)

compare_info = compare_result(self.name, self.result_device, self.result_cpu)
compare_info.update({"forward_id": self.forward_op_id})
global_autocompare_result.append(compare_info)
allclose = compare_info["allclose"]

if self.result is None:
print(f"{self.name} output is None, acc not checked")
return

allclose = compare_result(self.name, self.result_device, self.result_cpu)["allclose"]

self.forward_allclose = allclose
if not allclose:
print(f"OpAutoCompareHook: {self.name:<60} input: {serialize_args_to_dict(*self.args, **self.kwargs)}")
print(f"OpAutoCompareHook: {self.name:<60} output: {serialize_args_to_dict(self.result)['args']}")
pretty_print_op_args(
self.name,
serialize_args_to_dict(*self.args, **self.kwargs),
serialize_args_to_dict(self.result),
)
self.save_forward_args()

self.backward_hook_handle = BackwardHookHandle(self)
Expand Down Expand Up @@ -228,25 +244,14 @@ def compare_all_grad(self):
if isinstance(arg, torch.Tensor) and (arg.requires_grad and self.args_grad[i] is None):
return

all_grad_allclose = True
for i in range(len(self.args)):
arg = self.args[i]

if isinstance(arg, torch.Tensor) and arg.requires_grad:
arg_cpu_grad = self.args_cpu_grad[i]
allclose = compare_result(
self.name + f" (ins[{i}].grad)",
self.args_grad[i],
arg_cpu_grad,
)["allclose"]

if not allclose:
all_grad_allclose = False
if not all_grad_allclose:
compare_info = compare_result(self.name + " grad", self.args_cpu_grad, self.args_grad)
compare_info.update({"forward_id": self.forward_op_id})
global_autocompare_result.append(compare_info)
if not compare_info["allclose"]:
# Parameters are not saved when forward accuracy is normal
if self.forward_allclose:
self.save_forward_args()
self.save_backward_args
self.save_backward_args()
self = None
gc.collect()

Expand All @@ -258,38 +263,38 @@ def set_input_grad(self, index, grad):
def save_forward_args(self):
save_op_args(
self.name,
f"{self.forward_op_id}/device/input",
f"{self.identifier}/device/input",
*self.args,
**self.kwargs,
)
save_op_args(self.name, f"{self.forward_op_id}/device/output", self.result)
save_op_args(self.name, f"{self.identifier}/device/output", self.result)
save_op_args(
self.name,
f"{self.forward_op_id}/cpu/input",
f"{self.identifier}/cpu/input",
*self.args_cpu,
**self.kwargs_cpu,
)
save_op_args(self.name, f"{self.forward_op_id}/cpu/output", self.result_cpu)
save_op_args(self.name, f"{self.identifier}/cpu/output", self.result_cpu)

def save_backward_args(self):
save_op_args(
self.name,
f"{self.forward_op_id}/device/grad_outputs",
*tuple(self.grad_output),
f"{self.identifier}/device/grad_outputs",
*tuple(self.grad_outputs_cpu),
)
save_op_args(
self.name,
f"{self.forward_op_id}/device/grad_inputs",
f"{self.identifier}/device/grad_inputs",
*tuple(self.args_grad),
)
save_op_args(
self.name,
f"{self.forward_op_id}/cpu/grad_inputs",
f"{self.identifier}/cpu/grad_inputs",
*tuple(self.args_cpu_grad),
)
save_op_args(
self.name,
f"{self.forward_op_id}/cpu/grad_outputs",
f"{self.identifier}/cpu/grad_outputs",
*tuple(self.grad_outputs_cpu),
)

Expand All @@ -307,3 +312,29 @@ def is_should_apply(self, *args, **kwargs):
return False

return is_opname_match(self.name, os.getenv("OP_AUTOCOMPARE_LIST", ".*"))


def dump_all_autocompare_info():
if len(global_autocompare_result) == 0:
return
all_compare_info_list = []
while len(global_autocompare_result) > 0:
compare_info = global_autocompare_result.pop(0)
while len(compare_info["result_list"]) > 0:
compare_result = compare_info["result_list"].pop(0)
all_compare_info_list.append({"forward_id": compare_info["forward_id"], **compare_result})

table = dict_data_list_to_table(all_compare_info_list)
print(table)
data_string = table.get_csv_string()
file_name = f"op_tools_results/op_autocompare_result/op_autocompare_info_pid{os.getpid()}_{time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime())}.csv" # noqa: E501
dir = file_name[0 : file_name.rfind("/")]
os.makedirs(dir, exist_ok=True)

with open(file_name, "w") as f:
f.write(data_string)
f.close
print(f"op autocompare info saved to {file_name}")


atexit.register(dump_all_autocompare_info)
10 changes: 7 additions & 3 deletions op_tools/op_dispatch_watch_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,23 @@
from .base_hook import BaseHook, DisableHookGuard

from .save_op_args import serialize_args_to_dict
from .pretty_print import pretty_print_op_args


class OpDispatchWatcherHook(BaseHook):
def __init__(self, name, func) -> None:
super().__init__(name, func)

def before_call_op(self, *args, **kwargs):
with DisableHookGuard():
print(f"OpDispatchWatcherHook: {self.name} input: {serialize_args_to_dict(*args, **kwargs)}")
pass

def after_call_op(self, result):
with DisableHookGuard():
print(f"OpDispatchWatcherHook: {self.name} output: {serialize_args_to_dict(self.result)}")
pretty_print_op_args(
self.name,
serialize_args_to_dict(*self.args, **self.kwargs),
serialize_args_to_dict(self.result),
)

def is_should_apply(self, *args, **kwargs):
if is_opname_match(self.name, os.getenv("OP_DISPATCH_WATCH_DISABLE_LIST", "")):
Expand Down
30 changes: 21 additions & 9 deletions op_tools/op_dtype_cast_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
get_dtype_cast_dict_form_str,
is_opname_match,
)
from .pretty_print import dict_data_list_to_table


class OpDtypeCastHook(BaseHook):
Expand Down Expand Up @@ -54,13 +55,19 @@ def before_call_op(self, *args, **kwargs):
for arg in traverse_container(self.args_raw):
self.raw_ins_list.append(arg)

self.data_dict_list = []
for i in range(len(self.ins_list)):
if isinstance(self.ins_list[i], torch.Tensor):
if self.ins_list[i].dtype != self.raw_ins_list[i].dtype:
print(f"OpDtypeCastHook: {self.name:<50} {i}th arg {self.raw_ins_list[i].dtype} -> {self.ins_list[i].dtype} config:{self.dtype_cast_config_str}") # noqa: E501
self.dtype_cast_back_dict[self.ins_list[i].dtype] = (
self.raw_ins_list[i].dtype
)
# print(f"OpDtypeCastHook: {self.name:<50} {i}th arg {self.raw_ins_list[i].dtype} -> {self.ins_list[i].dtype} config:{self.dtype_cast_config_str}") # noqa: E501
self.dtype_cast_back_dict[self.ins_list[i].dtype] = self.raw_ins_list[i].dtype
data_dict = {
"name": self.name,
"target": f"input[{i}]",
"action": f"{self.raw_ins_list[i].dtype} -> {self.ins_list[i].dtype}",
"config": self.dtype_cast_config_str,
}
self.data_dict_list.append(data_dict)

def after_call_op(self, result):
if self.is_cpu_op:
Expand All @@ -76,11 +83,16 @@ def after_call_op(self, result):
i = -1
for out in traverse_container(self.result_raw):
i += 1
if (
isinstance(out, torch.Tensor)
and out.dtype in self.dtype_cast_back_dict.keys()
):
print(f"OpDtypeCastHook: {self.name:<50} {i}th out {out.dtype} -> {self.dtype_cast_back_dict[out.dtype]} config:{self.dtype_cast_config_str}") # noqa: E501
if isinstance(out, torch.Tensor) and out.dtype in self.dtype_cast_back_dict.keys():
data_dict = {
"name": self.name,
"target": f"output[{i}]",
"action": f"{out.dtype} -> {self.dtype_cast_back_dict[out.dtype]}",
"config": self.dtype_cast_config_str,
}
self.data_dict_list.append(data_dict)
if len(self.data_dict_list) > 0:
print(dict_data_list_to_table(self.data_dict_list))

def is_should_apply(self, *args, **kwargs):
if is_opname_match(self.name, os.getenv("OP_DTYPE_CAST_DISABLE_LIST", "")):
Expand Down
22 changes: 20 additions & 2 deletions op_tools/op_fallback_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .base_hook import BaseHook, DisableHookGuard
from .utils import to_device, is_cpu_op, is_opname_match
from .save_op_args import serialize_args_to_dict
from .pretty_print import packect_data_to_dict_list, dict_data_list_to_table


class OpFallbackHook(BaseHook):
Expand Down Expand Up @@ -38,7 +39,6 @@ def before_call_op(self, *args, **kwargs):
self.is_cpu_op, self.device = is_cpu_op(*args, **kwargs)
if self.is_cpu_op:
return
print(f"OpFallbackHook: {self.name:<50} input: {serialize_args_to_dict(*self.args, **self.kwargs)}")
self.args_device = self.args
self.kwargs_device = self.kwargs
self.args = to_device(
Expand Down Expand Up @@ -73,7 +73,25 @@ def after_call_op(self, result):
dtype_convert_back_dict = self.get_dtype_convert_back_dict()

self.result = to_device(self.device, self.result_cpu, dtype_convert_back_dict)
print(f"OpFallbackHook: {self.name:<50} output: {serialize_args_to_dict(self.result)['args']} cpu output: {serialize_args_to_dict(self.result_cpu)['args']} dtype_convert_back_dict:{dtype_convert_back_dict}") # noqa: E501
self.dump_op_args()

def dump_op_args(self):
data_dict_list = []
data_dict_list += packect_data_to_dict_list(
self.name,
serialize_args_to_dict(*self.args_device, **self.kwargs_device),
prefix="device_input ",
)
data_dict_list += packect_data_to_dict_list(
self.name,
serialize_args_to_dict(*self.args, **self.kwargs),
prefix="cpu_input ",
)
data_dict_list += packect_data_to_dict_list(self.name, serialize_args_to_dict(self.result), prefix="device_output")
data_dict_list += packect_data_to_dict_list(self.name, serialize_args_to_dict(self.result_cpu), prefix="cpu_output ")

table = dict_data_list_to_table(data_dict_list)
print(table)

def is_should_apply(self, *args, **kwargs):
BLACK_OP_LIST = ["torch.Tensor.cpu"]
Expand Down
5 changes: 4 additions & 1 deletion op_tools/op_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import torch
import time
from .utils import to_device, get_function_from_string, traverse_container
from .utils import to_device, get_function_from_string, traverse_container, is_inplace_op
from .op_autocompare_hook import OpAutoCompareHook


Expand Down Expand Up @@ -147,6 +147,9 @@ def load_backward_data(self):
self.grad_outputs_cpu = None

def run_forward(self):
if is_inplace_op(self.name):
self.args = to_device("cuda", self.args_cpu)
self.kwargs = to_device("cuda", self.kwargs_cpu)
self.run_before_forward()
self.result = self.func(*self.args, **self.kwargs)
self.run_after_forward()
Expand Down
Loading