Skip to content

Commit

Permalink
Support hook to take effect only when user-defined conditions are met
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaoguochun1995 committed Sep 9, 2024
1 parent ae53f6a commit 7c3c294
Show file tree
Hide file tree
Showing 6 changed files with 209 additions and 120 deletions.
12 changes: 10 additions & 2 deletions op_tools/custom_apply_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,13 @@ def apply_hook_to_ops(ops, hook, condition_funcs=[]):

def apply_feature(ops, feature, condition_func=lambda *args, **kwargs: True):
assert isinstance(ops, (str, list))
feature_options = ["fallback", "autocompare", "op_time_measure", "dump_op_args"]
feature_options = [
"fallback",
"autocompare",
"measure_op_time",
"dump_op_args",
"cast_dtype",
]
assert (
feature in feature_options
), f"feature must be one of {feature_options}, but got {feature}"
Expand All @@ -73,10 +79,12 @@ def apply_feature(ops, feature, condition_func=lambda *args, **kwargs: True):
hook_cls = OpFallbackHook
elif feature == "autocompare":
hook_cls = OpAutoCompareHook
elif feature == "op_time_measure":
elif feature == "measure_op_time":
hook_cls = OpTimeMeasureHook
elif feature == "dump_op_args":
hook_cls = OpDispatchWatcherHook
elif feature == "cast_dtype":
hook_cls = OpDtypeCastHook

if isinstance(ops, str):
apply_hook_to_ops(ops, hook_cls, condition_func)
Expand Down
23 changes: 17 additions & 6 deletions op_tools/op_autocompare_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,12 @@ def before_call_op(self, *args, **kwargs):
self.args_cpu = to_device(
"cpu",
self.args,
detach=True,
)
self.kwargs_cpu = to_device(
"cpu",
self.kwargs or {},
detach=True,
)

def after_call_op(self, result):
Expand All @@ -189,7 +191,7 @@ def after_call_op(self, result):
self.result = result
try:
self.result_cpu = self.func(*self.args_cpu, **self.kwargs_cpu)
self.result_device = to_device("cpu", self.result)
self.result_device = to_device("cpu", self.result, True)
self.dtype_cast_dict = dict()
args_cpu = self.args_cpu
except Exception as e:
Expand All @@ -199,11 +201,13 @@ def after_call_op(self, result):
"cpu",
self.args_cpu,
dtype_cast_dict=self.dtype_cast_dict,
detach=True,
)
self.kwargs_cpu = to_device(
"cpu",
self.kwargs_cpu or {},
dtype_cast_dict=self.dtype_cast_dict,
detach=True,
)
# RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.
if (is_inplace_op(self.name) or is_view_op(self.name)) and self.args[
Expand All @@ -220,6 +224,7 @@ def after_call_op(self, result):
"cpu",
self.result,
dtype_cast_dict=self.dtype_cast_dict,
detach=True,
)

if is_inplace_op(self.name):
Expand Down Expand Up @@ -259,12 +264,16 @@ def after_call_op(self, result):
if isinstance(arg, torch.Tensor) and arg.requires_grad:
self.backward_hook_handle.register_tensor_hook(index, arg)

self.args = to_device("cpu", self.args)
self.kwargs = to_device("cpu", self.kwargs or {})
self.args = to_device("cpu", self.args, detach=True)
self.kwargs = to_device("cpu", self.kwargs or {}, detach=True)

def run_backward_on_cpu(self, grad_inputs, grad_output):
self.grad_outputs_cpu = to_device("cpu", grad_output, self.dtype_cast_dict)
self.grad_inputs_cpu = to_device("cpu", grad_inputs, self.dtype_cast_dict)
self.grad_outputs_cpu = to_device(
"cpu", grad_output, dtype_cast_dict=self.dtype_cast_dict, detach=True
)
self.grad_inputs_cpu = to_device(
"cpu", grad_inputs, dtype_cast_dict=self.dtype_cast_dict, detach=True
)
for arg_cpu in traverse_container(self.args_cpu):
if isinstance(arg_cpu, torch.Tensor) and arg_cpu.grad is not None:
arg_cpu.grad.zero_()
Expand Down Expand Up @@ -332,7 +341,9 @@ def compare_all_grad(self):
def set_input_grad(self, index, grad):
if not hasattr(self, "args_grad"):
self.args_grad = [None for i in range(len(self.args))]
self.args_grad[index] = to_device("cpu", grad, self.dtype_cast_dict)
self.args_grad[index] = to_device(
"cpu", grad, dtype_cast_dict=self.dtype_cast_dict, detach=True
)

