| 
4 | 4 | import warnings  | 
5 | 5 | from abc import abstractmethod  | 
6 | 6 | from os.path import join  | 
7 |  | -from typing import (  | 
8 |  | -    Any,  | 
9 |  | -    Callable,  | 
10 |  | -    Iterator,  | 
11 |  | -    List,  | 
12 |  | -    NamedTuple,  | 
13 |  | -    Optional,  | 
14 |  | -    Tuple,  | 
15 |  | -    Type,  | 
16 |  | -    Union,  | 
17 |  | -)  | 
 | 7 | +from typing import Any, Callable, Iterator, List, Optional, Tuple, Type, Union  | 
18 | 8 | 
 
  | 
19 | 9 | import torch  | 
20 | 10 | from captum._utils.av import AV  | 
21 |  | -from captum._utils.common import _get_module_from_name, _parse_version  | 
22 |  | -from captum._utils.gradient import (  | 
23 |  | -    _compute_jacobian_wrt_params,  | 
24 |  | -    _compute_jacobian_wrt_params_with_sample_wise_trick,  | 
25 |  | -)  | 
 | 11 | +from captum._utils.common import _parse_version  | 
26 | 12 | from captum._utils.progress import NullProgress, progress  | 
27 | 13 | from captum.influence._core.influence import DataInfluence  | 
28 | 14 | from captum.influence._utils.common import (  | 
29 | 15 |     _check_loss_fn,  | 
 | 16 | +    _compute_jacobian_sample_wise_grads_per_batch,  | 
30 | 17 |     _format_inputs_dataset,  | 
31 | 18 |     _get_k_most_influential_helper,  | 
32 | 19 |     _gradient_dot_product,  | 
 | 20 | +    _influence_route_to_helpers,  | 
33 | 21 |     _load_flexible_state_dict,  | 
34 | 22 |     _self_influence_by_batches_helper,  | 
 | 23 | +    _set_active_parameters,  | 
 | 24 | +    KMostInfluentialResults,  | 
35 | 25 | )  | 
36 | 26 | from captum.log import log_usage  | 
37 | 27 | from torch import Tensor  | 
 | 
69 | 59 | """  | 
70 | 60 | 
 
  | 
71 | 61 | 
 
  | 
72 |  | -class KMostInfluentialResults(NamedTuple):  | 
73 |  | -    """  | 
74 |  | -    This namedtuple stores the results of using the `influence` method. This method  | 
75 |  | -    is implemented by all subclasses of `TracInCPBase` to calculate  | 
76 |  | -    proponents / opponents. The `indices` field stores the indices of the  | 
77 |  | -    proponents / opponents for each example in the test dataset. For example, if  | 
78 |  | -    finding opponents, `indices[i][j]` stores the index in the training data of the  | 
79 |  | -    example with the `j`-th highest influence score on the `i`-th example in the test  | 
80 |  | -    dataset. Similarly, the `influence_scores` field stores the actual influence  | 
81 |  | -    scores, so that `influence_scores[i][j]` is the influence score of example  | 
82 |  | -    `indices[i][j]` in the training data on example `i` of the test dataset.  | 
83 |  | -    Please see `TracInCPBase.influence` for more details.  | 
84 |  | -    """  | 
85 |  | - | 
86 |  | -    indices: Tensor  | 
87 |  | -    influence_scores: Tensor  | 
88 |  | - | 
89 |  | - | 
90 | 62 | class TracInCPBase(DataInfluence):  | 
91 | 63 |     """  | 
92 | 64 |     To implement the `influence` method, classes inheriting from `TracInCPBase` will  | 
@@ -448,34 +420,6 @@ def get_name(cls: Type["TracInCPBase"]) -> str:  | 
448 | 420 |         return cls.__name__  | 
449 | 421 | 
 
  | 
450 | 422 | 
 
  | 
