Skip to content

Commit

Permalink
Add SpinQuant to generate.py (#1069)
Browse files Browse the repository at this point in the history
* Only import SpinQuant when necessary

No need to import the large Hadamard matrices required for SpinQuant if it isn't necessary

* Add SpinQaunt to `generate.py`

* Custom op for Hadamard transform for torch.compile compatability

* Add spinquant to arg parser info

* Add Spinquant benchmark results to README

* Add performance testing details

* Fix broken custom op for PyTorch < 2.4
  • Loading branch information
tobiasvanderwerff authored Oct 22, 2024
1 parent f1b4c8e commit 3044ee5
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 14 deletions.
2 changes: 1 addition & 1 deletion torchao/_models/llama/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from tokenizer import get_tokenizer
import time
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
from torchao.prototype.spinquant import apply_spinquant

def run_evaluation(
checkpoint_path: Path,
Expand Down Expand Up @@ -71,6 +70,7 @@ def run_evaluation(

if quantization:
if "spinquant" in quantization:
from torchao.prototype.spinquant import apply_spinquant
apply_spinquant(model)
if "int8wo" in quantization:
quantize_(model, int8_weight_only())
Expand Down
5 changes: 4 additions & 1 deletion torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,9 @@ def main(
float8_dynamic_activation_float8_weight,
)
from torchao.quantization.granularity import PerTensor, PerRow
if "spinquant" in quantization:
from torchao.prototype.spinquant import apply_spinquant
apply_spinquant(model)
if "int8wo" in quantization:
quantize_(model, int8_weight_only())
if "int8dq" in quantization:
Expand Down Expand Up @@ -460,7 +463,7 @@ def callback(x):
parser.add_argument('-q', '--quantization', type=str,
help=(
'Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-<groupsize>, int4wo-<groupsize>-hqq, autoquant, '
+'autoquant-int4, autoquant-float8, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin'
+'autoquant-int4, autoquant-float8, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin, spinquant'
)
)
parser.add_argument("--calibration_limit", type=int, default=10, help="Number of calibration examples")
Expand Down
33 changes: 31 additions & 2 deletions torchao/prototype/spinquant/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,37 @@ Re-implementation of SpinQuant based on the official code implementation (https:

## Usage

Using this implementation with CUDA requires installing the Fast Hadamard Transform CUDA package, which can be done as follows:
For optimal performance on CUDA GPUs, install the Fast Hadamard Transform package:

```shell
pip install git+https://github.com/Dao-AILab/fast-hadamard-transform.git
```
```

## Performance

See https://github.com/pytorch/ao/pull/983 for Wikitext benchmark results.

Tested on:

- Llama-2-7b
- PyTorch 2.4.1
- NVIDIA A100
- CUDA 12.1

Without `torch.compile`:

| Configuration | Average tokens/sec | Average Bandwidth (GB/s) | Peak Memory Usage (GB) | Model Size (GB) |
|----------------|--------------------|--------------------------|------------------------|-----------------|
| Baseline | 27.33 | 361.21 | 13.62 | 13.21 |
| Spinquant (R4) | 23.01 | 304.10 | 14.24 | 13.22 |

With `torch.compile`:

| Configuration | Average tokens/sec | Average Bandwidth (GB/s) | Peak Memory Usage (GB) | Model Size (GB) |
|----------------------|--------------------|--------------------------|------------------------|-----------------|
| Baseline | 114.08 | 1507.58 | 13.88 | 13.21 |
| Spinquant (R4) | 109.59 | 1448.61 | 13.72 | 13.22 |
| Spinquant (R1+R2+R4) | 109.64 | 1449.28 | 14.90 | 13.22 |


NB: R1 and R2 are fused into the linear weights before inference takes place, so it is expected that they do not lead to additional overhead at inference time.
55 changes: 50 additions & 5 deletions torchao/prototype/spinquant/hadamard_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@

import torch

from torchao.ops import lib
from torchao.prototype.spinquant._hadamard_matrices import get_had172, get_had156, get_had140, get_had108, get_had60, get_had52, get_had36, get_had28, get_had44, get_had40, get_had20, get_had12
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4

try:
from fast_hadamard_transform import hadamard_transform
from fast_hadamard_transform import hadamard_transform as _fast_hadamard_transform

def matmul_hadU(X, hadK, K):
if X.is_cuda:
Expand All @@ -32,16 +34,59 @@ def matmul_hadU(X, hadK, K):
return matmul_hadU_slow(X, hadK, K)


def register_custom_op_impl(name):
def decorator(func):
if TORCH_VERSION_AT_LEAST_2_4:
return torch.library.custom_op(f"{name}", mutates_args=())(func)
else:
lib.define("hadamard_transform(Tensor x, float scale = 0.0) -> Tensor")
return torch.library.impl(f"{name}", "cuda")(func)
return decorator


def register_custom_op_abstract(name):
def decorator(func):
if TORCH_VERSION_AT_LEAST_2_4:
return torch.library.register_fake(f"{name}")(func)
else:
return torch.library.impl_abstract(f"{name}")(func)
return decorator


@register_custom_op_impl("torchao::hadamard_transform")
def hadamard_transform(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
"""
Arguments:
x: (..., dim)
scale: float. Multiply the output by this number.
Returns:
out: (..., dim)
Multiply each row of x by the Hadamard transform matrix.
Equivalent to F.linear(x, torch.tensor(scipy.linalg.hadamard(dim))) * scale.
If dim is not a power of 2, we implicitly pad x with zero so that dim is the next power of 2.
Source: https://github.com/Dao-AILab/fast-hadamard-transform
"""
return _fast_hadamard_transform(x, scale)


@register_custom_op_abstract("torchao::hadamard_transform")
def _(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
torch._check(x.dim() >= 1, lambda: f"input should be at least a 1D tensor, got {x.dim()}D")
return torch.empty_like(x)


class HadamardTransform(torch.autograd.Function):
"""The unnormalized Hadamard transform (i.e. without dividing by sqrt(2))"""

@staticmethod
def forward(ctx, u):
return hadamard_transform(u)
return _fast_hadamard_transform(u)

@staticmethod
def backward(ctx, grad):
return hadamard_transform(grad)
return _fast_hadamard_transform(grad)


def is_pow2(n):
Expand Down Expand Up @@ -144,9 +189,9 @@ def matmul_hadU_slow(X, hadK, K):
def matmul_hadU_fast(X, hadK, K):
n = X.shape[-1]
if K == 1:
return HadamardTransform.apply(X.contiguous()) / torch.tensor(n).sqrt()
return torch.ops.torchao.hadamard_transform.default(X.contiguous()) / torch.tensor(n).sqrt()
input = X.view(-1, K, n // K)
input = HadamardTransform.apply(input.contiguous()) / torch.tensor(n).sqrt()
input = torch.ops.torchao.hadamard_transform.default(input.contiguous()) / torch.tensor(n).sqrt()
input = hadK.to(input.device).to(input.dtype) @ input
return input.reshape(X.shape)

Expand Down
10 changes: 5 additions & 5 deletions torchao/prototype/spinquant/spinquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def apply_spinquant_r4(model, device):
_add_activation_wrappers_r4(model)


@torch.inference_mode()
@torch.no_grad()
def _fuse_layernorm_into_linear(layernorm: RMSNorm, linear_layers: typing.Iterable[torch.nn.Linear]):
"""Fuse the linear operations in Layernorm into the adjacent linear blocks."""
for linear in linear_layers:
Expand All @@ -127,7 +127,7 @@ def _fuse_layernorm_into_linear(layernorm: RMSNorm, linear_layers: typing.Iterab
layernorm.weight.data = torch.ones_like(layernorm.weight.data)


@torch.inference_mode()
@torch.no_grad()
def _rotate_model_r1(model, R1):
_rotate_embeddings(model, R1)
_rotate_head(model, R1)
Expand All @@ -139,7 +139,7 @@ def _rotate_model_r1(model, R1):
_rotate_mlp_output(layer, R1)


@torch.inference_mode()
@torch.no_grad()
def _rotate_model_r2(model, R2s):
"""Rotate the W_v and W_o weights of the multi-head self-attention modules."""

Expand Down Expand Up @@ -168,7 +168,7 @@ def _rotate_model_r2(model, R2s):
attn.wqkv.weight.data = torch.cat([wq, wk, wv_mod.weight.data], dim=0)


@torch.inference_mode()
@torch.no_grad()
def _rotate_model_r4(model):
"""Rotate the MLP output weights."""

Expand All @@ -193,7 +193,7 @@ def _add_activation_wrappers_r4(model):
)


@torch.inference_mode()
@torch.no_grad()
def fuse_layernorm_into_linear(model):
"""
Fuse RMSNorm weights into the subsequent linear layers.
Expand Down

0 comments on commit 3044ee5

Please sign in to comment.