diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index c0ab948b41fff..6e88812b4822c 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -467,6 +467,8 @@ else() list(APPEND mlas_platform_srcs ${MLAS_SRC_DIR}/sve/mlasi_sve.h) list(APPEND mlas_platform_srcs ${MLAS_SRC_DIR}/sve/elementwise_sve.cpp) set_source_files_properties(${MLAS_SRC_DIR}/sve/elementwise_sve.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+sve+fp16 ") + list(APPEND mlas_platform_srcs ${MLAS_SRC_DIR}/sve/sgemm_sve.cpp) + set_source_files_properties(${MLAS_SRC_DIR}/sve/sgemm_sve.cpp PROPERTIES COMPILE_FLAGS "-march=armv8.2-a+sve -O3 -ffast-math -funroll-loops") list(APPEND mlas_private_compile_definitions MLAS_USE_SVE) endif() diff --git a/onnxruntime/core/mlas/lib/sgemm.cpp b/onnxruntime/core/mlas/lib/sgemm.cpp index 84c26eb005c39..d4827e5f93083 100644 --- a/onnxruntime/core/mlas/lib/sgemm.cpp +++ b/onnxruntime/core/mlas/lib/sgemm.cpp @@ -17,6 +17,10 @@ Module Name: #include "mlasi.h" +#if defined(MLAS_USE_SVE) +#include "sve/mlasi_sve.h" +#endif + // // Define the number of rows from matrix A to transpose to a local buffer. // @@ -259,8 +263,15 @@ Return Value: do { -#if defined(MLAS_NEON_INTRINSICS) - vst4q_f32(D, vld4q_f32(b)); +#if defined(MLAS_USE_SVE) && defined(MLAS_NEON_INTRINSICS) + if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmSve()) { + SVE_LOAD_STORE(D, b); + } + else{ + vst4q_f32(D, vld4q_f32(b)); + } +#elif defined(MLAS_NEON_INTRINSICS) + vst4q_f32(D, vld4q_f32(b)); #else MLAS_FLOAT32X4 t0 = MlasLoadFloat32x4(&b[0]); MLAS_FLOAT32X4 t1 = MlasLoadFloat32x4(&b[4]); @@ -303,8 +314,14 @@ Return Value: float* d = D; const float* b = B; -#if defined(MLAS_NEON_INTRINSICS) - vst4q_f32(d, ZeroFloat32x4x4); +#if defined(MLAS_USE_SVE) && defined(MLAS_NEON_INTRINSICS) + if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmSve()) { + SVE_ZERO_INITIALIZE(d); + } else { + vst4q_f32(d, ZeroFloat32x4x4); + } +#elif defined(MLAS_NEON_INTRINSICS) + vst4q_f32(d, ZeroFloat32x4x4); #else MlasStoreAlignedFloat32x4(d, ZeroFloat32x4); MlasStoreAlignedFloat32x4(d + 4, ZeroFloat32x4); @@ -486,6 +503,21 @@ Return Value: x -= 4; } +#elif defined(MLAS_USE_SVE) + if(MLAS_CPUIDINFO::GetCPUIDInfo().HasArmSve()) { + SVE_TRANSPOSE(D,b,ldb,x); + } + else + { + while (x >= 4) { + + MlasSgemmTransposePackBNx4<16>(&D[0], &b[0], ldb); + + D += 16 * 4; + b += 4; + x -= 4; + } + } #else while (x >= 4) { @@ -564,8 +596,15 @@ Return Value: const float* b = B; if ((CountY & 8) != 0) { - - MlasSgemmTransposePackBNx4<8>(&d[0], &b[0], ldb); + #if defined(MLAS_USE_SVE) + if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmSve()){ + MlasSveTransposePackBNx4<8>(&d[0], &b[0], ldb);} + else{ + MlasSgemmTransposePackBNx4<8>(&d[0], &b[0], ldb); + } + #else + MlasSgemmTransposePackBNx4<8>(&d[0], &b[0], ldb); + #endif d += 8; b += ldb * 8; @@ -584,7 +623,15 @@ Return Value: if ((CountY & 4) != 0) { - MlasSgemmTransposePackBNx4<4>(&d[0], &b[0], ldb); + #if defined(MLAS_USE_SVE) + if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmSve()){ + MlasSveTransposePackBNx4<4>(&d[0], &b[0], ldb);} + else{ + MlasSgemmTransposePackBNx4<4>(&d[0], &b[0], ldb); + } + #else + MlasSgemmTransposePackBNx4<4>(&d[0], &b[0], ldb); + #endif d += 4; b += ldb * 4; @@ -631,7 +678,19 @@ Return Value: if ((CountY & 1) != 0) { -#if defined(MLAS_NEON_INTRINSICS) +#if defined(MLAS_USE_SVE) && defined(MLAS_NEON_INTRINSICS) + if(MLAS_CPUIDINFO::GetCPUIDInfo().HasArmSve()) + { + SCATTER_STORE(&d[0],&b[0]); + } + else{ + MLAS_FLOAT32X4 t0 = MlasLoadFloat32x4(&b[0]); + + MlasStoreLaneFloat32x4<0>(&d[0], t0); + MlasStoreLaneFloat32x4<1>(&d[16], t0); + MlasStoreLaneFloat32x4<2>(&d[32], t0); + MlasStoreLaneFloat32x4<3>(&d[48], t0);} +#elif defined(MLAS_NEON_INTRINSICS) MLAS_FLOAT32X4 t0 = MlasLoadFloat32x4(&b[0]); MlasStoreLaneFloat32x4<0>(&d[0], t0); @@ -1004,8 +1063,7 @@ Return Value: #endif MLAS_FORCEINLINE -float* -MlasSgemmKernelLoop( +float* MlasSgemmKernelLoop( const float* A, const float* B, float* C, @@ -1059,18 +1117,41 @@ Return Value: { while (CountM > 0) { - size_t RowsHandled; + size_t RowsHandled = 0; #if (defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_S390X) || defined(MLAS_TARGET_LARCH64)) && !defined(FORCE_GENERIC_ALGORITHMS) RowsHandled = GetMlasPlatform().GemmFloatKernel(A, B, C, CountK, CountM, CountN, lda, ldc, alpha, ZeroMode); + #else + if (ZeroMode) { + +#if defined(MLAS_USE_SVE) + if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmSve()) { + RowsHandled = MlasSgemmKernelZero_sve(A, B, C, CountK, CountM, CountN, lda, ldc, alpha); + } else { + RowsHandled = MlasSgemmKernelZero(A, B, C, CountK, CountM, CountN, lda, ldc, alpha); + } +#else RowsHandled = MlasSgemmKernelZero(A, B, C, CountK, CountM, CountN, lda, ldc, alpha); +#endif + } else { + +#if defined(MLAS_USE_SVE) + if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmSve()) { + RowsHandled = MlasSgemmKernelAdd_sve(A, B, C, CountK, CountM, CountN, lda, ldc, alpha); + } else { + RowsHandled = MlasSgemmKernelAdd(A, B, C, CountK, CountM, CountN, lda, ldc, alpha); + } +#else RowsHandled = MlasSgemmKernelAdd(A, B, C, CountK, CountM, CountN, lda, ldc, alpha); - } #endif + } + +#endif // platform check + C += ldc * RowsHandled; A += lda * RowsHandled; CountM -= RowsHandled; @@ -1079,6 +1160,7 @@ Return Value: return C; } + void MlasSgemmOperation( CBLAS_TRANSPOSE TransA, diff --git a/onnxruntime/core/mlas/lib/sve/mlasi_sve.h b/onnxruntime/core/mlas/lib/sve/mlasi_sve.h index 67a4bf453dd05..531425eab5404 100644 --- a/onnxruntime/core/mlas/lib/sve/mlasi_sve.h +++ b/onnxruntime/core/mlas/lib/sve/mlasi_sve.h @@ -1,5 +1,4 @@ /*++ - Copyright 2025 FUJITSU LIMITED Module Name: @@ -14,9 +13,10 @@ Module Name: #pragma once -#include "../mlasi.h" #include // SVE intrinsic header +#include "../mlasi.h" + #ifndef __clang__ #pragma GCC push_options #pragma GCC target("arch=armv8.2-a+sve") @@ -28,93 +28,72 @@ Module Name: #define MLAS_SVE_TARGET #endif +#define PACKED_B_BLOCK_WIDTH 16 typedef svfloat32_t MLAS_SVFLOAT32; typedef svint32_t MLAS_SVINT32; typedef svuint32_t MLAS_SVUINT32; typedef svbool_t MLAS_SVBOOL; // function decarations -MLAS_FORCEINLINE -MLAS_SVFLOAT32 -MlasSveComputeExpVector( - MLAS_SVBOOL Pred, - MLAS_SVFLOAT32 Vector -); +MLAS_FORCEINLINE MLAS_SVFLOAT32 +MlasSveComputeExpVector(MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector); -void -MLASCALL -MlasSveComputeExpF32Kernel( - const float* Input, - float* Output, - size_t N -); +void MLASCALL +MlasSveComputeExpF32Kernel(const float* Input, float* Output, size_t N); MLAS_FORCEINLINE MLAS_SVFLOAT32 -MlasSveComputeSumExpVector( - MLAS_SVBOOL Pred, - MLAS_SVFLOAT32 Vector, - MLAS_SVFLOAT32 NegativeMaximumVector -); +MlasSveComputeSumExpVector(MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector, MLAS_SVFLOAT32 NegativeMaximumVector); -float -MLASCALL -MlasSveComputeSumExpF32Kernel( - const float* Input, - float* Output, - size_t N, - const float* NegativeMaximum -); +float MLASCALL +MlasSveComputeSumExpF32Kernel(const float* Input, float* Output, size_t N, const float* NegativeMaximum); float MLASCALL -MlasSveReduceMaximumF32Kernel( - const float* Input, - size_t N -); +MlasSveReduceMaximumF32Kernel(const float* Input, size_t N); -void -MLASCALL -MlasSveReduceMinimumMaximumF32Kernel( - const float* Input, - float* Min, - float* Max, - size_t N -); +void MLASCALL +MlasSveReduceMinimumMaximumF32Kernel(const float* Input, float* Min, float* Max, size_t N); -void -MLASCALL -MlasSveComputeSoftmaxOutputF32Kernel( - float* Output, - size_t N, - const float* Parameters -); +void MLASCALL +MlasSveComputeSoftmaxOutputF32Kernel(float* Output, size_t N, const float* Parameters); -void -MLASCALL -MlasSveComputeLogSoftmaxOutputF32Kernel( - const float* Input, - float* Output, - size_t N, - const float* Parameters -); +void MLASCALL +MlasSveComputeLogSoftmaxOutputF32Kernel(const float* Input, float* Output, size_t N, const float* Parameters); -void -MLASCALL -MlasSveErfKernel( - const float* Input, - float* Output, - size_t N -); +void MLASCALL +MlasSveErfKernel(const float* Input, float* Output, size_t N); -void -MLASCALL -MlasSveLogisticKernel( - const float* Input, - float* Output, - size_t N -); +void MLASCALL +MlasSveLogisticKernel(const float* Input, float* Output, size_t N); + +// MLAS API for SVE intrinsics +size_t MLASCALL +MlasSgemmKernelAdd_sve(const float* A, const float* B, float* C, size_t CountK, size_t CountM, size_t CountN, size_t lda, size_t ldc, float alpha); + +size_t MLASCALL +MlasSgemmKernelZero_sve(const float* A, const float* B, float* C, size_t CountK, size_t CountM, size_t CountN, size_t lda, size_t ldc, float alpha); + +void MLAS_SVE_TARGET MLASCALL +SVE_ZERO_INITIALIZE(float* d); + +void MLAS_SVE_TARGET MLASCALL +SVE_LOAD_STORE(float* D, const float* b); -//MLAS API for SVE intrinsics +void MLAS_SVE_TARGET MLASCALL +SCATTER_STORE(float* d, const float* b); + +void MLAS_SVE_TARGET MLASCALL +SVE_TRANSPOSE(float*& D, const float*& b, size_t ldb, size_t& x); + +MLAS_SVE_TARGET +inline int +VL() +{ + static int fp32Lanes = svcntw(); // evaluated only once, the first time it's called + return fp32Lanes; +} + +// MLAS API for SVE intrinsics MLAS_SVE_TARGET MLAS_FORCEINLINE @@ -185,8 +164,8 @@ MLAS_SVE_TARGET MLAS_FORCEINLINE MLAS_SVINT32 MlasSveAddInt32(MLAS_SVBOOL Pred, MLAS_SVINT32 Vector1, MLAS_SVINT32 Vector2) -{ - return svadd_s32_m(Pred, Vector1, Vector2); +{ + return svadd_s32_m(Pred, Vector1, Vector2); } MLAS_SVE_TARGET @@ -243,26 +222,26 @@ MLAS_SVINT32 MlasSveBlendInt32(MLAS_SVBOOL Pred, MLAS_SVINT32 Vector1, MLAS_SVINT32 Vector2, MLAS_SVINT32 Selection) { return MlasSveOrInt32( - Pred, - MlasSveAndInt32(Pred, Vector2, Selection), + Pred, + MlasSveAndInt32(Pred, Vector2, Selection), MlasSveAndNotInt32(Pred, Selection, Vector1) ); } -template +template MLAS_SVE_TARGET -MLAS_FORCEINLINE -MLAS_SVUINT32 -MlasSveShiftLeftUInt32(MLAS_SVBOOL Pred, MLAS_SVUINT32 Vector) + MLAS_FORCEINLINE + MLAS_SVUINT32 + MlasSveShiftLeftUInt32(MLAS_SVBOOL Pred, MLAS_SVUINT32 Vector) { return svlsl_n_u32_z(Pred, Vector, ShiftCount); } -template +template MLAS_SVE_TARGET -MLAS_FORCEINLINE -MLAS_SVINT32 -MlasSveShiftLeftInt32(MLAS_SVBOOL Pred, MLAS_SVINT32 Vector) + MLAS_FORCEINLINE + MLAS_SVINT32 + MlasSveShiftLeftInt32(MLAS_SVBOOL Pred, MLAS_SVINT32 Vector) { return svlsl_n_s32_z(Pred, Vector, ShiftCount); } @@ -347,11 +326,10 @@ MlasSveStoreFloat32(MLAS_SVBOOL Pred, float* Buffer, MLAS_SVFLOAT32 Vector) svst1_f32(Pred, Buffer, Vector); } -template +template MLAS_SVE_TARGET -MLAS_FORCEINLINE -void -MlasSveStoreLaneFloat32(float* Buffer, MLAS_SVFLOAT32 Vector) + MLAS_FORCEINLINE void + MlasSveStoreLaneFloat32(float* Buffer, MLAS_SVFLOAT32 Vector) { svbool_t Pred = svwhilelt_b32(Lane, Lane + 1); svst1_f32(Pred, Buffer, Vector); @@ -366,11 +344,10 @@ MlasSveStoreLowHalfFloat32(float* Buffer, MLAS_SVFLOAT32 Vector) svst1_f32(Pred, Buffer, Vector); } -template +template MLAS_SVE_TARGET -MLAS_FORCEINLINE -float -MlasSveExtractLaneFloat32(MLAS_SVFLOAT32 Vector) + MLAS_FORCEINLINE float + MlasSveExtractLaneFloat32(MLAS_SVFLOAT32 Vector) { float TmpBuffer[1]; svbool_t Pred = svwhilelt_b32(Lane, Lane + 1); @@ -415,7 +392,7 @@ MLAS_FORCEINLINE MLAS_SVFLOAT32 MlasSveMultiplyFloat32(MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector1, MLAS_SVFLOAT32 Vector2) { - return svmul_f32_m(Pred, Vector1, Vector2); + return svmul_f32_m(Pred, Vector1, Vector2); } MLAS_SVE_TARGET @@ -429,7 +406,7 @@ MlasSveExpFloat32(MLAS_SVUINT32 Vector) MLAS_SVE_TARGET MLAS_FORCEINLINE MLAS_SVFLOAT32 -MlasSveScaleFloat32(MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector1, MLAS_SVINT32 Vector2) +MlasSveScaleFloat32(MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector1, MLAS_SVINT32 Vector2) { return svscale_f32_m(Pred, Vector1, Vector2); } @@ -439,7 +416,7 @@ MLAS_FORCEINLINE MLAS_SVFLOAT32 MlasSveRoundINTFloat32(MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector) { - return svrintm_f32_z(Pred, Vector); + return svrintm_f32_z(Pred, Vector); } MLAS_SVE_TARGET @@ -482,10 +459,10 @@ MlasSveGreaterThanFloat32(MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector1, MLAS_SVFLOAT // Compare Vector1 and Vector2, return a predicate vector svbool_t cmp_mask = svcmpgt_f32(Pred, Vector1, Vector2); - //Convert predicate to uint32_t mask + // Convert predicate to uint32_t mask svuint32_t mask_bits = svdup_u32_z(cmp_mask, 0xFFFFFFFF); - //Reinterpret to float32 + // Reinterpret to float32 return svreinterpret_f32_u32(mask_bits); } @@ -496,7 +473,7 @@ MlasSveAndFloat32(MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector1, MLAS_SVFLOAT32 Vecto { return MlasSveReinterpretAsFloat32( MlasSveAndInt32( - Pred, + Pred, MlasSveReinterpretAsInt32(Vector1), MlasSveReinterpretAsInt32(Vector2) ) @@ -551,7 +528,7 @@ MLAS_SVFLOAT32 MlasSveBlendFloat32(MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector1, MLAS_SVFLOAT32 Vector2, MLAS_SVFLOAT32 Selection) { return MlasSveOrFloat32( - Pred, + Pred, MlasSveAndFloat32(Pred, Vector2, Selection), MlasSveAndFloat32(Pred, Vector1, Selection) ); @@ -613,8 +590,8 @@ MLAS_SVFLOAT32 MlasSvePowerOf2Float32(MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector) { MLAS_SVINT32 emm0 = MlasSveAddInt32( - Pred, - MlasSveCastToInt32(Pred, Vector), + Pred, + MlasSveCastToInt32(Pred, Vector), MlasSveBroadcastInt32(127) ); return MlasSveReinterpretAsFloat32(MlasSveShiftLeftInt32<23>(Pred, emm0)); @@ -628,6 +605,14 @@ MlasSveSelect(svbool_t Pred, MLAS_SVFLOAT32 TrueValue, MLAS_SVFLOAT32 FalseValue return svsel_f32(Pred, TrueValue, FalseValue); } +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT32 +MlasSvedupFloat32(float Vector) +{ + return svdup_f32(Vector); +} + MLAS_SVE_TARGET MLAS_FORCEINLINE MLAS_SVBOOL @@ -636,6 +621,33 @@ MlasSveCompareLessThan(svbool_t Pred, MLAS_SVFLOAT32 A, MLAS_SVFLOAT32 B) return svcmplt_f32(Pred, A, B); } +MLASCALL +inline void +Transpose_SVE512_4x4(float* D, const float* B, size_t ldb) +{ + const static int VL = svcntw(); + MLAS_SVBOOL p = svwhilelt_b32(0, VL / 4); + MLAS_SVBOOL p3 = svwhilelt_b32(0, VL / 2); + MLAS_SVBOOL p1 = svnot_b_z(svwhilelt_b32(0, VL), p); + p1 = svand_b_z(p3, p3, p1); + p3 = svrev_b32(p1); + MLAS_SVBOOL p4 = svrev_b32(p); + + MLAS_SVFLOAT32 t0 = MlasSveLoadFloat32(p, &B[ldb * 0]); + MLAS_SVFLOAT32 t1 = MlasSveLoadFloat32(p, &B[ldb * 1]); + MLAS_SVFLOAT32 t2 = MlasSveLoadFloat32(p, &B[ldb * 2]); + MLAS_SVFLOAT32 t3 = MlasSveLoadFloat32(p, &B[ldb * 3]); + + MLAS_SVFLOAT32 t02 = MlasSveInterleaveLowFloat32(t0, t2); + MLAS_SVFLOAT32 t13 = MlasSveInterleaveLowFloat32(t1, t3); + MLAS_SVFLOAT32 t0123 = MlasSveInterleaveLowFloat32(t02, t13); // This zips the first half together + + MlasSveStoreFloat32(p, D, t0123); + MlasSveStoreFloat32(p1, &D[12], t0123); + MlasSveStoreFloat32(p3, &D[24], t0123); + MlasSveStoreFloat32(p4, &D[36], t0123); +} + MLAS_SVE_TARGET MLAS_FORCEINLINE MLAS_SVBOOL @@ -644,10 +656,330 @@ MlasSveCompareGreaterThan(svbool_t Pred, MLAS_SVFLOAT32 A, MLAS_SVFLOAT32 B) return svcmpgt_f32(Pred, A, B); } +MLASCALL +inline void +Transpose_SVE256_4x4(float* D, const float* B, size_t ldb) +{ + const static int VL = svcntw(); + MLAS_SVBOOL p = svwhilelt_b32(0, VL / 2); + + MLAS_SVFLOAT32 t0 = MlasSveLoadFloat32(p, &B[ldb * 0]); + MLAS_SVFLOAT32 t1 = MlasSveLoadFloat32(p, &B[ldb * 1]); + MLAS_SVFLOAT32 t2 = MlasSveLoadFloat32(p, &B[ldb * 2]); + MLAS_SVFLOAT32 t3 = MlasSveLoadFloat32(p, &B[ldb * 3]); + + MLAS_SVBOOL p1 = svnot_b_z(svwhilelt_b32((int)0, VL), p); + MLAS_SVFLOAT32 t02 = MlasSveInterleaveLowFloat32(t0, t2); + MLAS_SVFLOAT32 t13 = MlasSveInterleaveLowFloat32(t1, t3); + MLAS_SVFLOAT32 first_t0123 = MlasSveInterleaveLowFloat32(t02, t13); // This zips the first half together + MLAS_SVFLOAT32 second_t0123 = MlasSveInterleaveHighFloat32(t02, t13); // This zips the second half together + + MlasSveStoreFloat32(p, D, first_t0123); + MlasSveStoreFloat32(p1, &D[12], first_t0123); + MlasSveStoreFloat32(p, &D[32], second_t0123); + MlasSveStoreFloat32(p1, &D[44], second_t0123); +} + +MLASCALL +inline void +Transpose_SVE128_4x4(float* D, const float* B, size_t ldb) +{ + const static int VL = svcntw(); + MLAS_SVBOOL p = svwhilelt_b32((int)0, VL); + + MLAS_SVFLOAT32 v1 = MlasSveLoadFloat32(p, &B[ldb * 0]); + MLAS_SVFLOAT32 v2 = MlasSveLoadFloat32(p, &B[ldb * 1]); + MLAS_SVFLOAT32 v4 = MlasSveLoadFloat32(p, &B[ldb * 2]); + MLAS_SVFLOAT32 v5 = MlasSveLoadFloat32(p, &B[ldb * 3]); + + MLAS_SVFLOAT32 v3 = MlasSveInterleaveLowFloat32(v1, v4); + v1 = MlasSveInterleaveHighFloat32(v1, v4); + + v4 = MlasSveInterleaveLowFloat32(v2, v5); + v2 = MlasSveInterleaveHighFloat32(v2, v5); + + v5 = MlasSveInterleaveLowFloat32(v3, v4); + v3 = MlasSveInterleaveHighFloat32(v3, v4); + + v4 = MlasSveInterleaveLowFloat32(v1, v2); + v1 = MlasSveInterleaveHighFloat32(v1, v2); + + MlasSveStoreFloat32(p, &D[0], v5); + MlasSveStoreFloat32(p, &D[16], v3); + MlasSveStoreFloat32(p, &D[32], v4); + MlasSveStoreFloat32(p, &D[48], v1); +} + +MLASCALL +MLAS_FORCEINLINE +void +Transpose_SVE256_8x8(float* D, const float* B, size_t ldb) +{ + const static int VL = svcntw(); + + MLAS_SVBOOL p = svwhilelt_b32((int)0, VL); + + MLAS_SVFLOAT32 v1 = MlasSveLoadFloat32(p, &B[ldb * 0]); + MLAS_SVFLOAT32 v2 = MlasSveLoadFloat32(p, &B[ldb * 1]); + MLAS_SVFLOAT32 v4 = MlasSveLoadFloat32(p, &B[ldb * 2]); + MLAS_SVFLOAT32 v5 = MlasSveLoadFloat32(p, &B[ldb * 3]); + + MLAS_SVFLOAT32 v6 = MlasSveLoadFloat32(p, &B[ldb * 4]); + MLAS_SVFLOAT32 v7 = MlasSveLoadFloat32(p, &B[ldb * 5]); + MLAS_SVFLOAT32 v8 = MlasSveLoadFloat32(p, &B[ldb * 6]); + MLAS_SVFLOAT32 v9 = MlasSveLoadFloat32(p, &B[ldb * 7]); + + // First mix + MLAS_SVFLOAT32 v3 = MlasSveInterleaveLowFloat32(v1, v6); + v1 = MlasSveInterleaveHighFloat32(v1, v6); + + v6 = MlasSveInterleaveLowFloat32(v2, v7); + v2 = MlasSveInterleaveHighFloat32(v2, v7); + + v7 = MlasSveInterleaveLowFloat32(v4, v8); + v4 = MlasSveInterleaveHighFloat32(v4, v8); + + v8 = MlasSveInterleaveLowFloat32(v5, v9); + + v5 = MlasSveInterleaveHighFloat32(v5, v9); + + // Second mix + + v9 = MlasSveInterleaveLowFloat32(v3, v7); + v3 = MlasSveInterleaveHighFloat32(v3, v7); + + v7 = MlasSveInterleaveLowFloat32(v6, v8); + v6 = MlasSveInterleaveHighFloat32(v6, v8); + + v8 = MlasSveInterleaveLowFloat32(v1, v4); + v1 = MlasSveInterleaveHighFloat32(v1, v4); + + v4 = MlasSveInterleaveLowFloat32(v2, v5); + v2 = MlasSveInterleaveHighFloat32(v2, v5); + + // Third mix + v5 = MlasSveInterleaveLowFloat32(v9, v7); + v9 = MlasSveInterleaveHighFloat32(v9, v7); + + v7 = MlasSveInterleaveLowFloat32(v8, v4); + v8 = MlasSveInterleaveHighFloat32(v8, v4); + + v4 = MlasSveInterleaveLowFloat32(v3, v6); + v3 = MlasSveInterleaveHighFloat32(v3, v6); + + v6 = MlasSveInterleaveLowFloat32(v1, v2); + v1 = MlasSveInterleaveHighFloat32(v1, v2); + + // Store the results + + MlasSveStoreFloat32(p, &D[0], v5); + MlasSveStoreFloat32(p, &D[16], v9); + MlasSveStoreFloat32(p, &D[32], v4); + MlasSveStoreFloat32(p, &D[48], v3); + MlasSveStoreFloat32(p, &D[64], v7); + MlasSveStoreFloat32(p, &D[80], v8); + MlasSveStoreFloat32(p, &D[96], v6); + MlasSveStoreFloat32(p, &D[112], v1); +} + +MLASCALL +inline void +Transpose_SVE512_16x16(float* D, const float* B, size_t ldb) +{ + const static int VL = svcntw(); + MLAS_SVBOOL p = svwhilelt_b32((int)0, VL); + + MLAS_SVFLOAT32 v1 = MlasSveLoadFloat32(p, &B[ldb * 0]); + MLAS_SVFLOAT32 v2 = MlasSveLoadFloat32(p, &B[ldb * 1]); + MLAS_SVFLOAT32 v3 = MlasSveLoadFloat32(p, &B[ldb * 2]); + MLAS_SVFLOAT32 v4 = MlasSveLoadFloat32(p, &B[ldb * 3]); + + MLAS_SVFLOAT32 v5 = MlasSveLoadFloat32(p, &B[ldb * 4]); + MLAS_SVFLOAT32 v6 = MlasSveLoadFloat32(p, &B[ldb * 5]); + MLAS_SVFLOAT32 v7 = MlasSveLoadFloat32(p, &B[ldb * 6]); + MLAS_SVFLOAT32 v8 = MlasSveLoadFloat32(p, &B[ldb * 7]); + + MLAS_SVFLOAT32 v9 = MlasSveLoadFloat32(p, &B[ldb * 8]); + MLAS_SVFLOAT32 v10 = MlasSveLoadFloat32(p, &B[ldb * 9]); + MLAS_SVFLOAT32 v11 = MlasSveLoadFloat32(p, &B[ldb * 10]); + MLAS_SVFLOAT32 v12 = MlasSveLoadFloat32(p, &B[ldb * 11]); + + MLAS_SVFLOAT32 v13 = MlasSveLoadFloat32(p, &B[ldb * 12]); + MLAS_SVFLOAT32 v14 = MlasSveLoadFloat32(p, &B[ldb * 13]); + MLAS_SVFLOAT32 v15 = MlasSveLoadFloat32(p, &B[ldb * 14]); + MLAS_SVFLOAT32 v16 = MlasSveLoadFloat32(p, &B[ldb * 15]); + + /*========= FIRST MIX ==============*/ + + MLAS_SVFLOAT32 v17 = MlasSveInterleaveLowFloat32(v1, v9); + MLAS_SVFLOAT32 v18 = MlasSveInterleaveHighFloat32(v1, v9); + + MLAS_SVFLOAT32 v19 = MlasSveInterleaveLowFloat32(v2, v10); + MLAS_SVFLOAT32 v20 = MlasSveInterleaveHighFloat32(v2, v10); + + MLAS_SVFLOAT32 v21 = MlasSveInterleaveLowFloat32(v3, v11); + MLAS_SVFLOAT32 v22 = MlasSveInterleaveHighFloat32(v3, v11); + + MLAS_SVFLOAT32 v23 = MlasSveInterleaveLowFloat32(v4, v12); + MLAS_SVFLOAT32 v24 = MlasSveInterleaveHighFloat32(v4, v12); + + // + + MLAS_SVFLOAT32 v25 = MlasSveInterleaveLowFloat32(v5, v13); + MLAS_SVFLOAT32 v26 = MlasSveInterleaveHighFloat32(v5, v13); + + MLAS_SVFLOAT32 v27 = MlasSveInterleaveLowFloat32(v6, v14); + MLAS_SVFLOAT32 v28 = MlasSveInterleaveHighFloat32(v6, v14); + + MLAS_SVFLOAT32 v29 = MlasSveInterleaveLowFloat32(v7, v15); + MLAS_SVFLOAT32 v30 = MlasSveInterleaveHighFloat32(v7, v15); + + MLAS_SVFLOAT32 v31 = MlasSveInterleaveLowFloat32(v8, v16); + MLAS_SVFLOAT32 v32 = MlasSveInterleaveHighFloat32(v8, v16); + + /*========= SECOND MIX ==============*/ + + v1 = MlasSveInterleaveLowFloat32(v17, v25); + v9 = MlasSveInterleaveHighFloat32(v17, v25); + + v2 = MlasSveInterleaveLowFloat32(v18, v26); + v10 = MlasSveInterleaveHighFloat32(v18, v26); + + v3 = MlasSveInterleaveLowFloat32(v19, v27); + v11 = MlasSveInterleaveHighFloat32(v19, v27); + + v4 = MlasSveInterleaveLowFloat32(v20, v28); + v12 = MlasSveInterleaveHighFloat32(v20, v28); + + // + v5 = MlasSveInterleaveLowFloat32(v21, v29); + v13 = MlasSveInterleaveHighFloat32(v21, v29); + + v6 = MlasSveInterleaveLowFloat32(v22, v30); + v14 = MlasSveInterleaveHighFloat32(v22, v30); + + v7 = MlasSveInterleaveLowFloat32(v23, v31); + v15 = MlasSveInterleaveHighFloat32(v23, v31); + + v8 = MlasSveInterleaveLowFloat32(v24, v32); + v16 = MlasSveInterleaveHighFloat32(v24, v32); + + /*======= Third Mix =================*/ + + v17 = MlasSveInterleaveLowFloat32(v1, v5); + v25 = MlasSveInterleaveHighFloat32(v1, v5); + + v18 = MlasSveInterleaveLowFloat32(v9, v13); + v26 = MlasSveInterleaveHighFloat32(v9, v13); + + v19 = MlasSveInterleaveLowFloat32(v2, v6); + v27 = MlasSveInterleaveHighFloat32(v2, v6); + + v20 = MlasSveInterleaveLowFloat32(v10, v14); + v28 = MlasSveInterleaveHighFloat32(v10, v14); + + v21 = MlasSveInterleaveLowFloat32(v3, v7); + v29 = MlasSveInterleaveHighFloat32(v3, v7); + + v22 = MlasSveInterleaveLowFloat32(v11, v15); + v30 = MlasSveInterleaveHighFloat32(v11, v15); + + v23 = MlasSveInterleaveLowFloat32(v4, v8); + v31 = MlasSveInterleaveHighFloat32(v4, v8); + + v24 = MlasSveInterleaveLowFloat32(v12, v16); + v32 = MlasSveInterleaveHighFloat32(v12, v16); + + /*======== Final Mix ================*/ + + v1 = MlasSveInterleaveLowFloat32(v17, v21); + v9 = MlasSveInterleaveHighFloat32(v17, v21); + + v2 = MlasSveInterleaveLowFloat32(v25, v29); + v10 = MlasSveInterleaveHighFloat32(v25, v29); + + v3 = MlasSveInterleaveLowFloat32(v18, v22); + v11 = MlasSveInterleaveHighFloat32(v18, v22); + + v4 = MlasSveInterleaveLowFloat32(v26, v30); + v12 = MlasSveInterleaveHighFloat32(v26, v30); + + v5 = MlasSveInterleaveLowFloat32(v19, v23); + v13 = MlasSveInterleaveHighFloat32(v19, v23); + + v6 = MlasSveInterleaveLowFloat32(v27, v31); + v14 = MlasSveInterleaveHighFloat32(v27, v31); + + v7 = MlasSveInterleaveLowFloat32(v20, v24); + v15 = MlasSveInterleaveHighFloat32(v20, v24); + + v8 = MlasSveInterleaveLowFloat32(v28, v32); + v16 = MlasSveInterleaveHighFloat32(v28, v32); + + // store the result. + + MlasSveStoreFloat32(p, &D[0], v1); + MlasSveStoreFloat32(p, &D[16], v9); + MlasSveStoreFloat32(p, &D[32], v2); + MlasSveStoreFloat32(p, &D[48], v10); + // + MlasSveStoreFloat32(p, &D[64], v3); + MlasSveStoreFloat32(p, &D[80], v11); + MlasSveStoreFloat32(p, &D[96], v4); + MlasSveStoreFloat32(p, &D[112], v12); + // + MlasSveStoreFloat32(p, &D[128], v5); + MlasSveStoreFloat32(p, &D[144], v13); + MlasSveStoreFloat32(p, &D[160], v6); + MlasSveStoreFloat32(p, &D[176], v14); + // + MlasSveStoreFloat32(p, &D[192], v7); + MlasSveStoreFloat32(p, &D[208], v15); + MlasSveStoreFloat32(p, &D[224], v8); + MlasSveStoreFloat32(p, &D[240], v16); +} + +template +inline void +TransposePackBNx8( + float* D, + const float* B, + size_t ldb +) +{ + for (unsigned n = 0; n < N / 8; n++) { + Transpose_SVE256_8x8(D, B, ldb); + D += 8; + B += ldb * 8; + } +} + +MLAS_SVE_TARGET +template +void +MlasSveTransposePackBNx4( + float* D, + const float* B, + size_t ldb +) +{ + for (unsigned n = 0; n < N / 4; n++) { + if (VL() == 16) { + Transpose_SVE512_4x4(&D[0], &B[0], ldb); + } else if (VL() == 8) { + Transpose_SVE256_4x4(&D[0], &B[0], ldb); + } else if (VL() == 4) { + Transpose_SVE128_4x4(&D[0], &B[0], ldb); + } + + D += 4; + B += ldb * 4; + } +} + // GCC: Pop options after SVE-specific functions #ifndef __clang__ #pragma GCC pop_options #endif -#endif - +#endif \ No newline at end of file diff --git a/onnxruntime/core/mlas/lib/sve/sgemm_sve.cpp b/onnxruntime/core/mlas/lib/sve/sgemm_sve.cpp new file mode 100644 index 0000000000000..5863b5ba0e3fa --- /dev/null +++ b/onnxruntime/core/mlas/lib/sve/sgemm_sve.cpp @@ -0,0 +1,586 @@ +/*++ + +Copyright 2025 FUJITSU LIMITED + +Module Name: + + sgemm_sve.cpp + +Abstract: + + This module contains the implementation of SVE-based sgemm operations +--*/ + +#ifdef __ARM_FEATURE_SVE + +#include "mlasi_sve.h" + +template +inline void +processrows_8( + const float* __restrict__ a, + const float* __restrict__ b, + float* __restrict__ res, + size_t k, + size_t n, + size_t lda, + size_t ldc, + float alpha, + size_t vl +) +{ + constexpr size_t k_step = 64; // Tunable tile size + svfloat32_t zero_vec = MlasSvedupFloat32(0.f); + for (size_t col = 0; col < n; col += vl) { + svbool_t pg = svwhilelt_b32(col, n); + // Output pointers + float* out0 = res + 0 * ldc + col; + float* out1 = res + 1 * ldc + col; + float* out2 = res + 2 * ldc + col; + float* out3 = res + 3 * ldc + col; + float* out4 = res + 4 * ldc + col; + float* out5 = res + 5 * ldc + col; + float* out6 = res + 6 * ldc + col; + float* out7 = res + 7 * ldc + col; + // Accumulators initialized to zero + svfloat32_t acc0 = zero_vec; + svfloat32_t acc1 = zero_vec; + svfloat32_t acc2 = zero_vec; + svfloat32_t acc3 = zero_vec; + svfloat32_t acc4 = zero_vec; + svfloat32_t acc5 = zero_vec; + svfloat32_t acc6 = zero_vec; + svfloat32_t acc7 = zero_vec; + for (size_t k_block = 0; k_block < k; k_block += k_step) { + size_t k_max = std::min(k_block + k_step, k); + // Temporary partial sums + svfloat32_t partial0 = zero_vec; + svfloat32_t partial1 = zero_vec; + svfloat32_t partial2 = zero_vec; + svfloat32_t partial3 = zero_vec; + svfloat32_t partial4 = zero_vec; + svfloat32_t partial5 = zero_vec; + svfloat32_t partial6 = zero_vec; + svfloat32_t partial7 = zero_vec; + for (size_t p = k_block; p < k_max; ++p) { + const float* b_vec = b + p * PACKED_B_BLOCK_WIDTH + col; + svfloat32_t bvals = MlasSveLoadFloat32(pg, b_vec); + svfloat32_t a0, a1, a2, a3, a4, a5, a6, a7; + if constexpr (!Alpha1) { + a0 = MlasSvedupFloat32(a[0 * lda + p] * alpha); + a1 = MlasSvedupFloat32(a[1 * lda + p] * alpha); + a2 = MlasSvedupFloat32(a[2 * lda + p] * alpha); + a3 = MlasSvedupFloat32(a[3 * lda + p] * alpha); + a4 = MlasSvedupFloat32(a[4 * lda + p] * alpha); + a5 = MlasSvedupFloat32(a[5 * lda + p] * alpha); + a6 = MlasSvedupFloat32(a[6 * lda + p] * alpha); + a7 = MlasSvedupFloat32(a[7 * lda + p] * alpha); + } else { + a0 = MlasSvedupFloat32(a[0 * lda + p]); + a1 = MlasSvedupFloat32(a[1 * lda + p]); + a2 = MlasSvedupFloat32(a[2 * lda + p]); + a3 = MlasSvedupFloat32(a[3 * lda + p]); + a4 = MlasSvedupFloat32(a[4 * lda + p]); + a5 = MlasSvedupFloat32(a[5 * lda + p]); + a6 = MlasSvedupFloat32(a[6 * lda + p]); + a7 = MlasSvedupFloat32(a[7 * lda + p]); + } + partial0 = MlasSveMultiplyAddFloat32(pg, bvals, a0, partial0); + partial1 = MlasSveMultiplyAddFloat32(pg, bvals, a1, partial1); + partial2 = MlasSveMultiplyAddFloat32(pg, bvals, a2, partial2); + partial3 = MlasSveMultiplyAddFloat32(pg, bvals, a3, partial3); + partial4 = MlasSveMultiplyAddFloat32(pg, bvals, a4, partial4); + partial5 = MlasSveMultiplyAddFloat32(pg, bvals, a5, partial5); + partial6 = MlasSveMultiplyAddFloat32(pg, bvals, a6, partial6); + partial7 = MlasSveMultiplyAddFloat32(pg, bvals, a7, partial7); + } + // Accumulate partials into accumulators + acc0 = MlasSveAddFloat32(pg, acc0, partial0); + acc1 = MlasSveAddFloat32(pg, acc1, partial1); + acc2 = MlasSveAddFloat32(pg, acc2, partial2); + acc3 = MlasSveAddFloat32(pg, acc3, partial3); + acc4 = MlasSveAddFloat32(pg, acc4, partial4); + acc5 = MlasSveAddFloat32(pg, acc5, partial5); + acc6 = MlasSveAddFloat32(pg, acc6, partial6); + acc7 = MlasSveAddFloat32(pg, acc7, partial7); + } + if constexpr (!ZeroMode) { + acc0 = MlasSveAddFloat32(pg, acc0, svld1(pg, out0)); + acc1 = MlasSveAddFloat32(pg, acc1, svld1(pg, out1)); + acc2 = MlasSveAddFloat32(pg, acc2, svld1(pg, out2)); + acc3 = MlasSveAddFloat32(pg, acc3, svld1(pg, out3)); + acc4 = MlasSveAddFloat32(pg, acc4, svld1(pg, out4)); + acc5 = MlasSveAddFloat32(pg, acc5, svld1(pg, out5)); + acc6 = MlasSveAddFloat32(pg, acc6, svld1(pg, out6)); + acc7 = MlasSveAddFloat32(pg, acc7, svld1(pg, out7)); + } + // Store results + MlasSveStoreFloat32(pg, out0, acc0); + MlasSveStoreFloat32(pg, out1, acc1); + MlasSveStoreFloat32(pg, out2, acc2); + MlasSveStoreFloat32(pg, out3, acc3); + MlasSveStoreFloat32(pg, out4, acc4); + MlasSveStoreFloat32(pg, out5, acc5); + MlasSveStoreFloat32(pg, out6, acc6); + MlasSveStoreFloat32(pg, out7, acc7); + } +} + +template +inline void +processrows_6( + const float* __restrict__ a, + const float* __restrict__ b, + float* __restrict__ res, + size_t k, + size_t n, + size_t lda, + size_t ldc, + float alpha, + size_t vl +) +{ + constexpr size_t k_step = 64; // Can be tuned per architecture + svfloat32_t zero_vec = MlasSvedupFloat32(0.f); + for (size_t col = 0; col < n; col += vl) { + svbool_t pg = svwhilelt_b32(col, n); + float* out0 = res + 0 * ldc + col; + float* out1 = res + 1 * ldc + col; + float* out2 = res + 2 * ldc + col; + float* out3 = res + 3 * ldc + col; + float* out4 = res + 4 * ldc + col; + float* out5 = res + 5 * ldc + col; + // Initialize accumulators to zero + svfloat32_t acc0 = zero_vec; + svfloat32_t acc1 = zero_vec; + svfloat32_t acc2 = zero_vec; + svfloat32_t acc3 = zero_vec; + svfloat32_t acc4 = zero_vec; + svfloat32_t acc5 = zero_vec; + for (size_t k_block = 0; k_block < k; k_block += k_step) { + size_t k_max = std::min(k_block + k_step, k); + svfloat32_t partial0 = zero_vec; + svfloat32_t partial1 = zero_vec; + svfloat32_t partial2 = zero_vec; + svfloat32_t partial3 = zero_vec; + svfloat32_t partial4 = zero_vec; + svfloat32_t partial5 = zero_vec; + for (size_t p = k_block; p < k_max; ++p) { + const float* b_vec = b + p * PACKED_B_BLOCK_WIDTH + col; + svfloat32_t bvals = MlasSveLoadFloat32(pg, b_vec); + svfloat32_t a0, a1, a2, a3, a4, a5; + if constexpr (!Alpha1) { + a0 = MlasSvedupFloat32(a[0 * lda + p] * alpha); + a1 = MlasSvedupFloat32(a[1 * lda + p] * alpha); + a2 = MlasSvedupFloat32(a[2 * lda + p] * alpha); + a3 = MlasSvedupFloat32(a[3 * lda + p] * alpha); + a4 = MlasSvedupFloat32(a[4 * lda + p] * alpha); + a5 = MlasSvedupFloat32(a[5 * lda + p] * alpha); + } else { + a0 = MlasSvedupFloat32(a[0 * lda + p]); + a1 = MlasSvedupFloat32(a[1 * lda + p]); + a2 = MlasSvedupFloat32(a[2 * lda + p]); + a3 = MlasSvedupFloat32(a[3 * lda + p]); + a4 = MlasSvedupFloat32(a[4 * lda + p]); + a5 = MlasSvedupFloat32(a[5 * lda + p]); + } + partial0 = MlasSveMultiplyAddFloat32(pg, bvals, a0, partial0); + partial1 = MlasSveMultiplyAddFloat32(pg, bvals, a1, partial1); + partial2 = MlasSveMultiplyAddFloat32(pg, bvals, a2, partial2); + partial3 = MlasSveMultiplyAddFloat32(pg, bvals, a3, partial3); + partial4 = MlasSveMultiplyAddFloat32(pg, bvals, a4, partial4); + partial5 = MlasSveMultiplyAddFloat32(pg, bvals, a5, partial5); + } + acc0 = MlasSveAddFloat32(pg, acc0, partial0); + acc1 = MlasSveAddFloat32(pg, acc1, partial1); + acc2 = MlasSveAddFloat32(pg, acc2, partial2); + acc3 = MlasSveAddFloat32(pg, acc3, partial3); + acc4 = MlasSveAddFloat32(pg, acc4, partial4); + acc5 = MlasSveAddFloat32(pg, acc5, partial5); + } + // Add existing result values at the end (if not ZeroMode) + if constexpr (!ZeroMode) { + acc0 = MlasSveAddFloat32(pg, acc0, svld1(pg, out0)); + acc1 = MlasSveAddFloat32(pg, acc1, svld1(pg, out1)); + acc2 = MlasSveAddFloat32(pg, acc2, svld1(pg, out2)); + acc3 = MlasSveAddFloat32(pg, acc3, svld1(pg, out3)); + acc4 = MlasSveAddFloat32(pg, acc4, svld1(pg, out4)); + acc5 = MlasSveAddFloat32(pg, acc5, svld1(pg, out5)); + } + MlasSveStoreFloat32(pg, out0, acc0); + MlasSveStoreFloat32(pg, out1, acc1); + MlasSveStoreFloat32(pg, out2, acc2); + MlasSveStoreFloat32(pg, out3, acc3); + MlasSveStoreFloat32(pg, out4, acc4); + MlasSveStoreFloat32(pg, out5, acc5); + } +} + +template +inline void +processrows_4( + const float* __restrict__ a, + const float* __restrict__ b, + float* __restrict__ res, + size_t k, + size_t n, + size_t lda, + size_t ldc, + float alpha, + size_t vl +) +{ + constexpr size_t k_step = 64; // Tunable tile size + svfloat32_t zero_vec = MlasSvedupFloat32(0.f); + for (size_t col = 0; col < n; col += vl) { + svbool_t pg = svwhilelt_b32(col, n); + float* out0 = res + 0 * ldc + col; + float* out1 = res + 1 * ldc + col; + float* out2 = res + 2 * ldc + col; + float* out3 = res + 3 * ldc + col; + // Start with clean zeroed accumulators + svfloat32_t acc0 = zero_vec; + svfloat32_t acc1 = zero_vec; + svfloat32_t acc2 = zero_vec; + svfloat32_t acc3 = zero_vec; + for (size_t k_block = 0; k_block < k; k_block += k_step) { + size_t k_max = std::min(k_block + k_step, k); + svfloat32_t partial0 = zero_vec; + svfloat32_t partial1 = zero_vec; + svfloat32_t partial2 = zero_vec; + svfloat32_t partial3 = zero_vec; + for (size_t p = k_block; p < k_max; ++p) { + const float* b_vec = b + p * PACKED_B_BLOCK_WIDTH + col; + svfloat32_t bvals = MlasSveLoadFloat32(pg, b_vec); + svfloat32_t a0, a1, a2, a3; + if constexpr (!Alpha1) { + a0 = MlasSvedupFloat32(a[0 * lda + p] * alpha); + a1 = MlasSvedupFloat32(a[1 * lda + p] * alpha); + a2 = MlasSvedupFloat32(a[2 * lda + p] * alpha); + a3 = MlasSvedupFloat32(a[3 * lda + p] * alpha); + } else { + a0 = MlasSvedupFloat32(a[0 * lda + p]); + a1 = MlasSvedupFloat32(a[1 * lda + p]); + a2 = MlasSvedupFloat32(a[2 * lda + p]); + a3 = MlasSvedupFloat32(a[3 * lda + p]); + } + partial0 = MlasSveMultiplyAddFloat32(pg, bvals, a0, partial0); + partial1 = MlasSveMultiplyAddFloat32(pg, bvals, a1, partial1); + partial2 = MlasSveMultiplyAddFloat32(pg, bvals, a2, partial2); + partial3 = MlasSveMultiplyAddFloat32(pg, bvals, a3, partial3); + } + acc0 = MlasSveAddFloat32(pg, acc0, partial0); + acc1 = MlasSveAddFloat32(pg, acc1, partial1); + acc2 = MlasSveAddFloat32(pg, acc2, partial2); + acc3 = MlasSveAddFloat32(pg, acc3, partial3); + } + // Final addition of existing result (if ZeroMode == false) + if constexpr (!ZeroMode) { + acc0 = MlasSveAddFloat32(pg, acc0, svld1(pg, out0)); + acc1 = MlasSveAddFloat32(pg, acc1, svld1(pg, out1)); + acc2 = MlasSveAddFloat32(pg, acc2, svld1(pg, out2)); + acc3 = MlasSveAddFloat32(pg, acc3, svld1(pg, out3)); + } + // Store the final accumulated results + MlasSveStoreFloat32(pg, out0, acc0); + MlasSveStoreFloat32(pg, out1, acc1); + MlasSveStoreFloat32(pg, out2, acc2); + MlasSveStoreFloat32(pg, out3, acc3); + } +} + +template +inline void +processrows_2( + const float* __restrict__ a, + const float* __restrict__ b, + float* __restrict__ res, + size_t k, + size_t n, + int lda, + int ldc, + float alpha, + size_t vl +) +{ + constexpr size_t k_step = 64; // Tune this value as needed + svfloat32_t zero_vec = MlasSvedupFloat32(0.f); + for (size_t col = 0; col < n; col += vl) { + svbool_t pg = svwhilelt_b32(col, n); + float* out0 = res + 0 * ldc + col; + float* out1 = res + 1 * ldc + col; + // Always start with zero accumulators + svfloat32_t acc0 = zero_vec; + svfloat32_t acc1 = zero_vec; + for (size_t k_block = 0; k_block < k; k_block += k_step) { + size_t k_max = std::min(k_block + k_step, k); + svfloat32_t partial0 = zero_vec; + svfloat32_t partial1 = zero_vec; + for (size_t p = k_block; p < k_max; ++p) { + const float* b_vec = b + p * PACKED_B_BLOCK_WIDTH + col; + svfloat32_t bvals = MlasSveLoadFloat32(pg, b_vec); + svfloat32_t a0, a1; + if constexpr (!Alpha1) { + a0 = MlasSvedupFloat32(a[0 * lda + p] * alpha); + a1 = MlasSvedupFloat32(a[1 * lda + p] * alpha); + } else { + a0 = MlasSvedupFloat32(a[0 * lda + p]); + a1 = MlasSvedupFloat32(a[1 * lda + p]); + } + partial0 = MlasSveMultiplyAddFloat32(pg, bvals, a0, partial0); + partial1 = MlasSveMultiplyAddFloat32(pg, bvals, a1, partial1); + } + acc0 = MlasSveAddFloat32(pg, acc0, partial0); + acc1 = MlasSveAddFloat32(pg, acc1, partial1); + } + // Add existing values at the end (if ZeroMode == false) + if constexpr (!ZeroMode) { + acc0 = MlasSveAddFloat32(pg, acc0, svld1(pg, out0)); + acc1 = MlasSveAddFloat32(pg, acc1, svld1(pg, out1)); + } + MlasSveStoreFloat32(pg, out0, acc0); + MlasSveStoreFloat32(pg, out1, acc1); + } +} + +template +inline void +processrows_1( + const float* __restrict__ a, + const float* __restrict__ b, + float* __restrict__ res, + size_t k, + size_t n, + size_t lda, + size_t ldc, + float alpha, + size_t vl +) +{ + constexpr size_t k_step = 64; // Tune based on target hardware + svfloat32_t zero_vec = MlasSvedupFloat32(0.f); + for (size_t col = 0; col < n; col += vl) { + svbool_t pg = svwhilelt_b32(col, n); + float* out0 = res + 0 * ldc + col; + // Always start with zero accumulator + svfloat32_t acc0 = zero_vec; + for (size_t k_block = 0; k_block < k; k_block += k_step) { + size_t k_max = std::min(k_block + k_step, k); + svfloat32_t partial = zero_vec; + for (size_t p = k_block; p < k_max; ++p) { + const float* b_vec = b + p * PACKED_B_BLOCK_WIDTH + col; + svfloat32_t bvals = MlasSveLoadFloat32(pg, b_vec); + svfloat32_t a0; + if constexpr (!Alpha1) { + a0 = MlasSvedupFloat32(a[p + 0 * lda] * alpha); + } else { + a0 = MlasSvedupFloat32(a[p + 0 * lda]); + } + partial = MlasSveMultiplyAddFloat32(pg, bvals, a0, partial); + } + acc0 = MlasSveAddFloat32(pg, acc0, partial); + } + // In Add mode (ZeroMode == false), add existing res at the end + if constexpr (!ZeroMode) { + svfloat32_t prev = MlasSveLoadFloat32(pg, out0); + acc0 = MlasSveAddFloat32(pg, acc0, prev); + } + // Store final result + MlasSveStoreFloat32(pg, out0, acc0); + } +} + +template +inline void +ProcessRowsTemplate( + + const float* __restrict A, + size_t lda, + const float* __restrict B, + float* __restrict C, + size_t ldc, + size_t K, + size_t N, + float alpha +) +{ + size_t n = 0; + const size_t vl = svcntw(); + while (n < N) { + int cols = (n + PACKED_B_BLOCK_WIDTH <= N) ? PACKED_B_BLOCK_WIDTH : (N - n); + ProcessFn(A, B, C, K, cols, lda, ldc, alpha, vl); + B += cols * K; + C += cols; + n += cols; + } +} + +size_t MLASCALL +MlasSgemmKernelZero_sve( + const float* A, + const float* B, + float* C, + size_t CountK, + size_t CountM, + size_t CountN, + size_t lda, + size_t ldc, + float alpha +) +{ + if (alpha == 1.0f) { + if (CountM >= 8) { + ProcessRowsTemplate>(A, lda, B, C, ldc, CountK, CountN, alpha); + return 8; + } + + else if (CountM >= 6) { + ProcessRowsTemplate>(A, lda, B, C, ldc, CountK, CountN, alpha); + return 6; + } else if (CountM >= 4) { + ProcessRowsTemplate>(A, lda, B, C, ldc, CountK, CountN, alpha); + return 4; + } else if (CountM >= 2) { + ProcessRowsTemplate>(A, lda, B, C, ldc, CountK, CountN, alpha); + return 2; + } else + ProcessRowsTemplate>(A, lda, B, C, ldc, CountK, CountN, alpha); + return 1; + } else { + if (CountM >= 8) { + ProcessRowsTemplate>(A, lda, B, C, ldc, CountK, CountN, alpha); + return 8; + } + + else if (CountM >= 6) { + ProcessRowsTemplate>(A, lda, B, C, ldc, CountK, CountN, alpha); + return 6; + } else if (CountM >= 4) { + ProcessRowsTemplate>(A, lda, B, C, ldc, CountK, CountN, alpha); + return 4; + } else if (CountM >= 2) { + ProcessRowsTemplate>(A, lda, B, C, ldc, CountK, CountN, alpha); + return 2; + } else + ProcessRowsTemplate>(A, lda, B, C, ldc, CountK, CountN, alpha); + return 1; + } +} + +size_t MLASCALL +MlasSgemmKernelAdd_sve( + const float* A, + const float* B, + float* C, + size_t CountK, + size_t CountM, + size_t CountN, + size_t lda, + size_t ldc, + float alpha +) +{ + if (alpha == 1.0f) { + if (CountM >= 8) { + ProcessRowsTemplate>(A, lda, B, C, ldc, CountK, CountN, alpha); + return 8; + } + + else if (CountM >= 6) { + ProcessRowsTemplate>(A, lda, B, C, ldc, CountK, CountN, alpha); + return 6; + } else if (CountM >= 4) { + ProcessRowsTemplate>(A, lda, B, C, ldc, CountK, CountN, alpha); + return 4; + } else if (CountM >= 2) { + ProcessRowsTemplate>(A, lda, B, C, ldc, CountK, CountN, alpha); + return 2; + } else + ProcessRowsTemplate>(A, lda, B, C, ldc, CountK, CountN, alpha); + return 1; + } else { + if (CountM >= 8) { + ProcessRowsTemplate>(A, lda, B, C, ldc, CountK, CountN, alpha); + return 8; + } else if (CountM >= 6) { + ProcessRowsTemplate>(A, lda, B, C, ldc, CountK, CountN, alpha); + return 6; + } else if (CountM >= 4) { + ProcessRowsTemplate>(A, lda, B, C, ldc, CountK, CountN, alpha); + return 4; + } else if (CountM >= 2) { + ProcessRowsTemplate>(A, lda, B, C, ldc, CountK, CountN, alpha); + return 2; + } else + ProcessRowsTemplate>(A, lda, B, C, ldc, CountK, CountN, alpha); + return 1; + } +} + +void MLAS_SVE_TARGET MLASCALL +SVE_TRANSPOSE(float*& D, const float*& b, size_t ldb, size_t& x) +{ + const static int VL = svcntw(); + if (VL == 16) { + while (x >= 16) { + Transpose_SVE512_16x16(&D[0], &b[0], ldb); + D += 256; + b += 16; + x = x - 16; + } + } else if (VL == 8) { + while (x >= 8) { + TransposePackBNx8<16>(&D[0], &b[0], ldb); + D += 128; + b += 8; + x = x - 8; + } + } + + while (x >= 4) { + MlasSveTransposePackBNx4<16>(&D[0], &b[0], ldb); + + D += 16 * 4; + b += 4; + x = x - 4; + } +} + +void MLAS_SVE_TARGET MLASCALL +SCATTER_STORE(float* d, const float* b) +{ + MLAS_SVBOOL pb = svwhilelt_b32((int)0, 4); // MSB 4 bits + MLAS_SVFLOAT32 vec0 = MlasSveLoadFloat32(pb, b); // Load a set of 4 elements + + svuint32_t idx = svindex_u32(0, 1); + MLAS_SVBOOL pb_first_half = svcmpeq_u32(pb, idx, svdup_n_u32(0)); + MLAS_SVBOOL pb_second_half = svcmpeq_u32(pb, idx, svdup_n_u32(1)); + MLAS_SVBOOL pb_third_half = svcmpeq_u32(pb, idx, svdup_n_u32(2)); + MLAS_SVBOOL pb_fourth_half = svcmpeq_u32(pb, idx, svdup_n_u32(3)); + + MlasSveStoreFloat32(pb_first_half, &d[0], vec0); + MlasSveStoreFloat32(pb_second_half, &d[15], vec0); + MlasSveStoreFloat32(pb_third_half, &d[30], vec0); + MlasSveStoreFloat32(pb_fourth_half, &d[45], vec0); +} + +void MLAS_SVE_TARGET MLASCALL +SVE_LOAD_STORE(float* D, const float* b) +{ + for (int i = 0; i < MLAS_SGEMM_STRIDEN_THREAD_ALIGN; i += VL()) { + svfloat32_t vec0 = MlasSveLoadFloat32(svptrue_b32(), b + i); + MlasSveStoreFloat32(svptrue_b32(), D + i, vec0); + } +} + +void MLAS_SVE_TARGET MLASCALL +SVE_ZERO_INITIALIZE(float* d) +{ + if (VL() == PACKED_B_BLOCK_WIDTH) { + MlasSveStoreFloat32(svptrue_b32(), d, svdup_f32(0)); + } else { + MlasSveStoreFloat32(svptrue_b32(), d, svdup_f32(0)); + MlasSveStoreFloat32(svptrue_b32(), d + VL(), svdup_f32(0)); + } +} +#endif