2828)
2929from captum ._utils .exceptions import FeatureAblationFutureError
3030from captum ._utils .progress import progress , SimpleProgress
31- from captum ._utils .typing import BaselineType , TargetType , TensorOrTupleOfTensorsGeneric
31+ from captum ._utils .typing import (
32+ BaselineTupleType ,
33+ BaselineType ,
34+ TargetType ,
35+ TensorOrTupleOfTensorsGeneric ,
36+ )
3237from captum .attr ._utils .attribution import PerturbationAttribution
3338from captum .attr ._utils .common import _format_input_baseline
3439from captum .log import log_usage
@@ -353,10 +358,12 @@ def attribute(
353358 formatted_feature_mask ,
354359 attr_progress ,
355360 flattened_initial_eval ,
361+ initial_eval ,
356362 n_outputs ,
357363 total_attrib ,
358364 weights ,
359365 attrib_type ,
366+ perturbations_per_eval ,
360367 ** kwargs ,
361368 )
362369 else :
@@ -470,10 +477,12 @@ def _attribute_with_cross_tensor_feature_masks(
470477 formatted_feature_mask : Tuple [Tensor , ...],
471478 attr_progress : Optional [Union [SimpleProgress [IterableType ], tqdm ]],
472479 flattened_initial_eval : Tensor ,
480+ initial_eval : Tensor ,
473481 n_outputs : int ,
474482 total_attrib : List [Tensor ],
475483 weights : List [Tensor ],
476484 attrib_type : dtype ,
485+ perturbations_per_eval : int ,
477486 ** kwargs : Any ,
478487 ) -> Tuple [List [Tensor ], List [Tensor ]]:
479488 feature_idx_to_tensor_idx : Dict [int , List [int ]] = {}
@@ -482,17 +491,78 @@ def _attribute_with_cross_tensor_feature_masks(
482491 if feature_idx .item () not in feature_idx_to_tensor_idx :
483492 feature_idx_to_tensor_idx [feature_idx .item ()] = []
484493 feature_idx_to_tensor_idx [feature_idx .item ()].append (i )
494+ all_feature_idxs = list (feature_idx_to_tensor_idx .keys ())
495+
496+ additional_args_repeated : object
497+ if perturbations_per_eval > 1 :
498+ # Repeat features and additional args for batch size.
499+ all_features_repeated = tuple (
500+ torch .cat ([formatted_inputs [j ]] * perturbations_per_eval , dim = 0 )
501+ for j in range (len (formatted_inputs ))
502+ )
503+ additional_args_repeated = (
504+ _expand_additional_forward_args (
505+ formatted_additional_forward_args , perturbations_per_eval
506+ )
507+ if formatted_additional_forward_args is not None
508+ else None
509+ )
510+ target_repeated = _expand_target (target , perturbations_per_eval )
511+ else :
512+ all_features_repeated = formatted_inputs
513+ additional_args_repeated = formatted_additional_forward_args
514+ target_repeated = target
515+ num_examples = formatted_inputs [0 ].shape [0 ]
516+
517+ current_additional_args : object
518+ if isinstance (baselines , tuple ):
519+ reshaped = False
520+ reshaped_baselines : list [Union [Tensor , int , float ]] = []
521+ for baseline in baselines :
522+ if isinstance (baseline , Tensor ):
523+ reshaped = True
524+ reshaped_baselines .append (
525+ baseline .reshape ((1 ,) + tuple (baseline .shape ))
526+ )
527+ else :
528+ reshaped_baselines .append (baseline )
529+ baselines = tuple (reshaped_baselines ) if reshaped else baselines
530+ for i in range (0 , len (all_feature_idxs ), perturbations_per_eval ):
531+ current_feature_idxs = all_feature_idxs [i : i + perturbations_per_eval ]
532+ current_num_ablated_features = min (
533+ perturbations_per_eval , len (current_feature_idxs )
534+ )
535+
536+ # Store appropriate inputs and additional args based on batch size.
537+ if current_num_ablated_features != perturbations_per_eval :
538+ current_additional_args = (
539+ _expand_additional_forward_args (
540+ formatted_additional_forward_args , current_num_ablated_features
541+ )
542+ if formatted_additional_forward_args is not None
543+ else None
544+ )
545+ current_target = _expand_target (target , current_num_ablated_features )
546+ expanded_inputs = tuple (
547+ feature_repeated [0 : current_num_ablated_features * num_examples ]
548+ for feature_repeated in all_features_repeated
549+ )
550+ else :
551+ current_additional_args = additional_args_repeated
552+ current_target = target_repeated
553+ expanded_inputs = all_features_repeated
554+
555+ current_inputs , current_masks = (
556+ self ._construct_ablated_input_across_tensors (
557+ expanded_inputs ,
558+ formatted_feature_mask ,
559+ baselines ,
560+ current_feature_idxs ,
561+ feature_idx_to_tensor_idx ,
562+ current_num_ablated_features ,
563+ )
564+ )
485565
486- for (
487- current_inputs ,
488- current_mask ,
489- ) in self ._ablation_generator (
490- formatted_inputs ,
491- baselines ,
492- formatted_feature_mask ,
493- feature_idx_to_tensor_idx ,
494- ** kwargs ,
495- ):
496566 # modified_eval has (n_feature_perturbed * n_outputs) elements
497567 # shape:
498568 # agg mode: (*initial_eval.shape)
@@ -501,8 +571,8 @@ def _attribute_with_cross_tensor_feature_masks(
501571 modified_eval = _run_forward (
502572 self .forward_func ,
503573 current_inputs ,
504- target ,
505- formatted_additional_forward_args ,
574+ current_target ,
575+ current_additional_args ,
506576 )
507577
508578 if attr_progress is not None :
@@ -515,75 +585,65 @@ def _attribute_with_cross_tensor_feature_masks(
515585
516586 total_attrib , weights = self ._process_ablated_out_full (
517587 modified_eval ,
518- current_mask ,
588+ current_masks ,
519589 flattened_initial_eval ,
520- formatted_inputs ,
590+ initial_eval ,
591+ current_inputs ,
521592 n_outputs ,
593+ num_examples ,
522594 total_attrib ,
523595 weights ,
524596 attrib_type ,
597+ perturbations_per_eval ,
525598 )
526599 return total_attrib , weights
527600
528- def _ablation_generator (
529- self ,
530- inputs : Tuple [Tensor , ...],
531- baselines : BaselineType ,
532- input_mask : Tuple [Tensor , ...],
533- feature_idx_to_tensor_idx : Dict [int , List [int ]],
534- ** kwargs : Any ,
535- ) -> Generator [
536- Tuple [
537- Tuple [Tensor , ...],
538- Tuple [Optional [Tensor ], ...],
539- ],
540- None ,
541- None ,
542- ]:
543- if isinstance (baselines , torch .Tensor ):
544- baselines = baselines .reshape ((1 ,) + tuple (baselines .shape ))
545-
546- # Process one feature per time, rather than processing every input tensor
547- for feature_idx in feature_idx_to_tensor_idx .keys ():
548- ablated_inputs , current_masks = (
549- self ._construct_ablated_input_across_tensors (
550- inputs ,
551- input_mask ,
552- baselines ,
553- feature_idx ,
554- feature_idx_to_tensor_idx [feature_idx ],
555- )
556- )
557- yield ablated_inputs , current_masks
558-
559601 def _construct_ablated_input_across_tensors (
560602 self ,
561603 inputs : Tuple [Tensor , ...],
562604 input_mask : Tuple [Tensor , ...],
563605 baselines : BaselineType ,
564- feature_idx : int ,
565- tensor_idxs : List [int ],
606+ feature_idxs : List [int ],
607+ feature_idx_to_tensor_idx : Dict [int , List [int ]],
608+ current_num_ablated_features : int ,
566609 ) -> Tuple [Tuple [Tensor , ...], Tuple [Optional [Tensor ], ...]]:
567-
568610 ablated_inputs = []
569611 current_masks : List [Optional [Tensor ]] = []
612+ tensor_idxs = {
613+ tensor_idx
614+ for sublist in (
615+ feature_idx_to_tensor_idx [feature_idx ] for feature_idx in feature_idxs
616+ )
617+ for tensor_idx in sublist
618+ }
619+
570620 for i , input_tensor in enumerate (inputs ):
571621 if i not in tensor_idxs :
572622 ablated_inputs .append (input_tensor )
573623 current_masks .append (None )
574624 continue
575- tensor_mask = (input_mask [i ] == feature_idx ).to (input_tensor .device ).long ()
625+ tensor_mask = []
626+ ablated_input = input_tensor .clone ()
576627 baseline = baselines [i ] if isinstance (baselines , tuple ) else baselines
577- if isinstance ( baseline , torch . Tensor ):
578- baseline = baseline . reshape (
579- ( 1 ,) * ( input_tensor .dim () - baseline . dim ()) + tuple ( baseline . shape )
628+ for j , feature_idx in enumerate ( feature_idxs ):
629+ original_input_size = (
630+ input_tensor .shape [ 0 ] // current_num_ablated_features
580631 )
581- assert baseline is not None , "baseline must be provided"
582- ablated_input = (
583- input_tensor * (1 - tensor_mask ).to (input_tensor .dtype )
584- ) + (baseline * tensor_mask .to (input_tensor .dtype ))
632+ start_idx = j * original_input_size
633+ end_idx = (j + 1 ) * original_input_size
634+
635+ mask = (input_mask [i ] == feature_idx ).to (input_tensor .device ).long ()
636+ if mask .ndim == 0 :
637+ mask = mask .reshape ((1 ,) * input_tensor .dim ())
638+ tensor_mask .append (mask )
639+
640+ assert baseline is not None , "baseline must be provided"
641+ ablated_input [start_idx :end_idx ] = input_tensor [start_idx :end_idx ] * (
642+ 1 - mask
643+ ) + (baseline * mask .to (input_tensor .dtype ))
644+ current_masks .append (torch .stack (tensor_mask , dim = 0 ))
585645 ablated_inputs .append (ablated_input )
586- current_masks . append ( tensor_mask )
646+
587647 return tuple (ablated_inputs ), tuple (current_masks )
588648
589649 def _initial_eval_to_processed_initial_eval_fut (
@@ -784,7 +844,7 @@ def _attribute_progress_setup(
784844 formatted_inputs , feature_mask , ** kwargs
785845 )
786846 total_forwards = (
787- int (sum (feature_counts ))
847+ math . ceil ( int (sum (feature_counts )) / perturbations_per_eval )
788848 if enable_cross_tensor_attribution
789849 else sum (
790850 math .ceil (count / perturbations_per_eval ) for count in feature_counts
@@ -1194,13 +1254,46 @@ def _process_ablated_out_full(
11941254 modified_eval : Tensor ,
11951255 current_mask : Tuple [Optional [Tensor ], ...],
11961256 flattened_initial_eval : Tensor ,
1257+ initial_eval : Tensor ,
11971258 inputs : TensorOrTupleOfTensorsGeneric ,
11981259 n_outputs : int ,
1260+ num_examples : int ,
11991261 total_attrib : List [Tensor ],
12001262 weights : List [Tensor ],
12011263 attrib_type : dtype ,
1264+ perturbations_per_eval : int ,
12021265 ) -> Tuple [List [Tensor ], List [Tensor ]]:
12031266 modified_eval = self ._parse_forward_out (modified_eval )
1267+ # if perturbations_per_eval > 1, the output shape must grow with
1268+ # input and not be aggregated
1269+ current_batch_size = inputs [0 ].shape [0 ]
1270+
1271+ # number of perturbation, which is not the same as
1272+ # perturbations_per_eval when not enough features to perturb
1273+ n_perturb = current_batch_size / num_examples
1274+ if perturbations_per_eval > 1 and not self ._is_output_shape_valid :
1275+
1276+ current_output_shape = modified_eval .shape
1277+
1278+ # use initial_eval as the forward of perturbations_per_eval = 1
1279+ initial_output_shape = initial_eval .shape
1280+
1281+ assert (
1282+ # check if the output is not a scalar
1283+ current_output_shape
1284+ and initial_output_shape
1285+ # check if the output grow in same ratio, i.e., not agg
1286+ and current_output_shape [0 ] == n_perturb * initial_output_shape [0 ]
1287+ ), (
1288+ "When perturbations_per_eval > 1, forward_func's output "
1289+ "should be a tensor whose 1st dim grow with the input "
1290+ f"batch size: when input batch size is { num_examples } , "
1291+ f"the output shape is { initial_output_shape } ; "
1292+ f"when input batch size is { current_batch_size } , "
1293+ f"the output shape is { current_output_shape } "
1294+ )
1295+
1296+ self ._is_output_shape_valid = True
12041297
12051298 # reshape the leading dim for n_feature_perturbed
12061299 # flatten each feature's eval outputs into 1D of (n_outputs)
@@ -1209,9 +1302,6 @@ def _process_ablated_out_full(
12091302 eval_diff = flattened_initial_eval - modified_eval
12101303 eval_diff_shape = eval_diff .shape
12111304
1212- # append the shape of one input example
1213- # to make it broadcastable to mask
1214-
12151305 if self .use_weights :
12161306 for weight , mask in zip (weights , current_mask ):
12171307 if mask is not None :
@@ -1224,6 +1314,7 @@ def _process_ablated_out_full(
12241314 )
12251315 eval_diff = eval_diff .to (total_attrib [i ].device )
12261316 total_attrib [i ] += (eval_diff * mask .to (attrib_type )).sum (dim = 0 )
1317+
12271318 return total_attrib , weights
12281319
12291320 def _fut_tuple_to_accumulate_fut_list (
0 commit comments