diff --git a/captum/optim/_core/output_hook.py b/captum/optim/_core/output_hook.py index 682ecd94df..4903155e74 100644 --- a/captum/optim/_core/output_hook.py +++ b/captum/optim/_core/output_hook.py @@ -32,13 +32,13 @@ def is_ready(self) -> bool: def _forward_hook(self) -> Callable: """ - Return the forward_hook function. + Return the module_outputs_forward_hook forward hook function. Returns: - forward_hook (Callable): The forward_hook function. + forward_hook (Callable): The module_outputs_forward_hook function. """ - def forward_hook( + def module_outputs_forward_hook( module: nn.Module, input: Tuple[torch.Tensor], output: torch.Tensor ) -> None: assert module in self.outputs.keys() @@ -56,7 +56,7 @@ def forward_hook( "that you are passing model layers in your losses." ) - return forward_hook + return module_outputs_forward_hook def consume_outputs(self) -> ModuleOutputMapping: """ diff --git a/tests/optim/core/test_output_hook.py b/tests/optim/core/test_output_hook.py index 0d227d2d8b..6c6100a7fb 100644 --- a/tests/optim/core/test_output_hook.py +++ b/tests/optim/core/test_output_hook.py @@ -1,19 +1,219 @@ #!/usr/bin/env python3 -from typing import cast +from collections import OrderedDict +from typing import List, Optional, cast import captum.optim._core.output_hook as output_hook import torch from captum.optim.models import googlenet -from tests.helpers.basic import BaseTest +from tests.helpers.basic import BaseTest, assertTensorAlmostEqual + + +def _count_forward_hooks( + module: torch.nn.Module, hook_fn_name: Optional[str] = None +) -> int: + """ + Count the number of active forward hooks on the specified module instance. + + Args: + + module (nn.Module): The model module instance to count the number of + forward hooks on. + name (str, optional): Optionally only count specific forward hooks based on + their function's __name__ attribute. + Default: None + + Returns: + num_hooks (int): The number of active hooks in the specified module. + """ + + num_hooks: List[int] = [0] + + def _count_hooks(m: torch.nn.Module, name: Optional[str] = None) -> None: + if hasattr(m, "_forward_hooks"): + if m._forward_hooks != OrderedDict(): + dict_items = list(m._forward_hooks.items()) + for i, fn in dict_items: + if hook_fn_name is None or fn.__name__ == name: + num_hooks[0] += 1 + + def _count_child_hooks( + target_module: torch.nn.Module, + hook_name: Optional[str] = None, + ) -> None: + + for name, child in target_module._modules.items(): + if child is not None: + _count_hooks(child, hook_name) + _count_child_hooks(child, hook_name) + + _count_child_hooks(module, hook_fn_name) + _count_hooks(module, hook_fn_name) + return num_hooks[0] + + +class TestModuleOutputsHook(BaseTest): + def test_init_single_target(self) -> None: + model = torch.nn.Sequential(torch.nn.Identity(), torch.nn.Identity()) + target_modules = [model[0]] + + hook_module = output_hook.ModuleOutputsHook(target_modules) + self.assertEqual(len(hook_module.hooks), len(target_modules)) + + n_hooks = _count_forward_hooks(model, "module_outputs_forward_hook") + self.assertEqual(n_hooks, len(target_modules)) + + outputs = dict.fromkeys(target_modules, None) + self.assertEqual(outputs, hook_module.outputs) + self.assertEqual(list(hook_module.targets), target_modules) + self.assertFalse(hook_module.is_ready) + + def test_init_multiple_targets(self) -> None: + model = torch.nn.Sequential(torch.nn.Identity(), torch.nn.Identity()) + target_modules = [model[0], model[1]] + + hook_module = output_hook.ModuleOutputsHook(target_modules) + self.assertEqual(len(hook_module.hooks), len(target_modules)) + + n_hooks = _count_forward_hooks(model, "module_outputs_forward_hook") + self.assertEqual(n_hooks, len(target_modules)) + + outputs = dict.fromkeys(target_modules, None) + self.assertEqual(outputs, hook_module.outputs) + self.assertEqual(list(hook_module.targets), target_modules) + self.assertFalse(hook_module.is_ready) + + def test_init_multiple_targets_remove_hooks(self) -> None: + model = torch.nn.Sequential(torch.nn.Identity(), torch.nn.Identity()) + target_modules = [model[0], model[1]] + + hook_module = output_hook.ModuleOutputsHook(target_modules) + + n_hooks = _count_forward_hooks(model, "module_outputs_forward_hook") + self.assertEqual(n_hooks, len(target_modules)) + + hook_module.remove_hooks() + + n_hooks = _count_forward_hooks(model, "module_outputs_forward_hook") + self.assertEqual(n_hooks, 0) + + def test_reset_outputs_multiple_targets(self) -> None: + model = torch.nn.Sequential(torch.nn.Identity(), torch.nn.Identity()) + target_modules = [model[0], model[1]] + test_input = torch.randn(1, 3, 4, 4) + + hook_module = output_hook.ModuleOutputsHook(target_modules) + self.assertFalse(hook_module.is_ready) + + _ = model(test_input) + + self.assertTrue(hook_module.is_ready) + + outputs_dict = hook_module.outputs + i = 0 + for target, activations in outputs_dict.items(): + self.assertEqual(target, target_modules[i]) + assertTensorAlmostEqual(self, activations, test_input) + i += 1 + + hook_module._reset_outputs() + + self.assertFalse(hook_module.is_ready) + + expected_outputs = dict.fromkeys(target_modules, None) + self.assertEqual(hook_module.outputs, expected_outputs) + + def test_consume_outputs_multiple_targets(self) -> None: + model = torch.nn.Sequential(torch.nn.Identity(), torch.nn.Identity()) + target_modules = [model[0], model[1]] + test_input = torch.randn(1, 3, 4, 4) + + hook_module = output_hook.ModuleOutputsHook(target_modules) + self.assertFalse(hook_module.is_ready) + + _ = model(test_input) + + self.assertTrue(hook_module.is_ready) + + test_outputs_dict = hook_module.outputs + self.assertIsInstance(test_outputs_dict, dict) + self.assertEqual(len(test_outputs_dict), len(target_modules)) + + i = 0 + for target, activations in test_outputs_dict.items(): + self.assertEqual(target, target_modules[i]) + assertTensorAlmostEqual(self, activations, test_input) + i += 1 + + test_output = hook_module.consume_outputs() + + self.assertFalse(hook_module.is_ready) + + i = 0 + for target, activations in test_output.items(): + self.assertEqual(target, target_modules[i]) + assertTensorAlmostEqual(self, activations, test_input) + i += 1 + + expected_outputs = dict.fromkeys(target_modules, None) + self.assertEqual(hook_module.outputs, expected_outputs) + + def test_consume_outputs_warning(self) -> None: + model = torch.nn.Sequential(torch.nn.Identity(), torch.nn.Identity()) + target_modules = [model[0], model[1]] + test_input = torch.randn(1, 3, 4, 4) + + hook_module = output_hook.ModuleOutputsHook(target_modules) + self.assertFalse(hook_module.is_ready) + + _ = model(test_input) + + self.assertTrue(hook_module.is_ready) + + hook_module._reset_outputs() + + self.assertFalse(hook_module.is_ready) + + with self.assertWarns(Warning): + _ = hook_module.consume_outputs() class TestActivationFetcher(BaseTest): - def test_activation_fetcher(self) -> None: + def test_activation_fetcher_simple_model(self) -> None: + model = torch.nn.Sequential(torch.nn.Identity(), torch.nn.Identity()) + + catch_activ = output_hook.ActivationFetcher(model, targets=[model[0]]) + test_input = torch.randn(1, 3, 224, 224) + activ_out = catch_activ(test_input) + + self.assertIsInstance(activ_out, dict) + self.assertEqual(len(activ_out), 1) + activ = activ_out[model[0]] + assertTensorAlmostEqual(self, activ, test_input) + + def test_activation_fetcher_single_target(self) -> None: model = googlenet(pretrained=True) catch_activ = output_hook.ActivationFetcher(model, targets=[model.mixed4d]) activ_out = catch_activ(torch.zeros(1, 3, 224, 224)) self.assertIsInstance(activ_out, dict) + self.assertEqual(len(activ_out), 1) m4d_activ = activ_out[model.mixed4d] self.assertEqual(list(cast(torch.Tensor, m4d_activ).shape), [1, 528, 14, 14]) + + def test_activation_fetcher_multiple_targets(self) -> None: + model = googlenet(pretrained=True) + + catch_activ = output_hook.ActivationFetcher( + model, targets=[model.mixed4d, model.mixed5b] + ) + activ_out = catch_activ(torch.zeros(1, 3, 224, 224)) + + self.assertIsInstance(activ_out, dict) + self.assertEqual(len(activ_out), 2) + + m4d_activ = activ_out[model.mixed4d] + self.assertEqual(list(cast(torch.Tensor, m4d_activ).shape), [1, 528, 14, 14]) + + m5b_activ = activ_out[model.mixed5b] + self.assertEqual(list(cast(torch.Tensor, m5b_activ).shape), [1, 1024, 7, 7])