Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] use xnnpack quantization in eager/aoti #698

Closed
wants to merge 28 commits into from
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions _custom_linear/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
cmake_minimum_required(VERSION 3.17)
project(custom_linear)

set(CMAKE_CXX_STANDARD 17)

find_package(Torch REQUIRED)

add_library(custom_linear SHARED custom_linear.cpp)
target_include_directories(custom_linear PRIVATE "${TORCHCHAT_ROOT}/..")
target_link_libraries(custom_linear PRIVATE "${TORCH_LIBRARIES}")
32 changes: 32 additions & 0 deletions _custom_linear/_custom_linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import torch
import torch.nn as nn
from typing import Optional
torch.ops.load_library("_custom_linear/build/libcustom_linear.dylib")

class _CustomLinear(nn.Module):
def _prepare(self) -> None:
self.weight.requires_grad = False
if self.bias:
self.bias.requires_grad = False

# self.packed_weight_bias = torch.ops.prepacked.linear_clamp_prepack(self.weight, self.bias)
self.packed_weight_bias = torch.ops.torchchat.prepack.default(self.weight, self.bias, None, None)
metascroy marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> None:
super().__init__()
self.weight = weight
self.bias = bias
self._prepare()

def forward(self, x):
if x.dtype != torch.float32:
raise RuntimeError(f"x has dtype {x.dtype}, expected float32")
# return torch.ops.prepacked.linear_clamp_run(x, self.packed_weight_bias)
return torch.ops.torchchat.run.default(x, self.packed_weight_bias)

def _replace_linear_with_custom_linear(module: nn.Module):
for name, child in module.named_children():
if isinstance(child, nn.Linear):
setattr(module, name, _CustomLinear(child.weight, child.bias))
else:
_replace_linear_with_custom_linear(child)
5 changes: 5 additions & 0 deletions _custom_linear/build.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
rm -rf build
mkdir build
# cmake -DCMAKE_PREFIX_PATH="$(python -c 'import torch.utils; print(torch.utils.cmake_prefix_path)')" -S . -B build
cmake -DCMAKE_PREFIX_PATH="/Users/scroy/repos/pytorch/torch/share/cmake" -DTORCHCHAT_ROOT="${PWD}/.." -S . -B build
cmake --build build
166 changes: 166 additions & 0 deletions _custom_linear/custom_linear.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
#include <torch/library.h>
#include <torch/script.h>
#include <ATen/native/xnnpack/Linear.h>
#include <ATen/native/xnnpack/OpContext.h>
#include <torchchat/_custom_linear/custom_linear.h>


// Used for: make_zero_points_and_scales_tensor
// #include <ATen/native/quantized/cpu/QnnpackUtils.h>
#include <ATen/native/quantized/cpu/QuantUtils.h>


#include <ATen/native/quantized/cpu/XnnpackUtils.h>


c10::intrusive_ptr<at::native::xnnpack::LinearOpContext> prepack(
torch::Tensor weight,
c10::optional<torch::Tensor> bias,
const c10::optional<at::Scalar>& output_min,
const c10::optional<at::Scalar>& output_max) {
return at::native::xnnpack::XNNPackLinearOpContext::create_context(
metascroy marked this conversation as resolved.
Show resolved Hide resolved
std::move(weight), std::move(bias), output_min, output_max);
}

torch::Tensor run(const torch::Tensor& input, const c10::intrusive_ptr<at::native::xnnpack::LinearOpContext>& op_context) {
return op_context->run(input);
}

torch::Tensor prepack_and_run(
const torch::Tensor& input,
torch::Tensor weight,
c10::optional<torch::Tensor> bias,
const c10::optional<at::Scalar>& output_min,
const c10::optional<at::Scalar>& output_max) {
auto prepacked_op_context = prepack(weight, bias, output_min, output_max);
return run(input, prepacked_op_context);
}

