Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zgc/ditorch support overflow apply to third party interface #73

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,22 @@ ditorch 是设备无关 torch, 旨在屏蔽各硬件厂商 torch 差异,为

只需添加两行代码,即可在国产芯片上像官方 pytorch 一样使用。
```
import torch
import ditorch
>>> import torch
>>> import ditorch
ditorch.framework: torch_npu:2.1.0.post3 pid: 1729023
>>> x = torch.randn(3,4,device="cuda")
>>>
>>> y = x + x
>>> x
Warning: Device do not support double dtype now, dtype cast repalce with float.
tensor([[ 1.3310, 1.0011, -1.0679, -1.5444],
[-0.7345, -0.9888, -1.7310, -0.3305],
[-0.6676, -1.7792, 0.7108, -0.9981]], device='cuda:0')
>>> y
tensor([[ 2.6619, 2.0023, -2.1359, -3.0887],
[-1.4691, -1.9777, -3.4620, -0.6609],
[-1.3353, -3.5583, 1.4216, -1.9962]], device='cuda:0')
>>>
```

[ditorch + Ascend910 pytorch原生测例通过情况](ditorch/test/ascend_summary_of_pytorch_test_case_testing.csv.tar)
Expand Down
4 changes: 4 additions & 0 deletions op_tools/custom_apply_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .op_observe_hook import OpObserveHook
from .op_time_measure_hook import OpTimeMeasureHook
from .op_dtype_cast_hook import OpDtypeCastHook
from .op_overflow_check_hook import OpOverflowCheckHook
from .base_hook import BaseHook


Expand Down Expand Up @@ -79,6 +80,7 @@ def apply_feature(ops, feature, condition_func=lambda *args, **kwargs: True):
"dump_op_args",
"cast_dtype",
"op_capture",
"overflow_check",
]
assert feature in feature_options, f"feature must be one of {feature_options}, but got {feature}"
assert callable(condition_func)
Expand All @@ -94,6 +96,8 @@ def apply_feature(ops, feature, condition_func=lambda *args, **kwargs: True):
hook_cls = OpDtypeCastHook
elif feature == "op_capture":
hook_cls = OpCaptureHook
elif feature == "overflow_check":
hook_cls = OpOverflowCheckHook

if isinstance(ops, str):
apply_hook_to_ops(ops, hook_cls, condition_func)
Expand Down
27 changes: 14 additions & 13 deletions op_tools/op_autocompare_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
is_random_number_gen_op,
garbage_collect,
get_dtype_cast_dict_form_str,
set_env_if_env_is_empty
set_option_if_empty,
get_option,
)
from .save_op_args import save_op_args, serialize_args_to_dict

Expand All @@ -37,7 +38,7 @@ def append(self, forward_id, compare_info):
for result in compare_info["result_list"]:
self.global_autocompare_result.append({"forward_id": forward_id, **result})

if len(self.global_autocompare_result) > int(os.getenv("OP_TOOLS_MAX_CACHE_SIZE", "100")):
if len(self.global_autocompare_result) > int(get_option("OP_TOOLS_MAX_CACHE_SIZE", "100")):
self.write_to_file()

def write_to_file(self):
Expand Down Expand Up @@ -85,13 +86,13 @@ def cleanup(self):
self.hook_handle = None


