-
Notifications
You must be signed in to change notification settings - Fork 227
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Migrate to config for Int8DynamicActivationIntxWeightConfig #1836
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1836
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit fc46e34 with merge base ada4c02 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@drisspg @jerryzh168 are we ok adding tensor_impl_ctr_kwargs to from_hp_to_intx.
It can be used to propagate a bias when constructing the weight tensor subclass via from_plain.
e332b54
to
4b3a742
Compare
f138c3d
to
fc46e34
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly nits
Not terribly familiar with this code, but passes the gut test
if tensor_impl_ctr_kwargs is None: | ||
tensor_impl_ctr_kwargs = {} | ||
tensor_impl = tensor_impl_ctr( | ||
data, scale, zero_point, _layout, **tensor_impl_ctr_kwargs | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't know which style AO uses, no strong pref
if tensor_impl_ctr_kwargs is None: | |
tensor_impl_ctr_kwargs = {} | |
tensor_impl = tensor_impl_ctr( | |
data, scale, zero_point, _layout, **tensor_impl_ctr_kwargs | |
) | |
tensor_impl = tensor_impl_ctr( | |
data, scale, zero_point, _layout, **(tensor_impl_ctr_kwargs or {}) | |
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd like to hear from @drisspg or someone from torchao on this change.
Not so much on the style preference, but more so on whether they're OK adding tensor_impl_ctr_kwargs to the to_affine_quantized_intx signature.
quantized_model_reference = copy.deepcopy(model) | ||
quantize_( | ||
quantized_model_reference, | ||
int8_dynamic_activation_intx_weight( | ||
weight_dtype=weight_dtype, | ||
granularity=granularity, | ||
has_weight_zeros=has_weight_zeros, | ||
layout=reference_layout, | ||
), | ||
) | ||
|
||
with torch.no_grad(): | ||
result = quantized_model(activations) | ||
expected_result = quantized_model_reference(activations) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: We can factor out the creation of expected_results since it's just PlainLayout in both cases (different models)
and layout.target == Target.ATEN | ||
) | ||
weight_dtype: torch.dtype = torch.int4 | ||
granularity: Union[PerRow, PerGroup] = PerRow() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not
granularity: Union[PerRow, PerGroup] = PerGroup(128),
like int8_dynamic_activation_intx_weight
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PerRow is safer default because it doesn't depend on input data size. I expect users should always specify this parameter
) | ||
|
||
@register_quantize_module_handler(Int8DynamicActivationIntxWeightConfig) | ||
def _int8_dynamic_activation_intx_weigh_transform( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def _int8_dynamic_activation_intx_weigh_transform( | |
def _int8_dynamic_activation_intx_weight_transform( |
tensor_impl_ctr_kwargs = None | ||
if isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout): | ||
# We need to create a new layout object for each module because when | ||
# granulairty is PerRow, the layout objects cannot share the group_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# granulairty is PerRow, the layout objects cannot share the group_size | |
# granularity is PerRow, the layout objects cannot share the group_size |
if weight_tensor.tensor_impl.get_layout().has_bias: | ||
assert ( | ||
bias is None | ||
), "bias should be None because it is already packed with the weights (has_bias=True)" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: if: assert
; also fine with leaving it as-is for legibility
if weight_tensor.tensor_impl.get_layout().has_bias: | |
assert ( | |
bias is None | |
), "bias should be None because it is already packed with the weights (has_bias=True)" | |
assert ( | |
not weight_tensor.tensor_impl.get_layout().has_bias or bias is None | |
), "bias should be None because it is already packed with the weights (has_bias=True)" |
if torch.backends.kleidiai.is_available(): | ||
if isinstance(granularity, PerGroup): | ||
scale_dtype = ( | ||
torch.bfloat16 | ||
) # KleidiAI kernel requires bfloat16 scale_dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like we always use float32 in to_affine_quantized_intx
. Is this intentional?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
KleidiAI tests pass with this. This was only used for python-based quantization that computes qvals, scales, zeros, not by what was passed to the kernel itself.
This PR: