Skip to content

Commit

Permalink
Refactor quantize_activation_per_token_absmax to use general quant …
Browse files Browse the repository at this point in the history
…primitives

Summary:
att

Test Plan:
OSS CI

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 committed May 1, 2024
1 parent e3ed90f commit a31cf04
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 10 deletions.
22 changes: 12 additions & 10 deletions torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,21 +323,23 @@ def dynamically_quantize_per_tensor(


def quantize_activation_per_token_absmax(t):
n_bits = 8
# if the shape of t is [B, N, K], the shape of scales will be [B, N, 1]
mapping_type = MappingType.SYMMETRIC
block_size = list(t.shape)
for i in range(len(block_size) - 1):
block_size[i] = 1
dtype = torch.int8
eps = 1e-5
scale, zero_point = choose_qparams_affine(t, mapping_type, block_size, dtype, eps=eps, scale_dtype=torch.float)

scales = t.abs().amax(dim=-1, keepdim=True)
if scales.dtype == torch.float16:
scales = (
scales.float()
) # want float scales to avoid overflows for fp16, (bf16 has wide enough range)
q_max = 2 ** (n_bits - 1) - 1
scales = scales.clamp(min=1e-5).div(q_max)
# Note: the original smoothquant does not clamp to qmin/qmax here,
# but some of the tests with bfloat16 ended up with a flipped sign
# if we don't clamp. TODO(future) look into this further.
t = torch.round(t / scales).clamp(-127, 127).to(torch.int8)
return t, scales
quant_min = -127
quant_max = 127
quantized = quantize_affine(t, block_size, scale, zero_point, dtype, quant_min, quant_max)

return quantized, scale


def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype):
Expand Down
Binary file modified tutorials/quantize_vit/quant.json.gz
Binary file not shown.

0 comments on commit a31cf04

Please sign in to comment.