Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zgc/ditorch fix ljx/sup op tools test #55

Merged
merged 9 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions op_tools/custom_apply_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def condition_func(*args, **kwargs):


def apply_feature(ops, feature, condition_func=lambda *args, **kwargs: True):
assert isinstance(ops, (str, list))
assert isinstance(ops, (str, list, tuple))
feature_options = [
"fallback",
"autocompare",
Expand All @@ -97,8 +97,8 @@ def apply_feature(ops, feature, condition_func=lambda *args, **kwargs: True):

if isinstance(ops, str):
apply_hook_to_ops(ops, hook_cls, condition_func)
elif isinstance(ops, list):
elif isinstance(ops, (list, tuple)):
for op in ops:
apply_hook_to_ops(op, hook_cls, condition_func)
else:
assert False, f"ops must be str or list, but got {type(ops)}"
assert False, f"ops must be str, tuple or list, but got {type(ops)}"
4 changes: 3 additions & 1 deletion op_tools/op_time_measure_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,15 @@ def grad_fun(grad_inputs, grad_outputs):
data_dict_list += packect_data_to_dict_list(self.name + " grad_outputs", serialize_args_to_dict(grad_outputs))
data_dict_list += packect_data_to_dict_list(self.name + " grad_inputs", serialize_args_to_dict(grad_inputs))
table = dict_data_list_to_table(data_dict_list)
print("\n" * 2, f"{self.name} forward_id: {self.id}")
print(table)
elasped_info_dict = {
"name": self.name,
"forward_id": self.id,
"backward_elasped": f"{(self.backward_elasped * 1000):>10.8f}",
"unit": "ms",
}
print(dict_data_list_to_table([elasped_info_dict]))
print(dict_data_list_to_table([elasped_info_dict]), "\n" * 2)
elasped_info_dict["grad_inputs"] = serialize_args_to_dict(grad_inputs)
elasped_info_dict["grad_outputs"] = serialize_args_to_dict(grad_outputs)
time_measure_result_cache.append(self.id, elasped_info_dict)
Expand Down Expand Up @@ -140,6 +141,7 @@ def after_call_op(self, result):
"unit": "ms",
}
print("\n" * 2)
print(f"{self.name} forward_id: {self.id}")
print(self.current_location)
print(forward_args_table)
print(dict_data_list_to_table([elasped_info_dict]))
Expand Down
2 changes: 2 additions & 0 deletions op_tools/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
prettytable
psutil
83 changes: 83 additions & 0 deletions op_tools/test/test_compare_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,89 @@ def test_compare_different_bool(self):
self.assertTrue(math.isnan(compare_info["max_relative_diff"]))
self.assertTrue(isinstance(compare_info["result_list"], list))

def test_compare_empty_tensor(self):
result1 = torch.empty(0).cuda()
result2 = torch.empty(0).cuda()
compare_info = compare_result("empty_tensor", result1, result2)
self.assertTrue(compare_info["allclose"])

def test_compare_empty_list(self):
result1 = []
result2 = []
compare_info = compare_result("empty_list", result1, result2)
self.assertTrue(compare_info["allclose"])

def test_compare_diff_shape_tensor(self):
result1 = torch.randn(10, 10).cuda()
result2 = torch.randn(20, 20).cuda()
compare_info = compare_result("diff_shape_tensor", result1, result2)
self.assertFalse(compare_info["allclose"])
self.assertIn("Inconsistent shape", compare_info["error_info"])

def test_compare_mixed_types(self):
result1 = [1, 2.0, 3]
result2 = [1, 2, 3.0]
compare_info = compare_result("mixed_types", result1, result2)
self.assertFalse(compare_info["allclose"])

def test_compare_invalid_type(self):
compare_info = compare_result("invalid_type", {}, [])
self.assertTrue(compare_info["allclose"])

def test_compare_invalid_value_a(self):
result1 = ["1", 2.0, 3]
result2 = [1, 2, 3.0]
compare_info = compare_result("invalid_string_a", result1, result2)
self.assertFalse(compare_info["allclose"])

