diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 8af161b524bee..83d1751e55543 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -89,6 +89,7 @@ option(onnxruntime_USE_RKNPU "Build with RKNPU support" OFF) option(onnxruntime_USE_DNNL "Build with DNNL support" OFF) option(onnxruntime_USE_JSEP "Build with JavaScript implemented kernels support" OFF) option(onnxruntime_USE_SVE "Build with SVE support in MLAS" OFF) +option(onnxruntime_USE_RVV "Build with RISC-V Vector support in MLAS" OFF) option(onnxruntime_USE_ARM_NEON_NCHWC "Build with ARM Neon NCHWc kernels in MLAS" OFF) option(onnxruntime_USE_KLEIDIAI "Build with KleidiAI integration in MLAS" OFF) diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index bde73252449dc..0233254ad50ad 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -435,6 +435,8 @@ else() set(X86 TRUE) elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|amd64)$") set(X86_64 TRUE) + elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^riscv64.*") + set(RISCV64 TRUE) elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^loongarch64.*") set(LOONGARCH64 TRUE) elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^s390x$") @@ -903,6 +905,48 @@ endif() set(MLAS_SOURCE_IS_NOT_SET 0) endif() endif() + if(RISCV64 AND MLAS_SOURCE_IS_NOT_SET) + file(GLOB_RECURSE mlas_platform_srcs CONFIGURE_DEPENDS + "${MLAS_SRC_DIR}/scalar/*.cpp") + + if(onnxruntime_USE_RVV) + set(OLD_CMAKE_REQUIRED_FLAGS "${CMAKE_REQUIRED_FLAGS}") + set(CMAKE_REQUIRED_FLAGS "${OLD_CMAKE_REQUIRED_FLAGS} -march=rv64gcv -mabi=lp64d") + check_cxx_source_compiles(" + #include + #include + int main() { + size_t vl = __riscv_vsetvl_e32m1(4); + return static_cast(vl == 0); + }" + HAS_RISCV64_RVV + ) + set(CMAKE_REQUIRED_FLAGS "${OLD_CMAKE_REQUIRED_FLAGS}") + unset(OLD_CMAKE_REQUIRED_FLAGS) + + if(HAS_RISCV64_RVV) + list(APPEND mlas_platform_srcs + ${MLAS_SRC_DIR}/riscv64/sgemm_pack_b_rvv.cpp + ${MLAS_SRC_DIR}/riscv64/sgemm_kernel_rvv.cpp + ${MLAS_SRC_DIR}/riscv64/softmax_kernel_rvv.cpp + ) + set_source_files_properties( + ${MLAS_SRC_DIR}/riscv64/sgemm_pack_b_rvv.cpp + ${MLAS_SRC_DIR}/riscv64/sgemm_kernel_rvv.cpp + ${MLAS_SRC_DIR}/riscv64/softmax_kernel_rvv.cpp + PROPERTIES COMPILE_FLAGS "-march=rv64gcv -mabi=lp64d") + list(APPEND mlas_private_compile_definitions MLAS_USE_RVV=1) + else() + message( + WARNING + "onnxruntime_USE_RVV was requested, but the compiler does not support rv64gcv RVV intrinsics. Falling back to scalar MLAS kernels.") + endif() + endif() + + if(NOT ONNXRUNTIME_MLAS_MULTI_ARCH) + set(MLAS_SOURCE_IS_NOT_SET 0) + endif() + endif() if(NOT ONNXRUNTIME_MLAS_MULTI_ARCH AND MLAS_SOURCE_IS_NOT_SET) file(GLOB_RECURSE mlas_platform_srcs "${MLAS_SRC_DIR}/scalar/*.cpp") @@ -997,4 +1041,4 @@ if (NOT onnxruntime_ORT_MINIMAL_BUILD) endif() endif() -endif() \ No newline at end of file +endif() diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 4e5636572b94a..bd12b50b7af43 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -1400,6 +1400,7 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) SET(MLAS_BENCH_DIR ${TEST_SRC_DIR}/mlas/bench) file(GLOB_RECURSE MLAS_BENCH_SOURCE_FILES "${MLAS_BENCH_DIR}/*.cpp" "${MLAS_BENCH_DIR}/*.h") + list(FILTER MLAS_BENCH_SOURCE_FILES EXCLUDE REGEX "${MLAS_BENCH_DIR}/riscv64/.*") onnxruntime_add_executable(onnxruntime_mlas_benchmark ${MLAS_BENCH_SOURCE_FILES} ${ONNXRUNTIME_ROOT}/core/framework/error_code.cc) target_include_directories(onnxruntime_mlas_benchmark PRIVATE ${ONNXRUNTIME_ROOT}/core/mlas/inc) target_link_libraries(onnxruntime_mlas_benchmark PRIVATE benchmark::benchmark onnxruntime_util ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_common ${CMAKE_DL_LIBS}) @@ -1418,6 +1419,33 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) target_link_libraries(onnxruntime_mlas_benchmark PRIVATE cpuinfo) endif() set_target_properties(onnxruntime_mlas_benchmark PROPERTIES FOLDER "ONNXRuntimeTest") + + endif() + + if(CMAKE_SYSTEM_PROCESSOR MATCHES "^riscv64.*") + set(MLAS_RISCV64_BENCH_DIR ${TEST_SRC_DIR}/mlas/bench/riscv64) + + onnxruntime_add_executable( + onnxruntime_mlas_sgemm_riscv_bench + ${MLAS_RISCV64_BENCH_DIR}/sgemm_riscv_bench.cpp) + target_include_directories(onnxruntime_mlas_sgemm_riscv_bench PRIVATE ${ONNXRUNTIME_ROOT}/core/mlas/inc) + target_link_libraries( + onnxruntime_mlas_sgemm_riscv_bench + PRIVATE ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_common ${CMAKE_DL_LIBS}) + target_compile_definitions(onnxruntime_mlas_sgemm_riscv_bench PRIVATE ${mlas_private_compile_definitions}) + set_target_properties(onnxruntime_mlas_sgemm_riscv_bench PROPERTIES FOLDER "ONNXRuntimeTest") + + onnxruntime_add_executable( + onnxruntime_mlas_softmax_riscv_compare + ${MLAS_RISCV64_BENCH_DIR}/softmax_rvv_compare.cpp) + target_include_directories( + onnxruntime_mlas_softmax_riscv_compare + PRIVATE ${ONNXRUNTIME_ROOT} ${ONNXRUNTIME_ROOT}/core/mlas/inc) + target_link_libraries( + onnxruntime_mlas_softmax_riscv_compare + PRIVATE ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_common ${CMAKE_DL_LIBS}) + target_compile_definitions(onnxruntime_mlas_softmax_riscv_compare PRIVATE ${mlas_private_compile_definitions}) + set_target_properties(onnxruntime_mlas_softmax_riscv_compare PROPERTIES FOLDER "ONNXRuntimeTest") endif() if(WIN32) diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index f7c2908d0ab8b..04e99d206bd06 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -60,6 +60,9 @@ Module Name: #if defined(__s390x__) #define MLAS_TARGET_S390X #endif +#if defined(__riscv) && defined(__riscv_xlen) && (__riscv_xlen == 64) +#define MLAS_TARGET_RISCV64 +#endif #if defined(__VSX__) #define MLAS_TARGET_POWER diff --git a/onnxruntime/core/mlas/lib/compute.cpp b/onnxruntime/core/mlas/lib/compute.cpp index 4916062f2b4f9..a677ee5087672 100644 --- a/onnxruntime/core/mlas/lib/compute.cpp +++ b/onnxruntime/core/mlas/lib/compute.cpp @@ -876,7 +876,7 @@ Return Value: // float Maximum; -#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) || defined(MLAS_USE_SVE) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) || defined(MLAS_USE_SVE) || defined(MLAS_TARGET_RISCV64) Maximum = GetMlasPlatform().ReduceMaximumF32Kernel(Input, D); #else Maximum = MlasReduceMaximumF32Kernel(Input, D); @@ -894,7 +894,7 @@ Return Value: float* Temp = LogSoftmax ? nullptr : Output; float Accumulation; -#if defined(MLAS_TARGET_AMD64) || defined(MLAS_USE_SVE) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_USE_SVE) || defined(MLAS_TARGET_RISCV64) Accumulation = GetMlasPlatform().ComputeSumExpF32Kernel(Input, Temp, D, &NegativeMaximum); #else Accumulation = MlasComputeSumExpF32Kernel(Input, Temp, D, &NegativeMaximum); @@ -910,7 +910,7 @@ Return Value: // float Parameters[] = {NegativeMaximum, std::log(Accumulation)}; -#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) || defined(MLAS_USE_SVE) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) || defined(MLAS_USE_SVE) || defined(MLAS_TARGET_RISCV64) GetMlasPlatform().ComputeLogSoftmaxOutputF32Kernel(Input, Output, D, Parameters); #else @@ -922,7 +922,7 @@ Return Value: // float Parameters[] = {1.0f / Accumulation}; -#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) || defined(MLAS_USE_SVE) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) || defined(MLAS_USE_SVE) || defined(MLAS_TARGET_RISCV64) GetMlasPlatform().ComputeSoftmaxOutputF32Kernel(Output, D, Parameters); #else MlasComputeSoftmaxOutputF32Kernel(Output, D, Parameters); diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 662e757a47998..1fa4c90913b24 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -352,7 +352,8 @@ static_assert(sizeof(MLAS_FP16) == FP16_SIZE); // #if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || \ - defined(MLAS_TARGET_LARCH64) || defined(MLAS_TARGET_S390X) + defined(MLAS_TARGET_LARCH64) || defined(MLAS_TARGET_S390X) || \ + defined(MLAS_TARGET_RISCV64) typedef size_t @@ -1018,6 +1019,36 @@ extern "C" { MLAS_REDUCE_MAXIMUM_FLOAT_KERNEL MlasReduceMaximumF32KernelLasx; MLAS_COMPUTE_SOFTMAX_OUTPUT_FLOAT_KERNEL MlasComputeSoftmaxOutputF32KernelLasx; MLAS_COMPUTE_LOGSOFTMAX_OUTPUT_FLOAT_KERNEL MlasComputeLogSoftmaxOutputF32KernelLasx; +#elif defined(MLAS_TARGET_RISCV64) +#if defined(MLAS_USE_RVV) + MLAS_GEMM_FLOAT_KERNEL MlasGemmFloatKernelRvv; + void MlasSgemmCopyPackBRvv( + float* D, + const float* B, + size_t ldb, + size_t CountX, + size_t CountY); +#endif + size_t MLASCALL MlasSgemmKernelZero( + const float* A, + const float* B, + float* C, + size_t CountK, + size_t CountM, + size_t CountN, + size_t lda, + size_t ldc, + float alpha); + size_t MLASCALL MlasSgemmKernelAdd( + const float* A, + const float* B, + float* C, + size_t CountK, + size_t CountM, + size_t CountN, + size_t lda, + size_t ldc, + float alpha); #else MLAS_GEMM_FLOAT_KERNEL MlasSgemmKernelZero; MLAS_GEMM_FLOAT_KERNEL MlasSgemmKernelAdd; @@ -1167,6 +1198,12 @@ extern "C" { MLAS_REDUCE_MAXIMUM_FLOAT_KERNEL MlasReduceMaximumF32Kernel; MLAS_REDUCE_MINIMUM_MAXIMUM_FLOAT_KERNEL MlasReduceMinimumMaximumF32Kernel; +#if defined(MLAS_TARGET_RISCV64) && defined(MLAS_USE_RVV) + MLAS_COMPUTE_SUMEXP_FLOAT_KERNEL MlasComputeSumExpF32KernelRvv; + MLAS_REDUCE_MAXIMUM_FLOAT_KERNEL MlasReduceMaximumF32KernelRvv; + MLAS_COMPUTE_SOFTMAX_OUTPUT_FLOAT_KERNEL MlasComputeSoftmaxOutputF32KernelRvv; + MLAS_COMPUTE_LOGSOFTMAX_OUTPUT_FLOAT_KERNEL MlasComputeLogSoftmaxOutputF32KernelRvv; +#endif #if defined(MLAS_TARGET_AMD64) MLAS_REDUCE_MAXIMUM_FLOAT_KERNEL MlasReduceMaximumF32KernelAvx; MLAS_REDUCE_MAXIMUM_FLOAT_KERNEL MlasReduceMaximumF32KernelAvx512F; @@ -1442,7 +1479,7 @@ struct MLAS_PLATFORM { #endif -#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_S390X) +#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_S390X) || defined(MLAS_TARGET_RISCV64) MLAS_GEMM_FLOAT_KERNEL* GemmFloatKernel; #endif #if defined(MLAS_TARGET_LARCH64) @@ -1507,7 +1544,7 @@ struct MLAS_PLATFORM { MLAS_QUANTIZE_LINEAR_U4_KERNEL* QuantizeLinearU4Kernel; #endif -#if defined(MLAS_USE_SVE) || defined(MLAS_TARGET_AMD64) +#if defined(MLAS_USE_SVE) || defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_RISCV64) MLAS_COMPUTE_UNARY_FLOAT_KERNEL* ErfKernelRoutine; MLAS_COMPUTE_UNARY_FLOAT_KERNEL* LogisticKernelRoutine; MLAS_REDUCE_MAXIMUM_FLOAT_KERNEL* ReduceMaximumF32Kernel; diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index e9f140a2ee0f7..191ee1ab2f2f8 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -27,8 +27,10 @@ Module Name: #include "kleidiai/mlasi_kleidiai.h" #endif -#include +#include +#include #include +#include #if defined(MLAS_TARGET_POWER) #if defined(__linux__) @@ -49,6 +51,54 @@ Module Name: #include #endif +#if defined(MLAS_TARGET_RISCV64) && defined(MLAS_USE_RVV) && defined(__linux__) +#include +#include +#ifndef COMPAT_HWCAP_ISA_V +#define COMPAT_HWCAP_ISA_V (1UL << ('V' - 'A')) +#endif +#endif + +#if defined(MLAS_TARGET_RISCV64) && defined(MLAS_USE_RVV) +namespace { + +bool +MlasStringEqualsIgnoreCase( + const char* value, + const char* expected + ) +{ + while (*value != '\0' && *expected != '\0') { + const auto lhs = static_cast(*value); + const auto rhs = static_cast(*expected); + if (std::tolower(lhs) != std::tolower(rhs)) { + return false; + } + ++value; + ++expected; + } + + return *value == '\0' && *expected == '\0'; +} + +bool +MlasShouldForceScalarRiscv( + const char* value + ) +{ + if (value == nullptr || value[0] == '\0') { + return false; + } + + return MlasStringEqualsIgnoreCase(value, "1") || + MlasStringEqualsIgnoreCase(value, "true") || + MlasStringEqualsIgnoreCase(value, "on") || + MlasStringEqualsIgnoreCase(value, "yes"); +} + +} // namespace +#endif + #if defined(MLAS_TARGET_ARM64) #if defined(_WIN32) @@ -265,6 +315,33 @@ Return Value: this->CastF16ToF32Kernel = nullptr; this->CastF32ToF16Kernel = nullptr; +#if defined(MLAS_TARGET_RISCV64) + this->GemmFloatKernel = nullptr; + this->ErfKernelRoutine = MlasErfKernel; + this->LogisticKernelRoutine = MlasLogisticKernel; + this->ReduceMaximumF32Kernel = MlasReduceMaximumF32Kernel; + this->ComputeSumExpF32Kernel = MlasComputeSumExpF32Kernel; + this->ComputeSoftmaxOutputF32Kernel = MlasComputeSoftmaxOutputF32Kernel; + this->ComputeLogSoftmaxOutputF32Kernel = MlasComputeLogSoftmaxOutputF32Kernel; + +#if defined(MLAS_USE_RVV) + bool has_rvv = true; +#if defined(__linux__) + has_rvv = (getauxval(AT_HWCAP) & COMPAT_HWCAP_ISA_V) != 0; +#endif + if (MlasShouldForceScalarRiscv(std::getenv("ORT_MLAS_RISCV_FORCE_SCALAR"))) { + has_rvv = false; + } + if (has_rvv) { + this->GemmFloatKernel = MlasGemmFloatKernelRvv; + this->ReduceMaximumF32Kernel = MlasReduceMaximumF32KernelRvv; + this->ComputeSumExpF32Kernel = MlasComputeSumExpF32KernelRvv; + this->ComputeSoftmaxOutputF32Kernel = MlasComputeSoftmaxOutputF32KernelRvv; + this->ComputeLogSoftmaxOutputF32Kernel = MlasComputeLogSoftmaxOutputF32KernelRvv; + } +#endif +#endif + #if defined(MLAS_TARGET_AMD64_IX86) // diff --git a/onnxruntime/core/mlas/lib/riscv64/sgemm_kernel_rvv.cpp b/onnxruntime/core/mlas/lib/riscv64/sgemm_kernel_rvv.cpp new file mode 100644 index 0000000000000..c6e43e2c8bcd4 --- /dev/null +++ b/onnxruntime/core/mlas/lib/riscv64/sgemm_kernel_rvv.cpp @@ -0,0 +1,275 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + sgemm_kernel_rvv.cpp + +Abstract: + + This module implements an RVV kernel for the single precision matrix/matrix + multiply operation (SGEMM) on riscv64. + +--*/ + +#include "mlasi.h" + +#if defined(MLAS_USE_RVV) + +#include + +namespace { + +// The packed B layout stays 16 columns wide to match MLAS, but each tile is +// consumed in runtime-sized RVV chunks so the kernel is not tied to a fixed +// VLEN such as 128 or 256 bits. +constexpr size_t kPackedCountN = 16; + +template +MLAS_FORCEINLINE +void +MlasStoreAccumulatorRvv( + float* C, + vfloat32m4_t Accumulator, + size_t vl, + float alpha + ) +{ +#if defined(_WIN32) + + if constexpr (AlphaIsOne) { + UNREFERENCED_PARAMETER(alpha); + } + +#endif + + if constexpr (!AlphaIsOne) { + Accumulator = __riscv_vfmul_vf_f32m4(Accumulator, alpha, vl); + } + + if constexpr (!ZeroMode) { + Accumulator = __riscv_vfadd_vv_f32m4(Accumulator, __riscv_vle32_v_f32m4(C, vl), vl); + } + + __riscv_vse32_v_f32m4(C, Accumulator, vl); +} + +template +MLAS_FORCEINLINE +size_t +MlasSgemmKernelRvv( + const float* A, + const float* B, + float* C, + size_t CountK, + size_t CountN, + size_t lda, + size_t ldc, + float alpha + ) +{ + static_assert(Rows >= 1 && Rows <= 4, "unsupported RVV SGEMM tile height"); + +#if defined(_WIN32) + + if constexpr (Rows == 1) { + UNREFERENCED_PARAMETER(lda); + UNREFERENCED_PARAMETER(ldc); + } + + if constexpr (AlphaIsOne) { + UNREFERENCED_PARAMETER(alpha); + } + +#endif + + const float* packed_b_block = B; + float* c_block = C; + size_t remaining_n_total = CountN; + + do { + const size_t count_n_block = remaining_n_total >= kPackedCountN ? kPackedCountN : remaining_n_total; + size_t remaining_n_block = count_n_block; + size_t column_offset = 0; + float* c = c_block; + + while (remaining_n_block > 0) { + // Split a packed 16-column tile into however many lanes the current + // machine exposes for e32,m4. This keeps the kernel VLEN-agnostic. + const size_t vl = __riscv_vsetvl_e32m4(remaining_n_block); + vfloat32m4_t row0_block = __riscv_vfmv_v_f_f32m4(0.0f, vl); + vfloat32m4_t row1_block; + vfloat32m4_t row2_block; + vfloat32m4_t row3_block; + + if constexpr (Rows >= 2) { + row1_block = __riscv_vfmv_v_f_f32m4(0.0f, vl); + } + if constexpr (Rows >= 3) { + row2_block = __riscv_vfmv_v_f_f32m4(0.0f, vl); + } + if constexpr (Rows >= 4) { + row3_block = __riscv_vfmv_v_f_f32m4(0.0f, vl); + } + + const float* a = A; + const float* b = packed_b_block + column_offset; + size_t k = CountK; + + while (k >= 2) { + const float row0_a0 = a[0]; + const float row0_a1 = a[1]; + vfloat32m4_t b_elements = __riscv_vle32_v_f32m4(b, vl); + row0_block = __riscv_vfmacc_vf_f32m4(row0_block, row0_a0, b_elements, vl); + + if constexpr (Rows >= 2) { + row1_block = __riscv_vfmacc_vf_f32m4(row1_block, a[lda], b_elements, vl); + } + if constexpr (Rows >= 3) { + row2_block = __riscv_vfmacc_vf_f32m4(row2_block, a[lda * 2], b_elements, vl); + } + if constexpr (Rows >= 4) { + row3_block = __riscv_vfmacc_vf_f32m4(row3_block, a[lda * 3], b_elements, vl); + } + + b_elements = __riscv_vle32_v_f32m4(b + kPackedCountN, vl); + row0_block = __riscv_vfmacc_vf_f32m4(row0_block, row0_a1, b_elements, vl); + + if constexpr (Rows >= 2) { + row1_block = __riscv_vfmacc_vf_f32m4(row1_block, a[lda + 1], b_elements, vl); + } + if constexpr (Rows >= 3) { + row2_block = __riscv_vfmacc_vf_f32m4(row2_block, a[lda * 2 + 1], b_elements, vl); + } + if constexpr (Rows >= 4) { + row3_block = __riscv_vfmacc_vf_f32m4(row3_block, a[lda * 3 + 1], b_elements, vl); + } + + a += 2; + b += kPackedCountN * 2; + k -= 2; + } + + if (k > 0) { + vfloat32m4_t b_elements = __riscv_vle32_v_f32m4(b, vl); + row0_block = __riscv_vfmacc_vf_f32m4(row0_block, a[0], b_elements, vl); + + if constexpr (Rows >= 2) { + row1_block = __riscv_vfmacc_vf_f32m4(row1_block, a[lda], b_elements, vl); + } + if constexpr (Rows >= 3) { + row2_block = __riscv_vfmacc_vf_f32m4(row2_block, a[lda * 2], b_elements, vl); + } + if constexpr (Rows >= 4) { + row3_block = __riscv_vfmacc_vf_f32m4(row3_block, a[lda * 3], b_elements, vl); + } + } + + MlasStoreAccumulatorRvv(c, row0_block, vl, alpha); + + if constexpr (Rows >= 2) { + MlasStoreAccumulatorRvv(c + ldc, row1_block, vl, alpha); + } + if constexpr (Rows >= 3) { + MlasStoreAccumulatorRvv(c + ldc * 2, row2_block, vl, alpha); + } + if constexpr (Rows >= 4) { + MlasStoreAccumulatorRvv(c + ldc * 3, row3_block, vl, alpha); + } + + c += vl; + column_offset += vl; + remaining_n_block -= vl; + } + + c_block += count_n_block; + packed_b_block += CountK * kPackedCountN; + remaining_n_total -= count_n_block; + + } while (remaining_n_total > 0); + + return Rows; +} + +template +MLAS_FORCEINLINE +size_t +MlasGemmFloatKernelRvvDispatchRows( + const float* A, + const float* B, + float* C, + size_t CountK, + size_t CountM, + size_t CountN, + size_t lda, + size_t ldc, + float alpha + ) +{ + if (CountM >= 4) { + return MlasSgemmKernelRvv(A, B, C, CountK, CountN, lda, ldc, alpha); + } + + if (CountM == 3) { + return MlasSgemmKernelRvv(A, B, C, CountK, CountN, lda, ldc, alpha); + } + + if (CountM >= 2) { + return MlasSgemmKernelRvv(A, B, C, CountK, CountN, lda, ldc, alpha); + } + + return MlasSgemmKernelRvv(A, B, C, CountK, CountN, lda, ldc, alpha); +} + +template +MLAS_FORCEINLINE +size_t +MlasGemmFloatKernelRvvDispatch( + const float* A, + const float* B, + float* C, + size_t CountK, + size_t CountM, + size_t CountN, + size_t lda, + size_t ldc, + float alpha + ) +{ + if (alpha == 1.0f) { + return MlasGemmFloatKernelRvvDispatchRows( + A, B, C, CountK, CountM, CountN, lda, ldc, alpha); + } + + return MlasGemmFloatKernelRvvDispatchRows( + A, B, C, CountK, CountM, CountN, lda, ldc, alpha); +} + +} // namespace + +size_t +MLASCALL +MlasGemmFloatKernelRvv( + const float* A, + const float* B, + float* C, + size_t CountK, + size_t CountM, + size_t CountN, + size_t lda, + size_t ldc, + float alpha, + bool ZeroMode + ) +{ + if (ZeroMode) { + return MlasGemmFloatKernelRvvDispatch(A, B, C, CountK, CountM, CountN, lda, ldc, alpha); + } + + return MlasGemmFloatKernelRvvDispatch(A, B, C, CountK, CountM, CountN, lda, ldc, alpha); +} + +#endif // defined(MLAS_USE_RVV) diff --git a/onnxruntime/core/mlas/lib/riscv64/sgemm_pack_b_rvv.cpp b/onnxruntime/core/mlas/lib/riscv64/sgemm_pack_b_rvv.cpp new file mode 100644 index 0000000000000..b2ec24e3fbfdc --- /dev/null +++ b/onnxruntime/core/mlas/lib/riscv64/sgemm_pack_b_rvv.cpp @@ -0,0 +1,115 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + sgemm_pack_b_rvv.cpp + +Abstract: + + This module implements an RVV packing helper for the single precision + matrix/matrix multiply operation (SGEMM) on riscv64. + +--*/ + +#include "mlasi.h" + +#if defined(MLAS_USE_RVV) + +#include + +namespace { + +// Keep MLAS packing in 16-column tiles, but let RVV decide the actual chunk +// size at runtime via vsetvl so the same code works across different VLENs. +constexpr size_t kPackedCountN = 16; + +MLAS_FORCEINLINE +void +MlasStoreZeroPaddedBlock( + float* D, + const float* B, + size_t CountX + ) +{ + size_t remaining = kPackedCountN; + size_t offset = 0; + + while (remaining > 0) { + const size_t vl = __riscv_vsetvl_e32m4(remaining); + __riscv_vse32_v_f32m4(D + offset, __riscv_vfmv_v_f_f32m4(0.0f, vl), vl); + offset += vl; + remaining -= vl; + } + + remaining = CountX; + offset = 0; + + while (remaining > 0) { + const size_t vl = __riscv_vsetvl_e32m4(remaining); + __riscv_vse32_v_f32m4(D + offset, __riscv_vle32_v_f32m4(B + offset, vl), vl); + offset += vl; + remaining -= vl; + } +} + +MLAS_FORCEINLINE +void +MlasStoreFullBlock( + float* D, + const float* B + ) +{ + size_t remaining = kPackedCountN; + size_t offset = 0; + + while (remaining > 0) { + const size_t vl = __riscv_vsetvl_e32m4(remaining); + __riscv_vse32_v_f32m4(D + offset, __riscv_vle32_v_f32m4(B + offset, vl), vl); + offset += vl; + remaining -= vl; + } +} + +} // namespace + +void +MlasSgemmCopyPackBRvv( + float* D, + const float* B, + size_t ldb, + size_t CountX, + size_t CountY + ) +{ + while (CountX >= kPackedCountN) { + const float* b = B; + size_t y = CountY; + + do { + MlasStoreFullBlock(D, b); + D += kPackedCountN; + b += ldb; + y--; + } while (y > 0); + + B += kPackedCountN; + CountX -= kPackedCountN; + } + + if (CountX > 0) { + size_t y = CountY; + + do { + MlasStoreZeroPaddedBlock(D, B, CountX); + D += kPackedCountN; + B += ldb; + y--; + } while (y > 0); + } +} + +#endif // defined(MLAS_USE_RVV) diff --git a/onnxruntime/core/mlas/lib/riscv64/softmax_kernel_rvv.cpp b/onnxruntime/core/mlas/lib/riscv64/softmax_kernel_rvv.cpp new file mode 100644 index 0000000000000..dc548b56d676e --- /dev/null +++ b/onnxruntime/core/mlas/lib/riscv64/softmax_kernel_rvv.cpp @@ -0,0 +1,207 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + softmax_kernel_rvv.cpp + +Abstract: + + This module implements RVV kernels for the softmax critical path on + riscv64. The implementation keeps the scope intentionally small and + focuses on the float32 primitives used by Softmax and LogSoftmax: + reduction, sum-exp, normalization, and log-softmax output. + +--*/ + +#include "mlasi.h" + +#if defined(MLAS_USE_RVV) + +#include + +namespace { + +constexpr float kExpLowerRangeSumExp = -88.3762626647949f; +constexpr float kRoundingBias = MLAS_ROUNDING_BIAS_MAGIC; +constexpr float kLog2Reciprocal = 1.44269504088896341f; +constexpr float kLog2High = -6.93145752e-1f; +constexpr float kLog2Low = -1.42860677e-6f; +constexpr float kPoly0 = 0x1.694000p-10f; +constexpr float kPoly1 = 0x1.125edcp-7f; +constexpr float kPoly2 = 0x1.555b5ap-5f; +constexpr float kPoly3 = 0x1.555450p-3f; +constexpr float kPoly4 = 0x1.fffff6p-2f; +constexpr float kPoly56 = 0x1.000000p+0f; +constexpr int32_t kMaximumExponentBits = 0x3F800000; + +MLAS_FORCEINLINE +vfloat32m1_t +MlasComputeExpVectorRvv( + vfloat32m1_t value, + size_t vl + ) +{ + value = __riscv_vfmax_vf_f32m1(value, kExpLowerRangeSumExp, vl); + + vfloat32m1_t scaled = __riscv_vfmul_vf_f32m1(value, kLog2Reciprocal, vl); + vfloat32m1_t biased = __riscv_vfadd_vf_f32m1(scaled, kRoundingBias, vl); + vfloat32m1_t reduced_m = __riscv_vfsub_vf_f32m1(biased, kRoundingBias, vl); + vfloat32m1_t reduced = __riscv_vfadd_vv_f32m1( + __riscv_vfmul_vf_f32m1(reduced_m, kLog2High, vl), value, vl); + reduced = __riscv_vfadd_vv_f32m1( + __riscv_vfmul_vf_f32m1(reduced_m, kLog2Low, vl), reduced, vl); + + vfloat32m1_t poly = __riscv_vfmv_v_f_f32m1(kPoly0, vl); + poly = __riscv_vfadd_vf_f32m1( + __riscv_vfmul_vv_f32m1(poly, reduced, vl), kPoly1, vl); + poly = __riscv_vfadd_vf_f32m1( + __riscv_vfmul_vv_f32m1(poly, reduced, vl), kPoly2, vl); + poly = __riscv_vfadd_vf_f32m1( + __riscv_vfmul_vv_f32m1(poly, reduced, vl), kPoly3, vl); + poly = __riscv_vfadd_vf_f32m1( + __riscv_vfmul_vv_f32m1(poly, reduced, vl), kPoly4, vl); + poly = __riscv_vfadd_vf_f32m1( + __riscv_vfmul_vv_f32m1(poly, reduced, vl), kPoly56, vl); + poly = __riscv_vfadd_vf_f32m1( + __riscv_vfmul_vv_f32m1(poly, reduced, vl), kPoly56, vl); + + vint32m1_t exponent_bits = __riscv_vreinterpret_v_f32m1_i32m1(biased); + exponent_bits = __riscv_vsll_vx_i32m1(exponent_bits, 23, vl); + exponent_bits = __riscv_vadd_vx_i32m1(exponent_bits, kMaximumExponentBits, vl); + vfloat32m1_t scale = __riscv_vreinterpret_v_i32m1_f32m1(exponent_bits); + + return __riscv_vfmul_vv_f32m1(poly, scale, vl); +} + +MLAS_FORCEINLINE +float +MlasReduceSumRvv( + vfloat32m1_t value, + size_t vl + ) +{ + vfloat32m1_t accumulator = __riscv_vfmv_s_f_f32m1(0.0f, 1); + accumulator = __riscv_vfredusum_vs_f32m1_f32m1(value, accumulator, vl); + return __riscv_vfmv_f_s_f32m1_f32(accumulator); +} + +MLAS_FORCEINLINE +float +MlasReduceMaxRvv( + vfloat32m1_t value, + size_t vl + ) +{ + vfloat32m1_t accumulator = + __riscv_vfmv_s_f_f32m1(std::numeric_limits::lowest(), 1); + accumulator = __riscv_vfredmax_vs_f32m1_f32m1(value, accumulator, vl); + return __riscv_vfmv_f_s_f32m1_f32(accumulator); +} + +} // namespace + +float +MLASCALL +MlasReduceMaximumF32KernelRvv( + const float* Input, + size_t N + ) +{ + float maximum = std::numeric_limits::lowest(); + + while (N > 0) { + size_t vl = __riscv_vsetvl_e32m1(N); + vfloat32m1_t input = __riscv_vle32_v_f32m1(Input, vl); + input = __riscv_vfmax_vf_f32m1(input, maximum, vl); + maximum = MlasReduceMaxRvv(input, vl); + + Input += vl; + N -= vl; + } + + return maximum; +} + +float +MLASCALL +MlasComputeSumExpF32KernelRvv( + const float* Input, + float* Output, + size_t N, + const float* NegativeMaximum + ) +{ + const float negative_maximum = *NegativeMaximum; + float accumulation = 0.0f; + + while (N > 0) { + size_t vl = __riscv_vsetvl_e32m1(N); + vfloat32m1_t input = __riscv_vle32_v_f32m1(Input, vl); + vfloat32m1_t shifted = __riscv_vfadd_vf_f32m1(input, negative_maximum, vl); + vfloat32m1_t exp_value = MlasComputeExpVectorRvv(shifted, vl); + + if (Output != nullptr) { + __riscv_vse32_v_f32m1(Output, exp_value, vl); + Output += vl; + } + + accumulation += MlasReduceSumRvv(exp_value, vl); + + Input += vl; + N -= vl; + } + + return accumulation; +} + +void +MLASCALL +MlasComputeSoftmaxOutputF32KernelRvv( + float* Output, + size_t N, + const float* Parameters + ) +{ + const float scale = Parameters[0]; + + while (N > 0) { + size_t vl = __riscv_vsetvl_e32m1(N); + vfloat32m1_t output = __riscv_vle32_v_f32m1(Output, vl); + output = __riscv_vfmul_vf_f32m1(output, scale, vl); + __riscv_vse32_v_f32m1(Output, output, vl); + + Output += vl; + N -= vl; + } +} + +void +MLASCALL +MlasComputeLogSoftmaxOutputF32KernelRvv( + const float* Input, + float* Output, + size_t N, + const float* Parameters + ) +{ + const float negative_maximum = Parameters[0]; + const float logarithm = Parameters[1]; + + while (N > 0) { + size_t vl = __riscv_vsetvl_e32m1(N); + vfloat32m1_t input = __riscv_vle32_v_f32m1(Input, vl); + input = __riscv_vfadd_vf_f32m1(input, negative_maximum, vl); + input = __riscv_vfsub_vf_f32m1(input, logarithm, vl); + __riscv_vse32_v_f32m1(Output, input, vl); + + Input += vl; + Output += vl; + N -= vl; + } +} + +#endif // defined(MLAS_USE_RVV) diff --git a/onnxruntime/core/mlas/lib/sgemm.cpp b/onnxruntime/core/mlas/lib/sgemm.cpp index 7836b1f89b0c4..88d0308bfa21e 100644 --- a/onnxruntime/core/mlas/lib/sgemm.cpp +++ b/onnxruntime/core/mlas/lib/sgemm.cpp @@ -247,6 +247,13 @@ Return Value: --*/ { +#if defined(MLAS_TARGET_RISCV64) && defined(MLAS_USE_RVV) && !defined(FORCE_GENERIC_ALGORITHMS) + if (GetMlasPlatform().GemmFloatKernel != nullptr) { + MlasSgemmCopyPackBRvv(D, B, ldb, CountX, CountY); + return; + } +#endif + // // Copy data from matrix B into the destination buffer 16 columns at a // time. @@ -1004,6 +1011,14 @@ Return Value: #if (defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_S390X) || defined(MLAS_TARGET_LARCH64)) && !defined(FORCE_GENERIC_ALGORITHMS) RowsHandled = GetMlasPlatform().GemmFloatKernel(A, B, C, CountK, CountM, CountN, lda, ldc, alpha, ZeroMode); +#elif defined(MLAS_TARGET_RISCV64) && !defined(FORCE_GENERIC_ALGORITHMS) + if (GetMlasPlatform().GemmFloatKernel != nullptr) { + RowsHandled = GetMlasPlatform().GemmFloatKernel(A, B, C, CountK, CountM, CountN, lda, ldc, alpha, ZeroMode); + } else if (ZeroMode) { + RowsHandled = MlasSgemmKernelZero(A, B, C, CountK, CountM, CountN, lda, ldc, alpha); + } else { + RowsHandled = MlasSgemmKernelAdd(A, B, C, CountK, CountM, CountN, lda, ldc, alpha); + } #else if (ZeroMode) { RowsHandled = MlasSgemmKernelZero(A, B, C, CountK, CountM, CountN, lda, ldc, alpha); diff --git a/onnxruntime/test/mlas/bench/riscv64/README.md b/onnxruntime/test/mlas/bench/riscv64/README.md new file mode 100644 index 0000000000000..136c40d39430f --- /dev/null +++ b/onnxruntime/test/mlas/bench/riscv64/README.md @@ -0,0 +1,77 @@ +# RISC-V MLAS Benchmarks + +This directory stores the standalone benchmarks and compare tools used while +bringing up and tuning the RVV path in MLAS. + +Files: + +- `sgemm_riscv_bench.cpp`: standalone SGEMM timing harness with checksum + output. Useful for RVV versus scalar comparisons. +- `softmax_rvv_compare.cpp`: scalar versus RVV validation and timing tool for + the Softmax critical path. + +These tools are intentionally kept separate from `onnxruntime_mlas_benchmark`. +Each source file has its own `main()` and is built as an independent target. + +## Build + +On a riscv64 RVV build, first regenerate the build tree: + +```bash +python3 tools/ci_build/build.py \ + --config Release \ + --build_dir build/k1_rvv_resync \ + --update \ + --skip_tests \ + --skip_pip_install \ + --skip_submodule_sync \ + --no_sve \ + --enable_rvv +``` + +Then build both standalone tools directly with CMake: + +```bash +cmake --build build/k1_rvv_resync/Release \ + --config Release \ + --target onnxruntime_mlas_sgemm_riscv_bench onnxruntime_mlas_softmax_riscv_compare \ + -- -j8 +``` + +The resulting binaries are typically placed under: + +```bash +build/k1_rvv_resync/Release/onnxruntime_mlas_sgemm_riscv_bench +build/k1_rvv_resync/Release/onnxruntime_mlas_softmax_riscv_compare +``` + +## SGEMM examples + +RVV, packed-B: + +```bash +taskset -c 0 build/k1_rvv_resync/Release/onnxruntime_mlas_sgemm_riscv_bench \ + --m=128 --n=3072 --k=768 --iters=10 --warmup=3 --pack_b=1 --trans_a=0 --trans_b=0 +``` + +Scalar baseline on the same binary: + +```bash +ORT_MLAS_RISCV_FORCE_SCALAR=1 taskset -c 0 \ + build/k1_rvv_resync/Release/onnxruntime_mlas_sgemm_riscv_bench \ + --m=128 --n=3072 --k=768 --iters=10 --warmup=3 --pack_b=1 --trans_a=0 --trans_b=0 +``` + +## Softmax examples + +```bash +taskset -c 0 build/k1_rvv_resync/Release/onnxruntime_mlas_softmax_riscv_compare +``` + +## Notes + +- The RVV SGEMM path is written to be VLEN-agnostic. The MLAS packing format + remains 16 columns wide, but each tile is consumed using runtime `vsetvl` + chunking so the same binary works across different VLENs such as 128 and 256. +- `ORT_MLAS_RISCV_FORCE_SCALAR=1` disables the RVV dispatch at runtime and is + the preferred way to gather scalar baselines from the same build. diff --git a/onnxruntime/test/mlas/bench/riscv64/sgemm_riscv_bench.cpp b/onnxruntime/test/mlas/bench/riscv64/sgemm_riscv_bench.cpp new file mode 100644 index 0000000000000..d94840ffec518 --- /dev/null +++ b/onnxruntime/test/mlas/bench/riscv64/sgemm_riscv_bench.cpp @@ -0,0 +1,240 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + sgemm_riscv_bench.cpp + +Abstract: + + This module implements a standalone SGEMM benchmark used while tuning the + RISC-V MLAS path. It is intentionally separate from the Google Benchmark + suite so it can print pack time, compute time, checksum, and compare RVV + against scalar execution via ORT_MLAS_RISCV_FORCE_SCALAR. + +--*/ + +#include "mlas.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace { + +struct Options { + size_t m = 128; + size_t n = 3072; + size_t k = 768; + size_t iters = 20; + size_t warmup = 3; + bool pack_b = false; + bool trans_a = false; + bool trans_b = false; + float alpha = 1.0f; + float beta = 0.0f; +}; + +void PrintUsage(const char* argv0) { + std::cout + << "Usage: " << argv0 << " [--m=N] [--n=N] [--k=N] [--iters=N] [--warmup=N]\n" + << " [--pack_b=0|1] [--trans_a=0|1] [--trans_b=0|1]\n" + << " [--alpha=F] [--beta=F]\n"; +} + +bool ParseBool(std::string_view value) { + return value == "1" || value == "true" || value == "on" || value == "yes"; +} + +float MakeValue(size_t index) { + uint32_t x = static_cast(index * 747796405u + 2891336453u); + x ^= x >> 16; + x *= 2246822519u; + x ^= x >> 13; + const uint32_t bucket = x % 2048u; + return (static_cast(bucket) / 1024.0f) - 1.0f; +} + +Options ParseArgs(int argc, char** argv) { + Options options; + + for (int i = 1; i < argc; ++i) { + std::string_view arg(argv[i]); + if (arg == "--help" || arg == "-h") { + PrintUsage(argv[0]); + std::exit(0); + } + + const auto split = arg.find('='); + if (split == std::string_view::npos || split == 0 || split + 1 >= arg.size()) { + continue; + } + + const std::string_view key = arg.substr(0, split); + const std::string_view value = arg.substr(split + 1); + + if (key == "--m") { + options.m = std::strtoull(value.data(), nullptr, 10); + } else if (key == "--n") { + options.n = std::strtoull(value.data(), nullptr, 10); + } else if (key == "--k") { + options.k = std::strtoull(value.data(), nullptr, 10); + } else if (key == "--iters") { + options.iters = std::strtoull(value.data(), nullptr, 10); + } else if (key == "--warmup") { + options.warmup = std::strtoull(value.data(), nullptr, 10); + } else if (key == "--pack_b") { + options.pack_b = ParseBool(value); + } else if (key == "--trans_a") { + options.trans_a = ParseBool(value); + } else if (key == "--trans_b") { + options.trans_b = ParseBool(value); + } else if (key == "--alpha") { + options.alpha = std::strtof(value.data(), nullptr); + } else if (key == "--beta") { + options.beta = std::strtof(value.data(), nullptr); + } + } + + return options; +} + +template +double TimeLoop(size_t iterations, Fn&& fn) { + const auto begin = std::chrono::steady_clock::now(); + for (size_t i = 0; i < iterations; ++i) { + fn(); + } + const auto end = std::chrono::steady_clock::now(); + return std::chrono::duration(end - begin).count(); +} + +} // namespace + +int main(int argc, char** argv) { + const Options options = ParseArgs(argc, argv); + + if (options.m == 0 || options.n == 0 || options.k == 0 || options.iters == 0) { + std::cerr << "m, n, k, and iters must be > 0" << std::endl; + return 1; + } + + const size_t a_size = options.m * options.k; + const size_t b_size = options.n * options.k; + const size_t c_size = options.m * options.n; + + std::vector a(a_size); + std::vector b(b_size); + std::vector c(c_size, 0.0f); + + for (size_t i = 0; i < a.size(); ++i) { + a[i] = MakeValue(i); + } + for (size_t i = 0; i < b.size(); ++i) { + b[i] = MakeValue(i + a.size()); + } + + const CBLAS_TRANSPOSE trans_a = options.trans_a ? CblasTrans : CblasNoTrans; + const CBLAS_TRANSPOSE trans_b = options.trans_b ? CblasTrans : CblasNoTrans; + const size_t lda = options.trans_a ? options.m : options.k; + const size_t ldb = options.trans_b ? options.k : options.n; + const size_t ldc = options.n; + + std::vector packed_b; + double pack_ms = 0.0; + + if (options.pack_b) { + const size_t packed_b_size = MlasGemmPackBSize(trans_a, trans_b, options.n, options.k, nullptr); + if (packed_b_size == 0) { + std::cerr << "packing is not supported for this configuration" << std::endl; + return 2; + } + + packed_b.resize(packed_b_size); + + pack_ms = TimeLoop(options.iters, [&]() { + MlasGemmPackB(trans_a, trans_b, options.n, options.k, b.data(), ldb, packed_b.data(), nullptr); + }); + + MlasGemmPackB(trans_a, trans_b, options.n, options.k, b.data(), ldb, packed_b.data(), nullptr); + } + + auto run_once = [&]() { + if (options.beta == 0.0f) { + std::fill(c.begin(), c.end(), 0.0f); + } + + if (options.pack_b) { + MlasGemm( + trans_a, + options.m, + options.n, + options.k, + options.alpha, + a.data(), + lda, + packed_b.data(), + options.beta, + c.data(), + ldc, + nullptr, + nullptr); + } else { + MlasGemm( + trans_a, + trans_b, + options.m, + options.n, + options.k, + options.alpha, + a.data(), + lda, + b.data(), + ldb, + options.beta, + c.data(), + ldc, + nullptr, + nullptr); + } + }; + + for (size_t i = 0; i < options.warmup; ++i) { + run_once(); + } + + const double compute_ms = TimeLoop(options.iters, run_once); + const double avg_compute_ms = compute_ms / static_cast(options.iters); + const double avg_pack_ms = pack_ms / static_cast(options.iters); + const double flops = 2.0 * static_cast(options.m) * static_cast(options.n) * + static_cast(options.k); + const double gflops = flops / (avg_compute_ms * 1.0e6); + const double checksum = std::accumulate(c.begin(), c.end(), 0.0); + + std::cout << std::fixed << std::setprecision(4); + std::cout << "M=" << options.m + << " N=" << options.n + << " K=" << options.k + << " pack_b=" << (options.pack_b ? 1 : 0) + << " trans_a=" << (options.trans_a ? 1 : 0) + << " trans_b=" << (options.trans_b ? 1 : 0) + << " iters=" << options.iters + << " warmup=" << options.warmup << '\n'; + if (options.pack_b) { + std::cout << "pack_total_ms=" << pack_ms << " pack_avg_ms=" << avg_pack_ms << '\n'; + } + std::cout << "compute_total_ms=" << compute_ms + << " compute_avg_ms=" << avg_compute_ms + << " gflops=" << gflops << '\n'; + std::cout << "checksum=" << checksum << std::endl; + + return 0; +} diff --git a/onnxruntime/test/mlas/bench/riscv64/softmax_rvv_compare.cpp b/onnxruntime/test/mlas/bench/riscv64/softmax_rvv_compare.cpp new file mode 100644 index 0000000000000..e4411d3920408 --- /dev/null +++ b/onnxruntime/test/mlas/bench/riscv64/softmax_rvv_compare.cpp @@ -0,0 +1,241 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + softmax_rvv_compare.cpp + +Abstract: + + This module implements a standalone RVV versus scalar validation and + timing tool for the Softmax critical path on riscv64. + +--*/ + +#include "mlas.h" + +#include + +#if !defined(MLAS_TARGET_RISCV64) + +int main() { + std::cout << "softmax_rvv_compare is only supported on riscv64." << std::endl; + return 0; +} + +#elif !defined(MLAS_USE_RVV) + +int main() { + std::cout << "softmax_rvv_compare requires an RVV-enabled MLAS build." << std::endl; + return 0; +} + +#else + +#include "core/mlas/lib/mlasi.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace { + +struct CompareStats { + float max_abs_diff = 0.0f; + float max_rel_diff = 0.0f; + double checksum_scalar = 0.0; + double checksum_rvv = 0.0; +}; + +struct TimingStats { + double scalar_ms = 0.0; + double rvv_ms = 0.0; +}; + +void ScalarSoftmaxRow(const float* input, float* output, size_t d, bool log_softmax, bool smooth_softmax) { + float maximum = MlasReduceMaximumF32Kernel(input, d); + if (smooth_softmax && maximum < 0.0f) { + maximum = 0.0f; + } + + const float negative_maximum = -maximum; + + if (log_softmax) { + float accumulation = MlasComputeSumExpF32Kernel(input, nullptr, d, &negative_maximum); + if (smooth_softmax) { + accumulation += std::exp(-maximum); + } + + const float parameters[2] = {negative_maximum, std::log(accumulation)}; + MlasComputeLogSoftmaxOutputF32Kernel(input, output, d, parameters); + return; + } + + float accumulation = MlasComputeSumExpF32Kernel(input, output, d, &negative_maximum); + if (smooth_softmax) { + accumulation += std::exp(-maximum); + } + + const float parameters[1] = {1.0f / accumulation}; + MlasComputeSoftmaxOutputF32Kernel(output, d, parameters); +} + +void RvvSoftmaxRow(const float* input, float* output, size_t d, bool log_softmax, bool smooth_softmax) { + auto& platform = GetMlasPlatform(); + + float maximum = platform.ReduceMaximumF32Kernel(input, d); + if (smooth_softmax && maximum < 0.0f) { + maximum = 0.0f; + } + + const float negative_maximum = -maximum; + + if (log_softmax) { + float accumulation = platform.ComputeSumExpF32Kernel(input, nullptr, d, &negative_maximum); + if (smooth_softmax) { + accumulation += std::exp(-maximum); + } + + const float parameters[2] = {negative_maximum, std::log(accumulation)}; + platform.ComputeLogSoftmaxOutputF32Kernel(input, output, d, parameters); + return; + } + + float accumulation = platform.ComputeSumExpF32Kernel(input, output, d, &negative_maximum); + if (smooth_softmax) { + accumulation += std::exp(-maximum); + } + + const float parameters[1] = {1.0f / accumulation}; + platform.ComputeSoftmaxOutputF32Kernel(output, d, parameters); +} + +CompareStats CompareCase(size_t rows, size_t d, bool log_softmax, bool smooth_softmax) { + std::vector input(rows * d); + std::vector scalar_output(rows * d); + std::vector rvv_output(rows * d); + + std::mt19937 rng( + static_cast(rows * 131 + d * 17 + (log_softmax ? 7 : 0) + (smooth_softmax ? 19 : 0))); + std::uniform_real_distribution dist(-150.0f, 190.0f); + + for (float& value : input) { + value = dist(rng); + } + + for (size_t row = 0; row < rows; ++row) { + const float* row_input = input.data() + row * d; + ScalarSoftmaxRow(row_input, scalar_output.data() + row * d, d, log_softmax, smooth_softmax); + RvvSoftmaxRow(row_input, rvv_output.data() + row * d, d, log_softmax, smooth_softmax); + } + + CompareStats stats; + for (size_t i = 0; i < rows * d; ++i) { + const float scalar = scalar_output[i]; + const float rvv = rvv_output[i]; + const float abs_diff = std::fabs(scalar - rvv); + const float rel_diff = abs_diff / std::max(std::fabs(scalar), 1.0e-12f); + stats.max_abs_diff = std::max(stats.max_abs_diff, abs_diff); + stats.max_rel_diff = std::max(stats.max_rel_diff, rel_diff); + stats.checksum_scalar += scalar; + stats.checksum_rvv += rvv; + } + + return stats; +} + +TimingStats TimeCase(size_t rows, size_t d, size_t repeats, bool log_softmax, bool smooth_softmax) { + std::vector input(rows * d); + std::vector scalar_output(rows * d); + std::vector rvv_output(rows * d); + + std::mt19937 rng(static_cast(rows * 97 + d * 29 + repeats)); + std::uniform_real_distribution dist(-10.0f, 10.0f); + + for (float& value : input) { + value = dist(rng); + } + + const auto scalar_begin = std::chrono::steady_clock::now(); + for (size_t repeat = 0; repeat < repeats; ++repeat) { + for (size_t row = 0; row < rows; ++row) { + ScalarSoftmaxRow(input.data() + row * d, scalar_output.data() + row * d, d, log_softmax, smooth_softmax); + } + } + const auto scalar_end = std::chrono::steady_clock::now(); + + const auto rvv_begin = std::chrono::steady_clock::now(); + for (size_t repeat = 0; repeat < repeats; ++repeat) { + for (size_t row = 0; row < rows; ++row) { + RvvSoftmaxRow(input.data() + row * d, rvv_output.data() + row * d, d, log_softmax, smooth_softmax); + } + } + const auto rvv_end = std::chrono::steady_clock::now(); + + TimingStats stats; + stats.scalar_ms = + std::chrono::duration_cast >(scalar_end - scalar_begin).count(); + stats.rvv_ms = + std::chrono::duration_cast >(rvv_end - rvv_begin).count(); + return stats; +} + +void PrintCompareCase(const std::string& name, size_t rows, size_t d, bool log_softmax, bool smooth_softmax) { + const auto stats = CompareCase(rows, d, log_softmax, smooth_softmax); + std::cout << name << " rows=" << rows << " d=" << d << " log_softmax=" << log_softmax + << " smooth=" << smooth_softmax << '\n'; + std::cout << " max_abs_diff=" << std::setprecision(9) << stats.max_abs_diff + << " max_rel_diff=" << stats.max_rel_diff << '\n'; + std::cout << " checksum_scalar=" << std::setprecision(12) << stats.checksum_scalar + << " checksum_rvv=" << stats.checksum_rvv << '\n'; +} + +void PrintTimingCase( + const std::string& name, size_t rows, size_t d, size_t repeats, bool log_softmax, bool smooth_softmax) { + const auto stats = TimeCase(rows, d, repeats, log_softmax, smooth_softmax); + const double speedup = stats.rvv_ms > 0.0 ? stats.scalar_ms / stats.rvv_ms : 0.0; + std::cout << name << " rows=" << rows << " d=" << d << " repeats=" << repeats + << " log_softmax=" << log_softmax << " smooth=" << smooth_softmax << '\n'; + std::cout << " scalar_ms=" << std::fixed << std::setprecision(3) << stats.scalar_ms + << " rvv_ms=" << stats.rvv_ms << " speedup=" << speedup << "x\n"; +} + +} // namespace + +int main() { + auto& platform = GetMlasPlatform(); + + std::cout << std::boolalpha; + std::cout << "dispatch_is_rvv_reduce=" + << (platform.ReduceMaximumF32Kernel == MlasReduceMaximumF32KernelRvv) << '\n'; + std::cout << "dispatch_is_rvv_sumexp=" + << (platform.ComputeSumExpF32Kernel == MlasComputeSumExpF32KernelRvv) << '\n'; + std::cout << "dispatch_is_rvv_softmax=" + << (platform.ComputeSoftmaxOutputF32Kernel == MlasComputeSoftmaxOutputF32KernelRvv) << '\n'; + std::cout << "dispatch_is_rvv_logsoftmax=" + << (platform.ComputeLogSoftmaxOutputF32Kernel == MlasComputeLogSoftmaxOutputF32KernelRvv) << '\n'; + std::cout << '\n'; + + PrintCompareCase("regression_case_3x128_softmax", 3, 128, false, true); + PrintCompareCase("regression_case_3x128_logsoftmax", 3, 128, true, true); + PrintCompareCase("regression_case_63x95_softmax", 63, 95, false, true); + PrintCompareCase("regression_case_16x211_softmax", 16, 211, false, true); + std::cout << '\n'; + + PrintTimingCase("perf_case_attention_like", 4096, 128, 100, false, true); + PrintTimingCase("perf_case_long_seq", 1024, 1024, 20, false, true); + + return 0; +} + +#endif diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 4b231011832e0..f42617ba1b04c 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -888,6 +888,9 @@ def generate_build_tree( if args.enable_arm_neon_nchwc: cmake_args += ["-Donnxruntime_USE_ARM_NEON_NCHWC=ON"] + if args.enable_rvv: + cmake_args += ["-Donnxruntime_USE_RVV=ON"] + if not args.no_sve: cmake_args += ["-Donnxruntime_USE_SVE=ON"] diff --git a/tools/ci_build/build_args.py b/tools/ci_build/build_args.py index e30c5f8979183..b40bf4c2b25c6 100644 --- a/tools/ci_build/build_args.py +++ b/tools/ci_build/build_args.py @@ -673,6 +673,11 @@ def add_execution_provider_args(parser: argparse.ArgumentParser) -> None: cpu_group.add_argument( "--enable_arm_neon_nchwc", action="store_true", help="Enables building with NCHWc ARM kernels." ) + cpu_group.add_argument( + "--enable_rvv", + action="store_true", + help="Enable riscv64 MLAS kernels that use the RISC-V Vector extension.", + ) # --- DNNL (formerly MKL-DNN / oneDNN) --- dnnl_group = parser.add_argument_group("DNNL Execution Provider")