Skip to content

Commit

Permalink
Update torchao api reference and add contributor guide
Browse files Browse the repository at this point in the history
Summary:
1. updated torchao api reference for quantization to include the APIs we want to expose, renamed torchao/quantization/linear_activation_weight_observer.py
and removed the safe_int_mm and int_scaled_matmul from quant_primitives.py
2. added pytorch#391 to torchao docs

Test Plan:
CI

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 committed Nov 12, 2024
1 parent f96e5ec commit 0654ea1
Show file tree
Hide file tree
Showing 16 changed files with 807 additions and 56 deletions.
4 changes: 3 additions & 1 deletion docs/source/api_ref_dtypes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@ torchao.dtypes

to_nf4
to_affine_quantized_intx
to_affine_quantized_floatx
to_affine_quantized_intx_static
to_affine_quantized_floatx
to_affine_quantized_floatx_static
to_affine_quantized_fpx
NF4Tensor
AffineQuantizedTensor

..
Expand Down
9 changes: 3 additions & 6 deletions docs/source/api_ref_intro.rst
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
``torchao`` API Reference
=========================

This section introduces the torchao API reference.
Dive into the details of how torchao integrates with PyTorch to
optimize your machine learning models.
This section introduces the torchao API reference. Dive into the details of how torchao integrates with PyTorch to optimize your machine learning models.

.. toctree::
:glob:
:maxdepth: 1
:caption: Python API Reference

api_ref_sparsity
api_ref_quantization
api_ref_dtypes
api_ref_kernel
api_ref_quantization
api_ref_sparsity
39 changes: 32 additions & 7 deletions docs/source/api_ref_quantization.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,40 @@ torchao.quantization
.. autosummary::
:toctree: generated/
:nosignatures:

SmoothFakeDynQuantMixin
SmoothFakeDynamicallyQuantizedLinear
swap_linear_with_smooth_fq_linear
smooth_fq_linear_to_inference
Int4WeightOnlyGPTQQuantizer
Int4WeightOnlyQuantizer
autoquant

quantize_
int8_dynamic_activation_int4_weight
int8_dynamic_activation_int8_weight
int4_weight_only
int8_weight_only
float8_weight_only
float8_dynamic_activation_float8_weight
float8_static_activation_float8_weight
uintx_weight_only
fpx_weight_only

to_linear_activation_quantized
to_linear_activation_weight_observed

swap_linear_with_smooth_fq_linear
smooth_fq_linear_to_inference

choose_qparams_affine
choose_qparams_affine_with_min_max
choose_qparams_affine_floatx
quantize_affine
quantize_affine_floatx
dequantize_affine
dequantize_affine_floatx
choose_qparams_and_quantize_affine_hqq
fake_quantize_affine
fake_quantize_affine_cachemask

safe_int_mm
int_scaled_matmul

MappingType
ZeroPointDomain
TorchAODType

674 changes: 674 additions & 0 deletions docs/source/contributor_guide.rst

Large diffs are not rendered by default.

18 changes: 11 additions & 7 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
Welcome to the torchao Documentation
=======================================

**torchao** is an open-source library that provides the functionality
to quantize and prune your models using native PyTorch. Our documentation is under development
with more content coming soon.
`**torchao** <https://github.com/pytorch/ao>`__ is a ibrary for custom data types & optimizations. Quantize and sparsify weights, gradients, optimizers & activations for inference and training using native PyTorch. Please checkout torchao `README <https://github.com/pytorch/ao#torchao-pytorch-architecture-optimization>`__ for an overall introduction to the library and recent highlight and updates. The documentation here will focus on 1. API Reference 2. Developer / Researcher Contribution Guide 3. Tutorials.

..
.. grid:: 3
Expand Down Expand Up @@ -81,13 +79,19 @@ with more content coming soon.
:maxdepth: 1
:caption: API Reference

api_ref_sparsity
api_ref_intro
api_ref_quantization
api_ref_dtypes
api_ref_quantization
api_ref_sparsity
..
api_ref_kernel
.. toctree::
:glob:
:maxdepth: 1
:caption: Contributor Guide

contributor_guide

