Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions cmake/onnxruntime_rocm_hipify.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ set(contrib_ops_excluded_files
"quantization/attention_quantization_impl.cuh"
"quantization/dequantize_blockwise.cuh"
"quantization/dequantize_blockwise.cu"
"quantization/dequantize_blockwise_bnb4.cuh"
"quantization/dequantize_blockwise_bnb4.cu"
"quantization/matmul_bnb4.cc"
"quantization/matmul_bnb4.cuh"
"quantization/matmul_bnb4.cu"
"quantization/matmul_nbits.cc"
"quantization/matmul_nbits.cuh"
"quantization/matmul_nbits.cu"
Expand Down
57 changes: 57 additions & 0 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ Do not modify directly.*
* <a href="#com.microsoft.Inverse">com.microsoft.Inverse</a>
* <a href="#com.microsoft.Irfft">com.microsoft.Irfft</a>
* <a href="#com.microsoft.LongformerAttention">com.microsoft.LongformerAttention</a>
* <a href="#com.microsoft.MatMulBnb4">com.microsoft.MatMulBnb4</a>
* <a href="#com.microsoft.MatMulFpQ4">com.microsoft.MatMulFpQ4</a>
* <a href="#com.microsoft.MatMulInteger16">com.microsoft.MatMulInteger16</a>
* <a href="#com.microsoft.MatMulIntegerToFloat">com.microsoft.MatMulIntegerToFloat</a>
Expand Down Expand Up @@ -2504,6 +2505,62 @@ This version of the operator has been available since version 1 of the 'com.micr
</dl>


### <a name="com.microsoft.MatMulBnb4"></a><a name="com.microsoft.matmulbnb4">**com.microsoft.MatMulBnb4**</a>

MatMulBnb4 is a MatMul with weight quantized with 4 bits using either FP4 or NF4 data type (https://arxiv.org/pdf/2305.14314.pdf). It does Matrix Multiplication like MatMul (https://github.com/onnx/onnx/blob/main/docs/Operators.md#matmul) with differences:
1. Input B is a 2D constant Matrix. Its input feature count and output feature count are specified by attribute 'K' and 'N'.
2. Input B is quantized with 4 bits with quantization data type specified by attribute 'quant_type'. It is transposed, flattened and quantized blockwisely with block size specified by attribute 'block_size'.
And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,..
3. Input B's quantization constants or scales are specified by input 'absmax'.

Input B is stored as uint8_t with shape: [(N * K + 1) / 2].
Input absmax is stored in same type as original type of B(float32, float16) with shape like: [(N * K + block_size - 1) / block_size].

#### Version

This version of the operator has been available since version 1 of the 'com.microsoft' operator set.

#### Attributes

<dl>
<dt><tt>K</tt> : int (required)</dt>
<dd>size of each input feature</dd>
<dt><tt>N</tt> : int (required)</dt>
<dd>size of each output feature</dd>
<dt><tt>block_size</tt> : int (required)</dt>
<dd>number of groupsize used for weight quantization. It needs to be a power of 2 and not smaller than 16.</dd>
<dt><tt>quant_type</tt> : int (required)</dt>
<dd>quantization data type. 0 for FP4, 1 for NF4.</dd>
</dl>

#### Inputs

<dl>
<dt><tt>A</tt> : T1</dt>
<dd>The input tensor, not quantized</dd>
<dt><tt>B</tt> : T2</dt>
<dd>1-dimensional quantized data for weight</dd>
<dt><tt>absmax</tt> : T1</dt>
<dd>quantization constants</dd>
</dl>

#### Outputs

<dl>
<dt><tt>Y</tt> : T1</dt>
<dd>tensor. The output tensor has the same rank as the input. </dd>
</dl>

#### Type Constraints

<dl>
<dt><tt>T1</tt> : tensor(float), tensor(float16)</dt>
<dd>Constrain input and output types to float/half_float tensors.</dd>
<dt><tt>T2</tt> : tensor(uint8)</dt>
<dd>Constrain quantized weight types to uint8.</dd>
</dl>


### <a name="com.microsoft.MatMulFpQ4"></a><a name="com.microsoft.matmulfpq4">**com.microsoft.MatMulFpQ4**</a>

Matrix product with right hand matrix being pre-packed and quantized int4 data blob.
Expand Down
2 changes: 2 additions & 0 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,7 @@ Do not modify directly.*
|GreedySearch|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *out* sequences:**I**|1+|**T** = tensor(float)|
|GridSample|*in* X:**T1**<br> *in* Grid:**T1**<br> *out* Y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(float)|
|Inverse|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|MatMulBnb4|*in* A:**T1**<br> *in* B:**T2**<br> *in* absmax:**T1**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)|
|MatMulFpQ4|*in* A:**T1**<br> *in* B:**T2**<br> *in* B_shape:**T3**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)<br/> **T3** = tensor(int64)|
|MatMulInteger16|*in* A:**T1**<br> *in* B:**T2**<br> *out* Y:**T3**|1+|**T1** = tensor(int16)<br/> **T2** = tensor(int16)<br/> **T3** = tensor(int32)|
|MatMulIntegerToFloat|*in* A:**T1**<br> *in* B:**T2**<br> *in* a_scale:**T3**<br> *in* b_scale:**T3**<br> *in* a_zero_point:**T1**<br> *in* b_zero_point:**T2**<br> *in* bias:**T3**<br> *out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(float)|
Expand Down Expand Up @@ -849,6 +850,7 @@ Do not modify directly.*
|Inverse|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|Irfft|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|LongformerAttention|*in* input:**T**<br> *in* weight:**T**<br> *in* bias:**T**<br> *in* mask:**T**<br> *in* global_weight:**T**<br> *in* global_bias:**T**<br> *in* global:**G**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|MatMulBnb4|*in* A:**T1**<br> *in* B:**T2**<br> *in* absmax:**T1**<br> *out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(uint8)|
|MatMulNBits|*in* A:**T1**<br> *in* B:**T2**<br> *in* scales:**T1**<br> *in* zero_points:**T2**<br> *out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(uint8)|
|MultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**M**<br> *in* relative_position_bias:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**T** = tensor(float), tensor(float16)|
|NGramRepeatBlock|*in* input_ids:**Tid**<br> *in* scores:**T**<br> *out* scores_out:**T**|1+|**T** = tensor(float)<br/> **Tid** = tensor(int64)|
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Gathe
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, TransposeMatMul); // backward compatibility
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, FusedMatMul);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulNBits);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulBnb4);
#ifndef ORT_MINIMAL_BUILD
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulFpQ4);
#endif
Expand Down Expand Up @@ -270,6 +271,7 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, TransposeMatMul)>, // backward compatibility
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, FusedMatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulNBits)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulBnb4)>,
#ifndef ORT_MINIMAL_BUILD
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulFpQ4)>,
#endif
Expand Down
202 changes: 202 additions & 0 deletions onnxruntime/contrib_ops/cpu/quantization/blockwise_quant_block_bnb4.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include <cstdint>
#include <algorithm>
#include <cmath>

