2828)
2929from  captum ._utils .exceptions  import  FeatureAblationFutureError 
3030from  captum ._utils .progress  import  progress , SimpleProgress 
31- from  captum ._utils .typing  import  BaselineType , TargetType , TensorOrTupleOfTensorsGeneric 
31+ from  captum ._utils .typing  import  (
32+     BaselineTupleType ,
33+     BaselineType ,
34+     TargetType ,
35+     TensorOrTupleOfTensorsGeneric ,
36+ )
3237from  captum .attr ._utils .attribution  import  PerturbationAttribution 
3338from  captum .attr ._utils .common  import  _format_input_baseline 
3439from  captum .log  import  log_usage 
@@ -353,10 +358,12 @@ def attribute(
353358                    formatted_feature_mask ,
354359                    attr_progress ,
355360                    flattened_initial_eval ,
361+                     initial_eval ,
356362                    n_outputs ,
357363                    total_attrib ,
358364                    weights ,
359365                    attrib_type ,
366+                     perturbations_per_eval ,
360367                    ** kwargs ,
361368                )
362369            else :
@@ -470,10 +477,12 @@ def _attribute_with_cross_tensor_feature_masks(
470477        formatted_feature_mask : Tuple [Tensor , ...],
471478        attr_progress : Optional [Union [SimpleProgress [IterableType ], tqdm ]],
472479        flattened_initial_eval : Tensor ,
480+         initial_eval : Tensor ,
473481        n_outputs : int ,
474482        total_attrib : List [Tensor ],
475483        weights : List [Tensor ],
476484        attrib_type : dtype ,
485+         perturbations_per_eval : int ,
477486        ** kwargs : Any ,
478487    ) ->  Tuple [List [Tensor ], List [Tensor ]]:
479488        feature_idx_to_tensor_idx : Dict [int , List [int ]] =  {}
@@ -482,17 +491,78 @@ def _attribute_with_cross_tensor_feature_masks(
482491                if  feature_idx .item () not  in feature_idx_to_tensor_idx :
483492                    feature_idx_to_tensor_idx [feature_idx .item ()] =  []
484493                feature_idx_to_tensor_idx [feature_idx .item ()].append (i )
494+         all_feature_idxs  =  list (feature_idx_to_tensor_idx .keys ())
495+ 
496+         additional_args_repeated : object 
497+         if  perturbations_per_eval  >  1 :
498+             # Repeat features and additional args for batch size. 
499+             all_features_repeated  =  tuple (
500+                 torch .cat ([formatted_inputs [j ]] *  perturbations_per_eval , dim = 0 )
501+                 for  j  in  range (len (formatted_inputs ))
502+             )
503+             additional_args_repeated  =  (
504+                 _expand_additional_forward_args (
505+                     formatted_additional_forward_args , perturbations_per_eval 
506+                 )
507+                 if  formatted_additional_forward_args  is  not None 
508+                 else  None 
509+             )
510+             target_repeated  =  _expand_target (target , perturbations_per_eval )
511+         else :
512+             all_features_repeated  =  formatted_inputs 
513+             additional_args_repeated  =  formatted_additional_forward_args 
514+             target_repeated  =  target 
515+         num_examples  =  formatted_inputs [0 ].shape [0 ]
516+ 
517+         current_additional_args : object 
518+         if  isinstance (baselines , tuple ):
519+             reshaped  =  False 
520+             reshaped_baselines : list [Union [Tensor , int , float ]] =  []
521+             for  baseline  in  baselines :
522+                 if  isinstance (baseline , Tensor ):
523+                     reshaped  =  True 
524+                     reshaped_baselines .append (
525+                         baseline .reshape ((1 ,) +  tuple (baseline .shape ))
526+                     )
527+                 else :
528+                     reshaped_baselines .append (baseline )
529+             baselines  =  tuple (reshaped_baselines ) if  reshaped  else  baselines 
530+         for  i  in  range (0 , len (all_feature_idxs ), perturbations_per_eval ):
531+             current_feature_idxs  =  all_feature_idxs [i  : i  +  perturbations_per_eval ]
532+             current_num_ablated_features  =  min (
533+                 perturbations_per_eval , len (current_feature_idxs )
534+             )
535+ 
536+             # Store appropriate inputs and additional args based on batch size. 
537+             if  current_num_ablated_features  !=  perturbations_per_eval :
538+                 current_additional_args  =  (
539+                     _expand_additional_forward_args (
540+                         formatted_additional_forward_args , current_num_ablated_features 
541+                     )
542+                     if  formatted_additional_forward_args  is  not None 
543+                     else  None 
544+                 )
545+                 current_target  =  _expand_target (target , current_num_ablated_features )
546+                 expanded_inputs  =  tuple (
547+                     feature_repeated [0  : current_num_ablated_features  *  num_examples ]
548+                     for  feature_repeated  in  all_features_repeated 
549+                 )
550+             else :
551+                 current_additional_args  =  additional_args_repeated 
552+                 current_target  =  target_repeated 
553+                 expanded_inputs  =  all_features_repeated 
554+ 
555+             current_inputs , current_masks  =  (
556+                 self ._construct_ablated_input_across_tensors (
557+                     expanded_inputs ,
558+                     formatted_feature_mask ,
559+                     baselines ,
560+                     current_feature_idxs ,
561+                     feature_idx_to_tensor_idx ,
562+                     current_num_ablated_features ,
563+                 )
564+             )
485565
486-         for  (
487-             current_inputs ,
488-             current_mask ,
489-         ) in  self ._ablation_generator (
490-             formatted_inputs ,
491-             baselines ,
492-             formatted_feature_mask ,
493-             feature_idx_to_tensor_idx ,
494-             ** kwargs ,
495-         ):
496566            # modified_eval has (n_feature_perturbed * n_outputs) elements 
497567            # shape: 
498568            #   agg mode: (*initial_eval.shape) 
@@ -501,8 +571,8 @@ def _attribute_with_cross_tensor_feature_masks(
501571            modified_eval  =  _run_forward (
502572                self .forward_func ,
503573                current_inputs ,
504-                 target ,
505-                 formatted_additional_forward_args ,
574+                 current_target ,
575+                 current_additional_args ,
506576            )
507577
508578            if  attr_progress  is  not None :
@@ -515,75 +585,65 @@ def _attribute_with_cross_tensor_feature_masks(
515585
516586            total_attrib , weights  =  self ._process_ablated_out_full (
517587                modified_eval ,
518-                 current_mask ,
588+                 current_masks ,
519589                flattened_initial_eval ,
520-                 formatted_inputs ,
590+                 initial_eval ,
591+                 current_inputs ,
521592                n_outputs ,
593+                 num_examples ,
522594                total_attrib ,
523595                weights ,
524596                attrib_type ,
597+                 perturbations_per_eval ,
525598            )
526599        return  total_attrib , weights 
527600
528-     def  _ablation_generator (
529-         self ,
530-         inputs : Tuple [Tensor , ...],
531-         baselines : BaselineType ,
532-         input_mask : Tuple [Tensor , ...],
533-         feature_idx_to_tensor_idx : Dict [int , List [int ]],
534-         ** kwargs : Any ,
535-     ) ->  Generator [
536-         Tuple [
537-             Tuple [Tensor , ...],
538-             Tuple [Optional [Tensor ], ...],
539-         ],
540-         None ,
541-         None ,
542-     ]:
543-         if  isinstance (baselines , torch .Tensor ):
544-             baselines  =  baselines .reshape ((1 ,) +  tuple (baselines .shape ))
545- 
546-         # Process one feature per time, rather than processing every input tensor 
547-         for  feature_idx  in  feature_idx_to_tensor_idx .keys ():
548-             ablated_inputs , current_masks  =  (
549-                 self ._construct_ablated_input_across_tensors (
550-                     inputs ,
551-                     input_mask ,
552-                     baselines ,
553-                     feature_idx ,
554-                     feature_idx_to_tensor_idx [feature_idx ],
555-                 )
556-             )
557-             yield  ablated_inputs , current_masks 
558- 
559601    def  _construct_ablated_input_across_tensors (
560602        self ,
561603        inputs : Tuple [Tensor , ...],
562604        input_mask : Tuple [Tensor , ...],
563605        baselines : BaselineType ,
564-         feature_idx : int ,
565-         tensor_idxs : List [int ],
606+         feature_idxs : List [int ],
607+         feature_idx_to_tensor_idx : Dict [int , List [int ]],
608+         current_num_ablated_features : int ,
566609    ) ->  Tuple [Tuple [Tensor , ...], Tuple [Optional [Tensor ], ...]]:
567- 
568610        ablated_inputs  =  []
569611        current_masks : List [Optional [Tensor ]] =  []
612+         tensor_idxs  =  {
613+             tensor_idx 
614+             for  sublist  in  (
615+                 feature_idx_to_tensor_idx [feature_idx ] for  feature_idx  in  feature_idxs 
616+             )
617+             for  tensor_idx  in  sublist 
618+         }
619+ 
570620        for  i , input_tensor  in  enumerate (inputs ):
571621            if  i  not  in tensor_idxs :
572622                ablated_inputs .append (input_tensor )
573623                current_masks .append (None )
574624                continue 
575-             tensor_mask  =  (input_mask [i ] ==  feature_idx ).to (input_tensor .device ).long ()
625+             tensor_mask  =  []
626+             ablated_input  =  input_tensor .clone ()
576627            baseline  =  baselines [i ] if  isinstance (baselines , tuple ) else  baselines 
577-             if   isinstance ( baseline ,  torch . Tensor ):
578-                 baseline  =  baseline . reshape (
579-                     ( 1 ,)  *  ( input_tensor .dim ()  -   baseline . dim ())  +   tuple ( baseline . shape ) 
628+             for   j ,  feature_idx   in   enumerate ( feature_idxs ):
629+                 original_input_size  =  (
630+                     input_tensor .shape [ 0 ]  //   current_num_ablated_features 
580631                )
581-             assert  baseline  is  not None , "baseline must be provided" 
582-             ablated_input  =  (
583-                 input_tensor  *  (1  -  tensor_mask ).to (input_tensor .dtype )
584-             ) +  (baseline  *  tensor_mask .to (input_tensor .dtype ))
632+                 start_idx  =  j  *  original_input_size 
633+                 end_idx  =  (j  +  1 ) *  original_input_size 
634+ 
635+                 mask  =  (input_mask [i ] ==  feature_idx ).to (input_tensor .device ).long ()
636+                 if  mask .ndim  ==  0 :
637+                     mask  =  mask .reshape ((1 ,) *  input_tensor .dim ())
638+                 tensor_mask .append (mask )
639+ 
640+                 assert  baseline  is  not None , "baseline must be provided" 
641+                 ablated_input [start_idx :end_idx ] =  input_tensor [start_idx :end_idx ] *  (
642+                     1  -  mask 
643+                 ) +  (baseline  *  mask .to (input_tensor .dtype ))
644+             current_masks .append (torch .stack (tensor_mask , dim = 0 ))
585645            ablated_inputs .append (ablated_input )
586-              current_masks . append ( tensor_mask ) 
646+ 
587647        return  tuple (ablated_inputs ), tuple (current_masks )
588648
589649    def  _initial_eval_to_processed_initial_eval_fut (
@@ -784,7 +844,7 @@ def _attribute_progress_setup(
784844            formatted_inputs , feature_mask , ** kwargs 
785845        )
786846        total_forwards  =  (
787-             int (sum (feature_counts ))
847+             math . ceil ( int (sum (feature_counts ))  /   perturbations_per_eval )
788848            if  enable_cross_tensor_attribution 
789849            else  sum (
790850                math .ceil (count  /  perturbations_per_eval ) for  count  in  feature_counts 
@@ -1194,13 +1254,46 @@ def _process_ablated_out_full(
11941254        modified_eval : Tensor ,
11951255        current_mask : Tuple [Optional [Tensor ], ...],
11961256        flattened_initial_eval : Tensor ,
1257+         initial_eval : Tensor ,
11971258        inputs : TensorOrTupleOfTensorsGeneric ,
11981259        n_outputs : int ,
1260+         num_examples : int ,
11991261        total_attrib : List [Tensor ],
12001262        weights : List [Tensor ],
12011263        attrib_type : dtype ,
1264+         perturbations_per_eval : int ,
12021265    ) ->  Tuple [List [Tensor ], List [Tensor ]]:
12031266        modified_eval  =  self ._parse_forward_out (modified_eval )
1267+         # if perturbations_per_eval > 1, the output shape must grow with 
1268+         # input and not be aggregated 
1269+         current_batch_size  =  inputs [0 ].shape [0 ]
1270+ 
1271+         # number of perturbation, which is not the same as 
1272+         # perturbations_per_eval when not enough features to perturb 
1273+         n_perturb  =  current_batch_size  /  num_examples 
1274+         if  perturbations_per_eval  >  1  and  not  self ._is_output_shape_valid :
1275+ 
1276+             current_output_shape  =  modified_eval .shape 
1277+ 
1278+             # use initial_eval as the forward of perturbations_per_eval = 1 
1279+             initial_output_shape  =  initial_eval .shape 
1280+ 
1281+             assert  (
1282+                 # check if the output is not a scalar 
1283+                 current_output_shape 
1284+                 and  initial_output_shape 
1285+                 # check if the output grow in same ratio, i.e., not agg 
1286+                 and  current_output_shape [0 ] ==  n_perturb  *  initial_output_shape [0 ]
1287+             ), (
1288+                 "When perturbations_per_eval > 1, forward_func's output " 
1289+                 "should be a tensor whose 1st dim grow with the input " 
1290+                 f"batch size: when input batch size is { num_examples }  
1291+                 f"the output shape is { initial_output_shape }  
1292+                 f"when input batch size is { current_batch_size }  
1293+                 f"the output shape is { current_output_shape }  
1294+             )
1295+ 
1296+             self ._is_output_shape_valid  =  True 
12041297
12051298        # reshape the leading dim for n_feature_perturbed 
12061299        # flatten each feature's eval outputs into 1D of (n_outputs) 
@@ -1209,9 +1302,6 @@ def _process_ablated_out_full(
12091302        eval_diff  =  flattened_initial_eval  -  modified_eval 
12101303        eval_diff_shape  =  eval_diff .shape 
12111304
1212-         # append the shape of one input example 
1213-         # to make it broadcastable to mask 
1214- 
12151305        if  self .use_weights :
12161306            for  weight , mask  in  zip (weights , current_mask ):
12171307                if  mask  is  not None :
@@ -1224,6 +1314,7 @@ def _process_ablated_out_full(
12241314            )
12251315            eval_diff  =  eval_diff .to (total_attrib [i ].device )
12261316            total_attrib [i ] +=  (eval_diff  *  mask .to (attrib_type )).sum (dim = 0 )
1317+ 
12271318        return  total_attrib , weights 
12281319
12291320    def  _fut_tuple_to_accumulate_fut_list (
0 commit comments