Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate to config for Int8DynamicActivationIntxWeightConfig #1836

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

metascroy
Copy link
Contributor

This PR:

  • Migrates to Int8DynamicActivationIntxWeightConfig
  • Merges PackedLinearInt8DynamicActivationIntxWeightLayout to use the same quantizer, and merges the tests

Copy link

pytorch-bot bot commented Mar 5, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1836

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit fc46e34 with merge base ada4c02 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@metascroy metascroy requested a review from digantdesai March 5, 2025 01:21
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 5, 2025
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@drisspg @jerryzh168 are we ok adding tensor_impl_ctr_kwargs to from_hp_to_intx.

It can be used to propagate a bias when constructing the weight tensor subclass via from_plain.

@drisspg drisspg added the topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories) label Mar 5, 2025
Copy link
Contributor

@Jack-Khuu Jack-Khuu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly nits

Not terribly familiar with this code, but passes the gut test

Comment on lines +280 to +284
if tensor_impl_ctr_kwargs is None:
tensor_impl_ctr_kwargs = {}
tensor_impl = tensor_impl_ctr(
data, scale, zero_point, _layout, **tensor_impl_ctr_kwargs
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't know which style AO uses, no strong pref

Suggested change
if tensor_impl_ctr_kwargs is None:
tensor_impl_ctr_kwargs = {}
tensor_impl = tensor_impl_ctr(
data, scale, zero_point, _layout, **tensor_impl_ctr_kwargs
)
tensor_impl = tensor_impl_ctr(
data, scale, zero_point, _layout, **(tensor_impl_ctr_kwargs or {})
)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to hear from @drisspg or someone from torchao on this change.

Not so much on the style preference, but more so on whether they're OK adding tensor_impl_ctr_kwargs to the to_affine_quantized_intx signature.

Comment on lines +126 to +139
quantized_model_reference = copy.deepcopy(model)
quantize_(
quantized_model_reference,
int8_dynamic_activation_intx_weight(
weight_dtype=weight_dtype,
granularity=granularity,
has_weight_zeros=has_weight_zeros,
layout=reference_layout,
),
)

with torch.no_grad():
result = quantized_model(activations)
expected_result = quantized_model_reference(activations)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: We can factor out the creation of expected_results since it's just PlainLayout in both cases (different models)

and layout.target == Target.ATEN
)
weight_dtype: torch.dtype = torch.int4
granularity: Union[PerRow, PerGroup] = PerRow()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not
granularity: Union[PerRow, PerGroup] = PerGroup(128),

like int8_dynamic_activation_intx_weight?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PerRow is safer default because it doesn't depend on input data size. I expect users should always specify this parameter

)

@register_quantize_module_handler(Int8DynamicActivationIntxWeightConfig)
def _int8_dynamic_activation_intx_weigh_transform(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _int8_dynamic_activation_intx_weigh_transform(
def _int8_dynamic_activation_intx_weight_transform(

tensor_impl_ctr_kwargs = None
if isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout):
# We need to create a new layout object for each module because when
# granulairty is PerRow, the layout objects cannot share the group_size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# granulairty is PerRow, the layout objects cannot share the group_size
# granularity is PerRow, the layout objects cannot share the group_size

Comment on lines +317 to +320
if weight_tensor.tensor_impl.get_layout().has_bias:
assert (
bias is None
), "bias should be None because it is already packed with the weights (has_bias=True)"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: if: assert; also fine with leaving it as-is for legibility

Suggested change
if weight_tensor.tensor_impl.get_layout().has_bias:
assert (
bias is None
), "bias should be None because it is already packed with the weights (has_bias=True)"
assert (
not weight_tensor.tensor_impl.get_layout().has_bias or bias is None
), "bias should be None because it is already packed with the weights (has_bias=True)"

Comment on lines -631 to -635
if torch.backends.kleidiai.is_available():
if isinstance(granularity, PerGroup):
scale_dtype = (
torch.bfloat16
) # KleidiAI kernel requires bfloat16 scale_dtype
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like we always use float32 in to_affine_quantized_intx. Is this intentional?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

KleidiAI tests pass with this. This was only used for python-based quantization that computes qvals, scales, zeros, not by what was passed to the kernel itself.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories)
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants