diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index f8f5546ae9465..47e7779d93b33 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -31,6 +31,7 @@ onnxruntime_add_static_library(onnxruntime_mlas ${MLAS_SRC_DIR}/eltwise.cpp ${MLAS_SRC_DIR}/erf.cpp ${MLAS_SRC_DIR}/compute.cpp + ${MLAS_SRC_DIR}/dequantize.cpp ${MLAS_SRC_DIR}/quantize.cpp ${MLAS_SRC_DIR}/qgemm_kernel_default.cpp ${MLAS_SRC_DIR}/qladd.cpp diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index 3575e30721af7..91182a4ca9c44 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -1223,6 +1223,21 @@ MlasQuantizeLinearS4( int8_t ZeroPoint ); +// +// Linear dequantization routines. +// + +template +void +MLASCALL +MlasDequantizeLinear( + const InputType* Input, + float* Output, + size_t N, + float Scale, + InputType ZeroPoint + ); + /** * @brief Requantize a block of the intermediate buffer to the output buffer, * optionally adding the supplied bias diff --git a/onnxruntime/core/mlas/lib/dequantize.cpp b/onnxruntime/core/mlas/lib/dequantize.cpp new file mode 100644 index 0000000000000..175d3f668ac39 --- /dev/null +++ b/onnxruntime/core/mlas/lib/dequantize.cpp @@ -0,0 +1,395 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + dequantize.cpp + +Abstract: + + This module implements routines to dequantize buffers. + + The dequantization formula as specified in the ONNX operator documentation is: + + Output = (Input - ZeroPoint) * Scale + +--*/ + +#include "mlasi.h" + +// +// DequantizeLinear reference implementation using the C++ runtime. +// + +template +static +MLAS_FORCEINLINE +void +MlasDequantizeLinearRefImpl( + const InputType* Input, + float* Output, + size_t N, + float Scale, + InputType ZeroPoint + ) +/*++ + +Routine Description: + + This routine quantizes the input buffer using the supplied quantization + parameters. + +Arguments: + + Input - Supplies the input buffer with quantized data. + + Output - Supplies the output buffer. + + N - Supplies the number of elements to process. + + Scale - Supplies the quantization scale. + + ZeroPoint - Supplies the quantization zero point value. + +Return Value: + + None. + +--*/ +{ + int32_t ZeroPointS32 = static_cast(ZeroPoint); + + for (size_t n = 0; n < N; n++) { + Output[n] = static_cast(static_cast(Input[n]) - ZeroPointS32) * Scale; + } +} + +#if defined(MLAS_SSE2_INTRINSICS) +// Implementation for Intel SSE 2. Refer to the Intel Intrisics Guide: +// https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html + +void +MLASCALL +MlasDequantizeLinearS8Kernel( + const int8_t* Input, + float* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ) +{ + const __m128 ScaleVector = MlasBroadcastFloat32x4(Scale); + const __m128i ZeroPointS16Vector = _mm_set1_epi16(static_cast(ZeroPoint)); // Broadcast zp to 8 int16s + const __m128i Zeros = _mm_setzero_si128(); + + while (N >= 16) { + // Load a vector of 16 int8s: [0 ... 15] + __m128i VectorS8 = _mm_loadu_si128(reinterpret_cast(Input)); + + // Sign-extend into 2 vectors of 8 int16s + __m128i SignMaskS8 = _mm_cmpgt_epi8(Zeros, VectorS8); // 0xFF for every negative byte in VectorS8 + __m128i VectorS16_0 = _mm_unpacklo_epi8(VectorS8, SignMaskS8); // [0 ... 7] + __m128i VectorS16_1 = _mm_unpackhi_epi8(VectorS8, SignMaskS8); // [8 ... 15] + + // Subtract the zero-points in int16 domain. + VectorS16_0 = _mm_sub_epi16(VectorS16_0, ZeroPointS16Vector); + VectorS16_1 = _mm_sub_epi16(VectorS16_1, ZeroPointS16Vector); + + // Sign-extend into 4 vectors of 4 int32s + __m128i SignMaskS16_0 = _mm_cmpgt_epi16(Zeros, VectorS16_0); + __m128i VectorS32_0 = _mm_unpacklo_epi16(VectorS16_0, SignMaskS16_0); // [0 ... 3] + __m128i VectorS32_1 = _mm_unpackhi_epi16(VectorS16_0, SignMaskS16_0); // [4 ... 7] + + __m128i SignMaskS16_1 = _mm_cmpgt_epi16(Zeros, VectorS16_1); + __m128i VectorS32_2 = _mm_unpacklo_epi16(VectorS16_1, SignMaskS16_1); // [8 ... 11] + __m128i VectorS32_3 = _mm_unpackhi_epi16(VectorS16_1, SignMaskS16_1); // [12 ... 15] + + // Cast each int32x4 to float and multiply by the scale vector. + __m128 VectorF32_0 = _mm_mul_ps(_mm_cvtepi32_ps(VectorS32_0), ScaleVector); + __m128 VectorF32_1 = _mm_mul_ps(_mm_cvtepi32_ps(VectorS32_1), ScaleVector); + __m128 VectorF32_2 = _mm_mul_ps(_mm_cvtepi32_ps(VectorS32_2), ScaleVector); + __m128 VectorF32_3 = _mm_mul_ps(_mm_cvtepi32_ps(VectorS32_3), ScaleVector); + + // Store each int32x4 into the output. + _mm_storeu_ps(Output + 0, VectorF32_0); + _mm_storeu_ps(Output + 4, VectorF32_1); + _mm_storeu_ps(Output + 8, VectorF32_2); + _mm_storeu_ps(Output + 12, VectorF32_3); + + Input += 16; + Output += 16; + N -= 16; + } + + // Handle leftover elements (< 16) with the scalar reference implementation. + MlasDequantizeLinearRefImpl(Input, Output, N, Scale, ZeroPoint); +} + +void +MLASCALL +MlasDequantizeLinearU8Kernel( + const uint8_t* Input, + float* Output, + size_t N, + float Scale, + uint8_t ZeroPoint + ) +{ + const __m128 ScaleVector = MlasBroadcastFloat32x4(Scale); + const __m128i ZeroPointS16Vector = _mm_set1_epi16(static_cast(ZeroPoint)); // Broadcast zp to 8 int16s + const __m128i Zeros = _mm_setzero_si128(); + + while (N >= 16) { + // Load a vector of 16 uint8s: [0 ... 15] + __m128i VectorU8 = _mm_loadu_si128(reinterpret_cast(Input)); + + // Zero-extend into 2 vectors of 8 uint16s + __m128i VectorU16_0 = _mm_unpacklo_epi8(VectorU8, Zeros); // [0 ... 7] + __m128i VectorU16_1 = _mm_unpackhi_epi8(VectorU8, Zeros); // [8 ... 15] + + // Subtract the zero-points as uint16s. Due to two's compliment, negative results can be reinterpreted as int16 + __m128i VectorS16_0 = _mm_sub_epi16(VectorU16_0, ZeroPointS16Vector); + __m128i VectorS16_1 = _mm_sub_epi16(VectorU16_1, ZeroPointS16Vector); + + // Sign-extend into 4 vectors of 4 int32s + __m128i SignMaskS16_0 = _mm_cmpgt_epi16(Zeros, VectorS16_0); + __m128i VectorS32_0 = _mm_unpacklo_epi16(VectorS16_0, SignMaskS16_0); // [0 ... 3] + __m128i VectorS32_1 = _mm_unpackhi_epi16(VectorS16_0, SignMaskS16_0); // [4 ... 7] + + __m128i SignMaskS16_1 = _mm_cmpgt_epi16(Zeros, VectorS16_1); + __m128i VectorS32_2 = _mm_unpacklo_epi16(VectorS16_1, SignMaskS16_1); // [8 ... 11] + __m128i VectorS32_3 = _mm_unpackhi_epi16(VectorS16_1, SignMaskS16_1); // [12 ... 15] + + // Cast each int32x4 to float and multiply by the scale vector. + __m128 VectorF32_0 = _mm_mul_ps(_mm_cvtepi32_ps(VectorS32_0), ScaleVector); + __m128 VectorF32_1 = _mm_mul_ps(_mm_cvtepi32_ps(VectorS32_1), ScaleVector); + __m128 VectorF32_2 = _mm_mul_ps(_mm_cvtepi32_ps(VectorS32_2), ScaleVector); + __m128 VectorF32_3 = _mm_mul_ps(_mm_cvtepi32_ps(VectorS32_3), ScaleVector); + + // Store each int32x4 into the output. + _mm_storeu_ps(Output + 0, VectorF32_0); + _mm_storeu_ps(Output + 4, VectorF32_1); + _mm_storeu_ps(Output + 8, VectorF32_2); + _mm_storeu_ps(Output + 12, VectorF32_3); + + Input += 16; + Output += 16; + N -= 16; + } + + // Handle leftover elements (< 16) with the scalar reference implementation. + MlasDequantizeLinearRefImpl(Input, Output, N, Scale, ZeroPoint); +} + +template<> +void +MLASCALL +MlasDequantizeLinear( + const int8_t* Input, + float* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ) +{ +#if defined(MLAS_TARGET_AMD64) + GetMlasPlatform().DequantizeLinearS8Kernel( +#else + MlasDequantizeLinearS8Kernel( +#endif + Input, Output, N, Scale, ZeroPoint); +} + +template<> +void +MLASCALL +MlasDequantizeLinear( + const uint8_t* Input, + float* Output, + size_t N, + float Scale, + uint8_t ZeroPoint + ) +{ +#if defined(MLAS_TARGET_AMD64) + GetMlasPlatform().DequantizeLinearU8Kernel( +#else + MlasDequantizeLinearU8Kernel( +#endif + Input, Output, N, Scale, ZeroPoint); +} +#elif defined(MLAS_NEON64_INTRINSICS) +// Implementation for ARM64 NEON. Refer to the ARM instrinsics guide: +// https://developer.arm.com/architectures/instruction-sets/intrinsics/ + +void +MLASCALL +MlasDequantizeLinearS8Kernel( + const int8_t* Input, + float* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ) +{ + const float32x4_t ScaleVector = MlasBroadcastFloat32x4(Scale); + const int16x8_t ZeroPointVector = vdupq_n_s16(ZeroPoint); // Broadcast ZeroPoint (sign-extended to 16bits) + + while (N >= 16) { + // Load a vector of 16 int8s: [0 ... 15] + int8x16_t VectorS8 = vld1q_s8(Input); + + // Sign-extend into 2 vectors of 8 int16s + int16x8_t VectorS16_0 = vmovl_s8(vget_low_s8(VectorS8)); // [0 ... 7] + int16x8_t VectorS16_1 = vmovl_s8(vget_high_s8(VectorS8)); // [8 ... 15] + + // Subtract the zero-points in int16 domain. + VectorS16_0 = vsubq_s16(VectorS16_0, ZeroPointVector); + VectorS16_1 = vsubq_s16(VectorS16_1, ZeroPointVector); + + // Sign-extend into 4 vectors of 4 int32s + int32x4_t VectorS32_0 = vmovl_s16(vget_low_s16(VectorS16_0)); // [0 ... 3] + int32x4_t VectorS32_1 = vmovl_s16(vget_high_s16(VectorS16_0)); // [4 ... 7] + int32x4_t VectorS32_2 = vmovl_s16(vget_low_s16(VectorS16_1)); // [8 ... 11] + int32x4_t VectorS32_3 = vmovl_s16(vget_high_s16(VectorS16_1)); // [12 ... 15] + + // Cast each int32x4 to float and multiply by the scale vector. + float32x4_t VectorF32_0 = vmulq_f32(vcvtq_f32_s32(VectorS32_0), ScaleVector); + float32x4_t VectorF32_1 = vmulq_f32(vcvtq_f32_s32(VectorS32_1), ScaleVector); + float32x4_t VectorF32_2 = vmulq_f32(vcvtq_f32_s32(VectorS32_2), ScaleVector); + float32x4_t VectorF32_3 = vmulq_f32(vcvtq_f32_s32(VectorS32_3), ScaleVector); + + // Store each int32x4 into the output. + vst1q_f32(Output + 0, VectorF32_0); + vst1q_f32(Output + 4, VectorF32_1); + vst1q_f32(Output + 8, VectorF32_2); + vst1q_f32(Output + 12, VectorF32_3); + + N -= 16; + Input += 16; + Output += 16; + } + + // Handle leftover elements (< 16) with the scalar reference implementation. + MlasDequantizeLinearRefImpl(Input, Output, N, Scale, ZeroPoint); +} + +void +MLASCALL +MlasDequantizeLinearU8Kernel( + const uint8_t* Input, + float* Output, + size_t N, + float Scale, + uint8_t ZeroPoint + ) +{ + const float32x4_t ScaleVector = MlasBroadcastFloat32x4(Scale); + const uint8x8_t ZeroPointVector = vdup_n_u8(ZeroPoint); // Broadcast ZeroPoint to 8 uint8s + + while (N >= 16) { + // Load a vector of 16 uint8s: [0 ... 15] + uint8x16_t VectorU8 = vld1q_u8(Input); + + // Subtract zero-point. The vsubl_u8 instruction zero-extends its arguments to uint16 first. + // The reinterpret from uint16x8 to int16x8 is actually a NOP. + int16x8_t VectorS16_0 = vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(VectorU8), ZeroPointVector)); // [0 ... 7] + int16x8_t VectorS16_1 = vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(VectorU8), ZeroPointVector)); // [8 ... 15] + + // Sign-extend into 4 vectors of 4 int32s + int32x4_t VectorS32_0 = vmovl_s16(vget_low_s16(VectorS16_0)); // [0 ... 3] + int32x4_t VectorS32_1 = vmovl_s16(vget_high_s16(VectorS16_0)); // [4 ... 7] + int32x4_t VectorS32_2 = vmovl_s16(vget_low_s16(VectorS16_1)); // [8 ... 11] + int32x4_t VectorS32_3 = vmovl_s16(vget_high_s16(VectorS16_1)); // [12 ... 15] + + // Cast each int32x4 to float and multiply by the scale vector. + float32x4_t VectorF32_0 = vmulq_f32(vcvtq_f32_s32(VectorS32_0), ScaleVector); + float32x4_t VectorF32_1 = vmulq_f32(vcvtq_f32_s32(VectorS32_1), ScaleVector); + float32x4_t VectorF32_2 = vmulq_f32(vcvtq_f32_s32(VectorS32_2), ScaleVector); + float32x4_t VectorF32_3 = vmulq_f32(vcvtq_f32_s32(VectorS32_3), ScaleVector); + + // Store each int32x4 into the output. + vst1q_f32(Output + 0, VectorF32_0); + vst1q_f32(Output + 4, VectorF32_1); + vst1q_f32(Output + 8, VectorF32_2); + vst1q_f32(Output + 12, VectorF32_3); + + N -= 16; + Input += 16; + Output += 16; + } + + // Handle leftover elements (< 16) with the scalar reference implementation. + MlasDequantizeLinearRefImpl(Input, Output, N, Scale, ZeroPoint); +} + +template<> +void +MLASCALL +MlasDequantizeLinear( + const int8_t* Input, + float* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ) +{ + MlasDequantizeLinearS8Kernel(Input, Output, N, Scale, ZeroPoint); +} + +template<> +void +MLASCALL +MlasDequantizeLinear( + const uint8_t* Input, + float* Output, + size_t N, + float Scale, + uint8_t ZeroPoint + ) +{ + MlasDequantizeLinearU8Kernel(Input, Output, N, Scale, ZeroPoint); +} +#else +// Implementation that uses the scalar reference implementation. + +template +void +MLASCALL +MlasDequantizeLinear( + const InputType* Input, + float* Output, + size_t N, + float Scale, + InputType ZeroPoint + ) +{ + MlasDequantizeLinearRefImpl(Input, Output, N, Scale, ZeroPoint); +} + +template +void +MLASCALL +MlasDequantizeLinear( + const int8_t* Input, + float* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ); + +template +void +MLASCALL +MlasDequantizeLinear( + const uint8_t* Input, + float* Output, + size_t N, + float Scale, + uint8_t ZeroPoint + ); + +#endif diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 0af3cd2e33b02..0879d1b0ba510 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -747,6 +747,24 @@ void float Scale, int8_t ZeroPoint); +typedef +void +(MLASCALL MLAS_DEQUANTIZE_LINEAR_U8_KERNEL)( + const uint8_t* Input, + float* Output, + size_t N, + float Scale, + uint8_t ZeroPoint); + +typedef +void +(MLASCALL MLAS_DEQUANTIZE_LINEAR_S8_KERNEL)( + const int8_t* Input, + float* Output, + size_t N, + float Scale, + int8_t ZeroPoint); + template struct MLAS_QUANT_KERNEL { @@ -903,6 +921,8 @@ extern "C" { MLAS_QUANTIZE_LINEAR_S4_KERNEL MlasQuantizeLinearS4Kernel; MLAS_QUANTIZE_LINEAR_U4_KERNEL MlasQuantizeLinearU4Kernel; #if defined(MLAS_TARGET_AMD64) + MLAS_DEQUANTIZE_LINEAR_S8_KERNEL MlasDequantizeLinearS8Kernel; + MLAS_DEQUANTIZE_LINEAR_U8_KERNEL MlasDequantizeLinearU8Kernel; MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasErfKernelFma3; MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasComputeExpF32KernelFma3; MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasComputeExpF32KernelAvx512F; @@ -1246,6 +1266,8 @@ struct MLAS_PLATFORM { MLAS_QUANTIZE_LINEAR_U16_KERNEL* QuantizeLinearU16Kernel; MLAS_QUANTIZE_LINEAR_S4_KERNEL* QuantizeLinearS4Kernel; MLAS_QUANTIZE_LINEAR_U4_KERNEL* QuantizeLinearU4Kernel; + MLAS_DEQUANTIZE_LINEAR_S8_KERNEL* DequantizeLinearS8Kernel; + MLAS_DEQUANTIZE_LINEAR_U8_KERNEL* DequantizeLinearU8Kernel; uint32_t NchwcBlockSize; uint32_t PreferredBufferAlignment; int32_t MaximumThreadCount; diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 45d3a876beb86..45bba5363d4f2 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -285,6 +285,8 @@ Return Value: this->QuantizeLinearU16Kernel = MlasQuantizeLinearU16Kernel; this->QuantizeLinearS4Kernel = MlasQuantizeLinearS4Kernel; this->QuantizeLinearU4Kernel = MlasQuantizeLinearU4Kernel; + this->DequantizeLinearS8Kernel = MlasDequantizeLinearS8Kernel; + this->DequantizeLinearU8Kernel = MlasDequantizeLinearU8Kernel; #ifndef __APPLE__ #ifndef FORCE_GENERIC_ALGORITHMS this->CastF16ToF32Kernel = &MlasCastF16ToF32KernelSse; diff --git a/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc b/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc index adb2aee171f39..c691be6ffd0e8 100644 --- a/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc +++ b/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include #include #include "core/framework/element_type_lists.h" #include "core/framework/float8.h" @@ -301,14 +302,31 @@ struct DequantizeLinearApply { * @param[in] zero_point same shape as scale */ void op(size_t M, size_t K, size_t N, const T* input, - const OutT* scale, OutT* output, const T* zero_point) { + const OutT* scale, OutT* output, const T* zero_point, concurrency::ThreadPool* thread_pool) { for (size_t m = 0; m < M; m++) { for (size_t k = 0; k < K; k++) { +#if defined(ORT_CLIENT_PACKAGE_BUILD) + // TODO: Only using multithreaded/SIMD DQ when ORT is built for client/on-device workloads. + // Make this the default behavior after more testing. + if constexpr (std::is_same_v || std::is_same_v) { + ParDequantizeLinearStd(input, output, N, scale[k], zero_point ? zero_point[k] : 0, thread_pool); + input += N; + output += N; + } else { + auto zp = zero_point ? static_cast(zero_point[k]) : 0; + auto sc = static_cast(scale[k]); + for (size_t n = 0; n < N; n++) { + *output++ = static_cast(static_cast(static_cast(*input++) - zp) * sc); + } + } +#else + ORT_UNUSED_PARAMETER(thread_pool); auto zp = zero_point ? static_cast(zero_point[k]) : 0; auto sc = static_cast(scale[k]); for (size_t n = 0; n < N; n++) { *output++ = static_cast(static_cast(static_cast(*input++) - zp) * sc); } +#endif // defined(ORT_CLIENT_PACKAGE_BUILD) } } } @@ -327,7 +345,8 @@ struct DequantizeLinearApply { * @param[in] zero_point same shape as scale */ void op(size_t M, size_t K, size_t N, size_t quant_block_size, - const T* input, const OutT* scale, OutT* output, const T* zero_point) { + const T* input, const OutT* scale, OutT* output, const T* zero_point, concurrency::ThreadPool* thread_pool) { + ORT_UNUSED_PARAMETER(thread_pool); if (zero_point) { for (size_t m = 0; m < M; m++) { for (size_t bd = 0; bd < K; bd += quant_block_size) { @@ -368,7 +387,8 @@ template struct DequantizeLinearApply { // per-tensor/layer or per-axis quantization void op(size_t M, size_t K, size_t N, - const T* input, const OutT* scale, OutT* output, const T* zero_point) { + const T* input, const OutT* scale, OutT* output, const T* zero_point, concurrency::ThreadPool* thread_pool) { + ORT_UNUSED_PARAMETER(thread_pool); size_t input_index = 0; for (size_t m = 0; m < M; m++) { @@ -394,7 +414,8 @@ struct DequantizeLinearApply { // Blocked quantization // TODO(fajin) : add mlas kernel to utilize multithreading, refer MlasDequantizeBlockwise. void op(size_t M, size_t K, size_t N, size_t quant_block_size, - const T* input, const OutT* scale, OutT* output, const T* zero_point) { + const T* input, const OutT* scale, OutT* output, const T* zero_point, concurrency::ThreadPool* thread_pool) { + ORT_UNUSED_PARAMETER(thread_pool); size_t input_index = 0; if (zero_point) { @@ -440,36 +461,36 @@ struct DequantizeLinearApply { #if !defined(DISABLE_FLOAT8_TYPES) -#define DEQUANTIZE_LINEAR_APPLY_FLOAT8(T) \ - template \ - struct DequantizeLinearApply { \ - /* Per-tensor/layer or per-axis quantization */ \ - void op(size_t M, size_t K, size_t N, \ - const T* input, const OutT* scale, OutT* output, const T*) { \ - for (size_t m = 0; m < M; m++) { \ - for (size_t bd = 0; bd < K; bd++) { \ - auto sc = scale[bd]; \ - for (size_t bs = 0; bs < N; bs++, input++) { \ - *output++ = static_cast(input->ToFloat() * sc); \ - } \ - } \ - } \ - } \ - /* Blocked quantization */ \ - void op(size_t M, size_t K, size_t N, size_t quant_block_size, \ - const T* input, const OutT* scale, OutT* output, const T*) { \ - for (size_t m = 0; m < M; m++) { \ - for (size_t bd = 0; bd < K; bd += quant_block_size) { \ - for (size_t qb = 0, qb_end = std::min(quant_block_size, K - bd); qb < qb_end; ++qb) { \ - for (size_t bs = 0; bs < N; bs++, input++) { \ - auto sc = static_cast(scale[bs]); \ - *output++ = static_cast(input->ToFloat() * sc); \ - } \ - } \ - scale += N; \ - } \ - } \ - } \ +#define DEQUANTIZE_LINEAR_APPLY_FLOAT8(T) \ + template \ + struct DequantizeLinearApply { \ + /* Per-tensor/layer or per-axis quantization */ \ + void op(size_t M, size_t K, size_t N, \ + const T* input, const OutT* scale, OutT* output, const T*, concurrency::ThreadPool*) { \ + for (size_t m = 0; m < M; m++) { \ + for (size_t bd = 0; bd < K; bd++) { \ + auto sc = scale[bd]; \ + for (size_t bs = 0; bs < N; bs++, input++) { \ + *output++ = static_cast(input->ToFloat() * sc); \ + } \ + } \ + } \ + } \ + /* Blocked quantization */ \ + void op(size_t M, size_t K, size_t N, size_t quant_block_size, \ + const T* input, const OutT* scale, OutT* output, const T*, concurrency::ThreadPool*) { \ + for (size_t m = 0; m < M; m++) { \ + for (size_t bd = 0; bd < K; bd += quant_block_size) { \ + for (size_t qb = 0, qb_end = std::min(quant_block_size, K - bd); qb < qb_end; ++qb) { \ + for (size_t bs = 0; bs < N; bs++, input++) { \ + auto sc = static_cast(scale[bs]); \ + *output++ = static_cast(input->ToFloat() * sc); \ + } \ + } \ + scale += N; \ + } \ + } \ + } \ }; DEQUANTIZE_LINEAR_APPLY_FLOAT8(Float8E4M3FN) @@ -513,6 +534,7 @@ Status DequantizeLinear::Compute(OpKernelContext* ctx) const { const auto to = x_scale.GetElementType(); const T* input = x.Data(); constexpr bool is_4bit = boost::mp11::mp_contains, T>::value; + concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool(); if (to == ONNX_NAMESPACE::TensorProto::FLOAT) { const float* scale = x_scale.Data(); @@ -522,12 +544,12 @@ Status DequantizeLinear::Compute(OpKernelContext* ctx) const { static_cast(broadcast_dim), static_cast(process_block_size), static_cast(block_size_), - input, scale, output, zero_point); + input, scale, output, zero_point, thread_pool); } else { DequantizeLinearApply().op(static_cast(process_block_count), static_cast(broadcast_dim), static_cast(process_block_size), - input, scale, output, zero_point); + input, scale, output, zero_point, thread_pool); } } else if (to == ONNX_NAMESPACE::TensorProto::FLOAT16) { const MLFloat16* scale = x_scale.Data(); @@ -537,12 +559,12 @@ Status DequantizeLinear::Compute(OpKernelContext* ctx) const { static_cast(broadcast_dim), static_cast(process_block_size), static_cast(block_size_), - input, scale, output, zero_point); + input, scale, output, zero_point, thread_pool); } else { DequantizeLinearApply().op(static_cast(process_block_count), static_cast(broadcast_dim), static_cast(process_block_size), - input, scale, output, zero_point); + input, scale, output, zero_point, thread_pool); } } else if (to == ONNX_NAMESPACE::TensorProto::BFLOAT16) { ORT_THROW("DequantizeLinear into BFLOAT16 is not implemented yet."); diff --git a/onnxruntime/core/util/qmath.h b/onnxruntime/core/util/qmath.h index 0172902bdf4e2..f7d5cdb98aa1d 100644 --- a/onnxruntime/core/util/qmath.h +++ b/onnxruntime/core/util/qmath.h @@ -1001,4 +1001,53 @@ struct BlockedQuantizeLinear { #endif +/** + * @brief Run MlasDequantizeLinear in parallel, with provided thread pool + */ + +template +void ParDequantizeLinearStd(const InputQuantType* input, + float* output, + size_t num_elems, + float scale, + InputQuantType zero_point, + concurrency::ThreadPool* thread_pool) { + constexpr std::ptrdiff_t block_size = 128; + const std::ptrdiff_t num_blocks = (num_elems + block_size - 1) / block_size; + const TensorOpCost unit_cost{static_cast(block_size * sizeof(InputQuantType)), + static_cast(block_size * sizeof(float)), + static_cast(block_size) * 2.0}; + concurrency::ThreadPool::TryParallelFor(thread_pool, num_blocks, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + auto begin_idx = begin * block_size; + auto end_idx = std::min(static_cast(num_elems), end * block_size); + MlasDequantizeLinear(&(input[begin_idx]), &(output[begin_idx]), end_idx - begin_idx, scale, zero_point); + }); +} + +// Note: this doesn't use MLAS kernel. There are currently no MLAS kernels for fp16 QuantizeLinear or DequantizeLinear. +template +void ParDequantizeLinearStd(const InputQuantType* input, + MLFloat16* output, + size_t num_elems, + MLFloat16 scale, + InputQuantType zero_point, + concurrency::ThreadPool* thread_pool) { + constexpr std::ptrdiff_t block_size = 128; + const std::ptrdiff_t num_blocks = (num_elems + block_size - 1) / block_size; + const TensorOpCost unit_cost{static_cast(block_size * sizeof(InputQuantType)), + static_cast(block_size * sizeof(MLFloat16)), + static_cast(block_size) * 2.0}; + + const int32_t zp_s32 = static_cast(zero_point); + const float sc_f32 = scale.ToFloat(); + + concurrency::ThreadPool::TryParallelFor(thread_pool, num_blocks, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + auto begin_idx = begin * block_size; + auto end_idx = std::min(static_cast(num_elems), end * block_size); + for (; begin_idx != end_idx; ++begin_idx) { + output[begin_idx] = MLFloat16(static_cast(static_cast(input[begin_idx]) - zp_s32) * sc_f32); + } + }); +} + } // namespace onnxruntime diff --git a/onnxruntime/test/mlas/unittest/test_dequantizelinear.cpp b/onnxruntime/test/mlas/unittest/test_dequantizelinear.cpp new file mode 100644 index 0000000000000..b994981364947 --- /dev/null +++ b/onnxruntime/test/mlas/unittest/test_dequantizelinear.cpp @@ -0,0 +1,75 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "test_util.h" + +template +class MlasDequantizeLinearTest : public MlasTestBase { + private: + MatrixGuardBuffer BufferInput; + MatrixGuardBuffer BufferOutput; + MatrixGuardBuffer BufferOutputReference; + + void GenerateReference(const QuantInt* Input, float* OutputReference, size_t N, float Scale, QuantInt ZeroPoint) { + int32_t ZeroPointS32 = static_cast(ZeroPoint); + + for (size_t n = 0; n < N; n++) { + OutputReference[n] = static_cast(static_cast(Input[n]) - ZeroPointS32) * Scale; + } + } + + void Test(size_t N) { + QuantInt* Input = BufferInput.GetBuffer(N); + float* Output = BufferOutput.GetBuffer(N); + float* OutputReference = BufferOutputReference.GetBuffer(N); + + std::default_random_engine generator(static_cast(N)); + + std::uniform_real_distribution min_gen(-10.f, -10e-3f); + float MinimumValue = min_gen(generator); + + std::uniform_real_distribution max_gen(10e-3f, 10.f); + float MaximumValue = max_gen(generator); + + float Scale = (MaximumValue - MinimumValue) / 512.f; + + std::uniform_int_distribution zp_distribution(std::numeric_limits::min(), + std::numeric_limits::max()); + QuantInt ZeroPoint = static_cast(zp_distribution(generator)); + + for (size_t n = 0; n < N; n++) { + Input[n] = static_cast(zp_distribution(generator)); + } + + GenerateReference(Input, OutputReference, N, Scale, ZeroPoint); + MlasDequantizeLinear(Input, Output, N, Scale, ZeroPoint); + + for (size_t n = 0; n < N; n++) { + ASSERT_EQ(Output[n], OutputReference[n]) << ", size=" << N << ", index=" << n; + } + } + + public: + static const char* GetTestSuiteName() { + if constexpr (std::is_same_v) { + return "DequantizeLinearS8"; + } else { + return "DequantizeLinearU8"; + } + } + + void ExecuteShort(void) override { + for (size_t n = 1; n <= 512; n++) { + Test(n); + } + } +}; + +static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { + size_t count = 0; + if (is_short_execute) { + count += MlasDirectShortExecuteTests>::RegisterShortExecute(); + count += MlasDirectShortExecuteTests>::RegisterShortExecute(); + } + return count; +}); diff --git a/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc b/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc index 4e7a6356a5129..8fdbf0060eaa0 100644 --- a/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc @@ -33,6 +33,32 @@ TEST(DequantizeLinearOpTest, Int8) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } +// scalar zero & scale with uint8 (large enough input to execute MLAS vectorized loop) +TEST(DequantizeLinearOpTest, Uint8_Large) { + OpTester test("DequantizeLinear", 10); + std::vector dims{1, 1039}; // not evenly divisible by 16 (loop unroll amount) to test handling of leftover inputs + test.AddInput("x", dims, std::vector(1039, 1)); + test.AddInput("x_scale", {}, {1.0f}); + test.AddInput("x_zero_point", {}, {1}); + test.AddOutput("y", dims, std::vector(1039, 0.0f)); + // Disable Tensorrt EP due to error:node1_quantize_scale_node: out of bounds channel axis 1. Number of input dimensions is 1. + // Disable WebGPU EP because it requires dims.Size() to be multiple of 4. Fails with error: needs at least component size 4. + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kWebGpuExecutionProvider}); +} + +// scalar zero & scale with int8 (large enough input to execute MLAS vectorized loop) +TEST(DequantizeLinearOpTest, Int8_Large) { + OpTester test("DequantizeLinear", 10); + std::vector dims{1, 1039}; // not evenly divisible by 16 (loop unroll amount) to test handling of leftover inputs + test.AddInput("x", dims, std::vector(1039, 1)); + test.AddInput("x_scale", {}, {1.0f}); + test.AddInput("x_zero_point", {}, {1}); + test.AddOutput("y", dims, std::vector(1039, 0.0f)); + // Disable Tensorrt EP due to error:node1_quantize_scale_node: out of bounds channel axis 1. Number of input dimensions is 1. + // Disable WebGPU EP because it requires dims.Size() to be multiple of 4. Fails with error: needs at least component size 4. + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kWebGpuExecutionProvider}); +} + // scalar zero & scale with int4 TEST(DequantizeLinearOpTest, Int4) { OpTester test("DequantizeLinear", 21);