Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] use xnnpack quantization in eager/aoti #698

Closed
wants to merge 28 commits into from

Conversation

metascroy
Copy link
Contributor

@metascroy metascroy commented May 6, 2024

The CustomQuantHandler is just a temporary construct so that we can run the new linear class conveniently in eager mode. It is not the final API.

python torchchat.py generate stories15M --quantize '{"_custom":{}, "precision":{"dtype": "fp32"}}' --prompt "Once upon a time"

Using device=cpu Apple M1 Pro
Loading model...
Time to load model: 0.06 seconds
Quantizing the model with: {'_custom': {}, 'precision': {'dtype': 'fp32'}}
Time to quantize model: 0.09 seconds
Once upon a time, there was a little boy named Timmy. Timmy loved to play with his toys and explore the world around him. One day, he found a shiny object on the ground. It was a little box with a lock on it. 
Timmy wanted to open the box, but he didn't know how. He asked his friend, Billy, for help. "Can you help me open this box, Billy?" Timmy asked. 
Billy replied, "Sure, Timmy. I'll help you with that." Together, they turned a knob and the box opened. Inside was a treasure map that led to a secret cave. 
Timmy and Billy were excited to go on an adventure and find the cave. As they walked through the cave, Timmy asked Billy, "Do you think we can find the treasure?" Billy replied, "Yes, we can!" 
Finally, they arrived at
Max Sequence Length Reached. Ending Conversation.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label May 6, 2024
@metascroy metascroy requested a review from digantdesai May 6, 2024 18:12
export_aoti_util.py Outdated Show resolved Hide resolved
@mikekgfb
Copy link
Contributor

mikekgfb commented May 6, 2024

What is this PR trying to do? BTW, we might start collecting these source transformation operations into a subdirectory, like @larryliu0820 had done for the Executorch release version. (Not sure we'd want to do it this very instant,but as a directional statement... we're starting to collect all sorts of util.y files associated with some flavor of source transformation. Also, if it's supporting eager execution, aoti may be a misnomer?

@metascroy metascroy changed the title Linear class with init + forward separation [DRAFT] linear class with init + forward separation May 7, 2024
@metascroy
Copy link
Contributor Author

@mikekgfb the ultimate goal of the PR is to get xnnpack 4bit quantization working on desktop eager + AOTI (as opposed to using int4pack_mm_kernel_ ), but it is still pretty rough right now (on the aoti misnomer: that file no longer exists in the latest commits because I'm starting with eager) cc @digantdesai

I definitely like the idea of collecting transformations in a common place. I'll do some reorganization before landing this, and I'll look to @larryliu0820's work in ExecuTorch for inspiration.

@metascroy metascroy changed the title [DRAFT] linear class with init + forward separation [DRAFT] use xnnpack quantization in eager/aoti May 7, 2024
return run(input, prepacked_op_context);
}

at::Tensor prepack_and_run_qd8_f32_qb4w(
Copy link
Contributor Author

@metascroy metascroy May 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@digantdesai I haven't split prepack/run separately because I first want to get the end-to-end flow working.

Let me know if something is obviously wrong in this function. Here is the output of running it on an example:

import torch
from quantize import group_quantize_tensor_symmetric, convert_to_qc4w

input_channels = 512
output_channels = 5
group_size = 256
batch_size = 3

W = torch.randn(output_channels, input_channels)
w_int, s, z = group_quantize_tensor_symmetric(W, group_size, torch.float32)
w_packed = convert_to_qc4w(w_int)

w_int_dq = (w_int.reshape(-1, group_size) * s.reshape(-1,1)).reshape(
        output_channels, input_channels
)

inp = torch.randn(batch_size, input_channels)

torch.ops.load_library("build/libcustom_linear.dylib")
res1 = torch.ops.torchchat.prepack_and_run_qd8_f32_qb4w.default(w_packed, s, inp, group_size)
res2 = torch.ops.aten.linear.default(inp, W)
res3 = torch.ops.aten.linear.default(inp, w_int_dq)
res1
tensor([[-14.9157, -47.5983,  11.0697, -32.9488,  -8.4086],
        [ 13.8821, -24.7717,   9.0824,  18.2017,   3.9529],
        [ 26.6977,   3.9705, -32.1581,  22.4687,  -3.1330]])

res2
tensor([[-20.5561, -44.4960,  14.6975, -34.0947,  -6.6856],
        [ 16.5121, -24.3467,   6.5836,  20.2640,   2.1489],
        [ 27.7293,   2.4617, -33.1060,  22.3646,  -0.7434]])     
        
res3
tensor([[-15.0061, -47.3699,  11.1573, -32.8825,  -8.1975],
        [ 13.8134, -24.6430,   9.1434,  18.3097,   3.9621],
        [ 26.7795,   3.9283, -32.1443,  22.4861,  -3.3032]])

A couple bits I wasn't sure on and just picked the options that gave the best numeric results (but let me know if not correct):

  • I set input_channels for the operator to be the logical number of input channels (2 times the number of cols in kernel due to packing).

  • I set block_size equal to the group_size, but from https://fburl.com/z945pcpz, I first thought it was the number of groups per row because there is one scale per group.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A couple bits I wasn't sure on

Looks reasonable.
input_channels = K, independent of packing of weights.
group_size = number of input channels per group. num_of_scales = Output channels * (input_channels/group_size). fburl.com/z945pcpz seems wrong.

res1 = torch.ops.torchchat.prepack_and_run_qd8_f32_qb4w.default(w_packed, s, inp, group_size)
res2 = torch.ops.aten.linear.default(inp, W)
res3 = torch.ops.aten.linear.default(inp, w_int_dq)

Looks OK? esp we typically compare res1 and res3 but with some more q/dqs on activation side but this is decent.

res1 and res2 are off but we are comparing 4b vs f32 so, I don't think I did rigorous comparison like this so who knows.

at::Tensor weight,
at::Tensor weight_scales,
at::Tensor input,
int64_t group_size) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: infer group_size from weight_scales + weight

TORCH_CHECK(status == xnn_status_success, "Operator xnn_setup_fully_connected_nc_qd8_f32_qb4w failed with status ", status, ".");


status = xnn_run_operator(fc_op->get(), /*threadpool=*/nullptr);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you will need pthreadpool pointer for perf down the line

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this, I’ll just copy the thread pool used in linear.cpp?

Comment on lines 157 to 158
m.def("create_fully_connected_nc_qd8_f32_qb4w", create_fully_connected_nc_qd8_f32_qb4w);
m.def("create_convert_nc_f32_qd8", create_convert_nc_f32_qd8);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why register these?

Copy link
Contributor Author

@metascroy metascroy May 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So they can be called in the init function in python. The returned operators will then be saved and passed to run in forward.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I renamed to prepack/run to make the intention clearer.

#include <torch/script.h>
#include <ATen/native/xnnpack/Common.h>

int main() {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This segfaults.

Copy link

pytorch-bot bot commented May 16, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchchat/698

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit 3dccf9f with merge base e3af9ee (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

return w_int8, scales, zeros


# https://www.internalfb.com/code/fbsource/[f1458254b3caba86fb497abbfe15c74c4e8ca38d]/fbcode/executorch/backends/xnnpack/operators/node_visitor.py?lines=451

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oss link?

std::vector<xnn_dynamic_quantization_params> quantization_params(
batch_size + XNN_EXTRA_QUANTIZATION_PARAMS);

auto threadpool = caffe2::pthreadpool_();

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so clean :)

@Jack-Khuu
Copy link
Contributor

Closing stale PR

@Jack-Khuu Jack-Khuu closed this Jun 24, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants