Skip to content

Commit

Permalink
Zgc/ditorch support overflow apply to third party interface (#73)
Browse files Browse the repository at this point in the history
* Fixed the issue where autocompare reports input being modified in situations like torch.exp(x, out=x)

* support overflow_check apply on non-torch functions

* update README.md

* add option manger
  • Loading branch information
zhaoguochun1995 authored Oct 28, 2024
1 parent a38822b commit fb14343
Show file tree
Hide file tree
Showing 12 changed files with 111 additions and 52 deletions.
18 changes: 16 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,22 @@ ditorch 是设备无关 torch, 旨在屏蔽各硬件厂商 torch 差异,为

只需添加两行代码,即可在国产芯片上像官方 pytorch 一样使用。
```
import torch
import ditorch
>>> import torch
>>> import ditorch
ditorch.framework: torch_npu:2.1.0.post3 pid: 1729023
>>> x = torch.randn(3,4,device="cuda")
>>>
>>> y = x + x
>>> x
Warning: Device do not support double dtype now, dtype cast repalce with float.
tensor([[ 1.3310, 1.0011, -1.0679, -1.5444],
[-0.7345, -0.9888, -1.7310, -0.3305],
[-0.6676, -1.7792, 0.7108, -0.9981]], device='cuda:0')
>>> y
tensor([[ 2.6619, 2.0023, -2.1359, -3.0887],
[-1.4691, -1.9777, -3.4620, -0.6609],
[-1.3353, -3.5583, 1.4216, -1.9962]], device='cuda:0')
>>>
```

[ditorch + Ascend910 pytorch原生测例通过情况](ditorch/test/ascend_summary_of_pytorch_test_case_testing.csv.tar)
Expand Down
4 changes: 4 additions & 0 deletions op_tools/custom_apply_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .op_observe_hook import OpObserveHook
from .op_time_measure_hook import OpTimeMeasureHook
from .op_dtype_cast_hook import OpDtypeCastHook
from .op_overflow_check_hook import OpOverflowCheckHook
from .base_hook import BaseHook


Expand Down Expand Up @@ -79,6 +80,7 @@ def apply_feature(ops, feature, condition_func=lambda *args, **kwargs: True):
"dump_op_args",
"cast_dtype",
"op_capture",
"overflow_check",
]
assert feature in feature_options, f"feature must be one of {feature_options}, but got {feature}"
assert callable(condition_func)
Expand All @@ -94,6 +96,8 @@ def apply_feature(ops, feature, condition_func=lambda *args, **kwargs: True):
hook_cls = OpDtypeCastHook
elif feature == "op_capture":
hook_cls = OpCaptureHook
elif feature == "overflow_check":
hook_cls = OpOverflowCheckHook

if isinstance(ops, str):
apply_hook_to_ops(ops, hook_cls, condition_func)
Expand Down
27 changes: 14 additions & 13 deletions op_tools/op_autocompare_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
is_random_number_gen_op,
garbage_collect,
get_dtype_cast_dict_form_str,
set_env_if_env_is_empty
set_option_if_empty,
get_option,
)
from .save_op_args import save_op_args, serialize_args_to_dict

Expand All @@ -37,7 +38,7 @@ def append(self, forward_id, compare_info):
for result in compare_info["result_list"]:
self.global_autocompare_result.append({"forward_id": forward_id, **result})

if len(self.global_autocompare_result) > int(os.getenv("OP_TOOLS_MAX_CACHE_SIZE", "100")):
if len(self.global_autocompare_result) > int(get_option("OP_TOOLS_MAX_CACHE_SIZE", "100")):
self.write_to_file()

def write_to_file(self):
Expand Down Expand Up @@ -85,13 +86,13 @@ def cleanup(self):
self.hook_handle = None