def test_compare_invalid_value_b(self):
result1 = [1, 2.0, 3]
result2 = ["1", 2, 3.0]
compare_info = compare_result("invalid_string_b", result1, result2)
self.assertFalse(compare_info["allclose"])

def test_compare_same_dict(self):
result1 = {"1": 1}
result2 = {"1": 1}
compare_info = compare_result("same_dict", result1, result2)
self.assertTrue(compare_info["allclose"])

def test_compare_different_dict(self):
result1 = {"1": 2}
result2 = {"1": 1}
compare_info = compare_result("different_dict", result1, result2)
self.assertFalse(compare_info["allclose"])

def test_compare_different_dict2(self):
result1 = {"1": 2}
result2 = {"2": 2}
compare_info = compare_result("different_dict", result1, result2)
self.assertFalse(compare_info["allclose"])

def test_compare_same_dict_list_value(self):
result1 = {"1": [1, 2, 3]}
result2 = {"1": [1, 2, 3]}
compare_info = compare_result("same_dict_list_value", result1, result2)
self.assertTrue(compare_info["allclose"])

def test_compare_different_dict_list_value(self):
result1 = {"1": [2, 4, 6]}
result2 = {"1": [1, 2, 3]}
compare_info = compare_result("different_dict_list_value", result1, result2)
self.assertFalse(compare_info["allclose"])

def test_compare_dict_different_shape(self):
result1 = {"1": [2, 4, 6], "2": [4, 5, 6]}
result2 = {"1": [1, 2, 3]}
compare_info = compare_result("dict_different_shape", result1, result2)
self.assertFalse(compare_info["allclose"])

def test_compare_dict_different_list_shape(self):
result1 = {"1": [2, 4, 6, 8]}
result2 = {"1": [1, 2, 3]}
compare_info = compare_result("dict_different_list_shape", result1, result2)
self.assertFalse(compare_info["allclose"])

def test_compare_invalid_input(self):
self.assertTrue(compare_result("empty_list", [], [])["allclose"]) # 输入空列表
self.assertTrue(compare_result("empty_tesnsor", torch.empty(0).cuda(), torch.empty(0).cuda())["allclose"]) # 输入空张量
Expand Down
37 changes: 33 additions & 4 deletions op_tools/test/test_custom_apply_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ def _test_function(x, y):
b.device.type == x.device.type
c.device.type == x.device.type
d.device.type == x.device.type

a.requires_grad == x.requires_grad
b.requires_grad == x.requires_grad
c.requires_grad == x.requires_grad
Expand All @@ -28,12 +27,10 @@ def _test_function(x, y):
b.dtype == x.dtype
c.dtype == x.dtype
d.dtype == x.dtype

a.shape == x.shape
b.shape == x.shape
c.shape == x.shape
d.shape == x.shape

assert a.grad is None
assert b.grad is None
assert c.grad is None
Expand All @@ -42,7 +39,6 @@ def _test_function(x, y):
assert b.is_leaf is False
assert c.is_leaf is False
assert d.is_leaf is False

assert (x.grad is not None) == x.requires_grad
assert (y.grad is not None) == y.requires_grad

Expand Down Expand Up @@ -155,6 +151,39 @@ def test_condition_autocompare_linear(self):
out = func(input, weight, bias)
out.sum().backward()

def test_not_str_list(self):
op_tools.apply_feature(ops=("torch.add", "torch.sub", "torch.mul", "torch.div"), feature="fallback")
x = torch.tensor([1, 2, 3], dtype=torch.float16, device="cuda", requires_grad=True)
y = torch.tensor([4, 5, 6], dtype=torch.float16, device="cuda", requires_grad=True)
_test_function(x, y)

def test_str_null(self):
with unittest.TestCase().assertRaises(ValueError):
op_tools.apply_feature(ops="", feature="fallback")
x = torch.tensor([1, 2, 3], dtype=torch.float16, device="cuda", requires_grad=True)
y = torch.tensor([4, 5, 6], dtype=torch.float16, device="cuda", requires_grad=True)
_test_function(x, y)

