@@ -444,7 +444,7 @@ def _check_loss_fn(
444444    influence_instance : Union ["TracInCPBase" , "InfluenceFunctionBase" ],
445445    loss_fn : Optional [Union [Module , Callable ]],
446446    loss_fn_name : str ,
447-     sample_wise_grads_per_batch : Optional [ bool ]  =  None ,
447+     sample_wise_grads_per_batch : bool  =  True ,
448448) ->  str :
449449    """ 
450450    This checks whether `loss_fn` satisfies the requirements assumed of all 
@@ -469,16 +469,13 @@ def _check_loss_fn(
469469    # attribute. 
470470    if  hasattr (loss_fn , "reduction" ):
471471        reduction  =  loss_fn .reduction   # type: ignore 
472-         if  sample_wise_grads_per_batch   is   None :
472+         if  sample_wise_grads_per_batch :
473473            assert  reduction  in  [
474474                "sum" ,
475475                "mean" ,
476-             ], 'reduction for `loss_fn` must be "sum" or "mean"' 
477-             reduction_type  =  str (reduction )
478-         elif  sample_wise_grads_per_batch :
479-             assert  reduction  in  ["sum" , "mean" ], (
476+             ], (
480477                'reduction for `loss_fn` must be "sum" or "mean" when ' 
481-                 "`sample_wise_grads_per_batch` is True" 
478+                 "`sample_wise_grads_per_batch` is True (i.e. the default value)  " 
482479            )
483480            reduction_type  =  str (reduction )
484481        else :
@@ -490,18 +487,7 @@ def _check_loss_fn(
490487        # if we are unable to access the reduction used by `loss_fn`, we warn 
491488        # the user about the assumptions we are making regarding the reduction 
492489        # used by `loss_fn` 
493-         if  sample_wise_grads_per_batch  is  None :
494-             warnings .warn (
495-                 f'Since `{ loss_fn_name }  
496-                 f'implementation  assumes that `{ loss_fn_name }  
497-                 "function that reduces the per-example losses by taking their *sum*. " 
498-                 f"If `{ loss_fn_name }  
499-                 f"taking their mean, please set the reduction attribute of " 
500-                 f'`{ loss_fn_name }  
501-                 f'`{ loss_fn_name }  
502-             )
503-             reduction_type  =  "sum" 
504-         elif  sample_wise_grads_per_batch :
490+         if  sample_wise_grads_per_batch :
505491            warnings .warn (
506492                f"Since `{ loss_fn_name }  
507493                "`sample_wise_grads_per_batch` is True, the implementation assumes " 
0 commit comments