Skip to content

RistoAle97/centered-kernel-alignment

Repository files navigation

🤖 CKA PyTorch 🤖

CKA (Centered Kernel Alignment) in PyTorch.

Python Pytorch

Warning

This repository has been built mainly for personal and academic use since Captum still needs to implement its variant of CKA. As such, do not expect this project to work for every model.


✒️ About

Note

Centered Kernel Alignment (CKA) [1] is a similarity index between representations of features in neural networks, based on the Hilbert-Schmidt Independence Criterion (HSIC) [2]. Given a set of examples, CKA compares the representations of examples passed through the layers that we want to compare.

Given two matrices $\boldsymbol{X} \in \mathbb{R}^{n\times s_1}$ and $\boldsymbol{Y} \in \mathbb{R}^{n\times s_2}$ representing the output of two layers, we can define two auxiliary $n \times n$ Gram matrices $\boldsymbol{K}=\boldsymbol{XX^T}$ and $\boldsymbol{L}=\boldsymbol{YY^T}$ and compute the dot-product similarity between them

$$\langle vec(\boldsymbol{XX^T}), vec(\boldsymbol{YY^T})\rangle = tr(\boldsymbol{XX^T YY^T}) = \lVert \boldsymbol{Y^T X} \rVert_F^2.$$

Then, the $HSIC$ on $K$ and $L$ is defined as

$$HSIC_0(\boldsymbol{K}, \boldsymbol{L}) = \frac{tr(\boldsymbol{KHLH})}{(n - 1)^2},$$

where $\boldsymbol{H} = \boldsymbol{I_n} - \frac{1}{n}\boldsymbol{J_n}$ is the centering matrix and $\boldsymbol{J_n}$ is an $n \times n$ matrix filled with ones. Finally, to obtain the CKA value we only need to normalize $HSIC_0$

$$CKA(\boldsymbol{K}, \boldsymbol{L}) = \frac{HSIC(\boldsymbol{K}, \boldsymbol{L})}{\sqrt{HSIC(\boldsymbol{K}, \boldsymbol{K}) HSIC(\boldsymbol{L}, \boldsymbol{L})}}.$$

Note

However, naive computation of linear CKA (i.e.: the previous equation) requires maintaining the activations across the entire dataset in memory, which is challenging for wide and deep networks [3].

Therefore, we need to define the unbiased estimator of HSIC so that the value of CKA is independent of the batch size

$$HSIC_1(\boldsymbol{K}, \boldsymbol{L})=\frac{1}{n(n-3)}\left( tr(\boldsymbol{\tilde{K}}, \boldsymbol{\tilde{L}}) + \frac{\boldsymbol{1^T\tilde{K}11^T\tilde{L}1}}{(n-1)(n-2)} - \frac{2}{n-2}\boldsymbol{1^T\tilde{K}\tilde{L}1}\right),$$

where $\boldsymbol{\tilde{K}}$ and $\boldsymbol{\tilde{L}}$ are obtained by setting the diagonal entries of $\boldsymbol{K}$ and $\boldsymbol{L}$ to zero. Finally, we can compute the minibatch version of CKA by averaging $HSIC_1$ scores over $k$ minibatches

$$CKA_{minibatch}=\frac{\frac{1}{k} \displaystyle\sum_{i=1}^{k} HSIC_1(\boldsymbol{K_i}, \boldsymbol{L_i})}{\sqrt{\frac{1}{k} \displaystyle\sum_{i=1}^{k} HSIC_1(\boldsymbol{K_i}, \boldsymbol{K_i})}\sqrt{\frac{1}{k} \displaystyle\sum_{i=1}^{k} HSIC_1(\boldsymbol{L_i}, \boldsymbol{L_i})}},$$

with $\boldsymbol{K_i}=\boldsymbol{X_iX_i^T}$ and $\boldsymbol{L_i}=\boldsymbol{Y_iY_i^T}$, where $\boldsymbol{X_i} \in \mathbb{R}^{m \times p_1}$ and $\boldsymbol{Y_i} \in \mathbb{R}^{m \times p_2}$ are now matrices containing activations of the $i^{th}$ minibatch of $m$ examples sampled without replacement [3].


📦 Installation

This project requires python >= 3.10.

Create a new venv

Note

This will create a new virtual environment in the working directory under .venv. If you create such venv with uv there will be no need to activate it since uv will find the env in the working directory or any parent directories.

# If you have uv installed
uv venv

# Otherwise
python -m venv .venv
source .venv/bin/activate  # if you are on Linux
.\.venv\Scripts\activate.bat  # if you are using the cmd on Windows
.\.venv\Scripts\Activate.ps1  # if you are using the PowerShell on Windows

Install the package

Note

This will install PyTorch compiled with CUDA.

You can install the package by either:

  • using pip

    pip install git+https://github.com/RistoAle97/centered-kernel-alignment

    This will not install the dev dependencies listed in pyproject.toml.

  • cloning the repository and installing the dependencies

    git clone https://github.com/RistoAle97/centered-kernel-alignment
    
    # If you have uv installed
    uv pip install -e centered-kernel-alignment
    uv pip install ckatorch[dev]  # if you want to commit something to the repo
    
    # Otherwise
    pip install -e centered-kernel-alignment
    pip install ckatorch[dev]  # same as for uv, remember to open a pull request afterwards

Take a look at the examples directory to understand how to compute CKA in two basic scenarios.


🖼️ Plots

Model compared with itself Different models compared
Model compared with itself Model comparison

📚 Bibliography

[1] Kornblith, Simon, et al. "Similarity of neural network representations revisited." International Conference on Machine Learning. PMLR, 2019.

[2] Wang, Tinghua, Xiaolu Dai, and Yuze Liu. "Learning with Hilbert–Schmidt independence criterion: A review and new perspectives." Knowledge-based systems 234 (2021): 107567.

[3] Nguyen, Thao, Maithra Raghu, and Simon Kornblith. "Do wide and deep networks learn the same things? uncovering how neural network representations vary with width and depth." arXiv preprint arXiv:2010.15327 (2020).

This project is also based on the following repositories:


📝 License

This project is MIT licensed.

About

CKA (Centered Kernel Alignment) implemented in PyTorch

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages