diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 7b631895d8d95..d091995ccf508 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -829,7 +829,7 @@ endif() foreach(mlas_target ${ONNXRUNTIME_MLAS_LIBS}) target_include_directories(${mlas_target} PRIVATE ${MLAS_INC_DIR} ${MLAS_SRC_DIR}) - onnxruntime_add_include_to_target(${mlas_target} ${GSL_TARGET}) + onnxruntime_add_include_to_target(${mlas_target} ${GSL_TARGET} safeint_interface) target_compile_definitions(${mlas_target} PRIVATE ${mlas_private_compile_definitions}) diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_lasx.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_lasx.cpp index 04c6540d1783b..310f805a55988 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_lasx.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_lasx.cpp @@ -48,10 +48,10 @@ QNBitGemmPackQuantBDataSize_Lasx( BlkSumSize += SafeInt(BlkSumAlignment) - 1; PackedQuantBDataSize += ScaleSize + BlkSumSize; - return PackedQuantBDataSize.Value(); + return static_cast(PackedQuantBDataSize); } else { SafeInt PackedQuantBDataSize = SafeInt(N) * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - return PackedQuantBDataSize.Value(); + return static_cast(PackedQuantBDataSize); } } @@ -73,7 +73,7 @@ SQ4BitGemmPackQuantBData_Lasx( const size_t BlockCountK = MlasDivRoundup(K, BlkLen); const size_t BlkDataSize = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - const SafeInt Iterations = SafeInt(N) * BlockCountK; // one iteration per block + const size_t Iterations = SafeInt(N) * BlockCountK; // one iteration per block size_t SubBlkLen = (BlkLen == 16) ? 16 : (BlkLen == 32 ? 32 : 64); @@ -105,14 +105,14 @@ SQ4BitGemmPackQuantBData_Lasx( // MlasTrySimpleParallel( - ThreadPool, Iterations.Value(), + ThreadPool, Iterations, [&](ptrdiff_t tid) { const size_t n = tid / BlockCountK; const size_t k_blk = tid % BlockCountK; - const SafeInt data_offset = SafeInt(n) * BlockCountK * BlkDataSize + k_blk * BlkDataSize; - const std::byte* QuantBData = QuantBDataBegin + data_offset.Value(); - std::byte* PackedQuantBData = PackedQuantBDataBegin + data_offset.Value(); + const size_t data_offset = SafeInt(n) * BlockCountK * BlkDataSize + k_blk * BlkDataSize; + const std::byte* QuantBData = QuantBDataBegin + data_offset; + std::byte* PackedQuantBData = PackedQuantBDataBegin + data_offset; for (size_t kk = 0; kk < BlkLen; kk += SubBlkLen) { for (size_t byte_pair_idx = 0; byte_pair_idx < SubBlkBytePairCount; ++byte_pair_idx) { @@ -163,8 +163,8 @@ SQ4BitGemmPackQuantBDataAndBlkSum_Lasx( } if (QuantBScaleBegin) { - SafeInt offset = SafeInt(N) * BlockCountK; - std::copy(QuantBScaleBegin, QuantBScaleBegin + offset.Value(), packed_quant_b.PackedQuantBScale); + size_t offset = SafeInt(N) * BlockCountK; + std::copy(QuantBScaleBegin, QuantBScaleBegin + offset, packed_quant_b.PackedQuantBScale); } if ((QuantBScaleBegin && !has_zp_input) || QuantBZPBegin) { @@ -272,14 +272,14 @@ ComputeDotProducts_BlkLen32Plus_CompFp32_lasx( float scale_v[NCols]; UnrolledLoop([&](size_t i) { - SafeInt scale_offset = SafeInt(StrideQuantBScale) * i; - scale_v[i] = *(s + scale_offset.Value()); + size_t scale_offset = SafeInt(StrideQuantBScale) * i; + scale_v[i] = *(s + scale_offset); }); std::byte* b_blk_data_col_ptr[NCols]; UnrolledLoop([&](size_t i) { - SafeInt data_offset = SafeInt(StrideQuantBData) * i; - b_blk_data_col_ptr[i] = (std::byte*)(b_blk_data_ptr + data_offset.Value()); + size_t data_offset = SafeInt(StrideQuantBData) * i; + b_blk_data_col_ptr[i] = (std::byte*)(b_blk_data_ptr + data_offset); }); // not ready for "Manual conversion to float" in neon yet. @@ -427,14 +427,14 @@ ComputeDotProducts_BlkLen16_CompFp32_lasx( float scale_v[NCols]; UnrolledLoop([&](size_t i) { - SafeInt scale_offset = SafeInt(StrideQuantBScale) * i; - scale_v[i] = *(s + scale_offset.Value()); + size_t scale_offset = SafeInt(StrideQuantBScale) * i; + scale_v[i] = *(s + scale_offset); }); std::byte* b_blk_data_col_ptr[NCols]; UnrolledLoop([&](size_t i) { - SafeInt data_offset = SafeInt(StrideQuantBData) * i; - b_blk_data_col_ptr[i] = (std::byte*)(b_blk_data_ptr + data_offset.Value()); + size_t data_offset = SafeInt(StrideQuantBData) * i; + b_blk_data_col_ptr[i] = (std::byte*)(b_blk_data_ptr + data_offset); }); if constexpr (HasZeroPoint) { @@ -551,7 +551,7 @@ SQ4BitGemmM1Kernel_BlkLen16_CompFp32_lasx( float* SumPtr = CRowPtr; - int64_t nblk = (CountN) - NCols4; + int64_t nblk = static_cast(CountN - NCols4); while (nblk >= 0) { ComputeDotProducts_BlkLen16_CompFp32_lasx( BlkLen16, @@ -560,13 +560,13 @@ SQ4BitGemmM1Kernel_BlkLen16_CompFp32_lasx( BiasPtr ); - SafeInt data_offset = SafeInt(StrideQuantBData) * NCols4; - SafeInt scale_offset = SafeInt(StrideQuantBScale) * NCols4; - QuantBDataColPtr += data_offset.Value(); - QuantBScaleColPtr += scale_offset.Value(); + size_t data_offset = SafeInt(StrideQuantBData) * NCols4; + size_t scale_offset = SafeInt(StrideQuantBScale) * NCols4; + QuantBDataColPtr += data_offset; + QuantBScaleColPtr += scale_offset; if constexpr (HasZeroPoint) { - SafeInt zeropoint_offset = SafeInt(StrideQuantBZeroPoint) * NCols4; - QuantBZeroPointColPtr += zeropoint_offset.Value(); + size_t zeropoint_offset = SafeInt(StrideQuantBZeroPoint) * NCols4; + QuantBZeroPointColPtr += zeropoint_offset; } BiasPtr += BiasPtr != nullptr ? NCols4 : 0; @@ -650,13 +650,13 @@ SQ4BitGemmM1Kernel_BlkLen32Plus_CompFp32_lasx( ); } - SafeInt data_offset = SafeInt(StrideQuantBData) * NCols4; - SafeInt scale_offset = SafeInt(StrideQuantBScale) * NCols4; - QuantBDataColPtr += data_offset.Value(); - QuantBScaleColPtr += scale_offset.Value(); + size_t data_offset = SafeInt(StrideQuantBData) * NCols4; + size_t scale_offset = SafeInt(StrideQuantBScale) * NCols4; + QuantBDataColPtr += data_offset; + QuantBScaleColPtr += scale_offset; if constexpr (HasZeroPoint) { - SafeInt zeropoint_offset = SafeInt(StrideQuantBZeroPoint) * NCols4; - QuantBZeroPointColPtr += zeropoint_offset.Value(); + size_t zeropoint_offset = SafeInt(StrideQuantBZeroPoint) * NCols4; + QuantBZeroPointColPtr += zeropoint_offset; } BiasPtr += BiasPtr != nullptr ? NCols4 : 0; @@ -768,18 +768,18 @@ Q4BitBlkDequantBForSgemmBlkLen16_CompFp32_lasx( for (size_t k = 0; k < BlockCountK; k++) { // count # of tiles plus blks of the current tile from top const size_t tile_count = col / GemmFloatKernelWidth16; - SafeInt offset = SafeInt(tile_count * CountK + k * BlkLen16) * GemmFloatKernelWidth16; - float* dst_ptr = FpData + offset.Value(); + size_t offset = SafeInt(tile_count * CountK + k * BlkLen16) * GemmFloatKernelWidth16; + float* dst_ptr = FpData + offset; if (col % GemmFloatKernelWidth16 >= NCols8) { // for the second half to 16 width tile dst_ptr += NCols8; } - SafeInt b_data_offset = SafeInt(col) * b_data_col_stride_in_bytes + k * blk_data_size_in_bytes; - SafeInt b_scale_offset = SafeInt(col) * BlockCountK + k; - SafeInt b_zp_offset = SafeInt(col) * zp_col_stride_in_bytes + k / 2; - const std::byte* b_data_ptr = QuantBData + b_data_offset.Value(); - const float* scale_ptr = QuantBScale + b_scale_offset.Value(); - const std::byte* zp_ptr = QuantBZeroPoint + b_zp_offset.Value(); + size_t b_data_offset = SafeInt(col) * b_data_col_stride_in_bytes + k * blk_data_size_in_bytes; + size_t b_scale_offset = SafeInt(col) * BlockCountK + k; + size_t b_zp_offset = SafeInt(col) * zp_col_stride_in_bytes + k / 2; + const std::byte* b_data_ptr = QuantBData + b_data_offset; + const float* scale_ptr = QuantBScale + b_scale_offset; + const std::byte* zp_ptr = QuantBZeroPoint + b_zp_offset; bool is_lower = (k % 2) == 0; __m256i weight_16_epi16[NCols8]; @@ -910,19 +910,19 @@ Q4BitBlkDequantBForSgemmBlkLen32AndMore_CompFp32_lasx( const size_t cols = std::min(NCols8, CountN - col); for (size_t k = 0; k < BlockCountK; k++) { // count # of tiles plus blks of the current tile from top - const size_t tile_count = col / GemmFloatKernelWidth16; - SafeInt offset = SafeInt(tile_count * CountK + k * BlkLen) * GemmFloatKernelWidth16; - float* dst_ptr = FpData + offset.Value(); + const SafeInt tile_count = col / GemmFloatKernelWidth16; + size_t offset = tile_count * CountK + k * BlkLen * GemmFloatKernelWidth16; + float* dst_ptr = FpData + offset; if (col % GemmFloatKernelWidth16 >= NCols8) { // for the second half to 16 width tile dst_ptr += NCols8; } - SafeInt b_data_offset = SafeInt(col) * b_data_col_stride_in_bytes + k * blk_data_size_in_bytes; - SafeInt b_scale_offset = SafeInt(col) * BlockCountK + k; - SafeInt b_zp_offset = SafeInt(col) * zp_col_stride_in_bytes + k / 2; - const std::byte* b_data_ptr = QuantBData + b_data_offset.Value(); - const float* scale_ptr = QuantBScale + b_scale_offset.Value(); - const std::byte* zp_ptr = QuantBZeroPoint + b_zp_offset.Value(); + size_t b_data_offset = SafeInt(col) * b_data_col_stride_in_bytes + k * blk_data_size_in_bytes; + size_t b_scale_offset = SafeInt(col) * BlockCountK + k; + size_t b_zp_offset = SafeInt(col) * zp_col_stride_in_bytes + k / 2; + const std::byte* b_data_ptr = QuantBData + b_data_offset; + const float* scale_ptr = QuantBScale + b_scale_offset; + const std::byte* zp_ptr = QuantBZeroPoint + b_zp_offset; bool is_lower = (k % 2) == 0; for (size_t subblk = 0; subblk < BlkLen / SubblkLen32; subblk++) { diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_lasx_common.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_lasx_common.h index 508bcba8a2de7..f2b61a7f6e25d 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_lasx_common.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_lasx_common.h @@ -68,7 +68,7 @@ GetContinueLayoutOffsetSubBlk(size_t N, const size_t n, const size_t SubOrBlkCou } else { scale_dst_offset += SafeInt(k_sub_or_blk) * 4 + t; } - return scale_dst_offset.Value(); + return static_cast(scale_dst_offset); } static size_t @@ -87,7 +87,7 @@ GetContinueLayoutOffsetBlkInSubBlk(size_t N, const size_t n, const size_t BlockC scale_dst_offset += SafeInt(t) * blks_per_sub + b; } } - return scale_dst_offset.Value(); + return static_cast(scale_dst_offset); } static void @@ -106,23 +106,23 @@ ComputePackBlkSum_Lasx( const size_t n = tid / BlockCountK; const size_t k_blk = tid % BlockCountK; - const SafeInt src_blk_offset = SafeInt(n) * BlockCountK + k_blk; - float QuantBScale = QuantBScaleBegin[src_blk_offset.Value()]; + const size_t src_blk_offset = SafeInt(n) * BlockCountK + k_blk; + float QuantBScale = QuantBScaleBegin[src_blk_offset]; uint8_t zp = 8; if (QuantBZPBegin) { size_t ZPCountK = MlasDivRoundup(BlockCountK, 2); - SafeInt src_zp_offset = SafeInt(ZPCountK) * n + k_blk / 2; + size_t src_zp_offset = SafeInt(ZPCountK) * n + k_blk / 2; bool low_zp = k_blk % 2 == 0; - const std::byte* QuantBZP = QuantBZPBegin + src_zp_offset.Value(); + const std::byte* QuantBZP = QuantBZPBegin + src_zp_offset; const std::byte low_mask{0X0f}; zp = (uint8_t)(low_zp ? ((*QuantBZP) & low_mask) : ((*QuantBZP) >> 4)); } float result = -QuantBScale * zp; - const SafeInt dst_offset = ( SafeInt(n / 16) * BlockCountK + k_blk) * 16 + n % 16; - BlockSumBegin[dst_offset.Value()] = result; + size_t dst_offset = ( SafeInt(n / 16) * BlockCountK + k_blk) * 16 + n % 16; + BlockSumBegin[dst_offset] = result; if (BlkLen == 16) { } else if (BlkLen >= SubBlkLen) { @@ -162,8 +162,8 @@ PackQuantB( const size_t n = tid / SubBlkCountK; const size_t k_subblk = tid % SubBlkCountK; - const SafeInt src_data_offset = SafeInt(n) * BlockCountK * BlkDataSize + k_subblk * SubBlkDataSize; - const std::byte* QuantBData = QuantBDataBegin + src_data_offset.Value(); + size_t src_data_offset = SafeInt(n) * BlockCountK * BlkDataSize + k_subblk * SubBlkDataSize; + const std::byte* QuantBData = QuantBDataBegin + src_data_offset; size_t PackBytePairCount = SubBlkBytePairCount; size_t PackDataSize = SubBlkDataSize; @@ -192,7 +192,7 @@ PackQuantB( PackDataSize = BlkDataSize; const size_t k_blks_remaining = BlockCountK - (SubBlkCountK - 1) * SubBlkLen / BlkLen; for (size_t k = 0; k < k_blks_remaining; k++) { - const SafeInt k_blk = SafeInt(k_subblk) * SubBlkLen / BlkLen + k; + size_t k_blk = SafeInt(k_subblk) * SubBlkLen / BlkLen + k; if (BlkLen == 16) { // not to do the compute order layout yet std::byte* PackedQuantBData = PackedQuantBDataBegin + src_data_offset; @@ -202,7 +202,7 @@ PackQuantB( assert(SubBlkLen == 128); } else { int blks_per_sub = (int)(SubBlkLen / BlkLen); - const size_t dst_data_offset = GetContinueLayoutOffsetBlkInSubBlk(N, n, BlockCountK, k_blk.Value(), blks_per_sub); + const size_t dst_data_offset = GetContinueLayoutOffsetBlkInSubBlk(N, n, BlockCountK, k_blk, blks_per_sub); std::byte* PackedQuantBData = PackedQuantBDataBegin + dst_data_offset * BlkLen / 2; pack_subblk(QuantBData + k * BlkLen / 2, PackedQuantBData, PackBytePairCount, PackDataSize); } @@ -218,8 +218,8 @@ PackQuantB( pack_subblk(QuantBData, PackedQuantBData, PackBytePairCount, PackDataSize); } else { int blks_per_sub = (int)(SubBlkLen / BlkLen); - const SafeInt k_blk = SafeInt(k_subblk) * blks_per_sub; - const size_t dst_data_offset = GetContinueLayoutOffsetBlkInSubBlk(N, n, BlockCountK, k_blk.Value(), blks_per_sub); + size_t k_blk = SafeInt(k_subblk) * blks_per_sub; + const size_t dst_data_offset = GetContinueLayoutOffsetBlkInSubBlk(N, n, BlockCountK, k_blk, blks_per_sub); std::byte* PackedQuantBData = PackedQuantBDataBegin + dst_data_offset * BlkLen / 2; pack_subblk(QuantBData, PackedQuantBData, PackBytePairCount, PackDataSize); } @@ -311,29 +311,29 @@ Q8PackQuantB( const size_t c_4 = c & (~3), c_res = c & 3; const size_t r_subblk = tid % SubBlkCountK; - const SafeInt data_offset = SafeInt(c) * StrideN + r_subblk * SubBlkLen; - const std::byte* src = QuantBDataBegin + data_offset.Value(); + size_t data_offset = SafeInt(c) * StrideN + r_subblk * SubBlkLen; + const std::byte* src = QuantBDataBegin + data_offset; if (c_4 + 4 <= N) { // full 4 cols if (RemainderBlockCountK && r_subblk == SubBlkCountK - 1) { // remainder blocks - const SafeInt subblk_data_offset = SafeInt(c_4) * StrideN + r_subblk * SubBlkSize * 4 + c_res * BlkSize; + const size_t subblk_data_offset = SafeInt(c_4) * StrideN + r_subblk * SubBlkSize * 4 + c_res * BlkSize; std::byte* dest = - PackedQuantBDataBegin + subblk_data_offset.Value(); + PackedQuantBDataBegin + subblk_data_offset; for (size_t i = 0; i < RemainderBlockCountK; i++) { std::copy(src, src + BlkSize, dest); src += BlkSize; dest += BlkSize * 4; } } else { // full subblock - const SafeInt subblk_data_offset = SafeInt(c_4) * StrideN + r_subblk * SubBlkSize * 4 + c_res * SubBlkSize; + const size_t subblk_data_offset = SafeInt(c_4) * StrideN + r_subblk * SubBlkSize * 4 + c_res * SubBlkSize; std::byte* dest = - PackedQuantBDataBegin + subblk_data_offset.Value(); + PackedQuantBDataBegin + subblk_data_offset; std::copy(src, src + SubBlkSize, dest); } } else { // remainder cols - const SafeInt remain_data_offset = SafeInt(c) * StrideN + r_subblk * SubBlkSize; + const size_t remain_data_offset = SafeInt(c) * StrideN + r_subblk * SubBlkSize; std::byte* dest = - PackedQuantBDataBegin + remain_data_offset.Value(); + PackedQuantBDataBegin + remain_data_offset; std::copy(src, src + std::min(SubBlkSize, StrideN - r_subblk * SubBlkSize), dest); } } @@ -352,8 +352,8 @@ Q8ComputePackBlkSum( const size_t BlockCountK ) { - SafeInt size = SafeInt(N) * BlockCountK; - std::vector> QuantBScaleBeginCopy(size.Value()); + size_t size = SafeInt(N) * BlockCountK; + std::vector> QuantBScaleBeginCopy(size); std::copy(QuantBScaleBegin, QuantBScaleBegin + N * BlockCountK, QuantBScaleBeginCopy.begin()); MlasTrySimpleParallel(ThreadPool, N * BlockCountK, [&](ptrdiff_t tid) { @@ -361,30 +361,30 @@ Q8ComputePackBlkSum( const size_t n_4 = n & (~3), n_res = n & 3; const size_t k_blk = tid % BlockCountK; - const SafeInt src_blk_offset = SafeInt(n) * BlockCountK + k_blk; - const float& QuantBScale = QuantBScaleBeginCopy[src_blk_offset.Value()]; + const size_t src_blk_offset = SafeInt(n) * BlockCountK + k_blk; + const float& QuantBScale = QuantBScaleBeginCopy[src_blk_offset]; uint8_t zp = 128; if (QuantBZPBegin) { - const std::byte* QuantBZP = QuantBZPBegin + src_blk_offset.Value(); + const std::byte* QuantBZP = QuantBZPBegin + src_blk_offset; zp = (uint8_t)(*QuantBZP); } - const SafeInt dst_offset = ( SafeInt(n / 16) * BlockCountK + k_blk) * 16 + n % 16; - *(BlockSumBegin + dst_offset.Value()) = -QuantBScale * zp; + const size_t dst_offset = ( SafeInt(n / 16) * BlockCountK + k_blk) * 16 + n % 16; + *(BlockSumBegin + dst_offset) = -QuantBScale * zp; if (n_4 + 4 > N) { - SafeInt ptr_offset = SafeInt(n) * BlockCountK + k_blk; - *(QuantBScaleBegin + ptr_offset.Value()) = QuantBScale; + size_t ptr_offset = SafeInt(n) * BlockCountK + k_blk; + *(QuantBScaleBegin + ptr_offset) = QuantBScale; } else if (BlkLen >= SubBlkLen) { - SafeInt ptr_offset = SafeInt(n_4) * BlockCountK + k_blk * 4 + n_res; - *(QuantBScaleBegin + ptr_offset.Value()) = QuantBScale; + size_t ptr_offset = SafeInt(n_4) * BlockCountK + k_blk * 4 + n_res; + *(QuantBScaleBegin + ptr_offset) = QuantBScale; } else { size_t blks_per_sub = SubBlkLen / BlkLen; size_t remainder_blk = BlockCountK % blks_per_sub; size_t sub_blk_count_k = MlasDivRoundup(BlockCountK, blks_per_sub); size_t k_subblk = k_blk / blks_per_sub; size_t k_blk_res = k_blk % blks_per_sub; - SafeInt dest_offset; + size_t dest_offset; if (remainder_blk && k_subblk == sub_blk_count_k - 1) { // remainder blocks dest_offset = SafeInt(n_4) * BlockCountK + k_blk * 4 + n_res; @@ -392,7 +392,7 @@ Q8ComputePackBlkSum( dest_offset = SafeInt(n_4) * BlockCountK + k_subblk * blks_per_sub * 4 + n_res * blks_per_sub + k_blk_res; } - *(QuantBScaleBegin + dest_offset.Value()) = QuantBScale; + *(QuantBScaleBegin + dest_offset) = QuantBScale; } }); }