Skip to content

Commit

Permalink
Add back QAT module swap API
Browse files Browse the repository at this point in the history
Summary: Recent refactor into tensor subclasses (#585) broke
some existing use cases that rely on DDP and FSDP1, since the
new flow only supports FSDP2 currently. This commit adds back
the module swap API for now to provide a backdoor for these
use cases. In the long term, we still plan to deprecate the
module swap flow.

Test Plan:
python test/quantization/test_qat.py -k test_qat_8da4w_quantizer_module_swap
python test/quantization/test_qat.py -k test_qat_4w_quantizer_module_swap

Reviewers: jerryzh168, msaroufim

Subscribers: jerryzh168, msaroufim
  • Loading branch information
andrewor14 committed Aug 27, 2024
1 parent 37276d6 commit 4f2ef2f
Show file tree
Hide file tree
Showing 4 changed files with 437 additions and 171 deletions.
71 changes: 68 additions & 3 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def _set_ptq_weight(
Int8DynActInt4WeightLinear,
WeightOnlyInt4Linear,
)
from torchao.quantization.prototype.qat.api import (
from torchao.quantization.prototype.qat._module_swap_api import (
Int8DynActInt4WeightQATLinear,
Int4WeightOnlyQATLinear,
)
Expand Down Expand Up @@ -178,7 +178,7 @@ def _set_ptq_weight(

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_8da4w_linear(self):
from torchao.quantization.prototype.qat.api import Int8DynActInt4WeightQATLinear
from torchao.quantization.prototype.qat._module_swap_api import Int8DynActInt4WeightQATLinear
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear

group_size = 128
Expand Down Expand Up @@ -229,6 +229,35 @@ def test_qat_8da4w_quantizer(self):
converted_out = converted_model(*x)
torch.testing.assert_close(ptq_out, converted_out, atol=0, rtol=0)

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_8da4w_quantizer_module_swap(self):
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
from torchao.quantization.prototype.qat._module_swap_api import Int8DynActInt4WeightQATQuantizerModuleSwap

group_size = 16
torch.manual_seed(self.SEED)
m = M()
m2 = copy.deepcopy(m)
subclass_quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size)
module_swap_quantizer = Int8DynActInt4WeightQATQuantizerModuleSwap(groupsize=group_size)
subclass_model = subclass_quantizer.prepare(m)
module_swap_model = module_swap_quantizer.prepare(m2)

# Compare model values
torch.manual_seed(self.SEED)
x = m.example_inputs()
x2 = copy.deepcopy(x)
subclass_out = subclass_model(*x)
module_swap_out = module_swap_model(*x2)
torch.testing.assert_close(subclass_out, module_swap_out, atol=0, rtol=0)

# Convert QAT model and compare model values
subclass_model = subclass_quantizer.convert(subclass_model)
module_swap_model = module_swap_quantizer.convert(module_swap_model)
subclass_out = subclass_model(*x)
module_swap_out = module_swap_model(*x2)
torch.testing.assert_close(subclass_out, module_swap_out, atol=0, rtol=0)

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_8da4w_quantizer_meta_weights(self):
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
Expand Down Expand Up @@ -495,7 +524,7 @@ def test_qat_4w_primitives(self):
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
def test_qat_4w_linear(self):
from torchao.quantization.prototype.qat.api import Int4WeightOnlyQATLinear
from torchao.quantization.prototype.qat._module_swap_api import Int4WeightOnlyQATLinear
from torchao.quantization.GPTQ import WeightOnlyInt4Linear

group_size = 128
Expand Down Expand Up @@ -559,6 +588,42 @@ def test_qat_4w_quantizer_gradients(self):
quantizer = Int4WeightOnlyQATQuantizer(groupsize=32, inner_k_tiles=8)
self._test_qat_quantized_gradients(quantizer)

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_4w_quantizer_module_swap(self):
from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer
from torchao.quantization.prototype.qat._module_swap_api import Int4WeightOnlyQATQuantizerModuleSwap

group_size = 32
inner_k_tiles = 8
device = torch.device("cuda")
dtype = torch.bfloat16
torch.manual_seed(self.SEED)
m = M().to(device).to(dtype)
m2 = copy.deepcopy(m)
subclass_quantizer = Int4WeightOnlyQATQuantizer(
groupsize=group_size, inner_k_tiles=inner_k_tiles,
)
module_swap_quantizer = Int4WeightOnlyQATQuantizerModuleSwap(
groupsize=group_size, inner_k_tiles=inner_k_tiles,
)
subclass_model = subclass_quantizer.prepare(m)
module_swap_model = module_swap_quantizer.prepare(m2)

# Compare model values
torch.manual_seed(self.SEED)
x = [i.to(device).to(dtype) for i in m.example_inputs()]
x2 = copy.deepcopy(x)
subclass_out = subclass_model(*x)
module_swap_out = module_swap_model(*x2)
torch.testing.assert_close(subclass_out, module_swap_out, atol=0, rtol=0)

# Convert QAT model and compare model values
subclass_model = subclass_quantizer.convert(subclass_model)
module_swap_model = module_swap_quantizer.convert(module_swap_model)
subclass_out = subclass_model(*x)
module_swap_out = module_swap_model(*x2)
torch.testing.assert_close(subclass_out, module_swap_out, atol=0, rtol=0)


if __name__ == "__main__":
unittest.main()
3 changes: 3 additions & 0 deletions torchao/quantization/prototype/qat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
int8_dynamic_activation_int4_weight_fake_quantize,
Int4WeightOnlyQATQuantizer,
Int8DynActInt4WeightQATQuantizer,
)

from ._module_swap_api import (
Int8DynActInt4WeightQATLinear,
)

Expand Down
Loading

0 comments on commit 4f2ef2f

Please sign in to comment.