Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,8 @@ else()
list(APPEND mlas_platform_srcs ${MLAS_SRC_DIR}/sve/mlasi_sve.h)
list(APPEND mlas_platform_srcs ${MLAS_SRC_DIR}/sve/elementwise_sve.cpp)
set_source_files_properties(${MLAS_SRC_DIR}/sve/elementwise_sve.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+sve+fp16 ")
list(APPEND mlas_platform_srcs ${MLAS_SRC_DIR}/sve/sgemm_sve.cpp)
set_source_files_properties(${MLAS_SRC_DIR}/sve/sgemm_sve.cpp PROPERTIES COMPILE_FLAGS "-march=armv8.2-a+sve -O3 -ffast-math -funroll-loops")
list(APPEND mlas_private_compile_definitions MLAS_USE_SVE)
endif()

Expand Down
106 changes: 94 additions & 12 deletions onnxruntime/core/mlas/lib/sgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ Module Name:

#include "mlasi.h"

#if defined(MLAS_USE_SVE)
#include "sve/mlasi_sve.h"
#endif

//
// Define the number of rows from matrix A to transpose to a local buffer.
//
Expand Down Expand Up @@ -259,8 +263,15 @@ Return Value:

do {

#if defined(MLAS_NEON_INTRINSICS)
vst4q_f32(D, vld4q_f32(b));
#if defined(MLAS_USE_SVE) && defined(MLAS_NEON_INTRINSICS)
if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmSve()) {
SVE_LOAD_STORE(D, b);
}
else{
vst4q_f32(D, vld4q_f32(b));
}
#elif defined(MLAS_NEON_INTRINSICS)
vst4q_f32(D, vld4q_f32(b));
#else
MLAS_FLOAT32X4 t0 = MlasLoadFloat32x4(&b[0]);
MLAS_FLOAT32X4 t1 = MlasLoadFloat32x4(&b[4]);
Expand Down Expand Up @@ -303,8 +314,14 @@ Return Value:
float* d = D;
const float* b = B;

#if defined(MLAS_NEON_INTRINSICS)
vst4q_f32(d, ZeroFloat32x4x4);
#if defined(MLAS_USE_SVE) && defined(MLAS_NEON_INTRINSICS)
if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmSve()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as before - what happens when MLAS_USE_SVE is false and MLAS_NEON_INTRINSICS is true and HasArmSve() evaluates to true ?

SVE_ZERO_INITIALIZE(d);
} else {
vst4q_f32(d, ZeroFloat32x4x4);
}
#elif defined(MLAS_NEON_INTRINSICS)
vst4q_f32(d, ZeroFloat32x4x4);
#else
MlasStoreAlignedFloat32x4(d, ZeroFloat32x4);
MlasStoreAlignedFloat32x4(d + 4, ZeroFloat32x4);
Expand Down Expand Up @@ -486,6 +503,21 @@ Return Value:
x -= 4;
}

#elif defined(MLAS_USE_SVE)
if(MLAS_CPUIDINFO::GetCPUIDInfo().HasArmSve()) {
SVE_TRANSPOSE(D,b,ldb,x);
}
else
{
while (x >= 4) {

MlasSgemmTransposePackBNx4<16>(&D[0], &b[0], ldb);

D += 16 * 4;
b += 4;
x -= 4;
}
}
#else

while (x >= 4) {
Expand Down Expand Up @@ -564,8 +596,15 @@ Return Value:
const float* b = B;

if ((CountY & 8) != 0) {

MlasSgemmTransposePackBNx4<8>(&d[0], &b[0], ldb);
#if defined(MLAS_USE_SVE)
if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmSve()){
MlasSveTransposePackBNx4<8>(&d[0], &b[0], ldb);}
else{
MlasSgemmTransposePackBNx4<8>(&d[0], &b[0], ldb);
}
#else
MlasSgemmTransposePackBNx4<8>(&d[0], &b[0], ldb);
#endif

d += 8;
b += ldb * 8;
Expand All @@ -584,7 +623,15 @@ Return Value:

if ((CountY & 4) != 0) {

MlasSgemmTransposePackBNx4<4>(&d[0], &b[0], ldb);
#if defined(MLAS_USE_SVE)
if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmSve()){
MlasSveTransposePackBNx4<4>(&d[0], &b[0], ldb);}
else{
MlasSgemmTransposePackBNx4<4>(&d[0], &b[0], ldb);
}
#else
MlasSgemmTransposePackBNx4<4>(&d[0], &b[0], ldb);
#endif

d += 4;
b += ldb * 4;
Expand Down Expand Up @@ -631,7 +678,19 @@ Return Value:

if ((CountY & 1) != 0) {

#if defined(MLAS_NEON_INTRINSICS)
#if defined(MLAS_USE_SVE) && defined(MLAS_NEON_INTRINSICS)
if(MLAS_CPUIDINFO::GetCPUIDInfo().HasArmSve())
{
SCATTER_STORE(&d[0],&b[0]);
}
else{
MLAS_FLOAT32X4 t0 = MlasLoadFloat32x4(&b[0]);

MlasStoreLaneFloat32x4<0>(&d[0], t0);
MlasStoreLaneFloat32x4<1>(&d[16], t0);
MlasStoreLaneFloat32x4<2>(&d[32], t0);
MlasStoreLaneFloat32x4<3>(&d[48], t0);}
#elif defined(MLAS_NEON_INTRINSICS)
MLAS_FLOAT32X4 t0 = MlasLoadFloat32x4(&b[0]);

MlasStoreLaneFloat32x4<0>(&d[0], t0);
Expand Down Expand Up @@ -1004,8 +1063,7 @@ Return Value:
#endif

MLAS_FORCEINLINE
float*
MlasSgemmKernelLoop(
float* MlasSgemmKernelLoop(
const float* A,
const float* B,
float* C,
Expand Down Expand Up @@ -1059,18 +1117,41 @@ Return Value:
{
while (CountM > 0) {

size_t RowsHandled;
size_t RowsHandled = 0;

#if (defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_S390X) || defined(MLAS_TARGET_LARCH64)) && !defined(FORCE_GENERIC_ALGORITHMS)
RowsHandled = GetMlasPlatform().GemmFloatKernel(A, B, C, CountK, CountM, CountN, lda, ldc, alpha, ZeroMode);

#else

if (ZeroMode) {

#if defined(MLAS_USE_SVE)
if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmSve()) {
RowsHandled = MlasSgemmKernelZero_sve(A, B, C, CountK, CountM, CountN, lda, ldc, alpha);
} else {
RowsHandled = MlasSgemmKernelZero(A, B, C, CountK, CountM, CountN, lda, ldc, alpha);
}
#else
RowsHandled = MlasSgemmKernelZero(A, B, C, CountK, CountM, CountN, lda, ldc, alpha);
#endif

} else {

#if defined(MLAS_USE_SVE)
if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmSve()) {
RowsHandled = MlasSgemmKernelAdd_sve(A, B, C, CountK, CountM, CountN, lda, ldc, alpha);
} else {
RowsHandled = MlasSgemmKernelAdd(A, B, C, CountK, CountM, CountN, lda, ldc, alpha);
}
#else
RowsHandled = MlasSgemmKernelAdd(A, B, C, CountK, CountM, CountN, lda, ldc, alpha);
}
#endif

}

#endif // platform check

C += ldc * RowsHandled;
A += lda * RowsHandled;
CountM -= RowsHandled;
Expand All @@ -1079,6 +1160,7 @@ Return Value:
return C;
}


void
MlasSgemmOperation(
CBLAS_TRANSPOSE TransA,
Expand Down
Loading
Loading