451 |  | -def _influence_route_to_helpers(  | 
452 |  | -    influence_instance: TracInCPBase,  | 
453 |  | -    inputs: Union[Tuple[Any, ...], DataLoader],  | 
454 |  | -    k: Optional[int] = None,  | 
455 |  | -    proponents: bool = True,  | 
456 |  | -    **kwargs,  | 
457 |  | -) -> Union[Tensor, KMostInfluentialResults]:  | 
458 |  | -    """  | 
459 |  | -    This is a helper function called by `TracInCP.influence` and  | 
460 |  | -    `TracInCPFast.influence`. Those methods share a common logic in that they assume  | 
461 |  | -    an instance of their respective classes implement 2 private methods  | 
462 |  | -    (``_influence`, `_get_k_most_influential`), and the logic of  | 
463 |  | -    which private method to call is common, as described in the documentation of the  | 
464 |  | -    `influence` method. The arguments and return values of this function are the exact  | 
465 |  | -    same as the `influence` method. Note that `influence_instance` refers to the  | 
466 |  | -    instance for which the `influence` method was called.  | 
467 |  | -    """  | 
468 |  | -    if k is None:  | 
469 |  | -        return influence_instance._influence(inputs, **kwargs)  | 
470 |  | -    else:  | 
471 |  | -        return influence_instance._get_k_most_influential(  | 
472 |  | -            inputs,  | 
473 |  | -            k,  | 
474 |  | -            proponents,  | 
475 |  | -            **kwargs,  | 
476 |  | -        )  | 
477 |  | - | 
478 |  | - | 
479 | 423 | class TracInCP(TracInCPBase):  | 
480 | 424 |     def __init__(  | 
481 | 425 |         self,  | 
@@ -630,23 +574,7 @@ def __init__(  | 
630 | 574 |         """  | 
631 | 575 |         self.layer_modules = None  | 
632 | 576 |         if layers is not None:  | 
633 |  | -            assert isinstance(layers, List), "`layers` should be a list!"  | 
634 |  | -            assert len(layers) > 0, "`layers` cannot be empty!"  | 
635 |  | -            assert isinstance(  | 
636 |  | -                layers[0], str  | 
637 |  | -            ), "`layers` should contain str layer names."  | 
638 |  | -            self.layer_modules = [  | 
639 |  | -                _get_module_from_name(self.model, layer) for layer in layers  | 
640 |  | -            ]  | 
641 |  | -            for layer, layer_module in zip(layers, self.layer_modules):  | 
642 |  | -                for name, param in layer_module.named_parameters():  | 
643 |  | -                    if not param.requires_grad:  | 
644 |  | -                        warnings.warn(  | 
645 |  | -                            "Setting required grads for layer: {}, name: {}".format(  | 
646 |  | -                                ".".join(layer), name  | 
647 |  | -                            )  | 
648 |  | -                        )  | 
649 |  | -                        param.requires_grad = True  | 
 | 577 | +            self.layer_modules = _set_active_parameters(model, layers)  | 
650 | 578 | 
 
  | 
651 | 579 |     @log_usage()  | 
652 | 580 |     def influence(  # type: ignore[override]  | 
@@ -1463,19 +1391,6 @@ def _basic_computation_tracincp(  | 
1463 | 1391 |                     argument is only used if `sample_wise_grads_per_batch` was true in  | 
1464 | 1392 |                     initialization.  | 
1465 | 1393 |         """  | 
1466 |  | -        if self.sample_wise_grads_per_batch:  | 
1467 |  | -            return _compute_jacobian_wrt_params_with_sample_wise_trick(  | 
1468 |  | -                self.model,  | 
1469 |  | -                inputs,  | 
1470 |  | -                targets,  | 
1471 |  | -                loss_fn,  | 
1472 |  | -                reduction_type,  | 
1473 |  | -                self.layer_modules,  | 
1474 |  | -            )  | 
1475 |  | -        return _compute_jacobian_wrt_params(  | 
1476 |  | -            self.model,  | 
1477 |  | -            inputs,  | 
1478 |  | -            targets,  | 
1479 |  | -            loss_fn,  | 
1480 |  | -            self.layer_modules,  | 
 | 1394 | +        return _compute_jacobian_sample_wise_grads_per_batch(  | 
 | 1395 | +            self, inputs, targets, loss_fn, reduction_type  | 
1481 | 1396 |         )  | 
0 commit comments