-
Notifications
You must be signed in to change notification settings - Fork 227
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Initial PARQ addition and testing (#1738)
* Initial PARQ addition and testing * Fix errors due to torch version differences * Revert torchao/float8/config.py * Fix custom decorator * Reformat parq.py * Undo third_party/cutlass change
- Loading branch information
Showing
13 changed files
with
880 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
import unittest | ||
|
||
import torch | ||
|
||
from torchao.prototype.parq.optim import ( | ||
ProxHardQuant, | ||
ProxPARQ, | ||
QuantOptimizer, | ||
) | ||
from torchao.prototype.parq.quant import LSBQuantizer, UnifQuantizer | ||
|
||
_DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | ||
|
||
|
||
def split_param_groups(model): | ||
params_no_quant, params_quant = [], [] | ||
for p in model.parameters(): | ||
if p.dim() > 1: | ||
params_quant.append(p) | ||
else: | ||
params_no_quant.append(p) | ||
return params_no_quant, params_quant | ||
|
||
|
||
class M(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.embedding = torch.nn.Embedding(10, 256) | ||
self.linear1 = torch.nn.Linear(256, 128) | ||
self.linear2 = torch.nn.Linear(128, 16) | ||
self.relu = torch.nn.ReLU() | ||
self.sigmoid = torch.nn.Sigmoid() | ||
|
||
def reset_parameters(self): | ||
for module in (self.linear1, self.linear2): | ||
torch.nn.init.xavier_uniform_(module.weight) | ||
torch.nn.init.zeros_(module.bias) | ||
|
||
def example_inputs(self): | ||
return torch.randint(1, 10, (1, 256)) | ||
|
||
def forward(self, x): | ||
x = self.embedding(x) | ||
x = self.linear1(x) | ||
x = self.relu(x) | ||
x = self.linear2(x) | ||
x = self.sigmoid(x) | ||
return x | ||
|
||
|
||
class TestPARQuantization(unittest.TestCase): | ||
def setUp(self): | ||
torch.manual_seed(123) | ||
self.model = M().to(_DEVICE) | ||
self.params_no_quant, self.params_quant = split_param_groups(self.model) | ||
|
||
def test_2bit_unif_quantizer_hard_prox(self): | ||
self.model.reset_parameters() | ||
param_groups = [ | ||
{"params": self.params_no_quant}, | ||
{"params": self.params_quant, "quant_bits": 2}, | ||
] | ||
base_optimizer = torch.optim.AdamW(param_groups) | ||
quantizer = UnifQuantizer() | ||
prox_map = ProxHardQuant() | ||
optimizer = QuantOptimizer(base_optimizer, quantizer, prox_map) | ||
|
||
x = self.model.example_inputs().to(_DEVICE) | ||
out = self.model(x) | ||
out.sum().backward() | ||
optimizer.step() | ||
|
||
for child in self.model.children(): | ||
if isinstance(child, torch.nn.Linear): | ||
self.assertEqual(child.weight.unique().numel(), 4) | ||
|
||
def test_ternarybit_lsbq_parq_prox(self): | ||
self.model.reset_parameters() | ||
param_groups = [ | ||
{"params": self.params_no_quant}, | ||
{"params": self.params_quant, "quant_bits": 0}, | ||
] | ||
base_optimizer = torch.optim.AdamW(param_groups) | ||
quantizer = LSBQuantizer() | ||
prox_map = ProxPARQ(anneal_start=0, anneal_end=2) | ||
optimizer = QuantOptimizer(base_optimizer, quantizer, prox_map) | ||
|
||
for _ in range(3): | ||
x = self.model.example_inputs().to(_DEVICE) | ||
out = self.model(x) | ||
out.sum().backward() | ||
optimizer.step() | ||
|
||
for child in self.model.children(): | ||
if isinstance(child, torch.nn.Linear): | ||
self.assertEqual(child.weight.unique().numel(), 3) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# PARQ: Piecewise-Affine Regularized Quantization | ||
|
||
PARQ is a QAT method based on a convex regularization framework. It converges to hard quantization (i.e., STE) at its asymptotic limit. | ||
|
||
This library applies QAT without modifying model-level code. It instead interfaces with the optimizer only, allowing a user to choose which parameters should be quantized via parameter groups. It separates QAT into the below components. | ||
|
||
* quantization method: computing the best set of discrete, quantized values | ||
* proximal mapping: projection of weights onto quantized values | ||
|
||
## QAT arguments | ||
|
||
| | description | choices | | ||
| --- | --- | --- | | ||
| `quant-bits` | bit-width for quantized weights | 0 (ternary), 1—4 | | ||
| `quant-method` | method for determining quantized values | `lsbq`, `uniform` | | ||
| `quant-proxmap` | proximal mapping to project weights onto quantized values | `hard`, `parq`, `binaryrelax` | | ||
| `anneal-start` | start epoch for QAT annealing period | (0, `total_steps` - 1) | | ||
| `anneal-end` | end epoch for QAT annealing period | (`anneal_end`, `total_steps`) | | ||
| `anneal-steepness` | sigmoid steepness for PARQ inverse slope schedule | 25—100 | | ||
|
||
## Optimizer-only interface | ||
|
||
The `QuantOptimizer` wrapper takes any `torch.optim.Optimizer` object. It is also initialized with a `Quantizer` and `ProxMap` object. Integration into new training pipelines is simple: | ||
```python | ||
from parq.optim import ProxPARQ, QuantOptimizer | ||
from parq.quant import LSBQuantizer | ||
|
||
|
||
# split params into quantizable and non-quantizable params | ||
params_quant, params_no_wd, params_wd = split_param_groups(model) # user-defined | ||
param_groups = [ | ||
{"params": params_quant, "quant_bits": 2}, | ||
{"params": params_no_wd, "weight_decay": 0}, | ||
{"params": params_wd}, | ||
] | ||
|
||
# create PyTorch optimizer | ||
base_optimizer = torch.optim.SGD( # user-defined | ||
param_groups, lr=0.1, momentum=0.9, weight_decay=1e-4 | ||
) | ||
|
||
# create quantizer and proximal map objects | ||
quantizer = LSBQuantizer() | ||
prox_map = ProxPARQ(anneal_start=..., anneal_end=..., steepness=...) | ||
|
||
optimizer = QuantOptimizer(base_optimizer, quantizer, prox_map) | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from .optim import ( # noqa: F401 | ||
ProxBinaryRelax, | ||
ProxHardQuant, | ||
ProxMap, | ||
ProxPARQ, | ||
QuantOptimizer, | ||
) | ||
from .quant import LSBQuantizer, Quantizer, UnifQuantizer # noqa: F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .binarelax import ProxBinaryRelax # noqa: F401 | ||
from .parq import ProxPARQ # noqa: F401 | ||
from .proxmap import ProxHardQuant, ProxMap # noqa: F401 | ||
from .quantopt import QuantOptimizer # noqa: F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
from typing import Optional | ||
|
||
import torch | ||
from torch import Tensor | ||
|
||
from ..utils import channel_bucketize | ||
from .proxmap import ProxMap | ||
|
||
|
||
class ProxBinaryRelax(ProxMap): | ||
"""Prox-map of Binary Relax, Q may not be evenly spaced.""" | ||
|
||
def __init__(self, anneal_start: int, anneal_end: int) -> None: | ||
self.anneal_start = anneal_start | ||
self.anneal_end = anneal_end | ||
|
||
@torch.no_grad() | ||
def apply_( | ||
self, | ||
p: Tensor, | ||
q: Tensor, | ||
Q: Tensor, | ||
step_count: int, | ||
dim: Optional[int] = None, | ||
) -> None: | ||
if step_count < self.anneal_start: | ||
return | ||
|
||
if q is None: | ||
# hard quantization to the nearest point in Q | ||
Q_mid = (Q[..., :-1] + Q[..., 1:]) / 2 | ||
if dim is None: | ||
q = Q[torch.bucketize(p, Q_mid)] | ||
else: | ||
q = Q.gather(1, channel_bucketize(p, Q_mid)) | ||
|
||
if step_count >= self.anneal_end: | ||
p.copy_(q) | ||
return | ||
else: | ||
# linear annealing of relaxation coefficient | ||
theta = (step_count - self.anneal_start) / ( | ||
self.anneal_end - self.anneal_start | ||
) | ||
p.mul_(1 - theta).add_(q, alpha=theta) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
import math | ||
from functools import partial | ||
from typing import Optional | ||
|
||
import torch | ||
from torch import Tensor | ||
|
||
from ..utils import channel_bucketize | ||
from .proxmap import ProxMap | ||
|
||
|
||
def amp_custom_fwd(cast_inputs: Optional[torch.types._dtype] = None): | ||
try: | ||
return partial( | ||
torch.amp.custom_fwd, device_type="cuda", cast_inputs=cast_inputs | ||
) | ||
except AttributeError: | ||
return partial(torch.cuda.amp.custom_fwd, cast_inputs=cast_inputs) | ||
|
||
|
||
def normalized_mirror_sigmoid(t: float, t1: float, t2: float, s: float) -> float: | ||
"""Sigmoid-like function decreasing from 1 to 0 over interval [t1, t2). | ||
s is steepness of the sigmoid-like function, almost linear for s < 1. | ||
'mirror' means decreasing instead of increasing as true sigmoid, | ||
'normalized' means value 1 at starting point t1 and 0 at end point t2.""" | ||
assert t >= t1 and t < t2, "Normalized sigmoid: ensure t1 <= t < t2" | ||
ft = (t - t1) / (t2 - t1) # fraction of progress from t1 to t2 | ||
st = 1 / (1 + math.exp(s * (ft - 0.5))) # scaled and shifted mirror sigmoid | ||
s1 = 1 / (1 + math.exp(-0.5 * s)) # st value when t = t1 -> ft = 0 | ||
s2 = 1 / (1 + math.exp(0.5 * s)) # st value when t = t2 -> ft = 1 | ||
return (st - s2) / (s1 - s2) # shift and scale to range (0, 1] | ||
|
||
|
||
class ProxPARQ(ProxMap): | ||
def __init__( | ||
self, anneal_start: int, anneal_end: int, steepness: float = 10 | ||
) -> None: | ||
assert anneal_start < anneal_end, "PARQ annealing: start before end." | ||
assert steepness > 0, "PARQ annealing steepness should be positive." | ||
self.anneal_start = anneal_start | ||
self.anneal_end = anneal_end | ||
self.steepness = steepness | ||
|
||
@torch.no_grad() | ||
@amp_custom_fwd(cast_inputs=torch.float32) | ||
def apply_( | ||
self, | ||
p: Tensor, | ||
q: Tensor, | ||
Q: Tensor, | ||
step_count: int, | ||
dim: Optional[int] = None, | ||
) -> float: | ||
"""Prox-map of PARQ with gradual annealing to hard quantization.""" | ||
|
||
if step_count < self.anneal_start: | ||
inv_slope = 1.0 | ||
elif step_count >= self.anneal_end: | ||
inv_slope = 0.0 | ||
if q is None: | ||
# hard quantization to the nearest point in Q | ||
Q_mid = (Q[..., :-1] + Q[..., 1:]) / 2 | ||
if dim is None: | ||
q = Q[torch.bucketize(p, Q_mid)] | ||
else: | ||
q = Q.gather(1, channel_bucketize(p, Q_mid)) | ||
p.copy_(q) | ||
else: | ||
inv_slope = normalized_mirror_sigmoid( | ||
step_count, self.anneal_start, self.anneal_end, self.steepness | ||
) | ||
# it is important to clamp idx-1 and then clamping idx itself | ||
# idx_1[k] == idx[k] iff p[k] > Q.max() or p[k] <= Q.min() | ||
if dim is None: | ||
idx = torch.bucketize(p, Q) # locate quant interval | ||
idx_lower = (idx - 1).clamp_(min=0) # index of lower bound | ||
idx_upper = idx.clamp(max=Q.numel() - 1) # index of upper bound | ||
q_lower = Q[idx_lower] # lower boundary of interval | ||
q_upper = Q[idx_upper] # upper boundary of interval | ||
center = (q_lower + q_upper) / 2 # center of interval | ||
# concise implementation of piecewise-affine prox map | ||
q = (center + (p - center) / inv_slope).clamp_(min=q_lower, max=q_upper) | ||
else: | ||
idx = channel_bucketize(p, Q) | ||
idx_lower = (idx - 1).clamp_(min=0) | ||
idx_upper = idx.clamp(max=Q.size(1) - 1) | ||
q_lower = Q.gather(1, idx_lower) | ||
q_upper = Q.gather(1, idx_upper) | ||
center = (q_lower + q_upper) / 2 | ||
q = torch.minimum( | ||
torch.maximum(center + (p - center) / inv_slope, q_lower), q_upper | ||
) | ||
# in-place update of model parameters | ||
p.copy_(q) | ||
return inv_slope |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import Optional | ||
|
||
import torch | ||
from torch import Tensor | ||
|
||
from ..utils import channel_bucketize | ||
|
||
|
||
# Create an abstract class to provide proximal-mapping interface | ||
class ProxMap(ABC): | ||
@abstractmethod | ||
def apply_(self, p: Tensor, q: Tensor, Q: Tensor, step_count: int) -> None: | ||
"""Provide interface for proximal mapping (modify p in-place): | ||
prox_map.apply_(p, q, Q, step_count) | ||
Inputs: | ||
p (Tensor): tensor to be quantized | ||
q (Tensor): None or hard quantized tensor of same size as p | ||
Q (Tensor): set of target quantization values | ||
step_count: trigger iteration-dependent mapping if needed | ||
""" | ||
|
||
|
||
class ProxHardQuant(ProxMap): | ||
"""Prox-map of hard quantization, Q may not be evenly spaced.""" | ||
|
||
@torch.no_grad() | ||
def apply_( | ||
self, | ||
p: Tensor, | ||
q: Tensor, | ||
Q: Tensor, | ||
step_count: int, | ||
dim: Optional[int] = None, | ||
) -> None: | ||
if q is None: | ||
# quantize to the nearest point in Q | ||
Q_mid = (Q[..., :-1] + Q[..., 1:]) / 2 | ||
if dim is None: | ||
q = Q[torch.bucketize(p, Q_mid)] | ||
else: | ||
q = Q.gather(1, channel_bucketize(p, Q_mid)) | ||
p.copy_(q) |
Oops, something went wrong.