def test_list_not_str(self):
op_tools.apply_feature(ops=[torch.add, torch.add, torch.sub, torch.div], feature="fallback")
x = torch.tensor([1, 2, 3], dtype=torch.float16, device="cuda", requires_grad=True)
y = torch.tensor([4, 5, 6], dtype=torch.float16, device="cuda", requires_grad=True)
_test_function(x, y)

def test_error_feature(self):
with unittest.TestCase().assertRaises(AssertionError):
op_tools.apply_feature(ops=["torch.add", "torch.mul", "torch.sub", "torch.div"], feature="falback")
x = torch.tensor([1, 2, 3], dtype=torch.float16, device="cuda", requires_grad=True)
y = torch.tensor([4, 5, 6], dtype=torch.float16, device="cuda", requires_grad=True)
_test_function(x, y)

def test_error_ops(self):
with unittest.TestCase().assertRaises(AttributeError):
op_tools.apply_feature(ops=[torch.ad], feature="fallback")
x = torch.tensor([1, 2, 3], dtype=torch.float16, device="cuda", requires_grad=True)
y = torch.tensor([4, 5, 6], dtype=torch.float16, device="cuda", requires_grad=True)
_test_function(x, y)


if __name__ == "__main__":
unittest.main()
44 changes: 44 additions & 0 deletions op_tools/test/test_get_error_tolerance.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,49 @@ def _test_get_error_tolerance(self, dtype, atol, rtol, op_name="test.op_name"):
self.assertTrue(atol > 0)
self.assertTrue(rtol > 0)

def _tearDown(self):
# Clean up environment variables to avoid side effects
env_list = [
"AUTOCOMPARE_ERROR_TOLERANCE",
"AUTOCOMPARE_ERROR_TOLERANCE_FLOAT32",
"AUTOCOMPARE_ERROR_TOLERANCE_FLOAT16",
"AUTOCOMPARE_ERROR_TOLERANCE_FLOAT64",
"OP_NAME_AUTOCOMPARE_ERROR_TOLERANCE_FLOAT16",
"OP_NAME_AUTOCOMPARE_ERROR_TOLERANCE_FLOAT32",
"OP_NAME_AUTOCOMPARE_ERROR_TOLERANCE_FLOAT64",
]
if "AUTOCOMPARE_ERROR_TOLERANCE" in os.environ:
del os.environ["AUTOCOMPARE_ERROR_TOLERANCE"]
for env_name in os.environ:
del os.environ[env_name]

def test_default_error_tolerance(self):
# Testing default tolerances without any environment variables set
self._test_get_error_tolerance(torch.float16, 1e-4, 1e-4)
self._test_get_error_tolerance(torch.bfloat16, 1e-3, 1e-3)
self._test_get_error_tolerance(torch.float32, 1e-5, 1e-5)
self._test_get_error_tolerance(torch.float64, 1e-8, 1e-8)
self._tearDown()

def test_environment_variable_override(self):
os.environ["AUTOCOMPARE_ERROR_TOLERANCE"] = "2,3"
self._test_get_error_tolerance(torch.float16, 2, 3)
self._test_get_error_tolerance(torch.float32, 2, 3)

os.environ["AUTOCOMPARE_ERROR_TOLERANCE_FLOAT32"] = "5e-5,6e-5"
self._test_get_error_tolerance(torch.float32, 5e-5, 6e-5)
self._tearDown()

def test_operation_specific_override(self):
os.environ["OP_NAME_AUTOCOMPARE_ERROR_TOLERANCE_FLOAT32"] = "30,40"
self._test_get_error_tolerance(torch.float32, 30, 40, op_name="test.op_name")
self._tearDown()

def test_unknown_dtype(self):
# Test with a dtype that is not explicitly defined
self._test_get_error_tolerance(torch.int32, 1e-3, 1e-3)
self._tearDown()

