Skip to content

Commit c3f9a54

Browse files
authored
Update optimization.py
1 parent 1f6c2d5 commit c3f9a54

File tree

1 file changed

+24
-1
lines changed

1 file changed

+24
-1
lines changed

captum/optim/_core/optimization.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
)
1717

1818
from captum.optim._core.loss import default_loss_summarize
19-
from captum.optim._core.output_hook import ModuleOutputsHook
19+
from captum.optim._core.output_hook import ModuleOutputsHook, _remove_all_forward_hooks
2020
from captum.optim._param.image.images import InputParameterization, NaturalImage
2121
from captum.optim._param.image.transforms import RandomScale, RandomSpatialJitter
2222
from captum.optim._utils.typing import (
@@ -196,6 +196,29 @@ def continue_while(
196196
return continue_while
197197

198198

199+
def cleanup_module_hooks(modules: Union[nn.Module, List[nn.Module]) -> None:
200+
"""
201+
Remove any InputOptimization hooks from the specified modules. This may be useful
202+
in the event that something goes wrong in between creating the InputOptimization
203+
instance and running the optimization function, or if InputOptimization fails
204+
without properly removing it's hooks.
205+
206+
Warning: This function will remove all the hooks placed by InputOptimization
207+
instances on the target modules, and thus can interfere with using multiple
208+
InputOptimization instances.
209+
210+
Args:
211+
212+
modules (nn.Module or list of nn.Module): Any module instances that contain
213+
hooks created by InputOptimization, for which the removal of the hooks is
214+
required.
215+
"""
216+
if not hasattr(modules, "__iter__"):
217+
modules = [modules]
218+
# Captum ModuleOutputsHook uses "module_outputs_forward_hook" hook functions
219+
[_remove_all_forward_hooks(module, "module_outputs_forward_hook") for module in modules]
220+
221+
199222
__all__ = [
200223
"InputOptimization",
201224
"n_steps",

0 commit comments

Comments
 (0)