Skip to content

Commit 461aedd

Browse files
Fulton Wangfacebook-github-bot
authored andcommitted
add ArnoldiInfluenceFunction (#1187)
Summary: Pull Request resolved: #1187 This diff implements `ArnoldiInfluenceFunction`, which was described, along with `NaiveInfluenceFunction` in D40541294. Please see that diff for detailed description. Previously implementations of both methods had been 1 diff. Now, `ArnoldiInfluenceFunction` is separated out for easier review. Reviewed By: vivekmig Differential Revision: D42006733 fbshipit-source-id: 3d8bd87aaf23411025fecc3e7b5d965879358be9
1 parent c315e65 commit 461aedd

File tree

8 files changed

+1677
-8
lines changed

8 files changed

+1677
-8
lines changed

captum/influence/_core/arnoldi_influence_function.py

Lines changed: 1022 additions & 0 deletions
Large diffs are not rendered by default.

captum/influence/_utils/common.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -857,6 +857,57 @@ def _influence_route_to_helpers(
857857
)
858858

859859

860+
def _parameter_dot(
861+
params_1: Tuple[Tensor, ...], params_2: Tuple[Tensor, ...]
862+
) -> Tensor:
863+
"""
864+
returns the dot-product of 2 tensors, represented as tuple of tensors.
865+
"""
866+
return torch.Tensor(
867+
sum(
868+
torch.sum(param_1 * param_2)
869+
for (param_1, param_2) in zip(params_1, params_2)
870+
)
871+
)
872+
873+
874+
def _parameter_add(
875+
params_1: Tuple[Tensor, ...], params_2: Tuple[Tensor, ...]
876+
) -> Tuple[Tensor, ...]:
877+
"""
878+
returns the sum of 2 tensors, represented as tuple of tensors.
879+
"""
880+
return tuple(param_1 + param_2 for (param_1, param_2) in zip(params_1, params_2))
881+
882+
883+
def _parameter_multiply(params: Tuple[Tensor, ...], c: Tensor) -> Tuple[Tensor, ...]:
884+
"""
885+
multiplies all tensors in a tuple of tensors by a given scalar
886+
"""
887+
return tuple(param * c for param in params)
888+
889+
890+
def _parameter_to(params: Tuple[Tensor, ...], **to_kwargs) -> Tuple[Tensor, ...]:
891+
"""
892+
applies the `to` method to all tensors in a tuple of tensors
893+
"""
894+
return tuple(param.to(**to_kwargs) for param in params)
895+
896+
897+
def _parameter_linear_combination(
898+
paramss: List[Tuple[Tensor, ...]], cs: Tensor
899+
) -> Tuple[Tensor, ...]:
900+
"""
901+
scales each parameter (tensor of tuples) in a list by the corresponding scalar in a
902+
1D tensor of the same length, and sums up the scaled parameters
903+
"""
904+
assert len(cs.shape) == 1
905+
result = _parameter_multiply(paramss[0], cs[0])
906+
for (params, c) in zip(paramss[1:], cs[1:]):
907+
result = _parameter_add(result, _parameter_multiply(params, c))
908+
return result
909+
910+
860911
def _compute_jacobian_sample_wise_grads_per_batch(
861912
influence_inst: Union["TracInCP", "InfluenceFunctionBase"],
862913
inputs: Tuple[Any, ...],
@@ -1013,7 +1064,9 @@ def _functional_call(model, d, features):
10131064
def _dataset_fn(dataloader, batch_fn, reduce_fn, *batch_fn_args, **batch_fn_kwargs):
10141065
"""
10151066
Applies `batch_fn` to each batch in `dataloader`, reducing the results using
1016-
`reduce_fn`. This is useful for computing Hessians over an entire dataloader.
1067+
`reduce_fn`. This is useful for computing Hessians and Hessian-vector
1068+
products over an entire dataloader, and is used by both `NaiveInfluenceFunction`
1069+
and `ArnoldiInfluenceFunction`.
10171070
"""
10181071
_dataloader = iter(dataloader)
10191072

0 commit comments

Comments
 (0)