diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh index 580b5087f3fa3..f6f380c8211f6 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh +++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh @@ -7,6 +7,43 @@ namespace onnxruntime { namespace contrib { namespace cuda { + +/////////////////////////////////////////////////////////////////////////////// +// A more general block-wise dequantization implementation that supports +// different block sizes and block orientations (row-wise/column-wise). +template < + int Row_, ///< rows of a matrix + int Column_ ///< columns of a matrix + > +struct Shape2D { + static int const kRow = Row_; ///< rows of a matrix + static int const kColumn = Column_; ///< columns of a matrix + static int const kCount = Row_ * Column_; ///< total number of elements in a matrix +}; + +/** + * @brief Blockwise quantization constants + * @tparam ElementT source data type, e.g. fp32/fp16 + * @tparam block_size number of elemenets quantized together + * @tparam qbits number of bits in each quantized element + * @tparam Columnwise true: elements in a block come from one single column + * false: elements in a block come from one single row + */ +template < + typename ElementT, + int32_t block_size, + int32_t qbits, + bool Columnwise> +struct BlkQuantTraits { + // number of qbit elements to pack into whole bytes + static constexpr int kPackSize = (qbits == 8) ? 1 : (qbits == 4) ? 2 : (qbits == 2) ? 4 : 0; + static_assert(kPackSize != 0, "Packing to whole bytes not supported for this qbits!"); + + using QuantBlk = std::conditional_t, Shape2D<1, block_size>>; + + using ThreadBlk = Shape2D; +}; + template Status Dequantize4Bits( T* output, @@ -19,6 +56,18 @@ Status Dequantize4Bits( int block_size, cudaStream_t stream); +template +Status Dequantize8Bits( + T* output, + const uint8_t* quant_data, + const T* scales_data, + const ZeroT* zero_points, + const int32_t* reorder_idx, + int k, + int n, + int block_size, + cudaStream_t stream); + /** * @brief Dequantize a block-wise quantized matrix, and store the result in a * column major matrix for use in subsequent GEMM. This implementation supports @@ -45,6 +94,17 @@ Status DequantizeBlockwise4b( int columns, cudaStream_t stream); +template +Status DequantizeBlockwise8b( + T* dst, + const uint8_t* qelements, + const T* scales, + const uint8_t* zero_points, + int block_size, + bool columnwise, + int rows, + int columns, + cudaStream_t stream); } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_4bits.cu similarity index 70% rename from onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu rename to onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_4bits.cu index 7fb0619a799dc..cea1834fa1b62 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_4bits.cu @@ -1,4 +1,3 @@ -// Modifications: scaling is moved from masked softmax to the gemm before that. // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. @@ -20,7 +19,6 @@ namespace onnxruntime { namespace contrib { namespace cuda { - __device__ __forceinline__ void DequantizeEightElements(uint32_t values_quant, half scale, half zp, half* output) { half2 scale_half2 = {scale, scale}; half zp_adjust = -scale * zp; @@ -68,25 +66,28 @@ __global__ void Dequantize4BitsKernelReOrder( int groups_per_K, int groups_per_threadblock, int total_groups) { - int group_id = blockIdx.x * groups_per_threadblock + ((threadIdx.x * 8) / block_size); + constexpr int bits = 4; + constexpr int element_per_thread = 32 / bits; // Process 8 elements per thread using uint32_t load + constexpr int element_per_byte = 8 / bits; + int group_id = blockIdx.x * groups_per_threadblock + ((threadIdx.x * element_per_thread) / block_size); if (group_id >= total_groups) { return; } - const int zero_point_shape_x = (groups_per_K + 1) / 2; + const int zero_point_shape_x = (groups_per_K + (element_per_byte - 1)) / element_per_byte; const int scales_shape_x = groups_per_K; int n_idx = group_id / scales_shape_x; int kb_idx = group_id % scales_shape_x; - int element_offset = group_id * block_size + ((threadIdx.x * 8) & (block_size - 1)); + int element_offset = group_id * block_size + ((threadIdx.x * element_per_thread) & (block_size - 1)); T* output_i = output + element_offset; - uint32_t quant_value = *(reinterpret_cast(quant_data + element_offset / 2)); - const int32_t* reorder_idx_with_off = reorder_idx + kb_idx * block_size + ((threadIdx.x * 8) & (block_size - 1)); - for (int i = 0; i < 8; i++) { + uint32_t quant_value = *(reinterpret_cast(quant_data + element_offset / element_per_byte)); + const int32_t* reorder_idx_with_off = reorder_idx + kb_idx * block_size + ((threadIdx.x * element_per_thread) & (block_size - 1)); + for (int i = 0; i < element_per_thread; i++) { int32_t rid = reorder_idx_with_off[i]; T scale = *(scale_data + n_idx * scales_shape_x + rid); - uint8_t zp = 8; + uint8_t zp = 8; // Default zero point is 1 << (bits - 1) if (zero_points) { - zp = zero_points[n_idx * zero_point_shape_x + rid / 2]; + zp = zero_points[n_idx * zero_point_shape_x + rid / element_per_byte]; zp = (rid & 0x01) ? (zp >> 4) : (zp & 0x0f); } @@ -130,7 +131,7 @@ __global__ void Dequantize4BitsKernel( } zero_point_value = static_cast(zp); } else { - zero_point_value = zero_points? *(zero_points + block_id):static_cast(8); + zero_point_value = zero_points ? *(zero_points + block_id) : static_cast(8); } output = output + element_offset; @@ -151,35 +152,45 @@ Status Dequantize4Bits( // k is padded and equal to block_per_K * block_size ORT_ENFORCE(k % block_size == 0, "k must be a multiplier of block_size"); constexpr int element_per_thread = 8; - int groups_per_threadblock = GridDim::maxThreadsPerBlock * element_per_thread / block_size; + int groups_per_K = k / block_size; int total_groups = n * groups_per_K; // total elemenets in quant_data - int groups_per_grid = static_cast(CeilDiv(total_groups, groups_per_threadblock)); + int groups_per_threadblock = GridDim::maxThreadsPerBlock * element_per_thread / block_size; + int groups_per_grid = CeilDiv(total_groups, groups_per_threadblock); + dim3 grid_dim(groups_per_grid); + dim3 block_dim(GridDim::maxThreadsPerBlock); + if (!reorder_idx || std::is_same_v) { - Dequantize4BitsKernel<<>>( + // Launch standard kernel + Dequantize4BitsKernel<<>>( output, quant_data, scales_data, zero_points, block_size, - groups_per_K, + groups_per_K, // Pass groups_per_K for potential ZP indexing if needed groups_per_threadblock, total_groups); } else { - // static_assert(std::is_same_v, "ZeroT must be uint8_t"); - Dequantize4BitsKernelReOrder<<>>( - output, - quant_data, - scales_data, - (const uint8_t*)zero_points, - reorder_idx, - block_size, - groups_per_K, - groups_per_threadblock, - total_groups); + // Launch reorder kernel (requires uint8_t zero points as per original structure) + if constexpr (std::is_same_v) { + Dequantize4BitsKernelReOrder<<>>( + output, + quant_data, + scales_data, + (const uint8_t*)zero_points, + reorder_idx, + block_size, + groups_per_K, + groups_per_threadblock, + total_groups); + } else { + return Status(::onnxruntime::common::ONNXRUNTIME, ::onnxruntime::common::INVALID_ARGUMENT, + "Reorder kernel currently expects uint8_t zero points."); + } } - return Status::OK(); + return CUDA_CALL(cudaGetLastError()); // Check for launch errors } template Status Dequantize4Bits( @@ -224,60 +235,23 @@ template Status Dequantize4Bits( int n, int block_size, cudaStream_t stream); -/////////////////////////////////////////////////////////////////////////////// -// A more general block-wise dequantization implementation that supports -// different block sizes and block orientations (row-wise/column-wise). template < - int Row_, ///< rows of a matrix - int Column_ ///< columns of a matrix - > -struct Shape2D { - static int const kRow = Row_; ///< rows of a matrix - static int const kColumn = Column_; ///< columns of a matrix - static int const kCount = Row_ * Column_; ///< total number of elements in a matrix -}; - -/** - * @brief Blockwise quantization constants - * @tparam ElementT source data type, e.g. fp32/fp16 - * @tparam block_size number of elemenets quantized together - * @tparam qbits number of bits in each quantized element - * @tparam Columnwise true: elements in a block come from one single column - * false: elements in a block come from one single row - */ -template < - typename ElementT, - int32_t block_size, - int32_t qbits, - bool Columnwise> -struct BlkQuantTraits { - // number of qbit elements to pack into whole bytes - static constexpr int kPackSize = (qbits == 8) ? 1 : (qbits == 4) ? 2 : (qbits == 2) ? 4 : 0; - static_assert(kPackSize != 0, "Packing to whole bytes not supported for this qbits!"); - - using QuantBlk = std::conditional_t, Shape2D<1, block_size>>; - using ThreadBlk = Shape2D; -}; - -template < - typename ElementT, - int32_t block_size, - int32_t qbits, - bool Columnwise> -__global__ -void dequantizeThread(ElementT* dst, - const uint8_t* weights, - const ElementT* scales, - const uint8_t* zero_points, - int rows, - int columns, - int thrd_row_blks) { + typename ElementT, + int32_t block_size, + int32_t qbits, + bool Columnwise> +__global__ void dequantizeThread4b(ElementT* dst, + const uint8_t* weights, + const ElementT* scales, + const uint8_t* zero_points, + int rows, + int columns, + int thrd_row_blks) { using QuantBlk = typename BlkQuantTraits::QuantBlk; using ThreadBlk = typename BlkQuantTraits::ThreadBlk; - // !! 4b specific code - static_assert(qbits == 4, "Only 4b block quantization is supported!"); + static_assert(qbits == 4, "Only 4b block quantization is supported by this kernel specialization!!"); const auto block_idx = blockIdx.x * blockDim.x + threadIdx.x; const auto row_blks = (rows + QuantBlk::kRow - 1) / QuantBlk::kRow; @@ -299,12 +273,12 @@ void dequantizeThread(ElementT* dst, // for 4b quant, kPackSize = 2, so we have 2 scales and 2 offsets const ElementT scale_buf[2] = { scales[(c / QuantBlk::kColumn) * row_blks + r / QuantBlk::kRow], - ((r/QuantBlk::kRow) < (meta_rows - 1)) + ((r / QuantBlk::kRow) < (meta_rows - 1)) ? scales[(c / QuantBlk::kColumn) * row_blks + r / QuantBlk::kRow + 1] : static_cast(0.0f)}; const uint8_t zp_pair = (zero_points == nullptr) - ? 0x88 - : zero_points[(c / QuantBlk::kColumn) * ((row_blks + 1) / 2) + (r / QuantBlk::kRow) / 2]; + ? 0x88 + : zero_points[(c / QuantBlk::kColumn) * ((row_blks + 1) / 2) + (r / QuantBlk::kRow) / 2]; const uint16_t zp_buf[2] = {(uint16_t)(zp_pair & 0x0f), (uint16_t)((zp_pair >> 4) & 0x0f)}; const ElementT adjust_buf[2] = {(-scale_buf[0]) * static_cast(zp_buf[0]), (-scale_buf[1]) * static_cast(zp_buf[1])}; @@ -315,7 +289,7 @@ void dequantizeThread(ElementT* dst, const auto scale0 = scale_buf[(i - r) / QuantBlk::kRow]; const auto adjust0 = adjust_buf[(i - r) / QuantBlk::kRow]; - const auto scale1 = scale_buf[(i + 1 - r) / QuantBlk::kRow];; + const auto scale1 = scale_buf[(i + 1 - r) / QuantBlk::kRow]; const auto adjust1 = adjust_buf[(i + 1 - r) / QuantBlk::kRow]; const auto vi = q_ptr[i / 2]; @@ -333,7 +307,8 @@ void dequantizeThread(ElementT* dst, static_assert(std::is_same::value, "Only float and half are supported!"); const uint8_t vi0 = vi & 0xf; const uint8_t vi1 = vi >> 4; - dst[j * rows + i] = static_cast(vi0) * scale0 + adjust0;; + dst[j * rows + i] = static_cast(vi0) * scale0 + adjust0; + ; dst[j * rows + (i + 1)] = static_cast(vi1) * scale1 + adjust1; } } @@ -351,13 +326,13 @@ void dequantizeThread(ElementT* dst, } template < - typename ElementT, - int32_t block_size, - int32_t qbits, - bool Columnwise> -static void dequantize(ElementT* dst, const uint8_t* weights, const ElementT* scales, - const uint8_t* zero_points, int32_t rows, int32_t columns, - cudaStream_t stream) { + typename ElementT, + int32_t block_size, + int32_t qbits, + bool Columnwise> +static void dequantize4b_generic(ElementT* dst, const uint8_t* weights, const ElementT* scales, + const uint8_t* zero_points, int32_t rows, int32_t columns, + cudaStream_t stream) { using ThreadBlk = typename BlkQuantTraits::ThreadBlk; // Thread partitioning @@ -366,7 +341,7 @@ static void dequantize(ElementT* dst, const uint8_t* weights, const ElementT* sc const auto total_thrd_blks = thrd_row_blks * thrd_col_blks; const auto grids = (total_thrd_blks + GridDim::maxThreadsPerBlock - 1) / GridDim::maxThreadsPerBlock; - dequantizeThread<<>>( + dequantizeThread4b<<>>( dst, weights, scales, @@ -376,7 +351,6 @@ static void dequantize(ElementT* dst, const uint8_t* weights, const ElementT* sc thrd_row_blks); } - template Status DequantizeBlockwise4b( @@ -392,41 +366,37 @@ DequantizeBlockwise4b( switch (block_size) { case 16: if (columnwise) { - dequantize(dst, src, scales, zero_points, rows, columns, stream); + dequantize4b_generic(dst, src, scales, zero_points, rows, columns, stream); } else { - dequantize(dst, src, scales, zero_points, rows, columns, stream); + dequantize4b_generic(dst, src, scales, zero_points, rows, columns, stream); } return Status::OK(); case 32: if (columnwise) { - dequantize(dst, src, scales, zero_points, rows, columns, stream); + dequantize4b_generic(dst, src, scales, zero_points, rows, columns, stream); } else { - dequantize(dst, src, scales, zero_points, rows, columns, stream); + dequantize4b_generic(dst, src, scales, zero_points, rows, columns, stream); } return Status::OK(); case 64: if (columnwise) { - dequantize(dst, src, scales, zero_points, rows, columns, stream); + dequantize4b_generic(dst, src, scales, zero_points, rows, columns, stream); } else { - dequantize(dst, src, scales, zero_points, rows, columns, stream); + dequantize4b_generic(dst, src, scales, zero_points, rows, columns, stream); } return Status::OK(); case 128: if (columnwise) { - dequantize(dst, src, scales, zero_points, rows, - columns, stream); + dequantize4b_generic(dst, src, scales, zero_points, rows, columns, stream); } else { - dequantize(dst, src, scales, zero_points, - rows, columns, stream); + dequantize4b_generic(dst, src, scales, zero_points, rows, columns, stream); } return Status::OK(); case 256: if (columnwise) { - dequantize(dst, src, scales, zero_points, rows, - columns, stream); + dequantize4b_generic(dst, src, scales, zero_points, rows, columns, stream); } else { - dequantize(dst, src, scales, zero_points, - rows, columns, stream); + dequantize4b_generic(dst, src, scales, zero_points, rows, columns, stream); } return Status::OK(); default: @@ -436,8 +406,8 @@ DequantizeBlockwise4b( } } -template -Status DequantizeBlockwise4b( +// Template instantiations for 4-bit blockwise +template Status DequantizeBlockwise4b( float* dst, const uint8_t* src, const float* scales, @@ -448,8 +418,7 @@ Status DequantizeBlockwise4b( int columns, cudaStream_t stream); -template -Status DequantizeBlockwise4b( +template Status DequantizeBlockwise4b( half* dst, const uint8_t* src, const half* scales, diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_8bits.cu b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_8bits.cu new file mode 100644 index 0000000000000..e90ed85b22f02 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_8bits.cu @@ -0,0 +1,465 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include +#include +#include +#include +#include +#include "core/providers/cuda/cu_inc/common.cuh" +#include "core/providers/cuda/cuda_common.h" +#include "contrib_ops/cuda/utils/dump_cuda_tensor.h" +#include "dequantize_blockwise.cuh" + +using namespace onnxruntime::cuda; +using namespace cub; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +// Processes 4 elements (since each is 8 bits, 4 fit in uint32_t) +__device__ __forceinline__ void DequantizeFourElements(uint32_t values_quant, half scale, half zp, half* output) { + half2 scale_half2 = {scale, scale}; + // Formula: val = (quant - zp) * scale = quant * scale - zp * scale + half zp_adjust = -scale * zp; + half2 zp_adjust2 = {zp_adjust, zp_adjust}; + + alignas(16) half2 results[2]; // Store 4 half values + + // Extract 4 uint8_t values from uint32_t + half v0 = __ushort2half_rn(static_cast(values_quant & 0xFF)); + half v1 = __ushort2half_rn(static_cast((values_quant >> 8) & 0xFF)); + results[0] = __halves2half2(v0, v1) * scale_half2 + zp_adjust2; + + half v2 = __ushort2half_rn(static_cast((values_quant >> 16) & 0xFF)); + half v3 = __ushort2half_rn(static_cast((values_quant >> 24) & 0xFF)); + results[1] = __halves2half2(v2, v3) * scale_half2 + zp_adjust2; + + // Write 4 half values (equivalent to float2) + *(reinterpret_cast(output)) = *(reinterpret_cast(results)); +} + +// Processes 4 elements (since each is 8 bits, 4 fit in uint32_t) +__device__ __forceinline__ void DequantizeFourElements(uint32_t values_quant, float scale, float zp, float* output) { + // Assuming ZP is symmetric or already adjusted if needed. Standard formula: val = (quant - zp) * scale = quant * scale - zp * scale + float zp_adjust = -scale * zp; + + // Extract 4 uint8_t values from uint32_t + output[0] = float(values_quant & 0xFF) * scale + zp_adjust; + output[1] = float((values_quant >> 8) & 0xFF) * scale + zp_adjust; + output[2] = float((values_quant >> 16) & 0xFF) * scale + zp_adjust; + output[3] = float((values_quant >> 24) & 0xFF) * scale + zp_adjust; +} + +// REVIEW: Deprecate reorder_idx (Recommend to reorder scales and zero points during model conversion instead of using reorder_idx). +// Reorder index is a 1D array of size [K] to support desc_act used in GPTQ quantization. +// However, it impacts inference performance of the kernel since it is not optimized for coalescing memory access. +template +__global__ void Dequantize8BitsKernelReOrder( + T* output, + const uint8_t* quant_data, + const T* scale_data, + const uint8_t* zero_points, // Assuming uint8_t zero points for reorder case + const int32_t* reorder_idx, + int block_size, + int groups_per_K, + int groups_per_threadblock, + int total_groups) { + constexpr int element_per_thread = 4; // Process 4 elements (uint8_t) per thread using uint32_t load + int group_id = blockIdx.x * groups_per_threadblock + ((threadIdx.x * element_per_thread) / block_size); + if (group_id >= total_groups) { + return; + } + + // element_offset corresponds to the start of the 4 elements processed by this thread iteration + int element_offset = group_id * block_size + ((threadIdx.x * element_per_thread) & (block_size - 1)); + + T* output_i = output + element_offset; + + // shape of scales and zero_points is [N, groups_per_K]. Compute the 2D indices below. + int n_idx = group_id / groups_per_K; + int kb_idx = group_id % groups_per_K; + + // Read 4 uint8_t values packed into a uint32_t + uint32_t quant_value = *(reinterpret_cast(quant_data + element_offset)); + + // Adjust reorder index pointer to the start of the 4 indices for this thread iteration + const int32_t* g_idx = reorder_idx + kb_idx * block_size + ((threadIdx.x * element_per_thread) & (block_size - 1)); + + for (int i = 0; i < element_per_thread; i++) { + // Typical value of g_idx is in the range of [0, groups_per_K) for reordering groups. + // No range check here so it might have out-of-bound access if the reorder_idx is not valid. + int32_t rid = g_idx[i]; + ptrdiff_t scale_zp_offset = n_idx * groups_per_K + rid; + T scale = *(scale_data + scale_zp_offset); + + uint8_t zp = 128; // Default zero point + if (zero_points) { + zp = zero_points[scale_zp_offset]; + } + + // Extract the i-th uint8_t value + uint8_t q_val = (quant_value >> (8 * i)) & 0xFF; + + if constexpr (std::is_same_v) { + T zp_T = __ushort2half_rn(zp); + T zp_adjust = -scale * zp_T; + output_i[i] = __ushort2half_rn(q_val) * scale + zp_adjust; + } else { + T zp_T = static_cast(zp); + T zp_adjust = -scale * zp_T; + output_i[i] = static_cast(q_val) * scale + zp_adjust; + } + } +} + +template +__global__ void Dequantize8BitsKernel( + T* output, + const uint8_t* quant_data, + const T* scale_data, + const ZeroT* zero_points, + int block_size, + int groups_per_threadblock, + int total_groups) { + constexpr int element_per_thread = 4; // Process 4 elements (uint8_t) per thread using uint32_t load + int block_id = blockIdx.x * groups_per_threadblock + ((threadIdx.x * element_per_thread) / block_size); + if (block_id >= total_groups) { + return; + } + + // element_offset corresponds to the start of the 4 elements processed by this thread iteration + int element_offset = block_id * block_size + ((threadIdx.x * element_per_thread) & (block_size - 1)); + + // Read 4 uint8_t values packed into a uint32_t + uint32_t quant_value = *(reinterpret_cast(quant_data + element_offset)); + T scale = *(scale_data + block_id); // One scale per block + + T zero_point_value; + if constexpr (std::is_same_v) { + // Assuming one uint8_t zero point per block. Default 128 for uint8 asymmetric. + uint8_t zp = 128; + if (zero_points) { + zp = zero_points[block_id]; // Direct lookup, no packing + } + // Convert uint8_t zp to T (float/half) + if constexpr (std::is_same_v) { + zero_point_value = __uint2half_rn(zp); + } else { + zero_point_value = static_cast(zp); + } + } else { // ZeroT is T (float or half) + // Default 0 for float/half zero point + zero_point_value = zero_points ? *(zero_points + block_id) : static_cast(0.0f); + } + + output = output + element_offset; // Point output to the start of the 4 elements + DequantizeFourElements(quant_value, scale, zero_point_value, output); +} + +template +Status Dequantize8Bits( + T* output, + const uint8_t* quant_data, + const T* scales_data, + const ZeroT* zero_points, // Shape: [N, K_blocks] or [N * K_blocks] + const int32_t* reorder_idx, // If provided, ZeroT is expected to be uint8_t + int k, // Original dimension before padding + int n, // Other dimension + int block_size, + cudaStream_t stream) { + ORT_ENFORCE(k % block_size == 0, "k must be a multiple of block_size"); // K shall be padded to multiple of block_size. + + constexpr int element_per_thread = 4; + int groups_per_K = k / block_size; + int total_groups = n * groups_per_K; // Total number of blocks + + assert(block_size <= GridDim::maxThreadsPerBlock * element_per_thread); + int groups_per_threadblock = GridDim::maxThreadsPerBlock * element_per_thread / block_size; + int groups_per_grid = CeilDiv(total_groups, groups_per_threadblock); + + dim3 grid_dim(groups_per_grid); + dim3 block_dim(GridDim::maxThreadsPerBlock); + + DUMP_TENSOR_INIT(); + if (!reorder_idx || std::is_same_v) { + DUMP_STRING("Launch standard kernel for Dequantize8Bits"); + Dequantize8BitsKernel<<>>( + output, + quant_data, + scales_data, + zero_points, + block_size, + groups_per_threadblock, + total_groups); + } else { + if constexpr (std::is_same_v) { + DUMP_STRING("Launch reorder kernel for Dequantize8Bits"); + Dequantize8BitsKernelReOrder<<>>( + output, + quant_data, + scales_data, + (const uint8_t*)zero_points, + reorder_idx, + block_size, + groups_per_K, + groups_per_threadblock, + total_groups); + } else { + return Status(::onnxruntime::common::ONNXRUNTIME, ::onnxruntime::common::INVALID_ARGUMENT, + "Reorder kernel currently expects uint8_t zero points."); + } + } + + return CUDA_CALL(cudaGetLastError()); // Check for launch errors +} + +// Template instantiations for 8-bit +template Status Dequantize8Bits( + float* output, + const uint8_t* quant_data, + const float* scales_data, + const uint8_t* zero_points, + const int32_t* reorder_idx, + int k, + int n, + int block_size, + cudaStream_t stream); + +template Status Dequantize8Bits( + half* output, + const uint8_t* quant_data, + const half* scales_data, + const uint8_t* zero_points, + const int32_t* reorder_idx, + int k, + int n, + int block_size, + cudaStream_t stream); + +template Status Dequantize8Bits( + float* output, + const uint8_t* quant_data, + const float* scales_data, + const float* zero_points, + const int32_t* reorder_idx, + int k, + int n, + int block_size, + cudaStream_t stream); + +template Status Dequantize8Bits( + half* output, + const uint8_t* quant_data, + const half* scales_data, + const half* zero_points, + const int32_t* reorder_idx, + int k, + int n, + int block_size, + cudaStream_t stream); + +// Generic dequantization kernel for 8 bits +template < + typename ElementT, + int32_t block_size, + int32_t qbits, + bool Columnwise> +__global__ void dequantizeThread8b(ElementT* dst, + const uint8_t* weights, // Quantized data (uint8_t) + const ElementT* scales, + const uint8_t* zero_points, // Assuming uint8_t zero points + int rows, + int columns, + int thread_row_blocks) { // Number of thread blocks along row dimension + + using QuantBlk = typename BlkQuantTraits::QuantBlk; + using ThreadBlk = typename BlkQuantTraits::ThreadBlk; + + static_assert(qbits == 8, "Only 8b block quantization is supported by this kernel specialization!"); + + const auto thread_idx_global = blockIdx.x * blockDim.x + threadIdx.x; + + // Total blocks along row dim for scales/zp + const auto total_row_blks = (rows + QuantBlk::kRow - 1) / QuantBlk::kRow; + + // Total blocks along col dim for scales/zp + // const auto total_col_blks = (columns + QuantBlk::kColumn - 1) / QuantBlk::kColumn; + + // Total number of blocks to process + // const auto total_quant_blocks = total_row_blks * total_col_blks; + + // Iterate over the quantization blocks assigned to this thread + // Each thread might process multiple QuantBlks + // This loop structure assumes 1D grid/block launch. A 2D launch might map threads differently. + const auto block_idx = thread_idx_global; // Assuming 1 thread processes 1 ThreadBlk here + + // Calculate row and column block indices for this thread + // Map 1D block_idx back to 2D block indices (row_blk, col_blk) + const auto r_blk_idx_thread = static_cast(block_idx % thread_row_blocks); // Thread block index along rows + const auto c_blk_idx_thread = static_cast(block_idx / thread_row_blocks); // Thread block index along columns + + // Calculate starting row and column for this thread's work item (ThreadBlk) + int32_t r_start = r_blk_idx_thread * ThreadBlk::kRow; + int32_t c_start = c_blk_idx_thread * ThreadBlk::kColumn; + + // Check if this thread is out of bounds for the overall work + if (c_start >= columns) { + return; + } + + // Determine the actual end row/column considering matrix boundaries + int32_t r_end = std::min(r_start + ThreadBlk::kRow, rows); + int32_t c_end = std::min(c_start + ThreadBlk::kColumn, columns); + + // Process elements within the assigned ThreadBlk + for (int32_t c = c_start; c < c_end; ++c) { + // Calculate the block index for scale/zp lookup based on the current column 'c' + const auto scale_zp_col_blk_idx = c / QuantBlk::kColumn; + + // Calculate base pointer for this column in the quantized weights matrix + // Assuming weights stored column-major: shape [rows, columns] -> layout [columns, rows] + // Each element is uint8_t. + // const uint8_t* q_col_ptr = weights + static_cast(scale_zp_col_blk_idx) * rows; + + for (int32_t r = r_start; r < r_end; ++r) { + // Calculate the block index for scale/zp lookup based on current row 'r' + const auto scale_zp_row_blk_idx = r / QuantBlk::kRow; + const auto scale_zp_flat_idx = scale_zp_col_blk_idx * total_row_blks + scale_zp_row_blk_idx; + + // Get scale and zero point for this block + const ElementT scale = scales[scale_zp_flat_idx]; + const uint8_t zp_uint8 = (zero_points == nullptr) ? 128 : zero_points[scale_zp_flat_idx]; + + // Get the quantized value (uint8_t) + // Assuming weights are stored col-major for block quantization (e.g. [cols, rows/block_size, block_size]) + // Row-major logical layout for weights access: index = c * rows + r + const size_t q_val_idx = static_cast(c) * rows + r; + const uint8_t q_val = weights[q_val_idx]; + + // Dequantize + if constexpr (std::is_same::value) { + const half zp_half = __uint2half_rn(zp_uint8); + const half adjust = -scale * zp_half; + dst[q_val_idx] = __uint2half_rn(q_val) * scale + adjust; + } else { // Float + static_assert(std::is_same::value, "Only float and half are supported!"); + const float zp_float = static_cast(zp_uint8); + const float adjust = -scale * zp_float; + dst[q_val_idx] = static_cast(q_val) * scale + adjust; + } + } + } +} + +// Launcher function for the generic 8-bit kernel +template < + typename ElementT, + int32_t block_size, + int32_t qbits, + bool Columnwise> +static void dequantize8b_generic(ElementT* dst, const uint8_t* weights, const ElementT* scales, + const uint8_t* zero_points, int32_t rows, int32_t columns, + cudaStream_t stream) { + using ThreadBlk = typename BlkQuantTraits::ThreadBlk; + + const auto thread_row_blocks = (rows + ThreadBlk::kRow - 1) / ThreadBlk::kRow; + const auto thread_col_blocks = (columns + ThreadBlk::kColumn - 1) / ThreadBlk::kColumn; + const auto thread_total_blocks = thread_row_blocks * thread_col_blocks; + + const auto grids = (thread_total_blocks + GridDim::maxThreadsPerBlock - 1) / GridDim::maxThreadsPerBlock; + dequantizeThread8b<<>>( + dst, + weights, + scales, + zero_points, + rows, + columns, + thread_row_blocks); +} + +template +Status +DequantizeBlockwise8b( + T* dst, + const uint8_t* src, // Quantized uint8_t data + const T* scales, + const uint8_t* zero_points, // Assuming uint8_t zero points + int block_size, + bool columnwise, // Orientation of elements within a block + int rows, + int columns, + cudaStream_t stream) { + // Use the generic launcher, passing qbits=8 + switch (block_size) { + case 16: + if (columnwise) { + dequantize8b_generic(dst, src, scales, zero_points, rows, columns, stream); + } else { + dequantize8b_generic(dst, src, scales, zero_points, rows, columns, stream); + } + return Status::OK(); + case 32: + if (columnwise) { + dequantize8b_generic(dst, src, scales, zero_points, rows, columns, stream); + } else { + dequantize8b_generic(dst, src, scales, zero_points, rows, columns, stream); + } + return Status::OK(); + case 64: + if (columnwise) { + dequantize8b_generic(dst, src, scales, zero_points, rows, columns, stream); + } else { + dequantize8b_generic(dst, src, scales, zero_points, rows, columns, stream); + } + return Status::OK(); + case 128: + if (columnwise) { + dequantize8b_generic(dst, src, scales, zero_points, rows, columns, stream); + } else { + dequantize8b_generic(dst, src, scales, zero_points, rows, columns, stream); + } + return Status::OK(); + case 256: + if (columnwise) { + dequantize8b_generic(dst, src, scales, zero_points, rows, columns, stream); + } else { + dequantize8b_generic(dst, src, scales, zero_points, rows, columns, stream); + } + return Status::OK(); + default: + // Only block size 16, 32, 64, 128, 256 are supported. + return Status(::onnxruntime::common::ONNXRUNTIME, ::onnxruntime::common::FAIL, + "Unsupported block size for 8b blockwise quantization."); + } +} + +// Template instantiations for 8-bit blockwise +template Status DequantizeBlockwise8b( + float* dst, + const uint8_t* src, + const float* scales, + const uint8_t* zero_points, + int block_size, + bool columnwise, + int rows, + int columns, + cudaStream_t stream); + +template Status DequantizeBlockwise8b( + half* dst, + const uint8_t* src, + const half* scales, + const uint8_t* zero_points, + int block_size, + bool columnwise, + int rows, + int columns, + cudaStream_t stream); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu b/onnxruntime/contrib_ops/cuda/quantization/matmul_4bits.cu similarity index 93% rename from onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu rename to onnxruntime/contrib_ops/cuda/quantization/matmul_4bits.cu index ce6c07fbed2bc..5d634b8a929f1 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_4bits.cu @@ -1,4 +1,3 @@ -// Modifications: scaling is moved from masked softmax to the gemm before that. // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. @@ -89,7 +88,7 @@ __device__ __forceinline__ void Convert8xInt4To8xHalfs(uint32_t value, half2* ha asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(kOneSixteenth), "r"(kNeg64)); } -__device__ __forceinline__ void AccumulateEightElements(uint32_t values_quant, half scale, uint8_t zp, const half* a, half* sums) { +__device__ __forceinline__ void AccumulateEightElements4b(uint32_t values_quant, half scale, uint8_t zp, const half* a, half* sums) { half2 scale_half2 = {scale, scale}; half zp_adjust = -scale * __short2half_rn(zp); half2 zp_adjust2 = {zp_adjust, zp_adjust}; @@ -120,7 +119,7 @@ __device__ __forceinline__ void AccumulateEightElements(uint32_t values_quant, h sums_half2[3] = sums_half2[3] + v3 * (*(reinterpret_cast(&(vec_permuted.w)))); } #else -__device__ __forceinline__ void AccumulateEightElements(uint32_t values_quant, half scale, uint8_t zp, const half* a, half* sums) { +__device__ __forceinline__ void AccumulateEightElements4b(uint32_t values_quant, half scale, uint8_t zp, const half* a, half* sums) { half2 scale_half2 = {scale, scale}; half zp_adjust = -scale * __short2half_rn(zp); half2 zp_adjust2 = {zp_adjust, zp_adjust}; @@ -144,7 +143,7 @@ __device__ __forceinline__ void AccumulateEightElements(uint32_t values_quant, h } #endif -__device__ __forceinline__ void AccumulateEightElements(uint32_t values_quant, float scale, uint8_t zp, const float* a, float* sums) { +__device__ __forceinline__ void AccumulateEightElements4b(uint32_t values_quant, float scale, uint8_t zp, const float* a, float* sums) { float4 a_vec_0 = *(reinterpret_cast(a)); float4 a_vec_1 = *(reinterpret_cast(a + 4)); @@ -178,7 +177,7 @@ constexpr int kWarpSize = GPU_WARP_SIZE; // Each thread block computes [1, K] x [kColsPerThreadBlock, (K + block_size - 1)/block_size, blob], // i.e., computing kColsPerThreadBlock per block and a warp reduce (1, K) x (K) template -__global__ void __launch_bounds__(kWarpSize * kColsPerThreadBlock) MatMulFloatInt4Kernel( +__global__ void __launch_bounds__(kWarpSize* kColsPerThreadBlock) MatMulFloatInt4Kernel( T* output, const T* a_data, const uint8_t* b_data_quant, @@ -238,7 +237,7 @@ __global__ void __launch_bounds__(kWarpSize * kColsPerThreadBlock) MatMulFloatIn if constexpr (has_zero_point) { \ zp = b_zp_vec[t_meta_k + k_per_iter / block_size * i]; \ } \ - AccumulateEightElements(value, scale, zp, a_data + k_id + i * k_per_iter, sums); \ + AccumulateEightElements4b(value, scale, zp, a_data + k_id + i * k_per_iter, sums); \ } \ b_data_quant += k_per_iter / 2 * kUnroll; \ t_meta_k += k_per_iter / block_size * kUnroll; \ @@ -258,7 +257,7 @@ __global__ void __launch_bounds__(kWarpSize * kColsPerThreadBlock) MatMulFloatIn if constexpr (has_zero_point) { zp = b_zp_vec[t_meta_k]; } - AccumulateEightElements(value, scale, zp, a_data + k_id, sums); + AccumulateEightElements4b(value, scale, zp, a_data + k_id, sums); } float sum = (float)(sums[0] + sums[1] + sums[2] + sums[3] + sums[4] + sums[5] + sums[6] + sums[7]); @@ -283,7 +282,7 @@ bool TryMatMul4Bits( int n, int k, int block_size, - int shared_mem_per_block, + size_t shared_mem_per_block, cudaStream_t stream) { if (n % kColsPerThreadBlock != 0 || k % 8 != 0 || m > 1) { return false; @@ -291,8 +290,8 @@ bool TryMatMul4Bits( dim3 blocks((n + kColsPerThreadBlock - 1) / kColsPerThreadBlock, m); dim3 threads(GPU_WARP_SIZE_HOST, kColsPerThreadBlock); int blocks_per_K = (k + block_size - 1) / block_size; - int shared_mem_size = sizeof(T) * blocks_per_K * kColsPerThreadBlock + - (zero_points != nullptr ? (blocks_per_K + 1) / 2 * kColsPerThreadBlock * 2 : 0); + size_t shared_mem_size = sizeof(T) * blocks_per_K * kColsPerThreadBlock + + static_cast(zero_points != nullptr ? (blocks_per_K + 1) / 2 * kColsPerThreadBlock * 2 : 0); if (shared_mem_size > shared_mem_per_block) { return false; } @@ -333,7 +332,7 @@ template bool TryMatMul4Bits( int n, int k, int block_size, - int shared_mem_per_block, + size_t shared_mem_per_block, cudaStream_t stream); template bool TryMatMul4Bits( @@ -346,7 +345,7 @@ template bool TryMatMul4Bits( int n, int k, int block_size, - int shared_mem_per_block, + size_t shared_mem_per_block, cudaStream_t stream); } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_8bits.cu b/onnxruntime/contrib_ops/cuda/quantization/matmul_8bits.cu new file mode 100644 index 0000000000000..85ace5d39b24b --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_8bits.cu @@ -0,0 +1,469 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include +#include +#include "core/providers/cuda/cu_inc/common.cuh" +#include "core/providers/cuda/cuda_common.h" +#include "matmul_nbits.cuh" + +using namespace onnxruntime::cuda; +using namespace cub; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +constexpr int kColsPerThreadBlock = 8; +constexpr int kElementsPerThreadPerIteration = 8; +constexpr int kWarpSize = GPU_WARP_SIZE; // Typically 32 +constexpr uint8_t kDefaultZeroPoint = 128; +constexpr int kKernelAlgo = 0; // Choose algorithm here: 0 (unroll), 1 (simple loop), 2 (block size iteration) +constexpr bool kUseCUB = true; +constexpr bool kUseFloatInPartialSum = false; // Use float to accumulate partial sum of 8 half elements in a thread. Default is false like 4 bits kernel. + +__device__ __forceinline__ void AccumulateEightElements8b(uint64_t values_quant, half scale, uint8_t zp, const half* a, half* sums) { + // --- Dequantization Setup --- + // Convert scale and zero point to half format suitable for half2 operations + half2 scale_h2 = __half2half2(scale); // Broadcast scale to half2 + half zp_h = __ushort2half_rn(zp); // Convert uint8 zp to half + half2 zp_h2 = __half2half2(zp_h); // Broadcast zp to half2 + + // --- Extract 8 uint8_t values from the 64-bit input --- + uint8_t q[8]; +#pragma unroll + for (int i = 0; i < 8; ++i) { + q[i] = (values_quant >> (i * 8)) & 0xFF; + } + + // --- Dequantize 8 values into 4 half2 vectors: b_vec = (q - zp) * scale --- + // Convert uint8 q values to half2 vectors {q0,q1}, {q2,q3}, {q4,q5}, {q6,q7} + half2 q_01 = __halves2half2(__ushort2half_rn(q[0]), __ushort2half_rn(q[1])); + half2 q_23 = __halves2half2(__ushort2half_rn(q[2]), __ushort2half_rn(q[3])); + half2 q_45 = __halves2half2(__ushort2half_rn(q[4]), __ushort2half_rn(q[5])); + half2 q_67 = __halves2half2(__ushort2half_rn(q[6]), __ushort2half_rn(q[7])); + + // Calculate q - zp + half2 diff_01 = __hsub2(q_01, zp_h2); + half2 diff_23 = __hsub2(q_23, zp_h2); + half2 diff_45 = __hsub2(q_45, zp_h2); + half2 diff_67 = __hsub2(q_67, zp_h2); + + // Calculate b_vec = (q - zp) * scale + half2 b_vec0 = __hmul2(diff_01, scale_h2); // {b0, b1} + half2 b_vec1 = __hmul2(diff_23, scale_h2); // {b2, b3} + half2 b_vec2 = __hmul2(diff_45, scale_h2); // {b4, b5} + half2 b_vec3 = __hmul2(diff_67, scale_h2); // {b6, b7} + + // --- Load Input A (8 half values as 4 half2 vectors) --- + // Directly cast 'a' pointer to read half2 vectors. + // This assumes 'a' is properly aligned for half2 reads. + const half2* a_half2 = reinterpret_cast(a); + half2 a_vec0 = a_half2[0]; // {a0, a1} + half2 a_vec1 = a_half2[1]; // {a2, a3} + half2 a_vec2 = a_half2[2]; // {a4, a5} + half2 a_vec3 = a_half2[3]; // {a6, a7} + + // --- Accumulate: sums += a * b_vec using half2 FMA --- + // Cast sums pointer to half2* for vectorized accumulation. + half2* sums_half2 = reinterpret_cast(sums); + sums_half2[0] = __hfma2(a_vec0, b_vec0, sums_half2[0]); // {s0+=a0*b0, s1+=a1*b1} + sums_half2[1] = __hfma2(a_vec1, b_vec1, sums_half2[1]); // {s2+=a2*b2, s3+=a3*b3} + sums_half2[2] = __hfma2(a_vec2, b_vec2, sums_half2[2]); // {s4+=a4*b4, s5+=a5*b5} + sums_half2[3] = __hfma2(a_vec3, b_vec3, sums_half2[3]); // {s6+=a6*b6, s7+=a7*b7} +} + +// --- Keep Original Float Version --- +__device__ __forceinline__ void AccumulateEightElements8b(uint64_t values_quant, float scale, uint8_t zp, const float* a, float* sums) { + float4 a_vec_0 = *(reinterpret_cast(a)); + float4 a_vec_1 = *(reinterpret_cast(a + 4)); + + float zp_adjust = -scale * float(zp); + + // Extract and dequantize 8 float values + float v[8]; +#pragma unroll + for (int i = 0; i < 8; ++i) { + uint8_t q_val = (values_quant >> (i * 8)) & 0xFF; + v[i] = float(q_val) * scale + zp_adjust; + } + + // Accumulate using fmaf for potentially better precision/performance + sums[0] = fmaf(v[0], a_vec_0.x, sums[0]); + sums[1] = fmaf(v[1], a_vec_0.y, sums[1]); + sums[2] = fmaf(v[2], a_vec_0.z, sums[2]); + sums[3] = fmaf(v[3], a_vec_0.w, sums[3]); + sums[4] = fmaf(v[4], a_vec_1.x, sums[4]); + sums[5] = fmaf(v[5], a_vec_1.y, sums[5]); + sums[6] = fmaf(v[6], a_vec_1.z, sums[6]); + sums[7] = fmaf(v[7], a_vec_1.w, sums[7]); +} + +// kernel for 8bits quantized GEMM, i.e., computing C(M, N) = A(M, K) x B(K, N) +// B(K, N) is quantized with 8bits and block_size bs and stored as [N, K/bs, bs] +// kColsPerThreadBlock (C) = 8 is the number of columns each thread block computes per row of A +// kElementsPerThreadPerIteration (E) = 8 is the number of elements each thread computes in one iteration along K +// Constraints: N % C == 0, K % E == 0 +// The thread block size is (kWarpSize, C) = (32, 8) +// Grid size is (Ceil(N / C), M) +// Each thread block computes kColsPerThreadBlock columns for a specific row m_id +template +__global__ void __launch_bounds__(kWarpSize* kColsPerThreadBlock) MatMulFloat8bKernel( + T* output, // Base pointer for Output C [M, N] + const T* a_data, // Base pointer for A [M, K] + const uint8_t* b_data_quant, // Base pointer for B [N, K/bs, bs] + const T* scales_data, // Base pointer for scales [N, K/bs] + const uint8_t* zero_points, // Base pointer for ZPs [N, K/bs] + int m, // Number of rows in A and Output C + int n, // Number of columns in B and Output C (Constraint: N % C == 0) + int k, // Number of columns in A / rows in B (Constraint: K % E == 0) + int blocks_per_K) { // blocks_per_K = K/bs + + const int n_block_id = blockIdx.x; // Block column index in the range of [0, Ceil(N / C)) + const int m_id = blockIdx.y; // Block row index (identifies the row of A and C) [0, M) + + // Check if this block is needed for the M dimension + if (m_id >= m) return; + + const int lane_id = threadIdx.x; // Thread index in warp (0..31) + const int warp_id = threadIdx.y; // Warp index 0..7 in the range of [0, C-1) + const int n_block_head = n_block_id * kColsPerThreadBlock; // Head column index for this block [0, N) + const int n_id = n_block_head + warp_id; // Global output column index this warp computes + + // Ensure n_id does not go out of bounds (already checked by TryMatMul8Bits, but safer) + if (n_id >= n) return; + + extern __shared__ char shared_buffer[]; + + // Load scales to shared_buffer + T* b_scale_vec_shared = (T*)shared_buffer; + for (int i = threadIdx.y * kWarpSize + threadIdx.x; i < kColsPerThreadBlock * blocks_per_K; i += kColsPerThreadBlock * kWarpSize) { + // Boundary check needed if N is not perfectly divisible by kColsPerThreadBlock * blocks_per_K, + // though the N constraint N % C == 0 helps simplify this for scales/ZPs. + int current_n = n_block_head + (i / blocks_per_K); + int current_k_block = i % blocks_per_K; + if (current_n < n) { // Check if the column is valid + b_scale_vec_shared[i] = scales_data[static_cast(current_n) * blocks_per_K + current_k_block]; + } + } + + // Load zero points if they exist (logic remains the same, depends on n_block_id) + [[maybe_unused]] uint8_t* b_zp_vec_shared = nullptr; + [[maybe_unused]] const uint8_t* b_zp_vec_thread = nullptr; // Thread's ZP pointer + if constexpr (has_zero_point) { + b_zp_vec_shared = reinterpret_cast(b_scale_vec_shared + kColsPerThreadBlock * blocks_per_K); + for (int i = threadIdx.y * kWarpSize + threadIdx.x; i < kColsPerThreadBlock * blocks_per_K; i += kColsPerThreadBlock * kWarpSize) { + int current_n = n_block_head + (i / blocks_per_K); + int current_k_block = i % blocks_per_K; + if (current_n < n) { // Check if the column is valid + b_zp_vec_shared[i] = zero_points[static_cast(current_n) * blocks_per_K + current_k_block]; + } + } + b_zp_vec_thread = b_zp_vec_shared + warp_id * blocks_per_K; + } + + __syncthreads(); // Ensure scales and ZPs are loaded + + // Point a_data to the correct row based on m_id + const T* a_row_data = a_data + static_cast(m_id) * k; + + // Each thread calculates its part of the dot product along K. + // Point to the start of the elements this thread is responsible for in the current row of A. + const int lane_offset = lane_id * kElementsPerThreadPerIteration; + const T* a_thread_data_base = a_row_data + lane_offset; // Base pointer for this thread in row m_id + + // Pointer to the start of B data for the specific column n_id this warp handles. + // Layout of B is [N, K/bs, bs]. + const uint8_t* b_base_ptr_n = b_data_quant + static_cast(n_id) * blocks_per_K * block_size; + + // Pointer to the start of scales for the specific column (n_id) this warp handles (from shared mem). + const T* b_scale_vec_thread = b_scale_vec_shared + warp_id * blocks_per_K; + + T sums[kElementsPerThreadPerIteration] = {static_cast(0.0f)}; // Initialize sums to zero + + if constexpr (kKernelAlgo == 0 || kKernelAlgo == 1) { + // Note that k_per_iter (typical value is 256) is multiple of block_size (typical value is 16, 32, 64, 128, or 256). + constexpr int k_per_iter = kWarpSize * kElementsPerThreadPerIteration; + + int k_id = 0; + // Pointer to B data for this thread's starting element in K, for column n_id. + // B layout: [N, K/bs, bs]. Access is effectively [n_id, k_block, k_within_block] + // Pointer for thread should start at its `lane_offset` within the K dimension for column `n_id`. + const uint8_t* b_data_quant_thread = b_base_ptr_n + lane_offset; + + if constexpr (kKernelAlgo == 0) { // Algorithm 0: Unrolling + int k_start_block = lane_offset / block_size; // Block index in K for thread start + +#define UnRollReduction(unroll_size) \ + do { \ + constexpr int kUnroll = unroll_size; \ + constexpr int kElementsPerUnrollIter = k_per_iter * kUnroll; \ + for (; k_id + kElementsPerUnrollIter <= k; k_id += kElementsPerUnrollIter) { \ + _Pragma("unroll") for (int i = 0; i < kUnroll; ++i) { \ + const uint8_t* current_b_ptr = b_data_quant_thread + k_id + i * k_per_iter; \ + /* Assume alignment allows uint64_t load */ \ + uint64_t value = *reinterpret_cast(current_b_ptr); \ + /* Requires k_per_iter % block_size == 0 */ \ + int current_meta_k = k_start_block + (k_id / block_size) + i * (k_per_iter / block_size); \ + T scale = b_scale_vec_thread[current_meta_k]; \ + uint8_t zp = kDefaultZeroPoint; \ + if constexpr (has_zero_point) { \ + zp = b_zp_vec_thread[current_meta_k]; \ + } \ + /* Pass pointer to A for the current k segment */ \ + AccumulateEightElements8b(value, scale, zp, a_thread_data_base + k_id + i * k_per_iter, sums); \ + } \ + } \ + } while (false) + + UnRollReduction(16); + UnRollReduction(4); + UnRollReduction(1); + +#undef UnRollReduction + } else { // Algorithm 1: Simple loop + for (; k_id + k_per_iter <= k; k_id += k_per_iter) { + const uint8_t* current_b_ptr = b_data_quant_thread + k_id; + uint64_t value = *reinterpret_cast(current_b_ptr); + + int current_meta_k = (lane_offset + k_id) / block_size; + T scale = b_scale_vec_thread[current_meta_k]; + uint8_t zp = kDefaultZeroPoint; + if constexpr (has_zero_point) { + zp = b_zp_vec_thread[current_meta_k]; + } + /* Pass pointer to A for the current k segment */ + AccumulateEightElements8b(value, scale, zp, a_thread_data_base + k_id, sums); + } + } + + // Handle the tail elements (less than k_per_iter) if k is not multiple of k_per_iter + // Since k % kElementsPerThreadPerIteration == 0 is enforced, the tail is simpler. + // Each thread processes its remaining elements if its start offset is < k. + if (lane_offset + k_id < k) { // Check if this thread has any elements left + const uint8_t* current_b_ptr = b_data_quant_thread + k_id; + uint64_t value = *reinterpret_cast(current_b_ptr); + + int current_meta_k = (lane_offset + k_id) / block_size; + T scale = b_scale_vec_thread[current_meta_k]; + uint8_t zp = kDefaultZeroPoint; + if constexpr (has_zero_point) { + zp = b_zp_vec_thread[current_meta_k]; + } + /* Pass pointer to A for the current k segment */ + AccumulateEightElements8b(value, scale, zp, a_thread_data_base + k_id, sums); + } + } else { // Algorithm 2: block size iteration. + for (int block_k_idx = 0; block_k_idx < blocks_per_K; ++block_k_idx) { + int k_start_block = block_k_idx * block_size; + // B data pointer for the start of this block, for column n_id + const uint8_t* b_block_ptr_n = b_base_ptr_n + k_start_block; + + // Get scale/zp for this block (already loaded for the warp) + T scale = b_scale_vec_thread[block_k_idx]; + uint8_t zp = kDefaultZeroPoint; + if constexpr (has_zero_point) { + zp = b_zp_vec_thread[block_k_idx]; + } + + // Each thread `lane_id` handles elements starting at `lane_offset` within the K dimension. + // Calculate the base K index for this thread *within the current block*. + int k_offset_in_block = lane_offset; // Offset relative to block start + + // Check if the *start* of the 8 elements this thread would process is within the current block + // AND within the bounds of K dimension. + // Since K % 8 == 0 and thread handles 8 elements, if the thread's work starts within K, + // it finishes within K. We only need to check if the *block* has enough elements + // for this thread's starting offset. + if (k_offset_in_block < block_size && k_start_block + k_offset_in_block < k) { + // Calculate absolute K index for A pointer + int k_abs_idx = k_start_block + k_offset_in_block; + + // Since K % 8 == 0 is enforced, we don't need partial element handling. + // We are guaranteed to process 8 elements if the start is valid. + const uint8_t* current_b_ptr = b_block_ptr_n + k_offset_in_block; // Offset within the block + uint64_t value = *reinterpret_cast(current_b_ptr); + + // Pointer to A data corresponding to this thread's elements within this block + const T* current_a_ptr = a_row_data + k_abs_idx; // Use a_row_data + absolute k index + + AccumulateEightElements8b(value, scale, zp, current_a_ptr, sums); + } + } + } + + // Sum the 8 partial sums within each thread first. + float total_sum_thread = 0.0f; + + if constexpr (std::is_same_v) { + if constexpr (kUseFloatInPartialSum) { +// Convert 8 elements to float, then accumulate. +#pragma unroll + for (int i = 0; i < kElementsPerThreadPerIteration; ++i) { + total_sum_thread += __half2float(sums[i]); + } + } else { + // Accumulate 8 elements in half, then convert sum to float once. + T temp_sum = static_cast(0.0f); +#pragma unroll + for (int i = 0; i < kElementsPerThreadPerIteration; ++i) { + temp_sum += sums[i]; + } + total_sum_thread = __half2float(temp_sum); + } + } else { +#pragma unroll + for (int i = 0; i < kElementsPerThreadPerIteration; ++i) { + total_sum_thread += sums[i]; + } + } + + if constexpr (!kUseCUB) { + for (int i = kWarpSize / 2; i > 0; i = i / 2) { + total_sum_thread += __shfl_down_sync(0xFFFFFFFF, total_sum_thread, i); + } + + if (lane_id == 0) { + // Calculate output index: output[m_id, n_id] + output[static_cast(m_id) * n + n_id] = static_cast(total_sum_thread); + } + } else { + // Use CUB for efficient warp reduction + using BlockReduce = cub::WarpReduce; + + // Shared memory for CUB reduction storage (one per warp) + __shared__ typename BlockReduce::TempStorage temp_storage[kColsPerThreadBlock]; + total_sum_thread = BlockReduce(temp_storage[warp_id]).Sum(total_sum_thread); + + if (lane_id == 0) { + // Write the final result for the element C[m_id, n_id] + output[static_cast(m_id) * n + n_id] = static_cast(total_sum_thread); + } + } +} + +template +bool TryMatMul8Bits( + T* output, // Output C [M, N] + const T* a_data, // Input A [M, K] + const uint8_t* b_data_quant, // Input B Quantized [N, K/bs, bs] + const T* scales_data, // Scales [N, K/bs] + const uint8_t* zero_points, // Zero Points [N, K/bs] (can be nullptr) + int m, // Rows of A and C (M >= 1) + int n, // Columns of B and C + int k, // Columns of A / Rows of B + int block_size, // Quantization block size for B + size_t shared_mem_per_block, // Available shared memory + cudaStream_t stream) { + // Constraints Check + // N must be a multiple of kColsPerThreadBlock (8) for warps to align with columns. + // K must be a multiple of kElementsPerThreadPerIteration (8) for full uint64_t reads/processing per thread iter. + if (n % kColsPerThreadBlock != 0 || k % kElementsPerThreadPerIteration != 0) { + return false; + } + + // Ensure k_per_iter is multiple of block_size for algo 0. + if constexpr (kKernelAlgo == 0) { + constexpr int k_per_iter = kWarpSize * kElementsPerThreadPerIteration; + if (k_per_iter % block_size != 0) { + return false; + } + } + + if constexpr (kKernelAlgo == 1 || kKernelAlgo == 2) { + if (k % block_size != 0) { + // The indexing `(lane_offset + k_id) / block_size` in Algo 1 and the block iteration + // in Algo 2 rely on K being compatible with block_size for correct scale/zp lookup. + // While blocks_per_K handles rounding up, the core loops assume alignment. + // If K is not multiple of block_size, the last block is partial, potentially + // causing issues with scale/zp indexing. + // Let's enforce K % block_size == 0 for simplicity/correctness guarantee here. + return false; + } + } + + // Grid and Thread Block Configuration + dim3 threads(kWarpSize, kColsPerThreadBlock); // (32, 8) + dim3 blocks((n + kColsPerThreadBlock - 1) / kColsPerThreadBlock, m); + + int blocks_per_K = (k + block_size - 1) / block_size; // K / block_size rounded up + + // Shared memory needed for scales and optionally zero points for the columns handled by the block + size_t shared_mem_size = (sizeof(T) + (zero_points != nullptr ? sizeof(uint8_t) : 0)) * blocks_per_K * kColsPerThreadBlock; + + // Add shared memory for CUB reduction storage if used + if constexpr (kUseCUB) { + shared_mem_size += static_cast(kColsPerThreadBlock) * sizeof(typename cub::WarpReduce::TempStorage); + } + + // Check if required shared memory exceeds limits + if (shared_mem_size > shared_mem_per_block) { + return false; + } + + // Macro simplifies dispatching for different block sizes and presence of zero_points +#define MatMulFloat8bKernelDispatch(bs) \ + if (nullptr != zero_points) { \ + /* Launch kernel with zero points */ \ + MatMulFloat8bKernel<<>>( \ + output, a_data, b_data_quant, scales_data, zero_points, m, n, k, blocks_per_K); \ + } else { \ + /* Launch kernel without zero points */ \ + MatMulFloat8bKernel<<>>( \ + output, a_data, b_data_quant, scales_data, nullptr /*zero_points*/, m, n, k, blocks_per_K); \ + } + + // Dispatch based on block_size value + if (16 == block_size) { + MatMulFloat8bKernelDispatch(16); + } else if (32 == block_size) { + MatMulFloat8bKernelDispatch(32); + } else if (64 == block_size) { + MatMulFloat8bKernelDispatch(64); + } else if (128 == block_size) { + MatMulFloat8bKernelDispatch(128); + } else if (256 == block_size) { + MatMulFloat8bKernelDispatch(256); + } else { + // Unsupported block size. + return false; + } + +#undef MatMulFloat8bKernelDispatch + + // Here we do not use cudaGetLastError() to check kernel launch errors. That will be done later. + return true; +} + +// Template instantiations +template bool TryMatMul8Bits( + float* output, + const float* a_data, + const uint8_t* b_data_quant, + const float* scales_data, + const uint8_t* zero_points, + int m, + int n, + int k, + int block_size, + size_t shared_mem_per_block, + cudaStream_t stream); + +template bool TryMatMul8Bits( + half* output, + const half* a_data, + const uint8_t* b_data_quant, + const half* scales_data, + const uint8_t* zero_points, + int m, + int n, + int k, + int block_size, + size_t shared_mem_per_block, + cudaStream_t stream); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc index 1cec6f6a12f1c..33265744f3a7d 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc @@ -8,6 +8,7 @@ #include "core/common/status.h" #include "core/framework/float16.h" #include "core/providers/cpu/math/matmul_helper.h" +#include "contrib_ops/cuda/utils/dump_cuda_tensor.h" #include "matmul_nbits.cuh" #include "dequantize_blockwise.cuh" @@ -23,6 +24,10 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { const Tensor* scales = ctx->Input(2); const Tensor* zero_points = ctx->Input(3); const Tensor* reorder_idx = ctx->Input(4); + const Tensor* bias = ctx->Input(5); + if (bias != nullptr) { + ORT_THROW("MatMulNBits does not support bias in CUDA kernel"); + } const auto* a_data = a->Data(); const uint8_t* blob_data = b->Data(); @@ -40,80 +45,133 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { helper.Compute(a->Shape(), b_shape, transa, transb)); Tensor* Y = ctx->Output(0, helper.OutputShape()); + // Bail out early if the output is going to be empty - if (Y->Shape().Size() == 0) return Status::OK(); - - bool is_4bit_done = (reorder_idx_data == nullptr) && - (!zero_points || !zero_points->IsDataType()) && - TryMatMul4Bits( - reinterpret_cast(Y->MutableData()), - reinterpret_cast(a_data), - blob_data, - reinterpret_cast(scales_data), - static_cast(zero_points_data), - SafeInt(helper.M()), - SafeInt(helper.N()), - SafeInt(helper.K()), - SafeInt(block_size_), - SafeInt(GetDeviceProp().sharedMemPerBlock), - static_cast(ctx->GetComputeStream()->GetHandle())); - - if (is_4bit_done) { + if (Y->Shape().Size() == 0) return Status::OK(); + + if ((reorder_idx_data == nullptr) && (!zero_points || !zero_points->IsDataType())) { + bool done = (nbits_ == 8) ? TryMatMul8Bits( + reinterpret_cast(Y->MutableData()), + reinterpret_cast(a_data), + blob_data, + reinterpret_cast(scales_data), + static_cast(zero_points_data), + SafeInt(helper.M()), + SafeInt(helper.N()), + SafeInt(helper.K()), + SafeInt(block_size_), + GetDeviceProp().sharedMemPerBlock, + static_cast(ctx->GetComputeStream()->GetHandle())) + : TryMatMul4Bits( + reinterpret_cast(Y->MutableData()), + reinterpret_cast(a_data), + blob_data, + reinterpret_cast(scales_data), + static_cast(zero_points_data), + SafeInt(helper.M()), + SafeInt(helper.N()), + SafeInt(helper.K()), + SafeInt(block_size_), + GetDeviceProp().sharedMemPerBlock, + static_cast(ctx->GetComputeStream()->GetHandle())); + if (done) { + return Status::OK(); + } } int64_t K_padded = (K_ + block_size_ - 1) / block_size_ * block_size_; IAllocatorUniquePtr b_data_ptr = GetScratchBuffer(N_ * K_padded, ctx->GetComputeStream()); auto* b_data = b_data_ptr.get(); - if (column_wise_quant_blk_) { - if (reorder_idx) { - ORT_ENFORCE(K_padded == reorder_idx->Shape()[0], "K_padded != g_idx->Shape()[0]"); - } - // column-wise block - if ((zero_points && zero_points->IsDataType())) { - ORT_RETURN_IF_ERROR(Dequantize4Bits( + + if (nbits_ == 8) { + if (column_wise_quant_blk_) { + if (reorder_idx) { + ORT_ENFORCE(K_padded == reorder_idx->Shape()[0], "K_padded != g_idx->Shape()[0]"); + } + if (zero_points && zero_points->IsDataType()) { + ORT_RETURN_IF_ERROR(Dequantize8Bits( + reinterpret_cast(b_data), + blob_data, + reinterpret_cast(scales_data), + (const CudaT*)zero_points_data, + reorder_idx_data, + SafeInt(K_padded), + SafeInt(N_), + SafeInt(block_size_), + static_cast(ctx->GetComputeStream()->GetHandle()))); + } else { + ORT_RETURN_IF_ERROR(Dequantize8Bits( + reinterpret_cast(b_data), + blob_data, + reinterpret_cast(scales_data), + (const uint8_t*)zero_points_data, + reorder_idx_data, + SafeInt(K_padded), + SafeInt(N_), + SafeInt(block_size_), + static_cast(ctx->GetComputeStream()->GetHandle()))); + } + } else { // row-wise block + ORT_RETURN_IF_ERROR(DequantizeBlockwise8b( reinterpret_cast(b_data), blob_data, reinterpret_cast(scales_data), - (const CudaT*)zero_points_data, - reorder_idx_data, - SafeInt(K_padded), - SafeInt(N_), + (const uint8_t*)zero_points_data, SafeInt(block_size_), + column_wise_quant_blk_, + SafeInt(K_), + SafeInt(N_), static_cast(ctx->GetComputeStream()->GetHandle()))); + } + } else { // 4 bits + if (column_wise_quant_blk_) { + if (reorder_idx) { + ORT_ENFORCE(K_padded == reorder_idx->Shape()[0], "K_padded != g_idx->Shape()[0]"); + } + // column-wise block + if ((zero_points && zero_points->IsDataType())) { + ORT_RETURN_IF_ERROR(Dequantize4Bits( + reinterpret_cast(b_data), + blob_data, + reinterpret_cast(scales_data), + (const CudaT*)zero_points_data, + reorder_idx_data, + SafeInt(K_padded), + SafeInt(N_), + SafeInt(block_size_), + static_cast(ctx->GetComputeStream()->GetHandle()))); + } else { + ORT_RETURN_IF_ERROR(Dequantize4Bits( + reinterpret_cast(b_data), + blob_data, + reinterpret_cast(scales_data), + (const uint8_t*)zero_points_data, + reorder_idx_data, + SafeInt(K_padded), + SafeInt(N_), + SafeInt(block_size_), + static_cast(ctx->GetComputeStream()->GetHandle()))); + } } else { - ORT_RETURN_IF_ERROR(Dequantize4Bits( + // row-wise block + K_padded = K_; + + ORT_RETURN_IF_ERROR(DequantizeBlockwise4b( reinterpret_cast(b_data), blob_data, reinterpret_cast(scales_data), (const uint8_t*)zero_points_data, - reorder_idx_data, - SafeInt(K_padded), - SafeInt(N_), SafeInt(block_size_), + column_wise_quant_blk_, + SafeInt(K_), + SafeInt(N_), static_cast(ctx->GetComputeStream()->GetHandle()))); } - } else { - // row-wise block - K_padded = K_; - - ORT_RETURN_IF_ERROR(DequantizeBlockwise4b( - reinterpret_cast(b_data), - blob_data, - reinterpret_cast(scales_data), - (const uint8_t*)zero_points_data, - SafeInt(block_size_), - column_wise_quant_blk_, - SafeInt(K_), - SafeInt(N_), - static_cast(ctx->GetComputeStream()->GetHandle()))); } -#if 0 -cudaStreamSynchronize(static_cast(ctx->GetComputeStream()->GetHandle())); -T* b_data_cpu = new T[K_ * N_]; -cudaMemcpy(b_data_cpu, b_data, K_ * N_ * sizeof(T), cudaMemcpyDeviceToHost); -delete[] b_data_cpu; -#endif + + DUMP_TENSOR_INIT(); + DUMP_TENSOR_D("DeQuantized", b_data, N_, K_padded); const CudaT alpha = ToCudaType::FromFloat(1.f); const CudaT zero = ToCudaType::FromFloat(0.f); diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cuh b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cuh index 9ccbe4c4d97a8..fe7098b92cba8 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cuh +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cuh @@ -19,7 +19,21 @@ bool TryMatMul4Bits( int n, int k, int block_size, - int shared_mem_per_block, + size_t shared_mem_per_block, + cudaStream_t stream); + +template +bool TryMatMul8Bits( + T* output, + const T* a_data, + const uint8_t* b_data_quant, + const T* scales_data, + const uint8_t* zero_points, + int m, + int n, + int k, + int block_size, + size_t shared_mem_per_block, cudaStream_t stream); } // namespace cuda diff --git a/onnxruntime/test/contrib_ops/matmul_8bits_test.cc b/onnxruntime/test/contrib_ops/matmul_8bits_test.cc index 070f8a3330e8c..b29fc5181eb46 100644 --- a/onnxruntime/test/contrib_ops/matmul_8bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_8bits_test.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #ifndef ORT_MINIMAL_BUILD +#if (defined(MLAS_TARGET_AMD64_IX86) && !defined(USE_DML) && !defined(USE_WEBGPU) && !defined(USE_COREML)) || defined(USE_CUDA) #include @@ -59,21 +60,16 @@ struct TestOptions8Bits { } template -void RunTest8Bits(const TestOptions8Bits& opts, - std::vector>&& explicit_eps = {}) { +void RunTest8Bits(const TestOptions8Bits& opts) { SCOPED_TRACE(opts); - explicit_eps.push_back(DefaultCpuExecutionProvider()); - - static_assert(std::is_same_v, "unexpected type for T1"); - const int64_t M = opts.M, K = opts.K, N = opts.N; RandomValueGenerator random{1234}; - std::vector input0_vals(random.Gaussian(AsSpan({M, K}), 0.0f, 0.25f)); - std::vector input1_f_vals(random.Gaussian(AsSpan({K, N}), 0.0f, 0.25f)); + std::vector input0_fp32_vals(random.Gaussian(AsSpan({M, K}), 0.0f, 0.25f)); + std::vector input1_fp32_vals(random.Gaussian(AsSpan({K, N}), 0.0f, 0.25f)); int q_rows, q_cols; MlasBlockwiseQuantizedShape(static_cast(opts.block_size), /* columnwise */ true, @@ -96,7 +92,7 @@ void RunTest8Bits(const TestOptions8Bits& opts, input1_vals.data(), scales.data(), opts.has_zero_point ? zp.data() : nullptr, - input1_f_vals.data(), + input1_fp32_vals.data(), static_cast(opts.block_size), true, static_cast(K), @@ -106,7 +102,7 @@ void RunTest8Bits(const TestOptions8Bits& opts, // Note that raw_vals is NxK after dequant MlasDequantizeBlockwise( - input1_f_vals.data(), + input1_fp32_vals.data(), input1_vals.data(), scales.data(), opts.has_zero_point ? zp.data() : nullptr, @@ -129,7 +125,7 @@ void RunTest8Bits(const TestOptions8Bits& opts, for (int64_t n = 0; n < N; n++) { float sum = 0.0f; for (int64_t k = 0; k < K; k++) { - sum += input0_vals[m * K + k] * input1_f_vals[n * K + k]; + sum += input0_fp32_vals[m * K + k] * input1_fp32_vals[n * K + k]; } expected_vals[m * N + n] = sum + (bias.has_value() ? (*bias)[n] : 0.0f); } @@ -141,9 +137,19 @@ void RunTest8Bits(const TestOptions8Bits& opts, test.AddAttribute("block_size", opts.block_size); test.AddAttribute("bits", QBits); test.AddAttribute("accuracy_level", opts.accuracy_level); - test.AddInput("A", {M, K}, input0_vals, false); + if constexpr (std::is_same::value) { + test.AddInput("A", {M, K}, input0_fp32_vals, false); + } else { + test.AddInput("A", {M, K}, FloatsToMLFloat16s(input0_fp32_vals), false); + } + test.AddInput("B", {q_cols, q_rows}, input1_vals, true); - test.AddInput("scales", {static_cast(q_scale_size)}, scales, true); + + if constexpr (std::is_same::value) { + test.AddInput("scales", {static_cast(q_scale_size)}, scales, true); + } else { + test.AddInput("scales", {static_cast(q_scale_size)}, FloatsToMLFloat16s(scales), true); + } if (opts.has_zero_point) { test.AddInput("zero_points", {static_cast(q_zp_size_in_bytes)}, zp, true); @@ -154,12 +160,20 @@ void RunTest8Bits(const TestOptions8Bits& opts, test.AddOptionalInputEdge(); if (bias.has_value()) { - test.AddInput("bias", bias_shape, *bias, true); + if constexpr (std::is_same::value) { + test.AddInput("bias", bias_shape, *bias, true); + } else { + test.AddInput("bias", bias_shape, FloatsToMLFloat16s(*bias), true); + } } else { test.AddOptionalInputEdge(); } - test.AddOutput("Y", {M, N}, expected_vals); + if constexpr (std::is_same::value) { + test.AddOutput("Y", {M, N}, expected_vals); + } else { + test.AddOutput("Y", {M, N}, FloatsToMLFloat16s(expected_vals)); + } if (opts.output_abs_error.has_value()) { test.SetOutputAbsErr("Y", *opts.output_abs_error); @@ -169,11 +183,21 @@ void RunTest8Bits(const TestOptions8Bits& opts, test.SetOutputRelErr("Y", *opts.output_rel_error); } - if (!explicit_eps.empty()) { - test.ConfigEps(std::move(explicit_eps)); - } - + std::vector> execution_providers; +#ifdef USE_CUDA + execution_providers.emplace_back(DefaultCudaExecutionProvider()); + test.ConfigEps(std::move(execution_providers)); test.RunWithConfig(); + execution_providers.clear(); +#else + if constexpr (std::is_same::value) { + if (MlasIsQNBitGemmAvailable(8, 32, SQNBIT_CompInt8)) { + execution_providers.emplace_back(DefaultCpuExecutionProvider()); + test.ConfigEps(std::move(execution_providers)); + test.RunWithConfig(); + } + } +#endif } template @@ -202,6 +226,8 @@ void TestMatMul8BitsTyped() { RunTest8Bits(opts); } +// CUDA does not support bias for MatMulNBits +#if not defined(USE_CUDA) { TestOptions8Bits opts = base_opts; opts.has_bias = true; @@ -214,16 +240,11 @@ void TestMatMul8BitsTyped() { opts.has_bias = true; RunTest8Bits(opts); } +#endif } - } // namespace -#if defined(MLAS_TARGET_AMD64_IX86) && !defined(USE_DML) && !defined(USE_WEBGPU) && !defined(USE_COREML) -TEST(MatMulNBits, Float32_8b_Accuracy4) { - if (!MlasIsQNBitGemmAvailable(8, 32, SQNBIT_CompInt8)) { - GTEST_SKIP() << "Skipping test because MlasIsQNBitGemmAvailable(8, 32, SQNBIT_CompInt8) is false"; - } - +TEST(MatMulNBits, Float32_8b_AccuracyLevel4_Float) { TestMatMul8BitsTyped(); TestMatMul8BitsTyped(); TestMatMul8BitsTyped(); @@ -231,9 +252,13 @@ TEST(MatMulNBits, Float32_8b_Accuracy4) { TestMatMul8BitsTyped(); TestMatMul8BitsTyped(); TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); TestMatMul8BitsTyped(); TestMatMul8BitsTyped(); TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); TestMatMul8BitsTyped(); TestMatMul8BitsTyped(); TestMatMul8BitsTyped(); @@ -251,9 +276,18 @@ TEST(MatMulNBits, Float32_8b_Accuracy4) { TestMatMul8BitsTyped(); TestMatMul8BitsTyped(); TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); } -#endif // defined(MLAS_TARGET_AMD64_IX86) && !defined(USE_DML) && !defined(USE_WEBGPU) && !defined(USE_COREML) + +#ifdef USE_CUDA +TEST(MatMulNBits, Float32_8b_AccuracyLevel4_Float16) { + TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); +} +#endif + } // namespace test } // namespace onnxruntime +#endif #endif // ORT_MINIMAL_BUILD diff --git a/onnxruntime/test/python/quantization/test_op_matmul_8bits.py b/onnxruntime/test/python/quantization/test_op_matmul_8bits.py new file mode 100644 index 0000000000000..6354b7c5fcf0d --- /dev/null +++ b/onnxruntime/test/python/quantization/test_op_matmul_8bits.py @@ -0,0 +1,220 @@ +#!/usr/bin/env python +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import tempfile +import unittest +from pathlib import Path + +import numpy as np +import onnx +from op_test_utils import TestDataFeeds, check_model_correctness, check_op_type_count, check_qtype_by_node_type + +from onnxruntime import get_available_providers +from onnxruntime.quantization import quant_utils + + +@unittest.skipIf( + "CUDAExecutionProvider" not in get_available_providers(), reason="CUDA is not available, skipping tests." +) +class TestOpMatMul8Bits(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls._tmp_model_dir = tempfile.TemporaryDirectory(prefix="test_matmul8bits.") + + @classmethod + def tearDownClass(cls): + cls._tmp_model_dir.cleanup() + + def fill_weight_data(self, shape: tuple[int, ...]) -> np.ndarray: + return np.random.normal(0, 0.01, size=shape).astype(np.float32) + + def input_feeds( + self, + n: int, + name2shape: dict[str, int | tuple[int, ...]], + low: int = -1, + high: int = 2, + dtype: type = np.float32, + ) -> TestDataFeeds: + input_data_list = [] + for _i in range(n): + inputs = {} + for name, shape in name2shape.items(): + inputs.update({name: np.random.randint(low, high, shape).astype(dtype)}) + input_data_list.extend([inputs]) + dr = TestDataFeeds(input_data_list) + return dr + + def construct_model_matmul(self, output_model_path: str, k: int = 32, n: int = 64) -> None: + """Create a simple onnx model with one MatMul node like (input) --> MatMul --> (output).""" + input_name = "input" + output_name = "output" + initializers = [] + + def make_matmul( + input_name, weight_shape: int | tuple[int, ...], weight_name: str, output_name: str, node_name: str + ): + weight_data = self.fill_weight_data(weight_shape) + initializers.append(onnx.numpy_helper.from_array(weight_data, name=weight_name)) + return onnx.helper.make_node( + "MatMul", + [input_name, weight_name], + [output_name], + node_name, + ) + + in_features = k + out_features = n + # make MatMul node + matmul_node = make_matmul( + input_name, + [in_features, out_features], + "linear1.weight", + output_name, + "MatMul_0", + ) + + # make graph + input_tensor = onnx.helper.make_tensor_value_info(input_name, onnx.TensorProto.FLOAT, [-1, in_features]) + output_tensor = onnx.helper.make_tensor_value_info(output_name, onnx.TensorProto.FLOAT, [-1, out_features]) + graph_name = "matmul_8bits_test" + graph = onnx.helper.make_graph( + [matmul_node], + graph_name, + [input_tensor], + [output_tensor], + initializer=initializers, + ) + # blocked quantization requires DQ op set >= 21 + model = onnx.helper.make_model(graph, opset_imports=[onnx.helper.make_opsetid("", 21)]) + model.ir_version = 10 # use stable onnx ir version + + onnx.save(model, output_model_path) + + def quant_test( + self, + model_fp32_path: str, + data_reader: TestDataFeeds, + block_size: int, + is_symmetric: bool, + quant_format: quant_utils.QuantFormat = quant_utils.QuantFormat.QOperator, + op_types_to_quantize: tuple[str, ...] = ("MatMul",), + quant_axes: tuple[tuple[str, int], ...] = (("MatMul", 0), ("Gather", 1)), + rtol: float = 0.01, + atol: float = 0.05, + config: str = "default", + suffix: str = "", + ): + use_qdq = quant_format == quant_utils.QuantFormat.QDQ + name_prefix = "QDQ" if use_qdq else "QOperator" + model_int8_path = str( + Path(self._tmp_model_dir.name) + .joinpath(f"{name_prefix}_bs{block_size}_{is_symmetric}{suffix}.onnx") + .absolute() + ) + + # Quantize fp32 model to int8 model + from onnxruntime.quantization import matmul_nbits_quantizer + + model = quant_utils.load_model_with_shape_infer(Path(model_fp32_path)) + + assert config in ["default", "hqq"] + if config == "default": + quant_config = matmul_nbits_quantizer.DefaultWeightOnlyQuantConfig( + block_size=block_size, + is_symmetric=is_symmetric, + quant_format=quant_format, + op_types_to_quantize=op_types_to_quantize, + quant_axes=quant_axes, + bits=8, + ) + else: + quant_config = matmul_nbits_quantizer.HQQWeightOnlyQuantConfig( + block_size=block_size, + bits=8, + quant_format=quant_format, + op_types_to_quantize=op_types_to_quantize, + quant_axes=quant_axes, + ) + + quant = matmul_nbits_quantizer.MatMulNBitsQuantizer(model, algo_config=quant_config) + quant.process() + quant.model.save_model_to_file(model_int8_path, False) + + if "Gather" in op_types_to_quantize: + quant_nodes = {"GatherBlockQuantized": 1} + else: + quant_nodes = {"DequantizeLinear": 1, "MatMul": 1} if use_qdq else {"MatMulNBits": 1} + check_op_type_count(self, model_int8_path, **quant_nodes) + + if use_qdq: + dq_qtype = onnx.TensorProto.INT8 if is_symmetric else onnx.TensorProto.UINT8 + dqnode_io_qtypes = ( + { + "DequantizeLinear": [ + ["i", 0, dq_qtype], + ] + } + if is_symmetric + else { + "DequantizeLinear": [ + ["i", 0, dq_qtype], + ["i", 2, dq_qtype], + ] + } + ) + check_qtype_by_node_type(self, model_int8_path, dqnode_io_qtypes) + for op in quant.model.opset_import(): + if op.domain in [None, "", "ai.onnx"] and op.version < 21: + self.fail(f"In QDQ format {op.domain} opset should be >= 21") + + data_reader.rewind() + + try: + check_model_correctness( + self, + model_fp32_path, + model_int8_path, + data_reader.get_next(), + rtol, + atol, + providers=["CUDAExecutionProvider"], + ) + except Exception as exception: + if "8b quantization not yet supported on this hardware platform!" in exception.args[0]: + # Currently we don't have int8 quantization support on all platforms, has to tolerate this exception + pass + else: + raise exception + + def test_quantize_matmul_8bits(self): + np.random.seed(13) + for k in [32, 40, 256, 512, 512, 1024, 1040]: + for n in [8, 256]: + model_fp32_path = str( + Path(self._tmp_model_dir.name).joinpath(f"matmul_fp32_k_{k}_n_{n}.onnx").absolute() + ) + self.construct_model_matmul(model_fp32_path, k=k, n=n) + for m in [1, 2]: + data_reader = self.input_feeds(m, {"input": (m, k)}) + for config in ["default", "hqq"]: + for block_size in [16, 128, 256]: + if block_size <= k: + self.quant_test( + model_fp32_path, + data_reader, + block_size, + True, + atol=0.01, + rtol=0.01, + config=config, + suffix=f"_m_{m}_n_{n}_k_{k}", + ) + + +if __name__ == "__main__": + unittest.main()