@@ -353,10 +353,12 @@ def attribute(
353353 formatted_feature_mask ,
354354 attr_progress ,
355355 flattened_initial_eval ,
356+ initial_eval ,
356357 n_outputs ,
357358 total_attrib ,
358359 weights ,
359360 attrib_type ,
361+ perturbations_per_eval ,
360362 ** kwargs ,
361363 )
362364 else :
@@ -470,10 +472,12 @@ def _attribute_with_cross_tensor_feature_masks(
470472 formatted_feature_mask : Tuple [Tensor , ...],
471473 attr_progress : Optional [Union [SimpleProgress [IterableType ], tqdm ]],
472474 flattened_initial_eval : Tensor ,
475+ initial_eval : Tensor ,
473476 n_outputs : int ,
474477 total_attrib : List [Tensor ],
475478 weights : List [Tensor ],
476479 attrib_type : dtype ,
480+ perturbations_per_eval : int ,
477481 ** kwargs : Any ,
478482 ) -> Tuple [List [Tensor ], List [Tensor ]]:
479483 feature_idx_to_tensor_idx : Dict [int , List [int ]] = {}
@@ -482,17 +486,78 @@ def _attribute_with_cross_tensor_feature_masks(
482486 if feature_idx .item () not in feature_idx_to_tensor_idx :
483487 feature_idx_to_tensor_idx [feature_idx .item ()] = []
484488 feature_idx_to_tensor_idx [feature_idx .item ()].append (i )
489+ all_feature_idxs = list (feature_idx_to_tensor_idx .keys ())
490+
491+ additional_args_repeated : object
492+ if perturbations_per_eval > 1 :
493+ # Repeat features and additional args for batch size.
494+ all_features_repeated = tuple (
495+ torch .cat ([formatted_inputs [j ]] * perturbations_per_eval , dim = 0 )
496+ for j in range (len (formatted_inputs ))
497+ )
498+ additional_args_repeated = (
499+ _expand_additional_forward_args (
500+ formatted_additional_forward_args , perturbations_per_eval
501+ )
502+ if formatted_additional_forward_args is not None
503+ else None
504+ )
505+ target_repeated = _expand_target (target , perturbations_per_eval )
506+ else :
507+ all_features_repeated = formatted_inputs
508+ additional_args_repeated = formatted_additional_forward_args
509+ target_repeated = target
510+ num_examples = formatted_inputs [0 ].shape [0 ]
511+
512+ current_additional_args : object
513+ if isinstance (baselines , tuple ):
514+ reshaped = False
515+ reshaped_baselines : list [BaselineType ] = []
516+ for baseline in baselines :
517+ if isinstance (baseline , Tensor ):
518+ reshaped = True
519+ reshaped_baselines .append (
520+ baseline .reshape ((1 ,) + tuple (baseline .shape ))
521+ )
522+ else :
523+ reshaped_baselines .append (baseline )
524+ baselines = tuple (reshaped_baselines ) if reshaped else baselines
525+ for i in range (0 , len (all_feature_idxs ), perturbations_per_eval ):
526+ current_feature_idxs = all_feature_idxs [i : i + perturbations_per_eval ]
527+ current_num_ablated_features = min (
528+ perturbations_per_eval , len (current_feature_idxs )
529+ )
530+
531+ # Store appropriate inputs and additional args based on batch size.
532+ if current_num_ablated_features != perturbations_per_eval :
533+ current_additional_args = (
534+ _expand_additional_forward_args (
535+ formatted_additional_forward_args , current_num_ablated_features
536+ )
537+ if formatted_additional_forward_args is not None
538+ else None
539+ )
540+ current_target = _expand_target (target , current_num_ablated_features )
541+ expanded_inputs = tuple (
542+ feature_repeated [0 : current_num_ablated_features * num_examples ]
543+ for feature_repeated in all_features_repeated
544+ )
545+ else :
546+ current_additional_args = additional_args_repeated
547+ current_target = target_repeated
548+ expanded_inputs = all_features_repeated
549+
550+ current_inputs , current_masks = (
551+ self ._construct_ablated_input_across_tensors (
552+ expanded_inputs ,
553+ formatted_feature_mask ,
554+ baselines ,
555+ current_feature_idxs ,
556+ feature_idx_to_tensor_idx ,
557+ current_num_ablated_features ,
558+ )
559+ )
485560
486- for (
487- current_inputs ,
488- current_mask ,
489- ) in self ._ablation_generator (
490- formatted_inputs ,
491- baselines ,
492- formatted_feature_mask ,
493- feature_idx_to_tensor_idx ,
494- ** kwargs ,
495- ):
496561 # modified_eval has (n_feature_perturbed * n_outputs) elements
497562 # shape:
498563 # agg mode: (*initial_eval.shape)
@@ -501,8 +566,8 @@ def _attribute_with_cross_tensor_feature_masks(
501566 modified_eval = _run_forward (
502567 self .forward_func ,
503568 current_inputs ,
504- target ,
505- formatted_additional_forward_args ,
569+ current_target ,
570+ current_additional_args ,
506571 )
507572
508573 if attr_progress is not None :
@@ -515,75 +580,65 @@ def _attribute_with_cross_tensor_feature_masks(
515580
516581 total_attrib , weights = self ._process_ablated_out_full (
517582 modified_eval ,
518- current_mask ,
583+ current_masks ,
519584 flattened_initial_eval ,
520- formatted_inputs ,
585+ initial_eval ,
586+ current_inputs ,
521587 n_outputs ,
588+ num_examples ,
522589 total_attrib ,
523590 weights ,
524591 attrib_type ,
592+ perturbations_per_eval ,
525593 )
526594 return total_attrib , weights
527595
528- def _ablation_generator (
529- self ,
530- inputs : Tuple [Tensor , ...],
531- baselines : BaselineType ,
532- input_mask : Tuple [Tensor , ...],
533- feature_idx_to_tensor_idx : Dict [int , List [int ]],
534- ** kwargs : Any ,
535- ) -> Generator [
536- Tuple [
537- Tuple [Tensor , ...],
538- Tuple [Optional [Tensor ], ...],
539- ],
540- None ,
541- None ,
542- ]:
543- if isinstance (baselines , torch .Tensor ):
544- baselines = baselines .reshape ((1 ,) + tuple (baselines .shape ))
545-
546- # Process one feature per time, rather than processing every input tensor
547- for feature_idx in feature_idx_to_tensor_idx .keys ():
548- ablated_inputs , current_masks = (
549- self ._construct_ablated_input_across_tensors (
550- inputs ,
551- input_mask ,
552- baselines ,
553- feature_idx ,
554- feature_idx_to_tensor_idx [feature_idx ],
555- )
556- )
557- yield ablated_inputs , current_masks
558-
559596 def _construct_ablated_input_across_tensors (
560597 self ,
561598 inputs : Tuple [Tensor , ...],
562599 input_mask : Tuple [Tensor , ...],
563600 baselines : BaselineType ,
564- feature_idx : int ,
565- tensor_idxs : List [int ],
601+ feature_idxs : List [int ],
602+ feature_idx_to_tensor_idx : Dict [int , List [int ]],
603+ current_num_ablated_features : int ,
566604 ) -> Tuple [Tuple [Tensor , ...], Tuple [Optional [Tensor ], ...]]:
567-
568605 ablated_inputs = []
569606 current_masks : List [Optional [Tensor ]] = []
607+ tensor_idxs = {
608+ tensor_idx
609+ for sublist in (
610+ feature_idx_to_tensor_idx [feature_idx ] for feature_idx in feature_idxs
611+ )
612+ for tensor_idx in sublist
613+ }
614+
570615 for i , input_tensor in enumerate (inputs ):
571616 if i not in tensor_idxs :
572617 ablated_inputs .append (input_tensor )
573618 current_masks .append (None )
574619 continue
575- tensor_mask = (input_mask [i ] == feature_idx ).to (input_tensor .device ).long ()
620+ tensor_mask = []
621+ ablated_input = input_tensor .clone ()
576622 baseline = baselines [i ] if isinstance (baselines , tuple ) else baselines
577- if isinstance ( baseline , torch . Tensor ):
578- baseline = baseline . reshape (
579- ( 1 ,) * ( input_tensor .dim () - baseline . dim ()) + tuple ( baseline . shape )
623+ for j , feature_idx in enumerate ( feature_idxs ):
624+ original_input_size = (
625+ input_tensor .shape [ 0 ] // current_num_ablated_features
580626 )
581- assert baseline is not None , "baseline must be provided"
582- ablated_input = (
583- input_tensor * (1 - tensor_mask ).to (input_tensor .dtype )
584- ) + (baseline * tensor_mask .to (input_tensor .dtype ))
627+ start_idx = j * original_input_size
628+ end_idx = (j + 1 ) * original_input_size
629+
630+ mask = (input_mask [i ] == feature_idx ).to (input_tensor .device ).long ()
631+ if mask .ndim == 0 :
632+ mask = mask .reshape ((1 ,) * input_tensor .dim ())
633+ tensor_mask .append (mask )
634+
635+ assert baseline is not None , "baseline must be provided"
636+ ablated_input [start_idx :end_idx ] = input_tensor [start_idx :end_idx ] * (
637+ 1 - mask
638+ ) + (baseline * mask .to (input_tensor .dtype ))
639+ current_masks .append (torch .stack (tensor_mask , dim = 0 ))
585640 ablated_inputs .append (ablated_input )
586- current_masks . append ( tensor_mask )
641+
587642 return tuple (ablated_inputs ), tuple (current_masks )
588643
589644 def _initial_eval_to_processed_initial_eval_fut (
@@ -784,7 +839,7 @@ def _attribute_progress_setup(
784839 formatted_inputs , feature_mask , ** kwargs
785840 )
786841 total_forwards = (
787- int (sum (feature_counts ))
842+ math . ceil ( int (sum (feature_counts )) / perturbations_per_eval )
788843 if enable_cross_tensor_attribution
789844 else sum (
790845 math .ceil (count / perturbations_per_eval ) for count in feature_counts
@@ -1194,13 +1249,46 @@ def _process_ablated_out_full(
11941249 modified_eval : Tensor ,
11951250 current_mask : Tuple [Optional [Tensor ], ...],
11961251 flattened_initial_eval : Tensor ,
1252+ initial_eval : Tensor ,
11971253 inputs : TensorOrTupleOfTensorsGeneric ,
11981254 n_outputs : int ,
1255+ num_examples : int ,
11991256 total_attrib : List [Tensor ],
12001257 weights : List [Tensor ],
12011258 attrib_type : dtype ,
1259+ perturbations_per_eval : int ,
12021260 ) -> Tuple [List [Tensor ], List [Tensor ]]:
12031261 modified_eval = self ._parse_forward_out (modified_eval )
1262+ # if perturbations_per_eval > 1, the output shape must grow with
1263+ # input and not be aggregated
1264+ current_batch_size = inputs [0 ].shape [0 ]
1265+
1266+ # number of perturbation, which is not the same as
1267+ # perturbations_per_eval when not enough features to perturb
1268+ n_perturb = current_batch_size / num_examples
1269+ if perturbations_per_eval > 1 and not self ._is_output_shape_valid :
1270+
1271+ current_output_shape = modified_eval .shape
1272+
1273+ # use initial_eval as the forward of perturbations_per_eval = 1
1274+ initial_output_shape = initial_eval .shape
1275+
1276+ assert (
1277+ # check if the output is not a scalar
1278+ current_output_shape
1279+ and initial_output_shape
1280+ # check if the output grow in same ratio, i.e., not agg
1281+ and current_output_shape [0 ] == n_perturb * initial_output_shape [0 ]
1282+ ), (
1283+ "When perturbations_per_eval > 1, forward_func's output "
1284+ "should be a tensor whose 1st dim grow with the input "
1285+ f"batch size: when input batch size is { num_examples } , "
1286+ f"the output shape is { initial_output_shape } ; "
1287+ f"when input batch size is { current_batch_size } , "
1288+ f"the output shape is { current_output_shape } "
1289+ )
1290+
1291+ self ._is_output_shape_valid = True
12041292
12051293 # reshape the leading dim for n_feature_perturbed
12061294 # flatten each feature's eval outputs into 1D of (n_outputs)
@@ -1209,9 +1297,6 @@ def _process_ablated_out_full(
12091297 eval_diff = flattened_initial_eval - modified_eval
12101298 eval_diff_shape = eval_diff .shape
12111299
1212- # append the shape of one input example
1213- # to make it broadcastable to mask
1214-
12151300 if self .use_weights :
12161301 for weight , mask in zip (weights , current_mask ):
12171302 if mask is not None :
@@ -1224,6 +1309,7 @@ def _process_ablated_out_full(
12241309 )
12251310 eval_diff = eval_diff .to (total_attrib [i ].device )
12261311 total_attrib [i ] += (eval_diff * mask .to (attrib_type )).sum (dim = 0 )
1312+
12271313 return total_attrib , weights
12281314
12291315 def _fut_tuple_to_accumulate_fut_list (
0 commit comments