set_env_if_env_is_empty("LINEAR_OP_DTYPE_CAST_DICT", "torch.float16->torch.float64,torch.bfloat16->torch.float64,torch.float32->torch.float64") # noqa: E501
set_env_if_env_is_empty("EMBEDDING_OP_DTYPE_CAST_DICT", "torch.float16->torch.float64,torch.bfloat16->torch.float64,torch.float32->torch.float64") # noqa: E501
set_env_if_env_is_empty("NORMALIZE_OP_DTYPE_CAST_DICT", "torch.float16->torch.float64,torch.bfloat16->torch.float64,torch.float32->torch.float64") # noqa: E501
set_env_if_env_is_empty("NORM_OP_DTYPE_CAST_DICT", "torch.float16->torch.float64,torch.bfloat16->torch.float64,torch.float32->torch.float64") # noqa: E501
set_env_if_env_is_empty("CROSS_ENTROPY_OP_DTYPE_CAST_DICT", "torch.float16->torch.float64,torch.bfloat16->torch.float64,torch.float32->torch.float64") # noqa: E501
set_env_if_env_is_empty("MUL_OP_DTYPE_CAST_DICT", "torch.float16->torch.float64,torch.bfloat16->torch.float64,torch.float32->torch.float64") # noqa: E501
set_env_if_env_is_empty("MATMUL_OP_DTYPE_CAST_DICT", "torch.float16->torch.float64,torch.bfloat16->torch.float64,torch.float32->torch.float64") # noqa: E501
set_option_if_empty("LINEAR_OP_DTYPE_CAST_DICT", "torch.float16->torch.float64,torch.bfloat16->torch.float64,torch.float32->torch.float64") # noqa: E501
set_option_if_empty("EMBEDDING_OP_DTYPE_CAST_DICT", "torch.float16->torch.float64,torch.bfloat16->torch.float64,torch.float32->torch.float64") # noqa: E501
set_option_if_empty("NORMALIZE_OP_DTYPE_CAST_DICT", "torch.float16->torch.float64,torch.bfloat16->torch.float64,torch.float32->torch.float64") # noqa: E501
set_option_if_empty("NORM_OP_DTYPE_CAST_DICT", "torch.float16->torch.float64,torch.bfloat16->torch.float64,torch.float32->torch.float64") # noqa: E501
set_option_if_empty("CROSS_ENTROPY_OP_DTYPE_CAST_DICT", "torch.float16->torch.float64,torch.bfloat16->torch.float64,torch.float32->torch.float64") # noqa: E501
set_option_if_empty("MUL_OP_DTYPE_CAST_DICT", "torch.float16->torch.float64,torch.bfloat16->torch.float64,torch.float32->torch.float64") # noqa: E501
set_option_if_empty("MATMUL_OP_DTYPE_CAST_DICT", "torch.float16->torch.float64,torch.bfloat16->torch.float64,torch.float32->torch.float64") # noqa: E501


class OpAutoCompareHook(BaseHook):
Expand All @@ -117,8 +118,8 @@ def update_dtype_cast_dict(self):
heigher_priority_env_name = op_name_suffix + "_OP_DTYPE_CAST_DICT"
lower_priority_env_name = "OP_DTYPE_CAST_DICT"
default_env_value = "torch.float16->torch.float32,torch.bfloat16->torch.float32"
heigher_priority_env_value = os.environ.get(heigher_priority_env_name, None)
lower_priority_env_value = os.environ.get(lower_priority_env_name, default_env_value)
heigher_priority_env_value = get_option(heigher_priority_env_name, None)
lower_priority_env_value = get_option(lower_priority_env_name, default_env_value)
self.dtype_cast_config_str = heigher_priority_env_value or lower_priority_env_value

self.dtype_cast_dict = get_dtype_cast_dict_form_str(self.dtype_cast_config_str)
Expand Down Expand Up @@ -407,10 +408,10 @@ def is_should_apply(self, *args, **kwargs):
if self.name.startswith("torch.empty"):
return False

if is_opname_match(self.name, os.getenv("OP_AUTOCOMPARE_DISABLE_LIST", "")):
if is_opname_match(self.name, get_option("OP_AUTOCOMPARE_DISABLE_LIST", "")):
return False

return is_opname_match(self.name, os.getenv("OP_AUTOCOMPARE_LIST", ".*"))
return is_opname_match(self.name, get_option("OP_AUTOCOMPARE_LIST", ".*"))