def save_forward_args(self):
save_op_args(
Expand Down
17 changes: 14 additions & 3 deletions op_tools/op_dtype_cast_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,17 @@ def before_call_op(self, *args, **kwargs):
self.is_cpu_op, self.device = is_cpu_op(*args, **kwargs)
if self.is_cpu_op:
return
self.args = to_device(self.device, self.args, self.dtype_cast_dict, False)
self.args = to_device(
self.device,
self.args,
dtype_cast_dict=self.dtype_cast_dict,
detach=False,
)
self.kwargs = to_device(
self.device, self.kwargs or {}, self.dtype_cast_dict, False
self.device,
self.kwargs or {},
dtype_cast_dict=self.dtype_cast_dict,
detach=False,
)
self.dtype_cast_back_dict = {}
self.ins_list = []
Expand All @@ -62,7 +70,10 @@ def after_call_op(self, result):
with DisableHookGuard():
self.result_raw = result
self.result = to_device(
self.device, self.result, self.dtype_cast_back_dict, False
self.device,
self.result,
dtype_cast_dict=self.dtype_cast_back_dict,
detach=False,
)
i = -1
for out in traverse_container(self.result_raw):
Expand Down
167 changes: 167 additions & 0 deletions op_tools/test/test_custom_apply_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
# Copyright (c) 2024, DeepLink.
import torch
import op_tools
import unittest


def _test_function(x, y):
a = torch.add(x, y) * 2
b = torch.sub(a, y) / 3
c = torch.mul(b, a) + 1
d = torch.div(c, b) - 2
d.backward(torch.ones_like(d))
a.is_cpu == x.is_cpu
b.is_cpu == x.is_cpu
c.is_cpu == x.is_cpu
d.is_cpu == x.is_cpu

a.requires_grad == x.requires_grad
b.requires_grad == x.requires_grad
c.requires_grad == x.requires_grad
d.requires_grad == x.requires_grad
a.device == x.device
b.device == x.device
c.device == x.device
d.device == x.device
a.dtype == x.dtype
b.dtype == x.dtype
c.dtype == x.dtype
d.dtype == x.dtype

a.shape == x.shape
b.shape == x.shape
c.shape == x.shape
d.shape == x.shape

assert a.grad is None
assert b.grad is None
assert c.grad is None
assert d.grad is None
assert a.is_leaf == False
assert b.is_leaf == False
assert c.is_leaf == False
assert d.is_leaf == False

assert (x.grad is not None) == x.requires_grad
assert (y.grad is not None) == y.requires_grad


class TestCustomApplyHook(unittest.TestCase):
def test_fallback_op(self):
op_tools.apply_feature(
ops=["torch.add", "torch.sub", "torch.mul", "torch.div"], feature="fallback"
)
x = torch.tensor(
[1, 2, 3], dtype=torch.float16, device="cuda", requires_grad=True
)
y = torch.tensor(
[4, 5, 6], dtype=torch.float16, device="cuda", requires_grad=True
)
_test_function(x, y)

def test_dump_all_args(self):
op_tools.apply_feature(
ops=["torch.add", "torch.sub", "torch.mul", "torch.div"],
feature="autocompare",
)
x = torch.tensor(
[1, 2, 3], dtype=torch.float16, device="cuda", requires_grad=True
)
y = torch.tensor(
[4, 5, 6], dtype=torch.float16, device="cuda", requires_grad=True
)

_test_function(x, y)

def test_measure_op_time(self):
op_tools.apply_feature(
ops=["torch.add", "torch.sub", "torch.mul", "torch.div"],
feature="measure_op_time",
)
x = torch.tensor(
[1, 2, 3], device="cuda", dtype=torch.float16, requires_grad=True
)
y = torch.tensor(
[4, 5, 6], device="cuda", dtype=torch.float16, requires_grad=True
)
_test_function(x, y)

def test_cast_dtype(self):
op_tools.apply_feature(
ops=["torch.add", "torch.sub", "torch.mul", "torch.div"],
feature="cast_dtype",
)
x = torch.randn(4, 5, dtype=torch.float16, device="cuda", requires_grad=True)
y = torch.rand(4, 5, dtype=torch.float16, device="cuda", requires_grad=True)

def test_condition_fallback(self):
def condition_func(a, b, **kwargs):
if a.dtype == torch.float16:
print(f"fallback beacuse input dtype is float16")
return True
else:
print(f"not fallback beacuse input dtype is {a.dtype}")
return False

op_tools.apply_feature(
ops=["torch.add", "torch.sub", "torch.mul", "torch.div"],
feature="dump_op_args",
condition_func=condition_func,
)
x = torch.tensor(
[1, 2, 3], dtype=torch.float16, device="cuda", requires_grad=True
)
y = torch.tensor(
[4, 5, 6], dtype=torch.float16, device="cuda", requires_grad=True
)
_test_function(x, y)

x = torch.tensor(
[1, 2, 3], dtype=torch.float32, device="cuda", requires_grad=True
)
y = torch.tensor(
[4, 5, 6], dtype=torch.float32, device="cuda", requires_grad=True
)
_test_function(x, y)

def test_condition_autocompare(self):
def condition_func1(a, b, **kwargs):
if a.dtype == torch.float16:
print(f"autocompare beacuse input dtype is float16")
return True
else:
print(f"not autocompare beacuse input dtype is {a.dtype}")
return False

def condition_func2(a, b, **kwargs):
if a.dim() == 2:
print(f"autocompare beacuse input dim is 2")
return True
else:
print(f"not autocompare beacuse input dim is {a.dim()}")
return False

op_tools.apply_feature(
"torch.add", feature="autocompare", condition_func=condition_func1
)
op_tools.apply_feature(
"torch.sub", feature="autocompare", condition_func=condition_func2
)
op_tools.apply_feature(
"torch.mul", feature="autocompare", condition_func=condition_func1
)
op_tools.apply_feature(
"torch.div", feature="autocompare", condition_func=condition_func2
)

x = torch.randn(3, 4, dtype=torch.float16, device="cuda", requires_grad=True)
y = torch.randn(3, 4, dtype=torch.float16, device="cuda", requires_grad=True)
_test_function(x, y)

x = torch.randn(3, 4, dtype=torch.float32, device="cuda", requires_grad=True)
y = torch.randn(3, 4, dtype=torch.float32, device="cuda", requires_grad=True)
_test_function(x, y)


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit 7c3c294

Please sign in to comment.