Skip to content

Commit

Permalink
Refactor quantize_activation_per_token_absmax to use general quant …
Browse files Browse the repository at this point in the history
…primitives

Summary:
att

Test Plan:
OSS CI

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 committed May 1, 2024
1 parent e3ed90f commit 0f3db8c
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 11 deletions.
28 changes: 28 additions & 0 deletions test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,34 @@ def test_choose_qparams_tensor_sym(self):
self.assertTrue(torch.equal(scale, scale_ref))
self.assertTrue(torch.equal(zero_point, zp_ref))

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
def test_quantize_activation_per_token_abs_max(self):
from torchao.quantization.quant_primitives import quantize_activation_per_token_absmax
input = torch.randn(10, 10)
quantized_ref, scale_ref = quantize_activation_per_token_absmax(input)

mapping_type = MappingType.SYMMETRIC
block_size = list(input.shape)
for i in range(len(block_size) - 1):
block_size[i] = 1
dtype = torch.int8
eps = 1e-5
quant_min = -127
quant_max = 127
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, quant_min, quant_max, eps=eps, scale_dtype=torch.float)

quantized = quantize_affine(input, block_size, scale, zero_point, dtype, quant_min, quant_max)

self.assertTrue(torch.equal(quantized, quantized_ref))
self.assertTrue(torch.equal(scale, scale_ref))

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
def test_quantize_activation_per_token_abs_max_zero_input(self):
from torchao.quantization.quant_primitives import quantize_activation_per_token_absmax
input = torch.zeros(10, 10)
# make sure it still works
quantized_ref, scale_ref = quantize_activation_per_token_absmax(input)


@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
def test_quantize_dequantize_group_sym(self):
Expand Down
24 changes: 13 additions & 11 deletions torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,21 +323,23 @@ def dynamically_quantize_per_tensor(


def quantize_activation_per_token_absmax(t):
n_bits = 8
# if the shape of t is [B, N, K], the shape of scales will be [B, N, 1]

scales = t.abs().amax(dim=-1, keepdim=True)
if scales.dtype == torch.float16:
scales = (
scales.float()
) # want float scales to avoid overflows for fp16, (bf16 has wide enough range)
q_max = 2 ** (n_bits - 1) - 1
scales = scales.clamp(min=1e-5).div(q_max)
mapping_type = MappingType.SYMMETRIC
block_size = list(t.shape)
for i in range(len(block_size) - 1):
block_size[i] = 1
dtype = torch.int8
eps = 1e-5
# Note: the original smoothquant does not clamp to qmin/qmax here,
# but some of the tests with bfloat16 ended up with a flipped sign
# if we don't clamp. TODO(future) look into this further.
t = torch.round(t / scales).clamp(-127, 127).to(torch.int8)
return t, scales
quant_min = -127
quant_max = 127
scale, zero_point = choose_qparams_affine(t, mapping_type, block_size, dtype, quant_min, quant_max, eps, scale_dtype=torch.float)

quantized = quantize_affine(t, block_size, scale, zero_point, dtype, quant_min, quant_max)

return quantized, scale


def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype):
Expand Down
Binary file modified tutorials/quantize_vit/quant.json.gz
Binary file not shown.

0 comments on commit 0f3db8c

Please sign in to comment.