Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] use xnnpack quantization in eager/aoti #698

Closed
wants to merge 28 commits into from
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
update main.cpp
metascroy committed May 10, 2024

Verified

This commit was signed with the committer’s verified signature.
commit 92bc3050d9ed9b86084edacfabbf218ce0f980e0
64 changes: 49 additions & 15 deletions _custom_linear/main.cpp
Original file line number Diff line number Diff line change
@@ -1,38 +1,72 @@
#include <torch/library.h>
#include <torch/script.h>
#include <ATen/native/xnnpack/Common.h>
// #include <ATen/native/xnnpack/Common.h>



#include <xnnpack.h>
// #include <caffe2/utils/threadpool/pthreadpool-cpp.h>
// #include <c10/util/ArrayRef.h>
// #include <limits>
// #include <memory>

namespace at::native::xnnpack {

struct Deleter final {
void operator()(const xnn_operator_t op) const {
xnn_delete_operator(op);
}
};

using Operator = std::unique_ptr<xnn_operator, Deleter>;

} // namespace at::native::xnnpack


int main() {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This segfaults.


xnn_status status;
// status = xnn_initialize(/*allocator=*/nullptr);
// TORCH_CHECK(status == xnn_status_success);
status = xnn_initialize(/*allocator=*/nullptr);
TORCH_CHECK(status == xnn_status_success);

auto w_col = 384;
auto input_channels = w_col*2;
auto output_channels = 32000;
auto group_size = 32;
auto n_groups = 24;
int n_groups = input_channels / group_size;
TORCH_CHECK(n_groups * group_size == input_channels);

// auto options = torch::TensorOptions().dtype(torch::kByte);
// auto weight = torch::ones({output_channels, w_col}, options);
// auto weight_scales = torch::ones({output_channels, n_groups});
auto options = torch::TensorOptions().dtype(torch::kByte);
auto weight = torch::ones({output_channels, w_col}, options);
auto weight_scales = torch::ones({output_channels, n_groups});

auto weight_data = std::vector<uint8_t>();
for (int i = 0; i < output_channels * w_col; ++i) {
weight_data.push_back(1);

auto weight_vector = std::vector<uint8_t>(weight.numel(), 0);
for (int i = 0; i < weight.numel(); ++i) {
weight_vector[i] = weight.const_data_ptr<uint8_t>()[i];
}

auto weight_scales = std::vector<float>();
for (int i = 0; i < output_channels * n_groups; ++i) {
weight_data.push_back(1.0);
auto weight_scales_vector = std::vector<float>(weight_scales.numel(), 0);
for (int i = 0; i < weight_scales.numel(); ++i) {
weight_scales_vector[i] = weight_scales.const_data_ptr<float>()[i];
}

// auto weight_ptr = (void*)weight_vector.data();
// auto weight_scales_ptr = weight_scales_vector.data();

auto weight_ptr = (void*)weight.const_data_ptr();
auto weight_scales_ptr = weight_scales.const_data_ptr<float>();


TORCH_CHECK(weight_ptr != nullptr);
TORCH_CHECK(weight_scales_ptr != nullptr);

const float output_min = -std::numeric_limits<float>::infinity();
const float output_max = std::numeric_limits<float>::infinity();
const uint8_t weight_zero_point = 8;



xnn_operator_t fc_op = nullptr;
status = xnn_create_fully_connected_nc_qd8_f32_qb4w(
input_channels, /*size_t input_channels*/
@@ -41,8 +75,8 @@ int main() {
output_channels, /*size_t output_stride*/
group_size, /*size_t block_size*/
weight_zero_point, /*uint8_t kernel_zero_point*/
weight_scales.data(), /*const float* kernel_scale*/
(void*)weight_data.data(), /*const void* kernel*/
weight_scales_ptr, /*const float* kernel_scale*/
weight_ptr, /*const void* kernel*/
nullptr, /*const float* bias*/
output_min, /*float output_min*/
output_max, /*float output_max*/