Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -908,6 +908,11 @@ endif()
if(RISCV64 AND MLAS_SOURCE_IS_NOT_SET)
file(GLOB_RECURSE mlas_platform_srcs CONFIGURE_DEPENDS
"${MLAS_SRC_DIR}/scalar/*.cpp")
# Remove scalar depthwise kernel; replaced by the vectorized version
list(REMOVE_ITEM mlas_platform_srcs
"${MLAS_SRC_DIR}/scalar/SconvDepthwiseKernelScalar.cpp")
list(APPEND mlas_platform_srcs
${MLAS_SRC_DIR}/sconv_nchw_depthwise_multiplier_1.cpp)

if(onnxruntime_USE_RVV)
set(OLD_CMAKE_REQUIRED_FLAGS "${CMAKE_REQUIRED_FLAGS}")
Expand All @@ -929,11 +934,18 @@ endif()
${MLAS_SRC_DIR}/riscv64/sgemm_pack_b_rvv.cpp
${MLAS_SRC_DIR}/riscv64/sgemm_kernel_rvv.cpp
${MLAS_SRC_DIR}/riscv64/softmax_kernel_rvv.cpp
${MLAS_SRC_DIR}/riscv64/sconv_depthwise_kernel_rvv.cpp
${MLAS_SRC_DIR}/riscv64/sconv_nchwc_kernel_rvv.cpp
)
# RVV depthwise replaces the MLAS_FLOAT32X4 version
list(REMOVE_ITEM mlas_platform_srcs
"${MLAS_SRC_DIR}/sconv_nchw_depthwise_multiplier_1.cpp")
set_source_files_properties(
${MLAS_SRC_DIR}/riscv64/sgemm_pack_b_rvv.cpp
${MLAS_SRC_DIR}/riscv64/sgemm_kernel_rvv.cpp
${MLAS_SRC_DIR}/riscv64/softmax_kernel_rvv.cpp
${MLAS_SRC_DIR}/riscv64/sconv_depthwise_kernel_rvv.cpp
${MLAS_SRC_DIR}/riscv64/sconv_nchwc_kernel_rvv.cpp
PROPERTIES COMPILE_FLAGS "-march=rv64gcv -mabi=lp64d")
list(APPEND mlas_private_compile_definitions MLAS_USE_RVV=1)
else()
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/mlas/inc/mlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -881,7 +881,7 @@ enum MLAS_CONV_ALGORITHM {
MlasConvAlgorithmExpandThenGemm,
MlasConvAlgorithmExpandThenGemmSegmented,
MlasConvAlgorithmDepthwiseMultiplierGreaterThan1,
#if defined(MLAS_TARGET_WASM_SCALAR) || defined(MLAS_TARGET_ARM64)
#if defined(MLAS_TARGET_WASM_SCALAR) || defined(MLAS_TARGET_ARM64) || defined(MLAS_TARGET_RISCV64)
MlasConvAlgorithmDepthwise,
#endif
};
Expand Down
17 changes: 8 additions & 9 deletions onnxruntime/core/mlas/lib/convolve.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -855,7 +855,7 @@ Return Value:
}
}

#if defined(MLAS_TARGET_WASM_SCALAR) || defined(MLAS_TARGET_ARM64)
#if defined(MLAS_TARGET_WASM_SCALAR) || defined(MLAS_TARGET_ARM64) || defined(MLAS_TARGET_RISCV64)

void
MlasDepthwiseThreaded(
Expand Down Expand Up @@ -1119,7 +1119,7 @@ Return Value:
return;
}

#if defined(MLAS_TARGET_WASM_SCALAR) || defined(MLAS_TARGET_ARM64)
#if defined(MLAS_TARGET_WASM_SCALAR) || defined(MLAS_TARGET_ARM64) || defined(MLAS_TARGET_RISCV64)

if (Algorithm == MlasConvAlgorithmDepthwise) {
// Fill the Working Buffer with Zero for use by the depthwise kernel.
Expand Down Expand Up @@ -1178,7 +1178,7 @@ Return Value:
}


#if defined(MLAS_TARGET_WASM_SCALAR) || defined(MLAS_TARGET_ARM64)
#if defined(MLAS_TARGET_WASM_SCALAR) || defined(MLAS_TARGET_ARM64) || defined(MLAS_TARGET_RISCV64)

if (Algorithm == MlasConvAlgorithmDepthwise && ((BatchCount > 1) || (GroupCount > 1))) {
const size_t BatchGroupCount = BatchCount * GroupCount;
Expand Down Expand Up @@ -1277,7 +1277,7 @@ Return Value:
break;
}

#if defined(MLAS_TARGET_WASM_SCALAR) || defined(MLAS_TARGET_ARM64)
#if defined(MLAS_TARGET_WASM_SCALAR) || defined(MLAS_TARGET_ARM64) || defined(MLAS_TARGET_RISCV64)

case MlasConvAlgorithmDepthwise:
{
Expand Down Expand Up @@ -1549,15 +1549,14 @@ Return Value:
}
#endif

#if defined(MLAS_TARGET_WASM_SCALAR) || defined(MLAS_TARGET_ARM64)
#if defined(MLAS_TARGET_WASM_SCALAR) || defined(MLAS_TARGET_ARM64) || defined(MLAS_TARGET_RISCV64)

