Skip to content

`dattri` is a PyTorch library for developing, benchmarking, and deploying efficient data attribution algorithms.

Notifications You must be signed in to change notification settings

TRAIS-Lab/dattri

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

logo

dattri: A Library for Efficient Data Attribution

dattri is a PyTorch library for developing, benchmarking, and deploying efficient data attribution algorithms. You may use dattri to

  • Deploy existing data attribution methods to PyTorch models
    • e.g., Influence Function, TracIn, RPS, TRAK, ...
  • Develop new data attribution methods with efficient implementation of low-level utility functions
    • e.g., Hessian (HVP/IHVP), Fisher Information Matrix (IFVP), random projection, dropout ensembling, ...
  • Benchmark data attribution methods with standard benchmark settings
    • e.g., MNIST-10+LR/MLP, CIFAR-10/2+ResNet-9, MAESTRO + Music Transformer, Shakespeare + nanoGPT, ...

Quick Start

Installation

git clone https://github.com/TRAIS-Lab/dattri
pip install -e .

If you want to use all features on CUDA and accelerate the library, you may install the full version by

pip install -e .[all]

Note

It's highly recommended to use a device support CUDA to run dattri, especially for moderately large or larger models or datasets. And it's required to have CUDA if you want to install the full version dattri.

Note

If you are using dattri[all], please use pip<23 and torch<2.3 due to some known issue of fast_jl library.

Apply Data Attribution methods on PyTorch Models

One can apply different data attribution methods on PyTorch Models. One only needs to define:

  1. loss function used for model training (will be used as target function to be attributed if no other target function provided).
  2. trained model checkpoints.
  3. the data loaders for training samples and test samples (e.g., train_loader, test_loader).
  4. (optional) target function to be attributed if it's not the same as loss function.

The following is an example to use IFAttributor to apply data attribution to a PyTorch model.

import torch
from torch import nn

from dattri.algorithm.influence_function import IFAttributor
from dattri.benchmark.datasets.mnist import train_mnist_lr, create_mnist_dataset
from dattri.func.utils import flatten_func
from dattri.benchmark.utils import SubsetSampler


dataset_train, dataset_test = create_mnist_dataset("./data")

train_loader = torch.utils.data.DataLoader(
    dataset_train,
    batch_size=64,
    sampler=SubsetSampler(range(1000)),
)
test_loader = torch.utils.data.DataLoader(
    dataset_test,
    batch_size=64,
    sampler=SubsetSampler(range(100)),
)

model = train_mnist_lr(train_loader)

@flatten_func(model)
def f(params, data_target_pair):
    x, y = data_target_pair
    loss = nn.CrossEntropyLoss()
    yhat = torch.func.functional_call(model, params, x)
    return loss(yhat, y)

model_params = {k: p for k, p in model.named_parameters() if p.requires_grad}
attributor = IFAttributor(
    target_func=f,
    params=model_params,
    ihvp_solver="cg",
    ihvp_kwargs={"max_iter": 10, "regularization": 1e-2},
)

attributor.cache(train_loader)
score = attributor.attribute(train_loader, test_loader)

Use low-level utility functions to develop new data attribution methods

HVP/IHVP

Hessian-vector product (HVP), inverse-Hessian-vector product (IHVP) are widely used in data attribution methods. dattri provides efficient implementation to these operators by torch.func. This example shows how to use the CG implementation of the IHVP implementation.

from dattri.func.hessian import ihvp_cg, ihvp_at_x_cg

def f(x, param):
    return torch.sin(x / param).sum()

x = torch.randn(2)
param = torch.randn(1)
v = torch.randn(5, 2)

# ihvp_cg method
ihvp_func = ihvp_cg(f, argnums=0, max_iter=2) # argnums=0 indicates that the param of (x, param) to be passed to ihvp_func is the model parameter
ihvp_result_1 = ihvp_func((x, param), v) # both (x, param) and v as the inputs
# ihvp_at_x_cg method: (x, param) is cached
ihvp_at_x_func = ihvp_at_x_cg(f, x, param, argnums=0, max_iter=2)
ihvp_result_2 = ihvp_at_x_func(v) # only v as the input
# the above two will give the same result
assert torch.allclose(ihvp_result_1, ihvp_result_2)

Random Projection

It has been shown that long vectors will retain most of their relative information when projected down to a smaller feature dimension. To reduce the computational cost, random projection is widely used in data attribution methods. Following is an example to use random_project. The implementation leaverges fast_jl.

from dattri.func.random_projection import random_project

# initialize the projector based on users' needs
project_func = random_project(tensor, tensor.size(0), proj_dim=512)

# obtain projected tensors
projected_tensor = project_func(torch.full_like(tensor))

Normally speaking, tensor is probably the gradient of loss/target function and has a large dimension (i.e., the number of parameters).

Dropout Ensemble

Recent studies found that ensemble methods can significantly improve the performance of data attribution, DROPOUT ENSEMBLE is one of these ensemble methods. One may prepare their model with

from dattri.model_utils.dropout import activate_dropout

# initialize a torch.nn.Module model
model = MLP()

# (option 1) activate all dropout layers
model = activate_dropout(model, dropout_prob=0.2)

# (option 2) activate specific dropout layers
# here "dropout1" and "dropout2" are the names of dropout layers within the model
model = activate_dropout(model, ["dropout1", "dropout2"], dropout_prob=0.2)

Algorithms Supported

Family Algorithms
IF Explicit
CG
LiSSA
Arnoldi
DataInf
EK-FAC
TracIn TracInCP
Grad-Dot
Grad-Cos
RPS RPS-L2
TRAK TRAK

Metrics Supported

  • Leave-one-out (LOO) correlation
  • Linear datamodeling score (LDS)
  • Area under the ROC curve (AUC) for noisy label detection
  • Brittleness test for checking flipped label

Benchmark Settings Supported

Dataset Model Task Sample size (train,test) Parameter size Metrics Data Source
MNIST-10 LR Image Classification (5000,500) 7840 LOO/LDS/AUC link
MNIST-10 MLP Image Classification (5000,500) 0.11M LOO/LDS/AUC link
CIFAR-2 ResNet-9 Image Classification (5000,500) 4.83M LDS link
CIFAR-10 ResNet-9 Image Classification (5000,500) 4.83M AUC link
MAESTRO Music Transformer Music Generation (5000,178) 13.3M LDS link
Shakespeare nanoGPT Text Generation (3921,435) 10.7M LDS link

Benchmark Results

MNIST+LR/MLP

mnist-result

LDS performance on larger models

larger-lds-result

AUC performance

larger-lds-result

Development Plan

  • More (larger) benchmark settings to come
    • ImageNet + ResNet-18
    • Tinystories + nanoGPT
    • Comparison with other libraries
  • More algorithms and low-level utility functions to come
    • KNN filter
    • TF-IDF filter
    • RelativeIF
    • KNN Shapley
    • In-Run Shapley
  • Better documentation
    • Quick start colab notebooks

About

`dattri` is a PyTorch library for developing, benchmarking, and deploying efficient data attribution algorithms.

Resources

Stars

Watchers

Forks

Packages

No packages published