Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add WOQ int8 test with Inductor Freeze #362

Merged

Conversation

leslie-fang-intel
Copy link
Collaborator

Summary

Add the WOQ int8 test with inductor freeze

Test Plan

python -u -m pytest -s -v test/integration/test_integration.py -k test_int8_weight_only_quant_with_freeze

Copy link

pytorch-bot bot commented Jun 14, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/362

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 34f38ca with merge base bc2f8b7 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 14, 2024
@leslie-fang-intel
Copy link
Collaborator Author

Hi @jerryzh168, I found current TorchAO testing are all without turning on of torch._inductor.config.freezing. However, most of the Inductor CPP backend optimization like mkldnn linear or gemm template are based on the turning on of this flag. Could we add the test case with the flag turning on? cc @jgong5 @Valentine233

@jerryzh168
Copy link
Contributor

jerryzh168 commented Jun 14, 2024

Hi @jerryzh168, I found current TorchAO testing are all without turning on of torch._inductor.config.freezing. However, most of the Inductor CPP backend optimization like mkldnn linear or gemm template are based on the turning on of this flag. Could we add the test case with the flag turning on? cc @jgong5 @Valentine233

sure, what is the effect of that? is this required to get speedup for cpu?

@jerryzh168
Copy link
Contributor

is this issue related? pytorch/pytorch#122813

@leslie-fang-intel
Copy link
Collaborator Author

is this issue related? pytorch/pytorch#122813

I believe enabling freezing will make it easier for us to implement the first optimization mentioned by @jgong5 in pytorch/pytorch#122813 (comment). However, further work is needed to optimize the int4 WOQ kernel.

woq mostly benefits second+ tokens, i.e., sequence length = 1. The case you provided is for the first token which is compute-bound and would incur overhead in both weight dequant and activation type casting compared with fp computation. Of course, there are two areas to improve: 1) the activations can be kept in bf16 without the need of casting to/from fp32; 2) the woq kernel needs further optimization to leverage bf16 AMX accelerated compute, currently it is with fp32 fma.

@leslie-fang-intel
Copy link
Collaborator Author

leslie-fang-intel commented Jun 15, 2024

The motivation to add this UT is:

  • Previously, we found freeze flag failed to work with TorchAO as reported in [TorchAO] fail to do fake_tensor_prop with freezing pass pytorch#123522, since for a Linear module we only see one parameter of quantized weight instead of 3: int_weight, scale, zp.
  • And recently, we found we can see 3 parameters as int_weight, scale, zp for a linear module, probably related to the implementation of parametrize in
    if TORCH_VERSION_AFTER_2_4:
    quantize(model, get_apply_int8wo_quant(), filter_fn)
    unwrap_tensor_subclass(model, filter_fn)

Looks like the new added UT test_int8_weight_only_quant_with_freeze_0_cpu fails in the preCI environment, but it passes on my local system with latest PyTorch. @jerryzh168 do you have any idea about these failures? Is the preCI still using a legacy PyTorch which fails to meet the requirement of TORCH_VERSION_AFTER_2_4?

@jerryzh168
Copy link
Contributor

are you referring to this error: `RuntimeError: Expected a proper Tensor but got None (or an undefined Tensor in C++) for argument #1 'mat2'

While executing %mm : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%view, %convert_element_type), kwargs = {})`

this seems like an implementation problem in AffineQuantizedTensor, but you should be able to repro I feel, since it failed in all torch versions (2.2.2, 2.3 and nightly)

@leslie-fang-intel
Copy link
Collaborator Author

leslie-fang-intel commented Jun 15, 2024

The failure with torch nightly is

FAILED test/integration/test_integration.py::TestSubclass::test_int8_weight_only_quant_with_freeze_2_cpu - ImportError: This API is deprecated for pytorch 2.4+, please checkout quantization/README.md for most up to date APIs

which is different as 2.2 or 2.3

are you referring to this error: `RuntimeError: Expected a proper Tensor but got None (or an undefined Tensor in C++) for argument https://github.com/pytorch/ao/pull/1 'mat2'

While executing %mm : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%view, %convert_element_type), kwargs = {})`

@leslie-fang-intel leslie-fang-intel force-pushed the leslie/test_ao_with_inductor_freeze branch from fac758a to 34f38ca Compare June 15, 2024 01:43
@leslie-fang-intel
Copy link
Collaborator Author

leslie-fang-intel commented Jun 15, 2024

Hi @jerryzh168, rebase to fix the nightly failure and skip the test with pytorch before 2.4. Could you help to approve for CI running again?

BTW: Why unwrap_tensor_subclass requires PyTorch 2.4 and after, it looks like necessary to run with Inductor freeze.

@jerryzh168
Copy link
Contributor

BTW: Why unwrap_tensor_subclass requires PyTorch 2.4 and after, it looks like necessary to run with Inductor freeze.

this requires a fix: pytorch/pytorch#124888 that is only available in 2.4+

@msaroufim msaroufim requested a review from jerryzh168 June 15, 2024 02:28
@leslie-fang-intel
Copy link
Collaborator Author

leslie-fang-intel commented Jun 16, 2024

Hi @jerryzh168, rebase to fix the nightly failure and skip the test with pytorch before 2.4. Could you help to approve for CI running again?

Looks all the UT are green now.

@msaroufim msaroufim self-requested a review June 18, 2024 06:01
@msaroufim msaroufim merged commit f5b6ec9 into pytorch:main Jun 18, 2024
13 checks passed
dbyoung18 pushed a commit to dbyoung18/ao that referenced this pull request Jul 31, 2024
yanbing-j pushed a commit to yanbing-j/ao that referenced this pull request Dec 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants