Skip to content
/ ao Public
forked from pytorch/ao

torchao: PyTorch Architecture Optimization (AO). A repository to host AO techniques and performant kernels that work with PyTorch.

License

Notifications You must be signed in to change notification settings

jeromeku/ao

 
 

Repository files navigation

torchao: PyTorch Architecture Optimization

This repository is currently under heavy development - if you have suggestions on the API or use-cases you'd like to be covered, please open an issue

Introduction

torchao is a PyTorch library for quantization and sparsity.

Get Started

Installation

torchao makes liberal use of several new features in pytorch, it's recommended to use it with the current nightly or latest stable version of PyTorch.

Stable Release

pip install torchao

Nightly Release

pip install torchao-nightly

From source

git clone https://github.com/pytorch/ao
cd ao
pip install -r requirements.txt
pip install -r dev-requirements.txt

There are two options; -If you plan to be developing the library run:

python setup.py develop

If you want to install from source run

python setup.py install 

** Note: If you are running into any issues while building ao cpp extensions you can instead build using

USE_CPP=0 python setup.py install

Quantization

import torch
import torchao

# inductor settings which improve torch.compile performance for quantized modules
torch._inductor.config.force_fuse_int_mm_with_mul = True
torch._inductor.config.use_mixed_mm = True

# Plug in your model and example input
model = torch.nn.Sequential(torch.nn.Linear(32, 64)).cuda().to(torch.bfloat16)
input = torch.randn(32,32, dtype=torch.bfloat16, device='cuda')

# perform autoquantization and compilation
q_model = torchao.autoquant(torch.compile(model, mode='max-autotune'))
q_model(input)

Sparsity

import torch
from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
from torch.ao.pruning import WeightNormSparsifier

# bfloat16 CUDA model
model = torch.nn.Sequential(torch.nn.Linear(64, 64)).cuda().to(torch.bfloat16)

# Accuracy: Finding a sparse subnetwork
sparse_config = []
for name, mod in model.named_modules():
   if isinstance(mod, torch.nn.Linear):
      sparse_config.append({"tensor_fqn": f"{name}.weight"})

sparsifier = WeightNormSparsifier(sparsity_level=1.0,
                                 sparse_block_shape=(1,4),
                                 zeros_per_block=2)

# attach FakeSparsity
sparsifier.prepare(model, sparse_config)
sparsifier.step()
sparsifier.squash_mask()
# now we have dense model with sparse weights

# Performance: Accelerated sparse inference
for name, mod in model.named_modules():
   if isinstance(mod, torch.nn.Linear):
      mod.weight = torch.nn.Parameter(to_sparse_semi_structured(mod.weight))

To learn more try out our APIs, you can check out API examples in

Supported Features

  1. Quantization algorithms
  2. Sparsity algorithms such as Wanda that help improve accuracy of sparse networks
  3. Support for lower precision dtypes such as
    • nf4 which was used to implement QLoRA without writing custom Triton or CUDA code
    • uint4
    • MX implementing training and inference support with tensors using the OCP MX spec data types, which can be described as groupwise scaled float8/float6/float4/int8, with the scales being constrained to powers of two. This work is prototype as the hardware support is not available yet.
  4. Bleeding Edge Kernels for experimental kernels without backwards compatibility guarantees

Our Goals

  • Composability with torch.compile: We rely heavily on torch.compile to write pure PyTorch code and codegen efficient kernels. There are however limits to what a compiler can do so we don't shy away from writing our custom CUDA/Triton kernels
  • Composability with FSDP: The new support for FSDP per parameter sharding means engineers and researchers alike can experiment with different quantization and distributed strategies concurrently.
  • Performance: We measure our performance on every commit using an A10G. We also regularly run performance benchmarks on the torchbench suite
  • Heterogeneous Hardware: Efficient kernels that can run on CPU/GPU based server (w/ torch.compile) and mobile backends (w/ ExecuTorch).
  • Packaging kernels should be easy: We support custom CUDA and Triton extensions so you can focus on writing your kernels and we'll ensure that they work on most operating systems and devices

Integrations

torchao has been integrated with other libraries including

  • torchtune leverages our 8 and 4 bit weight-only quantization techniques with optional support for GPTQ
  • Executorch leverages our GPTQ implementation for both 8da4w (int8 dynamic activation with int4 weight) and int4 weight-only quantization.
  • HQQ leverages our int4mm kernel for low latency inference

Success stories

Our kernels have been used to achieve SOTA inference performance on

License

torchao is released under the BSD 3 license.

About

torchao: PyTorch Architecture Optimization (AO). A repository to host AO techniques and performant kernels that work with PyTorch.

Resources

License

Code of conduct

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 93.3%
  • Cuda 4.4%
  • C++ 2.1%
  • Shell 0.2%