Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 137 additions & 58 deletions test/prototype/test_parq.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@
from torchao.prototype.parq.quant import (
Int4UnifTorchaoQuantizer,
LSBQuantizer,
StretchedUnifTorchaoQuantizer,
TernaryUnifQuantizer,
UnifQuantizer,
UnifTorchaoQuantizer,
)
from torchao.prototype.parq.quant.quant_api import StretchedIntxWeightOnlyConfig
from torchao.prototype.parq.quant.uniform_torchao import _BIT_WIDTH_TO_DTYPE
from torchao.quantization.granularity import PerGroup
from torchao.quantization.qat import (
Expand All @@ -35,11 +37,11 @@
from torchao.quantization.quant_api import (
Int8DynamicActivationIntxWeightConfig,
IntxWeightOnlyConfig,
MappingType,
_is_linear,
int4_weight_only,
quantize_,
)
from torchao.quantization.quant_primitives import MappingType
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_4,
TORCH_VERSION_AT_LEAST_2_6,
Expand Down Expand Up @@ -74,6 +76,59 @@ def build_param_groups(model, b: int = 2, group_size: Optional[int] = None):
]


def compare_quantized_models(
model: nn.Module,
m_ref: nn.Module,
quantizer: UnifTorchaoQuantizer,
b: int,
group_size: int,
):
for n, module in model.named_children():
if not _is_linear(module):
continue

# simulate grouping from QuantOptimizer.step
p = module.weight
original_shape = p.shape
p = p.view(-1, group_size)

q, Q = quantizer.quantize(p, b=b, dim=-1)

# compare to AffineQuantizedTensor instance
q = q.view(original_shape)
ref = getattr(m_ref, n).weight.dequantize()
torch.testing.assert_close(q, ref, atol=0, rtol=0)


def compare_parq_convert(
model: nn.Module,
m_ref: nn.Module,
optimizer: QuantOptimizer,
config: AOBaseConfig,
):
# do not update model weights, just quantize
optimizer.zero_grad()
optimizer.step()

orig_model = copy.deepcopy(model) # save copy of PARQ quantized model

# equivalent to torchao's convert step
model.eval()
optimizer.restore_latent_params()
quantize_(model, config, filter_fn=optimizer.get_filter_fn(model))

for n, module in model.named_modules():
if not _is_linear(module):
continue

p_orig = getattr(orig_model, n).weight # PARQ weight
p = module.weight.dequantize() # PARQ weight after quantize_
p_ref = getattr(m_ref, n).weight.dequantize() # native quantize_

torch.testing.assert_true(p_orig, p_ref, atol=0, rtol=0)
torch.testing.assert_true(p, p_ref, atol=0, rtol=0)


class M(nn.Module):
def __init__(self, m=256, n=128, k=16, bias=False, embedding=True):
super().__init__()
Expand Down Expand Up @@ -143,59 +198,6 @@ class TestUnifTorchaoQuantizer(common_utils.TestCase):
def setUp(self):
torch.manual_seed(123)

def compare_quantized_models(
self,
model: nn.Module,
m_ref: nn.Module,
quantizer: UnifTorchaoQuantizer,
b: int,
group_size: int,
):
for n, module in model.named_children():
if not _is_linear(module):
continue

# simulate grouping from QuantOptimizer.step
p = module.weight
original_shape = p.shape
p = p.view(-1, group_size)

q, Q = quantizer.quantize(p, b=b, dim=-1)

# compare to AffineQuantizedTensor instance
q = q.view(original_shape)
ref = getattr(m_ref, n).weight.dequantize()
torch.testing.assert_close(q, ref, atol=0, rtol=0)

def compare_parq_convert(
self,
model: nn.Module,
m_ref: nn.Module,
optimizer: QuantOptimizer,
config: AOBaseConfig,
):
# do not update model weights, just quantize
optimizer.zero_grad()
optimizer.step()

orig_model = copy.deepcopy(model) # save copy of PARQ quantized model

# equivalent to torchao's convert step
model.eval()
optimizer.restore_latent_params()
quantize_(model, config, filter_fn=optimizer.get_filter_fn(model))

for n, module in model.named_modules():
if not _is_linear(module):
continue

p_orig = getattr(orig_model, n).weight # PARQ weight
p = module.weight.dequantize() # PARQ weight after quantize_
p_ref = getattr(m_ref, n).weight.dequantize() # native quantize_

torch.testing.assert_true(p_orig, p_ref, atol=0, rtol=0)
torch.testing.assert_true(p, p_ref, atol=0, rtol=0)

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
@common_utils.parametrize("group_size", [32, 256])
def test_int4_weight_only(self, group_size: int = 32):
Expand All @@ -209,7 +211,7 @@ def test_int4_weight_only(self, group_size: int = 32):
quantize_(m_ref, config)

