diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index fe3e577b4fc36..de1458c120016 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -54,6 +54,11 @@ set(contrib_ops_excluded_files "quantization/attention_quantization_impl.cuh" "quantization/dequantize_blockwise.cuh" "quantization/dequantize_blockwise.cu" + "quantization/dequantize_blockwise_bnb4.cuh" + "quantization/dequantize_blockwise_bnb4.cu" + "quantization/matmul_bnb4.cc" + "quantization/matmul_bnb4.cuh" + "quantization/matmul_bnb4.cu" "quantization/matmul_nbits.cc" "quantization/matmul_nbits.cuh" "quantization/matmul_nbits.cu" diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 5805333a0868c..1a76c18a6a8e0 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -47,6 +47,7 @@ Do not modify directly.* * com.microsoft.Inverse * com.microsoft.Irfft * com.microsoft.LongformerAttention + * com.microsoft.MatMulBnb4 * com.microsoft.MatMulFpQ4 * com.microsoft.MatMulInteger16 * com.microsoft.MatMulIntegerToFloat @@ -2504,6 +2505,62 @@ This version of the operator has been available since version 1 of the 'com.micr +### **com.microsoft.MatMulBnb4** + + MatMulBnb4 is a MatMul with weight quantized with 4 bits using either FP4 or NF4 data type (https://arxiv.org/pdf/2305.14314.pdf). It does Matrix Multiplication like MatMul (https://github.com/onnx/onnx/blob/main/docs/Operators.md#matmul) with differences: + 1. Input B is a 2D constant Matrix. Its input feature count and output feature count are specified by attribute 'K' and 'N'. + 2. Input B is quantized with 4 bits with quantization data type specified by attribute 'quant_type'. It is transposed, flattened and quantized blockwisely with block size specified by attribute 'block_size'. + And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,.. + 3. Input B's quantization constants or scales are specified by input 'absmax'. + + Input B is stored as uint8_t with shape: [(N * K + 1) / 2]. + Input absmax is stored in same type as original type of B(float32, float16) with shape like: [(N * K + block_size - 1) / block_size]. + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Attributes + +
+
K : int (required)
+
size of each input feature
+
N : int (required)
+
size of each output feature
+
block_size : int (required)
+
number of groupsize used for weight quantization. It needs to be a power of 2 and not smaller than 16.
+
quant_type : int (required)
+
quantization data type. 0 for FP4, 1 for NF4.
+
+ +#### Inputs + +
+
A : T1
+
The input tensor, not quantized
+
B : T2
+
1-dimensional quantized data for weight
+
absmax : T1
+
quantization constants
+
+ +#### Outputs + +
+
Y : T1
+
tensor. The output tensor has the same rank as the input.
+
+ +#### Type Constraints + +
+
T1 : tensor(float), tensor(float16)
+
Constrain input and output types to float/half_float tensors.
+
T2 : tensor(uint8)
+
Constrain quantized weight types to uint8.
+
+ + ### **com.microsoft.MatMulFpQ4** Matrix product with right hand matrix being pre-packed and quantized int4 data blob. diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index dea71d81f8df5..aef76203920df 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -454,6 +454,7 @@ Do not modify directly.* |GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float)| |GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)| |Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| +|MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)| |MatMulFpQ4|*in* A:**T1**
*in* B:**T2**
*in* B_shape:**T3**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)
**T3** = tensor(int64)| |MatMulInteger16|*in* A:**T1**
*in* B:**T2**
*out* Y:**T3**|1+|**T1** = tensor(int16)
**T2** = tensor(int16)
**T3** = tensor(int32)| |MatMulIntegerToFloat|*in* A:**T1**
*in* B:**T2**
*in* a_scale:**T3**
*in* b_scale:**T3**
*in* a_zero_point:**T1**
*in* b_zero_point:**T2**
*in* bias:**T3**
*out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float)| @@ -849,6 +850,7 @@ Do not modify directly.* |Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |Irfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| +|MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)| |MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T2**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)| |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**T** = tensor(float), tensor(float16)| |NGramRepeatBlock|*in* input_ids:**Tid**
*in* scores:**T**
*out* scores_out:**T**|1+|**T** = tensor(float)
**Tid** = tensor(int64)| diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc index f77e403f26dde..f9d9b13f0fedc 100644 --- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc @@ -30,6 +30,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Gathe class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, TransposeMatMul); // backward compatibility class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, FusedMatMul); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulNBits); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulBnb4); #ifndef ORT_MINIMAL_BUILD class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulFpQ4); #endif @@ -270,6 +271,7 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, // backward compatibility BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, #ifndef ORT_MINIMAL_BUILD BuildKernelCreateInfo, #endif diff --git a/onnxruntime/contrib_ops/cpu/quantization/blockwise_quant_block_bnb4.h b/onnxruntime/contrib_ops/cpu/quantization/blockwise_quant_block_bnb4.h new file mode 100644 index 0000000000000..cb8e97a592d8c --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/blockwise_quant_block_bnb4.h @@ -0,0 +1,202 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +namespace onnxruntime { +namespace contrib { + +#if defined(_MSC_VER) +#define FORCEINLINE __forceinline +#else +#define FORCEINLINE __attribute__((always_inline)) inline +#endif + +typedef enum Bnb_DataType_t { + FP4 = 0, + NF4 = 1, +} Bnb_DataType_t; + +FORCEINLINE uint8_t QuantizeOneFP4(float x) { + // FP4 with bias of 3 + // first bit is a sign + // subnormals + // 0b000 = 0 + // 0b001 = 0.0625 + // 0b110 = 2 + // 0b111 = 3 + // 0b100 = 4 + // 0b101 = 6 + // 0b010 = 8 + // 0b011 = 12 + + // we do a binary search + // the pivots are divided by 12 (the FP4 absmax) + // since we assum input data is in [-1.0, 1.0] + + // !be careful here, its easy to make a mistake + // that is difficult to noice if you add an extra + // zero somewhere! + + uint8_t sign = x < 0 ? 0b1000 : 0b0000; + x = fabsf(x); + if (x > 0.29166667f) { + if (x > 0.583333f) { + if (x > 0.8333333f) { + return 0b0011 + sign; + } else { + return 0b0010 + sign; + } + } else if (x > 0.4166667f) { + return 0b101 + sign; + } else { + return 0b100 + sign; + } + } else if (x > 0.0859375f) { + if (x > 0.20833333f) { + return 0b0111 + sign; + } else { + return 0b0110 + sign; + } + } else if (x > 0.00260417f) { + return 0b0001 + sign; + } else { + return 0b0000 + sign; + } +} + +FORCEINLINE uint8_t QuantizeOneNF4(float x) { + if (x > 0.03979014977812767f) { + if (x > 0.3893125355243683f) { // 1 + if (x > 0.6427869200706482f) { // 11 + if (x > 0.8614784181118011f) { // 111 + return 0b1111; + } else { + return 0b1110; + } + } else if (x > 0.5016634166240692f) { // 110 + return 0b1101; + } else { + return 0b1100; + } + } else if (x > 0.2035212516784668f) { // 10 + if (x > 0.2920137718319893f) { // 101 + return 0b1011; + } else { + return 0b1010; + } + } else if (x > 0.1202552504837513f) { // 100 + return 0b1001; + } else { + return 0b1000; + } + } else if (x > -0.33967943489551544f) { // 0 + if (x > -0.13791173323988914f) { // 01 + if (x > -0.045525018125772476f) { // 011 + return 0b0111; + } else { + return 0b0110; + } + } else if (x > -0.23460740596055984f) { // 010 + return 0b0101; + } else { + return 0b0100; + } + } else if (x > -0.6106329262256622f) { // 00 + if (x > -0.4599952697753906f) { // 001 + return 0b0011; + } else { + return 0b0010; + } + } else if (x > -0.8480964004993439f) { // 000 + return 0b0001; + } else { + return 0b0000; + } +} + +template +FORCEINLINE uint8_t QuantizeOneBnb4(float x) { + if constexpr (DATA_TYPE == FP4) + return QuantizeOneFP4(x); + else + return QuantizeOneNF4(x); +} + +template +FORCEINLINE void QuantizeBlockBnb4(const T* src, uint8_t* dst, T& absmax_block, int32_t block_idx, int32_t numel) { + float local_absmax = 0.0f; + + int32_t block_len = std::min(block_size, numel - block_idx * block_size); + int32_t src_offset = block_idx * block_size; + int32_t dst_offset = block_idx * block_size / 2; + + for (int32_t idx = 0; idx < block_len; idx++) { + const float v = static_cast(src[src_offset + idx]); + local_absmax = fmaxf(local_absmax, fabsf(v)); + } + + absmax_block = static_cast(local_absmax); + const float reciprocal_absmax = local_absmax ? 1.0f / local_absmax : 0.0f; + + for (int32_t idx = 0; idx < block_len; idx += 2) { + const float v0 = static_cast(src[src_offset + idx]) * reciprocal_absmax; + const uint8_t vi0 = QuantizeOneBnb4(v0); + + const float v1 = (idx + 1 < block_len) ? static_cast(src[src_offset + idx + 1]) * reciprocal_absmax : 0; + const uint8_t vi1 = QuantizeOneBnb4(v1); + + dst[dst_offset + idx / 2] = (vi0 << 4) | vi1; + } +} + +static float fp4_qaunt_map[16] = {0.00000000f, 5.208333333e-03f, 0.66666667f, 1.00000000f, + 0.33333333f, 0.50000000f, 0.16666667f, 0.25000000f, + -0.00000000f, -5.208333333e-03f, -0.66666667f, -1.00000000f, + -0.33333333f, -0.50000000f, -0.16666667f, -0.25000000f}; + +static float nf4_qaunt_map[16] = {-1.0f, + -0.6961928009986877f, + -0.5250730514526367f, + -0.39491748809814453f, + -0.28444138169288635f, + -0.18477343022823334f, + -0.09105003625154495f, + 0.0f, + 0.07958029955625534f, + 0.16093020141124725f, + 0.24611230194568634f, + 0.33791524171829224f, + 0.44070982933044434f, + 0.5626170039176941f, + 0.7229568362236023f, + 1.0f}; + +template +FORCEINLINE T DequantizeOneBnb4(uint8_t x) { + if constexpr (DATA_TYPE == FP4) + return static_cast(fp4_qaunt_map[x]); + else + return static_cast(nf4_qaunt_map[x]); +} + +template +FORCEINLINE void DequantizeBlockBnb4(const uint8_t* src, T* dst, T absmax_block, int32_t block_idx, int32_t numel) { + int32_t block_len = std::min(block_size, numel - block_idx * block_size); + int32_t src_offset = block_idx * block_size / 2; + int32_t dst_offset = block_idx * block_size; + + for (int32_t idx = 0; idx < block_len; idx += 2) { + const uint8_t val = src[src_offset + idx / 2]; + + dst[dst_offset + idx] = DequantizeOneBnb4(val >> 4) * absmax_block; + if (idx + 1 < block_len) dst[dst_offset + idx + 1] = DequantizeOneBnb4(val & 0xF) * absmax_block; + } +} + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise_bnb4.h b/onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise_bnb4.h new file mode 100644 index 0000000000000..5ddb77e5b5ee3 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise_bnb4.h @@ -0,0 +1,143 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "blockwise_quant_block_bnb4.h" + +#include + +#include "core/common/safeint.h" +#include "core/framework/float16.h" +#include "core/platform/threadpool.h" +#include + +namespace onnxruntime { +namespace contrib { + +template +void QuantizeBlockwiseBnb4( + uint8_t* dst, // shape: [(N * K + 1) / 2] + const T* src, // shape: [N, K] + T* absmax, // shape: [(N * K + block_size - 1) / block_size] + int32_t N, + int32_t K, + onnxruntime::concurrency::ThreadPool* thread_pool) { + int32_t numel = N * K; + int32_t total_block_count = (numel + block_size - 1) / block_size; + + concurrency::ThreadPool::TryBatchParallelFor( + thread_pool, + total_block_count, + [&](ptrdiff_t block_idx) { + QuantizeBlockBnb4( + src, + dst, + absmax[block_idx], + static_cast(block_idx), + numel); + }, + 0); +} + +#define QuantizeBlockwiseBn4DataTyped(block_size, quant_type) \ + if (quant_type == FP4) \ + QuantizeBlockwiseBnb4(dst, src, absmax, N, K, thread_pool); \ + else \ + QuantizeBlockwiseBnb4(dst, src, absmax, N, K, thread_pool); + +template +void QuantizeBlockwiseBnb4( + uint8_t* dst, // shape: [(N * K + 1) / 2] + const T* src, // shape: [N, K] + T* absmax, // shape: [(N * K + block_size - 1) / block_size] + int32_t block_size, + int32_t quant_type, + int32_t N, + int32_t K, + onnxruntime::concurrency::ThreadPool* thread_pool) { + ORT_ENFORCE( + quant_type == FP4 || quant_type == NF4, + "Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported."); + + if (block_size == 16) { + QuantizeBlockwiseBn4DataTyped(16, quant_type); + } else if (block_size == 32) { + QuantizeBlockwiseBn4DataTyped(32, quant_type); + } else if (block_size == 64) { + QuantizeBlockwiseBn4DataTyped(64, quant_type); + } else if (block_size == 128) { + QuantizeBlockwiseBn4DataTyped(128, quant_type); + } else if (block_size == 256) { + QuantizeBlockwiseBn4DataTyped(256, quant_type); + } else { + ORT_NOT_IMPLEMENTED("only block size 16, 32, 64, 128, 256 are supported."); + } +} + +#undef QuantizeBlockwiseBn4DataTyped + +template +void DequantizeBlockwiseBnb4( + T* dst, // shape: [N, K] + const uint8_t* src, // shape: [(N * K + 1) / 2)] + const T* absmax, // shape: [(N * K + block_size - 1) / block_size] + int32_t N, + int32_t K, + onnxruntime::concurrency::ThreadPool* thread_pool) { + int32_t numel = N * K; + int32_t total_block_count = (numel + block_size - 1) / block_size; + + concurrency::ThreadPool::TryBatchParallelFor( + thread_pool, + total_block_count, + [&](ptrdiff_t block_idx) { + DequantizeBlockBnb4( + src, + dst, + absmax[block_idx], + static_cast(block_idx), + numel); + }, + 0); +} + +#define DequantizeBlockwiseBn4DataTyped(block_size, quant_type) \ + if (quant_type == FP4) \ + DequantizeBlockwiseBnb4(dst, src, absmax, N, K, thread_pool); \ + else \ + DequantizeBlockwiseBnb4(dst, src, absmax, N, K, thread_pool); + +template +void DequantizeBlockwiseBnb4( + T* dst, // shape: [N, K] + const uint8_t* src, // shape: [(N * K + 1) / 2)] + const T* absmax, // shape: [(N * K + block_size - 1) / block_size] + int32_t block_size, + int32_t quant_type, + int32_t N, + int32_t K, + onnxruntime::concurrency::ThreadPool* thread_pool) { + ORT_ENFORCE( + quant_type == FP4 || quant_type == NF4, + "Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported."); + + if (block_size == 16) { + DequantizeBlockwiseBn4DataTyped(16, quant_type); + } else if (block_size == 32) { + DequantizeBlockwiseBn4DataTyped(32, quant_type); + } else if (block_size == 64) { + DequantizeBlockwiseBn4DataTyped(64, quant_type); + } else if (block_size == 128) { + DequantizeBlockwiseBn4DataTyped(128, quant_type); + } else if (block_size == 256) { + DequantizeBlockwiseBn4DataTyped(256, quant_type); + } else { + ORT_NOT_IMPLEMENTED("only block size 16, 32, 64, 128, 256 are supported."); + } +} + +#undef DequantizeBlockwiseBn4DataTyped + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_bnb4.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_bnb4.cc new file mode 100644 index 0000000000000..2f3ede49c3650 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_bnb4.cc @@ -0,0 +1,109 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/safeint.h" +#include "core/framework/op_kernel.h" +#include "core/providers/cpu/math/matmul_helper.h" +#include "core/providers/common.h" +#include "dequantize_blockwise_bnb4.h" +#include "core/mlas/inc/mlas.h" + +namespace onnxruntime { +namespace contrib { + +class MatMulBnb4 final : public OpKernel { + public: + MatMulBnb4(const OpKernelInfo& info) : OpKernel(info) { + ORT_ENFORCE(Status::OK() == info.GetAttr("K", &K_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("N", &N_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("block_size", &block_size_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("quant_type", &quant_type_)); + ORT_ENFORCE( + quant_type_ == FP4 || quant_type_ == NF4, + "Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported."); + } + + Status Compute(OpKernelContext* context) const override; + + private: + int64_t K_; + int64_t N_; + int64_t block_size_; + int64_t quant_type_; +}; + +Status MatMulBnb4::Compute(OpKernelContext* ctx) const { + concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool(); + + const Tensor* a = ctx->Input(0); + const Tensor* b_quant = ctx->Input(1); + const Tensor* absmax = ctx->Input(2); + + const float* a_data = a->Data(); + const uint8_t* b_quant_data = b_quant->Data(); + const float* absmax_data = absmax->Data(); + + AllocatorPtr allocator; + auto status = ctx->GetTempSpaceAllocator(&allocator); + ORT_RETURN_IF_ERROR(status); + auto tmp_b_data_ptr = IAllocator::MakeUniquePtr(allocator, SafeInt(K_) * N_); + DequantizeBlockwiseBnb4( + tmp_b_data_ptr.get(), + b_quant_data, + absmax_data, + static_cast(block_size_), + static_cast(quant_type_), + static_cast(N_), + static_cast(K_), + thread_pool); + + constexpr bool transa = false; + constexpr bool transb = true; + TensorShape b_shape({N_, K_}); + MatMulComputeHelper helper; + ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, transa, transb)); + + Tensor* y = ctx->Output(0, helper.OutputShape()); + + // Bail out early if the output is going to be empty + if (y->Shape().Size() == 0) return Status::OK(); + + auto* y_data = y->MutableData(); + + const size_t max_len = helper.OutputOffsets().size(); + const size_t M = static_cast(helper.M()); + const size_t N = static_cast(helper.N()); + const size_t K = static_cast(helper.K()); + const size_t lda = helper.Lda(transa); + const size_t ldb = helper.Ldb(transb); + + // TODO: implement with native kernel + std::vector data(max_len); + for (size_t i = 0; i < max_len; i++) { + data[i].BIsPacked = false; + data[i].A = a_data + helper.LeftOffsets()[i]; + data[i].lda = lda; + data[i].B = tmp_b_data_ptr.get() + helper.RightOffsets()[i]; + data[i].ldb = ldb; + data[i].C = y_data + helper.OutputOffsets()[i]; + data[i].ldc = N; + data[i].alpha = 1.f; + data[i].beta = 0.0f; + } + MlasGemmBatch(CblasNoTrans, CblasTrans, M, N, K, data.data(), max_len, thread_pool); + + return Status::OK(); +} + +ONNX_OPERATOR_KERNEL_EX( + MatMulBnb4, + kMSDomain, + 1, + kCpuExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + MatMulBnb4); + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index c52f869d6a9d2..e762a80cb0e2f 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -118,6 +118,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Inverse); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MatMulNBits); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MatMulNBits); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MatMulBnb4); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MatMulBnb4); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Trilu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, UnfoldTensor); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, DynamicTimeWarping); @@ -279,6 +281,8 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cu b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cu new file mode 100644 index 0000000000000..e58723f0b31e1 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cu @@ -0,0 +1,129 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include "core/providers/cuda/cuda_common.h" +#include "contrib_ops/cpu/quantization/blockwise_quant_block_bnb4.h" +#include "dequantize_blockwise_bnb4.cuh" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +Status SetBnbQuantMap(int quant_type, T* quant_map_buffer, cudaStream_t stream) { + ORT_ENFORCE( + quant_type == FP4 || quant_type == NF4, + "Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported."); + + T host_quant_map[16]; + switch (quant_type) { + case FP4: + for (int i = 0; i < 16; i++) host_quant_map[i] = static_cast(fp4_qaunt_map[i]); + break; + case NF4: + for (int i = 0; i < 16; i++) host_quant_map[i] = static_cast(nf4_qaunt_map[i]); + break; + } + CUDA_CALL_THROW(cudaMemcpyAsync(quant_map_buffer, host_quant_map, sizeof(T) * 16, cudaMemcpyHostToDevice, stream)); + + return Status::OK(); +} + +template Status SetBnbQuantMap(int quant_type, float* quant_map_buffer, cudaStream_t stream); + +template Status SetBnbQuantMap(int quant_type, half* quant_map_buffer, cudaStream_t stream); + +template +__global__ void kDequantizeBlockwise( + const T* quant_map, + T* output, + const uint8_t* quant_data, + const T* absmax, + const int block_size, + const int n) { + const int n_load = (gridDim.x * TILE_SIZE); + int valid_items_load = 0; + int valid_items_store = 0; + const int base_idx = (blockIdx.x * TILE_SIZE); + + T vals[NUM_PER_TH * 2]; + uint8_t qvals[NUM_PER_TH]; + T local_abs_max = T(0.0f); + + typedef cub::BlockLoad LoadChar; + typedef cub::BlockStore StoreT; + + __shared__ typename LoadChar::TempStorage loadchar; + __shared__ typename StoreT::TempStorage storet; + + for (unsigned int i = base_idx; i < n_load; i += gridDim.x * TILE_SIZE) { + valid_items_load = (n + 1) / 2 - i > TILE_SIZE ? TILE_SIZE : (n + 1) / 2 - i; + valid_items_store = n - i * 2 > TILE_SIZE * 2 ? TILE_SIZE * 2 : n - i * 2; + + local_abs_max = __ldg(&absmax[(i + threadIdx.x * NUM_PER_TH) / (block_size)]); + + __syncthreads(); + LoadChar(loadchar).Load(&(quant_data[i]), qvals, valid_items_load, 128); + + #pragma unroll NUM_PER_TH + for (int j = 0; j < NUM_PER_TH; j++) { + #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 + vals[j * 2] = quant_map[qvals[j] >> 4] * local_abs_max; + vals[j * 2 + 1] = quant_map[qvals[j] & 0x0F] * local_abs_max; + #else + // half multiplication not supported + vals[j * 2] = static_cast(static_cast(quant_map[qvals[j] >> 4]) * static_cast(local_abs_max)); + vals[j * 2 + 1] = + static_cast(static_cast(quant_map[qvals[j] & 0x0F]) * static_cast(local_abs_max)); + #endif + } + + __syncthreads(); + StoreT(storet).Store(&(output[i * 2]), vals, valid_items_store); + } +} + +template +Status DequantizeBnb4( + const T* quant_map, + T* output, + const uint8_t* quant_data, + const T* absmax, + int block_size, + int numel, + cudaStream_t stream) { + int tile_size = 1024; + kDequantizeBlockwise<<<(numel + tile_size - 1) / tile_size, 64, 0, stream>>>( + quant_map, + output, + quant_data, + absmax, + block_size / 2, + numel); + + return Status::OK(); +} + +template Status DequantizeBnb4( + const float* quant_map, + float* output, + const uint8_t* quant_data, + const float* absmax, + int block_size, + int numel, + cudaStream_t stream); + +template Status DequantizeBnb4( + const half* quant_map, + half* output, + const uint8_t* quant_data, + const half *absmax, + int block_size, + int numel, + cudaStream_t stream); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh new file mode 100644 index 0000000000000..4aef3ab699f9c --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/providers/cuda/shared_inc/cuda_utils.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +Status SetBnbQuantMap(int quant_type, T* quant_map_buffer, cudaStream_t stream); + +template +Status DequantizeBnb4( + const T* quant_map, + T* output, + const uint8_t* quant_data, + const T* absmax, + int block_size, + int numel, + cudaStream_t stream); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc new file mode 100644 index 0000000000000..bd5b6e0a8a1ce --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc @@ -0,0 +1,144 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/safeint.h" +#include "core/providers/cuda/cuda_kernel.h" +#include "core/providers/cuda/shared_inc/fpgeneric.h" +#include "core/providers/cpu/math/matmul_helper.h" +#include "contrib_ops/cpu/quantization/blockwise_quant_block_bnb4.h" +#include "matmul_bnb4.cuh" +#include "dequantize_blockwise_bnb4.cuh" + +namespace onnxruntime { +namespace contrib { +namespace cuda { +using namespace onnxruntime::cuda; + +template +class MatMulBnb4 final : public CudaKernel { + public: + MatMulBnb4(const OpKernelInfo& info) : CudaKernel(info) { + ORT_ENFORCE(Status::OK() == info.GetAttr("K", &K_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("N", &N_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("block_size", &block_size_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("quant_type", &quant_type_)); + ORT_ENFORCE( + quant_type_ == FP4 || quant_type_ == NF4, + "Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported."); + } + + Status ComputeInternal(OpKernelContext* context) const override; + + private: + int64_t K_; + int64_t N_; + int64_t block_size_; + int64_t quant_type_; +}; + +template +Status MatMulBnb4::ComputeInternal(OpKernelContext* ctx) const { + const Tensor* a = ctx->Input(0); + const Tensor* b_quant = ctx->Input(1); + const Tensor* absmax = ctx->Input(2); + + const auto* a_data = a->Data(); + const uint8_t* b_quant_data = b_quant->Data(); + const auto* absmax_data = absmax->Data(); + + typedef typename ToCudaType::MappedType CudaT; + + // TODO: find a better way to create the quant_map without using a buffer + // don't want to use malloc directly so asking from the caller + // can create a __device__ static array for float but doesn't work for half + IAllocatorUniquePtr quant_map_buffer = GetScratchBuffer(16, ctx->GetComputeStream()); + auto* quant_map_buffer_data = quant_map_buffer.get(); + ORT_RETURN_IF_ERROR(SetBnbQuantMap( + SafeInt(quant_type_), + reinterpret_cast(quant_map_buffer_data), + static_cast(ctx->GetComputeStream()->GetHandle()))); + + constexpr bool transa = false; + constexpr bool transb = true; + MatMulComputeHelper helper; + TensorShape b_shape({N_, K_}); + ORT_RETURN_IF_ERROR( + helper.Compute(a->Shape(), b_shape, transa, transb)); + + Tensor* Y = ctx->Output(0, helper.OutputShape()); + // Bail out early if the output is going to be empty + if (Y->Shape().Size() == 0) return Status::OK(); + + bool is_4bit_done = TryMatMulBnb4( + reinterpret_cast(quant_map_buffer_data), + reinterpret_cast(Y->MutableData()), + reinterpret_cast(a_data), + b_quant_data, + reinterpret_cast(absmax_data), + SafeInt(helper.M()), + SafeInt(helper.N()), + SafeInt(helper.K()), + SafeInt(block_size_), + static_cast(ctx->GetComputeStream()->GetHandle())); + + if (!is_4bit_done) { + IAllocatorUniquePtr b_dequant_ptr = GetScratchBuffer(N_ * K_, ctx->GetComputeStream()); + auto* b_dequant_data = b_dequant_ptr.get(); + ORT_RETURN_IF_ERROR(DequantizeBnb4( + reinterpret_cast(quant_map_buffer_data), + reinterpret_cast(b_dequant_data), + b_quant_data, + reinterpret_cast(absmax_data), + SafeInt(block_size_), + SafeInt(N_ * K_), + static_cast(ctx->GetComputeStream()->GetHandle()))); + + const CudaT alpha = ToCudaType::FromFloat(1.f); + const CudaT zero = ToCudaType::FromFloat(0.f); + + CUBLAS_RETURN_IF_ERROR(cublasGemmHelper( + GetCublasHandle(ctx), + CUBLAS_OP_T, + CUBLAS_OP_N, + SafeInt(helper.N()), + SafeInt(helper.M()), + SafeInt(helper.K()), + &alpha, + reinterpret_cast(b_dequant_data), + SafeInt(K_), + reinterpret_cast(a_data), + helper.Lda(transa), + &zero, + reinterpret_cast(Y->MutableData()), + helper.Ldc(), + GetDeviceProp())); + } + + return Status::OK(); +} + +ONNX_OPERATOR_TYPED_KERNEL_EX( + MatMulBnb4, + kMSDomain, + 1, + float, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + MatMulBnb4); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + MatMulBnb4, + kMSDomain, + 1, + MLFloat16, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + MatMulBnb4); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cu b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cu new file mode 100644 index 0000000000000..1d9aa75ff3701 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cu @@ -0,0 +1,192 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include +#include +#include +#include "matmul_bnb4.cuh" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#define num_values_4bit 32 +template +__global__ void kgemm_4bit_inference_naive( + int M, + int N, + int K, + const T* __restrict__ A, + const uint8_t* B, + const T* absmax, + const T* datatype, + T* out, + int lda, + int ldb, + int ldc, + int block_size) { + // per threadblock: + // load step-by-step in chunks of [32,warps]: 1x32 * [32,warps] -> [1,warps] + // 4 warps -> 4 loads per iter + // 1x32 * 32x4 -> 1x4 outputs per thread block + typedef cub::WarpReduce WarpReduce; + __shared__ typename WarpReduce::TempStorage temp_storage[THREADS / 32]; + + const int warp_idx = threadIdx.x / 32; + const int warp_lane = threadIdx.x % 32; + const int row_B = (THREADS / 32) * blockIdx.x + warp_idx; + const int num_values_8bit = num_values_4bit / 2; + float local_C = 0.0f; + + uint8_t local_B_4bit[num_values_8bit]; + T local_B[num_values_4bit / 4]; + T local_A[num_values_4bit / 4]; + __shared__ T quant_map[16]; + T local_absmax = T(0.0f); + + for (int i = threadIdx.x; i < 16; i++) quant_map[i] = T(datatype[i]); + __syncthreads(); + + // A: [1, K] + // B: [N, K] + for (int inner_idx = warp_lane * num_values_4bit; inner_idx < K; inner_idx += 32 * num_values_4bit) { + int inner_idx_halved = inner_idx / 2; + int offset_B = ldb * row_B; + int absidx = ((2 * offset_B) + inner_idx) / block_size; + local_absmax = __ldg(&(absmax[absidx])); + + if (row_B < N) { + if ((inner_idx_halved + num_values_8bit) < (K / 2)) { + // this is the most important for performance considerations + reinterpret_cast(local_B_4bit)[0] = + reinterpret_cast(B)[(offset_B + (inner_idx_halved)) / (num_values_8bit)]; + } else { + #pragma unroll + for (int j = 0; j < (num_values_8bit); j++) + if ((inner_idx_halved) + j < (K / 2)) + local_B_4bit[j] = B[offset_B + inner_idx_halved + j]; + else + local_B_4bit[j] = 0b01110111; + } + } else { + #pragma unroll + for (int j = 0; j < (num_values_8bit); j++) local_B_4bit[j] = 0b01110111; + } + + for (int i = 0; i < 4; i++) { + #pragma unroll + for (int k = 0; k < num_values_8bit / 4; k++) { + #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 + local_B[k * 2] = quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] >> 4] * local_absmax; + local_B[k * 2 + 1] = quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] & 0x0F] * local_absmax; + #else + // half multiplication not supported + local_B[k * 2] = + static_cast(static_cast(quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] >> 4]) * + static_cast(local_absmax)); + local_B[k * 2 + 1] = + static_cast(static_cast(quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] & 0x0F]) * + static_cast(local_absmax)); + #endif + } + + if (inner_idx + (num_values_4bit / 4) + (i * num_values_4bit / 4) < K) { + // this is also relatively important for performance + if (BITS == 16) { + reinterpret_cast(local_A)[0] = + reinterpret_cast(A)[inner_idx / (num_values_4bit / 4) + i]; + } else { + reinterpret_cast(local_A)[0] = + reinterpret_cast(A)[inner_idx / (num_values_4bit / 8) + (2 * i) + 0]; + reinterpret_cast(local_A)[1] = + reinterpret_cast(A)[inner_idx / (num_values_4bit / 8) + (2 * i) + 1]; + } + } else { + #pragma unroll + for (int k = 0; k < num_values_4bit / 4; k++) { + if (inner_idx + (i * num_values_4bit / 4) + k < K) + local_A[k] = A[inner_idx + k + (i * num_values_4bit / 4)]; + else + local_A[k] = T(0.0f); + } + } + + // accumulate in float; small performance hit for Ampere, but lower error for outputs + #pragma unroll + for (int k = 0; k < num_values_4bit / 4; k++) { + #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 + local_C += static_cast(local_A[k] * local_B[k]); + #else + // half multiplication not supported + local_C += static_cast(local_A[k]) * static_cast(local_B[k]); + #endif + } + } + } + + local_C = WarpReduce(temp_storage[warp_idx]).Sum(local_C); + + if (row_B < N && warp_lane == 0) out[row_B] = T(local_C); +} + +template +bool TryMatMulBnb4( + const T* quant_map, + T* output, + const T* a_data, + const uint8_t* b_data_quant, + const T* absmax, + int m, + int n, + int k, + int block_size, + cudaStream_t stream) { + if (k % block_size != 0 || m > 1) { + return false; + } + // supported block_sizes are [4096, 2048, 1024, 512, 256, 128, 64, 32] + if (block_size % 32 != 0 || block_size > 4096) { + return false; + } + + int lda = k; + int ldb = (k + 1) / 2; + int ldc = n; + int num_blocks = (n + 3) / 4; + + constexpr int bits = std::is_same_v ? 16 : 32; + kgemm_4bit_inference_naive<<>>( + m, n, k, a_data, b_data_quant, absmax, quant_map, output, lda, ldb, ldc, block_size); + + return true; +} + +template bool TryMatMulBnb4( + const float* quant_map, + float* output, + const float* a_data, + const uint8_t* b_data_quant, + const float* absmax, + int m, + int n, + int k, + int block_size, + cudaStream_t stream); + +template bool TryMatMulBnb4( + const half* quant_map, + half* output, + const half* a_data, + const uint8_t* b_data_quant, + const half* absmax, + int m, + int n, + int k, + int block_size, + cudaStream_t stream); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cuh b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cuh new file mode 100644 index 0000000000000..743234282fbf3 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cuh @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/providers/cuda/shared_inc/cuda_utils.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +bool TryMatMulBnb4( + const T* quant_map, + T* output, + const T* a_data, + const uint8_t* b_data_quant, + const T* absmax, + int m, + int n, + int k, + int block_size, + cudaStream_t stream); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 5e5eee568fa21..681a728f823da 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -3239,6 +3239,41 @@ Input zero_points is stored as uint8_t. If bits <= 4, two zero points are stored MatmulWithQuantWeightShapeInference(ctx, in_features, out_features); }); + static const char* MatMulBnb4_ver1_doc = R"DOC( +MatMulBnb4 is a MatMul with weight quantized with 4 bits using either FP4 or NF4 data type (https://arxiv.org/pdf/2305.14314.pdf). It does Matrix Multiplication like MatMul (https://github.com/onnx/onnx/blob/main/docs/Operators.md#matmul) with differences: + 1. Input B is a 2D constant Matrix. Its input feature count and output feature count are specified by attribute 'K' and 'N'. + 2. Input B is quantized with 4 bits with quantization data type specified by attribute 'quant_type'. It is transposed, flattened and quantized blockwisely with block size specified by attribute 'block_size'. + And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,.. + 3. Input B's quantization constants or scales are specified by input 'absmax'. + +Input B is stored as uint8_t with shape: [(N * K + 1) / 2]. +Input absmax is stored in same type as original type of B(float32, float16) with shape like: [(N * K + block_size - 1) / block_size]. + +)DOC"; + + ONNX_CONTRIB_OPERATOR_SCHEMA(MatMulBnb4) + .SetDomain(kMSDomain) + .SinceVersion(1) + .SetDoc(MatMulBnb4_ver1_doc) + .Attr("K", "size of each input feature", AttributeProto::INT) + .Attr("N", "size of each output feature", AttributeProto::INT) + .Attr("block_size", "number of groupsize used for weight quantization. It needs to be a power of 2 and not smaller than 16.", AttributeProto::INT) + .Attr("quant_type", "quantization data type. 0 for FP4, 1 for NF4.", AttributeProto::INT) + .Input(0, "A", "The input tensor, not quantized", "T1") + .Input(1, "B", "1-dimensional quantized data for weight", "T2") + .Input(2, "absmax", "quantization constants", "T1") + .Output(0, "Y", "tensor. The output tensor has the same rank as the input. ", "T1") + .TypeConstraint("T1", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float/half_float tensors.") + .TypeConstraint("T2", {"tensor(uint8)"}, "Constrain quantized weight types to uint8.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + // Type inference + propagateElemTypeFromInputToOutput(ctx, 0, 0); + // Shape inference + int64_t in_features = getAttribute(ctx, "K", -1); + int64_t out_features = getAttribute(ctx, "N", -1); + MatmulWithQuantWeightShapeInference(ctx, in_features, out_features); + }); + #ifdef ENABLE_ATEN ONNX_CONTRIB_OPERATOR_SCHEMA(ATen) .SetDomain(kPytorchAtenDomain) diff --git a/onnxruntime/python/onnxruntime_pybind_quant.cc b/onnxruntime/python/onnxruntime_pybind_quant.cc index 52ea677d5141d..04dfa9b51e112 100644 --- a/onnxruntime/python/onnxruntime_pybind_quant.cc +++ b/onnxruntime/python/onnxruntime_pybind_quant.cc @@ -6,6 +6,7 @@ #include #include "contrib_ops/cpu/quantization/dequantize_blockwise.h" +#include "contrib_ops/cpu/quantization/dequantize_blockwise_bnb4.h" #include "core/util/thread_utils.h" namespace pybind11 { @@ -64,9 +65,39 @@ void QuantizeMatMul4BitsBlockwise( tp.get()); } +template +void QuantizeMatMulBnb4Blockwise( + py::array_t dst, + py::array_t src, + py::array_t absmax, + int32_t block_size, + int32_t quant_type, + int32_t N, + int32_t K) { + OrtThreadPoolParams to; + auto tp = concurrency::CreateThreadPool(&onnxruntime::Env::Default(), to, + concurrency::ThreadPoolType::INTRA_OP); + + py::buffer_info dst_buf = dst.request(); + py::buffer_info src_buf = src.request(); + py::buffer_info absmax_buf = absmax.request(); + + contrib::QuantizeBlockwiseBnb4( + static_cast(dst_buf.ptr), + static_cast(src_buf.ptr), + static_cast(absmax_buf.ptr), + block_size, + quant_type, + N, + K, + tp.get()); +} + void CreateQuantPybindModule(py::module& m) { m.def("quantize_matmul_4bits", &QuantizeMatMul4BitsBlockwise); m.def("quantize_matmul_4bits", &QuantizeMatMul4BitsBlockwise); + m.def("quantize_matmul_bnb4", &QuantizeMatMulBnb4Blockwise); + m.def("quantize_matmul_bnb4", &QuantizeMatMulBnb4Blockwise); } } // namespace python diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/cuda/dequant_blockwise_bnb4.cu b/onnxruntime/python/tools/kernel_explorer/kernels/cuda/dequant_blockwise_bnb4.cu new file mode 100644 index 0000000000000..3504ce1bebe8c --- /dev/null +++ b/onnxruntime/python/tools/kernel_explorer/kernels/cuda/dequant_blockwise_bnb4.cu @@ -0,0 +1,89 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// This file serve as a simple example for adding a tunable op to onnxruntime. + +#include +#include +#include + +#include + +#include "core/providers/cuda/tunable/cuda_tunable.h" +#include "python/tools/kernel_explorer/kernel_explorer_interface.h" +#include "python/tools/kernel_explorer/device_array.h" +#include "contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh" + +namespace py = pybind11; + +namespace onnxruntime { + +// Extend the OpParams so that all specializations have the same parameter passing interface +template +struct DequantizeBnb4Params : cuda::tunable::OpParams { + std::string Signature() const override { return std::to_string(n_); } + + int quant_type_; + T* output_; + const uint8_t* quant_; + const T* absmax_; + T* quant_map_buffer_; + int n_; + int k_; +}; + +template +class DequantizeBnb4 : public IKernelExplorer { + public: + DequantizeBnb4( + int quant_type, + DeviceArray& output, + DeviceArray& quant, + DeviceArray& absmax, + DeviceArray& quant_map_buffer, + int n, int k) { + params_.tuning_ctx = TuningContext(); + params_.stream = Stream(); + params_.quant_type_ = quant_type; + params_.output_ = static_cast(output.ptr()); + params_.quant_ = static_cast(quant.ptr()); + params_.absmax_ = static_cast(absmax.ptr()); + params_.quant_map_buffer_ = static_cast(quant_map_buffer.ptr()); + params_.n_ = n; + params_.k_ = k; + } + + void Run() override { + ORT_THROW_IF_ERROR(contrib::cuda::SetBnbQuantMap( + params_.quant_type_, + params_.quant_map_buffer_, + params_.StreamHandle())); + ORT_THROW_IF_ERROR(contrib::cuda::DequantizeBnb4( + params_.quant_map_buffer_, + params_.output_, + params_.quant_, + params_.absmax_, + 64, + params_.n_ * params_.k_, + params_.StreamHandle())); + } + + private: + // A VectorAddOp is a callable that can process const VectorAddParams* + using ParamsT = DequantizeBnb4Params; + ParamsT params_{}; +}; + +#define REGISTER_OP(name, type) \ + py::class_>(m, #name "_" #type) \ + .def(py::init()) \ + .def("SetRepeats", &name::SetRepeats) \ + .def("Profile", &name::Profile) \ + .def("Run", &name::Run); + +KE_REGISTER(m) { + REGISTER_OP(DequantizeBnb4, half); + REGISTER_OP(DequantizeBnb4, float); +} + +} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/cuda/matmul_bnb4.cu b/onnxruntime/python/tools/kernel_explorer/kernels/cuda/matmul_bnb4.cu new file mode 100644 index 0000000000000..e4cd83565357a --- /dev/null +++ b/onnxruntime/python/tools/kernel_explorer/kernels/cuda/matmul_bnb4.cu @@ -0,0 +1,96 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// This file serve as a simple example for adding a tunable op to onnxruntime. + +#include +#include +#include + +#include + +#include "core/providers/cuda/tunable/cuda_tunable.h" +#include "python/tools/kernel_explorer/kernel_explorer_interface.h" +#include "python/tools/kernel_explorer/kernels/vector_add_kernel.cuh" +#include "contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh" +#include "contrib_ops/cuda/quantization/matmul_bnb4.cuh" + +namespace py = pybind11; + +namespace onnxruntime { + +// Extend the OpParams so that all specializations have the same parameter passing interface +template +struct MatrixFloatBnb4Params : cuda::tunable::OpParams { + std::string Signature() const override { return std::to_string(n_); } + + int quant_type_; + T* output_; + const T* a_; + const uint8_t* b_; + const T* absmax_; + T* quant_map_buffer_; + int m_; + int n_; + int k_; +}; + +template +class MatrixFloatBnb4 : public IKernelExplorer { + public: + MatrixFloatBnb4(DeviceArray& output, + DeviceArray& a, + DeviceArray& b, + DeviceArray& absmax, + DeviceArray& quant_map_buffer, + int quant_type, int m, int n, int k) { + params_.tuning_ctx = TuningContext(); + params_.stream = Stream(); + params_.output_ = static_cast(output.ptr()); + params_.a_ = static_cast(a.ptr()); + params_.b_ = static_cast(b.ptr()); + params_.absmax_ = static_cast(absmax.ptr()); + params_.quant_map_buffer_ = static_cast(quant_map_buffer.ptr()); + params_.quant_type_ = quant_type; + params_.m_ = m; + params_.n_ = n; + params_.k_ = k; + } + + void Run() override { + ORT_THROW_IF_ERROR(contrib::cuda::SetBnbQuantMap( + params_.quant_type_, + params_.quant_map_buffer_, + params_.StreamHandle())); + contrib::cuda::TryMatMulBnb4( + params_.quant_map_buffer_, + params_.output_, + params_.a_, + params_.b_, + params_.absmax_, + params_.m_, + params_.n_, + params_.k_, + 64, + params_.StreamHandle()); + } + + private: + // A VectorAddOp is a callable that can process const VectorAddParams* + using ParamsT = MatrixFloatBnb4Params; + ParamsT params_{}; +}; + +#define REGISTER_OP(name, type) \ + py::class_>(m, #name "_" #type) \ + .def(py::init()) \ + .def("SetRepeats", &name::SetRepeats) \ + .def("Profile", &name::Profile) \ + .def("Run", &name::Run); + +KE_REGISTER(m) { + REGISTER_OP(MatrixFloatBnb4, half); + REGISTER_OP(MatrixFloatBnb4, float); +} + +} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/dequantize_blockwise_bnb4.py b/onnxruntime/python/tools/kernel_explorer/kernels/dequantize_blockwise_bnb4.py new file mode 100644 index 0000000000000..140151aadcc0f --- /dev/null +++ b/onnxruntime/python/tools/kernel_explorer/kernels/dequantize_blockwise_bnb4.py @@ -0,0 +1,92 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +import sys +from dataclasses import dataclass + +import kernel_explorer as ke +import numpy as np +from utils import dtype_to_bytes + + +def dtype_to_funcs(dtype): + type_map = { + "float16": list(filter(lambda x: "DequantizeBnb4_half" in x, dir(ke))), + "float32": list(filter(lambda x: "DequantizeBnb4_float" in x, dir(ke))), + } + return type_map[dtype] + + +quant_enums = {"FP4": 0, "NF4": 1} + + +dtypes = ["float16", "float32"] +quant_types = ["FP4", "NF4"] + + +@dataclass +class DequantizeBnb4Metric(ke.BandwidthMetric): + quant_type: str + n: int + k: int + + def report(self): + return ( + f"{self.duration:6.2f} us {self.gbps:5.2f} GB/s" + f" {self.quant_type} {self.dtype} n={self.n} k={self.k} {self.name}" + ) + + +def profile_dequantize_int4_func(qt, n, k, dtype, func): + np.random.seed(0) + block_size = 64 + numel = n * k + output = np.random.rand(n, k).astype(dtype) + quant = np.random.randint(low=0, high=255, size=(numel + 1) // 2).astype("uint8") + absmax = np.random.rand((numel + block_size - 1) // block_size).astype(dtype) + quant_map_buffer = np.zeros(16).astype(dtype) + + output_d = ke.DeviceArray(output) + quant_d = ke.DeviceArray(quant) + absmax_d = ke.DeviceArray(absmax) + quant_map_buffer_d = ke.DeviceArray(quant_map_buffer) + f = getattr(ke, func) + my_op = f(quant_enums[qt], output_d, quant_d, absmax_d, quant_map_buffer_d, n, k) + duration_ms = my_op.Profile() + total_bytes = numel / 2 + (numel + numel / block_size) * dtype_to_bytes(dtype) + + ke.report(DequantizeBnb4Metric(func, dtype, duration_ms, total_bytes, qt, n, k)) + + +def profile_with_args(qt, n, k, dtype, sort): + with ke.benchmark(sort): + for func in dtype_to_funcs(dtype): + profile_dequantize_int4_func(qt, n, k, dtype, func) + + +def profile(): + for qt in quant_types: + for dt in dtypes: + for n, k in ((4096, 4096), (4096, 12288), (12288, 4096)): + profile_with_args(qt, n, k, dt, True) + print() + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + group = parser.add_argument_group("profile with args") + group.add_argument("n", type=int) + group.add_argument("k", type=int) + group.add_argument("quant_type", choices=quant_types) + group.add_argument("dtype", choices=dtypes) + group.add_argument("--sort", action="store_true") + + if len(sys.argv) == 1: + profile() + else: + args = parser.parse_args() + profile_with_args(args.quant_type, args.n, args.k, args.dtype, args.sort) diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/matmul_bnb4.py b/onnxruntime/python/tools/kernel_explorer/kernels/matmul_bnb4.py new file mode 100644 index 0000000000000..4a9489050fd61 --- /dev/null +++ b/onnxruntime/python/tools/kernel_explorer/kernels/matmul_bnb4.py @@ -0,0 +1,136 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +import sys +from dataclasses import dataclass + +import kernel_explorer as ke +import numpy as np +from utils import dtype_to_bytes + + +def dtype_to_funcs(dtype): + type_map = { + "float16": list(filter(lambda x: "MatrixFloatBnb4_half" in x, dir(ke))), + "float32": list(filter(lambda x: "MatrixFloatBnb4_float" in x, dir(ke))), + } + return type_map[dtype] + + +def dtype_to_funcs_cublas(dtype): + type_map = { + "float16": list(filter(lambda x: "GemmBenchmark_half" in x, dir(ke))), + "float32": list(filter(lambda x: "GemmBenchmark_float" in x, dir(ke))), + } + return type_map[dtype] + + +quant_enums = {"FP4": 0, "NF4": 1} + + +dtypes = ["float16", "float32"] +quant_types = ["FP4", "NF4"] + + +@dataclass +class MatrixMulMetric(ke.BandwidthMetric): + m: int + n: int + k: int + + def report(self): + return ( + f"{self.duration:6.2f} us {self.gbps:5.2f} GB/s {self.dtype} m={self.m} n={self.n} k={self.k} {self.name}" + ) + + +@dataclass +class MatrixFpBnb4Metric(MatrixMulMetric): + quant_type: str + + def report(self): + return ( + f"{self.duration:6.2f} us {self.gbps:5.2f} GB/s" + f" {self.quant_type} {self.dtype} m={self.m} n={self.n} k={self.k} {self.name}" + ) + + +def profile_matmul_fp_bnb4_func(qt, m, n, k, dtype, func): + np.random.seed(0) + block_size = 64 + numel = n * k + output = np.random.rand(m, n).astype(dtype) + a = np.random.rand(m, k).astype(dtype) + b = np.random.randint(low=0, high=255, size=(numel + 1) // 2).astype("uint8") + absmax = np.random.rand((numel + block_size - 1) // block_size).astype(dtype) + quant_map_buffer = np.zeros(16).astype(dtype) + + output_d = ke.DeviceArray(output) + a_d = ke.DeviceArray(a) + b_d = ke.DeviceArray(b) + absmax_d = ke.DeviceArray(absmax) + quant_map_buffer_d = ke.DeviceArray(quant_map_buffer) + f = getattr(ke, func) + + my_op = f(output_d, a_d, b_d, absmax_d, quant_map_buffer_d, quant_enums[qt], m, n, k) + duration_ms = my_op.Profile() + total_bytes = (m * k + n * k + m * n) * (dtype_to_bytes(dtype)) + + ke.report(MatrixFpBnb4Metric(func, dtype, duration_ms, total_bytes, m, n, k, qt)) + + +def profile_gemm_func(m, n, k, dtype, func): + np.random.seed(0) + output = np.random.rand(m, n).astype(dtype) + a = np.random.rand(m, k).astype(dtype) + b = np.random.rand(k, n).astype(dtype) + + output_d = ke.DeviceArray(output) + a_d = ke.DeviceArray(a) + b_d = ke.DeviceArray(b) + f = getattr(ke, func) + my_op = f(output_d, a_d, b_d, m, n, k) + duration_ms = my_op.Profile() + total_bytes = (m * k + n * k + m * n) * (dtype_to_bytes(dtype)) + + ke.report(MatrixMulMetric(func, dtype, duration_ms, total_bytes, m, n, k)) + + +def profile_with_args(qt, m, n, k, dtype, sort): + with ke.benchmark(sort): + for func in dtype_to_funcs(dtype): + profile_matmul_fp_bnb4_func(qt, m, n, k, dtype, func) + + for func in dtype_to_funcs_cublas(dtype): + profile_gemm_func(m, n, k, dtype, func) + + +def profile(): + dims_m = [1] + for qt in quant_types: + for dt in dtypes: + for m in dims_m: + for n, k in ((4096, 4096), (4096, 12288), (12288, 4096)): + profile_with_args(qt, m, n, k, dt, False) + print() + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + group = parser.add_argument_group("profile with args") + group.add_argument("m", type=int) + group.add_argument("n", type=int) + group.add_argument("k", type=int) + group.add_argument("quant_type", choices=quant_types) + group.add_argument("dtype", choices=dtypes) + group.add_argument("--sort", action="store_true") + + if len(sys.argv) == 1: + profile() + else: + args = parser.parse_args() + profile_with_args(args.quant_type, args.m, args.n, args.k, args.dtype, args.sort) diff --git a/onnxruntime/python/tools/quantization/matmul_bnb4_quantizer.py b/onnxruntime/python/tools/quantization/matmul_bnb4_quantizer.py new file mode 100644 index 0000000000000..951746a089305 --- /dev/null +++ b/onnxruntime/python/tools/quantization/matmul_bnb4_quantizer.py @@ -0,0 +1,240 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import argparse +import logging +import os +from typing import List, Tuple + +import numpy as np +import numpy.typing as npt +import onnx +from onnx.onnx_pb import GraphProto, ModelProto, NodeProto, TensorProto + +from onnxruntime.capi._pybind_state import quantize_matmul_bnb4 + +from .onnx_model import ONNXModel +from .quant_utils import attribute_to_kwarg + +logger = logging.getLogger(__name__) + + +class MatMulBnb4Quantizer: + """Perform 4b quantization of constant MatMul weights using FP4 or NF4 data type""" + + ################## + # quantization types, must be consistent with native code type + # Bnb_DataType_t defined in blockwise_quant_block_bnb4.h + + # 4b floating point with bias of 3 + FP4 = 0 + + # 4b NormalFloat + NF4 = 1 + + def __init__(self, model: ModelProto, quant_type: int, block_size: int, nodes_to_exclude=None): + nodes_to_exclude = nodes_to_exclude or [] + assert quant_type in [MatMulBnb4Quantizer.FP4, MatMulBnb4Quantizer.NF4] + self.model = ONNXModel(model) + self.quant_type = quant_type + self.block_size = block_size + self.nodes_to_exclude = set(nodes_to_exclude) + + @staticmethod + def __get_initializer(name, graph_path: List[GraphProto]) -> Tuple[TensorProto, GraphProto]: + for gid in range(len(graph_path) - 1, -1, -1): + graph = graph_path[gid] + for tensor in graph.initializer: + if tensor.name == name: + return tensor, graph + return None, None + + def bnb4_block_quant(self, fpweight: npt.ArrayLike) -> np.ndarray: + """4b quantize fp32/fp16 weight""" + + if len(fpweight.shape) != 2: + raise ValueError("Current bnb4 block quantization only supports 2D tensors!") + # need to copy since the transposed weight still has the original memory layout + # Linear4bit quantizes its weight data which is the transposed weight + fpweight_t = fpweight.transpose().copy() + + rows, cols = fpweight.shape + numel = rows * cols + block_size = self.block_size + num_blocks = (numel + block_size - 1) // block_size + quantized_numel = (numel + 1) // 2 + + packed = np.zeros(quantized_numel, dtype="uint8") + absmax = np.zeros(num_blocks, dtype=fpweight.dtype) + # block wise quantization, fpweight_t is flattened and divided into blocks + quantize_matmul_bnb4(packed, fpweight_t, absmax, block_size, self.quant_type, cols, rows) + + return (packed, absmax) + + def _bnb4_matmul_node_weight(self, node: NodeProto, graph_stack: List[GraphProto]) -> NodeProto: + """If the node is MatMul with fp32 const weight, quantize the weight with int4, and return the new node""" + + if node.op_type != "MatMul": + return node # only care about MatMul for now + + logger.debug(f"start to quantize {node.name} ...") + if node.name in self.nodes_to_exclude: + logger.debug(f"exclude to quantize {node.name} as specified by nodes_to_exclude...") + return node + + inputB = node.input[1] # noqa: N806 + B, Bs_graph = MatMulBnb4Quantizer.__get_initializer(inputB, graph_stack) # noqa: N806 + if B is None: + logger.debug("MatMul doesn't have const weight. Skip to quantize") + return node # only care about constant weight + + B_array = onnx.numpy_helper.to_array(B) # noqa: N806 + if len(B_array.shape) != 2: + logger.debug("MatMul weight is not 2D. Skip to quantize") + return node # can only process 2-D matrix + + packed, absmax = self.bnb4_block_quant(B_array) + B_quant = onnx.numpy_helper.from_array(packed) # noqa: N806 + B_quant.name = B.name + "_Bnb4" + for input in Bs_graph.input: + if input.name == inputB: + Bs_graph.input.remove(input) + break + + absmax_tensor = onnx.numpy_helper.from_array(absmax) + absmax_tensor.name = B.name + "_absmax" + + Bs_graph.initializer.extend([B_quant, absmax_tensor]) + + kwargs = {} + rows, cols = B_array.shape + kwargs["K"] = rows + kwargs["N"] = cols + kwargs["block_size"] = self.block_size + kwargs["quant_type"] = self.quant_type + + matmul_bnb4_node = onnx.helper.make_node( + "MatMulBnb4", + inputs=[node.input[0], B_quant.name, absmax_tensor.name], + outputs=[node.output[0]], + name=node.name + "_Bnb4" if node.name else "", + domain="com.microsoft", + **kwargs, + ) + + logger.debug(f"complete quantization of {node.name} ...") + + return matmul_bnb4_node + + def _process_subgraph(self, graph_stack: List[GraphProto]): + new_nodes = [] + graph = graph_stack[-1] + + for node in graph.node: + graph_attrs = [ + attr + for attr in node.attribute + if attr.type == onnx.AttributeProto.GRAPH or attr.type == onnx.AttributeProto.GRAPHS + ] + if len(graph_attrs): + kwargs = {} + for attr in node.attribute: + if attr.type == onnx.AttributeProto.GRAPH: + # recursive call to take care of sub-graph + graph_stack.append(attr.g) + kv = {attr.name: self._process_subgraph(graph_stack)} + elif attr.type == onnx.AttributeProto.GRAPHS: + value = [] + for subgraph in attr.graphs: + # recursive call to take care of sub-graph + graph_stack.append(subgraph) + value.extend([self._process_subgraph(graph_stack)]) + kv = {attr.name: value} + else: + kv = attribute_to_kwarg(attr) + kwargs.update(kv) + node = onnx.helper.make_node( # noqa: PLW2901 + node.op_type, node.input, node.output, name=node.name, **kwargs + ) + + new_nodes.append(self._bnb4_matmul_node_weight(node, graph_stack)) + + graph.ClearField("node") + graph.node.extend(new_nodes) + graph_stack.pop() + return graph + + def process(self): + # use a stack to keep track of sub-graphs + graph_stack = [self.model.graph()] + opset_import = self.model.opset_import() + + has_ms_domain = False + for opset in opset_import: + if opset.domain == "com.microsoft": + has_ms_domain = True + if not has_ms_domain: + opset_import.extend([onnx.helper.make_opsetid("com.microsoft", 1)]) + + self._process_subgraph(graph_stack) + self.model.clean_initializers() + + +def parse_args(): + parser = argparse.ArgumentParser( + description="""Blockwise FP4/NF4 quantization for MatMul 2D weight matrices. + +A weight matrix is partitioned into blocks, where each block is a contiguous +subset inside the flattened transposed weight matrix. Each block is quantized +into a set of 4b integers with an absolute value scaling factor. +""" + ) + + parser.add_argument("--input_model", required=True, help="Path to the input model file") + parser.add_argument("--output_model", required=True, help="Path to the output model file") + parser.add_argument( + "--quant_type", + required=False, + default=1, + options=[MatMulBnb4Quantizer.FP4, MatMulBnb4Quantizer.NF4], + help="Quantization data type. 0: FP4, 1: NF4", + ) + parser.add_argument( + "--block_size", + required=False, + default=64, + description="Block size for blockwise quantization. Note: bnb.nn.Linear4bit only uses block_size=64", + ) + parser.add_argument("-v", "--verbose", required=False, action="store_true") + parser.set_defaults(verbose=False) + parser.add_argument( + "--nodes_to_exclude", + nargs="+", + type=str, + required=False, + default=[], + help="Specify the nodes to be excluded from quantization with node names", + ) + + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + if args.verbose: + logger.setLevel(logging.DEBUG) + + input_model_path = args.input_model + output_model_path = args.output_model + + if os.path.exists(output_model_path): + logger.error(f"file {output_model_path} already exists") + raise Exception(f"file {output_model_path} already exists") + + model = onnx.load(input_model_path) + quant = MatMulBnb4Quantizer(model, args.quant_type, args.block_size, nodes_to_exclude=args.nodes_to_exclude) + quant.process() + quant.model.save_model_to_file(output_model_path, True) diff --git a/onnxruntime/test/contrib_ops/matmul_bnb4_test.cc b/onnxruntime/test/contrib_ops/matmul_bnb4_test.cc new file mode 100644 index 0000000000000..e739b17d5885f --- /dev/null +++ b/onnxruntime/test/contrib_ops/matmul_bnb4_test.cc @@ -0,0 +1,151 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef ORT_MINIMAL_BUILD + +#include "core/common/span_utils.h" +#include "core/framework/tensor.h" +#include "core/mlas/inc/mlas_q4.h" +#include "core/mlas/inc/mlas.h" +#include "core/session/inference_session.h" +#include "test/common/tensor_op_test_utils.h" +#include "test/framework/test_utils.h" +#include "test/optimizer/graph_transform_test_builder.h" +#include "test/providers/provider_test_utils.h" +#include "test/util/include/default_providers.h" +#include "core/util/qmath.h" +#include "contrib_ops/cpu/quantization/dequantize_blockwise_bnb4.h" + +#include +#include + +#include "gtest/gtest.h" +#include "gmock/gmock.h" + +namespace onnxruntime { +namespace test { + +void QuantizeDequantizeBnb4(std::vector& raw_vals, // N X K + std::vector& quant_vals, + std::vector& absmax, + int32_t quant_type, + int32_t N, + int32_t K, + int32_t block_size) { + OrtThreadPoolParams to; + auto tp = concurrency::CreateThreadPool(&onnxruntime::Env::Default(), to, + concurrency::ThreadPoolType::INTRA_OP); + + contrib::QuantizeBlockwiseBnb4( + quant_vals.data(), + raw_vals.data(), + absmax.data(), + block_size, + quant_type, + N, + K, + tp.get()); + + contrib::DequantizeBlockwiseBnb4( + raw_vals.data(), + quant_vals.data(), + absmax.data(), + block_size, + quant_type, + N, + K, + tp.get()); +} + +void RunTest(int64_t quant_type, int64_t M, int64_t N, int64_t K, int64_t block_size, bool use_float16) { + RandomValueGenerator random{1234}; + std::vector input0_vals(random.Gaussian(std::vector({M, K}), 0.0f, 0.25f)); + // quantizer expects transposed weights, N X K + std::vector input1_f_vals(random.Gaussian(std::vector({N, K}), 0.0f, 0.25f)); + + int64_t numel = N * K; + int64_t quantized_numel = (numel + 1) / 2; + int64_t total_block_count = (numel + block_size - 1) / block_size; + std::vector input1_vals(quantized_numel); + std::vector absmax(total_block_count); + + QuantizeDequantizeBnb4(input1_f_vals, + input1_vals, + absmax, + static_cast(quant_type), + static_cast(N), + static_cast(K), + static_cast(block_size)); + + std::vector expected_vals(M * N); + for (int64_t m = 0; m < M; m++) { + for (int64_t n = 0; n < N; n++) { + float sum = 0.0f; + for (int64_t k = 0; k < K; k++) { + sum += input0_vals[m * K + k] * input1_f_vals[n * K + k]; + } + expected_vals[m * N + n] = sum; + } + } + + OpTester test("MatMulBnb4", 1, kMSDomain); + test.AddAttribute("K", K); + test.AddAttribute("N", N); + test.AddAttribute("block_size", block_size); + test.AddAttribute("quant_type", quant_type); + if (use_float16) { + test.AddInput("A", {M, K}, ToFloat16(input0_vals), false); + test.AddInput("B", {quantized_numel}, input1_vals, true); + test.AddInput("absmax", {total_block_count}, ToFloat16(absmax), true); + + test.AddOutput("Y", {M, N}, ToFloat16(expected_vals)); + test.SetOutputAbsErr("Y", 0.02f); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } else { + test.AddInput("A", {M, K}, input0_vals, false); + test.AddInput("B", {quantized_numel}, input1_vals, true); + test.AddInput("absmax", {total_block_count}, absmax, true); + + test.AddOutput("Y", {M, N}, expected_vals); + + test.Run(); + } +} + +TEST(MatMulBnb4, Float32) { + for (auto qt : {0, 1}) { + for (auto M : {1, 2, 100}) { + for (auto N : {1, 2, 32, 288}) { + for (auto K : {16, 32, 64, 128, 256, 1024, 93, 1234}) { + for (auto block_size : {16, 32, 64, 128}) { + RunTest(qt, M, N, K, block_size, false); + } + } + } + } + } +} + +#if defined(USE_CUDA) +TEST(MatMulBnb4, Float16) { + for (auto qt : {0, 1}) { + for (auto M : {1, 2, 100}) { + for (auto N : {1, 2, 32, 288}) { + for (auto K : {16, 32, 64, 128, 256, 1024, 93, 1234}) { + for (auto block_size : {16, 32, 64, 128}) { + RunTest(qt, M, N, K, block_size, true); + } + } + } + } + } +} + +#endif +} // namespace test +} // namespace onnxruntime + +#endif // ORT_MINIMAL_BUILD diff --git a/onnxruntime/test/python/quantization/test_op_matmul_bnb4.py b/onnxruntime/test/python/quantization/test_op_matmul_bnb4.py new file mode 100644 index 0000000000000..88432d75c653e --- /dev/null +++ b/onnxruntime/test/python/quantization/test_op_matmul_bnb4.py @@ -0,0 +1,186 @@ +#!/usr/bin/env python +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import tempfile +import unittest +from importlib.util import find_spec +from pathlib import Path +from typing import Dict, Tuple, Union + +import numpy as np +import onnx +from onnx import TensorProto, helper +from op_test_utils import TestDataFeeds, check_model_correctness, check_op_type_count + +from onnxruntime.quantization import quant_utils + +quant_maps = { + 0: [ + 0.00000000, + 5.208333333e-03, + 0.66666667, + 1.00000000, + 0.33333333, + 0.50000000, + 0.16666667, + 0.25000000, + -0.00000000, + -5.208333333e-03, + -0.66666667, + -1.00000000, + -0.33333333, + -0.50000000, + -0.16666667, + -0.25000000, + ], + 1: [ + -1.0, + -0.6961928009986877, + -0.5250730514526367, + -0.39491748809814453, + -0.28444138169288635, + -0.18477343022823334, + -0.09105003625154495, + 0.0, + 0.07958029955625534, + 0.16093020141124725, + 0.24611230194568634, + 0.33791524171829224, + 0.44070982933044434, + 0.5626170039176941, + 0.7229568362236023, + 1.0, + ], +} + + +class TestOpMatMulBnb4(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls._tmp_model_dir = tempfile.TemporaryDirectory(prefix="test_matmulbnb4.") + + @classmethod + def tearDownClass(cls): + cls._tmp_model_dir.cleanup() + + def fill_bnb4_data(self, shape: Tuple[int, int], quant_type: int) -> np.ndarray: + rows, cols = shape + line = np.zeros(shape) + line = line.reshape(-1) + quant_map = np.array(quant_maps[quant_type], dtype=np.float32) + + v = 0 + for i in range(line.shape[0]): + line[i] = quant_map[v] + v += 1 + if v >= 16: + v = 0 + + # bnb quantization quantizes weight.T after flattening + line = line.reshape(cols, rows).transpose() + return line.reshape(shape) + + def input_feeds(self, n: int, name2shape: Dict[str, Union[int, Tuple[int, ...]]]) -> TestDataFeeds: + input_data_list = [] + for _i in range(n): + inputs = {} + for name, shape in name2shape.items(): + inputs.update({name: np.random.randint(-1, 2, shape).astype(np.float32)}) + input_data_list.extend([inputs]) + dr = TestDataFeeds(input_data_list) + return dr + + def construct_model_matmul(self, output_model_path: str, quant_type: int) -> None: + # (input) + # | + # MatMul + # | + # (output) + input_name = "input" + output_name = "output" + initializers = [] + + def make_matmul(input_name, weight_shape: Union[int, Tuple[int, ...]], weight_name: str, output_name: str): + weight_data = self.fill_bnb4_data(weight_shape, quant_type).astype(np.float32) + initializers.append(onnx.numpy_helper.from_array(weight_data, name=weight_name)) + return onnx.helper.make_node( + "MatMul", + [input_name, weight_name], + [output_name], + ) + + # for this to work (in_features * out_features) % block_size == 0 + in_features = 52 + out_features = 288 + # make MatMul node + matmul_node = make_matmul( + input_name, + [in_features, out_features], + "linear1.weight", + output_name, + ) + + # make graph + input_tensor = helper.make_tensor_value_info(input_name, TensorProto.FLOAT, [-1, in_features]) + output_tensor = helper.make_tensor_value_info(output_name, TensorProto.FLOAT, [-1, out_features]) + graph_name = "matmul_bnb4_test" + graph = helper.make_graph( + [matmul_node], + graph_name, + [input_tensor], + [output_tensor], + initializer=initializers, + ) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + model.ir_version = 7 # use stable onnx ir version + + onnx.save(model, output_model_path) + + def quant_test(self, quant_type: int, block_size: int): + model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath(f"matmul_fp32_{quant_type}.onnx").absolute()) + self.construct_model_matmul(model_fp32_path, quant_type) + data_reader = self.input_feeds(1, {"input": [100, 52]}) + + model_bnb4_path = str( + Path(self._tmp_model_dir.name).joinpath(f"MatMulBnb4_{quant_type}_{block_size}.onnx").absolute() + ) + + # Quantize fp32 model to bnb4 model + from onnxruntime.quantization import matmul_bnb4_quantizer + + model = quant_utils.load_model_with_shape_infer(Path(model_fp32_path)) + quant = matmul_bnb4_quantizer.MatMulBnb4Quantizer(model, quant_type, block_size) + quant.process() + quant.model.save_model_to_file(model_bnb4_path, False) + + quant_nodes = {"MatMulBnb4": 1} + check_op_type_count(self, model_bnb4_path, **quant_nodes) + + data_reader.rewind() + + try: + check_model_correctness(self, model_fp32_path, model_bnb4_path, data_reader.get_next()) + except Exception as exception: + raise exception + + @unittest.skipIf( + find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_bnb4" + ) + def test_quantize_matmul_bnb4_fp4(self): + np.random.seed(13) + self.quant_test(0, 64) + + @unittest.skipIf( + find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_bnb4" + ) + def test_quantize_matmul_bnb4_nf4(self): + np.random.seed(13) + self.quant_test(1, 64) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxruntime/test/python/quantization/test_quantizeblockwise_bnb4.py b/onnxruntime/test/python/quantization/test_quantizeblockwise_bnb4.py new file mode 100644 index 0000000000000..9e9d05fae027d --- /dev/null +++ b/onnxruntime/test/python/quantization/test_quantizeblockwise_bnb4.py @@ -0,0 +1,139 @@ +#!/usr/bin/env python +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import unittest +from importlib.util import find_spec + +import numpy as np +import numpy.typing as npt + +quant_enums = {"FP4": 0, "NF4": 1} + + +def quantize_block_fp4(block: npt.ArrayLike): + # quantize a block of float32 values to uint8 by simulating a binary search using pivots + # could have used (block[:,None] - quant_map).argmin(axis=1) but there are some mismatches due to + # floating point precision + # block: 1-D array of normalized [-1,1] float32 values, len(block) % 2 == 0 + + # pivots to find the quantization index + # only half of the pivots are needed since the other half is symmetric + pivots = np.array( + [0.00260417, 0.0859375, 0.20833333, 0.29166667, 0.4166667, 0.583333, 0.8333333, 1], dtype=np.float32 + ) + # indices are not 0,1,2,3,4,5,6,7 because it is a floating point data type + pivot_indices = np.array([0, 1, 6, 7, 4, 5, 2, 3], dtype=np.uint8) + + # signs of the block + signs = (block < 0).astype(np.uint8) * 8 + + # find the uint8 quantization index + # argmax finds the first occurrence of True + quant_indices = pivot_indices[(np.abs(block)[:, None] <= pivots).argmax(axis=1)] + signs + + return np.bitwise_or(np.left_shift(quant_indices[::2], 4), quant_indices[1::2]) + + +def quantize_block_nf4(block: npt.ArrayLike): + pivots = np.array( + [ + -0.8480964004993439, + -0.6106329262256622, + -0.4599952697753906, + -0.33967943489551544, + -0.23460740596055984, + -0.13791173323988914, + -0.045525018125772476, + 0.03979014977812767, + 0.1202552504837513, + 0.2035212516784668, + 0.2920137718319893, + 0.3893125355243683, + 0.5016634166240692, + 0.6427869200706482, + 0.8614784181118011, + 1.0, + ], + dtype=np.float32, + ) + + quant_indices = (block[:, None] <= pivots).argmax(axis=1).astype(np.uint8) + + return np.bitwise_or(np.left_shift(quant_indices[::2], 4), quant_indices[1::2]) + + +def quantize_blockwise_bnb4_ref(matrix_float: npt.ArrayLike, block_size: int, quant_type: str, target=None): + if len(matrix_float.shape) != 2: + raise ValueError("Current bnb4 block quantization only supports 2D tensors!") + + numel = matrix_float.size + num_blocks = (numel + block_size - 1) // block_size + quantized_numel = (numel + 1) // 2 + + packed = np.zeros(quantized_numel, dtype=np.uint8) + absmax = np.zeros(num_blocks, dtype=matrix_float.dtype) + + flattened_matrix_float = matrix_float.flatten() + for block_idx in range(num_blocks): + block_len = min(block_size, numel - block_idx * block_size) + block = np.float32(flattened_matrix_float[block_idx * block_size : block_idx * block_size + block_len]) + + block_absmax = np.max(np.abs(block)) + reciprocal_absmax = 1.0 / block_absmax if block_absmax != 0 else 0.0 + absmax[block_idx] = block_absmax + + if block_len % 2 != 0: + block = np.append(block, 0.0) + block_len += 1 + + block *= reciprocal_absmax + start = block_idx * block_size // 2 + end = start + block_len // 2 + if quant_type == "FP4": + packed[start:end] = quantize_block_fp4(block) + else: + packed[start:end] = quantize_block_nf4(block) + + return (packed, absmax) + + +def quantize_blockwise_bnb4_target(matrix_float: npt.ArrayLike, block_size: int, quant_type: str): + if len(matrix_float.shape) != 2: + raise ValueError("Current int4 block quantization only supports 2D tensors!") + quant_type_enum = quant_enums[quant_type] + + n, k = matrix_float.shape # already transposed + numel = n * k + num_blocks = (numel + block_size - 1) // block_size + quantized_numel = (numel + 1) // 2 + + packed = np.zeros(quantized_numel, dtype="uint8") + absmax = np.zeros(num_blocks, dtype=matrix_float.dtype) + from onnxruntime.capi._pybind_state import quantize_matmul_bnb4 + + quantize_matmul_bnb4(packed, matrix_float, absmax, block_size, quant_type_enum, n, k) + return (packed, absmax) + + +class TestQuantizeBlockwiseBnb4(unittest.TestCase): + @unittest.skipIf( + find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_bnb4" + ) + def test_quantize_blockwise_bnb4(self): + for quant_type in ["FP4", "NF4"]: + for k, n in [(128, 128), (32, 128), (128, 32), (52, 128), (128, 52), (73, 123)]: + for block_size in [16, 32, 64, 128]: + for type in [np.float32, np.float16]: + matrix_float = np.random.uniform(-1, 1, (k, n)).astype(type) + quant_value_ref, absmax_ref = quantize_blockwise_bnb4_ref(matrix_float, block_size, quant_type) + quant_value, absmax = quantize_blockwise_bnb4_target(matrix_float, block_size, quant_type) + assert np.allclose(quant_value_ref, quant_value) + assert np.allclose(absmax_ref, absmax) + + +if __name__ == "__main__": + unittest.main()