atexit.register(dump_all_autocompare_info)
7 changes: 3 additions & 4 deletions op_tools/op_capture_hook.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
# Copyright (c) 2024, DeepLink.
import os
import torch
import time
from .base_hook import BaseHook, DisableHookGuard
from .utils import traverse_container, is_opname_match, garbage_collect
from .utils import traverse_container, is_opname_match, garbage_collect, get_option
from .save_op_args import save_op_args, serialize_args_to_dict
from .pretty_print import dict_data_list_to_table, packect_data_to_dict_list

Expand Down Expand Up @@ -71,10 +70,10 @@ def after_call_op(self, result):
garbage_collect()

def is_should_apply(self, *args, **kwargs):
if is_opname_match(self.name, os.getenv("OP_CAPTURE_DISABLE_LIST", "")):
if is_opname_match(self.name, get_option("OP_CAPTURE_DISABLE_LIST", "")):
return False

if not is_opname_match(self.name, os.getenv("OP_CAPTURE_LIST", ".*")):
if not is_opname_match(self.name, get_option("OP_CAPTURE_LIST", ".*")):
return False

return True
7 changes: 4 additions & 3 deletions op_tools/op_dtype_cast_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
is_opname_match,
is_view_op,
is_dtype_cast_op,
garbage_collect
garbage_collect,
get_option,
)
from .pretty_print import dict_data_list_to_table

Expand Down Expand Up @@ -122,7 +123,7 @@ def after_call_op(self, result):
return result

def is_should_apply(self, *args, **kwargs):
if is_opname_match(self.name, os.getenv("OP_DTYPE_CAST_DISABLE_LIST", "")):
if is_opname_match(self.name, get_option("OP_DTYPE_CAST_DISABLE_LIST", "")):
return False

if is_view_op(self.name):
Expand All @@ -131,4 +132,4 @@ def is_should_apply(self, *args, **kwargs):
if is_dtype_cast_op(self.name, *args, **kwargs):
return False

return is_opname_match(self.name, os.getenv("OP_DTYPE_CAST_LIST", ".*"))
return is_opname_match(self.name, get_option("OP_DTYPE_CAST_LIST", ".*"))
7 changes: 3 additions & 4 deletions op_tools/op_fallback_hook.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
# Copyright (c) 2024, DeepLink.
import torch
import os

from .base_hook import BaseHook, DisableHookGuard
from .utils import to_device, is_cpu_op, is_opname_match, garbage_collect
from .utils import to_device, is_cpu_op, is_opname_match, garbage_collect, get_option
from .save_op_args import serialize_args_to_dict
from .pretty_print import packect_data_to_dict_list, dict_data_list_to_table

Expand Down Expand Up @@ -114,9 +113,9 @@ def is_should_apply(self, *args, **kwargs):
if self.name in BLACK_OP_LIST:
return False

if is_opname_match(self.name, os.getenv("OP_FALLBACK_DISABLE_LIST", "")):
if is_opname_match(self.name, get_option("OP_FALLBACK_DISABLE_LIST", "")):
return False
# if name in VIEW_OPS:
# return False

return is_opname_match(self.name, os.getenv("OP_FALLBACK_LIST", ".*"))
return is_opname_match(self.name, get_option("OP_FALLBACK_LIST", ".*"))
7 changes: 3 additions & 4 deletions op_tools/op_observe_hook.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Copyright (c) 2024, DeepLink.
import os
from .utils import is_opname_match
from .utils import is_opname_match, get_option
from .base_hook import BaseHook, DisableHookGuard

from .save_op_args import serialize_args_to_dict
Expand All @@ -26,6 +25,6 @@ def after_call_op(self, result):
print("\n" * 2)

def is_should_apply(self, *args, **kwargs):
if is_opname_match(self.name, os.getenv("OP_OBSERVE_DISABLE_LIST", "")):
if is_opname_match(self.name, get_option("OP_OBSERVE_DISABLE_LIST", "")):
return False
return is_opname_match(self.name, os.getenv("OP_OBSERVE_LIST", ".*"))
return is_opname_match(self.name, get_option("OP_OBSERVE_LIST", ".*"))
7 changes: 3 additions & 4 deletions op_tools/op_overflow_check_hook.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Copyright (c) 2024, DeepLink.
import os
from .utils import is_opname_match, traverse_container, is_inf_or_nan, garbage_collect, compute_tensor_features
from .utils import is_opname_match, traverse_container, is_inf_or_nan, garbage_collect, compute_tensor_features, get_option
from .base_hook import BaseHook, DisableHookGuard
import torch