at::Tensor prepack_and_run_qd8_f32_qb4w(
Copy link
Contributor Author

@metascroy metascroy May 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@digantdesai I haven't split prepack/run separately because I first want to get the end-to-end flow working.

Let me know if something is obviously wrong in this function. Here is the output of running it on an example:

import torch
from quantize import group_quantize_tensor_symmetric, convert_to_qc4w

input_channels = 512
output_channels = 5
group_size = 256
batch_size = 3

W = torch.randn(output_channels, input_channels)
w_int, s, z = group_quantize_tensor_symmetric(W, group_size, torch.float32)
w_packed = convert_to_qc4w(w_int)

w_int_dq = (w_int.reshape(-1, group_size) * s.reshape(-1,1)).reshape(
        output_channels, input_channels
)

inp = torch.randn(batch_size, input_channels)

torch.ops.load_library("build/libcustom_linear.dylib")
res1 = torch.ops.torchchat.prepack_and_run_qd8_f32_qb4w.default(w_packed, s, inp, group_size)
res2 = torch.ops.aten.linear.default(inp, W)
res3 = torch.ops.aten.linear.default(inp, w_int_dq)
res1
tensor([[-14.9157, -47.5983,  11.0697, -32.9488,  -8.4086],
        [ 13.8821, -24.7717,   9.0824,  18.2017,   3.9529],
        [ 26.6977,   3.9705, -32.1581,  22.4687,  -3.1330]])

res2
tensor([[-20.5561, -44.4960,  14.6975, -34.0947,  -6.6856],
        [ 16.5121, -24.3467,   6.5836,  20.2640,   2.1489],
        [ 27.7293,   2.4617, -33.1060,  22.3646,  -0.7434]])     
        
res3
tensor([[-15.0061, -47.3699,  11.1573, -32.8825,  -8.1975],
        [ 13.8134, -24.6430,   9.1434,  18.3097,   3.9621],
        [ 26.7795,   3.9283, -32.1443,  22.4861,  -3.3032]])

A couple bits I wasn't sure on and just picked the options that gave the best numeric results (but let me know if not correct):

  • I set input_channels for the operator to be the logical number of input channels (2 times the number of cols in kernel due to packing).

  • I set block_size equal to the group_size, but from https://fburl.com/z945pcpz, I first thought it was the number of groups per row because there is one scale per group.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A couple bits I wasn't sure on

Looks reasonable.
input_channels = K, independent of packing of weights.
group_size = number of input channels per group. num_of_scales = Output channels * (input_channels/group_size). fburl.com/z945pcpz seems wrong.

res1 = torch.ops.torchchat.prepack_and_run_qd8_f32_qb4w.default(w_packed, s, inp, group_size)
res2 = torch.ops.aten.linear.default(inp, W)
res3 = torch.ops.aten.linear.default(inp, w_int_dq)

Looks OK? esp we typically compare res1 and res3 but with some more q/dqs on activation side but this is decent.

res1 and res2 are off but we are comparing 4b vs f32 so, I don't think I did rigorous comparison like this so who knows.

at::Tensor weight,
at::Tensor weight_scales,
at::Tensor input) {

xnn_status status;

status = xnn_initialize(/*allocator=*/nullptr);
TORCH_CHECK(status == xnn_status_success);


const float output_min = -std::numeric_limits<float>::infinity();
const float output_max = std::numeric_limits<float>::infinity();
const uint8_t weight_zero_point = 8;


auto input_channels = weight.size(1);
auto output_channels = weight.size(0);
auto block_size = weight_scales.size(1);

std::cout << "input_channels: " << input_channels << std::endl;
std::cout << "output_channels: " << output_channels << std::endl;
std::cout << "block_size: " << block_size << std::endl;

// Create FC
xnn_operator_t fc_op = nullptr;
status = xnn_create_fully_connected_nc_qd8_f32_qb4w(
input_channels, /*size_t input_channels*/
output_channels, /*size_t output_channels*/
input_channels, /*size_t input_stride*/
output_channels, /*size_t output_stride*/
block_size, /*size_t block_size*/
weight_zero_point, /*uint8_t kernel_zero_point*/
weight_scales.const_data_ptr<float>(), /*const float* kernel_scale*/
weight.const_data_ptr(), /*const void* kernel*/ /* <--------- THIS IS OUTPUT OF PREPACK */
nullptr, /*const float* bias*/
output_min, /*float output_min*/
output_max, /*float output_max*/
0, /*uint32_t flags*/
nullptr, /*xnn_code_cache_t code_cache*/
nullptr, /*xnn_weights_cache_t weights_cache*/
&fc_op /*xnn_operator_t* fully_connected_op_out*/
);
TORCH_CHECK(status == xnn_status_success, status);
TORCH_CHECK(fc_op != nullptr);

// std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_fc_op(fc_op, xnn_delete_operator);



// Create, reshape, setup, and run convert
TORCH_CHECK(input.dim() == 2);
auto batch_size = input.size(0);

// Holds output of convert
std::vector<int8_t> output_convert(batch_size * input_channels + XNN_EXTRA_BYTES);
std::vector<xnn_dynamic_quantization_params> quantization_params(batch_size + XNN_EXTRA_QUANTIZATION_PARAMS);

xnn_operator_t convert_op = nullptr;
status = xnn_create_convert_nc_f32_qd8(
0, /*uint32_t flags*/
&convert_op /*xnn_operator_t* convert_op_out*/
);
TORCH_CHECK(status == xnn_status_success);
TORCH_CHECK(convert_op != nullptr);

status = xnn_reshape_convert_nc_f32_qd8(
convert_op, /*xnn_operator_t convert_op*/
batch_size, /*size_t batch_size*/
input_channels, /*size_t channels*/
input_channels, /*size_t input_stride*/
input_channels, /*size_t output_stride*/
nullptr /*pthreadpool_t threadpool*/
);
TORCH_CHECK(status == xnn_status_success);


status = xnn_setup_convert_nc_f32_qd8(
convert_op, /*xnn_operator_t convert_op*/
input.const_data_ptr<float>(), /*const float* input*/
output_convert.data(), /*int8_t* output*/
quantization_params.data() /*struct xnn_dynamic_quantization_params* quantization_params*/
);
TORCH_CHECK(status == xnn_status_success);

status = xnn_run_operator(convert_op, /*threadpool=*/nullptr);
TORCH_CHECK(status == xnn_status_success);



// Reshape, setup, and run FC
status = xnn_reshape_fully_connected_nc_qd8_f32_qb4w(
fc_op, /*xnn_operator_t fully_connected_op*/
batch_size, /*size_t batch_size*/
nullptr /*pthreadpool_t threadpool*/ // TODO: set to something sensible
);
TORCH_CHECK(status == xnn_status_success);

// Create tensor to hold output
auto options = torch::TensorOptions().dtype(torch::kFloat32);
auto output_tensor = torch::empty({batch_size, output_channels}, options);

status = xnn_setup_fully_connected_nc_qd8_f32_qb4w(
fc_op, /*xnn_operator_t fully_connected_op*/
output_convert.data(), /*const int8_t* input*/
output_tensor.data_ptr<float>(), /*float* output*/
quantization_params.data() /*const struct xnn_dynamic_quantization_params* quantization_params*/
);
TORCH_CHECK(status == xnn_status_success);


status = xnn_run_operator(fc_op, /*threadpool=*/nullptr);
TORCH_CHECK(status == xnn_status_success);

std::cout << "RETURNING." << std::endl;

return output_tensor;
}