def test_get_error_tolerance(self):
os.environ["AUTOCOMPARE_ERROR_TOLERANCE"] = "2,3"
self._test_get_error_tolerance(torch.float16, 2, 3)
Expand All @@ -28,6 +71,7 @@ def test_get_error_tolerance(self):
self._test_get_error_tolerance(torch.int32, 2, 3)
os.environ["OP_NAME_AUTOCOMPARE_ERROR_TOLERANCE_FLOAT32"] = "30,40"
self._test_get_error_tolerance(torch.float32, 30, 40, op_name="test.op_name")
self._tearDown()


if __name__ == "__main__":
Expand Down
50 changes: 48 additions & 2 deletions op_tools/test/test_op_autocompare.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
# Copyright (c) 2024, DeepLink.
# 这里提供了autocompare对一些简单张量操作的的测试,包含两种使用方法
# 可能缺少一些别的基础操作

import torch
import ditorch

import op_tools
import os


def f():
Expand Down Expand Up @@ -31,9 +34,52 @@ def f():
with op_tools.OpAutoCompare():
f()


# usage2
comparer = op_tools.OpAutoCompare()
comparer.start()
for i in range(3):
f()
comparer.stop()


# usage3
os.environ["OP_AUTOCOMPARE_DISABLE_LIST"] = "torch.Tensor.add,torch.Tensor.sub"
comparer.start()
f()
comparer.stop()

# usage4
os.environ["OP_AUTOCOMPARE_DISABLE_LIST"] = ""
os.environ["OP_AUTOCOMPARE_LIST"] = "torch.Tensor.backward" # 与EXCLUDE_OPS重复
comparer.start()
f()
comparer.stop()

# usage5
os.environ["OP_AUTOCOMPARE_DISABLE_LIST"] = ""
os.environ["OP_AUTOCOMPARE_LIST"] = "" # 空
comparer.start()
f()
comparer.stop()

# usage6
os.environ["OP_AUTOCOMPARE_DISABLE_LIST"] = "torch.Tensor.sort"
os.environ["OP_AUTOCOMPARE_LIST"] = "torch.Tensor.sort,torch.Tensor.add" # 重叠
comparer.start()
f()
comparer.stop()

# usage7
os.environ["OP_AUTOCOMPARE_DISABLE_LIST"] = "torch.Tensor.add,torch.Tensor.sub"
if "OP_AUTOCOMPARE_LIST" in os.environ:
del os.environ["OP_AUTOCOMPARE_LIST"] # 删除
comparer.start()
f()
comparer.stop()

# usage8
os.environ["OP_AUTOCOMPARE_DISABLE_LIST"] = "torch.Tensor.add,torch.Tensor.sub"
os.environ["OP_AUTOCOMPARE_LIST"] = "torch.Tensor.uniform_,torch.empty_like" # 与random_number_gen_ops重叠
comparer.start()
f()
comparer.stop()
24 changes: 23 additions & 1 deletion op_tools/test/test_op_capture.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,29 @@ def f():

# usage4
os.environ["OP_CAPTURE_DISABLE_LIST"] = ""
os.environ["OP_CAPTURE_LIST"] = "torch.Tensor.sort" # only capture these op
os.environ["OP_CAPTURE_LIST"] = "torch.Tensor.backward" # capture与EXCLUDE_OPS重复
capture.start()
f()
capture.stop()

# usage5
os.environ["OP_CAPTURE_DISABLE_LIST"] = ""
os.environ["OP_CAPTURE_LIST"] = "" # 空
capture.start()
f()
capture.stop()

# usage6
os.environ["OP_CAPTURE_DISABLE_LIST"] = "torch.Tensor.sort"
os.environ["OP_CAPTURE_LIST"] = "torch.Tensor.sort,torch.Tensor.add" # capture和disable重叠
capture.start()
f()
capture.stop()

# usage7
os.environ["OP_CAPTURE_DISABLE_LIST"] = "torch.Tensor.add,torch.Tensor.sub"
if "OP_CAPTURE_LIST" in os.environ:
del os.environ["OP_CAPTURE_LIST"] # 删除
capture.start()
f()
capture.stop()
Loading