-
Notifications
You must be signed in to change notification settings - Fork 546
Description
Motivation
The computation of Concept Activation Vectors (CAVs) is a fundamental component of concept-based explanation methods, most notably Testing with Concept Activation Vectors (TCAV), which is already implemented in Captum. The current standard for computing CAVs in Captum relies on training a Support Vector Machines (SVMs) for each concept. This can present a significant challenge, especially for modern, high-dimensional models. To address this, we propose adding an efficient drop-in with FastCAV. We provide benchmarking and theoretical justification to support its effectiveness.
Captum FastCAV API Design
Background
FastCAV is a novel approach introduced at ICML 2025 based on insights into superposition, that accelerates the computation of CAVs. FastCAV defines the CAV as the vector from the global mean of all activations to the mean of the positive class activations. As such it acts as a drop-in replacement for the prevalent SVM for balanced classes. Additionally, we provide mathematical assumptions under which FastCAV is equivalent to an SVM.
Requires:
-
Model: The trained neural network model
$f$ to be interpreted. -
Layer: A specific layer
$l$ within the model from which to extract activations. -
Concept Dataset
$D_c$ : A set of input examples that visually represent the concept to be analyzed. -
Control Dataset
$D_r$ : A set of random or control examples used as a baseline.
Pseudocode:
def fit(self, x: Tensor, y: Tensor) -> None:
self.mean = x.mean(dim=0)
self.coef_ = (x[y == self.classes_[-1]] -self.mean).mean(dim=0).unsqueeze(0)
self.intercept_ = (-self.coef_ @ self.mean).unsqueeze(1)Benchmarking
Performance benchmarking:
Given the example in the tutorial tutorials/TCAV_Image.ipynb we achieve an average speed-up to the DefaultClassifier of ~8.26 (SVM: 1.09
For bigger models this improves more drastically as shown in Table 1. of the paper.
Visual Benchmarking
Given the example in the tutorial tutorials/TCAV_Image.ipynb we get the following results for experimental_set_rand with FastCAV and SVM:
FastCAV:
and for experimental_set_zig_dot:
Mathematical Assumptions:
Under the given assumptions FastCAV is equivalent to an SVM:
- Gaussian Distribution: The activation vectors for both the random samples and the concept samples are assumed to follow independent multivariate Gaussian distributions.
-
Equal Mixture: The set of concept examples and the set of random examples are of equal size (
$∣Dc∣=∣Dr∣$ ), resulting in a uniform mixture of the two Gaussian distributions. - Isotropic Covariance: The within-class covariance matrices are assumed to be isotropic, meaning they are proportional to the unit matrix. This is a critical assumption that makes the FastCAV solution equivalent to the solution of a Fisher discriminant analysis.
-
High-Dimensionality: The method is analyzed in the context of high-dimensional activation spaces, where the number of dimensions
$d$ is significantly larger than the number of samples$n$ ($d \gg n$ ). In such spaces, the set of support vectors used by an SVM is likely to contain most of the training samples, making the SVM solution approximate the Fisher discriminant solution, and by extension, the FastCAV solution.
Proposed Captum API Design:
The integration of FastCAV into Captum is designed to be a high-performance drop-in to the default DefaultClassifier which utilizes an SVM. It is exposed to the user through the FastCAVClassifier class, which leverages a FastCAVLinearModel internally.
E.g. for the tutorial tutorials/TCAV_Image.ipynb:
fast_clf = classifier.FastCAVClassifier()
mytcav = TCAV(model=model,
layers=layers,
classifier=fast_clf,
layer_attr_method = LayerIntegratedGradients(
model, None, multiply_by_inputs=False))FastCAV
This is a low-level utility class, similar in interface to scikit-learn classifiers. It contains the core logic for computing the CAV. It is not a torch.nn.Module and is not intended for direct use within most Captum workflows, but provides the fundamental algorithm.
Constructor:
FastCAV(**kwargs)
Argument Descriptions:
kwargs- The constructor currently accepts but ignores any keyword argumentskwargsto maintain a consistent interface with other classifiers.
Methods:
fit(x: Tensor, y: Tensor): Takes tensorsxof activations and a tensoryof labels and computes thecoef_andintercept_.x: A 2D tensor of shape(n_samples, n_features)containing the input data (e.g., model activations).y: A 1D tensor of shape(n_samples,)containing binary labels (0 or 1).
predict(x: Tensor) -> Tensor: Predicts class labels for new data points based on the fitted hyperplane.
classes() -> Tensor: Returns the unique class labels the model was fitted on.
fastcav_train_linear_model
This function serves as the bridge between the FastCAV logic and Captum's LinearModel interface. It is designed to be used as the train_fn for a LinearModel. It orchestrates the process of fitting a FastCAV instance and then using its learned parameters to configure a LinearModel object
.
Signature:
fastcav_train_linear_model(model: LinearModel, dataloader: DataLoader, construct_kwargs: Dict[str, Any], norm_input: bool = False, **fit_kwargs: Any) -> Dict[str, float]Argument Descriptions:
model: TheLinearModelinstance to be configured.dataloader: ADataLoaderproviding the training data (activations and labels). Right now this followssklearn_train_linear_modeland iterates through the entire dataloader to collect all data into memory.construct_kwargs: Keyword arguments passed to the FastCAV constructor. AsFastCAVdoes not acceptkwargsthis is to maintain a consistent interface.norm_input: A boolean indicating whether to normalize the input data before fitting.fit_kwargs: Additional keyword arguments for the fit method (currently unused by FastCAV).
FastCAVLinearModel
This class acts as a bridge, wrapping the FastCAV logic into the LinearModel interface that is standard within Captum's concept utilities. This allows it to be used by higher-level abstractions.
Constructor:
FastCAVLinearModel(**kwargs)- The constructor simply calls the parent
LinearModelconstructor, passingfastcav_train_linear_modelas thetrain_fn. Anykwargsare stored and passed to the training function during thefitcall.
FastCAVClassifier
This is the main user-facing class. It is a drop-in replacement for DefaultClassifier for users of high-level APIs like TCAV or different future concept based explanations like ACE . By simply switching the classifier, users can leverage the performance benefits of FastCAV without changing the rest of their workflow. It inherits from the Default Classifier using a FastCAVLinearModel as its internal engine.
Constructor:
FastCAVClassifier()