Skip to content

Commit

Permalink
Relax QAT dtype assertion (#692)
Browse files Browse the repository at this point in the history
This was added originally for perf reasons specific to 8da4w,
but the autograd.Function has since been adapted for more general
use. A few users are hitting this assertion error.

More context: pytorch/torchtune#1333
  • Loading branch information
andrewor14 authored Aug 16, 2024
1 parent f2c908b commit b523f9f
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 13 deletions.
6 changes: 0 additions & 6 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,9 +422,6 @@ def test_qat_4w_primitives(self):

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
# TODO: remove once we fix int4 error: https://github.com/pytorch/ao/pull/517
@unittest.skipIf(TORCH_VERSION_AT_LEAST_2_4, "assert input.dtype == torch.float32" )
@unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 doesn't work for 2.5+ right now")
def test_qat_4w_linear(self):
from torchao.quantization.prototype.qat.api import Int4WeightOnlyQATLinear
from torchao.quantization.GPTQ import WeightOnlyInt4Linear
Expand Down Expand Up @@ -453,9 +450,6 @@ def test_qat_4w_linear(self):

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
# TODO: remove once we fix int4 error: https://github.com/pytorch/ao/pull/517
@unittest.skipIf(TORCH_VERSION_AT_LEAST_2_4, "assert input.dtype == torch.float32" )
@unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 doesn't work for 2.5+ right now")
def test_qat_4w_quantizer(self):
from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer
from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer
Expand Down
7 changes: 0 additions & 7 deletions torchao/quantization/prototype/qat/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,6 @@ def forward(
block_size: List[int],
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
) -> torch.Tensor:
# Note: for bf16 inputs, casting them to fp32 has the unexpected
# side effect of reducing memory footprint significantly, presumably
# because bf16 * fp32 kernels are not as memory efficient
assert input.dtype == torch.float32
assert scales.dtype == torch.float32
assert zero_points.dtype == torch.int32

(fq, mask) = fake_quantize_affine_cachemask(
input,
block_size,
Expand Down

0 comments on commit b523f9f

Please sign in to comment.