-
Notifications
You must be signed in to change notification settings - Fork 338
Add StretchedUnifTorchaoQuantizer #2576
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2576
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit 6bdd3f6 with merge base 2eb4f97 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
quant_min=quant_min, | ||
quant_max=quant_max, | ||
) | ||
data, scale, zero_point = _layout.post_process( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC, the zero_point isn't some fixed fp value, but can vary slightly based on the ranges.
So what the CPU kernel can support is a fp32 scale times a int8-LUT value.
So if we want the grid {-3.5, -1.5, 1.5, 3.5}, we instead use the LUT = {-7, -3, 3, 7} with s_table = 0.5.
In addition to having this LUT for the whole tensor, we can have an FP32 scale s at a per-group granularity. Dequantization is then:
w_dequantized = s * s_table * LUT[idx]
It's not clear to be that the affine scheme here is representable in that way. IIUC, you have:
w_dequantized = s * (qval - z)
So qval - z could define the LUT, but it looks like we'd have a different LUT per group_size values because z changes every group_size values? Is that right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the comments @metascroy!
Just to clarify from our chat earlier, zero_point=-0.5
is the same across all groups. (I flipped the sign since it's standard to add zero_point during quantization.)
w_quantized = torch.round(x / s + zero_point)
w_dequantized = s * (w_quantized - zero_point)
For the 2-bit case, we set s
so that x / s
is restricted to range [-1.5, 1.5]. Since zero_point=-0.5
, w_quantized
lies in the grid {-2, -1, 0, 1}.
It seems like we don't need to use LUT format since this is well-supported by an affine scheme. Maybe it would be worth supporting for latency comparisons though (and to avoid float zero_point
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is well supported by an affine scheme where zero_point is a float, but we do not have CPU kernel support for this.
But if zero_point is always 0.5, then w_quantized - zero_point is just some value in [1.5, 1.5], and this could define an LUT, so I think we can hook into the kernel in that way.
We just need the LUT to be integer, so we can define the LUT as [-3, -1, 1, 3] and then divide the scales in half.
compare_parq_convert(model, m_ref, optimizer, config) | ||
|
||
|
||
class TestStretchedUnifTorchaoQuantizer(common_utils.TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
New test case that ensures equivalence between PARQ's original UnifQuantizer
implementation and the new StretchedUnifTorchaoQuantizer
q_abs = input_float.abs() | ||
max_val = torch.minimum( | ||
b * q_abs.mean(dim=reduction_dims, keepdim=True), | ||
torch.amax(q_abs, dim=reduction_dims, keepdim=True), | ||
).clamp_(min=eps) | ||
|
||
scale = max_val / quant_max | ||
scale = scale.to(dtype=scale_dtype, device=input_float.device) | ||
zero_point = torch.full_like(scale, -0.5, dtype=zero_point_dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here's the logic for initializing the scale based on multiples of per-group absolute value means. I also manually set the zero point to be the same across groups.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me. We can translate the affine scheme to LUT when we prepare the data for the kernels.
* Add StretchedUnifTorchaoQuantizer * Fix tinygemm test case * Test equivalence to PARQ UnifQuantizer; custom choose_qparams, quantize, dequantize * Remove dequantize_stretched_affine
This PR adds a new stretched uniform quantizer for PARQ, which empirically performs well for 2- and 3-bit QAT. Main differences:
quant_min=-2**(b - 1) + 0.5
andquant_max=2**(b - 1) - 0.5
valuesmin_val
,max_val
are computed by taking a multiple of the mean over absolute values (instead of absmax)As in #2091, I also compare the resulting PARQ quantized weights with those quantized with torchao's module swap +
quantize_
API. To support this, I created a new tensor subclassStretchedAffineQuantizedTensor
and configStretchedIntxWeightOnlyConfig
to handle floating pointquant_min
,quant_max
, andzero_point
values.