b = 4
self.compare_quantized_models(
compare_quantized_models(
model, m_ref, Int4UnifTorchaoQuantizer(), b, group_size
)

Expand All @@ -229,7 +231,7 @@ def test_intx_weight_only(self, b: int = 2, group_size: int = 32):
)

quantizer = UnifTorchaoQuantizer()
self.compare_quantized_models(model, m_ref, quantizer, b, group_size)
compare_quantized_models(model, m_ref, quantizer, b, group_size)

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(_DEVICE == "cpu", "Need GPU available")
Expand All @@ -251,7 +253,7 @@ def test_int4_weight_only_e2e(self, group_size: int = 32):
ProxHardQuant(),
quant_per_channel=True,
)
self.compare_parq_convert(model, m_ref, optimizer, config)
compare_parq_convert(model, m_ref, optimizer, config)

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.6+")
@unittest.skipIf(_DEVICE == "cpu", "Need GPU available")
Expand All @@ -273,7 +275,84 @@ def test_intx_weight_only_e2e(self, b: int = 2, group_size: int = 32):
ProxHardQuant(),
quant_per_channel=True,
)
self.compare_parq_convert(model, m_ref, optimizer, config)
compare_parq_convert(model, m_ref, optimizer, config)


class TestStretchedUnifTorchaoQuantizer(common_utils.TestCase):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New test case that ensures equivalence between PARQ's original UnifQuantizer implementation and the new StretchedUnifTorchaoQuantizer

def setUp(self):
torch.manual_seed(123)

@common_utils.parametrize("b", [2, 3])
@common_utils.parametrize("group_size", [32, 256])
def test_intx_weight_only_parq_equivalent(self, b: int = 2, group_size: int = 32):
model = M(m=512, n=512).to(_DEVICE)
model.reset_parameters()

quantizer_ref = UnifQuantizer()
quantizer = StretchedUnifTorchaoQuantizer(b)

for n, module in model.named_children():
if not _is_linear(module):
continue

# simulate grouping from QuantOptimizer.step
p = module.weight
p = p.view(-1, group_size)

q_ref, Q_ref = quantizer_ref.quantize(p, b=b, dim=-1)
q, Q = quantizer.quantize(p, b=b, dim=-1)

torch.testing.assert_close(q, q_ref, atol=0, rtol=0)
torch.testing.assert_close(Q, Q_ref, atol=0, rtol=0)

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.6+")
@common_utils.parametrize("b", [2, 3])
@common_utils.parametrize("group_size", [32, 512])
def test_intx_weight_only(self, b: int = 2, group_size: int = 32):
model = M(m=512, n=512).to(_DEVICE)
model.reset_parameters()

quantizer = StretchedUnifTorchaoQuantizer(b)

m_ref = copy.deepcopy(model).eval().to(_DEVICE)
quantize_(
m_ref,
StretchedIntxWeightOnlyConfig(
b=b,
quant_min=quantizer.quant_min,
quant_max=quantizer.quant_max,
granularity=PerGroup(group_size),
),
)

compare_quantized_models(model, m_ref, quantizer, b, group_size)

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.6+")
@unittest.skipIf(_DEVICE == "cpu", "Need GPU available")
@common_utils.parametrize("b", [2, 3])
def test_intx_weight_only_e2e(self, b: int = 2, group_size: int = 32):
model = M(m=512, n=512).to(_DEVICE)
model.reset_parameters()

quantizer = StretchedUnifTorchaoQuantizer(b)

m_ref = copy.deepcopy(model).eval().to(_DEVICE)
config = StretchedIntxWeightOnlyConfig(
b=b,
quant_min=quantizer.quant_min,
quant_max=quantizer.quant_max,
granularity=PerGroup(group_size),
)
quantize_(m_ref, config)

base_optimizer = torch.optim.AdamW(build_param_groups(model, b, group_size))
optimizer = QuantOptimizer(
base_optimizer,
quantizer,
ProxHardQuant(),
quant_per_channel=True,
)
compare_parq_convert(model, m_ref, optimizer, config)


class TestInt8DynamicActivationTorchaoQuantizer(common_utils.TestCase):
Expand Down
1 change: 1 addition & 0 deletions torchao/prototype/parq/quant/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,6 @@
)
from .uniform_torchao import ( # noqa: F401
Int4UnifTorchaoQuantizer,
StretchedUnifTorchaoQuantizer,
UnifTorchaoQuantizer,
)
Loading
Loading