namespace onnxruntime {
namespace contrib {

#if defined(_MSC_VER)
#define FORCEINLINE __forceinline
#else
#define FORCEINLINE __attribute__((always_inline)) inline
#endif

typedef enum Bnb_DataType_t {
FP4 = 0,
NF4 = 1,
} Bnb_DataType_t;

FORCEINLINE uint8_t QuantizeOneFP4(float x) {
// FP4 with bias of 3
// first bit is a sign
// subnormals
// 0b000 = 0
// 0b001 = 0.0625
// 0b110 = 2
// 0b111 = 3
// 0b100 = 4
// 0b101 = 6
// 0b010 = 8
// 0b011 = 12

// we do a binary search
// the pivots are divided by 12 (the FP4 absmax)
// since we assum input data is in [-1.0, 1.0]

// !be careful here, its easy to make a mistake
// that is difficult to noice if you add an extra
// zero somewhere!

uint8_t sign = x < 0 ? 0b1000 : 0b0000;
x = fabsf(x);
if (x > 0.29166667f) {
if (x > 0.583333f) {
if (x > 0.8333333f) {
return 0b0011 + sign;
} else {
return 0b0010 + sign;
}
} else if (x > 0.4166667f) {
return 0b101 + sign;
} else {
return 0b100 + sign;
}
} else if (x > 0.0859375f) {
if (x > 0.20833333f) {
return 0b0111 + sign;
} else {
return 0b0110 + sign;
}
} else if (x > 0.00260417f) {
return 0b0001 + sign;
} else {
return 0b0000 + sign;
}
}

FORCEINLINE uint8_t QuantizeOneNF4(float x) {
if (x > 0.03979014977812767f) {
if (x > 0.3893125355243683f) { // 1
if (x > 0.6427869200706482f) { // 11
if (x > 0.8614784181118011f) { // 111
return 0b1111;
} else {
return 0b1110;
}
} else if (x > 0.5016634166240692f) { // 110
return 0b1101;
} else {
return 0b1100;
}
} else if (x > 0.2035212516784668f) { // 10
if (x > 0.2920137718319893f) { // 101
return 0b1011;
} else {
return 0b1010;
}
} else if (x > 0.1202552504837513f) { // 100
return 0b1001;
} else {
return 0b1000;
}
} else if (x > -0.33967943489551544f) { // 0
if (x > -0.13791173323988914f) { // 01
if (x > -0.045525018125772476f) { // 011
return 0b0111;
} else {
return 0b0110;
}
} else if (x > -0.23460740596055984f) { // 010
return 0b0101;
} else {
return 0b0100;
}
} else if (x > -0.6106329262256622f) { // 00
if (x > -0.4599952697753906f) { // 001
return 0b0011;
} else {
return 0b0010;
}
} else if (x > -0.8480964004993439f) { // 000
return 0b0001;
} else {
return 0b0000;
}
}

template <int32_t DATA_TYPE>
FORCEINLINE uint8_t QuantizeOneBnb4(float x) {
if constexpr (DATA_TYPE == FP4)
return QuantizeOneFP4(x);
else
return QuantizeOneNF4(x);
}

template <typename T, int32_t block_size, int32_t DATA_TYPE>
FORCEINLINE void QuantizeBlockBnb4(const T* src, uint8_t* dst, T& absmax_block, int32_t block_idx, int32_t numel) {
float local_absmax = 0.0f;

int32_t block_len = std::min(block_size, numel - block_idx * block_size);
int32_t src_offset = block_idx * block_size;
int32_t dst_offset = block_idx * block_size / 2;

for (int32_t idx = 0; idx < block_len; idx++) {
const float v = static_cast<float>(src[src_offset + idx]);
local_absmax = fmaxf(local_absmax, fabsf(v));
}

absmax_block = static_cast<T>(local_absmax);
const float reciprocal_absmax = local_absmax ? 1.0f / local_absmax : 0.0f;

for (int32_t idx = 0; idx < block_len; idx += 2) {
const float v0 = static_cast<float>(src[src_offset + idx]) * reciprocal_absmax;
const uint8_t vi0 = QuantizeOneBnb4<DATA_TYPE>(v0);

const float v1 = (idx + 1 < block_len) ? static_cast<float>(src[src_offset + idx + 1]) * reciprocal_absmax : 0;
const uint8_t vi1 = QuantizeOneBnb4<DATA_TYPE>(v1);

dst[dst_offset + idx / 2] = (vi0 << 4) | vi1;
}
}

static float fp4_qaunt_map[16] = {0.00000000f, 5.208333333e-03f, 0.66666667f, 1.00000000f,
0.33333333f, 0.50000000f, 0.16666667f, 0.25000000f,
-0.00000000f, -5.208333333e-03f, -0.66666667f, -1.00000000f,
-0.33333333f, -0.50000000f, -0.16666667f, -0.25000000f};

static float nf4_qaunt_map[16] = {-1.0f,
-0.6961928009986877f,
-0.5250730514526367f,
-0.39491748809814453f,
-0.28444138169288635f,
-0.18477343022823334f,
-0.09105003625154495f,
0.0f,
0.07958029955625534f,
0.16093020141124725f,
0.24611230194568634f,
0.33791524171829224f,
0.44070982933044434f,
0.5626170039176941f,
0.7229568362236023f,
1.0f};

template <typename T, int32_t DATA_TYPE>
FORCEINLINE T DequantizeOneBnb4(uint8_t x) {
if constexpr (DATA_TYPE == FP4)
return static_cast<T>(fp4_qaunt_map[x]);
else
return static_cast<T>(nf4_qaunt_map[x]);
}

template <typename T, int32_t block_size, int32_t DATA_TYPE>
FORCEINLINE void DequantizeBlockBnb4(const uint8_t* src, T* dst, T absmax_block, int32_t block_idx, int32_t numel) {
int32_t block_len = std::min(block_size, numel - block_idx * block_size);
int32_t src_offset = block_idx * block_size / 2;
int32_t dst_offset = block_idx * block_size;

for (int32_t idx = 0; idx < block_len; idx += 2) {
const uint8_t val = src[src_offset + idx / 2];

dst[dst_offset + idx] = DequantizeOneBnb4<T, DATA_TYPE>(val >> 4) * absmax_block;
if (idx + 1 < block_len) dst[dst_offset + idx + 1] = DequantizeOneBnb4<T, DATA_TYPE>(val & 0xF) * absmax_block;
}
}

} // namespace contrib
} // namespace onnxruntime
Loading