Skip to content

Commit

Permalink
Initial PARQ addition and testing (#1738)
Browse files Browse the repository at this point in the history
* Initial PARQ addition and testing

* Fix errors due to torch version differences

* Revert torchao/float8/config.py

* Fix custom decorator

* Reformat parq.py

* Undo third_party/cutlass change
  • Loading branch information
lisjin authored Mar 5, 2025
1 parent 03b83ec commit ffb4350
Show file tree
Hide file tree
Showing 13 changed files with 880 additions and 0 deletions.
100 changes: 100 additions & 0 deletions test/prototype/test_parq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import unittest

import torch

from torchao.prototype.parq.optim import (
ProxHardQuant,
ProxPARQ,
QuantOptimizer,
)
from torchao.prototype.parq.quant import LSBQuantizer, UnifQuantizer

_DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


def split_param_groups(model):
params_no_quant, params_quant = [], []
for p in model.parameters():
if p.dim() > 1:
params_quant.append(p)
else:
params_no_quant.append(p)
return params_no_quant, params_quant


class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.embedding = torch.nn.Embedding(10, 256)
self.linear1 = torch.nn.Linear(256, 128)
self.linear2 = torch.nn.Linear(128, 16)
self.relu = torch.nn.ReLU()
self.sigmoid = torch.nn.Sigmoid()

def reset_parameters(self):
for module in (self.linear1, self.linear2):
torch.nn.init.xavier_uniform_(module.weight)
torch.nn.init.zeros_(module.bias)

def example_inputs(self):
return torch.randint(1, 10, (1, 256))

def forward(self, x):
x = self.embedding(x)
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
x = self.sigmoid(x)
return x


class TestPARQuantization(unittest.TestCase):
def setUp(self):
torch.manual_seed(123)
self.model = M().to(_DEVICE)
self.params_no_quant, self.params_quant = split_param_groups(self.model)

def test_2bit_unif_quantizer_hard_prox(self):
self.model.reset_parameters()
param_groups = [
{"params": self.params_no_quant},
{"params": self.params_quant, "quant_bits": 2},
]
base_optimizer = torch.optim.AdamW(param_groups)
quantizer = UnifQuantizer()
prox_map = ProxHardQuant()
optimizer = QuantOptimizer(base_optimizer, quantizer, prox_map)

x = self.model.example_inputs().to(_DEVICE)
out = self.model(x)
out.sum().backward()
optimizer.step()

for child in self.model.children():
if isinstance(child, torch.nn.Linear):
self.assertEqual(child.weight.unique().numel(), 4)

def test_ternarybit_lsbq_parq_prox(self):
self.model.reset_parameters()
param_groups = [
{"params": self.params_no_quant},
{"params": self.params_quant, "quant_bits": 0},
]
base_optimizer = torch.optim.AdamW(param_groups)
quantizer = LSBQuantizer()
prox_map = ProxPARQ(anneal_start=0, anneal_end=2)
optimizer = QuantOptimizer(base_optimizer, quantizer, prox_map)

for _ in range(3):
x = self.model.example_inputs().to(_DEVICE)
out = self.model(x)
out.sum().backward()
optimizer.step()

for child in self.model.children():
if isinstance(child, torch.nn.Linear):
self.assertEqual(child.weight.unique().numel(), 3)


if __name__ == "__main__":
unittest.main()
47 changes: 47 additions & 0 deletions torchao/prototype/parq/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# PARQ: Piecewise-Affine Regularized Quantization

PARQ is a QAT method based on a convex regularization framework. It converges to hard quantization (i.e., STE) at its asymptotic limit.

This library applies QAT without modifying model-level code. It instead interfaces with the optimizer only, allowing a user to choose which parameters should be quantized via parameter groups. It separates QAT into the below components.

* quantization method: computing the best set of discrete, quantized values
* proximal mapping: projection of weights onto quantized values

## QAT arguments

| | description | choices |
| --- | --- | --- |
| `quant-bits` | bit-width for quantized weights | 0 (ternary), 1—4 |
| `quant-method` | method for determining quantized values | `lsbq`, `uniform` |
| `quant-proxmap` | proximal mapping to project weights onto quantized values | `hard`, `parq`, `binaryrelax` |
| `anneal-start` | start epoch for QAT annealing period | (0, `total_steps` - 1) |
| `anneal-end` | end epoch for QAT annealing period | (`anneal_end`, `total_steps`) |
| `anneal-steepness` | sigmoid steepness for PARQ inverse slope schedule | 25—100 |

## Optimizer-only interface

The `QuantOptimizer` wrapper takes any `torch.optim.Optimizer` object. It is also initialized with a `Quantizer` and `ProxMap` object. Integration into new training pipelines is simple:
```python
from parq.optim import ProxPARQ, QuantOptimizer
from parq.quant import LSBQuantizer


# split params into quantizable and non-quantizable params
params_quant, params_no_wd, params_wd = split_param_groups(model) # user-defined
param_groups = [
{"params": params_quant, "quant_bits": 2},
{"params": params_no_wd, "weight_decay": 0},
{"params": params_wd},
]

# create PyTorch optimizer
base_optimizer = torch.optim.SGD( # user-defined
param_groups, lr=0.1, momentum=0.9, weight_decay=1e-4
)

# create quantizer and proximal map objects
quantizer = LSBQuantizer()
prox_map = ProxPARQ(anneal_start=..., anneal_end=..., steepness=...)

optimizer = QuantOptimizer(base_optimizer, quantizer, prox_map)
```
8 changes: 8 additions & 0 deletions torchao/prototype/parq/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from .optim import ( # noqa: F401
ProxBinaryRelax,
ProxHardQuant,
ProxMap,
ProxPARQ,
QuantOptimizer,
)
from .quant import LSBQuantizer, Quantizer, UnifQuantizer # noqa: F401
4 changes: 4 additions & 0 deletions torchao/prototype/parq/optim/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .binarelax import ProxBinaryRelax # noqa: F401
from .parq import ProxPARQ # noqa: F401
from .proxmap import ProxHardQuant, ProxMap # noqa: F401
from .quantopt import QuantOptimizer # noqa: F401
45 changes: 45 additions & 0 deletions torchao/prototype/parq/optim/binarelax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from typing import Optional

import torch
from torch import Tensor

from ..utils import channel_bucketize
from .proxmap import ProxMap


class ProxBinaryRelax(ProxMap):
"""Prox-map of Binary Relax, Q may not be evenly spaced."""

def __init__(self, anneal_start: int, anneal_end: int) -> None:
self.anneal_start = anneal_start
self.anneal_end = anneal_end

@torch.no_grad()
def apply_(
self,
p: Tensor,
q: Tensor,
Q: Tensor,
step_count: int,
dim: Optional[int] = None,
) -> None:
if step_count < self.anneal_start:
return

if q is None:
# hard quantization to the nearest point in Q
Q_mid = (Q[..., :-1] + Q[..., 1:]) / 2
if dim is None:
q = Q[torch.bucketize(p, Q_mid)]
else:
q = Q.gather(1, channel_bucketize(p, Q_mid))

if step_count >= self.anneal_end:
p.copy_(q)
return
else:
# linear annealing of relaxation coefficient
theta = (step_count - self.anneal_start) / (
self.anneal_end - self.anneal_start
)
p.mul_(1 - theta).add_(q, alpha=theta)
95 changes: 95 additions & 0 deletions torchao/prototype/parq/optim/parq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import math
from functools import partial
from typing import Optional

import torch
from torch import Tensor

from ..utils import channel_bucketize
from .proxmap import ProxMap


def amp_custom_fwd(cast_inputs: Optional[torch.types._dtype] = None):
try:
return partial(
torch.amp.custom_fwd, device_type="cuda", cast_inputs=cast_inputs
)
except AttributeError:
return partial(torch.cuda.amp.custom_fwd, cast_inputs=cast_inputs)


def normalized_mirror_sigmoid(t: float, t1: float, t2: float, s: float) -> float:
"""Sigmoid-like function decreasing from 1 to 0 over interval [t1, t2).
s is steepness of the sigmoid-like function, almost linear for s < 1.
'mirror' means decreasing instead of increasing as true sigmoid,
'normalized' means value 1 at starting point t1 and 0 at end point t2."""
assert t >= t1 and t < t2, "Normalized sigmoid: ensure t1 <= t < t2"
ft = (t - t1) / (t2 - t1) # fraction of progress from t1 to t2
st = 1 / (1 + math.exp(s * (ft - 0.5))) # scaled and shifted mirror sigmoid
s1 = 1 / (1 + math.exp(-0.5 * s)) # st value when t = t1 -> ft = 0
s2 = 1 / (1 + math.exp(0.5 * s)) # st value when t = t2 -> ft = 1
return (st - s2) / (s1 - s2) # shift and scale to range (0, 1]


class ProxPARQ(ProxMap):
def __init__(
self, anneal_start: int, anneal_end: int, steepness: float = 10
) -> None:
assert anneal_start < anneal_end, "PARQ annealing: start before end."
assert steepness > 0, "PARQ annealing steepness should be positive."
self.anneal_start = anneal_start
self.anneal_end = anneal_end
self.steepness = steepness

@torch.no_grad()
@amp_custom_fwd(cast_inputs=torch.float32)
def apply_(
self,
p: Tensor,
q: Tensor,
Q: Tensor,
step_count: int,
dim: Optional[int] = None,
) -> float:
"""Prox-map of PARQ with gradual annealing to hard quantization."""

if step_count < self.anneal_start:
inv_slope = 1.0
elif step_count >= self.anneal_end:
inv_slope = 0.0
if q is None:
# hard quantization to the nearest point in Q
Q_mid = (Q[..., :-1] + Q[..., 1:]) / 2
if dim is None:
q = Q[torch.bucketize(p, Q_mid)]
else:
q = Q.gather(1, channel_bucketize(p, Q_mid))
p.copy_(q)
else:
inv_slope = normalized_mirror_sigmoid(
step_count, self.anneal_start, self.anneal_end, self.steepness
)
# it is important to clamp idx-1 and then clamping idx itself
# idx_1[k] == idx[k] iff p[k] > Q.max() or p[k] <= Q.min()
if dim is None:
idx = torch.bucketize(p, Q) # locate quant interval
idx_lower = (idx - 1).clamp_(min=0) # index of lower bound
idx_upper = idx.clamp(max=Q.numel() - 1) # index of upper bound
q_lower = Q[idx_lower] # lower boundary of interval
q_upper = Q[idx_upper] # upper boundary of interval
center = (q_lower + q_upper) / 2 # center of interval
# concise implementation of piecewise-affine prox map
q = (center + (p - center) / inv_slope).clamp_(min=q_lower, max=q_upper)
else:
idx = channel_bucketize(p, Q)
idx_lower = (idx - 1).clamp_(min=0)
idx_upper = idx.clamp(max=Q.size(1) - 1)
q_lower = Q.gather(1, idx_lower)
q_upper = Q.gather(1, idx_upper)
center = (q_lower + q_upper) / 2
q = torch.minimum(
torch.maximum(center + (p - center) / inv_slope, q_lower), q_upper
)
# in-place update of model parameters
p.copy_(q)
return inv_slope
43 changes: 43 additions & 0 deletions torchao/prototype/parq/optim/proxmap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from abc import ABC, abstractmethod
from typing import Optional

import torch
from torch import Tensor

from ..utils import channel_bucketize


# Create an abstract class to provide proximal-mapping interface
class ProxMap(ABC):
@abstractmethod
def apply_(self, p: Tensor, q: Tensor, Q: Tensor, step_count: int) -> None:
"""Provide interface for proximal mapping (modify p in-place):
prox_map.apply_(p, q, Q, step_count)
Inputs:
p (Tensor): tensor to be quantized
q (Tensor): None or hard quantized tensor of same size as p
Q (Tensor): set of target quantization values
step_count: trigger iteration-dependent mapping if needed
"""


class ProxHardQuant(ProxMap):
"""Prox-map of hard quantization, Q may not be evenly spaced."""

@torch.no_grad()
def apply_(
self,
p: Tensor,
q: Tensor,
Q: Tensor,
step_count: int,
dim: Optional[int] = None,
) -> None:
if q is None:
# quantize to the nearest point in Q
Q_mid = (Q[..., :-1] + Q[..., 1:]) / 2
if dim is None:
q = Q[torch.bucketize(p, Q_mid)]
else:
q = Q.gather(1, channel_bucketize(p, Q_mid))
p.copy_(q)
Loading

0 comments on commit ffb4350

Please sign in to comment.