Expand Down Expand Up @@ -106,6 +105,6 @@ def after_call_op(self, result):
garbage_collect()

def is_should_apply(self, *args, **kwargs):
if is_opname_match(self.name, os.getenv("OP_OVERFLOW_CHECK_DISABLE_LIST", "")):
if is_opname_match(self.name, get_option("OP_OVERFLOW_CHECK_DISABLE_LIST", "")):
return False
return is_opname_match(self.name, os.getenv("OP_OVERFLOW_CHECK_LIST", ".*"))
return is_opname_match(self.name, get_option("OP_OVERFLOW_CHECK_LIST", ".*"))
8 changes: 4 additions & 4 deletions op_tools/op_time_measure_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .base_hook import BaseHook, DisableHookGuard

from .save_op_args import serialize_args_to_dict
from .utils import is_opname_match, traverse_container, garbage_collect
from .utils import is_opname_match, traverse_container, garbage_collect, get_option
from .pretty_print import (
dict_data_list_to_table,
packect_data_to_dict_list,
Expand Down Expand Up @@ -39,7 +39,7 @@ def append(self, forward_id, elasped_info):
else:
self.global_elasped_info_dict[forward_id].update(elasped_info)

if len(self.global_elasped_info_dict) > int(os.getenv("OP_TOOLS_MAX_CACHE_SIZE", "5000")):
if len(self.global_elasped_info_dict) > int(get_option("OP_TOOLS_MAX_CACHE_SIZE", "5000")):
self.write_to_file()

def write_to_file(self):
Expand Down Expand Up @@ -153,10 +153,10 @@ def after_call_op(self, result):
garbage_collect()

def is_should_apply(self, *args, **kwargs):
if is_opname_match(self.name, os.getenv("OP_TIME_MEASURE_DISABLE_LIST", "")):
if is_opname_match(self.name, get_option("OP_TIME_MEASURE_DISABLE_LIST", "")):
return False

return is_opname_match(self.name, os.getenv("OP_TIME_MEASURE_LIST", ".*"))
return is_opname_match(self.name, get_option("OP_TIME_MEASURE_LIST", ".*"))


atexit.register(dump_all_op_elasped_info)
5 changes: 3 additions & 2 deletions op_tools/process_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
from prettytable import PrettyTable
from op_tools.pretty_print import dict_data_list_to_table
from op_tools.utils import get_option

is_ascend_npu_env = subprocess.run("npu-smi info", shell=True, capture_output=True, text=True).returncode == 0
is_camb_mlu_env = subprocess.run("cnmon", shell=True, capture_output=True, text=True).returncode == 0
Expand Down Expand Up @@ -104,13 +105,13 @@ def __init__(self, pid) -> None:
device_name = "ascend"
if is_camb_mlu_env:
device_name = "camb"
self.file_name = f"op_tools_results/process_monitor_result/process_monitor_result_{device_name}_pid{pid}_{os.getenv('label', '')}_{time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime())}.csv" # noqa: E501
self.file_name = f"op_tools_results/process_monitor_result/process_monitor_result_{device_name}_pid{pid}_{get_option('label', '')}_{time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime())}.csv" # noqa: E501
self.dir = self.file_name[0 : self.file_name.rfind("/")]

def append(self, info):
self.global_result.append(info)

if len(self.global_result) > int(os.getenv("OP_TOOLS_MAX_CACHE_SIZE", "1")):
if len(self.global_result) > int(get_option("OP_TOOLS_MAX_CACHE_SIZE", "1")):
self.write_to_file()

def write_to_file(self):
Expand Down
9 changes: 9 additions & 0 deletions op_tools/test/test_custom_apply_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,15 @@ def test_cast_dtype(self):
y = torch.rand(4, 5, dtype=torch.float16, device="cuda", requires_grad=True)
_test_function(x, y)

def test_overflow_check(self):
op_tools.apply_feature(
ops=["torch.add", "torch.sub", "torch.mul", "torch.div"],
feature="overflow_check",
)
x = torch.tensor([1, 2, 3], device="cuda", dtype=torch.float16, requires_grad=True)
y = torch.tensor([4, 5, 6], device="cuda", dtype=torch.float16, requires_grad=True)
_test_function(x, y)

def test_condition_fallback(self):
def condition_func(a, b, **kwargs):
if a.dtype == torch.float16:
Expand Down
57 changes: 45 additions & 12 deletions op_tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,12 +299,12 @@ def get_error_tolerance_for_type(dtype_name, atol, rtol):
op_name_processed = op_name.split(".")[-1].upper() + "_"
env_name = "AUTOCOMPARE_ERROR_TOLERANCE_" + dtype_name.upper()
high_priority_env_name = op_name_processed + env_name
if os.getenv(high_priority_env_name) is not None:
atol, rtol = map(float, os.getenv(high_priority_env_name).split(","))
elif os.getenv(env_name) is not None:
atol, rtol = map(float, os.getenv(env_name).split(","))
elif os.getenv("AUTOCOMPARE_ERROR_TOLERANCE") is not None:
atol, rtol = map(float, os.getenv("AUTOCOMPARE_ERROR_TOLERANCE").split(","))
if get_option(high_priority_env_name) is not None:
atol, rtol = map(float, get_option(high_priority_env_name).split(","))
elif get_option(env_name) is not None:
atol, rtol = map(float, get_option(env_name).split(","))
elif get_option("AUTOCOMPARE_ERROR_TOLERANCE") is not None:
atol, rtol = map(float, get_option("AUTOCOMPARE_ERROR_TOLERANCE").split(","))
return atol, rtol

if dtype == torch.float16:
Expand All @@ -317,8 +317,8 @@ def get_error_tolerance_for_type(dtype_name, atol, rtol):
return get_error_tolerance_for_type("FLOAT64", 1e-8, 1e-8)
else:
atol, rtol = 1e-3, 1e-3
if os.getenv("AUTOCOMPARE_ERROR_TOLERANCE") is not None:
atol, rtol = map(float, os.getenv("AUTOCOMPARE_ERROR_TOLERANCE").split(","))
if get_option("AUTOCOMPARE_ERROR_TOLERANCE") is not None:
atol, rtol = map(float, get_option("AUTOCOMPARE_ERROR_TOLERANCE").split(","))
return atol, rtol


Expand Down Expand Up @@ -510,7 +510,7 @@ def current_location(name=None, stack_depth=-1, print_stack=False):
else:
break

if print_stack or int(os.getenv("OP_TOOLS_PRINT_STACK", "0")) > 0:
if print_stack or int(get_option("OP_TOOLS_PRINT_STACK", "0")) > 0:
for i in range(len(stack) - 2):
file, line, func, text = stack[i]
print(f"{file}:{line} {func} {text}")
Expand All @@ -519,6 +519,39 @@ def current_location(name=None, stack_depth=-1, print_stack=False):
return f"{file}:{line} {func}: {text}"


def set_env_if_env_is_empty(env_name, env_value):
if os.environ.get(env_name, None) is None:
os.environ[env_name] = env_value
class OptionManger:
def __init__(self):
self.options = {}

def set_option(self, option_name, value):
if option_name in self.options:
if self.options[option_name] == value:
return
print(f"option: {option_name}={value}")
self.options[option_name] = value
os.environ[option_name] = str(value)

def get_option(self, option_name, default_value=None):
if option_name in os.environ:
value = os.environ[option_name]
self.set_option(option_name, value)
return value
else:
if default_value is not None:
self.set_option(option_name, default_value)
return default_value

def is_option_set(self, option_name):
return option_name in os.environ


global_option_manger = OptionManger()


def set_option_if_empty(option_name, value):
if not global_option_manger.is_option_set(option_name):
global_option_manger.set_option(option_name, value)


def get_option(option_name, default_value=None):
return global_option_manger.get_option(option_name, default_value)

0 comments on commit fb14343

Please sign in to comment.