This repository is currently under heavy development - if you have suggestions on the API or use-cases you'd like to be covered, please open an issue
torchao
is a PyTorch library for quantization and sparsity.
torchao
makes liberal use of several new features in pytorch, it's recommended to use it with the current nightly or latest stable version of PyTorch.
Stable Release
pip install torchao
Nightly Release
pip install torchao-nightly
From source
git clone https://github.com/pytorch/ao
cd ao
pip install .
import torch
import torchao
# inductor settings which improve torch.compile performance for quantized modules
torch._inductor.config.force_fuse_int_mm_with_mul = True
torch._inductor.config.use_mixed_mm = True
# Plug in your model and example input
model = torch.nn.Sequential(torch.nn.Linear(32, 64)).cuda().to(torch.bfloat16)
input = torch.randn(32,32, dtype=torch.bfloat16, device='cuda')
# perform autoquantization
torchao.autoquant(model, (input))
# compile the model to recover performance
model = torch.compile(model, mode='max-autotune')
model(input)
import torch
from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
from torch.ao.pruning import WeightNormSparsifier
# bfloat16 CUDA model
model = torch.nn.Sequential(torch.nn.Linear(64, 64)).cuda().to(torch.bfloat16)
# Accuracy: Finding a sparse subnetwork
sparse_config = []
for name, mod in model.named_modules():
if isinstance(mod, torch.nn.Linear):
sparse_config.append({"tensor_fqn": f"{name}.weight"})
sparsifier = WeightNormSparsifier(sparsity_level=1.0,
sparse_block_shape=(1,4),
zeros_per_block=2)
# attach FakeSparsity
sparsifier.prepare(model, sparse_config)
sparsifier.step()
sparsifier.squash_mask()
# now we have dense model with sparse weights
# Performance: Accelerated sparse inference
for name, mod in model.named_modules():
if isinstance(mod, torch.nn.Linear):
mod.weight = torch.nn.Parameter(to_sparse_semi_structured(mod.weight))
To learn more try out our APIs, you can check out API examples in
- Quantization algorithms
- Int8 weight-only quantization
- Int4 weight-only quantization
- GPTQ and Smoothquant for low latency inference
- High level torchao.autoquant API and kernel autotuner targeting SOTA performance across varying model shapes on consumer and enterprise GPUs
- Sparsity algorithms such as Wanda that help improve accuracy of sparse networks
- Support for lower precision dtypes such as
- nf4 which was used to implement QLoRA without writing custom Triton or CUDA code
- uint4
- Bleeding Edge Kernels for experimental kernels without backwards compatibility guarantees
- GaLore for memory efficient finetuning
- fused HQQ Gemm Kernel for compute bound workloads
- Composability with
torch.compile
: We rely heavily ontorch.compile
to write pure PyTorch code and codegen efficient kernels. There are however limits to what a compiler can do so we don't shy away from writing our custom CUDA/Triton kernels - Composability with
FSDP
: The new support for FSDP per parameter sharding means engineers and researchers alike can experiment with different quantization and distributed strategies concurrently. - Performance: We measure our performance on every commit using an A10G. We also regularly run performance benchmarks on the torchbench suite
- Heterogeneous Hardware: Efficient kernels that can run on CPU/GPU based server (w/ torch.compile) and mobile backends (w/ ExecuTorch).
- Packaging kernels should be easy: We support custom CUDA and Triton extensions so you can focus on writing your kernels and we'll ensure that they work on most operating systems and devices
torchao has been integrated with other libraries including
- torchtune leverages our 8 and 4 bit weight-only quantization techniques with optional support for GPTQ
- Executorch leverages our GPTQ implementation for both 8da4w (int8 dynamic activation with int4 weight) and int4 weight-only quantization.
- HQQ leverages our int4mm kernel for low latency inference
Our kernels have been used to achieve SOTA inference performance on
torchao
is released under the BSD 3 license.