set_env_if_env_is_empty("LINEAR_OP_DTYPE_CAST_DICT", "torch.float16->torch.float64,torch.bfloat16->torch.float64,torch.float32->torch.float64") # noqa: E501
set_env_if_env_is_empty("EMBEDDING_OP_DTYPE_CAST_DICT", "torch.float16->torch.float64,torch.bfloat16->torch.float64,torch.float32->torch.float64") # noqa: E501
set_env_if_env_is_empty("NORMALIZE_OP_DTYPE_CAST_DICT", "torch.float16->torch.float64,torch.bfloat16->torch.float64,torch.float32->torch.float64") # noqa: E501
set_env_if_env_is_empty("NORM_OP_DTYPE_CAST_DICT", "torch.float16->torch.float64,torch.bfloat16->torch.float64,torch.float32->torch.float64") # noqa: E501
set_env_if_env_is_empty("CROSS_ENTROPY_OP_DTYPE_CAST_DICT", "torch.float16->torch.float64,torch.bfloat16->torch.float64,torch.float32->torch.float64") # noqa: E501
set_env_if_env_is_empty("MUL_OP_DTYPE_CAST_DICT", "torch.float16->torch.float64,torch.bfloat16->torch.float64,torch.float32->torch.float64") # noqa: E501
set_env_if_env_is_empty("MATMUL_OP_DTYPE_CAST_DICT", "torch.float16->torch.float64,torch.bfloat16->torch.float64,torch.float32->torch.float64") # noqa: E501
set_option_if_empty("LINEAR_OP_DTYPE_CAST_DICT", "torch.float16->torch.float64,torch.bfloat16->torch.float64,torch.float32->torch.float64") # noqa: E501
set_option_if_empty("EMBEDDING_OP_DTYPE_CAST_DICT", "torch.float16->torch.float64,torch.bfloat16->torch.float64,torch.float32->torch.float64") # noqa: E501
set_option_if_empty("NORMALIZE_OP_DTYPE_CAST_DICT", "torch.float16->torch.float64,torch.bfloat16->torch.float64,torch.float32->torch.float64") # noqa: E501
set_option_if_empty("NORM_OP_DTYPE_CAST_DICT", "torch.float16->torch.float64,torch.bfloat16->torch.float64,torch.float32->torch.float64") # noqa: E501
set_option_if_empty("CROSS_ENTROPY_OP_DTYPE_CAST_DICT", "torch.float16->torch.float64,torch.bfloat16->torch.float64,torch.float32->torch.float64") # noqa: E501
set_option_if_empty("MUL_OP_DTYPE_CAST_DICT", "torch.float16->torch.float64,torch.bfloat16->torch.float64,torch.float32->torch.float64") # noqa: E501
set_option_if_empty("MATMUL_OP_DTYPE_CAST_DICT", "torch.float16->torch.float64,torch.bfloat16->torch.float64,torch.float32->torch.float64") # noqa: E501


class OpAutoCompareHook(BaseHook):
Expand All @@ -117,8 +118,8 @@ def update_dtype_cast_dict(self):
heigher_priority_env_name = op_name_suffix + "_OP_DTYPE_CAST_DICT"
lower_priority_env_name = "OP_DTYPE_CAST_DICT"
default_env_value = "torch.float16->torch.float32,torch.bfloat16->torch.float32"
heigher_priority_env_value = os.environ.get(heigher_priority_env_name, None)
lower_priority_env_value = os.environ.get(lower_priority_env_name, default_env_value)
heigher_priority_env_value = get_option(heigher_priority_env_name, None)
lower_priority_env_value = get_option(lower_priority_env_name, default_env_value)
self.dtype_cast_config_str = heigher_priority_env_value or lower_priority_env_value

self.dtype_cast_dict = get_dtype_cast_dict_form_str(self.dtype_cast_config_str)
Expand Down Expand Up @@ -407,10 +408,10 @@ def is_should_apply(self, *args, **kwargs):
if self.name.startswith("torch.empty"):
return False

if is_opname_match(self.name, os.getenv("OP_AUTOCOMPARE_DISABLE_LIST", "")):
if is_opname_match(self.name, get_option("OP_AUTOCOMPARE_DISABLE_LIST", "")):
return False

return is_opname_match(self.name, os.getenv("OP_AUTOCOMPARE_LIST", ".*"))
return is_opname_match(self.name, get_option("OP_AUTOCOMPARE_LIST", ".*"))


atexit.register(dump_all_autocompare_info)
7 changes: 3 additions & 4 deletions op_tools/op_capture_hook.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
# Copyright (c) 2024, DeepLink.
import os
import torch
import time
from .base_hook import BaseHook, DisableHookGuard
from .utils import traverse_container, is_opname_match, garbage_collect
from .utils import traverse_container, is_opname_match, garbage_collect, get_option
from .save_op_args import save_op_args, serialize_args_to_dict
from .pretty_print import dict_data_list_to_table, packect_data_to_dict_list