TORCH_LIBRARY(torchchat, m) {
metascroy marked this conversation as resolved.
Show resolved Hide resolved
m.def("prepack", prepack);
m.def("run", run);
m.def("prepack_and_run", prepack_and_run);
m.def("prepack_and_run_qd8_f32_qb4w", prepack_and_run_qd8_f32_qb4w);
}
21 changes: 21 additions & 0 deletions _custom_linear/custom_linear.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#pragma once

#include <torch/library.h>
#include <torch/script.h>
#include <ATen/native/xnnpack/Linear.h>
#include <ATen/native/xnnpack/OpContext.h>

c10::intrusive_ptr<at::native::xnnpack::LinearOpContext> prepack(
torch::Tensor weight,
c10::optional<torch::Tensor> bias,
const c10::optional<at::Scalar>& output_min,
const c10::optional<at::Scalar>& output_max);

torch::Tensor run(const torch::Tensor& input, const c10::intrusive_ptr<at::native::xnnpack::LinearOpContext>& op_context);

torch::Tensor prepack_and_run(
const torch::Tensor& input,
torch::Tensor weight,
c10::optional<torch::Tensor> bias,
const c10::optional<at::Scalar>& output_min,
const c10::optional<at::Scalar>& output_max);
92 changes: 92 additions & 0 deletions _custom_linear/quantize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# https://www.internalfb.com/code/fbsource/[f1458254b3caba86fb497abbfe15c74c4e8ca38d]/fbcode/executorch/backends/xnnpack/test/ops/linear.py?lines=348

import torch
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
get_symmetric_quantization_config,
)
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import QuantizationConfig

