@@ -851,6 +851,57 @@ def _influence_route_to_helpers(
851851        )
852852
853853
854+ def  _parameter_dot (
855+     params_1 : Tuple [Tensor , ...], params_2 : Tuple [Tensor , ...]
856+ ) ->  Tensor :
857+     """ 
858+     returns the dot-product of 2 tensors, represented as tuple of tensors. 
859+     """ 
860+     return  torch .Tensor (
861+         sum (
862+             torch .sum (param_1  *  param_2 )
863+             for  (param_1 , param_2 ) in  zip (params_1 , params_2 )
864+         )
865+     )
866+ 
867+ 
868+ def  _parameter_add (
869+     params_1 : Tuple [Tensor , ...], params_2 : Tuple [Tensor , ...]
870+ ) ->  Tuple [Tensor , ...]:
871+     """ 
872+     returns the sum of 2 tensors, represented as tuple of tensors. 
873+     """ 
874+     return  tuple (param_1  +  param_2  for  (param_1 , param_2 ) in  zip (params_1 , params_2 ))
875+ 
876+ 
877+ def  _parameter_multiply (params : Tuple [Tensor , ...], c : Tensor ) ->  Tuple [Tensor , ...]:
878+     """ 
879+     multiplies all tensors in a tuple of tensors by a given scalar 
880+     """ 
881+     return  tuple (param  *  c  for  param  in  params )
882+ 
883+ 
884+ def  _parameter_to (params : Tuple [Tensor , ...], ** to_kwargs ) ->  Tuple [Tensor , ...]:
885+     """ 
886+     applies the `to` method to all tensors in a tuple of tensors 
887+     """ 
888+     return  tuple (param .to (** to_kwargs ) for  param  in  params )
889+ 
890+ 
891+ def  _parameter_linear_combination (
892+     paramss : List [Tuple [Tensor , ...]], cs : Tensor 
893+ ) ->  Tuple [Tensor , ...]:
894+     """ 
895+     scales each parameter (tensor of tuples) in a list by the corresponding scalar in a 
896+     1D tensor of the same length, and sums up the scaled parameters 
897+     """ 
898+     assert  len (cs .shape ) ==  1 
899+     result  =  _parameter_multiply (paramss [0 ], cs [0 ])
900+     for  (params , c ) in  zip (paramss [1 :], cs [1 :]):
901+         result  =  _parameter_add (result , _parameter_multiply (params , c ))
902+     return  result 
903+ 
904+ 
854905def  _compute_jacobian_sample_wise_grads_per_batch (
855906    influence_inst : Union ["TracInCP" , "InfluenceFunctionBase" ],
856907    inputs : Tuple [Any , ...],
@@ -1007,7 +1058,9 @@ def _functional_call(model, d, features):
10071058def  _dataset_fn (dataloader , batch_fn , reduce_fn , * batch_fn_args , ** batch_fn_kwargs ):
10081059    """ 
10091060    Applies `batch_fn` to each batch in `dataloader`, reducing the results using 
1010-     `reduce_fn`.  This is useful for computing Hessians over an entire dataloader. 
1061+     `reduce_fn`.  This is useful for computing Hessians and Hessian-vector 
1062+     products over an entire dataloader, and is used by both `NaiveInfluenceFunction` 
1063+     and `ArnoldiInfluenceFunction`. 
10111064    """ 
10121065    _dataloader  =  iter (dataloader )
10131066
0 commit comments