Expand Down Expand Up @@ -71,10 +70,10 @@ def after_call_op(self, result):
garbage_collect()

def is_should_apply(self, *args, **kwargs):
if is_opname_match(self.name, os.getenv("OP_CAPTURE_DISABLE_LIST", "")):
if is_opname_match(self.name, get_option("OP_CAPTURE_DISABLE_LIST", "")):
return False

if not is_opname_match(self.name, os.getenv("OP_CAPTURE_LIST", ".*")):
if not is_opname_match(self.name, get_option("OP_CAPTURE_LIST", ".*")):
return False

return True
7 changes: 4 additions & 3 deletions op_tools/op_dtype_cast_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
is_opname_match,
is_view_op,
is_dtype_cast_op,
garbage_collect
garbage_collect,
get_option,
)
from .pretty_print import dict_data_list_to_table

Expand Down Expand Up @@ -122,7 +123,7 @@ def after_call_op(self, result):
return result

def is_should_apply(self, *args, **kwargs):
if is_opname_match(self.name, os.getenv("OP_DTYPE_CAST_DISABLE_LIST", "")):
if is_opname_match(self.name, get_option("OP_DTYPE_CAST_DISABLE_LIST", "")):
return False

if is_view_op(self.name):
Expand All @@ -131,4 +132,4 @@ def is_should_apply(self, *args, **kwargs):
if is_dtype_cast_op(self.name, *args, **kwargs):
return False

return is_opname_match(self.name, os.getenv("OP_DTYPE_CAST_LIST", ".*"))
return is_opname_match(self.name, get_option("OP_DTYPE_CAST_LIST", ".*"))
7 changes: 3 additions & 4 deletions op_tools/op_fallback_hook.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
# Copyright (c) 2024, DeepLink.
import torch
import os

from .base_hook import BaseHook, DisableHookGuard
from .utils import to_device, is_cpu_op, is_opname_match, garbage_collect
from .utils import to_device, is_cpu_op, is_opname_match, garbage_collect, get_option
from .save_op_args import serialize_args_to_dict
from .pretty_print import packect_data_to_dict_list, dict_data_list_to_table

Expand Down Expand Up @@ -114,9 +113,9 @@ def is_should_apply(self, *args, **kwargs):
if self.name in BLACK_OP_LIST:
return False

if is_opname_match(self.name, os.getenv("OP_FALLBACK_DISABLE_LIST", "")):
if is_opname_match(self.name, get_option("OP_FALLBACK_DISABLE_LIST", "")):
return False
# if name in VIEW_OPS:
# return False

return is_opname_match(self.name, os.getenv("OP_FALLBACK_LIST", ".*"))
return is_opname_match(self.name, get_option("OP_FALLBACK_LIST", ".*"))
7 changes: 3 additions & 4 deletions op_tools/op_observe_hook.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Copyright (c) 2024, DeepLink.
import os
from .utils import is_opname_match
from .utils import is_opname_match, get_option
from .base_hook import BaseHook, DisableHookGuard

from .save_op_args import serialize_args_to_dict
Expand All @@ -26,6 +25,6 @@ def after_call_op(self, result):
print("\n" * 2)

def is_should_apply(self, *args, **kwargs):
if is_opname_match(self.name, os.getenv("OP_OBSERVE_DISABLE_LIST", "")):
if is_opname_match(self.name, get_option("OP_OBSERVE_DISABLE_LIST", "")):
return False
return is_opname_match(self.name, os.getenv("OP_OBSERVE_LIST", ".*"))
return is_opname_match(self.name, get_option("OP_OBSERVE_LIST", ".*"))
7 changes: 3 additions & 4 deletions op_tools/op_overflow_check_hook.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Copyright (c) 2024, DeepLink.
import os
from .utils import is_opname_match, traverse_container, is_inf_or_nan, garbage_collect, compute_tensor_features
from .utils import is_opname_match, traverse_container, is_inf_or_nan, garbage_collect, compute_tensor_features, get_option
from .base_hook import BaseHook, DisableHookGuard
import torch

