|  | 
| 16 | 16 |     ) | 
| 17 | 17 | 
 | 
| 18 | 18 | from captum.optim._core.loss import default_loss_summarize | 
| 19 |  | -from captum.optim._core.output_hook import ModuleOutputsHook | 
|  | 19 | +from captum.optim._core.output_hook import ModuleOutputsHook, _remove_all_forward_hooks | 
| 20 | 20 | from captum.optim._param.image.images import InputParameterization, NaturalImage | 
| 21 | 21 | from captum.optim._param.image.transforms import RandomScale, RandomSpatialJitter | 
| 22 | 22 | from captum.optim._utils.typing import ( | 
| @@ -196,6 +196,29 @@ def continue_while( | 
| 196 | 196 |     return continue_while | 
| 197 | 197 | 
 | 
| 198 | 198 | 
 | 
|  | 199 | +def cleanup_module_hooks(modules: Union[nn.Module, List[nn.Module]) -> None: | 
|  | 200 | +    """ | 
|  | 201 | +    Remove any InputOptimization hooks from the specified modules. This may be useful | 
|  | 202 | +    in the event that something goes wrong in between creating the InputOptimization | 
|  | 203 | +    instance and running the optimization function, or if InputOptimization fails | 
|  | 204 | +    without properly removing it's hooks. | 
|  | 205 | +
 | 
|  | 206 | +    Warning: This function will remove all the hooks placed by InputOptimization | 
|  | 207 | +    instances on the target modules, and thus can interfere with using multiple | 
|  | 208 | +    InputOptimization instances. | 
|  | 209 | +
 | 
|  | 210 | +    Args: | 
|  | 211 | +
 | 
|  | 212 | +        modules (nn.Module or list of nn.Module): Any module instances that contain | 
|  | 213 | +            hooks created by InputOptimization, for which the removal of the hooks is | 
|  | 214 | +            required. | 
|  | 215 | +    """ | 
|  | 216 | +    if not hasattr(modules, "__iter__"): | 
|  | 217 | +        modules = [modules] | 
|  | 218 | +    # Captum ModuleOutputsHook uses "module_outputs_forward_hook" hook functions | 
|  | 219 | +    [_remove_all_forward_hooks(module, "module_outputs_forward_hook") for module in modules] | 
|  | 220 | + | 
|  | 221 | + | 
| 199 | 222 | __all__ = [ | 
| 200 | 223 |     "InputOptimization", | 
| 201 | 224 |     "n_steps", | 
|  | 
0 commit comments