Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 59 additions & 5 deletions captum/optim/_core/output_hook.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import warnings
from typing import Callable, Iterable, Tuple
from collections import OrderedDict
from typing import Callable, Dict, Iterable, Optional, Tuple
from warnings import warn

import torch
Expand All @@ -15,6 +16,9 @@ def __init__(self, target_modules: Iterable[nn.Module]) -> None:

target_modules (Iterable of nn.Module): A list of nn.Module targets.
"""
for module in target_modules:
# Clean up any old hooks that weren't properly deleted
_remove_all_forward_hooks(module, "module_outputs_forward_hook")
self.outputs: ModuleOutputMapping = dict.fromkeys(target_modules, None)
self.hooks = [
module.register_forward_hook(self._forward_hook())
Expand All @@ -33,13 +37,13 @@ def is_ready(self) -> bool:

def _forward_hook(self) -> Callable:
"""
Return the forward_hook function.
Return the module_outputs_forward_hook forward hook function.

Returns:
forward_hook (Callable): The forward_hook function.
forward_hook (Callable): The module_outputs_forward_hook function.
"""

def forward_hook(
def module_outputs_forward_hook(
module: nn.Module, input: Tuple[torch.Tensor], output: torch.Tensor
) -> None:
assert module in self.outputs.keys()
Expand All @@ -57,7 +61,7 @@ def forward_hook(
"that you are passing model layers in your losses."
)

return forward_hook
return module_outputs_forward_hook

def consume_outputs(self) -> ModuleOutputMapping:
"""
Expand Down Expand Up @@ -130,3 +134,53 @@ def __call__(self, input_t: TupleOfTensorsOrTensorType) -> ModuleOutputMapping:
finally:
self.layers.remove_hooks()
return activations_dict


def _remove_all_forward_hooks(
module: torch.nn.Module, hook_fn_name: Optional[str] = None
) -> None:
"""
This function removes all forward hooks in the specified module, without requiring
any hook handles. This lets us clean up & remove any hooks that weren't property
deleted.

Warning: Various PyTorch modules and systems make use of hooks, and thus extreme
Copy link
Contributor

@NarineK NarineK Jan 18, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As you mentioned this is dangerous to do because we can remove hooks that we didn't set and the function name happen to be the same. Why don't we remove only the hooks that we set right after we are done with the hook ? I thought that all hooks we be removed here, won't they ?
https://github.com/pytorch/captum/blob/6e7f0bd761ca538a6fce30dbd4e0e6007fc6abe5/captum/optim/_core/optimization.py#L97

If there is a bug and we miss some of them, then we should rather make sure that we removed them after optimization is finished.

Copy link
Contributor Author

@ProGamerGov ProGamerGov Jan 18, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea was that the problem would be fixed if the individual ran InputOptimization a second time, but I now realize that this would break the ability to use multiple instances of InputOptimization.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could somehow let the user decide when to run the cleanup code?

[_remove_all_forward_hooks(module, "module_outputs_forward_hook") for module in target_modules]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There does not appear to be a way to detect whether or not the hooks are still being used, so letting the user decide if they want to perform this fix is probably the best choice.

Copy link
Contributor Author

@ProGamerGov ProGamerGov Jan 18, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could probably add this function to the API for users:

def cleanup_module_hooks(modules: Union[nn.Module, List[nn.Module]) -> None:
    """
    Remove any InputOptimization hooks from the specified modules. This may be useful
    in the event that something goes wrong in between creating the InputOptimization
    instance and running the optimization function, or if InputOptimization fails
    without properly removing it's hooks.

    Warning: This function will remove all the hooks placed by InputOptimization
    instances on the target modules, and thus can interfere with using multiple
    InputOptimization instances.

    Args:

        modules (nn.Module or list of nn.Module): Any module instances that contain
            hooks created by InputOptimization, for which the removal of the hooks is
            required.
    """
    if not hasattr(modules, "__iter__"):
        modules = [modules]
    # Captum ModuleOutputsHook uses "module_outputs_forward_hook" hook functions
    [_remove_all_forward_hooks(module, "module_outputs_forward_hook") for module in modules]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ProGamerGov, in your example: does it occur after we intialize opt.InputOptimization(model, loss_fn, image) second time ?

Do you have a notebook where we can debug the error ?

Copy link
Contributor Author

@ProGamerGov ProGamerGov Jan 18, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NarineK Just setup / install captum with the optim module, and run this snippet of code to reproduce it:

!git clone https://github.com/progamergov/captum
%cd captum
!git checkout "optim-wip-fix-hook-bug"
!pip3 install -e .
import sys
sys.path.append('/content/captum')
%cd ..
import torch
import captum.optim._core.output_hook as output_hook

def test_bug():
    model = torch.nn.Identity()
    for i in range(5):
        _ = output_hook.ModuleOutputsHook([model])
    print(model._forward_hooks.items()) # There will be 5 hooks

test_bug()

The InputOptimization init function just calls ModuleOutputsHook, so we can just do the same in order to make it easier to reproduce.

Copy link
Contributor Author

@ProGamerGov ProGamerGov Jan 19, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It occurs because we often reuse the same target instance in notebooks for example, and thus the hooks attached to target instance are not removed. I don't think that it's something we can avoid, but we can make users aware of it and provide the option to mitigate it's effects.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc-ing @vivekmig - Vivek can help with this issue

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NarineK I can remove the hook removal functions for this PR, so we can merge it. Then in a later PR we can revisit this minor hook duplication bug as it's a very niche issue at the moment that won't interfere with anything at the moment.

caution should be exercised when removing all hooks. Users are recommended to give
their hook function a unique name that can be used to safely identify and remove
the target forward hooks.

Args:

module (nn.Module): The module instance to remove forward hooks from.
hook_fn_name (str, optional): Optionally only remove specific forward hooks
based on their function's __name__ attribute.
Default: None
"""

if hook_fn_name is None:
warn("Removing all active hooks will break some PyTorch modules & systems.")

def _remove_hooks(m: torch.nn.Module, name: Optional[str] = None) -> None:
if hasattr(module, "_forward_hooks"):
if m._forward_hooks != OrderedDict():
if name is not None:
dict_items = list(m._forward_hooks.items())
m._forward_hooks = OrderedDict(
[(i, fn) for i, fn in dict_items if fn.__name__ != name]
)
else:
m._forward_hooks: Dict[int, Callable] = OrderedDict()

def _remove_child_hooks(
target_module: torch.nn.Module, hook_name: Optional[str] = None
) -> None:
for name, child in target_module._modules.items():
if child is not None:
_remove_hooks(child, hook_name)
_remove_child_hooks(child, hook_name)

# Remove hooks from target submodules
_remove_child_hooks(module, hook_fn_name)

# Remove hooks from the target module
_remove_hooks(module, hook_fn_name)
Loading