Skip to content

Commit

Permalink
Add option to load pretrained R1/R2 matrices
Browse files Browse the repository at this point in the history
  • Loading branch information
tobiasvanderwerff committed Oct 4, 2024
1 parent a378156 commit 2c8acdd
Showing 1 changed file with 56 additions and 20 deletions.
76 changes: 56 additions & 20 deletions torchao/quantization/spin_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
Based on https://github.com/facebookresearch/SpinQuant
"""

from pathlib import Path
import typing

import torch
Expand Down Expand Up @@ -31,12 +32,12 @@ def forward(self, x):
return x


def apply_spinquant(model: Transformer):
def apply_spinquant(model: Transformer, use_r1=False, use_r2=True, use_r4=True, pretrained_rotation_path=None):
"""
Apply SpinQuant to a Transformer model: https://arxiv.org/abs/2405.16406
Currently, this only applies the R1 + R2 + R4 rotation matrices to the model
(not R3, and no Cayley optimization).
Currently, this has the option of applying R1, R2, and R4 rotation matrices to
the model (not R3, and no Cayley optimization).
"""
assert isinstance(model, Transformer), "Only Transformer models are supported"

Expand All @@ -45,25 +46,59 @@ def apply_spinquant(model: Transformer):
model.to(device=device)
torch.manual_seed(0) # for reproducability of random Hadamard matrices

fuse_layernorm_weights_into_linear(model)
apply_spinquant_r1(model, device)
apply_spinquant_r2(model, device)
apply_spinquant_r4(model, device)
# For testing purposes (remove later)
# Weights link: https://drive.google.com/drive/folders/1nV9juzE6_OHr10y6Ke5KCyOiGqDr0srX
# pretrained_rotation_path = "7B_W4A16KV16_lr_1.5_seed_0/R.bin"

if pretrained_rotation_path is not None:
assert Path(pretrained_rotation_path).is_file(), "Pretrained rotation path does not exist"
assert Path(pretrained_rotation_path).suffix == ".bin", "Expected a .bin file."

if use_r1:
fuse_layernorm_weights_into_linear(model)
apply_spinquant_r1(model, device, pretrained_rotation_path)
if use_r2:
apply_spinquant_r2(model, device, pretrained_rotation_path)
if use_r4:
apply_spinquant_r4(model, device)

model.to(device=original_device)


def apply_spinquant_r1(model, device):
R1 = random_hadamard_matrix(model.config.dim, device)
def apply_spinquant_r1(model, device, pretrained_rotation_path=None):
"""Apply the SpinQuant R1 rotation matrix to the model."""

# Load R1 matrix
if pretrained_rotation_path is not None:
R1 = torch.load(pretrained_rotation_path)["R1"].to(device).to(torch.float64)
assert R1.shape == (model.config.dim, model.config.dim), f"{R1.shape} vs {model.config.dim}"
else:
R1 = random_hadamard_matrix(model.config.dim, device)

_rotate_model_r1(model, R1)


def apply_spinquant_r2(model, device):
R2 = random_hadamard_matrix(model.config.head_dim, device)
_rotate_model_r2(model, R2)
def apply_spinquant_r2(model, device, pretrained_rotation_path=None):
"""Apply the SpinQuant R2 rotation matrix to the model."""

# Load R2 matrices
R2s = []
head_dim = model.config.head_dim
for i, _ in enumerate(model.layers):
if pretrained_rotation_path is not None:
key = f"model.layers.{i}.self_attn.R2"
R2s_ = torch.load(pretrained_rotation_path)
R2 = R2s_[key].to(device).to(torch.float64)
assert R2.shape == (head_dim, head_dim), f"{R2.shape} != ({head_dim}, {head_dim})"
else:
R2 = random_hadamard_matrix(head_dim, device)
R2s.append(R2)

_rotate_model_r2(model, R2s)


def apply_spinquant_r4(model, device):
"""Apply the SpinQuant R4 rotation matrix to the model."""
_rotate_model_r4(model)
_add_activation_wrappers_r4(model)

Expand Down Expand Up @@ -102,14 +137,16 @@ def _rotate_model_r1(model, R1):


@torch.inference_mode()
def _rotate_model_r2(model, R2):
def _rotate_model_r2(model, R2s):
"""Rotate the W_v and W_o weights of the multi-head self-attention modules."""

# Apply R2 rotation to all multi-head self-attention modules
for layer in model.layers:
for idx, layer in enumerate(model.layers):
attn = layer.attention
head_dim = model.config.head_dim

R2 = R2s[idx]

# Rotate W_o
apply_exact_had_to_linear(attn.wo, had_dim=head_dim, output=False, R2=R2)

Expand Down Expand Up @@ -169,10 +206,9 @@ def fuse_layernorm_weights_into_linear(model):
something.)
"""
# Embedding fusion (from utils/fuse_norm_utils.py:43)
# I currently don't understand why this is necessary, so I'm omitting it (I
# contacted the authors about it:
# https://github.com/facebookresearch/SpinQuant/issues/14). It doesn't seem
# to affect performance (tested on int4wo)
# I currently don't understand why this is necessary, so I contacted the
# authors about it:
# https://github.com/facebookresearch/SpinQuant/issues/14).
for W in [model.tok_embeddings]:
W_ = W.weight.data.double()
W.weight.data = (W_ - W_.mean(dim=-1, keepdim=True)).to(W.weight.data.dtype)
Expand All @@ -196,7 +232,7 @@ def fuse_layernorm_weights_into_linear(model):
def _rotate_mlp_output(layer, R1):
W = layer.feed_forward.w2
dtype = W.weight.dtype
W_ = W.weight.data.to( dtype=torch.float64)
W_ = W.weight.data.to(dtype=torch.float64)
W.weight.data = torch.matmul(R1.T, W_).to(dtype=dtype)
if W.bias is not None:
b = W.bias.data.to(dtype=torch.float64)
Expand Down

0 comments on commit 2c8acdd

Please sign in to comment.