This repository is currently under heavy development - if you have suggestions on the API or use-cases you'd like to be covered, please open an issue
torchao
is a PyTorch library for quantization and sparsity.
torchao
makes liberal use of several new features in pytorch, it's recommended to use it with the current nightly or latest stable version of PyTorch.
Stable Release
pip install torchao
Nightly Release
pip install torchao-nightly
From source
git clone https://github.com/pytorch/ao
cd ao
pip install -r requirements.txt
pip install -r dev-requirements.txt
There are two options; -If you plan to be developing the library run:
python setup.py develop
If you want to install from source run
python setup.py install
** Note:
If you are running into any issues while building ao
cpp extensions you can instead build using
USE_CPP=0 python setup.py install
import torch
import torchao
# inductor settings which improve torch.compile performance for quantized modules
torch._inductor.config.force_fuse_int_mm_with_mul = True
torch._inductor.config.use_mixed_mm = True
# Plug in your model and example input
model = torch.nn.Sequential(torch.nn.Linear(32, 64)).cuda().to(torch.bfloat16)
input = torch.randn(32,32, dtype=torch.bfloat16, device='cuda')
# perform autoquantization and compilation
q_model = torchao.autoquant(torch.compile(model, mode='max-autotune'))
q_model(input)
import torch
from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
from torch.ao.pruning import WeightNormSparsifier
# bfloat16 CUDA model
model = torch.nn.Sequential(torch.nn.Linear(64, 64)).cuda().to(torch.bfloat16)
# Accuracy: Finding a sparse subnetwork
sparse_config = []
for name, mod in model.named_modules():
if isinstance(mod, torch.nn.Linear):
sparse_config.append({"tensor_fqn": f"{name}.weight"})
sparsifier = WeightNormSparsifier(sparsity_level=1.0,
sparse_block_shape=(1,4),
zeros_per_block=2)
# attach FakeSparsity
sparsifier.prepare(model, sparse_config)
sparsifier.step()
sparsifier.squash_mask()
# now we have dense model with sparse weights
# Performance: Accelerated sparse inference
for name, mod in model.named_modules():
if isinstance(mod, torch.nn.Linear):
mod.weight = torch.nn.Parameter(to_sparse_semi_structured(mod.weight))
To learn more try out our APIs, you can check out API examples in
- Quantization algorithms
- Int8 weight-only quantization
- Int4 weight-only quantization
- GPTQ and Smoothquant for low latency inference
- High level torchao.autoquant API and kernel autotuner targeting SOTA performance across varying model shapes on consumer and enterprise GPUs
- Sparsity algorithms such as Wanda that help improve accuracy of sparse networks
- Support for lower precision dtypes such as
- nf4 which was used to implement QLoRA without writing custom Triton or CUDA code
- uint4
- MX implementing training and inference support with tensors using the OCP MX spec data types, which can be described as groupwise scaled float8/float6/float4/int8, with the scales being constrained to powers of two. This work is prototype as the hardware support is not available yet.
- Bleeding Edge Kernels for experimental kernels without backwards compatibility guarantees
- GaLore for memory efficient finetuning
- fused HQQ Gemm Kernel for compute bound workloads
- Composability with
torch.compile
: We rely heavily ontorch.compile
to write pure PyTorch code and codegen efficient kernels. There are however limits to what a compiler can do so we don't shy away from writing our custom CUDA/Triton kernels - Composability with
FSDP
: The new support for FSDP per parameter sharding means engineers and researchers alike can experiment with different quantization and distributed strategies concurrently. - Performance: We measure our performance on every commit using an A10G. We also regularly run performance benchmarks on the torchbench suite
- Heterogeneous Hardware: Efficient kernels that can run on CPU/GPU based server (w/ torch.compile) and mobile backends (w/ ExecuTorch).
- Packaging kernels should be easy: We support custom CUDA and Triton extensions so you can focus on writing your kernels and we'll ensure that they work on most operating systems and devices
torchao has been integrated with other libraries including
- torchtune leverages our 8 and 4 bit weight-only quantization techniques with optional support for GPTQ
- Executorch leverages our GPTQ implementation for both 8da4w (int8 dynamic activation with int4 weight) and int4 weight-only quantization.
- HQQ leverages our int4mm kernel for low latency inference
Our kernels have been used to achieve SOTA inference performance on
torchao
is released under the BSD 3 license.