Skip to content

Commit

Permalink
Support hook to take effect only when user-defined conditions are met (
Browse files Browse the repository at this point in the history
…#23)

* Support hook to take effect only when user-defined conditions are met

* Support hook to take effect only when user-defined conditions are met
  • Loading branch information
zhaoguochun1995 authored Sep 9, 2024
1 parent 1d809a1 commit 6b39137
Show file tree
Hide file tree
Showing 8 changed files with 258 additions and 71 deletions.
2 changes: 1 addition & 1 deletion op_tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@
OpDtypeCast,
)

from .custom_apply_hook import fallback_ops, dump_all_ops_args
from .custom_apply_hook import apply_feature
16 changes: 14 additions & 2 deletions op_tools/base_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
def is_should_apply_hook(name, func, *args, **kwargs):
if name is None:
return False
if inspect.isroutine(func) == False:
if callable(func) == False:
return False
if name.startswith("torch.Tensor.") and (
name.endswith("__get__") or name.endswith("__set__")
Expand Down Expand Up @@ -47,6 +47,18 @@ def __init__(self, name, func) -> None:
self.exception = None
self.func = func
self.wrapper_func = self.construct_wrapper_func()
self.condition_funcs = []

def add_condition_func(self, func):
if not callable(func):
raise ValueError("condition_func must be callable")
self.condition_funcs.append(func)

def conditions_met(self, *args, **kwargs):
for func in self.condition_funcs:
if not func(*args, **kwargs):
return False
return True

def before_call_op(self, *args, **kwargs):
self.args = args
Expand Down Expand Up @@ -80,9 +92,9 @@ def is_should_apply(self, *args, **kwargs):

def __call__(self, *args, **kwargs):
self.args_on_cpu, self.device = is_cpu_op(*args, **kwargs)
# import pdb; pdb.set_trace()
if (
self.enable
and self.conditions_met(*args, **kwargs)
and not self.args_on_cpu
and is_should_apply_hook(self.name, self.func, *args, **kwargs)
and self.is_should_apply(*args, **kwargs)
Expand Down
67 changes: 44 additions & 23 deletions op_tools/custom_apply_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@ def get_func_name(func):
return name


def apply_hook_to_ops(ops, hook):
def apply_hook_to_ops(ops, hook, condition_funcs=[]):
index = -1
for op in traverse_container(ops):
index += 1
if isinstance(op, str):
func = get_function_from_string(op)
name = op
Expand All @@ -46,29 +48,48 @@ def apply_hook_to_ops(ops, hook):
continue
if not hasattr(func, "__name__"):
func.__name__ = name.split(".")[-1]
if isinstance(condition_funcs, list):
if len(condition_funcs) > index:
condition_func = condition_funcs[index]
else:
condition_func = lambda *args, **kwargs: True
else:
condition_func = condition_funcs
assert callable(condition_func)

hook_obj = hook(name, func)
hook_obj.add_condition_func(condition_func)
setattr(module, func.__name__, hook_obj)


def fallback_ops(ops):
apply_hook_to_ops(ops, OpFallbackHook)


def fallback_op_if(op, condition=lambda *args, **kwargs: False):
apply_hook_to_ops(op, OpFallbackHook)


def dump_ops_args(ops):
apply_hook_to_ops(ops, OpDispatchWatcherHook)


def dump_all_ops_args():
apply_hook_to_ops(torch, OpDispatchWatcherHook)


def autocompare_ops(ops):
apply_hook_to_ops(ops, OpAutoCompareHook)


def measure_ops_elasped(ops):
apply_hook_to_ops(ops, OpTimeMeasureHook)
def apply_feature(ops, feature, condition_func=lambda *args, **kwargs: True):
assert isinstance(ops, (str, list))
feature_options = [
"fallback",
"autocompare",
"measure_op_time",
"dump_op_args",
"cast_dtype",
]
assert (
feature in feature_options
), f"feature must be one of {feature_options}, but got {feature}"
assert callable(condition_func)
if feature == "fallback":
hook_cls = OpFallbackHook
elif feature == "autocompare":
hook_cls = OpAutoCompareHook
elif feature == "measure_op_time":
hook_cls = OpTimeMeasureHook
elif feature == "dump_op_args":
hook_cls = OpDispatchWatcherHook
elif feature == "cast_dtype":
hook_cls = OpDtypeCastHook

if isinstance(ops, str):
apply_hook_to_ops(ops, hook_cls, condition_func)
elif isinstance(ops, list):
for op in ops:
apply_hook_to_ops(op, hook_cls, condition_func)
else:
assert False, f"ops must be str or list, but got {type(ops)}"
23 changes: 17 additions & 6 deletions op_tools/op_autocompare_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,12 @@ def before_call_op(self, *args, **kwargs):
self.args_cpu = to_device(
"cpu",
self.args,
detach=True,
)
self.kwargs_cpu = to_device(
"cpu",
self.kwargs or {},
detach=True,
)

def after_call_op(self, result):
Expand All @@ -189,7 +191,7 @@ def after_call_op(self, result):
self.result = result
try:
self.result_cpu = self.func(*self.args_cpu, **self.kwargs_cpu)
self.result_device = to_device("cpu", self.result)
self.result_device = to_device("cpu", self.result, detach=True)
self.dtype_cast_dict = dict()
args_cpu = self.args_cpu
except Exception as e:
Expand All @@ -199,11 +201,13 @@ def after_call_op(self, result):
"cpu",
self.args_cpu,
dtype_cast_dict=self.dtype_cast_dict,
detach=True,
)
self.kwargs_cpu = to_device(
"cpu",
self.kwargs_cpu or {},
dtype_cast_dict=self.dtype_cast_dict,
detach=True,
)
# RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.
if (is_inplace_op(self.name) or is_view_op(self.name)) and self.args[
Expand All @@ -220,6 +224,7 @@ def after_call_op(self, result):
"cpu",
self.result,
dtype_cast_dict=self.dtype_cast_dict,
detach=True,
)

if is_inplace_op(self.name):
Expand Down Expand Up @@ -259,12 +264,16 @@ def after_call_op(self, result):
if isinstance(arg, torch.Tensor) and arg.requires_grad:
self.backward_hook_handle.register_tensor_hook(index, arg)

self.args = to_device("cpu", self.args)
self.kwargs = to_device("cpu", self.kwargs or {})
self.args = to_device("cpu", self.args, detach=True)
self.kwargs = to_device("cpu", self.kwargs or {}, detach=True)

def run_backward_on_cpu(self, grad_inputs, grad_output):
self.grad_outputs_cpu = to_device("cpu", grad_output, self.dtype_cast_dict)
self.grad_inputs_cpu = to_device("cpu", grad_inputs, self.dtype_cast_dict)
self.grad_outputs_cpu = to_device(
"cpu", grad_output, dtype_cast_dict=self.dtype_cast_dict, detach=True
)
self.grad_inputs_cpu = to_device(
"cpu", grad_inputs, dtype_cast_dict=self.dtype_cast_dict, detach=True
)
for arg_cpu in traverse_container(self.args_cpu):
if isinstance(arg_cpu, torch.Tensor) and arg_cpu.grad is not None:
arg_cpu.grad.zero_()
Expand Down Expand Up @@ -332,7 +341,9 @@ def compare_all_grad(self):
def set_input_grad(self, index, grad):
if not hasattr(self, "args_grad"):
self.args_grad = [None for i in range(len(self.args))]
self.args_grad[index] = to_device("cpu", grad, self.dtype_cast_dict)
self.args_grad[index] = to_device(
"cpu", grad, dtype_cast_dict=self.dtype_cast_dict, detach=True
)

def save_forward_args(self):
save_op_args(
Expand Down
17 changes: 14 additions & 3 deletions op_tools/op_dtype_cast_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,17 @@ def before_call_op(self, *args, **kwargs):
self.is_cpu_op, self.device = is_cpu_op(*args, **kwargs)
if self.is_cpu_op:
return
self.args = to_device(self.device, self.args, self.dtype_cast_dict, False)
self.args = to_device(
self.device,
self.args,
dtype_cast_dict=self.dtype_cast_dict,
detach=False,
)
self.kwargs = to_device(
self.device, self.kwargs or {}, self.dtype_cast_dict, False
self.device,
self.kwargs or {},
dtype_cast_dict=self.dtype_cast_dict,
detach=False,
)
self.dtype_cast_back_dict = {}
self.ins_list = []
Expand All @@ -62,7 +70,10 @@ def after_call_op(self, result):
with DisableHookGuard():
self.result_raw = result
self.result = to_device(
self.device, self.result, self.dtype_cast_back_dict, False
self.device,
self.result,
dtype_cast_dict=self.dtype_cast_back_dict,
detach=False,
)
i = -1
for out in traverse_container(self.result_raw):
Expand Down
Loading

0 comments on commit 6b39137

Please sign in to comment.