@@ -353,10 +353,12 @@ def attribute(
353353                    formatted_feature_mask ,
354354                    attr_progress ,
355355                    flattened_initial_eval ,
356+                     initial_eval ,
356357                    n_outputs ,
357358                    total_attrib ,
358359                    weights ,
359360                    attrib_type ,
361+                     perturbations_per_eval ,
360362                    ** kwargs ,
361363                )
362364            else :
@@ -470,10 +472,12 @@ def _attribute_with_cross_tensor_feature_masks(
470472        formatted_feature_mask : Tuple [Tensor , ...],
471473        attr_progress : Optional [Union [SimpleProgress [IterableType ], tqdm ]],
472474        flattened_initial_eval : Tensor ,
475+         initial_eval : Tensor ,
473476        n_outputs : int ,
474477        total_attrib : List [Tensor ],
475478        weights : List [Tensor ],
476479        attrib_type : dtype ,
480+         perturbations_per_eval : int ,
477481        ** kwargs : Any ,
478482    ) ->  Tuple [List [Tensor ], List [Tensor ]]:
479483        feature_idx_to_tensor_idx : Dict [int , List [int ]] =  {}
@@ -482,17 +486,66 @@ def _attribute_with_cross_tensor_feature_masks(
482486                if  feature_idx .item () not  in feature_idx_to_tensor_idx :
483487                    feature_idx_to_tensor_idx [feature_idx .item ()] =  []
484488                feature_idx_to_tensor_idx [feature_idx .item ()].append (i )
489+         all_feature_idxs  =  list (feature_idx_to_tensor_idx .keys ())
490+         additional_args_repeated : object 
491+         if  perturbations_per_eval  >  1 :
492+             additional_args_repeated  =  (
493+                 _expand_additional_forward_args (
494+                     formatted_additional_forward_args , perturbations_per_eval 
495+                 )
496+                 if  formatted_additional_forward_args  is  not None 
497+                 else  None 
498+             )
499+             target_repeated  =  _expand_target (target , perturbations_per_eval )
500+         else :
501+             additional_args_repeated  =  formatted_additional_forward_args 
502+             target_repeated  =  target 
503+         num_examples  =  formatted_inputs [0 ].shape [0 ]
504+ 
505+         current_additional_args : object 
506+         # Process one feature per time, rather than processing every input tensor 
507+         for  i  in  range (0 , len (all_feature_idxs ), perturbations_per_eval ):
508+             current_feature_idxs  =  all_feature_idxs [i  : i  +  perturbations_per_eval ]
509+             current_num_ablated_features  =  min (
510+                 perturbations_per_eval , len (current_feature_idxs )
511+             )
512+ 
513+             # Store appropriate inputs and additional args based on batch size. 
514+             if  current_num_ablated_features  !=  perturbations_per_eval :
515+                 current_additional_args  =  (
516+                     _expand_additional_forward_args (
517+                         formatted_additional_forward_args , current_num_ablated_features 
518+                     )
519+                     if  formatted_additional_forward_args  is  not None 
520+                     else  None 
521+                 )
522+                 current_target  =  _expand_target (target , current_num_ablated_features )
523+             else :
524+                 current_additional_args  =  additional_args_repeated 
525+                 current_target  =  target_repeated 
526+ 
527+             current_inputs  =  ()
528+             current_masks  =  []
529+             for  (
530+                 single_perturb_input ,
531+                 single_perturb_masks ,
532+             ) in  self ._ablation_generator (
533+                 formatted_inputs ,
534+                 baselines ,
535+                 formatted_feature_mask ,
536+                 current_feature_idxs ,
537+                 feature_idx_to_tensor_idx ,
538+                 ** kwargs ,
539+             ):
540+                 if  len (current_inputs ) ==  0 :
541+                     current_inputs  =  single_perturb_input 
542+                 else :
543+                     current_inputs  =  tuple (
544+                         torch .cat ((current_inputs [j ], single_perturb_input [j ]), dim = 0 )
545+                         for  j  in  range (len (current_inputs ))
546+                     )
547+                 current_masks .append (list (single_perturb_masks ))
485548
486-         for  (
487-             current_inputs ,
488-             current_mask ,
489-         ) in  self ._ablation_generator (
490-             formatted_inputs ,
491-             baselines ,
492-             formatted_feature_mask ,
493-             feature_idx_to_tensor_idx ,
494-             ** kwargs ,
495-         ):
496549            # modified_eval has (n_feature_perturbed * n_outputs) elements 
497550            # shape: 
498551            #   agg mode: (*initial_eval.shape) 
@@ -501,8 +554,8 @@ def _attribute_with_cross_tensor_feature_masks(
501554            modified_eval  =  _run_forward (
502555                self .forward_func ,
503556                current_inputs ,
504-                 target ,
505-                 formatted_additional_forward_args ,
557+                 current_target ,
558+                 current_additional_args ,
506559            )
507560
508561            if  attr_progress  is  not None :
@@ -515,13 +568,16 @@ def _attribute_with_cross_tensor_feature_masks(
515568
516569            total_attrib , weights  =  self ._process_ablated_out_full (
517570                modified_eval ,
518-                 current_mask ,
571+                 current_masks ,
519572                flattened_initial_eval ,
520-                 formatted_inputs ,
573+                 initial_eval ,
574+                 current_inputs ,
521575                n_outputs ,
576+                 num_examples ,
522577                total_attrib ,
523578                weights ,
524579                attrib_type ,
580+                 perturbations_per_eval ,
525581            )
526582        return  total_attrib , weights 
527583
@@ -530,6 +586,7 @@ def _ablation_generator(
530586        inputs : Tuple [Tensor , ...],
531587        baselines : BaselineType ,
532588        input_mask : Tuple [Tensor , ...],
589+         feature_idxs : List [int ],
533590        feature_idx_to_tensor_idx : Dict [int , List [int ]],
534591        ** kwargs : Any ,
535592    ) ->  Generator [
@@ -540,11 +597,8 @@ def _ablation_generator(
540597        None ,
541598        None ,
542599    ]:
543-         if  isinstance (baselines , torch .Tensor ):
544-             baselines  =  baselines .reshape ((1 ,) +  tuple (baselines .shape ))
545- 
546600        # Process one feature per time, rather than processing every input tensor 
547-         for  feature_idx  in  feature_idx_to_tensor_idx . keys () :
601+         for  feature_idx  in  feature_idxs :
548602            ablated_inputs , current_masks  =  (
549603                self ._construct_ablated_input_across_tensors (
550604                    inputs ,
@@ -784,7 +838,7 @@ def _attribute_progress_setup(
784838            formatted_inputs , feature_mask , ** kwargs 
785839        )
786840        total_forwards  =  (
787-             int (sum (feature_counts ))
841+             math . ceil ( int (sum (feature_counts ))  /   perturbations_per_eval )
788842            if  enable_cross_tensor_attribution 
789843            else  sum (
790844                math .ceil (count  /  perturbations_per_eval ) for  count  in  feature_counts 
@@ -1187,43 +1241,76 @@ def _process_ablated_out(
11871241            weights [i ] +=  current_mask .float ().sum (dim = 0 )
11881242
11891243        total_attrib [i ] +=  (eval_diff  *  current_mask .to (attrib_type )).sum (dim = 0 )
1244+         print (i , weights )
11901245        return  total_attrib , weights 
11911246
11921247    def  _process_ablated_out_full (
11931248        self ,
11941249        modified_eval : Tensor ,
1195-         current_mask : Tuple [ Optional [Tensor ], ... ],
1250+         current_mask : List [ List [ Optional [Tensor ]] ],
11961251        flattened_initial_eval : Tensor ,
1252+         initial_eval : Tensor ,
11971253        inputs : TensorOrTupleOfTensorsGeneric ,
11981254        n_outputs : int ,
1255+         num_examples : int ,
11991256        total_attrib : List [Tensor ],
12001257        weights : List [Tensor ],
12011258        attrib_type : dtype ,
1259+         perturbations_per_eval : int ,
12021260    ) ->  Tuple [List [Tensor ], List [Tensor ]]:
12031261        modified_eval  =  self ._parse_forward_out (modified_eval )
1262+         # if perturbations_per_eval > 1, the output shape must grow with 
1263+         # input and not be aggregated 
1264+         current_batch_size  =  inputs [0 ].shape [0 ]
1265+ 
1266+         # number of perturbation, which is not the same as 
1267+         # perturbations_per_eval when not enough features to perturb 
1268+         n_perturb  =  current_batch_size  /  num_examples 
1269+         if  perturbations_per_eval  >  1  and  not  self ._is_output_shape_valid :
1270+ 
1271+             current_output_shape  =  modified_eval .shape 
1272+ 
1273+             # use initial_eval as the forward of perturbations_per_eval = 1 
1274+             initial_output_shape  =  initial_eval .shape 
1275+ 
1276+             assert  (
1277+                 # check if the output is not a scalar 
1278+                 current_output_shape 
1279+                 and  initial_output_shape 
1280+                 # check if the output grow in same ratio, i.e., not agg 
1281+                 and  current_output_shape [0 ] ==  n_perturb  *  initial_output_shape [0 ]
1282+             ), (
1283+                 "When perturbations_per_eval > 1, forward_func's output " 
1284+                 "should be a tensor whose 1st dim grow with the input " 
1285+                 f"batch size: when input batch size is { num_examples }  
1286+                 f"the output shape is { initial_output_shape }  
1287+                 f"when input batch size is { current_batch_size }  
1288+                 f"the output shape is { current_output_shape }  
1289+             )
1290+ 
1291+             self ._is_output_shape_valid  =  True 
12041292
12051293        # reshape the leading dim for n_feature_perturbed 
12061294        # flatten each feature's eval outputs into 1D of (n_outputs) 
12071295        modified_eval  =  modified_eval .reshape (- 1 , n_outputs )
12081296        # eval_diff in shape (n_feature_perturbed, n_outputs) 
12091297        eval_diff  =  flattened_initial_eval  -  modified_eval 
1210-         eval_diff_shape  =  eval_diff .shape 
1211- 
1212-         # append the shape of one input example 
1213-         # to make it broadcastable to mask 
12141298
1215-         if  self .use_weights :
1216-             for  weight , mask  in  zip (weights , current_mask ):
1217-                 if  mask  is  not None :
1218-                     weight  +=  mask .float ().sum (dim = 0 )
1219-         for  i , mask  in  enumerate (current_mask ):
1220-             if  mask  is  None  or  inputs [i ].numel () ==  0 :
1221-                 continue 
1222-             eval_diff  =  eval_diff .reshape (
1223-                 eval_diff_shape  +  (inputs [i ].dim () -  1 ) *  (1 ,)
1224-             )
1225-             eval_diff  =  eval_diff .to (total_attrib [i ].device )
1226-             total_attrib [i ] +=  (eval_diff  *  mask .to (attrib_type )).sum (dim = 0 )
1299+         for  j  in  range (int (n_perturb )):
1300+             single_perturb_mask  =  current_mask [j ]
1301+             if  self .use_weights :
1302+                 for  weight , mask  in  zip (weights , single_perturb_mask ):
1303+                     if  mask  is  not None :
1304+                         weight  +=  mask .float ()
1305+             for  i , mask  in  enumerate (single_perturb_mask ):
1306+                 this_input  =  inputs [i ][j  *  num_examples  : (j  +  1 ) *  num_examples ]
1307+                 if  mask  is  None  or  this_input .numel () ==  0 :
1308+                     continue 
1309+                 eval_diff_j  =  eval_diff [j ].reshape (
1310+                     eval_diff [j ].shape  +  (this_input .dim () -  1 ) *  (1 ,)
1311+                 )
1312+                 eval_diff_j  =  eval_diff_j .to (total_attrib [i ].device )
1313+                 total_attrib [i ] +=  eval_diff_j  *  mask .to (attrib_type )
12271314        return  total_attrib , weights 
12281315
12291316    def  _fut_tuple_to_accumulate_fut_list (
0 commit comments