.. toctree::
:glob:
:maxdepth: 1
Expand Down
4 changes: 3 additions & 1 deletion test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@
change_linear_weights_to_int8_woqtensors,
change_linear_weights_to_int4_woqtensors,
)
from torchao.quantization.quant_primitives import (
from torchao.quantization import (
safe_int_mm,
)
from torchao.quantization.quant_primitives import (
choose_qparams_affine,
quantize_affine,
dequantize_affine,
Expand Down
4 changes: 3 additions & 1 deletion torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,12 @@
choose_qparams_and_quantize_affine_hqq,
dequantize_affine,
dequantize_affine_floatx,
int_scaled_matmul,
quantize_affine,
quantize_affine_floatx,
)
from torchao.kernel import (
int_scaled_matmul,
)
from torchao.quantization.utils import (
pack_tinygemm_scales_and_zeros,
)
Expand Down
7 changes: 7 additions & 0 deletions torchao/kernel/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from torchao.kernel.intmm import int_scaled_matmul
from torchao.kernel.intmm import safe_int_mm

__all__ = [
"safe_int_mm",
"int_scaled_matmul",
]
81 changes: 56 additions & 25 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
PerTensor,
PerToken,
)
from torchao.kernel import (
safe_int_mm,
int_scaled_matmul,
)
from .linear_activation_quantized_tensor import (
LinearActivationQuantizedTensor,
to_linear_activation_quantized,
Expand Down Expand Up @@ -70,52 +74,79 @@
compute_error,
)
from .weight_only import WeightOnlyInt8QuantLinear
from .linear_activation_weight_observed_tensor import (
to_linear_activation_weight_observed,
)

__all__ = [
"swap_conv2d_1x1_to_linear",
# top level API - auto
"autoquant",
"DEFAULT_AUTOQUANT_CLASS_LIST",
"DEFAULT_INT4_AUTOQUANT_CLASS_LIST",
"OTHER_AUTOQUANT_CLASS_LIST",
"get_scale",
"SmoothFakeDynQuantMixin",
"SmoothFakeDynamicallyQuantizedLinear",
"swap_linear_with_smooth_fq_linear",
"smooth_fq_linear_to_inference",
"set_smooth_fq_attribute",
"compute_error",
"Int4WeightOnlyGPTQQuantizer",
"Int4WeightOnlyQuantizer",
"quantize_affine",
"dequantize_affine",
"choose_qparams_affine",

# top level API - manual
"quantize_",
"int8_dynamic_activation_int4_weight",
"int8_dynamic_activation_int8_weight",
"int8_dynamic_activation_int8_semi_sparse_weight",
"int4_weight_only",
"int8_weight_only",
"float8_weight_only",
"float8_dynamic_activation_float8_weight",
"float8_static_activation_float8_weight"
"uintx_weight_only",
"fpx_weight_only",
"LinearActivationQuantizedTensor",

# smooth quant - subject to change
"swap_conv2d_1x1_to_linear"
"get_scale",
"SmoothFakeDynQuantMixin",
"SmoothFakeDynamicallyQuantizedLinear",
"swap_linear_with_smooth_fq_linear",
"smooth_fq_linear_to_inference",
"set_smooth_fq_attribute",
"compute_error",

# building blocks
"to_linear_activation_quantized",
"to_weight_tensor_with_linear_activation_scale_metadata",
"float8_weight_only",
"float8_dynamic_activation_float8_weight",
"float8_static_activation_float8_weight",
"Int8DynActInt4WeightGPTQQuantizer",
"Int8DynActInt4WeightQuantizer",
"Int8DynActInt4WeightLinear",
"WeightOnlyInt8QuantLinear",
"TwoStepQuantizer",
"Quantizer",
"ZeroPointDomain",
"MappingType",
"AffineQuantizedMinMaxObserver",
"AffineQuantizedObserverBase",

# quant primitive ops
"choose_qprams_affine",
"choose_qparams_affine_with_min_max",
"choose_qparams_affine_floatx",
"quantize_affine",
"quantize_affine_floatx",
"dequantize_affine",
"dequantize_affine_floatx",
"choose_qparams_and_quantize_affine_hqq",
"fake_quantize_affine",
"fake_quantize_affine_cachemask",

# operators/kernels
"safe_int_mm",
"int_scaled_matmul",

# dataclasses and types
"MappingType",
"ZeroPointDomain",
"TorchAODType",
"PerTensor",
"PerAxis",
"PerGroup",
"PerRow",
"PerToken",

"LinearActivationQuantizedTensor",
"Int4WeightOnlyGPTQQuantizer",
"Int4WeightOnlyQuantizer",
"Int8DynActInt4WeightGPTQQuantizer",
"Int8DynActInt4WeightQuantizer",
"Int8DynActInt4WeightLinear",
"WeightOnlyInt8QuantLinear",
"TwoStepQuantizer",
"Quantizer",
]
3 changes: 2 additions & 1 deletion torchao/quantization/autoquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,14 @@
PerRow,
PerTensor,
)
from .quant_primitives import safe_int_mm
from torchao.kernel import safe_int_mm
from .subclass import ( # noqa
Int8DynamicallyQuantizedLinearWeight,
Int8WeightOnlyQuantizedLinearWeight,
QuantizedLinearWeightBase,
)


