Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support hook to take effect only when user-defined conditions are met #23

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion op_tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@
OpDtypeCast,
)

from .custom_apply_hook import fallback_ops, dump_all_ops_args
from .custom_apply_hook import apply_feature
16 changes: 14 additions & 2 deletions op_tools/base_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
def is_should_apply_hook(name, func, *args, **kwargs):
if name is None:
return False
if inspect.isroutine(func) == False:
if callable(func) == False:
return False
if name.startswith("torch.Tensor.") and (
name.endswith("__get__") or name.endswith("__set__")
Expand Down Expand Up @@ -47,6 +47,18 @@ def __init__(self, name, func) -> None:
self.exception = None
self.func = func
self.wrapper_func = self.construct_wrapper_func()
self.condition_funcs = []

def add_condition_func(self, func):
if not callable(func):
raise ValueError("condition_func must be callable")
self.condition_funcs.append(func)

def conditions_met(self, *args, **kwargs):
for func in self.condition_funcs:
if not func(*args, **kwargs):
return False
return True

def before_call_op(self, *args, **kwargs):
self.args = args
Expand Down Expand Up @@ -80,9 +92,9 @@ def is_should_apply(self, *args, **kwargs):

def __call__(self, *args, **kwargs):
self.args_on_cpu, self.device = is_cpu_op(*args, **kwargs)
# import pdb; pdb.set_trace()
if (
self.enable
and self.conditions_met(*args, **kwargs)
and not self.args_on_cpu
and is_should_apply_hook(self.name, self.func, *args, **kwargs)
and self.is_should_apply(*args, **kwargs)
Expand Down
67 changes: 44 additions & 23 deletions op_tools/custom_apply_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@ def get_func_name(func):
return name


def apply_hook_to_ops(ops, hook):
def apply_hook_to_ops(ops, hook, condition_funcs=[]):
index = -1
for op in traverse_container(ops):
index += 1
if isinstance(op, str):
func = get_function_from_string(op)
name = op
Expand All @@ -46,29 +48,48 @@ def apply_hook_to_ops(ops, hook):
continue
if not hasattr(func, "__name__"):
func.__name__ = name.split(".")[-1]
if isinstance(condition_funcs, list):
if len(condition_funcs) > index:
condition_func = condition_funcs[index]
else:
condition_func = lambda *args, **kwargs: True
else:
condition_func = condition_funcs
assert callable(condition_func)

hook_obj = hook(name, func)
hook_obj.add_condition_func(condition_func)
setattr(module, func.__name__, hook_obj)


def fallback_ops(ops):
apply_hook_to_ops(ops, OpFallbackHook)


def fallback_op_if(op, condition=lambda *args, **kwargs: False):
apply_hook_to_ops(op, OpFallbackHook)


def dump_ops_args(ops):
apply_hook_to_ops(ops, OpDispatchWatcherHook)


def dump_all_ops_args():
apply_hook_to_ops(torch, OpDispatchWatcherHook)


def autocompare_ops(ops):
apply_hook_to_ops(ops, OpAutoCompareHook)


def measure_ops_elasped(ops):
apply_hook_to_ops(ops, OpTimeMeasureHook)
def apply_feature(ops, feature, condition_func=lambda *args, **kwargs: True):
assert isinstance(ops, (str, list))
feature_options = [
"fallback",
"autocompare",
"measure_op_time",
"dump_op_args",
"cast_dtype",
]
assert (
feature in feature_options
), f"feature must be one of {feature_options}, but got {feature}"
assert callable(condition_func)
if feature == "fallback":
hook_cls = OpFallbackHook
elif feature == "autocompare":
hook_cls = OpAutoCompareHook
elif feature == "measure_op_time":
hook_cls = OpTimeMeasureHook
elif feature == "dump_op_args":
hook_cls = OpDispatchWatcherHook
elif feature == "cast_dtype":
hook_cls = OpDtypeCastHook

if isinstance(ops, str):
apply_hook_to_ops(ops, hook_cls, condition_func)
elif isinstance(ops, list):
for op in ops:
apply_hook_to_ops(op, hook_cls, condition_func)
else:
assert False, f"ops must be str or list, but got {type(ops)}"
23 changes: 17 additions & 6 deletions op_tools/op_autocompare_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,12 @@ def before_call_op(self, *args, **kwargs):
self.args_cpu = to_device(
"cpu",
self.args,
detach=True,
)
self.kwargs_cpu = to_device(
"cpu",
self.kwargs or {},
detach=True,
)

def after_call_op(self, result):
Expand All @@ -189,7 +191,7 @@ def after_call_op(self, result):
self.result = result
try:
self.result_cpu = self.func(*self.args_cpu, **self.kwargs_cpu)
self.result_device = to_device("cpu", self.result)
self.result_device = to_device("cpu", self.result, detach=True)
self.dtype_cast_dict = dict()
args_cpu = self.args_cpu
except Exception as e:
Expand All @@ -199,11 +201,13 @@ def after_call_op(self, result):
"cpu",
self.args_cpu,
dtype_cast_dict=self.dtype_cast_dict,
detach=True,
)
self.kwargs_cpu = to_device(
"cpu",
self.kwargs_cpu or {},
dtype_cast_dict=self.dtype_cast_dict,
detach=True,
)
# RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.
if (is_inplace_op(self.name) or is_view_op(self.name)) and self.args[
Expand All @@ -220,6 +224,7 @@ def after_call_op(self, result):
"cpu",
self.result,
dtype_cast_dict=self.dtype_cast_dict,
detach=True,
)

if is_inplace_op(self.name):
Expand Down Expand Up @@ -259,12 +264,16 @@ def after_call_op(self, result):
if isinstance(arg, torch.Tensor) and arg.requires_grad:
self.backward_hook_handle.register_tensor_hook(index, arg)

self.args = to_device("cpu", self.args)
self.kwargs = to_device("cpu", self.kwargs or {})
self.args = to_device("cpu", self.args, detach=True)
self.kwargs = to_device("cpu", self.kwargs or {}, detach=True)

def run_backward_on_cpu(self, grad_inputs, grad_output):
self.grad_outputs_cpu = to_device("cpu", grad_output, self.dtype_cast_dict)
self.grad_inputs_cpu = to_device("cpu", grad_inputs, self.dtype_cast_dict)
self.grad_outputs_cpu = to_device(
"cpu", grad_output, dtype_cast_dict=self.dtype_cast_dict, detach=True
)
self.grad_inputs_cpu = to_device(
"cpu", grad_inputs, dtype_cast_dict=self.dtype_cast_dict, detach=True
)
for arg_cpu in traverse_container(self.args_cpu):
if isinstance(arg_cpu, torch.Tensor) and arg_cpu.grad is not None:
arg_cpu.grad.zero_()
Expand Down Expand Up @@ -332,7 +341,9 @@ def compare_all_grad(self):
def set_input_grad(self, index, grad):
if not hasattr(self, "args_grad"):
self.args_grad = [None for i in range(len(self.args))]
self.args_grad[index] = to_device("cpu", grad, self.dtype_cast_dict)
self.args_grad[index] = to_device(
"cpu", grad, dtype_cast_dict=self.dtype_cast_dict, detach=True
)

def save_forward_args(self):
save_op_args(
Expand Down
17 changes: 14 additions & 3 deletions op_tools/op_dtype_cast_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,17 @@ def before_call_op(self, *args, **kwargs):
self.is_cpu_op, self.device = is_cpu_op(*args, **kwargs)
if self.is_cpu_op:
return
self.args = to_device(self.device, self.args, self.dtype_cast_dict, False)
self.args = to_device(
self.device,
self.args,
dtype_cast_dict=self.dtype_cast_dict,
detach=False,
)
self.kwargs = to_device(
self.device, self.kwargs or {}, self.dtype_cast_dict, False
self.device,
self.kwargs or {},
dtype_cast_dict=self.dtype_cast_dict,
detach=False,
)
self.dtype_cast_back_dict = {}
self.ins_list = []
Expand All @@ -62,7 +70,10 @@ def after_call_op(self, result):
with DisableHookGuard():
self.result_raw = result
self.result = to_device(
self.device, self.result, self.dtype_cast_back_dict, False
self.device,
self.result,
dtype_cast_dict=self.dtype_cast_back_dict,
detach=False,
)
i = -1
for out in traverse_container(self.result_raw):
Expand Down
Loading