Expand Down Expand Up @@ -106,6 +105,6 @@ def after_call_op(self, result):
garbage_collect()

def is_should_apply(self, *args, **kwargs):
if is_opname_match(self.name, os.getenv("OP_OVERFLOW_CHECK_DISABLE_LIST", "")):
if is_opname_match(self.name, get_option("OP_OVERFLOW_CHECK_DISABLE_LIST", "")):
return False
return is_opname_match(self.name, os.getenv("OP_OVERFLOW_CHECK_LIST", ".*"))
return is_opname_match(self.name, get_option("OP_OVERFLOW_CHECK_LIST", ".*"))
8 changes: 4 additions & 4 deletions op_tools/op_time_measure_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .base_hook import BaseHook, DisableHookGuard

from .save_op_args import serialize_args_to_dict
from .utils import is_opname_match, traverse_container, garbage_collect
from .utils import is_opname_match, traverse_container, garbage_collect, get_option
from .pretty_print import (
dict_data_list_to_table,
packect_data_to_dict_list,
Expand Down Expand Up @@ -39,7 +39,7 @@ def append(self, forward_id, elasped_info):
else:
self.global_elasped_info_dict[forward_id].update(elasped_info)

if len(self.global_elasped_info_dict) > int(os.getenv("OP_TOOLS_MAX_CACHE_SIZE", "5000")):
if len(self.global_elasped_info_dict) > int(get_option("OP_TOOLS_MAX_CACHE_SIZE", "5000")):
self.write_to_file()

def write_to_file(self):
Expand Down Expand Up @@ -153,10 +153,10 @@ def after_call_op(self, result):
garbage_collect()

def is_should_apply(self, *args, **kwargs):
if is_opname_match(self.name, os.getenv("OP_TIME_MEASURE_DISABLE_LIST", "")):
if is_opname_match(self.name, get_option("OP_TIME_MEASURE_DISABLE_LIST", "")):
return False

return is_opname_match(self.name, os.getenv("OP_TIME_MEASURE_LIST", ".*"))
return is_opname_match(self.name, get_option("OP_TIME_MEASURE_LIST", ".*"))


atexit.register(dump_all_op_elasped_info)
5 changes: 3 additions & 2 deletions op_tools/process_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
from prettytable import PrettyTable
from op_tools.pretty_print import dict_data_list_to_table
from op_tools.utils import get_option

is_ascend_npu_env = subprocess.run("npu-smi info", shell=True, capture_output=True, text=True).returncode == 0
is_camb_mlu_env = subprocess.run("cnmon", shell=True, capture_output=True, text=True).returncode == 0
Expand Down Expand Up @@ -104,13 +105,13 @@ def __init__(self, pid) -> None:
device_name = "ascend"
if is_camb_mlu_env:
device_name = "camb"
self.file_name = f"op_tools_results/process_monitor_result/process_monitor_result_{device_name}_pid{pid}_{os.getenv('label', '')}_{time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime())}.csv" # noqa: E501
self.file_name = f"op_tools_results/process_monitor_result/process_monitor_result_{device_name}_pid{pid}_{get_option('label', '')}_{time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime())}.csv" # noqa: E501
self.dir = self.file_name[0 : self.file_name.rfind("/")]

def append(self, info):
self.global_result.append(info)

if len(self.global_result) > int(os.getenv("OP_TOOLS_MAX_CACHE_SIZE", "1")):
if len(self.global_result) > int(get_option("OP_TOOLS_MAX_CACHE_SIZE", "1")):
self.write_to_file()

def write_to_file(self):
Expand Down
9 changes: 9 additions & 0 deletions op_tools/test/test_custom_apply_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,15 @@ def test_cast_dtype(self):
y = torch.rand(4, 5, dtype=torch.float16, device="cuda", requires_grad=True)
_test_function(x, y)

def test_overflow_check(self):
op_tools.apply_feature(
ops=["torch.add", "torch.sub", "torch.mul", "torch.div"],
feature="overflow_check",
)
x = torch.tensor([1, 2, 3], device="cuda", dtype=torch.float16, requires_grad=True)
y = torch.tensor([4, 5, 6], device="cuda", dtype=torch.float16, requires_grad=True)
_test_function(x, y)

def test_condition_fallback(self):
def condition_func(a, b, **kwargs):
if a.dtype == torch.float16:
Expand Down
57 changes: 45 additions & 12 deletions op_tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,12 +299,12 @@ def get_error_tolerance_for_type(dtype_name, atol, rtol):
op_name_processed = op_name.split(".")[-1].upper() + "_"
env_name = "AUTOCOMPARE_ERROR_TOLERANCE_" + dtype_name.upper()
high_priority_env_name = op_name_processed + env_name
if os.getenv(high_priority_env_name) is not None:
atol, rtol = map(float, os.getenv(high_priority_env_name).split(","))
elif os.getenv(env_name) is not None:
atol, rtol = map(float, os.getenv(env_name).split(","))
elif os.getenv("AUTOCOMPARE_ERROR_TOLERANCE") is not None:
atol, rtol = map(float, os.getenv("AUTOCOMPARE_ERROR_TOLERANCE").split(","))
if get_option(high_priority_env_name) is not None:
atol, rtol = map(float, get_option(high_priority_env_name).split(","))
elif get_option(env_name) is not None:
atol, rtol = map(float, get_option(env_name).split(","))
elif get_option("AUTOCOMPARE_ERROR_TOLERANCE") is not None:
atol, rtol = map(float, get_option("AUTOCOMPARE_ERROR_TOLERANCE").split(","))
return atol, rtol

if dtype == torch.float16:
Expand All @@ -317,8 +317,8 @@ def get_error_tolerance_for_type(dtype_name, atol, rtol):
return get_error_tolerance_for_type("FLOAT64", 1e-8, 1e-8)
else:
atol, rtol = 1e-3, 1e-3
if os.getenv("AUTOCOMPARE_ERROR_TOLERANCE") is not None:
atol, rtol = map(float, os.getenv("AUTOCOMPARE_ERROR_TOLERANCE").split(","))
if get_option("AUTOCOMPARE_ERROR_TOLERANCE") is not None:
atol, rtol = map(float, get_option("AUTOCOMPARE_ERROR_TOLERANCE").split(","))
return atol, rtol


Expand Down Expand Up @@ -510,7 +510,7 @@ def current_location(name=None, stack_depth=-1, print_stack=False):
else:
break

if print_stack or int(os.getenv("OP_TOOLS_PRINT_STACK", "0")) > 0:
if print_stack or int(get_option("OP_TOOLS_PRINT_STACK", "0")) > 0:
for i in range(len(stack) - 2):
file, line, func, text = stack[i]
print(f"{file}:{line} {func} {text}")
Expand All @@ -519,6 +519,39 @@ def current_location(name=None, stack_depth=-1, print_stack=False):
return f"{file}:{line} {func}: {text}"


def set_env_if_env_is_empty(env_name, env_value):
if os.environ.get(env_name, None) is None:
os.environ[env_name] = env_value
class OptionManger:
def __init__(self):
self.options = {}

def set_option(self, option_name, value):
if option_name in self.options:
if self.options[option_name] == value:
return
print(f"option: {option_name}={value}")
self.options[option_name] = value
os.environ[option_name] = str(value)

def get_option(self, option_name, default_value=None):
if option_name in os.environ:
value = os.environ[option_name]
self.set_option(option_name, value)
return value
else:
if default_value is not None:
self.set_option(option_name, default_value)
return default_value

def is_option_set(self, option_name):
return option_name in os.environ


global_option_manger = OptionManger()


def set_option_if_empty(option_name, value):
if not global_option_manger.is_option_set(option_name):
global_option_manger.set_option(option_name, value)


def get_option(option_name, default_value=None):
return global_option_manger.get_option(option_name, default_value)