__all__ = [
"AutoQuantizableLinearWeight",
"autoquant",
Expand Down
4 changes: 2 additions & 2 deletions torchao/quantization/linear_activation_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
class LinearActivationQuantizedTensor(TorchAOBaseTensor):
"""
Applies activation quantization for linear operator, this is used to support
dynamic quantization or static quantization, user can pass in a `input_quant_func`
dynamic quantization, user can pass in a `input_quant_func`
that is used to quantize the activation
Args:
Expand Down Expand Up @@ -60,7 +60,7 @@ def __init__(
self.quant_kwargs = quant_kwargs

def __repr__(self):
return f"LinearActivationQuantizedTensor({self.original_weight_tensor}, {self.input_quant_func}, quant_kwargs={self.quant_kwargs}))"
return f"{self.__class__.__name__}({self.original_weight_tensor}, {self.input_quant_func}, quant_kwargs={self.quant_kwargs}))"

def __tensor_flatten__(self):
return ["original_weight_tensor"], [self.input_quant_func, self.quant_kwargs]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

__all__ = [
"LinearActivationWeightObservedTensor",
"to_linear_activation_weight_observed",
]

aten = torch.ops.aten
Expand Down Expand Up @@ -147,6 +148,8 @@ def _(func, types, args, kwargs):
args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone),
)

to_linear_activation_weight_observed = LinearActivationWeightObservedTensor.from_float


if TORCH_VERSION_AT_LEAST_2_5:
# Allow a model with LinearActivationQuantizedTensor weights to be loaded with `weights_only=True`
Expand Down
2 changes: 1 addition & 1 deletion torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ def insert_observers_(
def convert_to_linear_observer(linear_module: nn.Linear):
# Wrap the weight with LinearActivationWeightObservedTensor and then with nn.Parameter
linear_module.weight = nn.Parameter(
LinearActivationWeightObservedTensor.from_float(
to_linear_activation_weight_observed(
linear_module.weight,
input_observer=input_observer,
weight_observer=weight_observer,
Expand Down
5 changes: 3 additions & 2 deletions torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
)

__all__ = [
"safe_int_mm",
"int_scaled_matmul",
"choose_qparams_affine",
"choose_qparams_affine_with_min_max",
"choose_qparams_affine_floatx",
Expand All @@ -36,6 +34,9 @@
"fake_quantize_affine",
"fake_quantize_affine_cachemask",
"choose_qparams_and_quantize_affine_hqq",
"MappingType",
"ZeroPointDomain",
"TorchAODType",
]


Expand Down
4 changes: 3 additions & 1 deletion torchao/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@
import torch
from torch.utils._python_dispatch import TorchDispatchMode

from torchao.kernel import (
int_scaled_matmul,
)
from torchao.quantization.quant_primitives import (
MappingType,
ZeroPointDomain,
choose_qparams_affine,
dequantize_affine,
int_scaled_matmul,
quantize_affine,
)
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(
self.quant_kwargs = quant_kwargs

def __repr__(self):
return f"LinearActivationQuantizedTensor({self.original_weight_tensor}, {self.input_quant_func_static}, scale={self.scale}, zero_point={self.zero_point}, quant_kwargs={self.quant_kwargs})"
return f"{self.__class__.__name__}({self.original_weight_tensor}, {self.input_quant_func_static}, scale={self.scale}, zero_point={self.zero_point}, quant_kwargs={self.quant_kwargs})"

def __tensor_flatten__(self):
tensor_data = ["original_weight_tensor", "scale"]
Expand Down

0 comments on commit 0654ea1

Please sign in to comment.