diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake
index fe3e577b4fc36..de1458c120016 100644
--- a/cmake/onnxruntime_rocm_hipify.cmake
+++ b/cmake/onnxruntime_rocm_hipify.cmake
@@ -54,6 +54,11 @@ set(contrib_ops_excluded_files
"quantization/attention_quantization_impl.cuh"
"quantization/dequantize_blockwise.cuh"
"quantization/dequantize_blockwise.cu"
+ "quantization/dequantize_blockwise_bnb4.cuh"
+ "quantization/dequantize_blockwise_bnb4.cu"
+ "quantization/matmul_bnb4.cc"
+ "quantization/matmul_bnb4.cuh"
+ "quantization/matmul_bnb4.cu"
"quantization/matmul_nbits.cc"
"quantization/matmul_nbits.cuh"
"quantization/matmul_nbits.cu"
diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index 5805333a0868c..1a76c18a6a8e0 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -47,6 +47,7 @@ Do not modify directly.*
* com.microsoft.Inverse
* com.microsoft.Irfft
* com.microsoft.LongformerAttention
+ * com.microsoft.MatMulBnb4
* com.microsoft.MatMulFpQ4
* com.microsoft.MatMulInteger16
* com.microsoft.MatMulIntegerToFloat
@@ -2504,6 +2505,62 @@ This version of the operator has been available since version 1 of the 'com.micr
+### **com.microsoft.MatMulBnb4**
+
+ MatMulBnb4 is a MatMul with weight quantized with 4 bits using either FP4 or NF4 data type (https://arxiv.org/pdf/2305.14314.pdf). It does Matrix Multiplication like MatMul (https://github.com/onnx/onnx/blob/main/docs/Operators.md#matmul) with differences:
+ 1. Input B is a 2D constant Matrix. Its input feature count and output feature count are specified by attribute 'K' and 'N'.
+ 2. Input B is quantized with 4 bits with quantization data type specified by attribute 'quant_type'. It is transposed, flattened and quantized blockwisely with block size specified by attribute 'block_size'.
+ And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,..
+ 3. Input B's quantization constants or scales are specified by input 'absmax'.
+
+ Input B is stored as uint8_t with shape: [(N * K + 1) / 2].
+ Input absmax is stored in same type as original type of B(float32, float16) with shape like: [(N * K + block_size - 1) / block_size].
+
+#### Version
+
+This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
+
+#### Attributes
+
+
+- K : int (required)
+- size of each input feature
+- N : int (required)
+- size of each output feature
+- block_size : int (required)
+- number of groupsize used for weight quantization. It needs to be a power of 2 and not smaller than 16.
+- quant_type : int (required)
+- quantization data type. 0 for FP4, 1 for NF4.
+
+
+#### Inputs
+
+
+- A : T1
+- The input tensor, not quantized
+- B : T2
+- 1-dimensional quantized data for weight
+- absmax : T1
+- quantization constants
+
+
+#### Outputs
+
+
+- Y : T1
+- tensor. The output tensor has the same rank as the input.
+
+
+#### Type Constraints
+
+
+- T1 : tensor(float), tensor(float16)
+- Constrain input and output types to float/half_float tensors.
+- T2 : tensor(uint8)
+- Constrain quantized weight types to uint8.
+
+
+
### **com.microsoft.MatMulFpQ4**
Matrix product with right hand matrix being pre-packed and quantized int4 data blob.
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index dea71d81f8df5..aef76203920df 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -454,6 +454,7 @@ Do not modify directly.*
|GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float)|
|GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)|
|Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
+|MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)|
|MatMulFpQ4|*in* A:**T1**
*in* B:**T2**
*in* B_shape:**T3**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)
**T3** = tensor(int64)|
|MatMulInteger16|*in* A:**T1**
*in* B:**T2**
*out* Y:**T3**|1+|**T1** = tensor(int16)
**T2** = tensor(int16)
**T3** = tensor(int32)|
|MatMulIntegerToFloat|*in* A:**T1**
*in* B:**T2**
*in* a_scale:**T3**
*in* b_scale:**T3**
*in* a_zero_point:**T1**
*in* b_zero_point:**T2**
*in* bias:**T3**
*out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float)|
@@ -849,6 +850,7 @@ Do not modify directly.*
|Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|Irfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
+|MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)|
|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T2**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)|
|MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**T** = tensor(float), tensor(float16)|
|NGramRepeatBlock|*in* input_ids:**Tid**
*in* scores:**T**
*out* scores_out:**T**|1+|**T** = tensor(float)
**Tid** = tensor(int64)|
diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
index f77e403f26dde..f9d9b13f0fedc 100644
--- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
+++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
@@ -30,6 +30,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Gathe
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, TransposeMatMul); // backward compatibility
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, FusedMatMul);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulNBits);
+class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulBnb4);
#ifndef ORT_MINIMAL_BUILD
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulFpQ4);
#endif
@@ -270,6 +271,7 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo, // backward compatibility
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
#ifndef ORT_MINIMAL_BUILD
BuildKernelCreateInfo,
#endif
diff --git a/onnxruntime/contrib_ops/cpu/quantization/blockwise_quant_block_bnb4.h b/onnxruntime/contrib_ops/cpu/quantization/blockwise_quant_block_bnb4.h
new file mode 100644
index 0000000000000..cb8e97a592d8c
--- /dev/null
+++ b/onnxruntime/contrib_ops/cpu/quantization/blockwise_quant_block_bnb4.h
@@ -0,0 +1,202 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include
+#include
+#include
+
+namespace onnxruntime {
+namespace contrib {
+
+#if defined(_MSC_VER)
+#define FORCEINLINE __forceinline
+#else
+#define FORCEINLINE __attribute__((always_inline)) inline
+#endif
+
+typedef enum Bnb_DataType_t {
+ FP4 = 0,
+ NF4 = 1,
+} Bnb_DataType_t;
+
+FORCEINLINE uint8_t QuantizeOneFP4(float x) {
+ // FP4 with bias of 3
+ // first bit is a sign
+ // subnormals
+ // 0b000 = 0
+ // 0b001 = 0.0625
+ // 0b110 = 2
+ // 0b111 = 3
+ // 0b100 = 4
+ // 0b101 = 6
+ // 0b010 = 8
+ // 0b011 = 12
+
+ // we do a binary search
+ // the pivots are divided by 12 (the FP4 absmax)
+ // since we assum input data is in [-1.0, 1.0]
+
+ // !be careful here, its easy to make a mistake
+ // that is difficult to noice if you add an extra
+ // zero somewhere!
+
+ uint8_t sign = x < 0 ? 0b1000 : 0b0000;
+ x = fabsf(x);
+ if (x > 0.29166667f) {
+ if (x > 0.583333f) {
+ if (x > 0.8333333f) {
+ return 0b0011 + sign;
+ } else {
+ return 0b0010 + sign;
+ }
+ } else if (x > 0.4166667f) {
+ return 0b101 + sign;
+ } else {
+ return 0b100 + sign;
+ }
+ } else if (x > 0.0859375f) {
+ if (x > 0.20833333f) {
+ return 0b0111 + sign;
+ } else {
+ return 0b0110 + sign;
+ }
+ } else if (x > 0.00260417f) {
+ return 0b0001 + sign;
+ } else {
+ return 0b0000 + sign;
+ }
+}
+
+FORCEINLINE uint8_t QuantizeOneNF4(float x) {
+ if (x > 0.03979014977812767f) {
+ if (x > 0.3893125355243683f) { // 1
+ if (x > 0.6427869200706482f) { // 11
+ if (x > 0.8614784181118011f) { // 111
+ return 0b1111;
+ } else {
+ return 0b1110;
+ }
+ } else if (x > 0.5016634166240692f) { // 110
+ return 0b1101;
+ } else {
+ return 0b1100;
+ }
+ } else if (x > 0.2035212516784668f) { // 10
+ if (x > 0.2920137718319893f) { // 101
+ return 0b1011;
+ } else {
+ return 0b1010;
+ }
+ } else if (x > 0.1202552504837513f) { // 100
+ return 0b1001;
+ } else {
+ return 0b1000;
+ }
+ } else if (x > -0.33967943489551544f) { // 0
+ if (x > -0.13791173323988914f) { // 01
+ if (x > -0.045525018125772476f) { // 011
+ return 0b0111;
+ } else {
+ return 0b0110;
+ }
+ } else if (x > -0.23460740596055984f) { // 010
+ return 0b0101;
+ } else {
+ return 0b0100;
+ }
+ } else if (x > -0.6106329262256622f) { // 00
+ if (x > -0.4599952697753906f) { // 001
+ return 0b0011;
+ } else {
+ return 0b0010;
+ }
+ } else if (x > -0.8480964004993439f) { // 000
+ return 0b0001;
+ } else {
+ return 0b0000;
+ }
+}
+
+template
+FORCEINLINE uint8_t QuantizeOneBnb4(float x) {
+ if constexpr (DATA_TYPE == FP4)
+ return QuantizeOneFP4(x);
+ else
+ return QuantizeOneNF4(x);
+}
+
+template
+FORCEINLINE void QuantizeBlockBnb4(const T* src, uint8_t* dst, T& absmax_block, int32_t block_idx, int32_t numel) {
+ float local_absmax = 0.0f;
+
+ int32_t block_len = std::min(block_size, numel - block_idx * block_size);
+ int32_t src_offset = block_idx * block_size;
+ int32_t dst_offset = block_idx * block_size / 2;
+
+ for (int32_t idx = 0; idx < block_len; idx++) {
+ const float v = static_cast(src[src_offset + idx]);
+ local_absmax = fmaxf(local_absmax, fabsf(v));
+ }
+
+ absmax_block = static_cast(local_absmax);
+ const float reciprocal_absmax = local_absmax ? 1.0f / local_absmax : 0.0f;
+
+ for (int32_t idx = 0; idx < block_len; idx += 2) {
+ const float v0 = static_cast(src[src_offset + idx]) * reciprocal_absmax;
+ const uint8_t vi0 = QuantizeOneBnb4(v0);
+
+ const float v1 = (idx + 1 < block_len) ? static_cast(src[src_offset + idx + 1]) * reciprocal_absmax : 0;
+ const uint8_t vi1 = QuantizeOneBnb4(v1);
+
+ dst[dst_offset + idx / 2] = (vi0 << 4) | vi1;
+ }
+}
+
+static float fp4_qaunt_map[16] = {0.00000000f, 5.208333333e-03f, 0.66666667f, 1.00000000f,
+ 0.33333333f, 0.50000000f, 0.16666667f, 0.25000000f,
+ -0.00000000f, -5.208333333e-03f, -0.66666667f, -1.00000000f,
+ -0.33333333f, -0.50000000f, -0.16666667f, -0.25000000f};
+
+static float nf4_qaunt_map[16] = {-1.0f,
+ -0.6961928009986877f,
+ -0.5250730514526367f,
+ -0.39491748809814453f,
+ -0.28444138169288635f,
+ -0.18477343022823334f,
+ -0.09105003625154495f,
+ 0.0f,
+ 0.07958029955625534f,
+ 0.16093020141124725f,
+ 0.24611230194568634f,
+ 0.33791524171829224f,
+ 0.44070982933044434f,
+ 0.5626170039176941f,
+ 0.7229568362236023f,
+ 1.0f};
+
+template
+FORCEINLINE T DequantizeOneBnb4(uint8_t x) {
+ if constexpr (DATA_TYPE == FP4)
+ return static_cast(fp4_qaunt_map[x]);
+ else
+ return static_cast(nf4_qaunt_map[x]);
+}
+
+template
+FORCEINLINE void DequantizeBlockBnb4(const uint8_t* src, T* dst, T absmax_block, int32_t block_idx, int32_t numel) {
+ int32_t block_len = std::min(block_size, numel - block_idx * block_size);
+ int32_t src_offset = block_idx * block_size / 2;
+ int32_t dst_offset = block_idx * block_size;
+
+ for (int32_t idx = 0; idx < block_len; idx += 2) {
+ const uint8_t val = src[src_offset + idx / 2];
+
+ dst[dst_offset + idx] = DequantizeOneBnb4(val >> 4) * absmax_block;
+ if (idx + 1 < block_len) dst[dst_offset + idx + 1] = DequantizeOneBnb4(val & 0xF) * absmax_block;
+ }
+}
+
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise_bnb4.h b/onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise_bnb4.h
new file mode 100644
index 0000000000000..5ddb77e5b5ee3
--- /dev/null
+++ b/onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise_bnb4.h
@@ -0,0 +1,143 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include "blockwise_quant_block_bnb4.h"
+
+#include
+
+#include "core/common/safeint.h"
+#include "core/framework/float16.h"
+#include "core/platform/threadpool.h"
+#include
+
+namespace onnxruntime {
+namespace contrib {
+
+template
+void QuantizeBlockwiseBnb4(
+ uint8_t* dst, // shape: [(N * K + 1) / 2]
+ const T* src, // shape: [N, K]
+ T* absmax, // shape: [(N * K + block_size - 1) / block_size]
+ int32_t N,
+ int32_t K,
+ onnxruntime::concurrency::ThreadPool* thread_pool) {
+ int32_t numel = N * K;
+ int32_t total_block_count = (numel + block_size - 1) / block_size;
+
+ concurrency::ThreadPool::TryBatchParallelFor(
+ thread_pool,
+ total_block_count,
+ [&](ptrdiff_t block_idx) {
+ QuantizeBlockBnb4(
+ src,
+ dst,
+ absmax[block_idx],
+ static_cast(block_idx),
+ numel);
+ },
+ 0);
+}
+
+#define QuantizeBlockwiseBn4DataTyped(block_size, quant_type) \
+ if (quant_type == FP4) \
+ QuantizeBlockwiseBnb4(dst, src, absmax, N, K, thread_pool); \
+ else \
+ QuantizeBlockwiseBnb4(dst, src, absmax, N, K, thread_pool);
+
+template
+void QuantizeBlockwiseBnb4(
+ uint8_t* dst, // shape: [(N * K + 1) / 2]
+ const T* src, // shape: [N, K]
+ T* absmax, // shape: [(N * K + block_size - 1) / block_size]
+ int32_t block_size,
+ int32_t quant_type,
+ int32_t N,
+ int32_t K,
+ onnxruntime::concurrency::ThreadPool* thread_pool) {
+ ORT_ENFORCE(
+ quant_type == FP4 || quant_type == NF4,
+ "Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported.");
+
+ if (block_size == 16) {
+ QuantizeBlockwiseBn4DataTyped(16, quant_type);
+ } else if (block_size == 32) {
+ QuantizeBlockwiseBn4DataTyped(32, quant_type);
+ } else if (block_size == 64) {
+ QuantizeBlockwiseBn4DataTyped(64, quant_type);
+ } else if (block_size == 128) {
+ QuantizeBlockwiseBn4DataTyped(128, quant_type);
+ } else if (block_size == 256) {
+ QuantizeBlockwiseBn4DataTyped(256, quant_type);
+ } else {
+ ORT_NOT_IMPLEMENTED("only block size 16, 32, 64, 128, 256 are supported.");
+ }
+}
+
+#undef QuantizeBlockwiseBn4DataTyped
+
+template
+void DequantizeBlockwiseBnb4(
+ T* dst, // shape: [N, K]
+ const uint8_t* src, // shape: [(N * K + 1) / 2)]
+ const T* absmax, // shape: [(N * K + block_size - 1) / block_size]
+ int32_t N,
+ int32_t K,
+ onnxruntime::concurrency::ThreadPool* thread_pool) {
+ int32_t numel = N * K;
+ int32_t total_block_count = (numel + block_size - 1) / block_size;
+
+ concurrency::ThreadPool::TryBatchParallelFor(
+ thread_pool,
+ total_block_count,
+ [&](ptrdiff_t block_idx) {
+ DequantizeBlockBnb4(
+ src,
+ dst,
+ absmax[block_idx],
+ static_cast(block_idx),
+ numel);
+ },
+ 0);
+}
+
+#define DequantizeBlockwiseBn4DataTyped(block_size, quant_type) \
+ if (quant_type == FP4) \
+ DequantizeBlockwiseBnb4(dst, src, absmax, N, K, thread_pool); \
+ else \
+ DequantizeBlockwiseBnb4(dst, src, absmax, N, K, thread_pool);
+
+template
+void DequantizeBlockwiseBnb4(
+ T* dst, // shape: [N, K]
+ const uint8_t* src, // shape: [(N * K + 1) / 2)]
+ const T* absmax, // shape: [(N * K + block_size - 1) / block_size]
+ int32_t block_size,
+ int32_t quant_type,
+ int32_t N,
+ int32_t K,
+ onnxruntime::concurrency::ThreadPool* thread_pool) {
+ ORT_ENFORCE(
+ quant_type == FP4 || quant_type == NF4,
+ "Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported.");
+
+ if (block_size == 16) {
+ DequantizeBlockwiseBn4DataTyped(16, quant_type);
+ } else if (block_size == 32) {
+ DequantizeBlockwiseBn4DataTyped(32, quant_type);
+ } else if (block_size == 64) {
+ DequantizeBlockwiseBn4DataTyped(64, quant_type);
+ } else if (block_size == 128) {
+ DequantizeBlockwiseBn4DataTyped(128, quant_type);
+ } else if (block_size == 256) {
+ DequantizeBlockwiseBn4DataTyped(256, quant_type);
+ } else {
+ ORT_NOT_IMPLEMENTED("only block size 16, 32, 64, 128, 256 are supported.");
+ }
+}
+
+#undef DequantizeBlockwiseBn4DataTyped
+
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_bnb4.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_bnb4.cc
new file mode 100644
index 0000000000000..2f3ede49c3650
--- /dev/null
+++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_bnb4.cc
@@ -0,0 +1,109 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "core/common/safeint.h"
+#include "core/framework/op_kernel.h"
+#include "core/providers/cpu/math/matmul_helper.h"
+#include "core/providers/common.h"
+#include "dequantize_blockwise_bnb4.h"
+#include "core/mlas/inc/mlas.h"
+
+namespace onnxruntime {
+namespace contrib {
+
+class MatMulBnb4 final : public OpKernel {
+ public:
+ MatMulBnb4(const OpKernelInfo& info) : OpKernel(info) {
+ ORT_ENFORCE(Status::OK() == info.GetAttr("K", &K_));
+ ORT_ENFORCE(Status::OK() == info.GetAttr("N", &N_));
+ ORT_ENFORCE(Status::OK() == info.GetAttr("block_size", &block_size_));
+ ORT_ENFORCE(Status::OK() == info.GetAttr("quant_type", &quant_type_));
+ ORT_ENFORCE(
+ quant_type_ == FP4 || quant_type_ == NF4,
+ "Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported.");
+ }
+
+ Status Compute(OpKernelContext* context) const override;
+
+ private:
+ int64_t K_;
+ int64_t N_;
+ int64_t block_size_;
+ int64_t quant_type_;
+};
+
+Status MatMulBnb4::Compute(OpKernelContext* ctx) const {
+ concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool();
+
+ const Tensor* a = ctx->Input(0);
+ const Tensor* b_quant = ctx->Input(1);
+ const Tensor* absmax = ctx->Input(2);
+
+ const float* a_data = a->Data();
+ const uint8_t* b_quant_data = b_quant->Data();
+ const float* absmax_data = absmax->Data();
+
+ AllocatorPtr allocator;
+ auto status = ctx->GetTempSpaceAllocator(&allocator);
+ ORT_RETURN_IF_ERROR(status);
+ auto tmp_b_data_ptr = IAllocator::MakeUniquePtr(allocator, SafeInt(K_) * N_);
+ DequantizeBlockwiseBnb4(
+ tmp_b_data_ptr.get(),
+ b_quant_data,
+ absmax_data,
+ static_cast(block_size_),
+ static_cast(quant_type_),
+ static_cast(N_),
+ static_cast(K_),
+ thread_pool);
+
+ constexpr bool transa = false;
+ constexpr bool transb = true;
+ TensorShape b_shape({N_, K_});
+ MatMulComputeHelper helper;
+ ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, transa, transb));
+
+ Tensor* y = ctx->Output(0, helper.OutputShape());
+
+ // Bail out early if the output is going to be empty
+ if (y->Shape().Size() == 0) return Status::OK();
+
+ auto* y_data = y->MutableData();
+
+ const size_t max_len = helper.OutputOffsets().size();
+ const size_t M = static_cast(helper.M());
+ const size_t N = static_cast(helper.N());
+ const size_t K = static_cast(helper.K());
+ const size_t lda = helper.Lda(transa);
+ const size_t ldb = helper.Ldb(transb);
+
+ // TODO: implement with native kernel
+ std::vector data(max_len);
+ for (size_t i = 0; i < max_len; i++) {
+ data[i].BIsPacked = false;
+ data[i].A = a_data + helper.LeftOffsets()[i];
+ data[i].lda = lda;
+ data[i].B = tmp_b_data_ptr.get() + helper.RightOffsets()[i];
+ data[i].ldb = ldb;
+ data[i].C = y_data + helper.OutputOffsets()[i];
+ data[i].ldc = N;
+ data[i].alpha = 1.f;
+ data[i].beta = 0.0f;
+ }
+ MlasGemmBatch(CblasNoTrans, CblasTrans, M, N, K, data.data(), max_len, thread_pool);
+
+ return Status::OK();
+}
+
+ONNX_OPERATOR_KERNEL_EX(
+ MatMulBnb4,
+ kMSDomain,
+ 1,
+ kCpuExecutionProvider,
+ KernelDefBuilder()
+ .TypeConstraint("T1", DataTypeImpl::GetTensorType())
+ .TypeConstraint("T2", DataTypeImpl::GetTensorType()),
+ MatMulBnb4);
+
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
index c52f869d6a9d2..e762a80cb0e2f 100644
--- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
+++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
@@ -118,6 +118,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Inverse);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MatMulNBits);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MatMulNBits);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MatMulBnb4);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MatMulBnb4);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Trilu);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, UnfoldTensor);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, DynamicTimeWarping);
@@ -279,6 +281,8 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cu b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cu
new file mode 100644
index 0000000000000..e58723f0b31e1
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cu
@@ -0,0 +1,129 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include
+#include
+#include "core/providers/cuda/cuda_common.h"
+#include "contrib_ops/cpu/quantization/blockwise_quant_block_bnb4.h"
+#include "dequantize_blockwise_bnb4.cuh"
+
+namespace onnxruntime {
+namespace contrib {
+namespace cuda {
+
+template
+Status SetBnbQuantMap(int quant_type, T* quant_map_buffer, cudaStream_t stream) {
+ ORT_ENFORCE(
+ quant_type == FP4 || quant_type == NF4,
+ "Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported.");
+
+ T host_quant_map[16];
+ switch (quant_type) {
+ case FP4:
+ for (int i = 0; i < 16; i++) host_quant_map[i] = static_cast(fp4_qaunt_map[i]);
+ break;
+ case NF4:
+ for (int i = 0; i < 16; i++) host_quant_map[i] = static_cast(nf4_qaunt_map[i]);
+ break;
+ }
+ CUDA_CALL_THROW(cudaMemcpyAsync(quant_map_buffer, host_quant_map, sizeof(T) * 16, cudaMemcpyHostToDevice, stream));
+
+ return Status::OK();
+}
+
+template Status SetBnbQuantMap(int quant_type, float* quant_map_buffer, cudaStream_t stream);
+
+template Status SetBnbQuantMap(int quant_type, half* quant_map_buffer, cudaStream_t stream);
+
+template
+__global__ void kDequantizeBlockwise(
+ const T* quant_map,
+ T* output,
+ const uint8_t* quant_data,
+ const T* absmax,
+ const int block_size,
+ const int n) {
+ const int n_load = (gridDim.x * TILE_SIZE);
+ int valid_items_load = 0;
+ int valid_items_store = 0;
+ const int base_idx = (blockIdx.x * TILE_SIZE);
+
+ T vals[NUM_PER_TH * 2];
+ uint8_t qvals[NUM_PER_TH];
+ T local_abs_max = T(0.0f);
+
+ typedef cub::BlockLoad LoadChar;
+ typedef cub::BlockStore StoreT;
+
+ __shared__ typename LoadChar::TempStorage loadchar;
+ __shared__ typename StoreT::TempStorage storet;
+
+ for (unsigned int i = base_idx; i < n_load; i += gridDim.x * TILE_SIZE) {
+ valid_items_load = (n + 1) / 2 - i > TILE_SIZE ? TILE_SIZE : (n + 1) / 2 - i;
+ valid_items_store = n - i * 2 > TILE_SIZE * 2 ? TILE_SIZE * 2 : n - i * 2;
+
+ local_abs_max = __ldg(&absmax[(i + threadIdx.x * NUM_PER_TH) / (block_size)]);
+
+ __syncthreads();
+ LoadChar(loadchar).Load(&(quant_data[i]), qvals, valid_items_load, 128);
+
+ #pragma unroll NUM_PER_TH
+ for (int j = 0; j < NUM_PER_TH; j++) {
+ #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530
+ vals[j * 2] = quant_map[qvals[j] >> 4] * local_abs_max;
+ vals[j * 2 + 1] = quant_map[qvals[j] & 0x0F] * local_abs_max;
+ #else
+ // half multiplication not supported
+ vals[j * 2] = static_cast(static_cast(quant_map[qvals[j] >> 4]) * static_cast(local_abs_max));
+ vals[j * 2 + 1] =
+ static_cast(static_cast(quant_map[qvals[j] & 0x0F]) * static_cast(local_abs_max));
+ #endif
+ }
+
+ __syncthreads();
+ StoreT(storet).Store(&(output[i * 2]), vals, valid_items_store);
+ }
+}
+
+template
+Status DequantizeBnb4(
+ const T* quant_map,
+ T* output,
+ const uint8_t* quant_data,
+ const T* absmax,
+ int block_size,
+ int numel,
+ cudaStream_t stream) {
+ int tile_size = 1024;
+ kDequantizeBlockwise<<<(numel + tile_size - 1) / tile_size, 64, 0, stream>>>(
+ quant_map,
+ output,
+ quant_data,
+ absmax,
+ block_size / 2,
+ numel);
+
+ return Status::OK();
+}
+
+template Status DequantizeBnb4(
+ const float* quant_map,
+ float* output,
+ const uint8_t* quant_data,
+ const float* absmax,
+ int block_size,
+ int numel,
+ cudaStream_t stream);
+
+template Status DequantizeBnb4(
+ const half* quant_map,
+ half* output,
+ const uint8_t* quant_data,
+ const half *absmax,
+ int block_size,
+ int numel,
+ cudaStream_t stream);
+
+} // namespace cuda
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh
new file mode 100644
index 0000000000000..4aef3ab699f9c
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh
@@ -0,0 +1,26 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+#include "core/providers/cuda/shared_inc/cuda_utils.h"
+
+namespace onnxruntime {
+namespace contrib {
+namespace cuda {
+
+template
+Status SetBnbQuantMap(int quant_type, T* quant_map_buffer, cudaStream_t stream);
+
+template
+Status DequantizeBnb4(
+ const T* quant_map,
+ T* output,
+ const uint8_t* quant_data,
+ const T* absmax,
+ int block_size,
+ int numel,
+ cudaStream_t stream);
+
+} // namespace cuda
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc
new file mode 100644
index 0000000000000..bd5b6e0a8a1ce
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc
@@ -0,0 +1,144 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "core/common/safeint.h"
+#include "core/providers/cuda/cuda_kernel.h"
+#include "core/providers/cuda/shared_inc/fpgeneric.h"
+#include "core/providers/cpu/math/matmul_helper.h"
+#include "contrib_ops/cpu/quantization/blockwise_quant_block_bnb4.h"
+#include "matmul_bnb4.cuh"
+#include "dequantize_blockwise_bnb4.cuh"
+
+namespace onnxruntime {
+namespace contrib {
+namespace cuda {
+using namespace onnxruntime::cuda;
+
+template
+class MatMulBnb4 final : public CudaKernel {
+ public:
+ MatMulBnb4(const OpKernelInfo& info) : CudaKernel(info) {
+ ORT_ENFORCE(Status::OK() == info.GetAttr("K", &K_));
+ ORT_ENFORCE(Status::OK() == info.GetAttr("N", &N_));
+ ORT_ENFORCE(Status::OK() == info.GetAttr("block_size", &block_size_));
+ ORT_ENFORCE(Status::OK() == info.GetAttr("quant_type", &quant_type_));
+ ORT_ENFORCE(
+ quant_type_ == FP4 || quant_type_ == NF4,
+ "Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported.");
+ }
+
+ Status ComputeInternal(OpKernelContext* context) const override;
+
+ private:
+ int64_t K_;
+ int64_t N_;
+ int64_t block_size_;
+ int64_t quant_type_;
+};
+
+template
+Status MatMulBnb4::ComputeInternal(OpKernelContext* ctx) const {
+ const Tensor* a = ctx->Input(0);
+ const Tensor* b_quant = ctx->Input(1);
+ const Tensor* absmax = ctx->Input(2);
+
+ const auto* a_data = a->Data();
+ const uint8_t* b_quant_data = b_quant->Data();
+ const auto* absmax_data = absmax->Data();
+
+ typedef typename ToCudaType::MappedType CudaT;
+
+ // TODO: find a better way to create the quant_map without using a buffer
+ // don't want to use malloc directly so asking from the caller
+ // can create a __device__ static array for float but doesn't work for half
+ IAllocatorUniquePtr quant_map_buffer = GetScratchBuffer(16, ctx->GetComputeStream());
+ auto* quant_map_buffer_data = quant_map_buffer.get();
+ ORT_RETURN_IF_ERROR(SetBnbQuantMap(
+ SafeInt(quant_type_),
+ reinterpret_cast(quant_map_buffer_data),
+ static_cast(ctx->GetComputeStream()->GetHandle())));
+
+ constexpr bool transa = false;
+ constexpr bool transb = true;
+ MatMulComputeHelper helper;
+ TensorShape b_shape({N_, K_});
+ ORT_RETURN_IF_ERROR(
+ helper.Compute(a->Shape(), b_shape, transa, transb));
+
+ Tensor* Y = ctx->Output(0, helper.OutputShape());
+ // Bail out early if the output is going to be empty
+ if (Y->Shape().Size() == 0) return Status::OK();
+
+ bool is_4bit_done = TryMatMulBnb4(
+ reinterpret_cast(quant_map_buffer_data),
+ reinterpret_cast(Y->MutableData()),
+ reinterpret_cast(a_data),
+ b_quant_data,
+ reinterpret_cast(absmax_data),
+ SafeInt(helper.M()),
+ SafeInt(helper.N()),
+ SafeInt(helper.K()),
+ SafeInt(block_size_),
+ static_cast(ctx->GetComputeStream()->GetHandle()));
+
+ if (!is_4bit_done) {
+ IAllocatorUniquePtr b_dequant_ptr = GetScratchBuffer(N_ * K_, ctx->GetComputeStream());
+ auto* b_dequant_data = b_dequant_ptr.get();
+ ORT_RETURN_IF_ERROR(DequantizeBnb4(
+ reinterpret_cast(quant_map_buffer_data),
+ reinterpret_cast(b_dequant_data),
+ b_quant_data,
+ reinterpret_cast(absmax_data),
+ SafeInt(block_size_),
+ SafeInt(N_ * K_),
+ static_cast(ctx->GetComputeStream()->GetHandle())));
+
+ const CudaT alpha = ToCudaType::FromFloat(1.f);
+ const CudaT zero = ToCudaType::FromFloat(0.f);
+
+ CUBLAS_RETURN_IF_ERROR(cublasGemmHelper(
+ GetCublasHandle(ctx),
+ CUBLAS_OP_T,
+ CUBLAS_OP_N,
+ SafeInt(helper.N()),
+ SafeInt(helper.M()),
+ SafeInt(helper.K()),
+ &alpha,
+ reinterpret_cast(b_dequant_data),
+ SafeInt(K_),
+ reinterpret_cast(a_data),
+ helper.Lda(transa),
+ &zero,
+ reinterpret_cast(Y->MutableData()),
+ helper.Ldc(),
+ GetDeviceProp()));
+ }
+
+ return Status::OK();
+}
+
+ONNX_OPERATOR_TYPED_KERNEL_EX(
+ MatMulBnb4,
+ kMSDomain,
+ 1,
+ float,
+ kCudaExecutionProvider,
+ (*KernelDefBuilder::Create())
+ .TypeConstraint("T1", DataTypeImpl::GetTensorType())
+ .TypeConstraint("T2", DataTypeImpl::GetTensorType()),
+ MatMulBnb4);
+
+ONNX_OPERATOR_TYPED_KERNEL_EX(
+ MatMulBnb4,
+ kMSDomain,
+ 1,
+ MLFloat16,
+ kCudaExecutionProvider,
+ (*KernelDefBuilder::Create())
+ .TypeConstraint("T1", DataTypeImpl::GetTensorType())
+ .TypeConstraint("T2", DataTypeImpl::GetTensorType()),
+ MatMulBnb4);
+
+} // namespace cuda
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cu b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cu
new file mode 100644
index 0000000000000..1d9aa75ff3701
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cu
@@ -0,0 +1,192 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include
+
+#include
+#include
+#include
+#include "matmul_bnb4.cuh"
+
+namespace onnxruntime {
+namespace contrib {
+namespace cuda {
+
+#define num_values_4bit 32
+template
+__global__ void kgemm_4bit_inference_naive(
+ int M,
+ int N,
+ int K,
+ const T* __restrict__ A,
+ const uint8_t* B,
+ const T* absmax,
+ const T* datatype,
+ T* out,
+ int lda,
+ int ldb,
+ int ldc,
+ int block_size) {
+ // per threadblock:
+ // load step-by-step in chunks of [32,warps]: 1x32 * [32,warps] -> [1,warps]
+ // 4 warps -> 4 loads per iter
+ // 1x32 * 32x4 -> 1x4 outputs per thread block
+ typedef cub::WarpReduce WarpReduce;
+ __shared__ typename WarpReduce::TempStorage temp_storage[THREADS / 32];
+
+ const int warp_idx = threadIdx.x / 32;
+ const int warp_lane = threadIdx.x % 32;
+ const int row_B = (THREADS / 32) * blockIdx.x + warp_idx;
+ const int num_values_8bit = num_values_4bit / 2;
+ float local_C = 0.0f;
+
+ uint8_t local_B_4bit[num_values_8bit];
+ T local_B[num_values_4bit / 4];
+ T local_A[num_values_4bit / 4];
+ __shared__ T quant_map[16];
+ T local_absmax = T(0.0f);
+
+ for (int i = threadIdx.x; i < 16; i++) quant_map[i] = T(datatype[i]);
+ __syncthreads();
+
+ // A: [1, K]
+ // B: [N, K]
+ for (int inner_idx = warp_lane * num_values_4bit; inner_idx < K; inner_idx += 32 * num_values_4bit) {
+ int inner_idx_halved = inner_idx / 2;
+ int offset_B = ldb * row_B;
+ int absidx = ((2 * offset_B) + inner_idx) / block_size;
+ local_absmax = __ldg(&(absmax[absidx]));
+
+ if (row_B < N) {
+ if ((inner_idx_halved + num_values_8bit) < (K / 2)) {
+ // this is the most important for performance considerations
+ reinterpret_cast(local_B_4bit)[0] =
+ reinterpret_cast(B)[(offset_B + (inner_idx_halved)) / (num_values_8bit)];
+ } else {
+ #pragma unroll
+ for (int j = 0; j < (num_values_8bit); j++)
+ if ((inner_idx_halved) + j < (K / 2))
+ local_B_4bit[j] = B[offset_B + inner_idx_halved + j];
+ else
+ local_B_4bit[j] = 0b01110111;
+ }
+ } else {
+ #pragma unroll
+ for (int j = 0; j < (num_values_8bit); j++) local_B_4bit[j] = 0b01110111;
+ }
+
+ for (int i = 0; i < 4; i++) {
+ #pragma unroll
+ for (int k = 0; k < num_values_8bit / 4; k++) {
+ #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530
+ local_B[k * 2] = quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] >> 4] * local_absmax;
+ local_B[k * 2 + 1] = quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] & 0x0F] * local_absmax;
+ #else
+ // half multiplication not supported
+ local_B[k * 2] =
+ static_cast(static_cast(quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] >> 4]) *
+ static_cast(local_absmax));
+ local_B[k * 2 + 1] =
+ static_cast(static_cast(quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] & 0x0F]) *
+ static_cast(local_absmax));
+ #endif
+ }
+
+ if (inner_idx + (num_values_4bit / 4) + (i * num_values_4bit / 4) < K) {
+ // this is also relatively important for performance
+ if (BITS == 16) {
+ reinterpret_cast(local_A)[0] =
+ reinterpret_cast(A)[inner_idx / (num_values_4bit / 4) + i];
+ } else {
+ reinterpret_cast(local_A)[0] =
+ reinterpret_cast(A)[inner_idx / (num_values_4bit / 8) + (2 * i) + 0];
+ reinterpret_cast(local_A)[1] =
+ reinterpret_cast(A)[inner_idx / (num_values_4bit / 8) + (2 * i) + 1];
+ }
+ } else {
+ #pragma unroll
+ for (int k = 0; k < num_values_4bit / 4; k++) {
+ if (inner_idx + (i * num_values_4bit / 4) + k < K)
+ local_A[k] = A[inner_idx + k + (i * num_values_4bit / 4)];
+ else
+ local_A[k] = T(0.0f);
+ }
+ }
+
+ // accumulate in float; small performance hit for Ampere, but lower error for outputs
+ #pragma unroll
+ for (int k = 0; k < num_values_4bit / 4; k++) {
+ #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530
+ local_C += static_cast(local_A[k] * local_B[k]);
+ #else
+ // half multiplication not supported
+ local_C += static_cast(local_A[k]) * static_cast(local_B[k]);
+ #endif
+ }
+ }
+ }
+
+ local_C = WarpReduce(temp_storage[warp_idx]).Sum(local_C);
+
+ if (row_B < N && warp_lane == 0) out[row_B] = T(local_C);
+}
+
+template
+bool TryMatMulBnb4(
+ const T* quant_map,
+ T* output,
+ const T* a_data,
+ const uint8_t* b_data_quant,
+ const T* absmax,
+ int m,
+ int n,
+ int k,
+ int block_size,
+ cudaStream_t stream) {
+ if (k % block_size != 0 || m > 1) {
+ return false;
+ }
+ // supported block_sizes are [4096, 2048, 1024, 512, 256, 128, 64, 32]
+ if (block_size % 32 != 0 || block_size > 4096) {
+ return false;
+ }
+
+ int lda = k;
+ int ldb = (k + 1) / 2;
+ int ldc = n;
+ int num_blocks = (n + 3) / 4;
+
+ constexpr int bits = std::is_same_v ? 16 : 32;
+ kgemm_4bit_inference_naive<<>>(
+ m, n, k, a_data, b_data_quant, absmax, quant_map, output, lda, ldb, ldc, block_size);
+
+ return true;
+}
+
+template bool TryMatMulBnb4(
+ const float* quant_map,
+ float* output,
+ const float* a_data,
+ const uint8_t* b_data_quant,
+ const float* absmax,
+ int m,
+ int n,
+ int k,
+ int block_size,
+ cudaStream_t stream);
+
+template bool TryMatMulBnb4(
+ const half* quant_map,
+ half* output,
+ const half* a_data,
+ const uint8_t* b_data_quant,
+ const half* absmax,
+ int m,
+ int n,
+ int k,
+ int block_size,
+ cudaStream_t stream);
+
+} // namespace cuda
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cuh b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cuh
new file mode 100644
index 0000000000000..743234282fbf3
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cuh
@@ -0,0 +1,26 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+#include "core/providers/cuda/shared_inc/cuda_utils.h"
+
+namespace onnxruntime {
+namespace contrib {
+namespace cuda {
+
+template
+bool TryMatMulBnb4(
+ const T* quant_map,
+ T* output,
+ const T* a_data,
+ const uint8_t* b_data_quant,
+ const T* absmax,
+ int m,
+ int n,
+ int k,
+ int block_size,
+ cudaStream_t stream);
+
+} // namespace cuda
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
index 5e5eee568fa21..681a728f823da 100644
--- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
@@ -3239,6 +3239,41 @@ Input zero_points is stored as uint8_t. If bits <= 4, two zero points are stored
MatmulWithQuantWeightShapeInference(ctx, in_features, out_features);
});
+ static const char* MatMulBnb4_ver1_doc = R"DOC(
+MatMulBnb4 is a MatMul with weight quantized with 4 bits using either FP4 or NF4 data type (https://arxiv.org/pdf/2305.14314.pdf). It does Matrix Multiplication like MatMul (https://github.com/onnx/onnx/blob/main/docs/Operators.md#matmul) with differences:
+ 1. Input B is a 2D constant Matrix. Its input feature count and output feature count are specified by attribute 'K' and 'N'.
+ 2. Input B is quantized with 4 bits with quantization data type specified by attribute 'quant_type'. It is transposed, flattened and quantized blockwisely with block size specified by attribute 'block_size'.
+ And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,..
+ 3. Input B's quantization constants or scales are specified by input 'absmax'.
+
+Input B is stored as uint8_t with shape: [(N * K + 1) / 2].
+Input absmax is stored in same type as original type of B(float32, float16) with shape like: [(N * K + block_size - 1) / block_size].
+
+)DOC";
+
+ ONNX_CONTRIB_OPERATOR_SCHEMA(MatMulBnb4)
+ .SetDomain(kMSDomain)
+ .SinceVersion(1)
+ .SetDoc(MatMulBnb4_ver1_doc)
+ .Attr("K", "size of each input feature", AttributeProto::INT)
+ .Attr("N", "size of each output feature", AttributeProto::INT)
+ .Attr("block_size", "number of groupsize used for weight quantization. It needs to be a power of 2 and not smaller than 16.", AttributeProto::INT)
+ .Attr("quant_type", "quantization data type. 0 for FP4, 1 for NF4.", AttributeProto::INT)
+ .Input(0, "A", "The input tensor, not quantized", "T1")
+ .Input(1, "B", "1-dimensional quantized data for weight", "T2")
+ .Input(2, "absmax", "quantization constants", "T1")
+ .Output(0, "Y", "tensor. The output tensor has the same rank as the input. ", "T1")
+ .TypeConstraint("T1", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float/half_float tensors.")
+ .TypeConstraint("T2", {"tensor(uint8)"}, "Constrain quantized weight types to uint8.")
+ .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
+ // Type inference
+ propagateElemTypeFromInputToOutput(ctx, 0, 0);
+ // Shape inference
+ int64_t in_features = getAttribute(ctx, "K", -1);
+ int64_t out_features = getAttribute(ctx, "N", -1);
+ MatmulWithQuantWeightShapeInference(ctx, in_features, out_features);
+ });
+
#ifdef ENABLE_ATEN
ONNX_CONTRIB_OPERATOR_SCHEMA(ATen)
.SetDomain(kPytorchAtenDomain)
diff --git a/onnxruntime/python/onnxruntime_pybind_quant.cc b/onnxruntime/python/onnxruntime_pybind_quant.cc
index 52ea677d5141d..04dfa9b51e112 100644
--- a/onnxruntime/python/onnxruntime_pybind_quant.cc
+++ b/onnxruntime/python/onnxruntime_pybind_quant.cc
@@ -6,6 +6,7 @@
#include
#include "contrib_ops/cpu/quantization/dequantize_blockwise.h"
+#include "contrib_ops/cpu/quantization/dequantize_blockwise_bnb4.h"
#include "core/util/thread_utils.h"
namespace pybind11 {
@@ -64,9 +65,39 @@ void QuantizeMatMul4BitsBlockwise(
tp.get());
}
+template
+void QuantizeMatMulBnb4Blockwise(
+ py::array_t dst,
+ py::array_t src,
+ py::array_t absmax,
+ int32_t block_size,
+ int32_t quant_type,
+ int32_t N,
+ int32_t K) {
+ OrtThreadPoolParams to;
+ auto tp = concurrency::CreateThreadPool(&onnxruntime::Env::Default(), to,
+ concurrency::ThreadPoolType::INTRA_OP);
+
+ py::buffer_info dst_buf = dst.request();
+ py::buffer_info src_buf = src.request();
+ py::buffer_info absmax_buf = absmax.request();
+
+ contrib::QuantizeBlockwiseBnb4(
+ static_cast(dst_buf.ptr),
+ static_cast(src_buf.ptr),
+ static_cast(absmax_buf.ptr),
+ block_size,
+ quant_type,
+ N,
+ K,
+ tp.get());
+}
+
void CreateQuantPybindModule(py::module& m) {
m.def("quantize_matmul_4bits", &QuantizeMatMul4BitsBlockwise);
m.def("quantize_matmul_4bits", &QuantizeMatMul4BitsBlockwise);
+ m.def("quantize_matmul_bnb4", &QuantizeMatMulBnb4Blockwise);
+ m.def("quantize_matmul_bnb4", &QuantizeMatMulBnb4Blockwise);
}
} // namespace python
diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/cuda/dequant_blockwise_bnb4.cu b/onnxruntime/python/tools/kernel_explorer/kernels/cuda/dequant_blockwise_bnb4.cu
new file mode 100644
index 0000000000000..3504ce1bebe8c
--- /dev/null
+++ b/onnxruntime/python/tools/kernel_explorer/kernels/cuda/dequant_blockwise_bnb4.cu
@@ -0,0 +1,89 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+// This file serve as a simple example for adding a tunable op to onnxruntime.
+
+#include
+#include
+#include
+
+#include
+
+#include "core/providers/cuda/tunable/cuda_tunable.h"
+#include "python/tools/kernel_explorer/kernel_explorer_interface.h"
+#include "python/tools/kernel_explorer/device_array.h"
+#include "contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh"
+
+namespace py = pybind11;
+
+namespace onnxruntime {
+
+// Extend the OpParams so that all specializations have the same parameter passing interface
+template
+struct DequantizeBnb4Params : cuda::tunable::OpParams {
+ std::string Signature() const override { return std::to_string(n_); }
+
+ int quant_type_;
+ T* output_;
+ const uint8_t* quant_;
+ const T* absmax_;
+ T* quant_map_buffer_;
+ int n_;
+ int k_;
+};
+
+template
+class DequantizeBnb4 : public IKernelExplorer {
+ public:
+ DequantizeBnb4(
+ int quant_type,
+ DeviceArray& output,
+ DeviceArray& quant,
+ DeviceArray& absmax,
+ DeviceArray& quant_map_buffer,
+ int n, int k) {
+ params_.tuning_ctx = TuningContext();
+ params_.stream = Stream();
+ params_.quant_type_ = quant_type;
+ params_.output_ = static_cast(output.ptr());
+ params_.quant_ = static_cast(quant.ptr());
+ params_.absmax_ = static_cast(absmax.ptr());
+ params_.quant_map_buffer_ = static_cast(quant_map_buffer.ptr());
+ params_.n_ = n;
+ params_.k_ = k;
+ }
+
+ void Run() override {
+ ORT_THROW_IF_ERROR(contrib::cuda::SetBnbQuantMap(
+ params_.quant_type_,
+ params_.quant_map_buffer_,
+ params_.StreamHandle()));
+ ORT_THROW_IF_ERROR(contrib::cuda::DequantizeBnb4(
+ params_.quant_map_buffer_,
+ params_.output_,
+ params_.quant_,
+ params_.absmax_,
+ 64,
+ params_.n_ * params_.k_,
+ params_.StreamHandle()));
+ }
+
+ private:
+ // A VectorAddOp is a callable that can process const VectorAddParams*
+ using ParamsT = DequantizeBnb4Params;
+ ParamsT params_{};
+};
+
+#define REGISTER_OP(name, type) \
+ py::class_>(m, #name "_" #type) \
+ .def(py::init()) \
+ .def("SetRepeats", &name::SetRepeats) \
+ .def("Profile", &name::Profile) \
+ .def("Run", &name::Run);
+
+KE_REGISTER(m) {
+ REGISTER_OP(DequantizeBnb4, half);
+ REGISTER_OP(DequantizeBnb4, float);
+}
+
+} // namespace onnxruntime
diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/cuda/matmul_bnb4.cu b/onnxruntime/python/tools/kernel_explorer/kernels/cuda/matmul_bnb4.cu
new file mode 100644
index 0000000000000..e4cd83565357a
--- /dev/null
+++ b/onnxruntime/python/tools/kernel_explorer/kernels/cuda/matmul_bnb4.cu
@@ -0,0 +1,96 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+// This file serve as a simple example for adding a tunable op to onnxruntime.
+
+#include
+#include
+#include
+
+#include
+
+#include "core/providers/cuda/tunable/cuda_tunable.h"
+#include "python/tools/kernel_explorer/kernel_explorer_interface.h"
+#include "python/tools/kernel_explorer/kernels/vector_add_kernel.cuh"
+#include "contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh"
+#include "contrib_ops/cuda/quantization/matmul_bnb4.cuh"
+
+namespace py = pybind11;
+
+namespace onnxruntime {
+
+// Extend the OpParams so that all specializations have the same parameter passing interface
+template
+struct MatrixFloatBnb4Params : cuda::tunable::OpParams {
+ std::string Signature() const override { return std::to_string(n_); }
+
+ int quant_type_;
+ T* output_;
+ const T* a_;
+ const uint8_t* b_;
+ const T* absmax_;
+ T* quant_map_buffer_;
+ int m_;
+ int n_;
+ int k_;
+};
+
+template
+class MatrixFloatBnb4 : public IKernelExplorer {
+ public:
+ MatrixFloatBnb4(DeviceArray& output,
+ DeviceArray& a,
+ DeviceArray& b,
+ DeviceArray& absmax,
+ DeviceArray& quant_map_buffer,
+ int quant_type, int m, int n, int k) {
+ params_.tuning_ctx = TuningContext();
+ params_.stream = Stream();
+ params_.output_ = static_cast(output.ptr());
+ params_.a_ = static_cast(a.ptr());
+ params_.b_ = static_cast(b.ptr());
+ params_.absmax_ = static_cast(absmax.ptr());
+ params_.quant_map_buffer_ = static_cast(quant_map_buffer.ptr());
+ params_.quant_type_ = quant_type;
+ params_.m_ = m;
+ params_.n_ = n;
+ params_.k_ = k;
+ }
+
+ void Run() override {
+ ORT_THROW_IF_ERROR(contrib::cuda::SetBnbQuantMap(
+ params_.quant_type_,
+ params_.quant_map_buffer_,
+ params_.StreamHandle()));
+ contrib::cuda::TryMatMulBnb4(
+ params_.quant_map_buffer_,
+ params_.output_,
+ params_.a_,
+ params_.b_,
+ params_.absmax_,
+ params_.m_,
+ params_.n_,
+ params_.k_,
+ 64,
+ params_.StreamHandle());
+ }
+
+ private:
+ // A VectorAddOp is a callable that can process const VectorAddParams*
+ using ParamsT = MatrixFloatBnb4Params;
+ ParamsT params_{};
+};
+
+#define REGISTER_OP(name, type) \
+ py::class_>(m, #name "_" #type) \
+ .def(py::init()) \
+ .def("SetRepeats", &name::SetRepeats) \
+ .def("Profile", &name::Profile) \
+ .def("Run", &name::Run);
+
+KE_REGISTER(m) {
+ REGISTER_OP(MatrixFloatBnb4, half);
+ REGISTER_OP(MatrixFloatBnb4, float);
+}
+
+} // namespace onnxruntime
diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/dequantize_blockwise_bnb4.py b/onnxruntime/python/tools/kernel_explorer/kernels/dequantize_blockwise_bnb4.py
new file mode 100644
index 0000000000000..140151aadcc0f
--- /dev/null
+++ b/onnxruntime/python/tools/kernel_explorer/kernels/dequantize_blockwise_bnb4.py
@@ -0,0 +1,92 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+# --------------------------------------------------------------------------
+
+import sys
+from dataclasses import dataclass
+
+import kernel_explorer as ke
+import numpy as np
+from utils import dtype_to_bytes
+
+
+def dtype_to_funcs(dtype):
+ type_map = {
+ "float16": list(filter(lambda x: "DequantizeBnb4_half" in x, dir(ke))),
+ "float32": list(filter(lambda x: "DequantizeBnb4_float" in x, dir(ke))),
+ }
+ return type_map[dtype]
+
+
+quant_enums = {"FP4": 0, "NF4": 1}
+
+
+dtypes = ["float16", "float32"]
+quant_types = ["FP4", "NF4"]
+
+
+@dataclass
+class DequantizeBnb4Metric(ke.BandwidthMetric):
+ quant_type: str
+ n: int
+ k: int
+
+ def report(self):
+ return (
+ f"{self.duration:6.2f} us {self.gbps:5.2f} GB/s"
+ f" {self.quant_type} {self.dtype} n={self.n} k={self.k} {self.name}"
+ )
+
+
+def profile_dequantize_int4_func(qt, n, k, dtype, func):
+ np.random.seed(0)
+ block_size = 64
+ numel = n * k
+ output = np.random.rand(n, k).astype(dtype)
+ quant = np.random.randint(low=0, high=255, size=(numel + 1) // 2).astype("uint8")
+ absmax = np.random.rand((numel + block_size - 1) // block_size).astype(dtype)
+ quant_map_buffer = np.zeros(16).astype(dtype)
+
+ output_d = ke.DeviceArray(output)
+ quant_d = ke.DeviceArray(quant)
+ absmax_d = ke.DeviceArray(absmax)
+ quant_map_buffer_d = ke.DeviceArray(quant_map_buffer)
+ f = getattr(ke, func)
+ my_op = f(quant_enums[qt], output_d, quant_d, absmax_d, quant_map_buffer_d, n, k)
+ duration_ms = my_op.Profile()
+ total_bytes = numel / 2 + (numel + numel / block_size) * dtype_to_bytes(dtype)
+
+ ke.report(DequantizeBnb4Metric(func, dtype, duration_ms, total_bytes, qt, n, k))
+
+
+def profile_with_args(qt, n, k, dtype, sort):
+ with ke.benchmark(sort):
+ for func in dtype_to_funcs(dtype):
+ profile_dequantize_int4_func(qt, n, k, dtype, func)
+
+
+def profile():
+ for qt in quant_types:
+ for dt in dtypes:
+ for n, k in ((4096, 4096), (4096, 12288), (12288, 4096)):
+ profile_with_args(qt, n, k, dt, True)
+ print()
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ group = parser.add_argument_group("profile with args")
+ group.add_argument("n", type=int)
+ group.add_argument("k", type=int)
+ group.add_argument("quant_type", choices=quant_types)
+ group.add_argument("dtype", choices=dtypes)
+ group.add_argument("--sort", action="store_true")
+
+ if len(sys.argv) == 1:
+ profile()
+ else:
+ args = parser.parse_args()
+ profile_with_args(args.quant_type, args.n, args.k, args.dtype, args.sort)
diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/matmul_bnb4.py b/onnxruntime/python/tools/kernel_explorer/kernels/matmul_bnb4.py
new file mode 100644
index 0000000000000..4a9489050fd61
--- /dev/null
+++ b/onnxruntime/python/tools/kernel_explorer/kernels/matmul_bnb4.py
@@ -0,0 +1,136 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+# --------------------------------------------------------------------------
+
+import sys
+from dataclasses import dataclass
+
+import kernel_explorer as ke
+import numpy as np
+from utils import dtype_to_bytes
+
+
+def dtype_to_funcs(dtype):
+ type_map = {
+ "float16": list(filter(lambda x: "MatrixFloatBnb4_half" in x, dir(ke))),
+ "float32": list(filter(lambda x: "MatrixFloatBnb4_float" in x, dir(ke))),
+ }
+ return type_map[dtype]
+
+
+def dtype_to_funcs_cublas(dtype):
+ type_map = {
+ "float16": list(filter(lambda x: "GemmBenchmark_half" in x, dir(ke))),
+ "float32": list(filter(lambda x: "GemmBenchmark_float" in x, dir(ke))),
+ }
+ return type_map[dtype]
+
+
+quant_enums = {"FP4": 0, "NF4": 1}
+
+
+dtypes = ["float16", "float32"]
+quant_types = ["FP4", "NF4"]
+
+
+@dataclass
+class MatrixMulMetric(ke.BandwidthMetric):
+ m: int
+ n: int
+ k: int
+
+ def report(self):
+ return (
+ f"{self.duration:6.2f} us {self.gbps:5.2f} GB/s {self.dtype} m={self.m} n={self.n} k={self.k} {self.name}"
+ )
+
+
+@dataclass
+class MatrixFpBnb4Metric(MatrixMulMetric):
+ quant_type: str
+
+ def report(self):
+ return (
+ f"{self.duration:6.2f} us {self.gbps:5.2f} GB/s"
+ f" {self.quant_type} {self.dtype} m={self.m} n={self.n} k={self.k} {self.name}"
+ )
+
+
+def profile_matmul_fp_bnb4_func(qt, m, n, k, dtype, func):
+ np.random.seed(0)
+ block_size = 64
+ numel = n * k
+ output = np.random.rand(m, n).astype(dtype)
+ a = np.random.rand(m, k).astype(dtype)
+ b = np.random.randint(low=0, high=255, size=(numel + 1) // 2).astype("uint8")
+ absmax = np.random.rand((numel + block_size - 1) // block_size).astype(dtype)
+ quant_map_buffer = np.zeros(16).astype(dtype)
+
+ output_d = ke.DeviceArray(output)
+ a_d = ke.DeviceArray(a)
+ b_d = ke.DeviceArray(b)
+ absmax_d = ke.DeviceArray(absmax)
+ quant_map_buffer_d = ke.DeviceArray(quant_map_buffer)
+ f = getattr(ke, func)
+
+ my_op = f(output_d, a_d, b_d, absmax_d, quant_map_buffer_d, quant_enums[qt], m, n, k)
+ duration_ms = my_op.Profile()
+ total_bytes = (m * k + n * k + m * n) * (dtype_to_bytes(dtype))
+
+ ke.report(MatrixFpBnb4Metric(func, dtype, duration_ms, total_bytes, m, n, k, qt))
+
+
+def profile_gemm_func(m, n, k, dtype, func):
+ np.random.seed(0)
+ output = np.random.rand(m, n).astype(dtype)
+ a = np.random.rand(m, k).astype(dtype)
+ b = np.random.rand(k, n).astype(dtype)
+
+ output_d = ke.DeviceArray(output)
+ a_d = ke.DeviceArray(a)
+ b_d = ke.DeviceArray(b)
+ f = getattr(ke, func)
+ my_op = f(output_d, a_d, b_d, m, n, k)
+ duration_ms = my_op.Profile()
+ total_bytes = (m * k + n * k + m * n) * (dtype_to_bytes(dtype))
+
+ ke.report(MatrixMulMetric(func, dtype, duration_ms, total_bytes, m, n, k))
+
+
+def profile_with_args(qt, m, n, k, dtype, sort):
+ with ke.benchmark(sort):
+ for func in dtype_to_funcs(dtype):
+ profile_matmul_fp_bnb4_func(qt, m, n, k, dtype, func)
+
+ for func in dtype_to_funcs_cublas(dtype):
+ profile_gemm_func(m, n, k, dtype, func)
+
+
+def profile():
+ dims_m = [1]
+ for qt in quant_types:
+ for dt in dtypes:
+ for m in dims_m:
+ for n, k in ((4096, 4096), (4096, 12288), (12288, 4096)):
+ profile_with_args(qt, m, n, k, dt, False)
+ print()
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ group = parser.add_argument_group("profile with args")
+ group.add_argument("m", type=int)
+ group.add_argument("n", type=int)
+ group.add_argument("k", type=int)
+ group.add_argument("quant_type", choices=quant_types)
+ group.add_argument("dtype", choices=dtypes)
+ group.add_argument("--sort", action="store_true")
+
+ if len(sys.argv) == 1:
+ profile()
+ else:
+ args = parser.parse_args()
+ profile_with_args(args.quant_type, args.m, args.n, args.k, args.dtype, args.sort)
diff --git a/onnxruntime/python/tools/quantization/matmul_bnb4_quantizer.py b/onnxruntime/python/tools/quantization/matmul_bnb4_quantizer.py
new file mode 100644
index 0000000000000..951746a089305
--- /dev/null
+++ b/onnxruntime/python/tools/quantization/matmul_bnb4_quantizer.py
@@ -0,0 +1,240 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License. See License.txt in the project root for
+# license information.
+# --------------------------------------------------------------------------
+
+import argparse
+import logging
+import os
+from typing import List, Tuple
+
+import numpy as np
+import numpy.typing as npt
+import onnx
+from onnx.onnx_pb import GraphProto, ModelProto, NodeProto, TensorProto
+
+from onnxruntime.capi._pybind_state import quantize_matmul_bnb4
+
+from .onnx_model import ONNXModel
+from .quant_utils import attribute_to_kwarg
+
+logger = logging.getLogger(__name__)
+
+
+class MatMulBnb4Quantizer:
+ """Perform 4b quantization of constant MatMul weights using FP4 or NF4 data type"""
+
+ ##################
+ # quantization types, must be consistent with native code type
+ # Bnb_DataType_t defined in blockwise_quant_block_bnb4.h
+
+ # 4b floating point with bias of 3
+ FP4 = 0
+
+ # 4b NormalFloat
+ NF4 = 1
+
+ def __init__(self, model: ModelProto, quant_type: int, block_size: int, nodes_to_exclude=None):
+ nodes_to_exclude = nodes_to_exclude or []
+ assert quant_type in [MatMulBnb4Quantizer.FP4, MatMulBnb4Quantizer.NF4]
+ self.model = ONNXModel(model)
+ self.quant_type = quant_type
+ self.block_size = block_size
+ self.nodes_to_exclude = set(nodes_to_exclude)
+
+ @staticmethod
+ def __get_initializer(name, graph_path: List[GraphProto]) -> Tuple[TensorProto, GraphProto]:
+ for gid in range(len(graph_path) - 1, -1, -1):
+ graph = graph_path[gid]
+ for tensor in graph.initializer:
+ if tensor.name == name:
+ return tensor, graph
+ return None, None
+
+ def bnb4_block_quant(self, fpweight: npt.ArrayLike) -> np.ndarray:
+ """4b quantize fp32/fp16 weight"""
+
+ if len(fpweight.shape) != 2:
+ raise ValueError("Current bnb4 block quantization only supports 2D tensors!")
+ # need to copy since the transposed weight still has the original memory layout
+ # Linear4bit quantizes its weight data which is the transposed weight
+ fpweight_t = fpweight.transpose().copy()
+
+ rows, cols = fpweight.shape
+ numel = rows * cols
+ block_size = self.block_size
+ num_blocks = (numel + block_size - 1) // block_size
+ quantized_numel = (numel + 1) // 2
+
+ packed = np.zeros(quantized_numel, dtype="uint8")
+ absmax = np.zeros(num_blocks, dtype=fpweight.dtype)
+ # block wise quantization, fpweight_t is flattened and divided into blocks
+ quantize_matmul_bnb4(packed, fpweight_t, absmax, block_size, self.quant_type, cols, rows)
+
+ return (packed, absmax)
+
+ def _bnb4_matmul_node_weight(self, node: NodeProto, graph_stack: List[GraphProto]) -> NodeProto:
+ """If the node is MatMul with fp32 const weight, quantize the weight with int4, and return the new node"""
+
+ if node.op_type != "MatMul":
+ return node # only care about MatMul for now
+
+ logger.debug(f"start to quantize {node.name} ...")
+ if node.name in self.nodes_to_exclude:
+ logger.debug(f"exclude to quantize {node.name} as specified by nodes_to_exclude...")
+ return node
+
+ inputB = node.input[1] # noqa: N806
+ B, Bs_graph = MatMulBnb4Quantizer.__get_initializer(inputB, graph_stack) # noqa: N806
+ if B is None:
+ logger.debug("MatMul doesn't have const weight. Skip to quantize")
+ return node # only care about constant weight
+
+ B_array = onnx.numpy_helper.to_array(B) # noqa: N806
+ if len(B_array.shape) != 2:
+ logger.debug("MatMul weight is not 2D. Skip to quantize")
+ return node # can only process 2-D matrix
+
+ packed, absmax = self.bnb4_block_quant(B_array)
+ B_quant = onnx.numpy_helper.from_array(packed) # noqa: N806
+ B_quant.name = B.name + "_Bnb4"
+ for input in Bs_graph.input:
+ if input.name == inputB:
+ Bs_graph.input.remove(input)
+ break
+
+ absmax_tensor = onnx.numpy_helper.from_array(absmax)
+ absmax_tensor.name = B.name + "_absmax"
+
+ Bs_graph.initializer.extend([B_quant, absmax_tensor])
+
+ kwargs = {}
+ rows, cols = B_array.shape
+ kwargs["K"] = rows
+ kwargs["N"] = cols
+ kwargs["block_size"] = self.block_size
+ kwargs["quant_type"] = self.quant_type
+
+ matmul_bnb4_node = onnx.helper.make_node(
+ "MatMulBnb4",
+ inputs=[node.input[0], B_quant.name, absmax_tensor.name],
+ outputs=[node.output[0]],
+ name=node.name + "_Bnb4" if node.name else "",
+ domain="com.microsoft",
+ **kwargs,
+ )
+
+ logger.debug(f"complete quantization of {node.name} ...")
+
+ return matmul_bnb4_node
+
+ def _process_subgraph(self, graph_stack: List[GraphProto]):
+ new_nodes = []
+ graph = graph_stack[-1]
+
+ for node in graph.node:
+ graph_attrs = [
+ attr
+ for attr in node.attribute
+ if attr.type == onnx.AttributeProto.GRAPH or attr.type == onnx.AttributeProto.GRAPHS
+ ]
+ if len(graph_attrs):
+ kwargs = {}
+ for attr in node.attribute:
+ if attr.type == onnx.AttributeProto.GRAPH:
+ # recursive call to take care of sub-graph
+ graph_stack.append(attr.g)
+ kv = {attr.name: self._process_subgraph(graph_stack)}
+ elif attr.type == onnx.AttributeProto.GRAPHS:
+ value = []
+ for subgraph in attr.graphs:
+ # recursive call to take care of sub-graph
+ graph_stack.append(subgraph)
+ value.extend([self._process_subgraph(graph_stack)])
+ kv = {attr.name: value}
+ else:
+ kv = attribute_to_kwarg(attr)
+ kwargs.update(kv)
+ node = onnx.helper.make_node( # noqa: PLW2901
+ node.op_type, node.input, node.output, name=node.name, **kwargs
+ )
+
+ new_nodes.append(self._bnb4_matmul_node_weight(node, graph_stack))
+
+ graph.ClearField("node")
+ graph.node.extend(new_nodes)
+ graph_stack.pop()
+ return graph
+
+ def process(self):
+ # use a stack to keep track of sub-graphs
+ graph_stack = [self.model.graph()]
+ opset_import = self.model.opset_import()
+
+ has_ms_domain = False
+ for opset in opset_import:
+ if opset.domain == "com.microsoft":
+ has_ms_domain = True
+ if not has_ms_domain:
+ opset_import.extend([onnx.helper.make_opsetid("com.microsoft", 1)])
+
+ self._process_subgraph(graph_stack)
+ self.model.clean_initializers()
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(
+ description="""Blockwise FP4/NF4 quantization for MatMul 2D weight matrices.
+
+A weight matrix is partitioned into blocks, where each block is a contiguous
+subset inside the flattened transposed weight matrix. Each block is quantized
+into a set of 4b integers with an absolute value scaling factor.
+"""
+ )
+
+ parser.add_argument("--input_model", required=True, help="Path to the input model file")
+ parser.add_argument("--output_model", required=True, help="Path to the output model file")
+ parser.add_argument(
+ "--quant_type",
+ required=False,
+ default=1,
+ options=[MatMulBnb4Quantizer.FP4, MatMulBnb4Quantizer.NF4],
+ help="Quantization data type. 0: FP4, 1: NF4",
+ )
+ parser.add_argument(
+ "--block_size",
+ required=False,
+ default=64,
+ description="Block size for blockwise quantization. Note: bnb.nn.Linear4bit only uses block_size=64",
+ )
+ parser.add_argument("-v", "--verbose", required=False, action="store_true")
+ parser.set_defaults(verbose=False)
+ parser.add_argument(
+ "--nodes_to_exclude",
+ nargs="+",
+ type=str,
+ required=False,
+ default=[],
+ help="Specify the nodes to be excluded from quantization with node names",
+ )
+
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ if args.verbose:
+ logger.setLevel(logging.DEBUG)
+
+ input_model_path = args.input_model
+ output_model_path = args.output_model
+
+ if os.path.exists(output_model_path):
+ logger.error(f"file {output_model_path} already exists")
+ raise Exception(f"file {output_model_path} already exists")
+
+ model = onnx.load(input_model_path)
+ quant = MatMulBnb4Quantizer(model, args.quant_type, args.block_size, nodes_to_exclude=args.nodes_to_exclude)
+ quant.process()
+ quant.model.save_model_to_file(output_model_path, True)
diff --git a/onnxruntime/test/contrib_ops/matmul_bnb4_test.cc b/onnxruntime/test/contrib_ops/matmul_bnb4_test.cc
new file mode 100644
index 0000000000000..e739b17d5885f
--- /dev/null
+++ b/onnxruntime/test/contrib_ops/matmul_bnb4_test.cc
@@ -0,0 +1,151 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#ifndef ORT_MINIMAL_BUILD
+
+#include "core/common/span_utils.h"
+#include "core/framework/tensor.h"
+#include "core/mlas/inc/mlas_q4.h"
+#include "core/mlas/inc/mlas.h"
+#include "core/session/inference_session.h"
+#include "test/common/tensor_op_test_utils.h"
+#include "test/framework/test_utils.h"
+#include "test/optimizer/graph_transform_test_builder.h"
+#include "test/providers/provider_test_utils.h"
+#include "test/util/include/default_providers.h"
+#include "core/util/qmath.h"
+#include "contrib_ops/cpu/quantization/dequantize_blockwise_bnb4.h"
+
+#include
+#include
+
+#include "gtest/gtest.h"
+#include "gmock/gmock.h"
+
+namespace onnxruntime {
+namespace test {
+
+void QuantizeDequantizeBnb4(std::vector& raw_vals, // N X K
+ std::vector& quant_vals,
+ std::vector& absmax,
+ int32_t quant_type,
+ int32_t N,
+ int32_t K,
+ int32_t block_size) {
+ OrtThreadPoolParams to;
+ auto tp = concurrency::CreateThreadPool(&onnxruntime::Env::Default(), to,
+ concurrency::ThreadPoolType::INTRA_OP);
+
+ contrib::QuantizeBlockwiseBnb4(
+ quant_vals.data(),
+ raw_vals.data(),
+ absmax.data(),
+ block_size,
+ quant_type,
+ N,
+ K,
+ tp.get());
+
+ contrib::DequantizeBlockwiseBnb4(
+ raw_vals.data(),
+ quant_vals.data(),
+ absmax.data(),
+ block_size,
+ quant_type,
+ N,
+ K,
+ tp.get());
+}
+
+void RunTest(int64_t quant_type, int64_t M, int64_t N, int64_t K, int64_t block_size, bool use_float16) {
+ RandomValueGenerator random{1234};
+ std::vector input0_vals(random.Gaussian(std::vector({M, K}), 0.0f, 0.25f));
+ // quantizer expects transposed weights, N X K
+ std::vector input1_f_vals(random.Gaussian(std::vector({N, K}), 0.0f, 0.25f));
+
+ int64_t numel = N * K;
+ int64_t quantized_numel = (numel + 1) / 2;
+ int64_t total_block_count = (numel + block_size - 1) / block_size;
+ std::vector input1_vals(quantized_numel);
+ std::vector absmax(total_block_count);
+
+ QuantizeDequantizeBnb4(input1_f_vals,
+ input1_vals,
+ absmax,
+ static_cast(quant_type),
+ static_cast(N),
+ static_cast(K),
+ static_cast(block_size));
+
+ std::vector expected_vals(M * N);
+ for (int64_t m = 0; m < M; m++) {
+ for (int64_t n = 0; n < N; n++) {
+ float sum = 0.0f;
+ for (int64_t k = 0; k < K; k++) {
+ sum += input0_vals[m * K + k] * input1_f_vals[n * K + k];
+ }
+ expected_vals[m * N + n] = sum;
+ }
+ }
+
+ OpTester test("MatMulBnb4", 1, kMSDomain);
+ test.AddAttribute("K", K);
+ test.AddAttribute("N", N);
+ test.AddAttribute("block_size", block_size);
+ test.AddAttribute("quant_type", quant_type);
+ if (use_float16) {
+ test.AddInput("A", {M, K}, ToFloat16(input0_vals), false);
+ test.AddInput("B", {quantized_numel}, input1_vals, true);
+ test.AddInput("absmax", {total_block_count}, ToFloat16(absmax), true);
+
+ test.AddOutput("Y", {M, N}, ToFloat16(expected_vals));
+ test.SetOutputAbsErr("Y", 0.02f);
+
+ std::vector