// Scalar (WASM_SCALAR) / vectorized (ARM64) direct conv for depthwise convolution.
// Scalar (WASM_SCALAR) / vectorized (ARM64/RISCV64) direct conv for depthwise convolution.
// Currently only support 3x3 kernel with padding <=1 and dilations = 1
// and on ARM64, it is further restricted to strides = 1.
// and on ARM64/RISCV64, it is further restricted to strides = 1.
// TODO: support more general depthwise convolution.

// On ARM64, only support stride = 1 for depthwise conv.
#if defined(MLAS_TARGET_ARM64)
#if defined(MLAS_TARGET_ARM64) || defined(MLAS_TARGET_RISCV64)
bool depthwise_conv_stride_support_check = Parameters->StrideShape[0] == 1 && Parameters->StrideShape[1] == 1;
#else
bool depthwise_conv_stride_support_check = true;
Expand Down
18 changes: 17 additions & 1 deletion onnxruntime/core/mlas/lib/mlasi.h
Original file line number Diff line number Diff line change
Expand Up @@ -1203,6 +1203,13 @@ extern "C" {
MLAS_REDUCE_MAXIMUM_FLOAT_KERNEL MlasReduceMaximumF32KernelRvv;
MLAS_COMPUTE_SOFTMAX_OUTPUT_FLOAT_KERNEL MlasComputeSoftmaxOutputF32KernelRvv;
MLAS_COMPUTE_LOGSOFTMAX_OUTPUT_FLOAT_KERNEL MlasComputeLogSoftmaxOutputF32KernelRvv;
MLAS_CONV_FLOAT_KERNEL MlasConvNchwFloatKernelRvv;
MLAS_CONV_FLOAT_KERNEL MlasConvNchwcFloatKernelRvv;
MLAS_CONV_DEPTHWISE_FLOAT_KERNEL MlasConvDepthwiseFloatKernelRvv;
MLAS_CONV_POINTWISE_FLOAT_KERNEL MlasConvPointwiseFloatKernelRvv;
MLAS_POOL_FLOAT_KERNEL MlasPoolMaximumFloatKernelRvv;
MLAS_POOL_FLOAT_KERNEL MlasPoolAverageExcludePadFloatKernelRvv;
MLAS_POOL_FLOAT_KERNEL MlasPoolAverageIncludePadFloatKernelRvv;
#endif
#if defined(MLAS_TARGET_AMD64)
MLAS_REDUCE_MAXIMUM_FLOAT_KERNEL MlasReduceMaximumF32KernelAvx;
Expand Down Expand Up @@ -1553,6 +1560,15 @@ struct MLAS_PLATFORM {
MLAS_COMPUTE_SOFTMAX_OUTPUT_FLOAT_KERNEL* ComputeSoftmaxOutputF32Kernel;
#endif

#if defined(MLAS_TARGET_RISCV64) && defined(MLAS_USE_RVV)
MLAS_CONV_FLOAT_KERNEL* ConvNchwFloatKernel;
MLAS_CONV_FLOAT_KERNEL* ConvNchwcFloatKernel;
MLAS_CONV_DEPTHWISE_FLOAT_KERNEL* ConvDepthwiseFloatKernel;
MLAS_CONV_POINTWISE_FLOAT_KERNEL* ConvPointwiseFloatKernel;
MLAS_POOL_FLOAT_KERNEL* PoolFloatKernel[MlasPoolingKindCount];
uint32_t NchwcBlockSize;
#endif

MLAS_COMPUTE_ERF_FP16_KERNEL* ErfFP16KernelRoutine = nullptr;
MLAS_COMPUTE_GELU_FP16_KERNEL* GeluFP16KernelRoutine = nullptr;
MLAS_COMPUTE_TANH_FP16_KERNEL* TanhFP16KernelRoutine = nullptr;
Expand Down Expand Up @@ -1760,7 +1776,7 @@ MlasFp32FromBits(
#pragma warning(pop)
#endif

#if defined(MLAS_TARGET_WASM_SCALAR) || defined(MLAS_TARGET_ARM64)
#if defined(MLAS_TARGET_WASM_SCALAR) || defined(MLAS_TARGET_ARM64) || defined(MLAS_TARGET_RISCV64)
void
MLASCALL
MlasConvDepthwiseFloat_CHW(
Expand Down
9 changes: 9 additions & 0 deletions onnxruntime/core/mlas/lib/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,15 @@ Return Value:
this->ComputeSumExpF32Kernel = MlasComputeSumExpF32KernelRvv;
this->ComputeSoftmaxOutputF32Kernel = MlasComputeSoftmaxOutputF32KernelRvv;
this->ComputeLogSoftmaxOutputF32Kernel = MlasComputeLogSoftmaxOutputF32KernelRvv;

this->NchwcBlockSize = 16;
this->ConvNchwFloatKernel = MlasConvNchwFloatKernelRvv;
this->ConvNchwcFloatKernel = MlasConvNchwcFloatKernelRvv;
this->ConvDepthwiseFloatKernel = MlasConvDepthwiseFloatKernelRvv;
this->ConvPointwiseFloatKernel = MlasConvPointwiseFloatKernelRvv;
this->PoolFloatKernel[MlasMaximumPooling] = MlasPoolMaximumFloatKernelRvv;
this->PoolFloatKernel[MlasAveragePoolingExcludePad] = MlasPoolAverageExcludePadFloatKernelRvv;
this->PoolFloatKernel[MlasAveragePoolingIncludePad] = MlasPoolAverageIncludePadFloatKernelRvv;
Comment on lines +342 to +349
}
#endif
#endif
Expand Down
Loading
Loading