|
4 | 4 | import warnings |
5 | 5 | from abc import abstractmethod |
6 | 6 | from os.path import join |
7 | | -from typing import ( |
8 | | - Any, |
9 | | - Callable, |
10 | | - Iterator, |
11 | | - List, |
12 | | - NamedTuple, |
13 | | - Optional, |
14 | | - Tuple, |
15 | | - Type, |
16 | | - Union, |
17 | | -) |
| 7 | +from typing import Any, Callable, Iterator, List, Optional, Tuple, Type, Union |
18 | 8 |
|
19 | 9 | import torch |
20 | 10 | from captum._utils.av import AV |
21 | | -from captum._utils.common import _get_module_from_name, _parse_version |
22 | | -from captum._utils.gradient import ( |
23 | | - _compute_jacobian_wrt_params, |
24 | | - _compute_jacobian_wrt_params_with_sample_wise_trick, |
25 | | -) |
| 11 | +from captum._utils.common import _parse_version |
26 | 12 | from captum._utils.progress import NullProgress, progress |
27 | 13 | from captum.influence._core.influence import DataInfluence |
28 | 14 | from captum.influence._utils.common import ( |
29 | 15 | _check_loss_fn, |
| 16 | + _compute_jacobian_sample_wise_grads_per_batch, |
30 | 17 | _format_inputs_dataset, |
31 | 18 | _get_k_most_influential_helper, |
32 | 19 | _gradient_dot_product, |
| 20 | + _influence_route_to_helpers, |
33 | 21 | _load_flexible_state_dict, |
34 | 22 | _self_influence_by_batches_helper, |
| 23 | + _set_active_parameters, |
| 24 | + KMostInfluentialResults, |
35 | 25 | ) |
36 | 26 | from captum.log import log_usage |
37 | 27 | from torch import Tensor |
|
69 | 59 | """ |
70 | 60 |
|
71 | 61 |
|
72 | | -class KMostInfluentialResults(NamedTuple): |
73 | | - """ |
74 | | - This namedtuple stores the results of using the `influence` method. This method |
75 | | - is implemented by all subclasses of `TracInCPBase` to calculate |
76 | | - proponents / opponents. The `indices` field stores the indices of the |
77 | | - proponents / opponents for each example in the test dataset. For example, if |
78 | | - finding opponents, `indices[i][j]` stores the index in the training data of the |
79 | | - example with the `j`-th highest influence score on the `i`-th example in the test |
80 | | - dataset. Similarly, the `influence_scores` field stores the actual influence |
81 | | - scores, so that `influence_scores[i][j]` is the influence score of example |
82 | | - `indices[i][j]` in the training data on example `i` of the test dataset. |
83 | | - Please see `TracInCPBase.influence` for more details. |
84 | | - """ |
85 | | - |
86 | | - indices: Tensor |
87 | | - influence_scores: Tensor |
88 | | - |
89 | | - |
90 | 62 | class TracInCPBase(DataInfluence): |
91 | 63 | """ |
92 | 64 | To implement the `influence` method, classes inheriting from `TracInCPBase` will |
@@ -448,34 +420,6 @@ def get_name(cls: Type["TracInCPBase"]) -> str: |
448 | 420 | return cls.__name__ |
449 | 421 |
|
450 | 422 |
|
451 | | -def _influence_route_to_helpers( |
452 | | - influence_instance: TracInCPBase, |
453 | | - inputs: Union[Tuple[Any, ...], DataLoader], |
454 | | - k: Optional[int] = None, |
455 | | - proponents: bool = True, |
456 | | - **kwargs, |
457 | | -) -> Union[Tensor, KMostInfluentialResults]: |
458 | | - """ |
459 | | - This is a helper function called by `TracInCP.influence` and |
460 | | - `TracInCPFast.influence`. Those methods share a common logic in that they assume |
461 | | - an instance of their respective classes implement 2 private methods |
462 | | - (``_influence`, `_get_k_most_influential`), and the logic of |
463 | | - which private method to call is common, as described in the documentation of the |
464 | | - `influence` method. The arguments and return values of this function are the exact |
465 | | - same as the `influence` method. Note that `influence_instance` refers to the |
466 | | - instance for which the `influence` method was called. |
467 | | - """ |
468 | | - if k is None: |
469 | | - return influence_instance._influence(inputs, **kwargs) |
470 | | - else: |
471 | | - return influence_instance._get_k_most_influential( |
472 | | - inputs, |
473 | | - k, |
474 | | - proponents, |
475 | | - **kwargs, |
476 | | - ) |
477 | | - |
478 | | - |
479 | 423 | class TracInCP(TracInCPBase): |
480 | 424 | def __init__( |
481 | 425 | self, |
@@ -630,23 +574,7 @@ def __init__( |
630 | 574 | """ |
631 | 575 | self.layer_modules = None |
632 | 576 | if layers is not None: |
633 | | - assert isinstance(layers, List), "`layers` should be a list!" |
634 | | - assert len(layers) > 0, "`layers` cannot be empty!" |
635 | | - assert isinstance( |
636 | | - layers[0], str |
637 | | - ), "`layers` should contain str layer names." |
638 | | - self.layer_modules = [ |
639 | | - _get_module_from_name(self.model, layer) for layer in layers |
640 | | - ] |
641 | | - for layer, layer_module in zip(layers, self.layer_modules): |
642 | | - for name, param in layer_module.named_parameters(): |
643 | | - if not param.requires_grad: |
644 | | - warnings.warn( |
645 | | - "Setting required grads for layer: {}, name: {}".format( |
646 | | - ".".join(layer), name |
647 | | - ) |
648 | | - ) |
649 | | - param.requires_grad = True |
| 577 | + self.layer_modules = _set_active_parameters(model, layers) |
650 | 578 |
|
651 | 579 | @log_usage() |
652 | 580 | def influence( # type: ignore[override] |
@@ -1463,19 +1391,6 @@ def _basic_computation_tracincp( |
1463 | 1391 | argument is only used if `sample_wise_grads_per_batch` was true in |
1464 | 1392 | initialization. |
1465 | 1393 | """ |
1466 | | - if self.sample_wise_grads_per_batch: |
1467 | | - return _compute_jacobian_wrt_params_with_sample_wise_trick( |
1468 | | - self.model, |
1469 | | - inputs, |
1470 | | - targets, |
1471 | | - loss_fn, |
1472 | | - reduction_type, |
1473 | | - self.layer_modules, |
1474 | | - ) |
1475 | | - return _compute_jacobian_wrt_params( |
1476 | | - self.model, |
1477 | | - inputs, |
1478 | | - targets, |
1479 | | - loss_fn, |
1480 | | - self.layer_modules, |
| 1394 | + return _compute_jacobian_sample_wise_grads_per_batch( |
| 1395 | + self, inputs, targets, loss_fn, reduction_type |
1481 | 1396 | ) |
0 commit comments