# Note: not using from torchao.quantization.quant_primitives because it will run into op registraion issues
def get_group_qparams_symmetric(w, n_bit, groupsize, precision):
# needed for GPTQ with padding
if groupsize > w.shape[-1]:
groupsize = w.shape[-1]
assert groupsize > 1
assert w.shape[-1] % groupsize == 0
assert w.dim() == 2

to_quant = w.reshape(-1, groupsize)
assert torch.isnan(to_quant).sum() == 0

max_val = to_quant.amax(dim=1, keepdim=True)
min_val = to_quant.amin(dim=1, keepdim=True)
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))

max_val_abs = torch.max(-min_val_neg, max_val_pos)
max_int = 2 ** (n_bit - 1) - 1
min_int = -(2 ** (n_bit - 1))

# max_int - min_int is just 2**(n_bit) - 1

scales = max_val_abs / (float(max_int - min_int) / 2) # This is just 2 * max(abs(x)) / (int range)
scales = torch.max(
scales, torch.full_like(scales, torch.finfo(torch.float32).eps)
)
# TODO: make sure abs(scales) is not too small?
zeros = torch.full_like(scales, 0)
return scales.to(precision).reshape(w.shape[0], -1), zeros.to(
precision
).reshape(w.shape[0], -1)

# Note: not using from torchao.quantization.quant_primitives because it will run into op registraion issues
# Does 4-bit quantization
def group_quantize_tensor_symmetric(w, group_size, precision):
n_bit = 4
scales, zeros = get_group_qparams_symmetric(w, n_bit, group_size, precision)
max_int = 2 ** (n_bit - 1) - 1
min_int = -(2 ** (n_bit - 1))
# TODO: currently we don't know how to express torch.int4, we'll
# add torch.int4 to core later
w_int8 = torch.ops.quantized_decomposed.quantize_per_channel_group(
w, scales, zeros, min_int, max_int, torch.int8, group_size
)

return w_int8, scales, zeros


# https://www.internalfb.com/code/fbsource/[f1458254b3caba86fb497abbfe15c74c4e8ca38d]/fbcode/executorch/backends/xnnpack/operators/node_visitor.py?lines=451
def convert_to_qc4w(inp: torch.Tensor) -> torch.Tensor:
"""
Convert a tensor to a quantized channelwise tensor 4bit tensor
"""

import torch.nn.functional as F

# Assert we got a properly quantized tensor.
min, max = inp.min().item(), inp.max().item()
assert (
max <= 7 and min >= -8
), f"convert_to_qc4w: [min,max] out of [-8, 7] range, got [{min}, {max}]"

# Assuming we have a 2d tensor
if inp.ndim != 2:
inp = inp.squeeze()
assert (
inp.ndim == 2
), f"convert_to_qc4w: expecting input tensor to be 2d, got {inp.ndim}"

# pad ic
if inp.shape[-1] % 2 != 0:
inp = F.pad(input=inp, pad=(0, 1, 0, 0), mode="constant", value=0)

# Shape after padding
oc, ic = inp.shape
assert ic % 2 == 0, "convert_to_qc4w: expecting ic to be even"

# Adjust inp tensor for zp
inp = inp.to(dtype=torch.uint8) + 8

# Prepare the Result tensor
inp = inp.contiguous().view(-1)
return (inp[1::2] << 4 | inp[::2]).view(oc, int(ic / 2))
20 changes: 20 additions & 0 deletions quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,25 @@ def quantized_model(self) -> nn.Module:
return self.model_.to(device=self.device)


class CustomHandler(QuantHandler):
def __init__(self, model: nn.Module, device="cpu", tokenizer=None):
self.model_ = model
self.device = device
self.tokenizer = tokenizer

def create_quantized_state_dict(self) -> Dict: # "StateDict"
pass

def convert_for_runtime(self) -> nn.Module:
pass

def quantized_model(self) -> nn.Module:
self.model_ = self.model_.to(device=self.device)

from _custom_linear._custom_linear import _replace_linear_with_custom_linear
_replace_linear_with_custom_linear(self.model_)
return self.model_

#########################################################################
##### Quantization Primitives ######

Expand Down Expand Up @@ -1059,4 +1078,5 @@ def quantized_model(self) -> nn.Module:
"linear:hqq": WeightOnlyInt4HqqQuantHandler,
"precision": PrecisionHandler,
"executor": ExecutorHandler,
"_custom": CustomHandler,
}
Loading