diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index bd4bcce1aa..7fdeb65d96 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -323,21 +323,23 @@ def dynamically_quantize_per_tensor( def quantize_activation_per_token_absmax(t): - n_bits = 8 # if the shape of t is [B, N, K], the shape of scales will be [B, N, 1] + mapping_type = MappingType.SYMMETRIC + block_size = list(t.shape) + for i in range(len(block_size) - 1): + block_size[i] = 1 + dtype = torch.int8 + eps = 1e-5 + scale, zero_point = choose_qparams_affine(t, mapping_type, block_size, dtype, eps=eps, scale_dtype=torch.float) - scales = t.abs().amax(dim=-1, keepdim=True) - if scales.dtype == torch.float16: - scales = ( - scales.float() - ) # want float scales to avoid overflows for fp16, (bf16 has wide enough range) - q_max = 2 ** (n_bits - 1) - 1 - scales = scales.clamp(min=1e-5).div(q_max) # Note: the original smoothquant does not clamp to qmin/qmax here, # but some of the tests with bfloat16 ended up with a flipped sign # if we don't clamp. TODO(future) look into this further. - t = torch.round(t / scales).clamp(-127, 127).to(torch.int8) - return t, scales + quant_min = -127 + quant_max = 127 + quantized = quantize_affine(t, block_size, scale, zero_point, dtype, quant_min, quant_max) + + return quantized, scale def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype): diff --git a/tutorials/quantize_vit/quant.json.gz b/tutorials/quantize_vit/quant.json.gz index a207cefc5f..0b3200eeb5 100644 Binary files a/tutorials/quantize_vit/quant.json.gz and b/tutorials/quantize_vit/quant.json.gz differ