@@ -353,10 +353,12 @@ def attribute(
353353                    formatted_feature_mask ,
354354                    attr_progress ,
355355                    flattened_initial_eval ,
356+                     initial_eval ,
356357                    n_outputs ,
357358                    total_attrib ,
358359                    weights ,
359360                    attrib_type ,
361+                     perturbations_per_eval ,
360362                    ** kwargs ,
361363                )
362364            else :
@@ -470,10 +472,12 @@ def _attribute_with_cross_tensor_feature_masks(
470472        formatted_feature_mask : Tuple [Tensor , ...],
471473        attr_progress : Optional [Union [SimpleProgress [IterableType ], tqdm ]],
472474        flattened_initial_eval : Tensor ,
475+         initial_eval : Tensor ,
473476        n_outputs : int ,
474477        total_attrib : List [Tensor ],
475478        weights : List [Tensor ],
476479        attrib_type : dtype ,
480+         perturbations_per_eval : int ,
477481        ** kwargs : Any ,
478482    ) ->  Tuple [List [Tensor ], List [Tensor ]]:
479483        feature_idx_to_tensor_idx : Dict [int , List [int ]] =  {}
@@ -482,17 +486,78 @@ def _attribute_with_cross_tensor_feature_masks(
482486                if  feature_idx .item () not  in feature_idx_to_tensor_idx :
483487                    feature_idx_to_tensor_idx [feature_idx .item ()] =  []
484488                feature_idx_to_tensor_idx [feature_idx .item ()].append (i )
489+         all_feature_idxs  =  list (feature_idx_to_tensor_idx .keys ())
490+ 
491+         additional_args_repeated : object 
492+         if  perturbations_per_eval  >  1 :
493+             # Repeat features and additional args for batch size. 
494+             all_features_repeated  =  tuple (
495+                 torch .cat ([formatted_inputs [j ]] *  perturbations_per_eval , dim = 0 )
496+                 for  j  in  range (len (formatted_inputs ))
497+             )
498+             additional_args_repeated  =  (
499+                 _expand_additional_forward_args (
500+                     formatted_additional_forward_args , perturbations_per_eval 
501+                 )
502+                 if  formatted_additional_forward_args  is  not None 
503+                 else  None 
504+             )
505+             target_repeated  =  _expand_target (target , perturbations_per_eval )
506+         else :
507+             all_features_repeated  =  formatted_inputs 
508+             additional_args_repeated  =  formatted_additional_forward_args 
509+             target_repeated  =  target 
510+         num_examples  =  formatted_inputs [0 ].shape [0 ]
511+ 
512+         current_additional_args : object 
513+         if  isinstance (baselines , tuple ):
514+             reshaped  =  False 
515+             reshaped_baselines : list [Union [Tensor , int , float ]] =  []
516+             for  baseline  in  baselines :
517+                 if  isinstance (baseline , Tensor ):
518+                     reshaped  =  True 
519+                     reshaped_baselines .append (
520+                         baseline .reshape ((1 ,) +  tuple (baseline .shape ))
521+                     )
522+                 else :
523+                     reshaped_baselines .append (baseline )
524+             baselines  =  tuple (reshaped_baselines ) if  reshaped  else  baselines 
525+         for  i  in  range (0 , len (all_feature_idxs ), perturbations_per_eval ):
526+             current_feature_idxs  =  all_feature_idxs [i  : i  +  perturbations_per_eval ]
527+             current_num_ablated_features  =  min (
528+                 perturbations_per_eval , len (current_feature_idxs )
529+             )
530+ 
531+             # Store appropriate inputs and additional args based on batch size. 
532+             if  current_num_ablated_features  !=  perturbations_per_eval :
533+                 current_additional_args  =  (
534+                     _expand_additional_forward_args (
535+                         formatted_additional_forward_args , current_num_ablated_features 
536+                     )
537+                     if  formatted_additional_forward_args  is  not None 
538+                     else  None 
539+                 )
540+                 current_target  =  _expand_target (target , current_num_ablated_features )
541+                 expanded_inputs  =  tuple (
542+                     feature_repeated [0  : current_num_ablated_features  *  num_examples ]
543+                     for  feature_repeated  in  all_features_repeated 
544+                 )
545+             else :
546+                 current_additional_args  =  additional_args_repeated 
547+                 current_target  =  target_repeated 
548+                 expanded_inputs  =  all_features_repeated 
549+ 
550+             current_inputs , current_masks  =  (
551+                 self ._construct_ablated_input_across_tensors (
552+                     expanded_inputs ,
553+                     formatted_feature_mask ,
554+                     baselines ,
555+                     current_feature_idxs ,
556+                     feature_idx_to_tensor_idx ,
557+                     current_num_ablated_features ,
558+                 )
559+             )
485560
486-         for  (
487-             current_inputs ,
488-             current_mask ,
489-         ) in  self ._ablation_generator (
490-             formatted_inputs ,
491-             baselines ,
492-             formatted_feature_mask ,
493-             feature_idx_to_tensor_idx ,
494-             ** kwargs ,
495-         ):
496561            # modified_eval has (n_feature_perturbed * n_outputs) elements 
497562            # shape: 
498563            #   agg mode: (*initial_eval.shape) 
@@ -501,8 +566,8 @@ def _attribute_with_cross_tensor_feature_masks(
501566            modified_eval  =  _run_forward (
502567                self .forward_func ,
503568                current_inputs ,
504-                 target ,
505-                 formatted_additional_forward_args ,
569+                 current_target ,
570+                 current_additional_args ,
506571            )
507572
508573            if  attr_progress  is  not None :
@@ -515,75 +580,65 @@ def _attribute_with_cross_tensor_feature_masks(
515580
516581            total_attrib , weights  =  self ._process_ablated_out_full (
517582                modified_eval ,
518-                 current_mask ,
583+                 current_masks ,
519584                flattened_initial_eval ,
520-                 formatted_inputs ,
585+                 initial_eval ,
586+                 current_inputs ,
521587                n_outputs ,
588+                 num_examples ,
522589                total_attrib ,
523590                weights ,
524591                attrib_type ,
592+                 perturbations_per_eval ,
525593            )
526594        return  total_attrib , weights 
527595
528-     def  _ablation_generator (
529-         self ,
530-         inputs : Tuple [Tensor , ...],
531-         baselines : BaselineType ,
532-         input_mask : Tuple [Tensor , ...],
533-         feature_idx_to_tensor_idx : Dict [int , List [int ]],
534-         ** kwargs : Any ,
535-     ) ->  Generator [
536-         Tuple [
537-             Tuple [Tensor , ...],
538-             Tuple [Optional [Tensor ], ...],
539-         ],
540-         None ,
541-         None ,
542-     ]:
543-         if  isinstance (baselines , torch .Tensor ):
544-             baselines  =  baselines .reshape ((1 ,) +  tuple (baselines .shape ))
545- 
546-         # Process one feature per time, rather than processing every input tensor 
547-         for  feature_idx  in  feature_idx_to_tensor_idx .keys ():
548-             ablated_inputs , current_masks  =  (
549-                 self ._construct_ablated_input_across_tensors (
550-                     inputs ,
551-                     input_mask ,
552-                     baselines ,
553-                     feature_idx ,
554-                     feature_idx_to_tensor_idx [feature_idx ],
555-                 )
556-             )
557-             yield  ablated_inputs , current_masks 
558- 
559596    def  _construct_ablated_input_across_tensors (
560597        self ,
561598        inputs : Tuple [Tensor , ...],
562599        input_mask : Tuple [Tensor , ...],
563600        baselines : BaselineType ,
564-         feature_idx : int ,
565-         tensor_idxs : List [int ],
601+         feature_idxs : List [int ],
602+         feature_idx_to_tensor_idx : Dict [int , List [int ]],
603+         current_num_ablated_features : int ,
566604    ) ->  Tuple [Tuple [Tensor , ...], Tuple [Optional [Tensor ], ...]]:
567- 
568605        ablated_inputs  =  []
569606        current_masks : List [Optional [Tensor ]] =  []
607+         tensor_idxs  =  {
608+             tensor_idx 
609+             for  sublist  in  (
610+                 feature_idx_to_tensor_idx [feature_idx ] for  feature_idx  in  feature_idxs 
611+             )
612+             for  tensor_idx  in  sublist 
613+         }
614+ 
570615        for  i , input_tensor  in  enumerate (inputs ):
571616            if  i  not  in tensor_idxs :
572617                ablated_inputs .append (input_tensor )
573618                current_masks .append (None )
574619                continue 
575-             tensor_mask  =  (input_mask [i ] ==  feature_idx ).to (input_tensor .device ).long ()
620+             tensor_mask  =  []
621+             ablated_input  =  input_tensor .clone ()
576622            baseline  =  baselines [i ] if  isinstance (baselines , tuple ) else  baselines 
577-             if   isinstance ( baseline ,  torch . Tensor ):
578-                 baseline  =  baseline . reshape (
579-                     ( 1 ,)  *  ( input_tensor .dim ()  -   baseline . dim ())  +   tuple ( baseline . shape ) 
623+             for   j ,  feature_idx   in   enumerate ( feature_idxs ):
624+                 original_input_size  =  (
625+                     input_tensor .shape [ 0 ]  //   current_num_ablated_features 
580626                )
581-             assert  baseline  is  not None , "baseline must be provided" 
582-             ablated_input  =  (
583-                 input_tensor  *  (1  -  tensor_mask ).to (input_tensor .dtype )
584-             ) +  (baseline  *  tensor_mask .to (input_tensor .dtype ))
627+                 start_idx  =  j  *  original_input_size 
628+                 end_idx  =  (j  +  1 ) *  original_input_size 
629+ 
630+                 mask  =  (input_mask [i ] ==  feature_idx ).to (input_tensor .device ).long ()
631+                 if  mask .ndim  ==  0 :
632+                     mask  =  mask .reshape ((1 ,) *  input_tensor .dim ())
633+                 tensor_mask .append (mask )
634+ 
635+                 assert  baseline  is  not None , "baseline must be provided" 
636+                 ablated_input [start_idx :end_idx ] =  input_tensor [start_idx :end_idx ] *  (
637+                     1  -  mask 
638+                 ) +  (baseline  *  mask .to (input_tensor .dtype ))
639+             current_masks .append (torch .stack (tensor_mask , dim = 0 ))
585640            ablated_inputs .append (ablated_input )
586-              current_masks . append ( tensor_mask ) 
641+ 
587642        return  tuple (ablated_inputs ), tuple (current_masks )
588643
589644    def  _initial_eval_to_processed_initial_eval_fut (
@@ -784,7 +839,7 @@ def _attribute_progress_setup(
784839            formatted_inputs , feature_mask , ** kwargs 
785840        )
786841        total_forwards  =  (
787-             int (sum (feature_counts ))
842+             math . ceil ( int (sum (feature_counts ))  /   perturbations_per_eval )
788843            if  enable_cross_tensor_attribution 
789844            else  sum (
790845                math .ceil (count  /  perturbations_per_eval ) for  count  in  feature_counts 
@@ -1194,13 +1249,46 @@ def _process_ablated_out_full(
11941249        modified_eval : Tensor ,
11951250        current_mask : Tuple [Optional [Tensor ], ...],
11961251        flattened_initial_eval : Tensor ,
1252+         initial_eval : Tensor ,
11971253        inputs : TensorOrTupleOfTensorsGeneric ,
11981254        n_outputs : int ,
1255+         num_examples : int ,
11991256        total_attrib : List [Tensor ],
12001257        weights : List [Tensor ],
12011258        attrib_type : dtype ,
1259+         perturbations_per_eval : int ,
12021260    ) ->  Tuple [List [Tensor ], List [Tensor ]]:
12031261        modified_eval  =  self ._parse_forward_out (modified_eval )
1262+         # if perturbations_per_eval > 1, the output shape must grow with 
1263+         # input and not be aggregated 
1264+         current_batch_size  =  inputs [0 ].shape [0 ]
1265+ 
1266+         # number of perturbation, which is not the same as 
1267+         # perturbations_per_eval when not enough features to perturb 
1268+         n_perturb  =  current_batch_size  /  num_examples 
1269+         if  perturbations_per_eval  >  1  and  not  self ._is_output_shape_valid :
1270+ 
1271+             current_output_shape  =  modified_eval .shape 
1272+ 
1273+             # use initial_eval as the forward of perturbations_per_eval = 1 
1274+             initial_output_shape  =  initial_eval .shape 
1275+ 
1276+             assert  (
1277+                 # check if the output is not a scalar 
1278+                 current_output_shape 
1279+                 and  initial_output_shape 
1280+                 # check if the output grow in same ratio, i.e., not agg 
1281+                 and  current_output_shape [0 ] ==  n_perturb  *  initial_output_shape [0 ]
1282+             ), (
1283+                 "When perturbations_per_eval > 1, forward_func's output " 
1284+                 "should be a tensor whose 1st dim grow with the input " 
1285+                 f"batch size: when input batch size is { num_examples }  
1286+                 f"the output shape is { initial_output_shape }  
1287+                 f"when input batch size is { current_batch_size }  
1288+                 f"the output shape is { current_output_shape }  
1289+             )
1290+ 
1291+             self ._is_output_shape_valid  =  True 
12041292
12051293        # reshape the leading dim for n_feature_perturbed 
12061294        # flatten each feature's eval outputs into 1D of (n_outputs) 
@@ -1209,9 +1297,6 @@ def _process_ablated_out_full(
12091297        eval_diff  =  flattened_initial_eval  -  modified_eval 
12101298        eval_diff_shape  =  eval_diff .shape 
12111299
1212-         # append the shape of one input example 
1213-         # to make it broadcastable to mask 
1214- 
12151300        if  self .use_weights :
12161301            for  weight , mask  in  zip (weights , current_mask ):
12171302                if  mask  is  not None :
@@ -1224,6 +1309,7 @@ def _process_ablated_out_full(
12241309            )
12251310            eval_diff  =  eval_diff .to (total_attrib [i ].device )
12261311            total_attrib [i ] +=  (eval_diff  *  mask .to (attrib_type )).sum (dim = 0 )
1312+ 
12271313        return  total_attrib , weights 
12281314
12291315    def  _fut_tuple_to_accumulate_fut_list (
0 commit comments