From da4023c222f427861dafddacc94b26150a94bcfa Mon Sep 17 00:00:00 2001 From: xiaying Date: Thu, 19 Dec 2024 16:20:00 +0800 Subject: [PATCH 1/6] MNN:Sync: Sync Interal 3.0.2 --- .gitignore | 7 - 3rd_party/OpenCLHeaders/CL/cl2.hpp | 18 + 3rd_party/OpenCLHeaders/CL/cl_ext.h | 22 + CMakeLists.txt | 32 +- .../compute/DeconvolutionWithStride.cpp | 0 .../compute/DeconvolutionWithStride.hpp | 0 .../cpubackend}/compute/GemmInt8Executor.cpp | 0 .../cpubackend}/compute/GemmInt8Executor.hpp | 0 docs/compile/cmake.md | 4 +- docs/compile/other.md | 13 + docs/transformers/llm.md | 51 +- docs/transformers/models.md | 3 +- express/MathOp.cpp | 23 +- express/NeuralNetWorkOp.cpp | 242 ++-- include/MNN/MNNDefine.h | 2 +- include/MNN/MNNForwardType.h | 8 +- include/MNN/expr/MathOp.hpp | 13 +- project/android/build_32.sh | 2 +- project/android/build_64.sh | 2 +- project/ios/MNN.xcodeproj/project.pbxproj | 38 +- pymnn/CMakeLists.txt | 7 +- pymnn/pip_package/MNN/__init__.py | 1 + pymnn/pip_package/MNN/audio/__init__.py | 96 ++ pymnn/pip_package/MNN/llm/__init__.py | 22 +- pymnn/pip_package/build_deps.py | 6 +- pymnn/pip_package/setup.py | 5 +- pymnn/src/MNN.cc | 18 +- pymnn/src/audio.h | 105 ++ pymnn/src/llm.h | 41 +- schema/current/MNN_generated.h | 204 +++- schema/default/MNN.fbs | 10 +- source/backend/cpu/CPUBackend.hpp | 3 + source/backend/cpu/CPUBinaryInt8.cpp | 16 +- source/backend/cpu/CPUDeconvolution.cpp | 374 ++---- source/backend/cpu/CPUDeconvolution.hpp | 11 +- source/backend/cpu/CPUInstanceNorm.cpp | 11 +- source/backend/cpu/CPUMoments.cpp | 11 +- source/backend/cpu/CPUOPRegister.cpp | 6 + source/backend/cpu/CPURelu.cpp | 100 +- source/backend/cpu/CPUStft.cpp | 75 ++ source/backend/cpu/CPUStft.hpp | 31 + .../arm/arm32/MNNReluWithSlopeChannelInt8.S | 62 +- .../arm/arm32/MNNWinogradMatrixProductLeft.S | 225 ---- .../arm/arm32/MNNWinogradMatrixProductRight.S | 223 ---- .../arm/arm64/MNNReluWithSlopeChannelInt8.S | 61 +- .../arm/arm64/MNNWinogradMatrixProductLeft.S | 171 --- .../arm/arm64/MNNWinogradMatrixProductRight.S | 164 --- .../backend/cpu/compute/CommonOptFunction.cpp | 22 +- .../backend/cpu/compute/CommonOptFunction.h | 3 +- .../backend/cpu/compute/Int8FunctionsOpt.cpp | 24 +- source/backend/cpu/compute/Int8FunctionsOpt.h | 2 +- .../cpu/compute/WinogradOptFunction.cpp | 67 -- .../cpu/compute/WinogradOptFunction.hpp | 3 - .../cpu/x86_x64/FunctionDispatcher.cpp | 1 - .../cpu/x86_x64/sse/FunctionSummary.hpp | 2 +- .../backend/cpu/x86_x64/sse/MathFunctions.cpp | 43 - source/backend/metal/AllShader.cpp | 66 +- source/backend/metal/MetalAttention.mm | 832 +++---------- source/backend/metal/MetalAttentionShader.hpp | 636 ++++++++++ source/backend/metal/MetalConvolution1x1.hpp | 2 +- source/backend/metal/MetalConvolution1x1.mm | 27 +- .../backend/metal/MetalConvolutionCommon.hpp | 18 +- .../backend/metal/MetalConvolutionCommon.mm | 48 +- source/backend/metal/MetalDeconvolution.hpp | 9 - source/backend/metal/MetalDeconvolution.mm | 108 +- .../metal/shader/MetalConvolution1x1.metal | 57 +- .../metal/shader/MetalDeconvolution.metal | 9 +- .../backend/opencl/core/BufferConvertor.cpp | 74 ++ .../backend/opencl/core/BufferConvertor.hpp | 2 + source/backend/opencl/core/BufferPool.cpp | 1 - source/backend/opencl/core/OpenCLBackend.cpp | 152 ++- source/backend/opencl/core/OpenCLBackend.hpp | 24 +- .../opencl/core/runtime/OpenCLRuntime.cpp | 83 +- .../opencl/core/runtime/OpenCLRuntime.hpp | 8 +- .../opencl/core/runtime/OpenCLWrapper.cpp | 96 +- .../opencl/core/runtime/OpenCLWrapper.hpp | 22 +- .../execution/buffer/ConvBufExecution.cpp | 11 +- .../buffer/ConvBufLowMemoryExecution.cpp | 8 +- .../buffer/DepthwiseConvBufExecution.cpp | 12 +- .../opencl/execution/cl/buffer_convert_buf.cl | 2 +- source/backend/opencl/execution/cl/conv_2d.cl | 28 + .../opencl/execution/cl/conv_2d_buf.cl | 167 ++- .../opencl/execution/cl/conv_2d_int_buf.cl | 125 +- .../execution/cl/depthwise_conv2d_buf.cl | 20 +- .../opencl/execution/cl/glmem_convert.cl | 211 ++++ .../opencl/execution/cl/opencl_program.cc | 540 +++++++-- .../opencl/execution/cl/opencl_source_map.hpp | 2 + .../opencl/execution/image/ConvExecution.cpp | 28 +- .../image/ConvLowMemoryExecution.cpp | 12 + .../vulkan/buffer/execution/VulkanPRelu.cpp | 2 +- source/core/Interpreter.cpp | 3 + source/core/Pipeline.cpp | 2 +- source/core/TensorUtils.cpp | 12 +- source/core/TensorUtils.hpp | 6 + source/shape/ShapeConcat.cpp | 2 +- source/shape/ShapeRegister.cpp | 6 + source/shape/ShapeStft.cpp | 38 + source/shape/SizeComputer.hpp | 9 + test.sh | 4 +- test/CMakeLists.txt | 4 + test/op/DeconvolutionTest.cpp | 162 ++- test/op/MomentsTest.cpp | 2 + test/op/PReLUTest.cpp | 32 +- test/op/StftTest.cpp | 62 + test/sharedmem/AhardWareBufferTest.cpp | 176 +++ test/speed/StftSpeed.cpp | 40 + tools/audio/CMakeLists.txt | 44 + tools/audio/README.md | 9 + tools/audio/include/audio/audio.hpp | 169 +++ tools/audio/source/audio.cpp | 515 ++++++++ tools/audio/test/CMakeLists.txt | 20 + tools/audio/test/audio_test.cpp | 228 ++++ .../source/common/FullQuantAndCoding.cpp | 7 +- tools/converter/source/common/cli.cpp | 15 +- tools/cpp/CMakeLists.txt | 6 +- tools/cpp/GpuInterTest.cpp | 224 ++-- tools/script/register.py | 33 +- transformers/llm/.gitignore | 7 - transformers/llm/datasets/get-sharegpt.sh | 2 - .../llm/datasets/get-wikitext-2-raw.sh | 2 - .../llm/datasets/visualization/stats.py | 116 -- .../llm/datasets/visualization/time.py | 83 -- transformers/llm/engine/CMakeLists.txt | 24 +- transformers/llm/engine/app/ppl_demo.cpp | 61 - .../llm/engine/{app => }/embedding_demo.cpp | 0 .../engine/include/evaluation/MemMonitor.hpp | 42 - .../llm/engine/include/evaluation/dataset.hpp | 33 - .../engine/include/evaluation/evaluation.hpp | 68 -- transformers/llm/engine/include/llm/llm.hpp | 83 +- .../mnn-llm/LLMInferenceEngineWrapper.mm | 22 +- .../llm/engine/{app => }/llm_demo.cpp | 79 +- .../{test/bench_cn.txt => model/bench.txt} | 0 .../llm/engine/src/LlmSessionInfo.cpp | 87 -- transformers/llm/engine/src/dataset.cpp | 223 ---- transformers/llm/engine/src/evaluation.cpp | 30 - transformers/llm/engine/src/llm.cpp | 1036 +++++++++++------ transformers/llm/engine/src/llmconfig.cpp | 47 - transformers/llm/engine/src/llmconfig.hpp | 182 +-- transformers/llm/engine/src/perplexity.cpp | 318 ----- transformers/llm/engine/src/perplexity.hpp | 65 -- transformers/llm/engine/src/prompt.cpp | 110 -- transformers/llm/engine/src/prompt.hpp | 61 - transformers/llm/engine/src/sampler.cpp | 551 --------- transformers/llm/engine/src/sampler.hpp | 141 --- transformers/llm/engine/src/tokenizer.cpp | 2 +- transformers/llm/engine/test/bench_en.txt | 3 - transformers/llm/eval/evaluate_perplexity.py | 2 +- transformers/llm/export/.gitignore | 4 - transformers/llm/export/README.md | 36 +- transformers/llm/export/llmexport.py | 664 +++++++---- transformers/llm/export/requirements.txt | 4 +- 151 files changed, 6404 insertions(+), 5762 deletions(-) rename {source/backend/cpu => backupcode/cpubackend}/compute/DeconvolutionWithStride.cpp (100%) rename {source/backend/cpu => backupcode/cpubackend}/compute/DeconvolutionWithStride.hpp (100%) rename {source/backend/cpu => backupcode/cpubackend}/compute/GemmInt8Executor.cpp (100%) rename {source/backend/cpu => backupcode/cpubackend}/compute/GemmInt8Executor.hpp (100%) create mode 100644 pymnn/pip_package/MNN/audio/__init__.py create mode 100644 pymnn/src/audio.h create mode 100644 source/backend/cpu/CPUStft.cpp create mode 100644 source/backend/cpu/CPUStft.hpp delete mode 100644 source/backend/cpu/arm/arm32/MNNWinogradMatrixProductLeft.S delete mode 100644 source/backend/cpu/arm/arm32/MNNWinogradMatrixProductRight.S delete mode 100644 source/backend/cpu/arm/arm64/MNNWinogradMatrixProductLeft.S delete mode 100644 source/backend/cpu/arm/arm64/MNNWinogradMatrixProductRight.S create mode 100644 source/backend/metal/MetalAttentionShader.hpp create mode 100644 source/backend/opencl/execution/cl/glmem_convert.cl create mode 100644 source/shape/ShapeStft.cpp create mode 100644 test/op/StftTest.cpp create mode 100644 test/sharedmem/AhardWareBufferTest.cpp create mode 100644 test/speed/StftSpeed.cpp create mode 100644 tools/audio/CMakeLists.txt create mode 100644 tools/audio/README.md create mode 100644 tools/audio/include/audio/audio.hpp create mode 100644 tools/audio/source/audio.cpp create mode 100644 tools/audio/test/CMakeLists.txt create mode 100644 tools/audio/test/audio_test.cpp delete mode 100644 transformers/llm/.gitignore delete mode 100644 transformers/llm/datasets/get-sharegpt.sh delete mode 100644 transformers/llm/datasets/get-wikitext-2-raw.sh delete mode 100644 transformers/llm/datasets/visualization/stats.py delete mode 100644 transformers/llm/datasets/visualization/time.py delete mode 100644 transformers/llm/engine/app/ppl_demo.cpp rename transformers/llm/engine/{app => }/embedding_demo.cpp (100%) delete mode 100644 transformers/llm/engine/include/evaluation/MemMonitor.hpp delete mode 100644 transformers/llm/engine/include/evaluation/dataset.hpp delete mode 100644 transformers/llm/engine/include/evaluation/evaluation.hpp rename transformers/llm/engine/{app => }/llm_demo.cpp (62%) rename transformers/llm/engine/{test/bench_cn.txt => model/bench.txt} (100%) delete mode 100644 transformers/llm/engine/src/LlmSessionInfo.cpp delete mode 100644 transformers/llm/engine/src/dataset.cpp delete mode 100644 transformers/llm/engine/src/evaluation.cpp delete mode 100644 transformers/llm/engine/src/llmconfig.cpp delete mode 100644 transformers/llm/engine/src/perplexity.cpp delete mode 100644 transformers/llm/engine/src/perplexity.hpp delete mode 100644 transformers/llm/engine/src/prompt.cpp delete mode 100644 transformers/llm/engine/src/prompt.hpp delete mode 100644 transformers/llm/engine/src/sampler.cpp delete mode 100644 transformers/llm/engine/src/sampler.hpp delete mode 100644 transformers/llm/engine/test/bench_en.txt delete mode 100644 transformers/llm/export/.gitignore diff --git a/.gitignore b/.gitignore index ca1a4320f..d11586811 100644 --- a/.gitignore +++ b/.gitignore @@ -361,10 +361,3 @@ pymnn_build/ # mnncompress generated MNN_compression_pb2.py - -# model path -model/ - -# datasets -datasets/* -!datasets/*.sh \ No newline at end of file diff --git a/3rd_party/OpenCLHeaders/CL/cl2.hpp b/3rd_party/OpenCLHeaders/CL/cl2.hpp index 305e88f30..b74fdbe11 100644 --- a/3rd_party/OpenCLHeaders/CL/cl2.hpp +++ b/3rd_party/OpenCLHeaders/CL/cl2.hpp @@ -3810,6 +3810,24 @@ class Buffer : public Memory } } + Buffer( + const Context& context, + cl_mem_flags flags, + const cl_import_properties_arm *properties, + void *memory, + size_type size, + cl_int* err = NULL) + { + cl_int error; + object_ = ::clImportMemoryARM(context(), flags, properties, memory, size, &error); + + detail::errHandler(error, __CREATE_BUFFER_ERR); + if (err != NULL) { + *err = error; + } + } + + /*! * \brief Construct a Buffer from a host container via iterators using a specified context. * IteratorType must be random access. diff --git a/3rd_party/OpenCLHeaders/CL/cl_ext.h b/3rd_party/OpenCLHeaders/CL/cl_ext.h index 7b101d737..47afb42f2 100644 --- a/3rd_party/OpenCLHeaders/CL/cl_ext.h +++ b/3rd_party/OpenCLHeaders/CL/cl_ext.h @@ -430,6 +430,23 @@ typedef struct _cl_mem_android_native_buffer_host_ptr } cl_mem_android_native_buffer_host_ptr; +/********************************* +* cl_qcom_ahardwarebuffer_host_ptr extension +*********************************/ + +#define CL_MEM_ANDROID_AHARDWAREBUFFER_HOST_PTR_QCOM 0x4119 + +typedef struct _cl_mem_ahardwarebuffer_host_ptr +{ + /* Type of external memory allocation. */ + /* Must be CL_MEM_ANDROID_AHARDWAREBUFFER_HOST_PTR_QCOM for Android Hardware buffers. */ + cl_mem_ext_host_ptr ext_host_ptr; + + /* Virtual pointer to the android hardware buffer */ + void* ahb_ptr; + +} cl_mem_ahardwarebuffer_host_ptr; + /****************************************** * cl_img_yuv_image extension * ******************************************/ @@ -583,6 +600,11 @@ typedef intptr_t cl_import_properties_arm; /* Protected DMA BUF memory type value for CL_IMPORT_TYPE_ARM property */ #define CL_IMPORT_TYPE_PROTECTED_ARM 0x40B5 +#define CL_IMPORT_TYPE_ANDROID_HARDWARE_BUFFER_ARM 0x41E2 +#define CL_IMPORT_DMA_BUF_DATA_CONSISTENCY_WITH_HOST_ARM 0x41E3 +#define CL_IMPORT_MEMORY_WHOLE_ALLOCATION_ARM SIZE_MAX +#define CL_IMPORT_ANDROID_HARDWARE_BUFFER_PLANE_INDEX_ARM 0x41EF +#define CL_IMPORT_ANDROID_HARDWARE_BUFFER_LAYER_INDEX_ARM 0x41F0 /* This extension adds a new function that allows for direct memory import into * OpenCL via the clImportMemoryARM function. diff --git a/CMakeLists.txt b/CMakeLists.txt index c7768e340..6048bf4d3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -20,9 +20,7 @@ endif() project(MNN VERSION ${MNN_VERSION} LANGUAGES C CXX ASM) # complier options set(CMAKE_C_STANDARD 99) -IF (NOT (CMAKE_CXX_STANDARD EQUAL 17)) - set(CMAKE_CXX_STANDARD 11) -ENDIF() +set(CMAKE_CXX_STANDARD 11) set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_LIST_DIR}/cmake" @@ -49,7 +47,7 @@ option(MNN_BUILD_TOOLS "Build tools/cpp or not" ON) option(MNN_BUILD_QUANTOOLS "Build Quantized Tools or not" OFF) option(MNN_EVALUATION "Build Evaluation Tools or not" OFF) option(MNN_BUILD_CONVERTER "Build Converter" OFF) -option(MNN_SUPPORT_DEPRECATED_OP "Enable MNN's tflite quantized op" ON) +option(MNN_SUPPORT_DEPRECATED_OP "Enable MNN's tflite quantized op" OFF) option(MNN_DEBUG_MEMORY "MNN Debug Memory Access" OFF) option(MNN_DEBUG_TENSOR_SIZE "Enable Tensor Size" OFF) option(MNN_GPU_TRACE "Enable MNN Gpu Debug" OFF) @@ -74,6 +72,7 @@ option(MNN_JNI "Build MNN Jni for java to use" OFF) option(MNN_SUPPORT_BF16 "Enable MNN's bf16 op" OFF) option(MNN_LOW_MEMORY "Build MNN support low memory for weight quant model." OFF) option(MNN_CPU_WEIGHT_DEQUANT_GEMM "Build MNN CPU weight dequant related gemm kernels." OFF) +option(MNN_BUILD_AUDIO "Build audio api in MNN." OFF) IF (OHOS AND MNN_INTERNAL) include($ENV{NODE_PATH}/@ali/tcpkg/tcpkg.cmake) @@ -192,6 +191,9 @@ endif() if(MNN_SUPPORT_TRANSFORMER_FUSE) add_definitions(-DMNN_SUPPORT_TRANSFORMER_FUSE) endif() +if(MNN_BUILD_AUDIO) + add_definitions(-DMNN_BUILD_AUDIO) +endif() # debug options if(MNN_DEBUG_MEMORY) add_definitions(-DMNN_DEBUG_MEMORY) @@ -287,7 +289,7 @@ if(CMAKE_SYSTEM_NAME MATCHES "^Android") endif() option(MNN_USE_CPP11 "Enable MNN use c++11" ON) if (NOT MSVC) - if((MNN_CUDA AND MNN_SUPPORT_TRANSFORMER_FUSE) OR (CMAKE_CXX_STANDARD EQUAL 17)) + if(MNN_CUDA AND MNN_SUPPORT_TRANSFORMER_FUSE) set(CMAKE_CXX_STANDARD 17) set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -std=gnu99") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17") @@ -463,6 +465,10 @@ IF(MNN_BUILD_OPENCV) list(APPEND MNN_EXTRA_HEADERS ${MNN_CV_HDRS}) list(APPEND MNN_EXTRA_HEADERS ${MNN_CV_IMGHDRS}) ENDIF() +IF(MNN_BUILD_AUDIO) + file(GLOB MNN_AUDIO_HDRS ${CMAKE_CURRENT_SOURCE_DIR}/tools/audio/include/audio/*.hpp PARENT_SCOPE) + list(APPEND MNN_EXTRA_HEADERS ${MNN_AUDIO_HDRS}) +ENDIF() IF(MNN_BUILD_LLM) file(GLOB MNN_LLM_HDRS ${CMAKE_CURRENT_SOURCE_DIR}/transformers/llm/engine/include/llm/*) list(APPEND MNN_EXTRA_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/transformers/llm/engine/include/llm/llm.hpp) @@ -775,6 +781,14 @@ IF(MNN_BUILD_OPENCV AND NOT MNN_SEP_BUILD) ENDIF() target_sources(MNN PRIVATE $) ENDIF() +add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/tools/audio) +IF(MNN_BUILD_AUDIO AND NOT MNN_SEP_BUILD) + IF(MSVC) + target_compile_definitions(MNNAudio PRIVATE "-DBUILDING_MNN_DLL" INTERFACE "-DUSING_MNN_DLL") + ENDIF() + message(STATUC "### build MNNAudio into MNN") + target_sources(MNN PRIVATE $) +ENDIF() if(CMAKE_SYSTEM_NAME MATCHES "^Linux") @@ -884,6 +898,14 @@ ELSE() SET_SOURCE_FILES_PROPERTIES(${HDR} PROPERTIES MACOSX_PACKAGE_LOCATION Headers/cv/imgproc ) ENDFOREACH() ENDIF() + IF(MNN_BUILD_AUDIO) + if (NOT MNN_AAPL_FMWK) + INSTALL(FILES ${MNN_AUDIO_HDRS} DESTINATION include/MNN/audio) + endif() + FOREACH(HDR ${MNN_AUDIO_HDRS}) + SET_SOURCE_FILES_PROPERTIES(${HDR} PROPERTIES MACOSX_PACKAGE_LOCATION Headers/audio/ ) + ENDFOREACH() + ENDIF() IF(MNN_BUILD_LLM) if (NOT MNN_AAPL_FMWK) INSTALL(FILES ${MNN_LLM_HDRS} DESTINATION include/MNN/llm) diff --git a/source/backend/cpu/compute/DeconvolutionWithStride.cpp b/backupcode/cpubackend/compute/DeconvolutionWithStride.cpp similarity index 100% rename from source/backend/cpu/compute/DeconvolutionWithStride.cpp rename to backupcode/cpubackend/compute/DeconvolutionWithStride.cpp diff --git a/source/backend/cpu/compute/DeconvolutionWithStride.hpp b/backupcode/cpubackend/compute/DeconvolutionWithStride.hpp similarity index 100% rename from source/backend/cpu/compute/DeconvolutionWithStride.hpp rename to backupcode/cpubackend/compute/DeconvolutionWithStride.hpp diff --git a/source/backend/cpu/compute/GemmInt8Executor.cpp b/backupcode/cpubackend/compute/GemmInt8Executor.cpp similarity index 100% rename from source/backend/cpu/compute/GemmInt8Executor.cpp rename to backupcode/cpubackend/compute/GemmInt8Executor.cpp diff --git a/source/backend/cpu/compute/GemmInt8Executor.hpp b/backupcode/cpubackend/compute/GemmInt8Executor.hpp similarity index 100% rename from source/backend/cpu/compute/GemmInt8Executor.hpp rename to backupcode/cpubackend/compute/GemmInt8Executor.hpp diff --git a/docs/compile/cmake.md b/docs/compile/cmake.md index 9307038ad..a4d45bca4 100644 --- a/docs/compile/cmake.md +++ b/docs/compile/cmake.md @@ -16,7 +16,7 @@ MNN使用CMake构建项目,CMake中的宏定义列表如下: | MNN_BUILD_QUANTOOLS | 是否构建MNN的量化工具,默认为`OFF` | | MNN_EVALUATION | 是否构建MNN的评估工具,默认为`OFF` | | MNN_BUILD_CONVERTER | 是否构建MNN的转换工具,默认为`OFF` | -| MNN_SUPPORT_DEPRECATED_OP | 是否支持Tflite的量化算子,默认为`ON` | +| MNN_SUPPORT_DEPRECATED_OP | 是否支持Tflite的量化算子等已经废弃的算子,用于兼容历史模型(1.1.0版本之前),默认为`OFF` | | MNN_DEBUG_MEMORY | 是否开启MNN内存调试,默认为`OFF` | | MNN_DEBUG_TENSOR_SIZE | 是否开启MNN tensor size调试,默认为`OFF` | | MNN_GPU_TRACE | 是否开启MNN GPU调试,默认为`OFF` | @@ -32,6 +32,7 @@ MNN使用CMake构建项目,CMake中的宏定义列表如下: | MNN_ENABLE_COVERAGE | 是否开启MNN的代码覆盖率,默认为`OFF` | | MNN_BUILD_PROTOBUFFER | 是否使用MNN中的`protobuffer`,默认为`ON` | | MNN_BUILD_OPENCV | 是否构建MNN的OpenCV功能,默认为`OFF` | +| MNN_BUILD_AUDIO | 是否构建MNN的Audio功能,默认为`OFF` | | MNN_INTERNAL | 是否构建MNN的一些内部功能,如:日志;默认为`OFF` | | MNN_JNI | 是否构建MNN的JNI支持,默认为`OFF` | | MNN_METAL | 是否构建`Metal`后端,默认为`OFF` | @@ -79,6 +80,7 @@ MNN使用CMake构建项目,CMake中的宏定义列表如下: | MNN_CVCORE | 构建MNN的OpenCV功能是否开启`core`功能,默认为`ON` | | MNN_OPENCV_TEST | 构建MNN的OpenCV功能是否开启单元测试,默认为`OFF` | | MNN_OPENCV_BENCH | 构建MNN的OpenCV功能是否开启性能benchmark,默认为`OFF` | +| MNN_AUDIO_TEST | 构建MNN的Audio功能是否开启单元测试,默认为`OFF` | | MNN_VULKAN_IMAGE | 构建MNN的Vulkan后端时采用Image内存模式,以便支持FP16和部分移动端上GPU的加速,默认为`ON` | | MNN_LOW_MEMORY | 是否支持低内存模式,支持低内存模式使用权值量化模型并设置`low_memory`则会使用计算时反量化,默认为`OFF` | | MNN_CPU_WEIGHT_DEQUANT_GEMM | 是否编译CPU权重反量化的矩阵乘Kernel, 如果打开该编译宏并且在CPU推理时设置MNN::BackendConfig::MemoryMode=Memory_Normal,就会使用权重反量化算子进行权重量化模型的推理,默认为`OFF` | diff --git a/docs/compile/other.md b/docs/compile/other.md index d0209f61b..f6418cc27 100644 --- a/docs/compile/other.md +++ b/docs/compile/other.md @@ -133,6 +133,19 @@ - `libMNNOpenCV.so` MNN OpenCV函数库 - `opencv_test` MNN OpenCV单元测试 - `opencv_bench` MNN OpenCV性能测试 +## MNN Audio库 +- 相关编译选项 + - `MNN_BUILD_AUDIO` 是否编译Audio函数接口 + - `MNN_AUDIO_TEST` 是否编译Audio单元测试 +- 编译命令 + ```bash + mkdir build && cd build + cmake .. -MNN_BUILD_AUDIO=ON -MNN_AUDIO_TEST=ON + make -j4 + ``` +- 编译产物 + - `libMNNAudio.so` MNN Audio函数库 + - `audio_test` MNN Audio单元测试 ## 示例工程 - 相关编译选项 diff --git a/docs/transformers/llm.md b/docs/transformers/llm.md index b0fbd4932..bdd40739f 100644 --- a/docs/transformers/llm.md +++ b/docs/transformers/llm.md @@ -49,7 +49,7 @@ python llmexport.py \ ### 功能 - 直接转为mnn模型,使用`--export mnn`,注意,你需要先安装pymnn或者通过`--mnnconvert`选项指定MNNConvert工具的地址,两种条件必须满足其中一个。如果没有安装pymnn并且没有通过`--mnnconvert`指定MNNConvert工具的地址,那么llmexport.py脚本会在目录"../../../build/"下寻找MNNConvert工具,需保证该目录下存在MNNConvert文件。此方案目前支持导出4bit和8bit模型 -- 如果直接转为mnn模型遇到问题,或者需要其他bits数的量化(如5bit/6bit),可以先将模型先转为onnx模型,使用`--export onnx`,然后使用./MNNConvert工具将onnx模型转为mnn模型: +- 如果直接转为mnn模型遇到问题,或者需要其他bits数的量化(如5bit/6bit),可以先将模型先转为onnx模型,使用`--export onnx`,然后使用./MNNConvert工具将onnx模型转为mnn模型: ``` ./MNNConvert --modelFile ../transformers/llm/export/model/onnx/llm.onnx --MNNModel llm.mnn --keepInputFormat --weightQuantBits=4 --weightQuantBlock=128 -f ONNX --transformerFuse=1 --allowCustomOp --saveExternalData @@ -98,13 +98,17 @@ options: [从源码编译](../compile/other.html#id4) 在原有编译过程中增加必需编译宏即可: ``` --DMNN_LOW_MEMORY=true -DMNN_CPU_WEIGHT_DEQUANT_GEMM=true -DMNN_BUILD_LLM=true -DMNN_SUPPORT_TRANSFORMER_FUSE=true +-DMNN_LOW_MEMORY=true -DMNN_CPU_WEIGHT_DEQUANT_GEMM=true -DMNN_BUILD_LLM=true -DMNN_SUPPORT_TRANSFORMER_FUSE=true ``` - 需要开启视觉功能时,增加相关编译宏 ``` -DLLM_SUPPORT_VISION=true -DMNN_BUILD_OPENCV=true -DMNN_IMGCODECS=true ``` +- 需要开启音频功能时,增加相关编译宏 +``` +-DLLM_SUPPORT_AUDIO=true +``` #### mac / linux / windows @@ -137,7 +141,7 @@ sh package_scripts/ios/buildiOS.sh "-DMNN_ARM82=true -DMNN_LOW_MEMORY=true -DMNN ``` #### Web -环境配置参考 https://mnn-docs.readthedocs.io/en/latest/compile/engine.html#web +环境配置参考 https://mnn-docs.readthedocs.io/en/latest/compile/engine.html#web - 编译库,产出 `libMNN.a`,`libMNN_Express.a`,`libllm.a` @@ -189,7 +193,7 @@ node llm_demo.js ~/qwen2.0_1.5b/config.json ~/qwen2.0_1.5b/prompt.txt - visual_model: 当使用VL模型时,visual_model的实际路径为`base_dir + visual_model`,默认为`base_dir + 'visual.mnn'` - 推理配置 - max_new_tokens: 生成时最大token数,默认为`512` - - reuse_kv: 多轮对话时是否复用之前对话的`kv cache`,默认为`false`, 目前只有CPU后端支持设置为`true`. + - reuse_kv: 多轮对话时是否复用之前对话的`kv cache`,默认为`false` - quant_qkv: CPU attention 算子中`query, key, value`是否量化,可选为:`0, 1, 2, 3, 4`,默认为`0`,含义如下: - 0: key和value都不量化 - 1: 使用非对称8bit量化存储key @@ -205,19 +209,6 @@ node llm_demo.js ~/qwen2.0_1.5b/config.json ~/qwen2.0_1.5b/prompt.txt - thread_num: CPU推理使用硬件线程数,默认为:`4`; OpenCL推理时使用`68` - precision: 推理使用精度策略,默认为:`"low"`,尽量使用`fp16` - memory: 推理使用内存策略,默认为:`"low"`,开启运行时量化 -- Sampler配置 - - sampler_type: 使用的sampler种类,目前支持`greedy`, `temperature`, `topK`, `topP`, `minP`, `tfs`, `typical`, `penalty`8种基本sampler,外加`mixed`(混合sampler)。当选择`mixed`时,依次执行mixed_samplers中的sampler。默认为`mixed`。 - - mixed_samplers: 当`sampler_type`为`mixed`时有效,默认为`["topK", "tfs", "typical", "topP", "min_p", "temperature"]` - - temperature: `temperature`, `topP`, `minP`, `tfsZ`, `typical`中temerature值,默认为1.0 - - topK: `topK`中top K 个的个数,默认为40 - - topP: `topP`中top P的值,默认为0.9 - - minP: `minP`中min P的值,默认为0.1 - - tfsZ: `tfs`中Z的值,默认为1.0,即不使用tfs算法 - - typical: `typical`中p的值,默认为1.0,即不使用typical算法 - - penalty: `penalty`中对于logits的惩罚项,默认为0.0,即不惩罚 - - n_gram: `penalty`中最大存储的ngram大小,默认为8 - - ngram_factor: `penalty`中对于重复ngram的额外惩罚,默认为1.0,即没有额外惩罚 - - penalty_sampler: `penalty`中最后一步采用的sampling策略,可选"greedy"或"temperature",默认greedy. ##### 配置文件示例 - `config.json` @@ -229,15 +220,7 @@ node llm_demo.js ~/qwen2.0_1.5b/config.json ~/qwen2.0_1.5b/prompt.txt "backend_type": "cpu", "thread_num": 4, "precision": "low", - "memory": "low", - "sampler_type": "mixed", - "mixed_samplers": ["topK", "tfs", "typical", "topP", "min_p", "temperature"], - "temperature": 1.0, - "topK": 40, - "topP": 0.9, - "tfsZ": 1.0, - "minP": 0.1, - "reuse_kv": true + "memory": "low" } ``` - `llm_config.json` @@ -261,8 +244,7 @@ node llm_demo.js ~/qwen2.0_1.5b/config.json ~/qwen2.0_1.5b/prompt.txt #### 推理用法 `llm_demo`的用法如下: -pc端直接推理 -```bash +``` # 使用config.json ## 交互式聊天 ./llm_demo model_dir/config.json @@ -276,16 +258,15 @@ pc端直接推理 ./llm_demo model_dir/llm.mnn prompt.txt ``` -android手机端adb推理用法: -```bash -# 利用adb push将链接库push到手机上 -adb shell mkdir /data/local/tmp/llm -adb push llm_demo ppl_demo libllm.so libMNN_CL.so libMNN_Express.so libMNN.so tools/cv/libMNNOpenCV.so /data/local/tmp/llm -``` - - 对于视觉大模型,在prompt中嵌入图片输入 ``` https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg介绍一下图片里的内容 +# 指定图片大小 +280, 420https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg介绍一下图片里的内容 +``` +- 对于音频大模型,在prompt中嵌入音频输入 +``` +介绍一下音频里的内容 ``` #### GPTQ权重加载 diff --git a/docs/transformers/models.md b/docs/transformers/models.md index 5587b41a5..ee83463a4 100644 --- a/docs/transformers/models.md +++ b/docs/transformers/models.md @@ -47,4 +47,5 @@ | [reader-lm-0.5b](https://huggingface.co/jinaai/reader-lm-0.5b) | [Q4_1](https://modelscope.cn/models/MNN/reader-lm-0.5b-MNN) | [Q4_1](https://huggingface.co/taobao-mnn/reader-lm-0.5b-MNN) | | [reader-lm-1.5b](https://huggingface.co/jinaai/reader-lm-1.5b) | [Q4_1](https://modelscope.cn/models/MNN/reader-lm-1.5b-MNN) | [Q4_1](https://huggingface.co/taobao-mnn/reader-lm-1.5b-MNN) | | [TinyLlama-1.1B-Chat-v1.0](https://modelscope.cn/models/AI-ModelScope/TinyLlama-1.1B-Chat-v1.0/summary) | [Q4_1](https://modelscope.cn/models/MNN/TinyLlama-1.1B-Chat-MNN) | [Q4_1](https://huggingface.co/taobao-mnn/TinyLlama-1.1B-Chat-MNN) | -| [Yi-6B-Chat](https://modelscope.cn/models/01ai/Yi-6B-Chat/summary) | [Q4_1](https://modelscope.cn/models/MNN/Yi-6B-Chat-MNN) | [Q4_1](https://huggingface.co/taobao-mnn/Yi-6B-Chat-MNN) | \ No newline at end of file +| [Yi-6B-Chat](https://modelscope.cn/models/01ai/Yi-6B-Chat/summary) | [Q4_1](https://modelscope.cn/models/MNN/Yi-6B-Chat-MNN) | [Q4_1](https://huggingface.co/taobao-mnn/Yi-6B-Chat-MNN) | +| [QwQ-32B-Preview](https://modelscope.cn/models/Qwen/QwQ-32B-Preview/summary) | [Q4_1](https://modelscope.cn/models/MNN/QwQ-32B-Preview-MNN) | [Q4_1](https://huggingface.co/taobao-mnn/QwQ-32B-Preview-MNN) | \ No newline at end of file diff --git a/express/MathOp.cpp b/express/MathOp.cpp index c14a902f1..eaf93f9fc 100644 --- a/express/MathOp.cpp +++ b/express/MathOp.cpp @@ -1208,7 +1208,7 @@ VARP _LinSpace(VARP start, VARP stop, VARP num) { return (Variable::create(Expr::create(std::move(op), {start, stop, num}))); } -VARP _EltwiseProdInt8(VARP x, VARP y, +VARP _EltwiseProdInt8(VARP x, VARP y, std::vector x_weight, std::vector x_bias, std::vector x_scale, std::vector x_tensorScale, std::vector y_weight, std::vector y_bias, std::vector y_scale, std::vector y_tensorScale, std::vector output_weight, std::vector output_bias, std::vector output_scale, std::vector output_tensorScale) @@ -1219,7 +1219,7 @@ VARP _EltwiseProdInt8(VARP x, VARP y, output_weight, output_bias, output_scale, output_tensorScale); } -VARP _EltwiseSumInt8(VARP x, VARP y, +VARP _EltwiseSumInt8(VARP x, VARP y, std::vector x_weight, std::vector x_bias, std::vector x_scale, std::vector x_tensorScale, std::vector y_weight, std::vector y_bias, std::vector y_scale, std::vector y_tensorScale, std::vector output_weight, std::vector output_bias, std::vector output_scale, std::vector output_tensorScale) @@ -1230,7 +1230,7 @@ VARP _EltwiseSumInt8(VARP x, VARP y, output_weight, output_bias, output_scale, output_tensorScale); } -VARP _EltwiseSubInt8(VARP x, VARP y, +VARP _EltwiseSubInt8(VARP x, VARP y, std::vector x_weight, std::vector x_bias, std::vector x_scale, std::vector x_tensorScale, std::vector y_weight, std::vector y_bias, std::vector y_scale, std::vector y_tensorScale, std::vector output_weight, std::vector output_bias, std::vector output_scale, std::vector output_tensorScale) @@ -1241,7 +1241,7 @@ VARP _EltwiseSubInt8(VARP x, VARP y, output_weight, output_bias, output_scale, output_tensorScale); } -VARP _EltwiseMaxInt8(VARP x, VARP y, +VARP _EltwiseMaxInt8(VARP x, VARP y, std::vector x_weight, std::vector x_bias, std::vector x_scale, std::vector x_tensorScale, std::vector y_weight, std::vector y_bias, std::vector y_scale, std::vector y_tensorScale, std::vector output_weight, std::vector output_bias, std::vector output_scale, std::vector output_tensorScale) @@ -1320,5 +1320,20 @@ VARP _Histogram(VARP x, int bin, int min, int max, int channel) { return (Variable::create(Expr::create(std::move(op), {x}))); } +#ifdef MNN_BUILD_AUDIO +VARP _Stft(VARP sample, VARP window, int n_fft, int hop_length, bool abs) { + std::unique_ptr op(new OpT); + op->type = OpType_Stft; + op->main.type = OpParameter_StftParam; + auto param = new StftParamT; + param->n_fft = n_fft; + param->hop_length = hop_length; + param->abs = abs; + op->main.value = param; + EXPRP expr = Expr::create(std::move(op), {sample, window}); + return Variable::create(expr); +} +#endif + } // namespace Express } // namespace MNN diff --git a/express/NeuralNetWorkOp.cpp b/express/NeuralNetWorkOp.cpp index 18d58c3ec..28580d792 100644 --- a/express/NeuralNetWorkOp.cpp +++ b/express/NeuralNetWorkOp.cpp @@ -70,7 +70,7 @@ VARP _Scalar(const void* ptr, halide_type_t type) { ptr: A pointer. Indicates the values. shape: A vector, the shape of the variable. format: A enum, NCHW/NHWC/NC4HW4 is allowed. -type: The type of the elements of the resulting variable. +type: The type of the elements of the resulting variable. Returns: output: A constant variable. */ @@ -118,7 +118,7 @@ VARP _InnerProduct(std::vector&& weight, std::vector&& bias, VARP ipParam->biasTerm = 1; } ipParam->weightSize = (int)weight.size(); - + ipParam->weight = std::move(weight); ipParam->bias = std::move(bias); return (Variable::create(Expr::create(ipOp.get(), {x}))); @@ -369,9 +369,9 @@ VARP _MaxPool(VARP x, INTS kernel, INTS stride, PaddingMode pad, INTS pads) { } /*Reshapes a variable. Args: -x: A variable. +x: A variable. shape: A vector, the shape of the target variable. -original_format: A enum, only NCHW/NHWC is allowed, NC4HW4 is not allowed, +original_format: A enum, only NCHW/NHWC is allowed, NC4HW4 is not allowed, as it provides additional information(x comes from NCHW or NHWC) When x is NC4HW4. Returns: output: A variable with the same type as `x`. @@ -387,7 +387,7 @@ VARP _Reshape(VARP x, INTS shape, Dimensionformat original_format) { } /*Reshapes a variable. Args: -x: A variable. +x: A variable. shape: A variable, the shape of the target variable. Returns: output: A variable with the same type as `x`. @@ -415,10 +415,10 @@ VARP _Scale(VARP x, int channels, std::vector&& scales, std::vectormain.AsScale()->biasData = std::move(bias); return (Variable::create(Expr::create(std::move(scale), {x}))); } -/*Given an input value x, it computes the output as x if x > 0 and slope * x if x <= 0. +/*Given an input value x, it computes the output as x if x > 0 and slope * x if x <= 0. Args: -x: A variable. -slope: A float, a positive float value, it leakes the negative part by multiplying with `slope` rather than setting it to 0.0f. +x: A variable. +slope: A float, a positive float value, it leakes the negative part by multiplying with `slope` rather than setting it to 0.0f. Returns: output: A variable with the same type as `x`. */ @@ -432,7 +432,7 @@ VARP _Relu(VARP x, float slope) { } /*Given an input value x, it computes Rectified Linear 6: min(max(x, 0), 6). Args: -x: A variable. +x: A variable. Returns: output: A variable with the same type as `x`. */ @@ -445,9 +445,9 @@ VARP _Relu6(VARP x, float minValue, float maxValue) { relu->main.AsRelu6()->minValue = minValue; return (Variable::create(Expr::create(relu.get(), {x}))); } -/*Given an input value x, it computes the output as x if x > 0 and slopes * x if x <= 0. +/*Given an input value x, it computes the output as x if x > 0 and slopes * x if x <= 0. Args: -x: A variable, must be 4-D with NC4HW4 format. +x: A variable, must be 4-D with NC4HW4 format. slopes: A vector, has save size as x. Returns: output: A variable with the same type as `x`. @@ -497,10 +497,10 @@ VARP _Softsign(VARP features) { /*Concatenates variables along one dimension. Args: values: A list of variables a single variable. -axis: A int. Dimension along which to concatenate. -Must be in the range [-rank(values), rank(values)). -As in Python, indexing for axis is 0-based. -Positive axis in the rage of [0, rank(values)) refers to axis-th dimension. +axis: A int. Dimension along which to concatenate. +Must be in the range [-rank(values), rank(values)). +As in Python, indexing for axis is 0-based. +Positive axis in the rage of [0, rank(values)) refers to axis-th dimension. And negative axis refers to axis + rank(values)-th dimension. Returns: A variable resulting from concatenation of the input variables. @@ -516,7 +516,7 @@ VARP _Concat(VARPS values, int axis) { /*Convert a variable to another format(possibily added after `input`). Args: input: A variable. -format: The target format. +format: The target format. Returns: A variable. If `input` is already `format`, then return `input` directly, otherwize add a variable after `input` with `format`. */ @@ -537,7 +537,7 @@ VARP _Convert(VARP input, Dimensionformat format) { /*Splits a variable value into a list of sub variables. Args: value: The variable to split. -size_splits: A vector, a 1-D integer containing the sizes of each output variable along axis. +size_splits: A vector, a 1-D integer containing the sizes of each output variable along axis. axis: A int, the dimension along which to split. Must be in the range [-rank(value), rank(value)). Defaults to 0 Returns: A list of variables. @@ -645,7 +645,7 @@ VARP _ReverseSequence(VARP x, VARP y, int batchDim, int seqDim) { /*Convert a variable to another format(possibily added before `input`). Args: input: A variable. -format: The target format. +format: The target format. Returns: A variable. If `input` is already `format`, then return `input` directly, otherwize add a variable before `input` with `format`. */ @@ -735,15 +735,15 @@ VARP _PoolGrad(VARP originInput, VARP originOutput, VARP inputGrad, INTS kernel, pool->main.AsPool()->type = (PoolType)type; return (Variable::create(Expr::create(std::move(pool), {originInput, originOutput, inputGrad}))); } -/*Crop images. +/*Crop images. Args: -images: 4-D variable of NC4HW4 format. +images: 4-D variable of NC4HW4 format. size: A variable. It takes the shape of `size` as output cropped variable's shape while omits the values/format of `size`. -axis: A int indicating the dimention to crop. Must be >=2. All dimensions up to but excluding `axis` are preserved, while the dimensions including and trailing `axis` are cropped. +axis: A int indicating the dimention to crop. Must be >=2. All dimensions up to but excluding `axis` are preserved, while the dimensions including and trailing `axis` are cropped. offset: A vector of int indicating the offsets. length(`offset`) must be >=1 and <=2. If length(`offset`) is 1, then all dimensions are offset by this amount.Otherwise, the number of offsets must equal the number of cropped axes in each dimension accordingly. Returns: The cropped 4-D variable of NC4HW4 format. -*/ +*/ VARP _Crop(VARP images, VARP size, int axis, INTS offset) { std::unique_ptr crop(new OpT); crop->type = OpType_Crop; @@ -753,13 +753,13 @@ VARP _Crop(VARP images, VARP size, int axis, INTS offset) { crop->main.AsCrop()->offset = offset; return (Variable::create(Expr::create(std::move(crop), {images, size}))); } -/*Resize images. +/*Resize images. Args: -images: 4-D variable of NC4HW4 format. -xScale: A float. +images: 4-D variable of NC4HW4 format. +xScale: A float. yScale: A float. Returns: -The resized 4-D variable of NC4HW4 format. +The resized 4-D variable of NC4HW4 format. */ VARP _Resize(VARP images, float xScale, float yScale) { std::unique_ptr resize(new OpT); @@ -773,8 +773,8 @@ VARP _Resize(VARP images, float xScale, float yScale) { /*Pads a variable. Args: x: A variable. -paddings: A variable of type Halide_Type_Int. The shape is [n, 2] where n is the rank of variable. -mode: A enum, One of PadValueMode_CONSTANT, PadValueMode_SYMMETRIC, or PadValueMode_REFLECT. +paddings: A variable of type Halide_Type_Int. The shape is [n, 2] where n is the rank of variable. +mode: A enum, One of PadValueMode_CONSTANT, PadValueMode_SYMMETRIC, or PadValueMode_REFLECT. Returns: A variable. Has the same type as x. */ @@ -802,7 +802,7 @@ VARP _Pad(VARP x, VARP paddings, PadValueMode mode) { /*Returns a variable with an additional dimension inserted at index axis. Args: input: A variable. -axis: A int, specifying the dimension index at which to expand the shape of input. +axis: A int, specifying the dimension index at which to expand the shape of input. Given an input of D dimensions, axis must be in range [-(D+1), D] (inclusive). Returns: A variable with the same data as input, with an additional dimension inserted at the index specified by axis. @@ -827,7 +827,7 @@ VARP _ExpandDims(VARP input, VARP axis) { input: A variable. Returns: A variable of Halide_Type_Int. -*/ +*/ VARP _Shape(VARP input, bool nchw) { std::unique_ptr shape(new OpT); shape->type = OpType_Shape; @@ -838,13 +838,13 @@ VARP _Shape(VARP input, bool nchw) { } /*Stacks a list of rank-R variables into one rank-(R+1) variable. Packs the list of variables in `values` into a ariable with rank one higher than each variable in values, -by packing them along the axis dimension. +by packing them along the axis dimension. Given a list of length N of variables of shape (A, B, C); -if axis == 0 then the output variable will have the shape (N, A, B, C). +if axis == 0 then the output variable will have the shape (N, A, B, C). if axis == 1 then the output variable will have the shape (A, N, B, C). Etc. Args: values: A list of variable objects with the same shape and type. -axis: An int. The axis to stack along. Defaults to the first dimension. Negative values wrap around, +axis: An int. The axis to stack along. Defaults to the first dimension. Negative values wrap around, so the valid range is [-(R+1), R+1). Returns: output: A stacked variable with the same type as `values`. @@ -858,13 +858,13 @@ VARP _Stack(VARPS values, int axis) { return (Variable::create(Expr::create(std::move(pack), values))); } /*Extracts crops from the input image variable and resizes them using bilinear sampling or nearest neighbor sampling (possibly with aspect ratio change) -to a common output size specified by crop_size. -Returns a variable with crops from the input image at positions defined at the bounding box locations in boxes. -The cropped boxes are all resized (with bilinear or nearest neighbor interpolation) to a fixed size = [crop_height, crop_width]. +to a common output size specified by crop_size. +Returns a variable with crops from the input image at positions defined at the bounding box locations in boxes. +The cropped boxes are all resized (with bilinear or nearest neighbor interpolation) to a fixed size = [crop_height, crop_width]. The result is a 4-D tensor [num_boxes, crop_height, crop_width, depth](supposing NHWC format). Arguments: image: A 4-D variable of shape [batch, image_height, image_width, depth](supposing NHWC format). Both image_height and image_width need to be positive. -boxes: A 2-D variable of shape [num_boxes, 4]. The i-th row of the variable specifies the coordinates of a box in the box_ind[i] image and is specified in normalized coordinates [y1, x1, y2, x2]. +boxes: A 2-D variable of shape [num_boxes, 4]. The i-th row of the variable specifies the coordinates of a box in the box_ind[i] image and is specified in normalized coordinates [y1, x1, y2, x2]. A normalized coordinate value of y is mapped to the image coordinate at y * (image_height - 1), so as the [0, 1] interval of normalized image height is mapped to [0, image_height - 1] in image height coordinates. We do allow y1 > y2, in which case the sampled crop is an up-down flipped version of the original image. The width dimension is treated similarly. Normalized coordinates outside the [0, 1] range are allowed, in which case we use extrapolation_value to extrapolate the input image values. box_ind: A 1-D variable of shape [num_boxes] with int values in [0, batch). The value of box_ind[i] specifies the image that the i-th box refers to. crop_size: A 1-D variable of 2 elements, size = [crop_height, crop_width]. All cropped image patches are resized to this size. The aspect ratio of the image content is not preserved. Both crop_height and crop_width need to be positive. @@ -893,7 +893,7 @@ VARP _CropAndResize(VARP image, VARP boxes, VARP box_ind, VARP crop_size, Interp /*Creates a variable filled with a scalar value. Args: dims: A variable. Must be 1-D Halide_Type_Int. Represents the shape of the output variable. -value: A variable. 0-D (scalar). Value to fill the returned variable. +value: A variable. 0-D (scalar). Value to fill the returned variable. Returns: A variable. Has the same type as value. */ @@ -918,7 +918,7 @@ VARP _Tile(VARP input, VARP multiples) { } /*Gather slices from params according to indices. Arguments: -params: The variable from which to gather values. +params: The variable from which to gather values. indices: Index variable. Must be Halide_Type_Int in range [0, ndims(params)-1]. Returns: Output: Values from params gathered from indices given by indices. @@ -930,10 +930,10 @@ VARP _Gather(VARP params, VARP indices) { } /*Gather slices from params axis according to indices. Arguments: -params: The variable from which to gather values. +params: The variable from which to gather values. indices: Index variable. Must be Halide_Type_Int in range [0, ndims(params)-1]. -axis: A int, the axis in params to gather indices from. Supports negative indexes. -If set to 0, it's same as _Gather. Currently only 0 is supported. +axis: A int, the axis in params to gather indices from. Supports negative indexes. +If set to 0, it's same as _Gather. Currently only 0 is supported. Returns: Output: Values from params gathered from indices given by indices. */ @@ -951,8 +951,8 @@ VARP _GatherV2(VARP params, VARP indices, VARP axis) { /*Removes dimensions of size 1 from the shape of a variable. Args: input: A variable. The input to squeeze. -axis: A vector, Defaults to {}. If specified, only squeezes the dimensions listed. The dimension index starts at 0. -Must be in the range [-rank(input), rank(input)). +axis: A vector, Defaults to {}. If specified, only squeezes the dimensions listed. The dimension index starts at 0. +Must be in the range [-rank(input), rank(input)). Returns: A variable. Has the same type as input. Contains the same data as input, but has one or more dimensions of size 1 removed. */ @@ -1062,24 +1062,24 @@ VARP _GatherElements(VARP params, VARP indices, VARP axis) { } /*BatchToSpace for N-D variables -This operation reshapes the "batch" dimension 0 into M + 1 dimensions of shape block_shape + [batch], -interleaves these blocks back into the grid defined by the spatial dimensions [1, ..., M], -to obtain a result with the same rank as the input. -The spatial dimensions of this intermediate result are then optionally cropped according to crops to +This operation reshapes the "batch" dimension 0 into M + 1 dimensions of shape block_shape + [batch], +interleaves these blocks back into the grid defined by the spatial dimensions [1, ..., M], +to obtain a result with the same rank as the input. +The spatial dimensions of this intermediate result are then optionally cropped according to crops to produce the output. This is the reverse of SpaceToBatch. See below for a precise description. Arguments: input: must be 4-D with NC4HW4 format. N-D with shape input_shape = [batch] + spatial_shape + remaining_shape, where spatial_shape has M dimensions. block_shape: 1-D with shape [M], all values must be >= 1. -crops: 2-D with shape [M, 2], all values must be >= 0. crops[i] = [crop_start, crop_end] specifies the amount to crop from input dimension i + 1, +crops: 2-D with shape [M, 2], all values must be >= 0. crops[i] = [crop_start, crop_end] specifies the amount to crop from input dimension i + 1, which corresponds to spatial dimension i. It is required that crop_start[i] + crop_end[i] <= block_shape[i] * input_shape[i + 1]. This operation is equivalent to the following steps: -Reshape input to reshaped of shape: [block_shape[0], ..., block_shape[M-1], batch / prod(block_shape), +Reshape input to reshaped of shape: [block_shape[0], ..., block_shape[M-1], batch / prod(block_shape), input_shape[1], ..., input_shape[N-1]] -Permute dimensions of reshaped to produce permuted of shape +Permute dimensions of reshaped to produce permuted of shape [batch / prod(block_shape),input_shape[1], block_shape[0], ..., input_shape[M], block_shape[M-1],input_shape[M+1], ..., input_shape[N-1]] -Reshape permuted to produce reshaped_permuted of shape +Reshape permuted to produce reshaped_permuted of shape [batch / prod(block_shape),input_shape[1] * block_shape[0], ..., input_shape[M] * block_shape[M-1],input_shape[M+1], ..., input_shape[N-1]] -Crop the start and end of dimensions [1, ..., M] of reshaped_permuted according to crops to produce the output of shape: +Crop the start and end of dimensions [1, ..., M] of reshaped_permuted according to crops to produce the output of shape: [batch / prod(block_shape),input_shape[1] * block_shape[0] - crops[0,0] - crops[0,1], ..., input_shape[M] * block_shape[M-1] - crops[M-1,0] - crops[M-1,1],input_shape[M+1], ..., input_shape[N-1]] Some examples: for the following input of shape [4, 1, 1, 3], block_shape = [2, 2], and crops = [[0, 0], [0, 0]]: @@ -1095,14 +1095,14 @@ VARP _BatchToSpaceND(VARP input, VARP block_shape, VARP crops) { std::unique_ptr op(new OpT); std::unique_ptr blob_blockShape(new BlobT); std::unique_ptr blob_paddings(new BlobT); - + auto info_block_shape = block_shape->getInfo(); auto info_crops = crops->getInfo(); MNN_ASSERT(info_block_shape != nullptr); MNN_ASSERT(info_crops != nullptr); MNN_ASSERT(halide_type_int == info_block_shape->type.code); MNN_ASSERT(halide_type_int == info_crops->type.code); - + blob_blockShape->dims = info_block_shape->dim; blob_blockShape->dataFormat = (MNN_DATA_FORMAT)Utils::convertFormat(info_block_shape->order); blob_blockShape->dataType = (MNN::DataType)Utils::convertDataType(info_block_shape->type); @@ -1144,7 +1144,7 @@ VARP _MatrixBandPart(VARP input, VARP num_lower, VARP num_upper) { Args: x: A variable. must be 4-D with NC4HW4 format. axes: Array of ints. Axes along which to compute mean and variance. Ignored for this implementation: must be {2, 3} -shift: Not used in the current implementation. +shift: Not used in the current implementation. keepdims: produce moments with the same dimensionality as the input. Ignored for this implementation: must be true. Returns: Two variable objects: mean and variance. @@ -1153,7 +1153,7 @@ std::vector _Moments(VARP x, INTS axis, VARP shift, bool keepDims) { std::unique_ptr op(new OpT); axis = {2, 3}; keepDims = true; - // if axis != {2,3} or keepDims != true, print warning. + // if axis != {2,3} or keepDims != true, print warning. // ignore shift. op->type = OpType_Moments; auto momentsParam = new MomentsParamT; @@ -1168,11 +1168,11 @@ std::vector _Moments(VARP x, INTS axis, VARP shift, bool keepDims) { return res; } /*Computes the difference between two lists of numbers or strings. -Given a list x and a list y, this operation returns a list out that represents all values that are in x but not in y. -The returned list out is sorted in the same order that the numbers appear in x (duplicates are preserved). -This operation also returns a list idx that represents the position of each out element in x. +Given a list x and a list y, this operation returns a list out that represents all values that are in x but not in y. +The returned list out is sorted in the same order that the numbers appear in x (duplicates are preserved). +This operation also returns a list idx that represents the position of each out element in x. Arguments: -x: 1-D variable of type Halide_Type_Int. Values to keep. +x: 1-D variable of type Halide_Type_Int. Values to keep. y: 1-D variable of type Halide_Type_Int. Values to remove. Returns: Output out: 1-D variable of type Halide_Type_Int. Values present in x but not in y. @@ -1184,8 +1184,8 @@ VARP _SetDiff1D(VARP x, VARP y) { op->main.value = nullptr; return Variable::create(Expr::create(std::move(op), {x, y})); } -/*Rearranges blocks of spatial data, into depth. -More specifically, it outputs a copy of the input variable where values from the height and width dimensions are moved to the depth dimension. +/*Rearranges blocks of spatial data, into depth. +More specifically, it outputs a copy of the input variable where values from the height and width dimensions are moved to the depth dimension. The block_size indicates the input block size. Non-overlapping blocks of size block_size x block_size are rearranged into depth at each location. The depth of the output variable is block_size * block_size * input_depth. @@ -1207,11 +1207,11 @@ VARP _SpaceToDepth(VARP input, int block_size) { return Variable::create(Expr::create(std::move(op), {input})); } -/*This operation divides "spatial" dimensions [1, ..., M] of the input into a grid of blocks of shape block_shape, -and interleaves these blocks with the "batch" dimension +/*This operation divides "spatial" dimensions [1, ..., M] of the input into a grid of blocks of shape block_shape, +and interleaves these blocks with the "batch" dimension such that in the output, the spatial dimensions [1, ..., M] correspond to the position within the grid, and the batch dimension combines both the position within a spatial block and the original batch position. -Prior to division into blocks, the spatial dimensions of the input are optionally zero padded according to paddings. +Prior to division into blocks, the spatial dimensions of the input are optionally zero padded according to paddings. See below for a precise description. Args: input: A variable. must be 4-D with NC4HW4 format. N-D with shape input_shape = [batch] + spatial_shape + remaining_shape, where spatial_shape has M dimensions. @@ -1232,7 +1232,7 @@ VARP _SpaceToBatchND(VARP input, VARP block_shape, VARP paddings) { MNN_ASSERT(info_paddings != nullptr); MNN_ASSERT(halide_type_int == info_block_shape->type.code); MNN_ASSERT(halide_type_int == info_paddings->type.code); - + blob_blockShape->dims = info_block_shape->dim; blob_blockShape->dataFormat = (MNN::MNN_DATA_FORMAT)Utils::convertFormat(info_block_shape->order); blob_blockShape->dataType = (MNN::DataType)Utils::convertDataType(info_block_shape->type); @@ -1271,9 +1271,9 @@ VARP _ZerosLike(VARP input) { } /*Unpacks the given dimension of a rank-R tensor into rank-(R-1) variable. For example, given a variable of shape (A, B, C, D); -If axis == 0 then the i'th variable in output is the slice value[i, :, :, :] and each variable in output will have shape (B, C, D). +If axis == 0 then the i'th variable in output is the slice value[i, :, :, :] and each variable in output will have shape (B, C, D). (Note that the dimension unpacked along is gone, unlike split). -If axis == 1 then the i'th variable in output is the slice value[:, i, :, :] and each variable in output will have shape (A, C, D). +If axis == 1 then the i'th variable in output is the slice value[:, i, :, :] and each variable in output will have shape (A, C, D). Args: value: A rank R > 0 variable to be unstacked. num: An int. The length of the dimension axis. Automatically inferred if None (the default). @@ -1304,13 +1304,13 @@ std::vector _Unstack(VARP value, int axis) { for (int i = 0; i < size; ++i) { res.emplace_back(Variable::create(expr, i)); } - return res; + return res; } /*Returns the rank of a variable. Returns a 0-D int32 variable representing the rank of input. -Note: The rank of a variable is not the same as the rank of a matrix. -It's the number of indices required to uniquely select each element of the variable. +Note: The rank of a variable is not the same as the rank of a matrix. +It's the number of indices required to uniquely select each element of the variable. It's also known as "order", "degree", or "ndims." Args: input: A variable. @@ -1326,9 +1326,9 @@ VARP _Rank(VARP input) { } /*Creates a sequence of numbers. Args: -start: A 0-D variable (scalar). -limit: A 0-D variable (scalar). -delta: A 0-D variable (scalar). +start: A 0-D variable (scalar). +limit: A 0-D variable (scalar). +delta: A 0-D variable (scalar). */ VARP _Range(VARP start, VARP limit, VARP delta) { std::unique_ptr op(new OpT); @@ -1338,9 +1338,9 @@ VARP _Range(VARP start, VARP limit, VARP delta) { op->main.value = rangeParam; return Variable::create(Expr::create(std::move(op), {start, limit, delta})); } -/*Rearranges data from depth into blocks of spatial data. +/*Rearranges data from depth into blocks of spatial data. It is the reverse transformation of SpaceToDepth. More specifically, -it outputs a copy of the input variable where values from the depth dimension are moved in spatial blocks to the height and width dimensions. +it outputs a copy of the input variable where values from the depth dimension are moved in spatial blocks to the height and width dimensions. Args: input: A variable. block_size: An int that is >= 2. The size of the spatial block, same as in Space2Depth. @@ -1356,25 +1356,25 @@ VARP _DepthToSpace(VARP input, int block_size) { op->main.value = depthtospaceParam; return Variable::create(Expr::create(std::move(op), {input})); } -/*SSD network's priorbox layer. +/*SSD network's priorbox layer. Args: -feature: A variable. Contains the feature map. Namely bottom[0] in caffe. +feature: A variable. Contains the feature map. Namely bottom[0] in caffe. image: A variable. Contains the image. Namely bottom[1] in caffe. -min_size: Minimum box size (in pixels). +min_size: Minimum box size (in pixels). max_size: Maximum box size (in pixels). -aspect_ratio: Various of aspect ratios. Duplicate ratios are ignored. If none is provided, use default 1.0. -flip: If true, flips each aspect ratio. For example, if there is aspect ratio "r", generates aspect ratio "1.0/r" as well. Default true. -clip: If true, clips the prior so that it is within [0, 1]. Default false. -variance: Variance for adjusting the prior bboxes. -img_h: image height. If 0, uses information in image. +aspect_ratio: Various of aspect ratios. Duplicate ratios are ignored. If none is provided, use default 1.0. +flip: If true, flips each aspect ratio. For example, if there is aspect ratio "r", generates aspect ratio "1.0/r" as well. Default true. +clip: If true, clips the prior so that it is within [0, 1]. Default false. +variance: Variance for adjusting the prior bboxes. +img_h: image height. If 0, uses information in image. img_w: image width. If 0, uses information in image. -step_h: step in height. -step_w: step in width. -offset: Offset to the top left corner of each cell. -Returns: -A variable. +step_h: step in height. +step_w: step in width. +offset: Offset to the top left corner of each cell. +Returns: +A variable. */ -VARP _PriorBox(VARP feature, VARP image, std::vector min_size, std::vector max_size, std::vectoraspect_ratio, +VARP _PriorBox(VARP feature, VARP image, std::vector min_size, std::vector max_size, std::vectoraspect_ratio, bool flip, bool clip, std::vectorvariance, unsigned int img_h, unsigned int img_w, float step_h, float step_w, float offset) { std::unique_ptr op(new OpT); @@ -1395,12 +1395,12 @@ VARP _PriorBox(VARP feature, VARP image, std::vector min_size, std::vecto op->main.value = param; return Variable::create(Expr::create(std::move(op), {feature, image})); } -/*SSD network's permute layer. +/*SSD network's permute layer. Args: -input: A variable. Contains the feature map. Namely bottom[0] in caffe. +input: A variable. Contains the feature map. Namely bottom[0] in caffe. dims: A vector. Contains the order. -Returns: -A variable. +Returns: +A variable. */ VARP _Permute(VARP input, INTS dims) { std::unique_ptr op(new OpT); @@ -1411,27 +1411,27 @@ VARP _Permute(VARP input, INTS dims) { op->main.value = param; return Variable::create(Expr::create(std::move(op), {input})); } -/*SSD network's detectionoutput layer. +/*SSD network's detectionoutput layer. Args: -location: A variable. +location: A variable. confidence: A variable. priorbox: A variable. num_classes: number of classes. -share_location: indicates wheter share location between different classes, default true. -background_label_id: default = 0. +share_location: indicates wheter share location between different classes, default true. +background_label_id: default = 0. nms_threshhold: nonmaximumsupression threshhold. mns_topk: nonmaximumsupression topk. -code_type: indicates the mode to encode bbox, default = CORNER. -variance_encoded_in_target: indicates whether encode variance in target, default false. -keep_top_k: indicates the number of boxes kept, default -1(all boxes are kept). -confidence_threshold: the threshhold for confidence. +code_type: indicates the mode to encode bbox, default = CORNER. +variance_encoded_in_target: indicates whether encode variance in target, default false. +keep_top_k: indicates the number of boxes kept, default -1(all boxes are kept). +confidence_threshold: the threshhold for confidence. visualize_threshold: The threshold used to visualize the detection results. -Returns: -A variable. +Returns: +A variable. */ -VARP _DetectionOutput(VARP location, VARP confidence, VARP priorbox, - unsigned int num_classes, bool share_location, int background_label_id, - float nms_threshhold, int nms_topk, int code_type, +VARP _DetectionOutput(VARP location, VARP confidence, VARP priorbox, + unsigned int num_classes, bool share_location, int background_label_id, + float nms_threshhold, int nms_topk, int code_type, bool variance_encoded_in_target, int keep_top_k, float confidence_threshold, float visualize_threshold){ std::unique_ptr op(new OpT); @@ -1451,26 +1451,26 @@ VARP _DetectionOutput(VARP location, VARP confidence, VARP priorbox, op->main.value = param; return Variable::create(Expr::create(std::move(op), {location, confidence, priorbox})); } -/*SSD network's detectionpostprocess layer. +/*SSD network's detectionpostprocess layer. Args: -encode_boxes: A variable. +encode_boxes: A variable. class_predictions: A variable. anchors: A variable. num_classes: number of classes. max_detections: A int, indicates max detections. -max_class_per_detection: A int, indicates max class per detection. -detections_per_class: A int, indicates detections per class. +max_class_per_detection: A int, indicates max class per detection. +detections_per_class: A int, indicates detections per class. nms_threshhold: A float, the threshold for nms. -iou_threshold: A float, the threshold for iou. -use_regular_nms: A bool, indicates whether use regular nms method, only false is implemented currently. -centersize_encoding: A float vector, indicates the centersize encoding. -Returns: +iou_threshold: A float, the threshold for iou. +use_regular_nms: A bool, indicates whether use regular nms method, only false is implemented currently. +centersize_encoding: A float vector, indicates the centersize encoding. +Returns: 4 variable, detection_boxes, detection_class, detection_scores, num_detections */ -std::vector _DetectionPostProcess(VARP encode_boxes, VARP class_predictions, VARP anchors, - int num_classes, int max_detections, - int max_class_per_detection, int detections_per_class, - float nms_threshold, float iou_threshold, +std::vector _DetectionPostProcess(VARP encode_boxes, VARP class_predictions, VARP anchors, + int num_classes, int max_detections, + int max_class_per_detection, int detections_per_class, + float nms_threshold, float iou_threshold, bool use_regular_nms, std::vector centersize_encoding){ std::unique_ptr op(new OpT); op->type = OpType_DetectionPostProcess; @@ -1649,7 +1649,7 @@ VARP _Conv(std::vector&& weight, std::vector&& bias, std::vector< } conv2D->bias = bias; - + conv2D->symmetricQuan->weight = std::move(weight); conv2D->symmetricQuan->zeroPoint = std::move(inputZeroPoint); conv2D->symmetricQuan->outputZeroPoint = std::move(outputZeroPoint); diff --git a/include/MNN/MNNDefine.h b/include/MNN/MNNDefine.h index 695a55cad..bd0b72a30 100644 --- a/include/MNN/MNNDefine.h +++ b/include/MNN/MNNDefine.h @@ -75,6 +75,6 @@ MNN_ERROR("Check failed: %s ==> %s\n", #success, #log); \ #define STR(x) STR_IMP(x) #define MNN_VERSION_MAJOR 3 #define MNN_VERSION_MINOR 0 -#define MNN_VERSION_PATCH 1 +#define MNN_VERSION_PATCH 2 #define MNN_VERSION STR(MNN_VERSION_MAJOR) "." STR(MNN_VERSION_MINOR) "." STR(MNN_VERSION_PATCH) #endif /* MNNDefine_h */ diff --git a/include/MNN/MNNForwardType.h b/include/MNN/MNNForwardType.h index c115113ea..31665c1ec 100644 --- a/include/MNN/MNNForwardType.h +++ b/include/MNN/MNNForwardType.h @@ -40,14 +40,16 @@ typedef enum { MNN_FORWARD_USER_2 = 10, MNN_FORWARD_USER_3 = 11, - MNN_FORWARD_ALL, + MNN_FORWARD_ALL = 12, /* Apply arm extension instruction set to accelerate some Ops, this forward type is only used in MNN internal, and will be active automatically when user set forward type to be MNN_FORWARD_CPU and extension instruction set is valid on hardware. */ - MNN_FORWARD_CPU_EXTENSION - + MNN_FORWARD_CPU_EXTENSION = 13, + // use for shared memory on android device + + MNN_MEMORY_AHARDWAREBUFFER = 14 } MNNForwardType; typedef enum { diff --git a/include/MNN/expr/MathOp.hpp b/include/MNN/expr/MathOp.hpp index c5595fa5a..9d7e41763 100644 --- a/include/MNN/expr/MathOp.hpp +++ b/include/MNN/expr/MathOp.hpp @@ -13,7 +13,7 @@ namespace MNN { namespace Express { //BinaryOPs MNN_PUBLIC VARP _Add(VARP x, VARP y); -MNN_PUBLIC VARP _Subtract(VARP x, VARP y); +MNN_PUBLIC VARP _Subtract(VARP x, VARP y); MNN_PUBLIC VARP _Multiply(VARP x, VARP y); MNN_PUBLIC VARP _Divide(VARP x, VARP y); MNN_PUBLIC VARP _Pow(VARP x, VARP y); @@ -92,19 +92,19 @@ MNN_PUBLIC VARP _Prod(VARP a, VARP b, std::vector coeff); MNN_PUBLIC VARP _Sum(VARP a, VARP b, std::vector coeff); MNN_PUBLIC VARP _Max(VARP a, VARP b, std::vector coeff); MNN_PUBLIC VARP _Sub(VARP a, VARP b, std::vector coeff); -MNN_PUBLIC VARP _EltwiseProdInt8(VARP x, VARP y, +MNN_PUBLIC VARP _EltwiseProdInt8(VARP x, VARP y, std::vector x_weight, std::vector x_bias, std::vector x_scale, std::vector x_tensorScale, std::vector y_weight, std::vector y_bias, std::vector y_scale, std::vector y_tensorScale, std::vector output_weight, std::vector output_bias, std::vector output_scale, std::vector output_tensorScale); -MNN_PUBLIC VARP _EltwiseSumInt8(VARP x, VARP y, +MNN_PUBLIC VARP _EltwiseSumInt8(VARP x, VARP y, std::vector x_weight, std::vector x_bias, std::vector x_scale, std::vector x_tensorScale, std::vector y_weight, std::vector y_bias, std::vector y_scale, std::vector y_tensorScale, std::vector output_weight, std::vector output_bias, std::vector output_scale, std::vector output_tensorScale); -MNN_PUBLIC VARP _EltwiseSubInt8(VARP x, VARP y, +MNN_PUBLIC VARP _EltwiseSubInt8(VARP x, VARP y, std::vector x_weight, std::vector x_bias, std::vector x_scale, std::vector x_tensorScale, std::vector y_weight, std::vector y_bias, std::vector y_scale, std::vector y_tensorScale, std::vector output_weight, std::vector output_bias, std::vector output_scale, std::vector output_tensorScale); -MNN_PUBLIC VARP _EltwiseMaxInt8(VARP x, VARP y, +MNN_PUBLIC VARP _EltwiseMaxInt8(VARP x, VARP y, std::vector x_weight, std::vector x_bias, std::vector x_scale, std::vector x_tensorScale, std::vector y_weight, std::vector y_bias, std::vector y_scale, std::vector y_tensorScale, std::vector output_weight, std::vector output_bias, std::vector output_scale, std::vector output_tensorScale); @@ -138,6 +138,9 @@ MNN_PUBLIC VARP _CumSum(VARP x, int axis, bool exclusive = false, bool reverse = MNN_PUBLIC VARP _CumProd(VARP x, int axis); MNN_PUBLIC VARPS _Svd(VARP x); MNN_PUBLIC VARP _Histogram(VARP x, int bin, int min, int max, int channel = -1); +#ifdef MNN_BUILD_AUDIO +MNN_PUBLIC VARP _Stft(VARP sample, VARP window, int n_fft, int hop_length, bool abse = true); +#endif }; // namespace Express }; // namespace MNN diff --git a/project/android/build_32.sh b/project/android/build_32.sh index 24f0eb8cc..c9f9f7b24 100755 --- a/project/android/build_32.sh +++ b/project/android/build_32.sh @@ -4,7 +4,7 @@ cmake ../../../ \ -DCMAKE_BUILD_TYPE=Release \ -DANDROID_ABI="armeabi-v7a" \ -DANDROID_STL=c++_static \ --DANDROID_NATIVE_API_LEVEL=android-14 \ +-DANDROID_NATIVE_API_LEVEL=android-26 \ -DANDROID_TOOLCHAIN=clang \ -DMNN_USE_LOGCAT=false \ -DMNN_USE_SSE=OFF \ diff --git a/project/android/build_64.sh b/project/android/build_64.sh index 34b18057e..328d8ea74 100755 --- a/project/android/build_64.sh +++ b/project/android/build_64.sh @@ -8,7 +8,7 @@ cmake ../../../ \ -DMNN_BUILD_BENCHMARK=ON \ -DMNN_USE_SSE=OFF \ -DMNN_BUILD_TEST=ON \ --DANDROID_NATIVE_API_LEVEL=android-21 \ +-DANDROID_NATIVE_API_LEVEL=android-26 \ -DMNN_BUILD_FOR_ANDROID_COMMAND=true \ -DNATIVE_LIBRARY_OUTPUT=. -DNATIVE_INCLUDE_OUTPUT=. $1 $2 $3 $4 $5 $6 $7 diff --git a/project/ios/MNN.xcodeproj/project.pbxproj b/project/ios/MNN.xcodeproj/project.pbxproj index 6aafd2121..36b0971e1 100644 --- a/project/ios/MNN.xcodeproj/project.pbxproj +++ b/project/ios/MNN.xcodeproj/project.pbxproj @@ -486,11 +486,9 @@ 92FF02E223AA0B5A00AC97F6 /* MNNMatrixAdd.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF016223AA0B4E00AC97F6 /* MNNMatrixAdd.S */; }; 92FF02E323AA0B5A00AC97F6 /* MNNExpC8.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF016323AA0B4E00AC97F6 /* MNNExpC8.S */; }; 92FF02E523AA0B5A00AC97F6 /* MNNConvDwF23SourceTransUnit.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF016523AA0B4E00AC97F6 /* MNNConvDwF23SourceTransUnit.S */; }; - 92FF02E623AA0B5A00AC97F6 /* MNNWinogradMatrixProductLeft.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF016623AA0B4E00AC97F6 /* MNNWinogradMatrixProductLeft.S */; }; 92FF02E723AA0B5A00AC97F6 /* MNNDeconvRunForUnitDepthWise.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF016723AA0B4E00AC97F6 /* MNNDeconvRunForUnitDepthWise.S */; }; 92FF02E823AA0B5A00AC97F6 /* MNNSamplerC1BilinearOpt.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF016823AA0B4E00AC97F6 /* MNNSamplerC1BilinearOpt.S */; }; 92FF02EA23AA0B5A00AC97F6 /* MNNGemmInt8AddBiasScale_16x4_Unit.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF016A23AA0B4E00AC97F6 /* MNNGemmInt8AddBiasScale_16x4_Unit.S */; }; - 92FF02EC23AA0B5A00AC97F6 /* MNNWinogradMatrixProductRight.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF016C23AA0B4E00AC97F6 /* MNNWinogradMatrixProductRight.S */; }; 92FF02EE23AA0B5A00AC97F6 /* MNNReluWithSlopeChannel.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF016E23AA0B4E00AC97F6 /* MNNReluWithSlopeChannel.S */; }; 92FF02F223AA0B5A00AC97F6 /* MNNBlitC3ToFloatRGBA.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF017223AA0B4E00AC97F6 /* MNNBlitC3ToFloatRGBA.S */; }; 92FF02F423AA0B5A00AC97F6 /* MNNUInt8ToInt16WithOffsetC4Common.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF017423AA0B4E00AC97F6 /* MNNUInt8ToInt16WithOffsetC4Common.S */; }; @@ -530,11 +528,9 @@ 92FF032223AA0B5A00AC97F6 /* MNNMatrixAdd.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF01A323AA0B4E00AC97F6 /* MNNMatrixAdd.S */; }; 92FF032323AA0B5A00AC97F6 /* MNNExpC8.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF01A423AA0B4E00AC97F6 /* MNNExpC8.S */; }; 92FF032523AA0B5A00AC97F6 /* MNNConvDwF23SourceTransUnit.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF01A623AA0B4E00AC97F6 /* MNNConvDwF23SourceTransUnit.S */; }; - 92FF032623AA0B5A00AC97F6 /* MNNWinogradMatrixProductLeft.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF01A723AA0B4E00AC97F6 /* MNNWinogradMatrixProductLeft.S */; }; 92FF032723AA0B5A00AC97F6 /* MNNDeconvRunForUnitDepthWise.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF01A823AA0B4E00AC97F6 /* MNNDeconvRunForUnitDepthWise.S */; }; 92FF032823AA0B5A00AC97F6 /* MNNSamplerC1BilinearOpt.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF01A923AA0B4E00AC97F6 /* MNNSamplerC1BilinearOpt.S */; }; 92FF032A23AA0B5A00AC97F6 /* MNNGemmInt8AddBiasScale_16x4_Unit.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF01AB23AA0B4E00AC97F6 /* MNNGemmInt8AddBiasScale_16x4_Unit.S */; }; - 92FF032C23AA0B5A00AC97F6 /* MNNWinogradMatrixProductRight.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF01AD23AA0B4E00AC97F6 /* MNNWinogradMatrixProductRight.S */; }; 92FF032E23AA0B5A00AC97F6 /* MNNReluWithSlopeChannel.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF01AF23AA0B4E00AC97F6 /* MNNReluWithSlopeChannel.S */; }; 92FF033223AA0B5A00AC97F6 /* MNNBlitC3ToFloatRGBA.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF01B323AA0B4E00AC97F6 /* MNNBlitC3ToFloatRGBA.S */; }; 92FF033423AA0B5A00AC97F6 /* MNNUInt8ToInt16WithOffsetC4Common.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF01B523AA0B4E00AC97F6 /* MNNUInt8ToInt16WithOffsetC4Common.S */; }; @@ -592,7 +588,6 @@ 92FF03A123AA0B5A00AC97F6 /* Int8FunctionsOpt.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 92FF022323AA0B5600AC97F6 /* Int8FunctionsOpt.cpp */; }; 92FF03A323AA0B5A00AC97F6 /* ConvOpt.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 92FF022523AA0B5600AC97F6 /* ConvOpt.cpp */; }; 92FF03A423AA0B5A00AC97F6 /* OptimizedComputer.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 92FF022623AA0B5600AC97F6 /* OptimizedComputer.cpp */; }; - 92FF03A523AA0B5A00AC97F6 /* DeconvolutionWithStride.hpp in Headers */ = {isa = PBXBuildFile; fileRef = 92FF022723AA0B5600AC97F6 /* DeconvolutionWithStride.hpp */; }; 92FF03A623AA0B5A00AC97F6 /* ConvolutionTiledExecutor.hpp in Headers */ = {isa = PBXBuildFile; fileRef = 92FF022823AA0B5600AC97F6 /* ConvolutionTiledExecutor.hpp */; }; 92FF03A723AA0B5A00AC97F6 /* ConvolutionIntFactory.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 92FF022923AA0B5600AC97F6 /* ConvolutionIntFactory.cpp */; }; 92FF03A823AA0B5A00AC97F6 /* WinogradOptFunction.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 92FF022A23AA0B5600AC97F6 /* WinogradOptFunction.cpp */; }; @@ -609,7 +604,6 @@ 92FF03B923AA0B5A00AC97F6 /* ConvOpt.h in Headers */ = {isa = PBXBuildFile; fileRef = 92FF023B23AA0B5600AC97F6 /* ConvOpt.h */; }; 92FF03BC23AA0B5A00AC97F6 /* OptimizedComputer.hpp in Headers */ = {isa = PBXBuildFile; fileRef = 92FF023E23AA0B5600AC97F6 /* OptimizedComputer.hpp */; }; 92FF03BD23AA0B5A00AC97F6 /* Int8FunctionsOpt.h in Headers */ = {isa = PBXBuildFile; fileRef = 92FF023F23AA0B5600AC97F6 /* Int8FunctionsOpt.h */; }; - 92FF03BE23AA0B5A00AC97F6 /* DeconvolutionWithStride.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 92FF024023AA0B5600AC97F6 /* DeconvolutionWithStride.cpp */; }; 92FF03BF23AA0B5A00AC97F6 /* ConvolutionTiledExecutor.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 92FF024123AA0B5600AC97F6 /* ConvolutionTiledExecutor.cpp */; }; 92FF03C323AA0B5A00AC97F6 /* CPUEltwise.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 92FF024523AA0B5700AC97F6 /* CPUEltwise.cpp */; }; 92FF03C423AA0B5A00AC97F6 /* CPUInterp.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 92FF024623AA0B5700AC97F6 /* CPUInterp.cpp */; }; @@ -740,8 +734,6 @@ 95772DCF2C50F12A000FC1C3 /* MNNPackC4Int8ForMatMulA_ARM82.S in Sources */ = {isa = PBXBuildFile; fileRef = 95772DCD2C50F12A000FC1C3 /* MNNPackC4Int8ForMatMulA_ARM82.S */; }; 95772DD02C50F12A000FC1C3 /* MNNPackC4Int8ForMatMulA_ARM86.S in Sources */ = {isa = PBXBuildFile; fileRef = 95772DCE2C50F12A000FC1C3 /* MNNPackC4Int8ForMatMulA_ARM86.S */; }; 958375352A496E5C007C0A3E /* MNNLineDepthWiseInt8AddBiasScale_ARMV82_Unit3X3.S in Sources */ = {isa = PBXBuildFile; fileRef = 958375342A496E5C007C0A3E /* MNNLineDepthWiseInt8AddBiasScale_ARMV82_Unit3X3.S */; }; - 958B046429D2C89D00FC3AEF /* GemmInt8Executor.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 958B046329D2C89D00FC3AEF /* GemmInt8Executor.cpp */; }; - 958B046629D2C8AF00FC3AEF /* GemmInt8Executor.hpp in Headers */ = {isa = PBXBuildFile; fileRef = 958B046529D2C8AF00FC3AEF /* GemmInt8Executor.hpp */; }; 95CE1DFF2AC57F6200EFB51E /* MNNReluWithSlopeChannelInt8.S in Sources */ = {isa = PBXBuildFile; fileRef = 95CE1DFE2AC57F6200EFB51E /* MNNReluWithSlopeChannelInt8.S */; }; 95CE1E012AC57F7600EFB51E /* MNNReluWithSlopeChannelInt8.S in Sources */ = {isa = PBXBuildFile; fileRef = 95CE1E002AC57F7600EFB51E /* MNNReluWithSlopeChannelInt8.S */; }; C43C81FA251894A600A0FF84 /* CommonOptFunctionNeon.cpp in Sources */ = {isa = PBXBuildFile; fileRef = C43C81F8251894A500A0FF84 /* CommonOptFunctionNeon.cpp */; }; @@ -1342,11 +1334,9 @@ 92FF016223AA0B4E00AC97F6 /* MNNMatrixAdd.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNMatrixAdd.S; sourceTree = ""; }; 92FF016323AA0B4E00AC97F6 /* MNNExpC8.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNExpC8.S; sourceTree = ""; }; 92FF016523AA0B4E00AC97F6 /* MNNConvDwF23SourceTransUnit.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNConvDwF23SourceTransUnit.S; sourceTree = ""; }; - 92FF016623AA0B4E00AC97F6 /* MNNWinogradMatrixProductLeft.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNWinogradMatrixProductLeft.S; sourceTree = ""; }; 92FF016723AA0B4E00AC97F6 /* MNNDeconvRunForUnitDepthWise.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNDeconvRunForUnitDepthWise.S; sourceTree = ""; }; 92FF016823AA0B4E00AC97F6 /* MNNSamplerC1BilinearOpt.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNSamplerC1BilinearOpt.S; sourceTree = ""; }; 92FF016A23AA0B4E00AC97F6 /* MNNGemmInt8AddBiasScale_16x4_Unit.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNGemmInt8AddBiasScale_16x4_Unit.S; sourceTree = ""; }; - 92FF016C23AA0B4E00AC97F6 /* MNNWinogradMatrixProductRight.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNWinogradMatrixProductRight.S; sourceTree = ""; }; 92FF016E23AA0B4E00AC97F6 /* MNNReluWithSlopeChannel.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNReluWithSlopeChannel.S; sourceTree = ""; }; 92FF017223AA0B4E00AC97F6 /* MNNBlitC3ToFloatRGBA.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNBlitC3ToFloatRGBA.S; sourceTree = ""; }; 92FF017423AA0B4E00AC97F6 /* MNNUInt8ToInt16WithOffsetC4Common.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNUInt8ToInt16WithOffsetC4Common.S; sourceTree = ""; }; @@ -1386,11 +1376,9 @@ 92FF01A323AA0B4E00AC97F6 /* MNNMatrixAdd.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNMatrixAdd.S; sourceTree = ""; }; 92FF01A423AA0B4E00AC97F6 /* MNNExpC8.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNExpC8.S; sourceTree = ""; }; 92FF01A623AA0B4E00AC97F6 /* MNNConvDwF23SourceTransUnit.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNConvDwF23SourceTransUnit.S; sourceTree = ""; }; - 92FF01A723AA0B4E00AC97F6 /* MNNWinogradMatrixProductLeft.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNWinogradMatrixProductLeft.S; sourceTree = ""; }; 92FF01A823AA0B4E00AC97F6 /* MNNDeconvRunForUnitDepthWise.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNDeconvRunForUnitDepthWise.S; sourceTree = ""; }; 92FF01A923AA0B4E00AC97F6 /* MNNSamplerC1BilinearOpt.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNSamplerC1BilinearOpt.S; sourceTree = ""; }; 92FF01AB23AA0B4E00AC97F6 /* MNNGemmInt8AddBiasScale_16x4_Unit.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNGemmInt8AddBiasScale_16x4_Unit.S; sourceTree = ""; }; - 92FF01AD23AA0B4E00AC97F6 /* MNNWinogradMatrixProductRight.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNWinogradMatrixProductRight.S; sourceTree = ""; }; 92FF01AF23AA0B4E00AC97F6 /* MNNReluWithSlopeChannel.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNReluWithSlopeChannel.S; sourceTree = ""; }; 92FF01B323AA0B4E00AC97F6 /* MNNBlitC3ToFloatRGBA.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNBlitC3ToFloatRGBA.S; sourceTree = ""; }; 92FF01B523AA0B4E00AC97F6 /* MNNUInt8ToInt16WithOffsetC4Common.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNUInt8ToInt16WithOffsetC4Common.S; sourceTree = ""; }; @@ -1448,7 +1436,6 @@ 92FF022323AA0B5600AC97F6 /* Int8FunctionsOpt.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = Int8FunctionsOpt.cpp; sourceTree = ""; }; 92FF022523AA0B5600AC97F6 /* ConvOpt.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = ConvOpt.cpp; sourceTree = ""; }; 92FF022623AA0B5600AC97F6 /* OptimizedComputer.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = OptimizedComputer.cpp; sourceTree = ""; }; - 92FF022723AA0B5600AC97F6 /* DeconvolutionWithStride.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = DeconvolutionWithStride.hpp; sourceTree = ""; }; 92FF022823AA0B5600AC97F6 /* ConvolutionTiledExecutor.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = ConvolutionTiledExecutor.hpp; sourceTree = ""; }; 92FF022923AA0B5600AC97F6 /* ConvolutionIntFactory.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = ConvolutionIntFactory.cpp; sourceTree = ""; }; 92FF022A23AA0B5600AC97F6 /* WinogradOptFunction.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = WinogradOptFunction.cpp; sourceTree = ""; }; @@ -1465,7 +1452,6 @@ 92FF023B23AA0B5600AC97F6 /* ConvOpt.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = ConvOpt.h; sourceTree = ""; }; 92FF023E23AA0B5600AC97F6 /* OptimizedComputer.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = OptimizedComputer.hpp; sourceTree = ""; }; 92FF023F23AA0B5600AC97F6 /* Int8FunctionsOpt.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = Int8FunctionsOpt.h; sourceTree = ""; }; - 92FF024023AA0B5600AC97F6 /* DeconvolutionWithStride.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = DeconvolutionWithStride.cpp; sourceTree = ""; }; 92FF024123AA0B5600AC97F6 /* ConvolutionTiledExecutor.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = ConvolutionTiledExecutor.cpp; sourceTree = ""; }; 92FF024523AA0B5700AC97F6 /* CPUEltwise.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = CPUEltwise.cpp; sourceTree = ""; }; 92FF024623AA0B5700AC97F6 /* CPUInterp.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = CPUInterp.cpp; sourceTree = ""; }; @@ -1597,8 +1583,6 @@ 95772DCD2C50F12A000FC1C3 /* MNNPackC4Int8ForMatMulA_ARM82.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNPackC4Int8ForMatMulA_ARM82.S; sourceTree = ""; }; 95772DCE2C50F12A000FC1C3 /* MNNPackC4Int8ForMatMulA_ARM86.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNPackC4Int8ForMatMulA_ARM86.S; sourceTree = ""; }; 958375342A496E5C007C0A3E /* MNNLineDepthWiseInt8AddBiasScale_ARMV82_Unit3X3.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; name = MNNLineDepthWiseInt8AddBiasScale_ARMV82_Unit3X3.S; path = arm/arm64/MNNLineDepthWiseInt8AddBiasScale_ARMV82_Unit3X3.S; sourceTree = ""; }; - 958B046329D2C89D00FC3AEF /* GemmInt8Executor.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = GemmInt8Executor.cpp; sourceTree = ""; }; - 958B046529D2C8AF00FC3AEF /* GemmInt8Executor.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = GemmInt8Executor.hpp; sourceTree = ""; }; 95CE1DFE2AC57F6200EFB51E /* MNNReluWithSlopeChannelInt8.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNReluWithSlopeChannelInt8.S; sourceTree = ""; }; 95CE1E002AC57F7600EFB51E /* MNNReluWithSlopeChannelInt8.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNReluWithSlopeChannelInt8.S; sourceTree = ""; }; C43C81F8251894A500A0FF84 /* CommonOptFunctionNeon.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = CommonOptFunctionNeon.cpp; sourceTree = ""; }; @@ -2643,11 +2627,9 @@ 92FF016223AA0B4E00AC97F6 /* MNNMatrixAdd.S */, 92FF016323AA0B4E00AC97F6 /* MNNExpC8.S */, 92FF016523AA0B4E00AC97F6 /* MNNConvDwF23SourceTransUnit.S */, - 92FF016623AA0B4E00AC97F6 /* MNNWinogradMatrixProductLeft.S */, 92FF016723AA0B4E00AC97F6 /* MNNDeconvRunForUnitDepthWise.S */, 92FF016823AA0B4E00AC97F6 /* MNNSamplerC1BilinearOpt.S */, 92FF016A23AA0B4E00AC97F6 /* MNNGemmInt8AddBiasScale_16x4_Unit.S */, - 92FF016C23AA0B4E00AC97F6 /* MNNWinogradMatrixProductRight.S */, 92FF016E23AA0B4E00AC97F6 /* MNNReluWithSlopeChannel.S */, 92FF017223AA0B4E00AC97F6 /* MNNBlitC3ToFloatRGBA.S */, 92FF017423AA0B4E00AC97F6 /* MNNUInt8ToInt16WithOffsetC4Common.S */, @@ -2737,11 +2719,9 @@ 92FF01A323AA0B4E00AC97F6 /* MNNMatrixAdd.S */, 92FF01A423AA0B4E00AC97F6 /* MNNExpC8.S */, 92FF01A623AA0B4E00AC97F6 /* MNNConvDwF23SourceTransUnit.S */, - 92FF01A723AA0B4E00AC97F6 /* MNNWinogradMatrixProductLeft.S */, 92FF01A823AA0B4E00AC97F6 /* MNNDeconvRunForUnitDepthWise.S */, 92FF01A923AA0B4E00AC97F6 /* MNNSamplerC1BilinearOpt.S */, 92FF01AB23AA0B4E00AC97F6 /* MNNGemmInt8AddBiasScale_16x4_Unit.S */, - 92FF01AD23AA0B4E00AC97F6 /* MNNWinogradMatrixProductRight.S */, 92FF01AF23AA0B4E00AC97F6 /* MNNReluWithSlopeChannel.S */, 92FF01B323AA0B4E00AC97F6 /* MNNBlitC3ToFloatRGBA.S */, 92FF01B523AA0B4E00AC97F6 /* MNNUInt8ToInt16WithOffsetC4Common.S */, @@ -2761,8 +2741,6 @@ children = ( CEA82BD92A15F8AD002CBC95 /* IdstConvolutionInt8.cpp */, CEA82BDA2A15F8AD002CBC95 /* IdstConvolutionInt8.hpp */, - 958B046529D2C8AF00FC3AEF /* GemmInt8Executor.hpp */, - 958B046329D2C89D00FC3AEF /* GemmInt8Executor.cpp */, C48CAE2528900C4A00271A6D /* ConvInt8Winograd.cpp */, C48CAE2428900C4A00271A6D /* ConvInt8Winograd.hpp */, 4A224A1227D0C56E000A9260 /* ConvolutionWinogradBridge.cpp */, @@ -2790,7 +2768,6 @@ 92FF022323AA0B5600AC97F6 /* Int8FunctionsOpt.cpp */, 92FF022523AA0B5600AC97F6 /* ConvOpt.cpp */, 92FF022623AA0B5600AC97F6 /* OptimizedComputer.cpp */, - 92FF022723AA0B5600AC97F6 /* DeconvolutionWithStride.hpp */, 92FF022823AA0B5600AC97F6 /* ConvolutionTiledExecutor.hpp */, 92FF022923AA0B5600AC97F6 /* ConvolutionIntFactory.cpp */, 92FF022A23AA0B5600AC97F6 /* WinogradOptFunction.cpp */, @@ -2807,7 +2784,6 @@ 92FF023B23AA0B5600AC97F6 /* ConvOpt.h */, 92FF023E23AA0B5600AC97F6 /* OptimizedComputer.hpp */, 92FF023F23AA0B5600AC97F6 /* Int8FunctionsOpt.h */, - 92FF024023AA0B5600AC97F6 /* DeconvolutionWithStride.cpp */, 92FF024123AA0B5600AC97F6 /* ConvolutionTiledExecutor.cpp */, ); path = compute; @@ -2939,7 +2915,6 @@ buildActionMask = 2147483647; files = ( 48C84B89250F711700EE7666 /* StaticModule.hpp in Headers */, - 958B046629D2C8AF00FC3AEF /* GemmInt8Executor.hpp in Headers */, 1F501F812397BA5B004E8721 /* AutoTime.hpp in Headers */, 92FF04A523AA0BFB00AC97F6 /* AutoStorage.h in Headers */, EBECA3A124643D4E0062C7A3 /* MNNAsmGlobal.h in Headers */, @@ -3105,7 +3080,6 @@ 92FF03C923AA0B5A00AC97F6 /* CPUMatMul.hpp in Headers */, EBECA39924643D320062C7A3 /* Arm82Relu.hpp in Headers */, 4838EA7C2611BFE20027232C /* CPUGridSample.hpp in Headers */, - 92FF03A523AA0B5A00AC97F6 /* DeconvolutionWithStride.hpp in Headers */, 92FF03D123AA0B5A00AC97F6 /* CPUTopKV2.hpp in Headers */, 92FF033F23AA0B5A00AC97F6 /* CPUArgMax.hpp in Headers */, 92FF034C23AA0B5A00AC97F6 /* CPUSetDiff1D.hpp in Headers */, @@ -3335,7 +3309,6 @@ 92FF038623AA0B5A00AC97F6 /* CPULinSpace.cpp in Sources */, 4819FB2D24C1396A0050BD09 /* GeometryConv2D.cpp in Sources */, 48747D63245D9E33000B9709 /* GeometryPermute.cpp in Sources */, - 92FF032C23AA0B5A00AC97F6 /* MNNWinogradMatrixProductRight.S in Sources */, 48BB6EF625220AA80056E195 /* MNNTranspose32Bit4x4.S in Sources */, CE072A1C2C91AEE700F190FD /* MNNRGBAToBGRFast.S in Sources */, CEE9B95C2A3AA4D4006438F2 /* MNNBilinearSampleC8.S in Sources */, @@ -3597,7 +3570,6 @@ 48FD12BF2466A88D009E9102 /* GeometryConv2DBackPropFilter.cpp in Sources */, 92FF02F923AA0B5A00AC97F6 /* MNNGemmint8to32_8x4_Unit.S in Sources */, 95772DCF2C50F12A000FC1C3 /* MNNPackC4Int8ForMatMulA_ARM82.S in Sources */, - 92FF02E623AA0B5A00AC97F6 /* MNNWinogradMatrixProductLeft.S in Sources */, 48747D64245D9E33000B9709 /* GeometryTile.cpp in Sources */, 92FF043723AA0B7100AC97F6 /* ShapeDetectionOutput.cpp in Sources */, 92FF042623AA0B7100AC97F6 /* ShapeCosineSimilarity.cpp in Sources */, @@ -3633,7 +3605,6 @@ 92FF043023AA0B7100AC97F6 /* ShapeQuantizedAvgPool.cpp in Sources */, 92FF030623AA0B5A00AC97F6 /* MNNStrassenMergeCFunction.S in Sources */, 92FF033223AA0B5A00AC97F6 /* MNNBlitC3ToFloatRGBA.S in Sources */, - 92FF03BE23AA0B5A00AC97F6 /* DeconvolutionWithStride.cpp in Sources */, 92FF044923AA0B7100AC97F6 /* ShapeGatherND.cpp in Sources */, 489D7AB32550FDC900AD896A /* MetalPReLU.mm in Sources */, 19D0FE7028534C4500B74B1A /* MetalSoftmax.mm in Sources */, @@ -3787,13 +3758,11 @@ 92FF02C723AA0B5A00AC97F6 /* MNNCopyC4WithStride.S in Sources */, 92FF030923AA0B5A00AC97F6 /* MNNNV21ToBGRUnit.S in Sources */, CECF8C79299CAD9400D3875B /* hmac-sha.cpp in Sources */, - 92FF032623AA0B5A00AC97F6 /* MNNWinogradMatrixProductLeft.S in Sources */, 92FF04C023AA0BFB00AC97F6 /* Tensor.cpp in Sources */, CEE9B95B2A3AA4D4006438F2 /* MNNBilinearLineC8.S in Sources */, 92FF045D23AA0B7100AC97F6 /* ShapeCast.cpp in Sources */, 92FF032223AA0B5A00AC97F6 /* MNNMatrixAdd.S in Sources */, 92FF02D723AA0B5A00AC97F6 /* MNNConvRunForUnitDepthWiseUint8.S in Sources */, - 958B046429D2C89D00FC3AEF /* GemmInt8Executor.cpp in Sources */, 92FF026123AA0B5A00AC97F6 /* CPUCropAndResize.cpp in Sources */, 48FA474923AA127B00172C3B /* MathOp.cpp in Sources */, 4819FB3C24C69E680050BD09 /* GeometryBatchMatMul.cpp in Sources */, @@ -3826,7 +3795,6 @@ 92FF032823AA0B5A00AC97F6 /* MNNSamplerC1BilinearOpt.S in Sources */, 4896D37F25FE2A6B00717702 /* MNNConvRunForLineDepthwiseFP16.S in Sources */, 92FF044323AA0B7100AC97F6 /* ShapeTopKV2.cpp in Sources */, - 92FF02EC23AA0B5A00AC97F6 /* MNNWinogradMatrixProductRight.S in Sources */, 48C84BA1250F725600EE7666 /* InitNet.cpp in Sources */, 4894C6E927016F7200D8BE79 /* CPUResizeCache.cpp in Sources */, 4DD1791B2684815A00B0098F /* ShapeSetDiff1D.cpp in Sources */, @@ -4164,7 +4132,7 @@ METAL_LIBRARY_FILE_BASE = mnn; ONLY_ACTIVE_ARCH = YES; OTHER_CFLAGS = ""; - PRODUCT_BUNDLE_IDENTIFIER = com.taobao.mnn.abcde3vjk; + PRODUCT_BUNDLE_IDENTIFIER = com.taobao.mnn.abcde; PRODUCT_NAME = "$(TARGET_NAME:c99extidentifier)"; PROVISIONING_PROFILE_SPECIFIER = ""; "PROVISIONING_PROFILE_SPECIFIER[sdk=macosx*]" = ""; @@ -4260,7 +4228,7 @@ IPHONEOS_DEPLOYMENT_TARGET = 9.0; LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks"; OTHER_CPLUSPLUSFLAGS = "$(OTHER_CFLAGS)"; - PRODUCT_BUNDLE_IDENTIFIER = com.taobao.mnn.abcde3vjk; + PRODUCT_BUNDLE_IDENTIFIER = com.taobao.mnn.abcde; PRODUCT_NAME = "$(TARGET_NAME)"; TARGETED_DEVICE_FAMILY = "1,2"; }; @@ -4287,7 +4255,7 @@ IPHONEOS_DEPLOYMENT_TARGET = 9.0; LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks"; OTHER_CPLUSPLUSFLAGS = "$(OTHER_CFLAGS)"; - PRODUCT_BUNDLE_IDENTIFIER = com.taobao.mnn.abcde3vjk; + PRODUCT_BUNDLE_IDENTIFIER = com.taobao.mnn.abcde; PRODUCT_NAME = "$(TARGET_NAME)"; TARGETED_DEVICE_FAMILY = "1,2"; }; diff --git a/pymnn/CMakeLists.txt b/pymnn/CMakeLists.txt index 7e3694407..aedbbe378 100644 --- a/pymnn/CMakeLists.txt +++ b/pymnn/CMakeLists.txt @@ -16,6 +16,7 @@ option(PYMNN_TRAIN_API "MNN train API be exposed" OFF) option(PYMNN_INTERNAL_SERVING "Internal use only." OFF) option(PYMNN_OPENCV_API "MNN OpenCV API be exposed" ON) option(PYMNN_IMGCODECS "MNN IMGCODECS API be exposed" OFF) +option(PYMNN_AUDIO_API "MNN Audio API be exposed" ON) option(PYMNN_OHOS_INTERNAL "compile for harmony internal." OFF) if (PYMNN_OHOS_INTERNAL) @@ -91,6 +92,10 @@ if(PYMNN_CVCORE) target_compile_definitions(mnnpybridge PRIVATE PYMNN_CVCORE) endif() +if(PYMNN_AUDIO_API) + target_compile_definitions(mnnpybridge PRIVATE PYMNN_AUDIO_API) +endif() + if(PYMNN_INTERNAL_SERVING) message(STATUS "mnnpybridge define PYMNN_INTERNAL_SERVING") target_compile_definitions(mnnpybridge PRIVATE PYMNN_INTERNAL_SERVING) @@ -197,7 +202,7 @@ else() endif() export_headers(DIR ${CMAKE_SOURCE_DIR}/pip_package/MNN) else() - target_link_libraries(mnnpybridge PRIVATE log MNN MNN_Express MNNOpenCV) + target_link_libraries(mnnpybridge PRIVATE log MNN MNN_Express MNNOpenCV MNNAudio) if(PYMNN_USE_ALINNPYTHON) target_link_libraries(mnnpybridge PRIVATE AliNNPython) endif() diff --git a/pymnn/pip_package/MNN/__init__.py b/pymnn/pip_package/MNN/__init__.py index 5c2a4b1c4..89ed46b14 100644 --- a/pymnn/pip_package/MNN/__init__.py +++ b/pymnn/pip_package/MNN/__init__.py @@ -9,3 +9,4 @@ from . import optim from . import numpy from . import cv +from . import audio \ No newline at end of file diff --git a/pymnn/pip_package/MNN/audio/__init__.py b/pymnn/pip_package/MNN/audio/__init__.py new file mode 100644 index 000000000..98b970209 --- /dev/null +++ b/pymnn/pip_package/MNN/audio/__init__.py @@ -0,0 +1,96 @@ +from _mnncengine.audio import * +import _mnncengine.audio as _F +import MNN.expr as _expr +import MNN.numpy as _np +import MNN + +# Enum Types +# enum WINDOW_TYPE +HAMMING = 0 +HANNING = 1 +POVEY = 2 +RECTANGULAR = 3 +BLACKMAN = 4 +# enum PadValueMode +CONSTANT = 0 +REFLECT = 1 +SYMMETRIC = 2 +EDGE = 3 + +""" +Loads a portion of an audio file. + +Parameters: + filename (str): The path to the audio file. + frame_offset (int): The offset in frames from which to start loading the audio data. Default is 0. + num_frames (int): The number of frames to load. If set to -1, the entire audio file will be loaded. Default is -1. + +Returns: + The result of loading the specified portion of the audio var and the sample rate. +""" +def load(filename, sr = 0, frame_offset = 0, num_frames = -1): + return _F.load(filename, sr, frame_offset, num_frames) + +""" +Saves an audio var to a file. +Parameters: + filename (str): The path to the audio file. + audio (Var): The audio var to save. + sample_rate (int): The sample rate of the audio var. +Returns: + None +""" +def save(filename, audio, sample_rate): + return _F.save(filename, audio, sample_rate) + +""" +Generates a Hamming window. +Parameters: + window_size (int): The size of the window. + periodic (bool): Whether the window is periodic. Default is False. + alpha (float): The alpha parameter of the Hamming window. Default is 0.54. + beta (float): The beta parameter of the Hamming window. Default is 0.46. +Returns: + The Hamming window. +""" +def hamming_window(window_size, periodic = False, alpha = 0.54, beta = 0.46): + return _F.hamming_window(window_size, periodic, alpha, beta) + +""" +Generates a Hann window. +Parameters: + window_size (int): The size of the window. + periodic (bool): Whether the window is periodic. Default is False. +Returns: + The Hann window. +""" +def hanning_window(window_size, periodic = False): + return _F.hanning_window(window_size, periodic) + +def melscale_fbanks(n_mels, n_fft, sampe_rate = 16000, htk = True, norm = False, + f_min = 0.0, f_max = 0.0): + return _F.melscale_fbanks(n_mels, n_fft, sampe_rate, htk, norm, f_min, f_max) + +def spectrogram(waveform, n_fft = 400, hop_length = 0, win_length = 0, window_type = HANNING, + pad_left = 0, pad_right = 0, center = False, normalized = False, pad_mode = REFLECT, + power = 2.0): + return _F.spectrogram(waveform, n_fft, hop_length, win_length, window_type, pad_left, + pad_right, center, normalized, pad_mode, power) + + +def mel_spectrogram(waveform, n_mels, n_fft, sampe_rate = 16000, htk = True, norm = False, + f_min = 0.0, f_max = 0.0, hop_length = 0, win_length = 0, window_type = HANNING, + pad_left = 0, pad_right = 0, center = False, normalized = False, pad_mode = REFLECT, + power = 2.0): + return _F.mel_spectrogram(waveform, n_mels, n_fft, sampe_rate, htk, norm, f_min, f_max, + hop_length, win_length, window_type, pad_left, pad_right, center, + normalized, pad_mode, power) + +def fbank(waveform, sample_rate = 16000, n_mels = 80, n_fft = 400, hop_length = 160, + dither = 0.0, preemphasis = 0.97): + return _F.fbank(waveform, sample_rate, n_mels, n_fft, hop_length, dither, preemphasis) + + +def whisper_fbank(waveform, sample_rate = 16000, n_mels = 128, n_fft = 400, + hop_length = 160, chunk_len = 30): + return _F.whisper_fbank(waveform, sample_rate, n_mels, n_fft, hop_length, chunk_len) \ No newline at end of file diff --git a/pymnn/pip_package/MNN/llm/__init__.py b/pymnn/pip_package/MNN/llm/__init__.py index f144b3e06..ebf4cf84e 100644 --- a/pymnn/pip_package/MNN/llm/__init__.py +++ b/pymnn/pip_package/MNN/llm/__init__.py @@ -57,7 +57,25 @@ def response(self, prompt, stream = False): ''' return super.response(prompt, stream) -def create(config_path): + def txt_embedding(self, prompt): + ''' + get prompt's embedding + + Parameters + ---------- + prompt : input prompt + + Returns + ------- + res : embedding var + + Example: + ------- + >>> res = qwen.txt_embedding('Hello') + ''' + return super.txt_embedding(prompt) + +def create(config_path, embedding_model = False): ''' create LLM instance by `config.json` @@ -73,4 +91,4 @@ def create(config_path): ------- >>> qwen = llm.create('./qwen-1.8b-int4/config.json') ''' - return _F.create(config_path) \ No newline at end of file + return _F.create(config_path, embedding_model) \ No newline at end of file diff --git a/pymnn/pip_package/build_deps.py b/pymnn/pip_package/build_deps.py index 6ee2398a5..c3fa00278 100644 --- a/pymnn/pip_package/build_deps.py +++ b/pymnn/pip_package/build_deps.py @@ -99,7 +99,7 @@ def build_deps(): if IS_WINDOWS: os.system('cmake -G "Ninja" ' + extra_opts +' -DMNN_BUILD_TRAIN=ON -DMNN_BUILD_CONVERTER=on -DMNN_BUILD_TORCH=OFF\ -DMNN_BUILD_SHARED_LIBS=OFF -DCMAKE_BUILD_TYPE=Release -DMNN_WIN_RUNTIME_MT=ON\ - -DMNN_BUILD_OPENCV=ON -DMNN_IMGCODECS=ON -DMNN_AAPL_FMWK=OFF -DMNN_SEP_BUILD=OFF .. && ninja MNN MNNConvertDeps') + -DMNN_BUILD_OPENCV=ON -DMNN_IMGCODECS=ON -DMNN_BUILD_AUDIO=ON -DMNN_AAPL_FMWK=OFF -DMNN_SEP_BUILD=OFF .. && ninja MNN MNNConvertDeps') elif IS_LINUX: extra_opts += '-DMNN_TENSORRT=ON \ -DCMAKE_LIBRARY_PATH=/usr/local/cuda/lib64/stubs/ ' if USE_TRT else ' ' @@ -113,14 +113,14 @@ def build_deps(): os.system('cmake ' + extra_opts + '-DMNN_BUILD_CONVERTER=on -DMNN_BUILD_TRAIN=ON -DCMAKE_BUILD_TYPE=Release \ -DMNN_BUILD_SHARED_LIBS=OFF -DMNN_AAPL_FMWK=OFF -DMNN_SEP_BUILD=OFF -DMNN_BUILD_OPENCV=ON -DMNN_IMGCODECS=ON \ - .. && make MNN MNNTrain MNNConvertDeps -j32') + -DMNN_BUILD_AUDIO=ON .. && make MNN MNNTrain MNNConvertDeps -j32') else: extra_opts += ' -DMNN_INTERNAL=ON ' if USE_INTERNAL else ' ' extra_opts += ' -DMNN_BUILD_TORCH=ON ' if USE_TORCH else ' ' print(extra_opts) os.system('cmake ' + extra_opts + '-DMNN_BUILD_CONVERTER=on -DMNN_BUILD_TRAIN=ON -DCMAKE_BUILD_TYPE=Release \ -DMNN_BUILD_SHARED_LIBS=ON -DMNN_AAPL_FMWK=OFF -DMNN_SEP_BUILD=OFF\ - -DMNN_BUILD_OPENCV=ON -DMNN_IMGCODECS=ON \ + -DMNN_BUILD_OPENCV=ON -DMNN_IMGCODECS=ON -DMNN_BUILD_AUDIO=ON\ .. && make MNN MNNConvertDeps -j64') ################################################################################ # Building dependent libraries diff --git a/pymnn/pip_package/setup.py b/pymnn/pip_package/setup.py index 633968edb..3f818b733 100644 --- a/pymnn/pip_package/setup.py +++ b/pymnn/pip_package/setup.py @@ -166,7 +166,7 @@ def configure_extension_build(): ] if check_env_flag('WERROR'): extra_compile_args.append('-Werror') - extra_compile_args += ['-DPYMNN_EXPR_API', '-DPYMNN_NUMPY_USABLE', '-DPYMNN_OPENCV_API'] + extra_compile_args += ['-DPYMNN_EXPR_API', '-DPYMNN_NUMPY_USABLE', '-DPYMNN_OPENCV_API', '-DPYMNN_AUDIO_API'] if IS_LINUX and USE_INTERNAL: extra_compile_args += ['-DPYMNN_INTERNAL_SERVING'] if args.env == 'daily': @@ -177,6 +177,7 @@ def configure_extension_build(): engine_library_dirs = [os.path.join(root_dir, BUILD_DIR)] engine_library_dirs += [os.path.join(root_dir, BUILD_DIR, "tools", "train")] engine_library_dirs += [os.path.join(root_dir, BUILD_DIR, "tools", "cv")] + engine_library_dirs += [os.path.join(root_dir, BUILD_DIR, "tools", "audio")] engine_library_dirs += [os.path.join(root_dir, BUILD_DIR, "source", "backend", "tensorrt")] engine_library_dirs += [os.path.join(root_dir, BUILD_DIR, "source", "backend", "cuda")] if USE_TRT or USE_CUDA: @@ -214,6 +215,8 @@ def configure_extension_build(): engine_include_dirs += [os.path.join(root_dir, "3rd_party", "rapidjson")] # cv include engine_include_dirs += [os.path.join(root_dir, "tools", "cv", "include")] + # audio include + engine_include_dirs += [os.path.join(root_dir, "tools", "audio", "include")] # llm include engine_include_dirs += [os.path.join(root_dir, "transformers", "llm", "engine", "include")] engine_include_dirs += [os.path.join(root_dir, "3rd_party")] diff --git a/pymnn/src/MNN.cc b/pymnn/src/MNN.cc index 1ec9a15e1..1f065e27e 100644 --- a/pymnn/src/MNN.cc +++ b/pymnn/src/MNN.cc @@ -22,6 +22,9 @@ using namespace MNN::Express; #ifdef PYMNN_OPENCV_API #include "cv/cv.hpp" #endif +#ifdef PYMNN_AUDIO_API +#include "audio/audio.hpp" +#endif #endif // PYMNN_EXPR_API #ifdef BUILD_OPTYPE @@ -64,6 +67,9 @@ using RegularizationMethod = ParameterOptimizer::RegularizationMethod; #ifdef PYMNN_OPENCV_API #include "cv.h" #endif +#ifdef PYMNN_AUDIO_API +#include "audio.h" +#endif #endif #ifdef PYMNN_LLM_API @@ -1587,7 +1593,8 @@ static PyObject* PyMNNTensor_repr(PyObject *self) { #ifdef PYMNN_NUMPY_USABLE auto content = PyMNNTensor_getNumpyData(((PyMNNTensor*)self), NULL); #else - auto content = PyMNNVar_read_as_tuple((PyMNNVar*)self, NULL); + // print shape of tensor + auto content = PyMNNTensor_getShape((PyMNNTensor*)self, NULL); #endif auto reprfunc = PyObject_GetAttrString(content, "__repr__"); auto str = PyEval_CallObject(reprfunc, NULL); @@ -2713,6 +2720,15 @@ PyMODINIT_FUNC MOD_INIT_FUNC(void) { def_method(cv_module, &PyMNNCV_methods[i]); } #endif +#ifdef PYMNN_AUDIO_API + // audio submodule + auto audio_module = def_submodule(m, "audio"); + // add methods of audio + constexpr int audio_method_num = sizeof(PyMNNAUDIO_methods) / sizeof(PyMethodDef); + for (int i = 0; i < audio_method_num; i++) { + def_method(audio_module, &PyMNNAUDIO_methods[i]); + } +#endif #endif #ifdef PYMNN_LLM_API // llm submodule diff --git a/pymnn/src/audio.h b/pymnn/src/audio.h new file mode 100644 index 000000000..a7afc6ec4 --- /dev/null +++ b/pymnn/src/audio.h @@ -0,0 +1,105 @@ +// MNN AUDIO +static PyObject *PyMNNAUDIO_load(PyObject *self, PyObject *args) { + const char *filename = NULL; + int sr = 0, frame_offset = 0, num_frames = -1; + if (PyArg_ParseTuple(args, "s|iii", &filename, &sr, &frame_offset, &num_frames) && filename) { + return toPyObj(AUDIO::load(filename, sr, frame_offset, num_frames)); + } + PyMNN_ERROR("load require args: (string, int, int, int)"); +} +static PyObject *PyMNNAUDIO_save(PyObject *self, PyObject *args) { + const char *filename = NULL; + PyObject *audio = nullptr; + int sample_rate = 0; + if (PyArg_ParseTuple(args, "sOi", &filename, &audio, &sample_rate) && filename && isVar(audio)) { + return toPyObj(AUDIO::save(filename, toVar(audio), sample_rate)); + } + PyMNN_ERROR("save require args: (string, Var, int)"); +} +static PyObject *PyMNNAUDIO_hamming_window(PyObject *self, PyObject *args) { + int window_size = 0, periodic = 0; + float alpha = 0.54, beta = 0.46; + if (PyArg_ParseTuple(args, "i|iff", &window_size, &periodic, &alpha, &beta)) { + return toPyObj(AUDIO::hamming_window(window_size, periodic, alpha, beta)); + } + PyMNN_ERROR("hamming_window require args: (int, |bool, float, float)"); +} +static PyObject *PyMNNAUDIO_hann_window(PyObject *self, PyObject *args) { + int window_size = 0, periodic = 0; + if (PyArg_ParseTuple(args, "i|i", &window_size, &periodic)) { + return toPyObj(AUDIO::hann_window(window_size, periodic)); + } + PyMNN_ERROR("hann_window require args: (int, |bool)"); +} +static PyObject *PyMNNAUDIO_melscale_fbanks(PyObject *self, PyObject *args) { + AUDIO::MelscaleParams mel; + if (PyArg_ParseTuple(args, "ii|ifff", &mel.n_mels, &mel.n_fft, &mel.sample_rate, &mel.htk, &mel.norm, &mel.f_min, &mel.f_max)) { + return toPyObj(AUDIO::melscale_fbanks(&mel)); + } + PyMNN_ERROR("melscale_fbanks require args: (int, int, |int, bool, bool, float, float)"); +} +static PyObject *PyMNNAUDIO_spectrogram(PyObject *self, PyObject *args) { + PyObject *waveform = nullptr; + AUDIO::SpectrogramParams spec; + if (PyArg_ParseTuple(args, "O|iiiiiiiiiif", &waveform, &spec.n_fft, &spec.hop_length, &spec.win_length, + &spec.window_type, &spec.pad_left, &spec.pad_right, &spec.center, &spec.normalized, + &spec.pad_mode, &spec.power) && + isVar(waveform)) { + return toPyObj(AUDIO::spectrogram(toVar(waveform), &spec)); + } + PyMNN_ERROR("spectrogram require args: (Var, |int, int, int, int, int, int, bool, bool, PadValueMode, float)"); +} +static PyObject *PyMNNAUDIO_mel_spectrogram(PyObject *self, PyObject *args) { + PyObject *waveform = nullptr; + AUDIO::MelscaleParams mel; + AUDIO::SpectrogramParams spec; + int n_fft = 400; + if (PyArg_ParseTuple(args, "O|iiifiiifiiiii", &waveform, &mel.n_mels, &mel.n_fft, &mel.sample_rate, &mel.htk, + &mel.norm, &mel.f_min, &mel.f_max, &spec.hop_length, &spec.win_length, &spec.window_type, + &spec.pad_left, &spec.pad_right, &spec.center, &spec.normalized, &spec.pad_mode, + &spec.power) && + isVar(waveform)) { + spec.n_fft = mel.n_fft; + return toPyObj(AUDIO::mel_spectrogram(toVar(waveform), &mel, &spec)); + } + PyMNN_ERROR( + "mel_spectrogram require args: (Var, |int, bool, bool, float, float, int, int, int, int, int, bool, bool, " + "PadValueMode, float)" + "int)"); +} +static PyObject *PyMNNAUDIO_fbank(PyObject *self, PyObject *args) { + PyObject *waveform = nullptr; + int sample_rate = 16000, n_mels = 80, n_fft = 400, hop_length = 160; + float dither = 0.0, preemphasis = 0.97; + if (PyArg_ParseTuple(args, "O|iiiiff", &waveform, &sample_rate, &n_mels, &n_fft, &hop_length, &dither, + &preemphasis) && + isVar(waveform)) { + return toPyObj( + AUDIO::fbank(toVar(waveform), sample_rate, n_mels, n_fft, hop_length, dither, preemphasis)); + } + PyMNN_ERROR("fbank require args: (Var, |int, int, int, int, float, float)"); +} + +static PyObject *PyMNNAUDIO_whisper_fbank(PyObject *self, PyObject *args) { + PyObject *waveform = nullptr; + int sample_rate = 16000, n_mels = 128, n_fft = 400, hop_length = 160, chunk_len = 30; + if (PyArg_ParseTuple(args, "O|iiiii", &waveform, &sample_rate, &n_mels, &n_fft, &hop_length, &chunk_len) && + isVar(waveform)) { + return toPyObj(AUDIO::whisper_fbank(toVar(waveform), sample_rate, n_mels, n_fft, hop_length, chunk_len)); + } + PyMNN_ERROR("whisper_fbank require args: (Var, |int, int, int, int, int)"); +} + +static PyMethodDef PyMNNAUDIO_methods[] = { + register_methods(AUDIO, + load, "load", + save, "save", + hamming_window, "hamming_window", + hann_window, "hann_window", + melscale_fbanks, "melscale_fbanks", + spectrogram, "spectrogram", + mel_spectrogram, "mel_spectrogram", + fbank, "fbank", + whisper_fbank, "whisper_fbank" + ) +}; diff --git a/pymnn/src/llm.h b/pymnn/src/llm.h index 0d363fe98..93329cd60 100644 --- a/pymnn/src/llm.h +++ b/pymnn/src/llm.h @@ -4,6 +4,7 @@ typedef struct { PyObject_HEAD MNN::Transformer::Llm* llm; + bool is_embedding = false; } LLM; static PyObject* PyMNNLLM_new(struct _typeobject *type, PyObject *args, PyObject *kwds) { @@ -25,6 +26,9 @@ static PyObject* PyMNNLLM_load(LLM *self, PyObject *args) { } static PyObject* PyMNNLLM_forward(LLM *self, PyObject *args) { + if (self->is_embedding) { + Py_RETURN_NONE; + } PyObject *input_ids = nullptr; if (!PyArg_ParseTuple(args, "O", &input_ids) && isInts(input_ids)) { Py_RETURN_NONE; @@ -37,6 +41,9 @@ static PyObject* PyMNNLLM_forward(LLM *self, PyObject *args) { } static PyObject* PyMNNLLM_generate(LLM *self, PyObject *args) { + if (self->is_embedding) { + Py_RETURN_NONE; + } PyObject *input_ids = nullptr; if (!PyArg_ParseTuple(args, "O", &input_ids) && isInts(input_ids)) { Py_RETURN_NONE; @@ -46,6 +53,9 @@ static PyObject* PyMNNLLM_generate(LLM *self, PyObject *args) { } static PyObject* PyMNNLLM_response(LLM *self, PyObject *args) { + if (self->is_embedding) { + Py_RETURN_NONE; + } const char* query = NULL; int stream = 0; if (!PyArg_ParseTuple(args, "s|p", &query, &stream)) { @@ -57,6 +67,9 @@ static PyObject* PyMNNLLM_response(LLM *self, PyObject *args) { } static PyObject* PyMNNLLM_tokenizer_encode(LLM *self, PyObject *args) { + if (self->is_embedding) { + Py_RETURN_NONE; + } const char* prompt = NULL; int use_template = 0; if (!PyArg_ParseTuple(args, "s|p", &prompt, &use_template)) { @@ -67,6 +80,9 @@ static PyObject* PyMNNLLM_tokenizer_encode(LLM *self, PyObject *args) { } static PyObject* PyMNNLLM_tokenizer_decode(LLM *self, PyObject *args) { + if (self->is_embedding) { + Py_RETURN_NONE; + } PyObject *id = nullptr; if (!PyArg_ParseTuple(args, "O", &id) && isInt(id)) { Py_RETURN_NONE; @@ -75,6 +91,19 @@ static PyObject* PyMNNLLM_tokenizer_decode(LLM *self, PyObject *args) { return string2Object(query); } +static PyObject* PyMNNLLM_txt_embedding(LLM *self, PyObject *args) { + if (!self->is_embedding) { + Py_RETURN_NONE; + } + const char* query = NULL; + if (!PyArg_ParseTuple(args, "s", &query)) { + Py_RETURN_NONE; + } + auto embeds = getVar(); + *(embeds->var) = ((MNN::Transformer::Embedding*)self->llm)->txt_embedding(query); + return (PyObject *)embeds; +} + static PyMethodDef PyMNNLLM_methods[] = { {"load", (PyCFunction)PyMNNLLM_load, METH_VARARGS, "load model."}, {"forward", (PyCFunction)PyMNNLLM_forward, METH_VARARGS, "forward `logits` by `input_ids`."}, @@ -82,6 +111,7 @@ static PyMethodDef PyMNNLLM_methods[] = { {"response", (PyCFunction)PyMNNLLM_response, METH_VARARGS, "response `query` without hsitory."}, {"tokenizer_encode", (PyCFunction)PyMNNLLM_tokenizer_encode, METH_VARARGS, "tokenizer encode."}, {"tokenizer_decode", (PyCFunction)PyMNNLLM_tokenizer_decode, METH_VARARGS, "tokenizer decode."}, + {"txt_embedding", (PyCFunction)PyMNNLLM_txt_embedding, METH_VARARGS, "txt embedding."}, {NULL} /* Sentinel */ }; @@ -131,14 +161,21 @@ static PyObject* PyMNNLLM_create(PyObject *self, PyObject *args) { return NULL; } const char* path = NULL; - if (!PyArg_ParseTuple(args, "s", &path)) { + int embedding_model = 0; + if (!PyArg_ParseTuple(args, "s|p", &path, &embedding_model)) { return NULL; } LLM *llm = (LLM *)PyObject_Call((PyObject*)&PyMNNLLM, PyTuple_New(0), NULL); if (!llm) { return NULL; } - llm->llm = MNN::Transformer::Llm::createLLM(path); + if (embedding_model) { + llm->llm = MNN::Transformer::Embedding::createEmbedding(path); + llm->is_embedding = true; + } else { + llm->llm = MNN::Transformer::Llm::createLLM(path); + } + return (PyObject*)llm; } diff --git a/schema/current/MNN_generated.h b/schema/current/MNN_generated.h index bb4f48a44..1c9647c87 100644 --- a/schema/current/MNN_generated.h +++ b/schema/current/MNN_generated.h @@ -33,6 +33,9 @@ struct FmhaV2ParamT; struct FmhcaParam; struct FmhcaParamT; +struct StftParam; +struct StftParamT; + struct WhileParam; struct WhileParamT; @@ -78,6 +81,8 @@ inline const flatbuffers::TypeTable *FmhaV2ParamTypeTable(); inline const flatbuffers::TypeTable *FmhcaParamTypeTable(); +inline const flatbuffers::TypeTable *StftParamTypeTable(); + inline const flatbuffers::TypeTable *WhileParamTypeTable(); inline const flatbuffers::TypeTable *IfParamTypeTable(); @@ -252,6 +257,7 @@ enum OpType { OpType_Svd = 153, OpType_Histogram = 154, OpType_DynamicQuant = 155, + OpType_Stft = 156, OpType_Plugin = 256, OpType_Select = 257, OpType_ZerosLike = 258, @@ -287,7 +293,7 @@ enum OpType { OpType_MAX = OpType_GridSample }; -inline const OpType (&EnumValuesOpType())[182] { +inline const OpType (&EnumValuesOpType())[183] { static const OpType values[] = { OpType_AbsVal, OpType_QuantizedAdd, @@ -440,6 +446,7 @@ inline const OpType (&EnumValuesOpType())[182] { OpType_Svd, OpType_Histogram, OpType_DynamicQuant, + OpType_Stft, OpType_Plugin, OpType_Select, OpType_ZerosLike, @@ -633,7 +640,7 @@ inline const char * const *EnumNamesOpType() { "Svd", "Histogram", "DynamicQuant", - "", + "Stft", "", "", "", @@ -1193,11 +1200,12 @@ enum OpParameter { OpParameter_FmhaV2Param = 96, OpParameter_FmhcaParam = 97, OpParameter_AttentionParam = 98, + OpParameter_StftParam = 99, OpParameter_MIN = OpParameter_NONE, - OpParameter_MAX = OpParameter_AttentionParam + OpParameter_MAX = OpParameter_StftParam }; -inline const OpParameter (&EnumValuesOpParameter())[99] { +inline const OpParameter (&EnumValuesOpParameter())[100] { static const OpParameter values[] = { OpParameter_NONE, OpParameter_QuantizedAdd, @@ -1297,7 +1305,8 @@ inline const OpParameter (&EnumValuesOpParameter())[99] { OpParameter_GroupNorm, OpParameter_FmhaV2Param, OpParameter_FmhcaParam, - OpParameter_AttentionParam + OpParameter_AttentionParam, + OpParameter_StftParam }; return values; } @@ -1403,13 +1412,14 @@ inline const char * const *EnumNamesOpParameter() { "FmhaV2Param", "FmhcaParam", "AttentionParam", + "StftParam", nullptr }; return names; } inline const char *EnumNameOpParameter(OpParameter e) { - if (e < OpParameter_NONE || e > OpParameter_AttentionParam) return ""; + if (e < OpParameter_NONE || e > OpParameter_StftParam) return ""; const size_t index = static_cast(e); return EnumNamesOpParameter()[index]; } @@ -1810,6 +1820,10 @@ template<> struct OpParameterTraits { static const OpParameter enum_value = OpParameter_AttentionParam; }; +template<> struct OpParameterTraits { + static const OpParameter enum_value = OpParameter_StftParam; +}; + struct OpParameterUnion { OpParameter type; void *value; @@ -2625,6 +2639,14 @@ struct OpParameterUnion { return type == OpParameter_AttentionParam ? reinterpret_cast(value) : nullptr; } + StftParamT *AsStftParam() { + return type == OpParameter_StftParam ? + reinterpret_cast(value) : nullptr; + } + const StftParamT *AsStftParam() const { + return type == OpParameter_StftParam ? + reinterpret_cast(value) : nullptr; + } }; bool VerifyOpParameter(flatbuffers::Verifier &verifier, const void *obj, OpParameter type); @@ -3084,6 +3106,82 @@ inline flatbuffers::Offset CreateFmhcaParam( flatbuffers::Offset CreateFmhcaParam(flatbuffers::FlatBufferBuilder &_fbb, const FmhcaParamT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +struct StftParamT : public flatbuffers::NativeTable { + typedef StftParam TableType; + int32_t n_fft; + int32_t hop_length; + bool abs; + StftParamT() + : n_fft(0), + hop_length(0), + abs(true) { + } +}; + +struct StftParam FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef StftParamT NativeTableType; + static const flatbuffers::TypeTable *MiniReflectTypeTable() { + return StftParamTypeTable(); + } + int32_t n_fft() const { + return GetField(4, 0); + } + int32_t hop_length() const { + return GetField(6, 0); + } + bool abs() const { + return GetField(8, 1) != 0; + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, 4) && + VerifyField(verifier, 6) && + VerifyField(verifier, 8) && + verifier.EndTable(); + } + StftParamT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(StftParamT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const StftParamT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct StftParamBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_n_fft(int32_t n_fft) { + fbb_.AddElement(4, n_fft, 0); + } + void add_hop_length(int32_t hop_length) { + fbb_.AddElement(6, hop_length, 0); + } + void add_abs(bool abs) { + fbb_.AddElement(8, static_cast(abs), 1); + } + explicit StftParamBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + StftParamBuilder &operator=(const StftParamBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateStftParam( + flatbuffers::FlatBufferBuilder &_fbb, + int32_t n_fft = 0, + int32_t hop_length = 0, + bool abs = true) { + StftParamBuilder builder_(_fbb); + builder_.add_hop_length(hop_length); + builder_.add_n_fft(n_fft); + builder_.add_abs(abs); + return builder_.Finish(); +} + +flatbuffers::Offset CreateStftParam(flatbuffers::FlatBufferBuilder &_fbb, const StftParamT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + struct WhileParamT : public flatbuffers::NativeTable { typedef WhileParam TableType; std::string cond_graph; @@ -3863,6 +3961,9 @@ struct Op FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const AttentionParam *main_as_AttentionParam() const { return main_type() == OpParameter_AttentionParam ? static_cast(main()) : nullptr; } + const StftParam *main_as_StftParam() const { + return main_type() == OpParameter_StftParam ? static_cast(main()) : nullptr; + } const flatbuffers::String *name() const { return GetPointer(10); } @@ -4292,6 +4393,10 @@ template<> inline const AttentionParam *Op::main_as() const { return main_as_AttentionParam(); } +template<> inline const StftParam *Op::main_as() const { + return main_as_StftParam(); +} + struct OpBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; @@ -5167,6 +5272,38 @@ inline flatbuffers::Offset CreateFmhcaParam(flatbuffers::FlatBufferB _heads); } +inline StftParamT *StftParam::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new StftParamT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void StftParam::UnPackTo(StftParamT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = n_fft(); _o->n_fft = _e; }; + { auto _e = hop_length(); _o->hop_length = _e; }; + { auto _e = abs(); _o->abs = _e; }; +} + +inline flatbuffers::Offset StftParam::Pack(flatbuffers::FlatBufferBuilder &_fbb, const StftParamT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateStftParam(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateStftParam(flatbuffers::FlatBufferBuilder &_fbb, const StftParamT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const StftParamT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _n_fft = _o->n_fft; + auto _hop_length = _o->hop_length; + auto _abs = _o->abs; + return MNN::CreateStftParam( + _fbb, + _n_fft, + _hop_length, + _abs); +} + inline WhileParamT *WhileParam::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new WhileParamT(); UnPackTo(_o, _resolver); @@ -6015,6 +6152,10 @@ inline bool VerifyOpParameter(flatbuffers::Verifier &verifier, const void *obj, auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } + case OpParameter_StftParam: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } default: return false; } } @@ -6425,6 +6566,10 @@ inline void *OpParameterUnion::UnPack(const void *obj, OpParameter type, const f auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } + case OpParameter_StftParam: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } default: return nullptr; } } @@ -6823,6 +6968,10 @@ inline flatbuffers::Offset OpParameterUnion::Pack(flatbuffers::FlatBufferB auto ptr = reinterpret_cast(value); return CreateAttentionParam(_fbb, ptr, _rehasher).Union(); } + case OpParameter_StftParam: { + auto ptr = reinterpret_cast(value); + return CreateStftParam(_fbb, ptr, _rehasher).Union(); + } default: return 0; } } @@ -7221,6 +7370,10 @@ inline OpParameterUnion::OpParameterUnion(const OpParameterUnion &u) FLATBUFFERS value = new AttentionParamT(*reinterpret_cast(u.value)); break; } + case OpParameter_StftParam: { + value = new StftParamT(*reinterpret_cast(u.value)); + break; + } default: break; } @@ -7718,6 +7871,11 @@ inline void OpParameterUnion::Reset() { delete ptr; break; } + case OpParameter_StftParam: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } default: break; } value = nullptr; @@ -7907,12 +8065,13 @@ inline const flatbuffers::TypeTable *OpTypeTypeTable() { { flatbuffers::ET_INT, 0, 0 }, { flatbuffers::ET_INT, 0, 0 }, { flatbuffers::ET_INT, 0, 0 }, + { flatbuffers::ET_INT, 0, 0 }, { flatbuffers::ET_INT, 0, 0 } }; static const flatbuffers::TypeFunction type_refs[] = { OpTypeTypeTable }; - static const int64_t values[] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 299, 300, 301, 302, 303, 304, 512, 513, 514, 515, 516, 517, 518, 600, 601, 603, 604 }; + static const int64_t values[] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 299, 300, 301, 302, 303, 304, 512, 513, 514, 515, 516, 517, 518, 600, 601, 603, 604 }; static const char * const names[] = { "AbsVal", "QuantizedAdd", @@ -8065,6 +8224,7 @@ inline const flatbuffers::TypeTable *OpTypeTypeTable() { "Svd", "Histogram", "DynamicQuant", + "Stft", "Plugin", "Select", "ZerosLike", @@ -8098,7 +8258,7 @@ inline const flatbuffers::TypeTable *OpTypeTypeTable() { "GridSample" }; static const flatbuffers::TypeTable tt = { - flatbuffers::ST_ENUM, 182, type_codes, type_refs, values, names + flatbuffers::ST_ENUM, 183, type_codes, type_refs, values, names }; return &tt; } @@ -8203,7 +8363,8 @@ inline const flatbuffers::TypeTable *OpParameterTypeTable() { { flatbuffers::ET_SEQUENCE, 0, 94 }, { flatbuffers::ET_SEQUENCE, 0, 95 }, { flatbuffers::ET_SEQUENCE, 0, 96 }, - { flatbuffers::ET_SEQUENCE, 0, 97 } + { flatbuffers::ET_SEQUENCE, 0, 97 }, + { flatbuffers::ET_SEQUENCE, 0, 98 } }; static const flatbuffers::TypeFunction type_refs[] = { QuantizedAddTypeTable, @@ -8303,7 +8464,8 @@ inline const flatbuffers::TypeTable *OpParameterTypeTable() { GroupNormTypeTable, FmhaV2ParamTypeTable, FmhcaParamTypeTable, - AttentionParamTypeTable + AttentionParamTypeTable, + StftParamTypeTable }; static const char * const names[] = { "NONE", @@ -8404,10 +8566,11 @@ inline const flatbuffers::TypeTable *OpParameterTypeTable() { "GroupNorm", "FmhaV2Param", "FmhcaParam", - "AttentionParam" + "AttentionParam", + "StftParam" }; static const flatbuffers::TypeTable tt = { - flatbuffers::ST_UNION, 99, type_codes, type_refs, nullptr, names + flatbuffers::ST_UNION, 100, type_codes, type_refs, nullptr, names }; return &tt; } @@ -8550,6 +8713,23 @@ inline const flatbuffers::TypeTable *FmhcaParamTypeTable() { return &tt; } +inline const flatbuffers::TypeTable *StftParamTypeTable() { + static const flatbuffers::TypeCode type_codes[] = { + { flatbuffers::ET_INT, 0, -1 }, + { flatbuffers::ET_INT, 0, -1 }, + { flatbuffers::ET_BOOL, 0, -1 } + }; + static const char * const names[] = { + "n_fft", + "hop_length", + "abs" + }; + static const flatbuffers::TypeTable tt = { + flatbuffers::ST_TABLE, 3, type_codes, nullptr, nullptr, names + }; + return &tt; +} + inline const flatbuffers::TypeTable *WhileParamTypeTable() { static const flatbuffers::TypeCode type_codes[] = { { flatbuffers::ET_STRING, 0, -1 }, diff --git a/schema/default/MNN.fbs b/schema/default/MNN.fbs index d415bddcb..e5e588a66 100644 --- a/schema/default/MNN.fbs +++ b/schema/default/MNN.fbs @@ -168,6 +168,7 @@ enum OpType : int { Svd = 153, Histogram = 154, DynamicQuant = 155, + Stft = 156, Plugin = 256, //The Type load from plugin //Training Op Start from 257 @@ -239,6 +240,12 @@ table FmhcaParam { heads: int; } +table StftParam { + n_fft: int; + hop_length: int; + abs: bool = true; +} + table WhileParam { // The name of condition subgraph. cond_graph: string; @@ -414,7 +421,8 @@ union OpParameter { GroupNorm, FmhaV2Param, FmhcaParam, - AttentionParam + AttentionParam, + StftParam } table Op { diff --git a/source/backend/cpu/CPUBackend.hpp b/source/backend/cpu/CPUBackend.hpp index b23bc8ead..3ec321c99 100644 --- a/source/backend/cpu/CPUBackend.hpp +++ b/source/backend/cpu/CPUBackend.hpp @@ -237,6 +237,9 @@ class CastWrapExecution : public Execution { CPUBackend::addCreator(opType, &_temp); \ } +#define REGISTER_CPU_OP_CREATOR_AUDIO(name, opType) \ + REGISTER_CPU_OP_CREATOR(name, opType) + } // namespace MNN #endif /* CPUBackend_hpp */ diff --git a/source/backend/cpu/CPUBinaryInt8.cpp b/source/backend/cpu/CPUBinaryInt8.cpp index a1da4a2e4..cf46a1af5 100644 --- a/source/backend/cpu/CPUBinaryInt8.cpp +++ b/source/backend/cpu/CPUBinaryInt8.cpp @@ -80,16 +80,16 @@ ErrorCode CPUBinaryInt8::onExecute(const std::vector& inputs, const std int inpBytes = 1; int outBytes = 1; + QuanPrePostParameters params; + + params.inputScale = mInputScales.data(); + params.outputScale = mOutputScales.data(); + params.outputZeroPoint = mOutputZeros.data(); + params.inputZeroPoint = mInputZeros.data(); + params.minValue = (ssize_t)mMinValue; + params.maxValue = (ssize_t)TensorUtils::getDescribe(outputs[0])->quantAttr->max; MNN_CONCURRENCY_BEGIN(tId, schedule.second) { - QuanPrePostParameters params; - - params.inputScale = mInputScales.data(); - params.outputScale = mOutputScales.data(); - params.outputZeroPoint = mOutputZeros.data(); - params.inputZeroPoint = mInputZeros.data(); - params.minValue = (ssize_t)mMinValue; - params.maxValue = (ssize_t)TensorUtils::getDescribe(outputs[0])->quantAttr->max; int start = schedule.first * (int)tId; int realSize = schedule.first; diff --git a/source/backend/cpu/CPUDeconvolution.cpp b/source/backend/cpu/CPUDeconvolution.cpp index bdef005cb..8bfcf0738 100644 --- a/source/backend/cpu/CPUDeconvolution.cpp +++ b/source/backend/cpu/CPUDeconvolution.cpp @@ -18,7 +18,6 @@ #include "core/ConvolutionCommon.hpp" #include "compute/CommonOptFunction.h" #include "compute/ConvOpt.h" -#include "compute/DeconvolutionWithStride.hpp" //#define MNN_OPEN_TIME_TRACE #include @@ -83,63 +82,13 @@ static void _transformWeight(const uint8_t* tempWeight, uint8_t* dest, int outpu //printf("%d - %d - %d - %d\n", outputCount, srcCount, fh, fw); core->MNNPackForMatMul_B((float*)dest, (const float*)cache, outputC4 * fw * fh * core->pack, srcCount, false); } -// Int8 Weight. -static void _reorderWeightInt8(Backend* bn, const Convolution2DCommon* common, const int8_t* srcPtr, - std::shared_ptr& weight) { - auto core = static_cast(bn)->int8Functions(); - auto gcore = static_cast(bn)->functions(); - int UNIT, SRC_UNIT, DST_XUNIT; - core->MNNGetGemmUnit(&UNIT, &SRC_UNIT, &DST_XUNIT); - UNIT = gcore->pack; - int oc = common->outputCount(), ic = common->inputCount(), kernelCount = common->kernelX() * common->kernelY(); - std::vector shape = {UP_DIV(oc, UNIT), UP_DIV(ic, SRC_UNIT) * kernelCount, UNIT, SRC_UNIT}; - - weight.reset(Tensor::createDevice(shape)); - bool succ = bn->onAcquireBuffer(weight.get(), Backend::STATIC); - if (!succ) { - MNN_ERROR("Memory not enough"); - return; - } - auto dstPtr = weight->host(); - ::memset(dstPtr, 0, weight->size()); - - int icDiv = UP_DIV(ic, SRC_UNIT); - for (int k = 0; k < kernelCount; ++k) { - auto srcK = srcPtr + k; - auto dstK = dstPtr + k * SRC_UNIT * UNIT * icDiv; - for (int x = 0; x < oc; ++x) { - int xout = x / UNIT; - int xin = x % UNIT; - auto srcY = srcK + x * kernelCount; - auto dstY = dstK + xout * SRC_UNIT * UNIT * icDiv * kernelCount + xin * SRC_UNIT; - for (int y = 0; y < ic; ++y) { - int yout = y / SRC_UNIT; - int yin = y % SRC_UNIT; - - const int dstIndex = yout * SRC_UNIT * UNIT + yin; - const int srcIndex = y * oc * kernelCount; - dstY[dstIndex] = srcY[srcIndex]; - } - } - } -} CPUDeconvolution::CPUDeconvolution(const Tensor* input, const Op* convOp, Backend* backend, bool dynamicWeight) : MNN::CPUDeconvolutionCommon(input, convOp, backend, dynamicWeight) { auto core = static_cast(backend)->functions(); auto coreInt8 = static_cast(backend)->int8Functions(); int eP, lP, hP; core->MNNGetMatMulPackMode(&eP, &lP, &hP); - int UNIT, SRC_UNIT, DST_XUNIT; - coreInt8->MNNGetGemmUnit(&UNIT, &SRC_UNIT, &DST_XUNIT); - bool ModeInt8 = false; - - if (CPUBackend::getDataType(input) == DataType_DT_INT8 || input->getType().bytes() == 1) { - eP = DST_XUNIT; - lP = SRC_UNIT; - hP = UNIT; - ModeInt8 = true; - } auto conv2d = convOp->main_as_Convolution2D(); auto layer = conv2d->common(); int outputCount = layer->outputCount(); @@ -155,30 +104,17 @@ CPUDeconvolution::CPUDeconvolution(const Tensor* input, const Op* convOp, Backen mWeight.reset(Tensor::createDevice(std::vector{UP_DIV(outputAlign, hP), UP_DIV(srcCount, lP) * lP, hP})); std::shared_ptr cache(Tensor::createDevice({outputAlign * srcCount})); if (dynamicWeight) { - mOrigin.reset(new CPUDeconvolutionOrigin(input, mWeight.get(), convOp, backend, ModeInt8)); + mOrigin.reset(new CPUDeconvolutionOrigin(input, mWeight.get(), convOp, backend, false)); mWeightTransformCache = cache; return; } const float* tempWeight = nullptr; - const int8_t* quanWeightInt8 = nullptr; int tempWeightSize = 0; - std::unique_ptr externalWeightTensor; std::shared_ptr quanCommon; - std::vector _bias(outputChannleUp4, 0); - std::vector _scale(outputChannleUp4, 0); - std::vector _beta(outputChannleUp4, 0); - auto biasPtr = _bias.data(); - auto scalePtr = _scale.data(); - auto betaPtr = _beta.data(); - - if (ModeInt8) { - ConvolutionCommon::getConvInt8Parameters(convOp, quanCommon, backend, quanWeightInt8, tempWeightSize, scalePtr, biasPtr, betaPtr); - } else { - ConvolutionCommon::getConvParameters(&quanCommon, backend, convOp, &tempWeight, &tempWeightSize); - } + ConvolutionCommon::getConvParameters(&quanCommon, backend, convOp, &tempWeight, &tempWeightSize); bool success = backend->onAcquireBuffer(mWeight.get(), Backend::STATIC) && backend->onAcquireBuffer(cache.get(), Backend::STATIC); @@ -196,26 +132,16 @@ CPUDeconvolution::CPUDeconvolution(const Tensor* input, const Op* convOp, Backen core->MNNFp32ToLowp(tempWeight, (int16_t*)lowpWeight.get(), outputCount * srcCount * fh * fw); tempWeight = (float*)lowpWeight.get(); } - if (!ModeInt8) { - mWeight.reset(Tensor::createDevice(std::vector{UP_DIV(outputAlign, hP), UP_DIV(srcCount, lP) * lP, hP})); - success = backend->onAcquireBuffer(mWeight.get(), Backend::STATIC); - if (!success) { - mValid = false; - return; - } - auto dest = mWeight->host(); - _transformWeight((uint8_t*)tempWeight, dest, outputCount, srcCount, fh, fw, cache->host(), core); - } else { - mWeight.reset(Tensor::createDevice(std::vector{UP_DIV(outputAlign, hP), UP_DIV(srcCount, lP) * lP, hP})); - success = backend->onAcquireBuffer(mWeight.get(), Backend::STATIC); - if (!success) { - mValid = false; - return; - } - _reorderWeightInt8(backend, layer, quanWeightInt8, mWeight); + mWeight.reset(Tensor::createDevice(std::vector{UP_DIV(outputAlign, hP), UP_DIV(srcCount, lP) * lP, hP})); + success = backend->onAcquireBuffer(mWeight.get(), Backend::STATIC); + if (!success) { + mValid = false; + return; } + auto dest = mWeight->host(); + _transformWeight((uint8_t*)tempWeight, dest, outputCount, srcCount, fh, fw, cache->host(), core); backend->onReleaseBuffer(cache.get(), Backend::STATIC); - mOrigin.reset(new CPUDeconvolutionOrigin(input, mWeight.get(), convOp, backend, ModeInt8)); + mOrigin.reset(new CPUDeconvolutionOrigin(input, mWeight.get(), convOp, backend, false)); } CPUDeconvolution::~CPUDeconvolution() { @@ -261,68 +187,21 @@ ErrorCode CPUDeconvolution::onResize(const std::vector &inputs, const } CPUDeconvolutionOrigin::CPUDeconvolutionOrigin(const Tensor *input, Tensor *weight, const Op *convOp, Backend *b, bool ModeInt8) : CPUDeconvolutionBasic(input, convOp, b) { - if (ModeInt8) { - const auto weightDataPtr = weight->host(); - auto conv2d = convOp->main_as_Convolution2D(); - auto common = conv2d->common(); - auto pack = static_cast(b)->functions()->pack; - mResource = CPUConvolution::makeResourceInt8(backend(), convOp, pack); - CPUConvolution::MutableResourceInt8 mutableResource(mResource, b); - auto core = static_cast(b)->int8Functions(); - auto gemmKernel = core->Int8GemmKernel; - int UNIT, SRC_UNIT, DST_XUNIT; - core->MNNGetGemmUnit(&UNIT, &SRC_UNIT, &DST_XUNIT); - const auto kEleCnt = mCommon->kernelX() * mCommon->kernelY(); - const int ocDiv4 = UP_DIV(common->outputCount(), pack) * kEleCnt; - const int icDiv4 = UP_DIV(common->inputCount(), SRC_UNIT); - const int ocDivUnit = UP_DIV(common->outputCount(), UNIT); - const int oc4 = ocDiv4 / kEleCnt; - const int bias_elesize = ocDiv4 * pack; - // set offset if use SSE. - auto inputQuant = TensorUtils::getQuantInfo(input); - auto inputZeroPoint = inputQuant[1]; - std::vector _bias(bias_elesize, inputZeroPoint); -#ifdef MNN_USE_SSE - int actBits = conv2d->symmetricQuan()->nbits(); - if (actBits <= 7) { - gemmKernel = core->Int8GemmKernelFast; - } - for (int a = 0; a < kEleCnt; ++a){ - for (int oz = 0; oz < ocDivUnit * UNIT; ++oz) { - int offset = inputZeroPoint, oz4 = oz / UNIT, ozRemain = oz % UNIT; - for (int sz = 0; sz < icDiv4 * SRC_UNIT; ++sz) { - int sz4 = sz / SRC_UNIT, szRemain = sz % SRC_UNIT; - int index = (((a * oc4 + oz4) * icDiv4 + sz4) * UNIT + ozRemain) * SRC_UNIT + szRemain; - auto weightInt8Data = weightDataPtr[index]; - offset += weightInt8Data * (-128); - } - if (oz < oc4 * pack) { - _bias[a * oc4 * pack + oz] = offset; - } - } - } -#else - if(conv2d->symmetricQuan() && conv2d->symmetricQuan()->method() == QuantizeAlgo_OVERFLOW_AWARE){ - gemmKernel = core->Int8GemmKernelFast; - } -#endif - mDeconvInt8Exe.reset(new GemmInt8Executor(b, mResource, convOp, gemmKernel, _bias)); - } + // Do nothing } ErrorCode CPUDeconvolutionOrigin::onResize(const std::vector& inputs, const std::vector& outputs) { CPUDeconvolutionBasic::onResize(inputs, outputs); auto core = static_cast(backend())->functions(); - auto gcore = static_cast(backend())->int8Functions(); int bytes = core->bytes; auto input = inputs[0]; auto output = outputs[0]; auto oc = output->channel(); - int UNIT, SRC_UNIT, DST_XUNIT; - gcore->MNNGetGemmUnit(&UNIT, &SRC_UNIT, &DST_XUNIT); if (UP_DIV(oc, core->pack) * core->pack != inputs[2]->length(0)) { return INPUT_DATA_ERROR; } + int eP, lP, hP; + core->MNNGetMatMulPackMode(&eP, &lP, &hP); auto ocC4 = UP_DIV(output->channel(), core->pack); auto icC4 = UP_DIV(input->channel(), core->pack); @@ -339,136 +218,132 @@ ErrorCode CPUDeconvolutionOrigin::onResize(const std::vector& inputs, c auto src_height = output->height(); auto src_width = output->width(); auto batch = output->batch(); + auto weightTensor = inputs[1]; + auto biasTensor = inputs[2]; auto kernelCount = ocC4 * mCommon->kernelX() * mCommon->kernelY(); - mPostFunctions.clear(); - auto plane = width * height * batch; - const int maxDepth = 5; + auto plane = width * height * batch; auto allocator = static_cast(backend())->getBufferAllocator(); - //int zeroPoint = 0; - - auto biasTensor = inputs[2]; - - // prepare for float2int8 if necessary. - auto outputQuant = TensorUtils::getQuantInfo(outputs[0]); - float scale = outputQuant[0]; - scale = (scale == 0.f ? 0.f : 1.f / scale); - auto maxValue = outputQuant[3]; - auto minValue = outputQuant[2]; - auto zeroPoint = outputQuant[1]; - - AutoRelease tempInput(Tensor::createDevice({icC4, plane, core->pack})); - bool needReleaseTempInput = true; - int outi8 = 0; - if (CPUBackend::getDataType(output) == DataType_DT_INT8 || output->getType().bytes() == 1) { - outi8 = 1; + auto threadNumber = static_cast(backend())->threadNumber(); + auto tileCount = UP_DIV(plane, eP); + threadNumber = ALIMIN(tileCount, threadNumber); + auto im2colOutputStride = input->channel() * eP * core->bytes; + mGemmInput = allocator->alloc(threadNumber * im2colOutputStride); + auto gemmOutputStride = kernelCount * core->pack * eP * core->bytes; + mGemmOutput = allocator->alloc(threadNumber * gemmOutputStride); + auto outputSize = batch*src_width*src_height*ocC4*core->pack*core->bytes; + if (threadNumber > 1) { + mExtraOutput = allocator->alloc((threadNumber-1)*outputSize); } - if (CPUBackend::getDataType(inputs[0]) == DataType_DT_INT8 || inputs[0]->getType().bytes() == 1) { - mTempOutput.reset(Tensor::createDevice({batch, height, width, ocC4 * kw * kh * core->pack})); - auto res = backend()->onAcquireBuffer(mTempOutput.get(), Backend::DYNAMIC); - if (!res) { - return OUT_OF_MEMORY; - } - mDeconvInt8Exe->onResize({input}, {mTempOutput.get()}); - if (mResource->mRelu) { - minValue = outputQuant[1]; - } + allocator->free(mGemmInput); + allocator->free(mGemmOutput); + if (threadNumber > 1) { + allocator->free(mExtraOutput); } - else { - mTempOutput.reset(Tensor::createDevice({kernelCount, plane, core->pack})); - auto res = backend()->onAcquireBuffer(mTempOutput.get(), Backend::DYNAMIC); - if (!res) { - return OUT_OF_MEMORY; - } - mMatMul.reset(new StrassenMatrixComputor(backend(), true, maxDepth)); - // tempInput->buffer().host = (uint8_t*)inputPtr; - - needReleaseTempInput = false; - TensorUtils::getDescribeOrigin(tempInput.get())->mem = new CPUMemObj(nullptr, TensorUtils::getDescribeOrigin(input)->mem->chunk(), 0); - mMatMul->onEncode({tempInput.get(), inputs[1]}, {mTempOutput.get()}); - } - auto threadNumber = ((CPUBackend*)backend())->threadNumber(); - std::vector scales(core->pack * src_height * src_width * batch, scale); - MemChunk outputFp32Ptr; - if (outi8) { - outputFp32Ptr = allocator->alloc(batch * src_height * src_width * ocC4 * core->pack * bytes); - if (outputFp32Ptr.invalid()) { - return OUT_OF_MEMORY; - } - } - - mPostFunctions.emplace_back(std::make_pair([ocC4, width, height, kh, kw, padY, padX, dilateY, dilateX, strideY, - strideX, threadNumber, src_width, src_height, plane, input, biasTensor, this, core, gcore, batch, outi8, scale, - minValue, maxValue, zeroPoint, outputFp32Ptr](uint8_t* outputPtr, int tId) { - auto colBufferPtr = mTempOutput->host(); - auto biasPtr = biasTensor->host(); - auto inputPtr = input->host(); + auto first = std::make_pair([=](uint8_t* outputPtr, int tId) { + auto gemmInputBufferPtr = mGemmInput.ptr() + tId * im2colOutputStride; + auto colBufferPtr = mGemmOutput.ptr() + tId * gemmOutputStride; + auto inputPtr = input->host(); auto unitBytes = core->pack * core->bytes; auto tempOutPtr = outputPtr; - auto float2Int8_step = src_height * src_width * batch; - if (outi8) { - tempOutPtr = outputFp32Ptr.ptr(); + if (tId > 0) { + tempOutPtr = mExtraOutput.ptr() + (tId-1) * outputSize; } - for (int z = (tId); z < ocC4; z += threadNumber) { - auto dstZ = tempOutPtr + z * src_height * src_width * batch * unitBytes; - auto srcZ = colBufferPtr + kw * kh * plane * z * unitBytes; - ::memset(dstZ, 0, src_width * src_height * batch * unitBytes); - for (int b = 0; b < batch; ++b) { - auto dstB = dstZ + b * src_width * src_height * unitBytes; - auto srcB = srcZ + b * width * height * unitBytes; - for (int oy = 0; oy < height; ++oy) { - for (int ox = 0; ox < width; ++ox) { - int srcStartX = ox * strideX - padX; - int srcStartY = oy * strideY - padY; - - int sfy = ALIMAX(0, (UP_DIV(-srcStartY, dilateY))); - int efy = ALIMIN(kh, UP_DIV(src_height - srcStartY, dilateY)); - - int sfx = ALIMAX(0, (UP_DIV(-srcStartX, dilateX))); - int efx = ALIMIN(kw, UP_DIV(src_width - srcStartX, dilateX)); - - auto dstStart = dstB + srcStartX * unitBytes + srcStartY * src_width * unitBytes; - auto srcStart = srcB + unitBytes * (ox + oy * width); - if (sfy >= efy || sfx >= efx) { - continue; - } - - for (int fy = sfy; fy < efy; ++fy) { - auto dstY = dstStart + fy * unitBytes * dilateY * src_width; - auto srcY = srcStart + fy * kw * plane * unitBytes; - core->MNNAddC4WithStride((const float*)(srcY + sfx * plane * unitBytes), (float*)(dstY + sfx * dilateX * unitBytes), plane * core->pack, dilateX * core->pack, efx - sfx); - } + ::memset(tempOutPtr, 0, outputSize); + + int l = mSrcCount; + int h = kernelCount * core->pack; + auto weightPtr = weightTensor->host(); + for (int index=tId; index < tileCount; index+=threadNumber) { + int xStart = index * eP; + int xEnd = ALIMIN(xStart + eP, plane); + int xCount = xEnd-xStart; + if (xCount <= 0) { + continue; + } + size_t parameters[7]; + parameters[0] = xCount * core->bytes; + parameters[1] = l; + parameters[2] = h; + parameters[3] = xCount * core->bytes * core->pack; + parameters[4] = 0; + parameters[5] = 0; + parameters[6] = 0; + const float* postParametersPtr = nullptr; + int32_t info[4]; + int32_t stride[4]; + stride[0] = xCount; + stride[1] = (int32_t)parameters[1]; + stride[2] = 0; + stride[3] = 0; + info[0] = 1; + info[1] = plane; + info[2] = xCount; + info[3] = 1; + auto aStart = inputPtr + xStart * unitBytes; + core->MNNPackC4ForMatMul_A((float*)(gemmInputBufferPtr), (const float**)(&aStart), info, stride); + if (xCount == eP) { + core->MNNPackedMatMul((float*)(colBufferPtr), (float*)gemmInputBufferPtr, (float*)weightPtr, parameters, postParametersPtr, nullptr, nullptr, nullptr); + } else { + core->MNNPackedMatMulRemain((float*)(colBufferPtr), (float*)gemmInputBufferPtr, (float*)weightPtr, xCount, parameters, postParametersPtr, nullptr, nullptr, nullptr); + } + // Col2Im + for (int z = 0; z < ocC4; ++z) { + auto dstZ = tempOutPtr + z * src_height * src_width * batch * unitBytes; + auto srcZ = colBufferPtr + kw * kh * xCount * z * unitBytes; + for (int x=0; x= efy || sfx >= efx) { + continue; + } + + for (int fy = sfy; fy < efy; ++fy) { + auto dstY = dstStart + fy * unitBytes * dilateY * src_width; + auto srcY = srcStart + fy * kw * xCount * unitBytes; + core->MNNAddC4WithStride((const float*)(srcY + sfx * xCount * unitBytes), (float*)(dstY + sfx * dilateX * unitBytes), xCount * core->pack, dilateX * core->pack, efx - sfx); } } } - core->MNNAxByClampBroadcastUnit((float*)dstZ, (float*)dstZ, (const float*)((uint8_t*)biasPtr + unitBytes * z), src_height * src_width * batch, 0, 0, 1, mPostParameters.data()); - if (outi8) { - float scaleOne = scale; - float zeroOne = zeroPoint; - gcore->MNNFloat2Int8((float*)dstZ, (int8_t*)(outputPtr + z * float2Int8_step * core->pack), float2Int8_step, &scaleOne, minValue, maxValue, &zeroOne, 0); + } + }, threadNumber); + auto second = std::make_pair([ocC4, src_height, src_width, threadNumber, batch, biasTensor, this, outputSize, core](uint8_t* outputPtr, int tId) { + auto unitBytes = core->pack * core->bytes; + auto biasPtr = biasTensor->host(); + for (int z = tId; z < ocC4; z+=threadNumber) { + auto dstZ = outputPtr + z * src_height * src_width * batch * unitBytes; + if (threadNumber > 1) { + for (int index=0; indexMNNMatrixAdd((float*)(dstZ), (float*)(src), (float*)(dstZ), src_height * src_width * batch, 0, 0, 0, 1); + } } + core->MNNAxByClampBroadcastUnit((float*)dstZ, (float*)dstZ, (const float*)((uint8_t*)biasPtr + unitBytes * z), src_height * src_width * batch, 0, 0, 1, mPostParameters.data()); } - }, threadNumber)); - if (outi8) { - allocator->free(outputFp32Ptr); - } - if (needReleaseTempInput) { - backend()->onReleaseBuffer(tempInput.get(), Backend::DYNAMIC); - } - backend()->onReleaseBuffer(mTempOutput.get(), Backend::DYNAMIC); + + }, threadNumber); + mExecuteFuntion = {first, second}; return NO_ERROR; } ErrorCode CPUDeconvolutionOrigin::onExecute(const std::vector& inputs, const std::vector& outputs) { auto inputPtr = inputs[0]->host(); auto outputPtr = outputs[0]->host(); - if (mDeconvInt8Exe.get() != nullptr) { - mDeconvInt8Exe->onExecute({inputs[0], inputs[1]}, {mTempOutput.get()}); - } - else { - mMatMul->onExecute(); - } - for (auto& unit : mPostFunctions) { + for (auto& unit : mExecuteFuntion) { MNN_CONCURRENCY_BEGIN(tId, unit.second) { unit.first(outputPtr, (int)tId); } @@ -482,15 +357,6 @@ class CPUDeconvolutionCreator : public CPUBackend::Creator { const MNN::Op* op, Backend* backend) const { auto convOp = op->main_as_Convolution2D(); auto common = convOp->common(); - if (backend->type() == MNN_FORWARD_CPU && inputs.size() == 1) { - if (common->strideY() > 1 || common->strideX() > 1) { - if (common->dilateX() == 1 && common->dilateY() == 1) { - if (common->kernelX() / common->strideX() > 2 || common->kernelY() / common->strideY() > 2) { - return new DeconvolutionWithStride(inputs[0], op, backend); - } - } - } - } return new CPUDeconvolution(inputs[0], op, backend, inputs.size() > 1); } }; diff --git a/source/backend/cpu/CPUDeconvolution.hpp b/source/backend/cpu/CPUDeconvolution.hpp index 82f7168d4..bea9f164a 100644 --- a/source/backend/cpu/CPUDeconvolution.hpp +++ b/source/backend/cpu/CPUDeconvolution.hpp @@ -12,7 +12,6 @@ #include "CPUConvolution.hpp" #include "compute/CommonOptFunction.h" #include "compute/StrassenMatmulComputor.hpp" -#include "compute/GemmInt8Executor.hpp" #include "core/TensorUtils.hpp" namespace MNN { class CPUDeconvolutionBasic : public CPUConvolution { @@ -44,11 +43,11 @@ class CPUDeconvolutionOrigin : public CPUDeconvolutionBasic { virtual ErrorCode onResize(const std::vector &inputs, const std::vector &outputs) override; private: - std::shared_ptr mMatMul; - std::shared_ptr mDeconvInt8Exe; - std::vector, int>> mPostFunctions; - std::shared_ptr mTempOutput; - std::shared_ptr mResource; + MemChunk mGemmOutput; + MemChunk mGemmInput; + MemChunk mExtraOutput; + + std::vector, int>> mExecuteFuntion; }; class CPUDeconvolution : public CPUDeconvolutionCommon { diff --git a/source/backend/cpu/CPUInstanceNorm.cpp b/source/backend/cpu/CPUInstanceNorm.cpp index 851d97831..6ed4513e6 100644 --- a/source/backend/cpu/CPUInstanceNorm.cpp +++ b/source/backend/cpu/CPUInstanceNorm.cpp @@ -6,9 +6,10 @@ // Copyright © 2018, Alibaba Group Holding Limited // +#include "backend/cpu/CPUBackend.hpp" +#ifdef MNN_SUPPORT_DEPRECATED_OP #include "backend/cpu/CPUInstanceNorm.hpp" #include -#include "backend/cpu/CPUBackend.hpp" #include "core/Concurrency.h" #include #include "core/Macro.h" @@ -106,7 +107,9 @@ class CPUInstanceNormCreator : public CPUBackend::Creator { return new CPUInstanceNorm(backend, op); } }; - -REGISTER_CPU_OP_CREATOR(CPUInstanceNormCreator, OpType_InstanceNorm); - } // namespace MNN +#endif +namespace MNN { +REGISTER_CPU_OP_CREATOR_OLD(CPUInstanceNormCreator, OpType_InstanceNorm); +}; + diff --git a/source/backend/cpu/CPUMoments.cpp b/source/backend/cpu/CPUMoments.cpp index 40c2cccf2..8ad50904e 100644 --- a/source/backend/cpu/CPUMoments.cpp +++ b/source/backend/cpu/CPUMoments.cpp @@ -6,9 +6,10 @@ // Copyright © 2018, Alibaba Group Holding Limited // +#include "backend/cpu/CPUBackend.hpp" +#ifdef MNN_SUPPORT_DEPRECATED_OP #include "backend/cpu/CPUMoments.hpp" #include -#include "backend/cpu/CPUBackend.hpp" #include "core/Concurrency.h" #include #include "core/Macro.h" @@ -129,7 +130,9 @@ class CPUMomentsCreator : public CPUBackend::Creator { return new CPUMoments(backend, op); } }; - -REGISTER_CPU_OP_CREATOR(CPUMomentsCreator, OpType_Moments); - } // namespace MNN +#endif +namespace MNN { +REGISTER_CPU_OP_CREATOR_OLD(CPUMomentsCreator, OpType_Moments); +}; + diff --git a/source/backend/cpu/CPUOPRegister.cpp b/source/backend/cpu/CPUOPRegister.cpp index 37f868732..345f45ce5 100644 --- a/source/backend/cpu/CPUOPRegister.cpp +++ b/source/backend/cpu/CPUOPRegister.cpp @@ -78,6 +78,9 @@ extern void ___CPUTextureCreator__OpType_Texture__(); #ifdef MNN_SUPPORT_TRANSFORMER_FUSE extern void ___CPUAttentionCreator__OpType_Attention__(); #endif +#ifdef MNN_BUILD_AUDIO +extern void ___CPUStftCreator__OpType_Stft__(); +#endif void registerCPUOps() { ___CPUCropAndResizeCreator__OpType_CropAndResize__(); ___CPUArgMaxCreator__OpType_ArgMax__(); @@ -156,5 +159,8 @@ ___CPUTextureCreator__OpType_Texture__(); #ifdef MNN_SUPPORT_TRANSFORMER_FUSE ___CPUAttentionCreator__OpType_Attention__(); #endif +#ifdef MNN_BUILD_AUDIO +___CPUStftCreator__OpType_Stft__(); +#endif } } diff --git a/source/backend/cpu/CPURelu.cpp b/source/backend/cpu/CPURelu.cpp index 073556464..71bc41f16 100644 --- a/source/backend/cpu/CPURelu.cpp +++ b/source/backend/cpu/CPURelu.cpp @@ -46,16 +46,53 @@ ErrorCode CPURelu::onExecute(const std::vector& inputs, const std::vect auto& ob = outputs[0]->buffer(); if (CPUBackend::getDataType(inputs[0]) == DataType_DT_INT8 || inputs[0]->getType().bytes() == 1) { + auto core = static_cast(backend())->int8Functions(); + auto gcore = static_cast(backend())->functions(); const int8_t* srcO = (const int8_t*)ib.host; + int8_t* dstO = (int8_t*)ob.host; auto inInfo = TensorUtils::getQuantInfo(inputs[0]); auto outInfo = TensorUtils::getQuantInfo(outputs[0]); - if (inInfo != outInfo) { - MNN_PRINT("this relu int8 implementation has error when input output quant info mismatch\n"); - } - int8_t zeroPoint = int8_t(outInfo[1]); - int8_t* dstO = (int8_t*)ob.host; auto size = mRealSize; auto numberThread = ((CPUBackend*)backend())->threadNumber(); + + auto inputscale = inInfo[0]; + auto inputzero = (ssize_t)inInfo[1]; + auto outputzero = (ssize_t)outInfo[1]; + auto outputscale = outInfo[0] > 0.f ? 1.0f / outInfo[0] : 0.f; + QuanPrePostParameters params; + params.maxValue = static_cast(inInfo[3]); + params.minValue = static_cast(inInfo[2]); + params.inputScale = &inputscale; + params.inputZeroPoint = &inputzero; + params.outputScale = &outputscale; + params.outputZeroPoint = &outputzero; + + if (((float*)mSlope.get())[0] != 0.f) { + // PRelu Int8 + int sizeQuad = size / gcore->pack; + int remain = size % gcore->pack; + int sizeDivide = UP_DIV(sizeQuad, numberThread); + + if (sizeQuad > 0) { + MNN_CONCURRENCY_BEGIN(tId, numberThread) { + + int number = sizeDivide; + if (tId == numberThread - 1) { + number = sizeQuad - tId * sizeDivide; + } + core->MNNReluWithSlopeChannelInt8((int8_t*)(dstO + tId * gcore->pack * sizeDivide), srcO + tId * sizeDivide * gcore->pack, (const float*)(mSlope.get()), number, 1, ¶ms, gcore->pack); + + } + MNN_CONCURRENCY_END(); + } + if (remain > 0) { + ::memcpy(mCacheSrc.get(), srcO + sizeQuad * gcore->pack, remain); + core->MNNReluWithSlopeChannelInt8((int8_t*)mCacheDst.get(), (const int8_t*)(mCacheSrc.get()), (const float*)mSlope.get(), 1, 1, ¶ms, gcore->pack); + ::memcpy(dstO + sizeQuad * gcore->pack, mCacheDst.get(), remain); + } + return NO_ERROR; + } + int8_t zeroPoint = int8_t(outInfo[1]); int sizeQuad = size / 16; int remain = sizeQuad * 16; int sizeDivide = sizeQuad / numberThread; @@ -187,10 +224,6 @@ ErrorCode CPUPRelu::onResize(const std::vector& inputs, const std::vect mQuanScalesOutput = {outputScale}; mQuanZerosInput = {inputZero}; mQuanZerosOutput = {outputZero}; - auto p = mSlope.host(); - for (int i = 0; i < mSlope.buffer().dim[0].extent; ++i) { - p[i] = p[i] * inputScale * outputScale; - } } return NO_ERROR; } @@ -198,42 +231,53 @@ ErrorCode CPUPRelu::onResize(const std::vector& inputs, const std::vect ErrorCode CPUPRelu::onExecute(const std::vector& inputs, const std::vector& outputs) { auto& ib = inputs[0]->buffer(); auto& ob = outputs[0]->buffer(); - int sizeQuad = 1; - for (int i=2; i(backend())->functions(); auto coreInt8 = static_cast(backend())->int8Functions(); const int channel = ib.dim[1].extent; const int batch = ib.dim[0].extent; - int pack = 4; - int depthQuad = UP_DIV(channel, core->pack); - const uint8_t* srcO = (const uint8_t*)ib.host; + int pack = core->pack; + + const int8_t* srcO = (const int8_t*)ib.host; uint8_t* dstO = (uint8_t*)ob.host; + auto depthQuad = UP_DIV(channel, core->pack); auto totalCount = batch * depthQuad; auto numberThread = ((CPUBackend*)backend())->threadNumber(); + auto sizeQuad = UP_DIV(depthQuad, numberThread); + auto sizeCount = sizeQuad * batch * inputs[0]->width() * inputs[0]->height() * core->pack; + if (mUseInt8) { - depthQuad = UP_DIV(channel, pack); + auto inputInfo = TensorUtils::getDescribe(inputs[0])->quantAttr; + auto outputInfo = TensorUtils::getDescribe(outputs[0])->quantAttr; + auto inzero = (ssize_t)inputInfo->zero; + auto outzero = (ssize_t)outputInfo->zero; + auto outscale = outputInfo->scale > 0 ? 1.f / outputInfo->scale : 0.f; + QuanPrePostParameters params; + params.maxValue = static_cast(outputInfo->max); + params.minValue = static_cast(outputInfo->min); + params.inputScale = &inputInfo->scale; + params.inputZeroPoint = &inzero; + params.outputScale = &outscale; + params.outputZeroPoint = &outzero; MNN_CONCURRENCY_BEGIN(tId, numberThread) { - QuanPrePostParameters params; - params.maxValue = static_cast(TensorUtils::getDescribe(inputs[0])->quantAttr->max); - params.minValue = static_cast(TensorUtils::getDescribe(inputs[0])->quantAttr->min); - params.inputScale = mQuanScalesInput.data(); - params.inputZeroPoint = mQuanZerosInput.data(); - params.outputScale = mQuanScalesOutput.data(); - params.outputZeroPoint = mQuanZerosOutput.data(); - for (int b=tId; bMNNReluWithSlopeChannelInt8((int8_t*)(dstO + sizeQuad * pack * b), (const int8_t*)(srcO + sizeQuad * pack * b), (const float*)(mSlope.host() + core->bytes * pack * c), sizeQuad, 1, ¶ms); + + + auto number = ALIMIN(sizeQuad, depthQuad - tId * sizeQuad); + if (number > 0) { + auto sizeQ = number * batch * inputs[0]->width() * inputs[0]->height(); + coreInt8->MNNReluWithSlopeChannelInt8((int8_t*)(dstO + tId * sizeCount), srcO + tId * sizeCount, (const float*)(mSlope.host() + tId * sizeQuad * pack * core->bytes), sizeQ / number, number, ¶ms, core->pack); } } MNN_CONCURRENCY_END(); return NO_ERROR; } + int hw = 1; + for (int i=2; iMNNReluWithSlopeChannel((float*)(dstO + sizeQuad * core->bytes * core->pack * b), (const float*)(srcO + sizeQuad * core->pack * core->bytes * b), (const float*)(mSlope.host() + core->bytes * core->pack * c), sizeQuad, 1); + core->MNNReluWithSlopeChannel((float*)(dstO + hw * core->bytes * core->pack * b), (const float*)(srcO + hw * core->pack * core->bytes * b), (const float*)(mSlope.host() + core->bytes * core->pack * c), hw, 1); } } MNN_CONCURRENCY_END(); diff --git a/source/backend/cpu/CPUStft.cpp b/source/backend/cpu/CPUStft.cpp new file mode 100644 index 000000000..5e6d40b54 --- /dev/null +++ b/source/backend/cpu/CPUStft.cpp @@ -0,0 +1,75 @@ +// +// CPUStft.cpp +// MNN +// +// Created by MNN on 2024/11/26. +// Copyright © 2018, Alibaba Group Holding Limited +// + +#ifdef MNN_BUILD_AUDIO + +#include "backend/cpu/CPUStft.hpp" +#include "backend/cpu/CPUBackend.hpp" +#include "core/Concurrency.h" +#include "core/TensorUtils.hpp" +#include "core/Macro.h" +#include "compute/CommonOptFunction.h" + +namespace MNN { + +CPUStft::CPUStft(Backend* backend, int nfft, int hop_length, bool abs) + : Execution(backend), mNfft(nfft), mHopLength(hop_length), mAbs(abs) { + // nothing to do +} + +ErrorCode CPUStft::onResize(const std::vector &inputs, const std::vector &outputs) { + auto cpuBn = static_cast(backend()); + mTmpFrames.buffer().dim[0].extent = cpuBn->threadNumber(); + mTmpFrames.buffer().dim[1].extent = mNfft; + TensorUtils::getDescribe(&mTmpFrames)->dimensionFormat = MNN_DATA_FORMAT_NHWC; + mTmpFrames.buffer().dimensions = 2; + mTmpFrames.buffer().type = inputs[0]->getType(); + backend()->onAcquireBuffer(&mTmpFrames, Backend::DYNAMIC); + backend()->onReleaseBuffer(&mTmpFrames, Backend::DYNAMIC); + return NO_ERROR; +} + +ErrorCode CPUStft::onExecute(const std::vector& inputs, const std::vector& outputs) { + const float* sample = inputs[0]->host(); + const float* window = inputs[1]->host(); + float* buffer = mTmpFrames.host(); + float* output = outputs[0]->host(); + auto outputShape = outputs[0]->shape(); + int frames = outputShape[0]; + int col = outputShape[1]; + auto cpuBn = static_cast(backend()); + int threadNum = cpuBn->threadNumber(); + // div frames to threadNum + int threadNumber = std::min(threadNum, frames); + int sizeDivide = frames / threadNumber; + MNN_CONCURRENCY_BEGIN(tId, threadNumber) { + int number = sizeDivide; + if (tId == threadNumber - 1) { + number = frames - tId * sizeDivide; + } + for (int i = tId * sizeDivide; i < tId * sizeDivide + number; ++i) { + MNNDftAbs(sample + i * mHopLength, window, output + i * col, buffer + tId * mNfft, mNfft); + } + }; + MNN_CONCURRENCY_END(); + + return NO_ERROR; +} + +class CPUStftCreator : public CPUBackend::Creator { +public: + virtual Execution* onCreate(const std::vector& inputs, const std::vector& outputs, + const MNN::Op* op, Backend* backend) const { + auto stft = op->main_as_StftParam(); + return new CPUStft(backend, stft->n_fft(), stft->hop_length(), stft->abs()); + } +}; + +REGISTER_CPU_OP_CREATOR_AUDIO(CPUStftCreator, OpType_Stft); +} // namespace MNN +#endif // MNN_BUILD_AUDIO \ No newline at end of file diff --git a/source/backend/cpu/CPUStft.hpp b/source/backend/cpu/CPUStft.hpp new file mode 100644 index 000000000..e483a9b8c --- /dev/null +++ b/source/backend/cpu/CPUStft.hpp @@ -0,0 +1,31 @@ +// +// CPUStft.hpp +// MNN +// +// Created by MNN on 2024/11/26. +// Copyright © 2018, Alibaba Group Holding Limited +// + +#ifdef MNN_BUILD_AUDIO +#ifndef CPUStft_hpp +#define CPUStft_hpp + +#include "core/Execution.hpp" + +namespace MNN { +class CPUStft : public Execution { +public: + CPUStft(Backend *backend, int nfft, int hop_length, bool abs); + virtual ~CPUStft() = default; + virtual ErrorCode onResize(const std::vector &inputs, const std::vector &outputs) override; + virtual ErrorCode onExecute(const std::vector &inputs, const std::vector &outputs) override; +private: + int mNfft, mHopLength; + bool mAbs; + Tensor mTmpFrames; +}; + +} // namespace MNN + +#endif /* CPUStft.hpp */ +#endif // MNN_BUILD_AUDIO \ No newline at end of file diff --git a/source/backend/cpu/arm/arm32/MNNReluWithSlopeChannelInt8.S b/source/backend/cpu/arm/arm32/MNNReluWithSlopeChannelInt8.S index 4595733b8..4c09f69c6 100644 --- a/source/backend/cpu/arm/arm32/MNNReluWithSlopeChannelInt8.S +++ b/source/backend/cpu/arm/arm32/MNNReluWithSlopeChannelInt8.S @@ -48,9 +48,9 @@ beq PReluEnd cmp r3, #0 beq PReluEnd -vmov.f32 q12, #0.5 -vmov.f32 q13, #-0.5 .macro ROUND_TWO x0, x1 + vmov.f32 q12, #0.5 + vmov.f32 q13, #-0.5 vcgt.f32 q10, \x0, #0 vcgt.f32 q11, \x1, #0 vbsl.f32 q10, q12, q13 @@ -62,6 +62,8 @@ vmov.f32 q13, #-0.5 .endm .macro ROUND_ONE x0 + vmov.f32 q12, #0.5 + vmov.f32 q13, #-0.5 vcgt.f32 q10, \x0, #0 vbsl.f32 q10, q12, q13 vadd.f32 \x0, q10, \x0 @@ -69,11 +71,13 @@ vmov.f32 q13, #-0.5 .endm vld1.8 d30[0], [r8] -vdup.8 d30, d30[0] // inputZeroPoint - vld1.8 d31[0], [r6] +vdup.8 d30, d30[0] // inputZeroPoint vdup.8 d31, d31[0] // outputZeroPoint +ldr r6, [r5, #0] // inputScale +ldr r8, [r5, #4] // outputScale + PReluZLoop: vld1.32 {q14}, [r2]! @@ -93,17 +97,38 @@ vmovl.s16 q4, d3 vmovl.s16 q5, d4 vmovl.s16 q6, d5 -vclt.s8 q1, q0, #0 - vcvt.f32.s32 q3, q3 vcvt.f32.s32 q4, q4 vcvt.f32.s32 q5, q5 vcvt.f32.s32 q6, q6 - -vmul.f32 q3, q3, q14 -vmul.f32 q4, q4, q14 -vmul.f32 q5, q5, q14 -vmul.f32 q6, q6, q14 +// *input_scale +vld1.f32 {d14[0]}, [r6] +vld1.f32 {d14[1]}, [r8] // outputscale +vmul.f32 q3, q3, d14[0] +vmul.f32 q4, q4, d14[0] +vmul.f32 q5, q5, d14[0] +vmul.f32 q6, q6, d14[0] + +vclt.f32 q0, q3, #0 +vclt.f32 q1, q4, #0 +vclt.f32 q2, q5, #0 +vclt.f32 q12, q6, #0 + +// *slope +vmul.f32 q8, q3, q14 +vmul.f32 q9, q4, q14 +vmul.f32 q10, q5, q14 +vmul.f32 q11, q6, q14 + +vbit.32 q3, q8, q0 +vbit.32 q4, q9, q1 +vbit.32 q5, q10, q2 +vbit.32 q6, q11, q12 + +vmul.f32 q3, q3, d14[1] +vmul.f32 q4, q4, d14[1] +vmul.f32 q5, q5, d14[1] +vmul.f32 q6, q6, d14[1] ROUND_TWO q3, q4 ROUND_TWO q5, q6 @@ -122,8 +147,7 @@ vqmovn.s16 d19, q8 vmax.s8 q9, q9, q10 vmin.s8 q9, q9, q11 -vbit.8 q0, q9, q1 -vst1.8 {q0}, [r0]! +vst1.8 {q9}, [r0]! sub r5, r5, #4 cmp r5, #4 @@ -139,10 +163,18 @@ vmovl.s8 q1, d0 vsubw.s8 q1, q1, d30 vmovl.s16 q2, d2 -vclt.s8 d10, d0, #0 vcvt.f32.s32 q2, q2 -vmul.f32 q2, q2, q14 +// *input_scale +vld1.f32 {d14[0]}, [r6] +vld1.f32 {d14[1]}, [r8] // outputscale +vmul.f32 q2, q2, d14[0] +vclt.f32 q4, q2, #0 // index +// *slope +vmul.f32 q3, q2, q14 +vbit q2, q3, q4 +// *output_scale +vmul.f32 q2, q2, d14[1] ROUND_ONE q2 diff --git a/source/backend/cpu/arm/arm32/MNNWinogradMatrixProductLeft.S b/source/backend/cpu/arm/arm32/MNNWinogradMatrixProductLeft.S deleted file mode 100644 index 4796a7091..000000000 --- a/source/backend/cpu/arm/arm32/MNNWinogradMatrixProductLeft.S +++ /dev/null @@ -1,225 +0,0 @@ -// -// MNNWinogradMatrixProductLeft.S -// MNN -// -// Created by MNN on 2018/08/22. -// Copyright © 2018, Alibaba Group Holding Limited -// - -#ifdef __arm__ -#ifndef __aarch64__ - -#include "MNNAsmGlobal.h" - -.text -.align 5 - -asm_function MNNWinogradMatrixProductLeft -//void MNNWinogradMatrixProductLeft(const float* S, const float* B, float* M, size_t w, size_t h, size_t k, size_t length); - -//Auto: r0: S, r1:B, r2: M, r3:w -//Load From sp: r4:h, r5:k, r6:length - -push {r4-r8, r10, r11, lr} // avoid to touch platform-register r-9 -ldr r4, [sp, #32] -ldr r5, [sp, #36] -ldr r6, [sp, #40] - -//unitStepInFloat -mov r8, #16 // 4*sizeof(float) -mul r8, r6, r8 - -//srcYUnitStep -mul lr, r3, r8 -sub lr, lr, r8 -add r7, lr, r8 - -//B's step -mov r10, #4 -mul r10, r4, r10 - -LoopY: - push {r0, r3} - LoopX: - push {r0, r1} - vmov.i32 q14, #0 - mov r11, r6 - LoopUnitSetZero: - vst1.32 {q14}, [r2]! - subs r11, r11, #1 - bne LoopUnitSetZero - sub r2, r2, r8 - mov r12, r5 - - LK7: - cmp r12, #7 - blt LK4 - push {r3-r7} - LoopK7: - vld1.32 {d0[0]}, [r1], r10 - vld1.32 {d0[1]}, [r1], r10 - vld1.32 {d1[0]}, [r1], r10 - vld1.32 {d1[1]}, [r1], r10 - vld1.32 {d2[0]}, [r1], r10 - vld1.32 {d2[1]}, [r1], r10 - vld1.32 {d3[0]}, [r1], r10 - mov r11, r6 - vmov.32 d30[0], r1 - - add r1, r0, r7 - add r3, r1, r7 - add r4, r3, r7 - add r5, r4, r7 - add r6, r5, r7 - add r7, r6, r7 - - LoopUnitK7: - vld1.32 {q8}, [r2] - vld1.32 {q12}, [r0]! - vmla.f32 q8, q12, d0[0] - vld1.32 {q13}, [r1]! - vmul.f32 q9, q13, d0[1] - vld1.32 {q12}, [r3]! - vmla.f32 q8, q12, d1[0] - vld1.32 {q13}, [r4]! - vmla.f32 q9, q13, d1[1] - vld1.32 {q12}, [r5]! - vmla.f32 q8, q12, d2[0] - vld1.32 {q13}, [r6]! - vmla.f32 q9, q13, d2[1] - vld1.32 {q12}, [r7]! - vmla.f32 q8, q12, d3[0] - - vadd.f32 q9, q8, q9 - vst1.32 {q9}, [r2]! - subs r11, r11, #1 - bne LoopUnitK7 - sub r2, r2, r8 - sub r12, r12, #7 - add r0, r7, lr - vmov.32 r1, d30[0] - cmp r12, #7 - bge LoopK7 - pop {r3-r7} - - LK4: - cmp r12, #4 - blt LK3 - vmov.32 d30[1], r3 - vmov.32 d31[0], r4 - LoopK4: - vld1.32 {d0[0]}, [r1], r10 - vld1.32 {d0[1]}, [r1], r10 - vld1.32 {d1[0]}, [r1], r10 - vld1.32 {d1[1]}, [r1], r10 - mov r11, r6 - vmov.32 d30[0], r1 - - add r1, r0, r7 - add r3, r1, r7 - add r4, r3, r7 - - LoopUnitK4: - vld1.32 {q8}, [r2] - vld1.32 {q12}, [r0]! - vmla.f32 q8, q12, d0[0] - vld1.32 {q13}, [r1]! - vmul.f32 q9, q13, d0[1] - vld1.32 {q12}, [r3]! - vmla.f32 q8, q12, d1[0] - vld1.32 {q13}, [r4]! - vmla.f32 q9, q13, d1[1] - - vadd.f32 q9, q8, q9 - vst1.32 {q9}, [r2]! - subs r11, r11, #1 - bne LoopUnitK4 - sub r2, r2, r8 - sub r12, r12, #4 - add r0, r4, lr - vmov.32 r1, d30[0] - cmp r12, #4 - bge LoopK4 - vmov.32 r3, d30[1] - vmov.32 r4, d31[0] - - LK3: - cmp r12, #3 - blt LK1 - vmov.32 d30[1], r3 - vmov.32 d31[0], r4 - LoopK3: - vld1.32 {d0[0]}, [r1], r10 - vld1.32 {d0[1]}, [r1], r10 - vld1.32 {d1[0]}, [r1], r10 - mov r11, r6 - vmov.32 d30[0], r1 - - add r1, r0, r7 - add r3, r1, r7 - - LoopUnitK3: - vld1.32 {q8}, [r2] - vld1.32 {q12}, [r0]! - vmla.f32 q8, q12, d0[0] - vld1.32 {q13}, [r1]! - vmul.f32 q9, q13, d0[1] - vld1.32 {q12}, [r3]! - vmla.f32 q8, q12, d1[0] - - vadd.f32 q9, q8, q9 - vst1.32 {q9}, [r2]! - subs r11, r11, #1 - bne LoopUnitK3 - sub r2, r2, r8 - sub r12, r12, #3 - add r0, r3, lr - vmov.32 r1, d30[0] - cmp r12, #3 - bge LoopK3 - vmov.32 r3, d30[1] - vmov.32 r4, d31[0] - - - - LK1: - cmp r12, #0 - beq LKEnd - - LoopK: - vld1.32 {d30[0]}, [r1], r10 - - vdup.32 q15, d30[0] - mov r11, r6 - LoopUnit: - vld1.32 {q0}, [r2] - vld1.32 {q1}, [r0]! - vmla.f32 q0, q1, q15 - - vst1.32 {q0}, [r2]! - subs r11, r11, #1 - bne LoopUnit - subs r12, r12, #1 - - sub r2, r2, r8 - add r0, r0, lr - bne LoopK - LKEnd: - pop {r0, r1} - subs r3, r3, #1 - add r0, r0, r8 - add r2, r2, r8 - - bne LoopX - pop {r0, r3} - add r1, r1, #4 //sizeof(float) - - subs r4, r4, #1 - bne LoopY - - - -pop {r4-r8, r10, r11, pc} - -#endif -#endif diff --git a/source/backend/cpu/arm/arm32/MNNWinogradMatrixProductRight.S b/source/backend/cpu/arm/arm32/MNNWinogradMatrixProductRight.S deleted file mode 100644 index b0a97197c..000000000 --- a/source/backend/cpu/arm/arm32/MNNWinogradMatrixProductRight.S +++ /dev/null @@ -1,223 +0,0 @@ -// -// MNNWinogradMatrixProductRight.S -// MNN -// -// Created by MNN on 2018/08/22. -// Copyright © 2018, Alibaba Group Holding Limited -// - -#ifdef __arm__ -#ifndef __aarch64__ - -#include "MNNAsmGlobal.h" - -.text -.align 5 - -asm_function MNNWinogradMatrixProductRight -//void MNNWinogradMatrixProductRight(const float* S, const float* B, float* M, size_t w, size_t h, size_t k, size_t length); - -//Auto: r0: S, r1:B, r2: M, r3:w -//Load From sp: r4:h, r5:k, r6:length - -push {r4-r8, r10, r11, lr} // avoid to touch platform-register r-9 -ldr r4, [sp, #32] -ldr r5, [sp, #36] -ldr r6, [sp, #40] - -//unitStepInFloat -mov r8, #16 // 4*sizeof(float) -mul r8, r6, r8 - -//srcYUnitStep -mul lr, r5, r8 - -//B's step -mov r10, #4 -mul r10, r4, r10 - -LoopY: - push {r1, r3} - LoopX: - push {r0, r1} - vmov.i32 q14, #0 - mov r11, r6 - LoopUnitSetZero: - vst1.32 {q14}, [r2]! - subs r11, r11, #1 - bne LoopUnitSetZero - sub r2, r2, r8 - mov r12, r5 - - LK7: - cmp r12, #7 - blt LK4 - push {r3-r7} - LoopK7: - vld1.32 {d0[0]}, [r1], r10 - vld1.32 {d0[1]}, [r1], r10 - vld1.32 {d1[0]}, [r1], r10 - vld1.32 {d1[1]}, [r1], r10 - vld1.32 {d2[0]}, [r1], r10 - vld1.32 {d2[1]}, [r1], r10 - vld1.32 {d3[0]}, [r1], r10 - mov r11, r6 - vmov.32 d30[0], r1 - - add r1, r0, r8 - add r3, r1, r8 - add r4, r3, r8 - add r5, r4, r8 - add r6, r5, r8 - add r7, r6, r8 - - LoopUnitK7: - vld1.32 {q8}, [r2] - vld1.32 {q12}, [r0]! - vmla.f32 q8, q12, d0[0] - vld1.32 {q13}, [r1]! - vmul.f32 q9, q13, d0[1] - vld1.32 {q12}, [r3]! - vmla.f32 q8, q12, d1[0] - vld1.32 {q13}, [r4]! - vmla.f32 q9, q13, d1[1] - vld1.32 {q12}, [r5]! - vmla.f32 q8, q12, d2[0] - vld1.32 {q13}, [r6]! - vmla.f32 q9, q13, d2[1] - vld1.32 {q12}, [r7]! - vmla.f32 q8, q12, d3[0] - - vadd.f32 q9, q8, q9 - vst1.32 {q9}, [r2]! - subs r11, r11, #1 - bne LoopUnitK7 - sub r2, r2, r8 - sub r12, r12, #7 - mov r0, r7 - vmov.32 r1, d30[0] - cmp r12, #7 - bge LoopK7 - pop {r3-r7} - - LK4: - cmp r12, #4 - blt LK3 - vmov.32 d30[1], r3 - vmov.32 d31[0], r4 - LoopK4: - vld1.32 {d0[0]}, [r1], r10 - vld1.32 {d0[1]}, [r1], r10 - vld1.32 {d1[0]}, [r1], r10 - vld1.32 {d1[1]}, [r1], r10 - mov r11, r6 - vmov.32 d30[0], r1 - - add r1, r0, r8 - add r3, r1, r8 - add r4, r3, r8 - - LoopUnitK4: - vld1.32 {q8}, [r2] - vld1.32 {q12}, [r0]! - vmla.f32 q8, q12, d0[0] - vld1.32 {q13}, [r1]! - vmul.f32 q9, q13, d0[1] - vld1.32 {q12}, [r3]! - vmla.f32 q8, q12, d1[0] - vld1.32 {q13}, [r4]! - vmla.f32 q9, q13, d1[1] - - vadd.f32 q9, q8, q9 - vst1.32 {q9}, [r2]! - subs r11, r11, #1 - bne LoopUnitK4 - sub r2, r2, r8 - - sub r12, r12, #4 - - mov r0, r4 - vmov.32 r1, d30[0] - cmp r12, #4 - bge LoopK4 - vmov.32 r3, d30[1] - vmov.32 r4, d31[0] - - LK3: - cmp r12, #3 - blt LK1 - vmov.32 d30[1], r3 - LoopK3: - vld1.32 {d0[0]}, [r1], r10 - vld1.32 {d0[1]}, [r1], r10 - vld1.32 {d1[0]}, [r1], r10 - mov r11, r6 - vmov.32 d30[0], r1 - - add r1, r0, r8 - add r3, r1, r8 - - LoopUnitK3: - vld1.32 {q8}, [r2] - vld1.32 {q12}, [r0]! - vmla.f32 q8, q12, d0[0] - vld1.32 {q13}, [r1]! - vmul.f32 q9, q13, d0[1] - vld1.32 {q12}, [r3]! - vmla.f32 q8, q12, d1[0] - - vadd.f32 q9, q8, q9 - vst1.32 {q9}, [r2]! - subs r11, r11, #1 - bne LoopUnitK3 - sub r2, r2, r8 - - sub r12, r12, #3 - - mov r0, r3 - vmov.32 r1, d30[0] - cmp r12, #3 - bge LoopK3 - vmov.32 r3, d30[1] - - - LK1: - cmp r12, #0 - beq LKEnd - - LoopK: - vld1.32 {d30[0]}, [r1], r10 - - vdup.32 q15, d30[0] - mov r11, r6 - LoopUnit: - vld1.32 {q0}, [r2] - vld1.32 {q1}, [r0]! - vmla.f32 q0, q1, q15 - - vst1.32 {q0}, [r2]! - subs r11, r11, #1 - bne LoopUnit - subs r12, r12, #1 - - sub r2, r2, r8 - bne LoopK - LKEnd: - pop {r0, r1} - subs r3, r3, #1 - add r2, r2, r8 - add r1, r1, #4 //sizeof(float) - - bne LoopX - pop {r1, r3} - add r0, r0, lr - - subs r4, r4, #1 - bne LoopY - - - -pop {r4-r8, r10, r11, pc} - -#endif -#endif diff --git a/source/backend/cpu/arm/arm64/MNNReluWithSlopeChannelInt8.S b/source/backend/cpu/arm/arm64/MNNReluWithSlopeChannelInt8.S index e1622504c..4128118f9 100644 --- a/source/backend/cpu/arm/arm64/MNNReluWithSlopeChannelInt8.S +++ b/source/backend/cpu/arm/arm64/MNNReluWithSlopeChannelInt8.S @@ -25,8 +25,10 @@ asm_function MNNReluWithSlopeChannelInt8 // MNNReluWithSlopeChannelInt8(int8_t* dst, const int8_t* src, const float* slope, size_t planeNumber, size_t depthQuad, QuanPrePostParameters *params) // Auto load: // x0: dst, x1: src, x2: slope, x3: planeNumber, x4: depthQuad, x5: params -// Load from x5: x8: inputZeroPoint, x9: outputZeroPoint, x10: minValue, x11: maxValue +// Load from x5: x9: outputZeroPoint, x10: minValue, x11: maxValue +ldr x12, [x5, #0] +ldr x13, [x5, #8] ldr x8, [x5, #16] ldr x9, [x5, #24] ldr x10, [x5, #32] @@ -43,10 +45,12 @@ beq End cmp x4, #0 beq End -ld1r {v29.8b}, [x8] // inputZeroPoint -ld1r {v28.8b}, [x9] // outputZeroPoint +ld1r {v29.16b}, [x8] // inputZeroPoint +ld1r {v28.16b}, [x9] // outputZeroPoint dup v26.16b, w10 dup v27.16b, w11 +ld1r {v24.4s}, [x12] // inputscale +ld1r {v25.4s}, [x13] // outputscale /* Quant parameters */ @@ -60,7 +64,6 @@ ble PReluL1 PReluL4Loop: ld1 {v0.16b}, [x1], #16 -cmlt v30.16b, v0.16b, #0 // mask0: x<0 sxtl v1.8h, v0.8b sxtl2 v2.8h, v0.16b @@ -76,10 +79,33 @@ scvtf v4.4s, v4.4s scvtf v5.4s, v5.4s scvtf v6.4s, v6.4s -fmul v3.4s, v3.4s, v31.4s -fmul v4.4s, v4.4s, v31.4s -fmul v5.4s, v5.4s, v31.4s -fmul v6.4s, v6.4s, v31.4s +// input_scale +fmul v3.4s, v3.4s, v24.4s +fmul v4.4s, v4.4s, v24.4s +fmul v5.4s, v5.4s, v24.4s +fmul v6.4s, v6.4s, v24.4s + +fcmle v7.4s, v3.4s, #0 +fcmle v8.4s, v4.4s, #0 +fcmle v9.4s, v5.4s, #0 +fcmle v10.4s, v6.4s, #0 + +// *slope +fmul v11.4s, v3.4s, v31.4s +fmul v12.4s, v4.4s, v31.4s +fmul v13.4s, v5.4s, v31.4s +fmul v14.4s, v6.4s, v31.4s + +bit v3.16b, v11.16b, v7.16b +bit v4.16b, v12.16b, v8.16b +bit v5.16b, v13.16b, v9.16b +bit v6.16b, v14.16b, v10.16b + +// *output_scale +fmul v3.4s, v3.4s, v25.4s +fmul v4.4s, v4.4s, v25.4s +fmul v5.4s, v5.4s, v25.4s +fmul v6.4s, v6.4s, v25.4s fcvtas v3.4s, v3.4s fcvtas v4.4s, v4.4s @@ -99,8 +125,7 @@ sqxtn2 v9.16b, v8.8h smax v9.16b, v9.16b, v26.16b smin v9.16b, v9.16b, v27.16b -bit v0.16b, v9.16b, v30.16b -st1 {v0.16b}, [x0], #16 +st1 {v9.16b}, [x0], #16 sub x5, x5, #4 cmp x5, #4 @@ -113,13 +138,20 @@ beq PReluL1End PReluL1Loop: ld1 {v0.s}[0], [x1], #4 -cmlt v30.8b, v0.8b, #0 sxtl v1.8h, v0.8b ssubw v1.8h, v1.8h, v29.8b sxtl v1.4s, v1.4h scvtf v1.4s, v1.4s -fmul v1.4s, v1.4s, v31.4s +// *input_scale +fmul v1.4s, v1.4s, v24.4s +fcmle v7.4s, v1.4s, #0 +// *slope +fmul v11.4s, v1.4s, v31.4s +bit v1.16b, v11.16b, v7.16b +// *output_scale +fmul v1.4s, v1.4s, v25.4s + fcvtas v1.4s, v1.4s sqxtn v1.4h, v1.4s saddw v1.8h, v1.8h, v28.8b @@ -127,8 +159,7 @@ sqxtn v1.8b, v1.8h smax v1.8b, v1.8b, v26.8b smin v1.8b, v1.8b, v27.8b -bit v0.8b, v1.8b, v30.8b -st1 {v0.s}[0], [x0], #4 +st1 {v1.s}[0], [x0], #4 subs x5, x5, #1 bne PReluL1Loop @@ -144,4 +175,4 @@ End: ldp d14, d15, [sp], #64 ret -#endif \ No newline at end of file +#endif diff --git a/source/backend/cpu/arm/arm64/MNNWinogradMatrixProductLeft.S b/source/backend/cpu/arm/arm64/MNNWinogradMatrixProductLeft.S deleted file mode 100644 index f013aac62..000000000 --- a/source/backend/cpu/arm/arm64/MNNWinogradMatrixProductLeft.S +++ /dev/null @@ -1,171 +0,0 @@ -// -// MNNWinogradMatrixProductLeft.S -// MNN -// -// Created by MNN on 2018/08/22. -// Copyright © 2018, Alibaba Group Holding Limited -// - -#ifdef __aarch64__ - -#include "MNNAsmGlobal.h" - -.text -.align 5 - -asm_function MNNWinogradMatrixProductLeft -//void MNNWinogradMatrixProductLeft(const float* S, const float* B, float* M, size_t w, size_t h, size_t k, size_t length); - -//Auto: x0: S, x1:B, x2: M, x3:w, x4:h, x5:k, x6:length - -//unitStepInFloat -mov x8, #16 // 4*sizeof(float) -mul x8, x6, x8 - -//srcYUnitStep -mul x9, x3, x8 -sub x9, x9, x8 -add x7, x9, x8 - -//B's step -mov x10, #4 -mul x10, x4, x10 - -LoopY: - mov v4.d[0], x0 - mov v4.d[1], x3 - LoopX: - mov v5.d[0], x0 - mov v5.d[1], x1 - movi v30.4s, #0 - mov x11, x6 - LoopUnitSetZero: - st1 {v30.4s}, [x2], #16 - subs x11, x11, #1 - bne LoopUnitSetZero - sub x2, x2, x8 - mov x12, x5 - - LK4: - cmp x12, #4 - blt LK3 - mov v6.d[0], x3 - mov v6.d[1], x4 - LoopK4: - ld1 {v0.s}[0], [x1], x10 - ld1 {v0.s}[1], [x1], x10 - ld1 {v0.s}[2], [x1], x10 - ld1 {v0.s}[3], [x1], x10 - mov x11, x6 - mov v7.d[0], x1 - - add x1, x0, x7 - add x3, x1, x7 - add x4, x3, x7 - - LoopUnitK4: - ld1 {v16.4s}, [x2] - ld1 {v20.4s}, [x0], #16 - fmla v16.4s, v20.4s, v0.s[0] - ld1 {v21.4s}, [x1], #16 - fmul v17.4s, v21.4s, v0.s[1] - ld1 {v20.4s}, [x3], #16 - fmla v16.4s, v20.4s, v0.s[2] - ld1 {v21.4s}, [x4], #16 - fmla v17.4s, v21.4s, v0.s[3] - - fadd v17.4s, v16.4s, v17.4s - st1 {v17.4s}, [x2], #16 - subs x11, x11, #1 - bne LoopUnitK4 - sub x2, x2, x8 - - sub x12, x12, #4 - - add x0, x4, x9 - mov x1, v7.d[0] - cmp x12, #4 - bge LoopK4 - mov x3, v6.d[0] - mov x4, v6.d[1] - - LK3: - cmp x12, #3 - blt LK1 - mov v6.d[0], x3 - LoopK3: - ld1 {v0.s}[0], [x1], x10 - ld1 {v0.s}[1], [x1], x10 - ld1 {v0.s}[2], [x1], x10 - mov x11, x6 - mov v7.d[0], x1 - - add x1, x0, x7 - add x3, x1, x7 - - LoopUnitK3: - ld1 {v16.4s}, [x2] - ld1 {v20.4s}, [x0], #16 - fmla v16.4s, v20.4s, v0.s[0] - ld1 {v21.4s}, [x1], #16 - fmul v17.4s, v21.4s, v0.s[1] - ld1 {v20.4s}, [x3], #16 - fmla v16.4s, v20.4s, v0.s[2] - - fadd v17.4s, v16.4s, v17.4s - st1 {v17.4s}, [x2], #16 - subs x11, x11, #1 - bne LoopUnitK3 - sub x2, x2, x8 - - sub x12, x12, #3 - - add x0, x3, x9 - mov x1, v7.d[0] - cmp x12, #3 - bge LoopK3 - mov x3, v6.d[0] - - - LK1: - cmp x12, #0 - beq LKEnd - - LoopK: - ld1 {v31.s}[0], [x1], x10 - - dup v31.4s, v31.s[0] - mov x11, x6 - LoopUnit: - ld1 {v0.4s}, [x2] - ld1 {v1.4s}, [x0], #16 - fmla v0.4s, v1.4s, v31.4s - - st1 {v0.4s}, [x2], #16 - subs x11, x11, #1 - bne LoopUnit - subs x12, x12, #1 - - sub x2, x2, x8 - add x0, x0, x9 - bne LoopK - LKEnd: - mov x0, v5.d[0] - mov x1, v5.d[1] - subs x3, x3, #1 - add x0, x0, x8 - add x2, x2, x8 - - bne LoopX - mov x0, v4.d[0] - mov x3, v4.d[1] - add x1, x1, #4 //sizeof(float) - - subs x4, x4, #1 - bne LoopY - - - - ret - -#endif diff --git a/source/backend/cpu/arm/arm64/MNNWinogradMatrixProductRight.S b/source/backend/cpu/arm/arm64/MNNWinogradMatrixProductRight.S deleted file mode 100644 index 5542e3a93..000000000 --- a/source/backend/cpu/arm/arm64/MNNWinogradMatrixProductRight.S +++ /dev/null @@ -1,164 +0,0 @@ -// -// MNNWinogradMatrixProductRight.S -// MNN -// -// Created by MNN on 2018/08/22. -// Copyright © 2018, Alibaba Group Holding Limited -// - -#ifdef __aarch64__ - -#include "MNNAsmGlobal.h" - -.text -.align 5 - -asm_function MNNWinogradMatrixProductRight -//void MNNWinogradMatrixProductRight(const float* S, const float* B, float* M, size_t w, size_t h, size_t k, size_t length); - -//Auto: x0: S, x1:B, x2: M, x3:w, x4:h, x5:k, x6:length - -//unitStepInFloat -mov x8, #16 // 4*sizeof(float) -mul x8, x6, x8 - -//srcYUnitStep -mul x9, x5, x8 - -//B's step -mov x10, #4 -mul x10, x4, x10 - -LoopY: - mov v4.d[0], x1 - mov v4.d[1], x3 - LoopX: - mov v5.d[0], x0 - mov v5.d[1], x1 - movi v30.4s, #0 - mov x11, x6 - LoopUnitSetZero: - st1 {v30.4s}, [x2], #16 - subs x11, x11, #1 - bne LoopUnitSetZero - sub x2, x2, x8 - mov x12, x5 - - LK4: - cmp x12, #4 - blt LK3 - mov v6.d[0], x3 - mov v6.d[1], x4 - LoopK4: - ld1 {v0.s}[0], [x1], x10 - ld1 {v0.s}[1], [x1], x10 - ld1 {v0.s}[2], [x1], x10 - ld1 {v0.s}[3], [x1], x10 - mov x11, x6 - mov v7.d[0], x1 - - add x1, x0, x8 - add x3, x1, x8 - add x4, x3, x8 - - LoopUnitK4: - ld1 {v16.4s}, [x2] - ld1 {v20.4s}, [x0], #16 - fmla v16.4s, v20.4s, v0.s[0] - ld1 {v21.4s}, [x1], #16 - fmul v17.4s, v21.4s, v0.s[1] - ld1 {v20.4s}, [x3], #16 - fmla v16.4s, v20.4s, v0.s[2] - ld1 {v21.4s}, [x4], #16 - fmla v17.4s, v21.4s, v0.s[3] - - fadd v17.4s, v16.4s, v17.4s - st1 {v17.4s}, [x2], #16 - subs x11, x11, #1 - bne LoopUnitK4 - sub x2, x2, x8 - sub x12, x12, #4 - mov x0, x4 - - mov x1, v7.d[0] - cmp x12, #4 - bge LoopK4 - mov x3, v6.d[0] - mov x4, v6.d[1] - - LK3: - cmp x12, #3 - blt LK1 - mov v6.d[0], x3 - LoopK3: - ld1 {v0.s}[0], [x1], x10 - ld1 {v0.s}[1], [x1], x10 - ld1 {v0.s}[2], [x1], x10 - mov x11, x6 - mov v7.d[0], x1 - - add x1, x0, x8 - add x3, x1, x8 - - LoopUnitK3: - ld1 {v16.4s}, [x2] - ld1 {v20.4s}, [x0], #16 - fmla v16.4s, v20.4s, v0.s[0] - ld1 {v21.4s}, [x1], #16 - fmul v17.4s, v21.4s, v0.s[1] - ld1 {v20.4s}, [x3], #16 - fmla v16.4s, v20.4s, v0.s[2] - - fadd v17.4s, v16.4s, v17.4s - st1 {v17.4s}, [x2], #16 - subs x11, x11, #1 - bne LoopUnitK3 - sub x2, x2, x8 - sub x12, x12, #3 - mov x0, x4 - mov x1, v7.d[0] - cmp x12, #3 - bge LoopK3 - mov x3, v6.d[0] - - LK1: - cmp x12, #0 - beq LKEnd - - LoopK: - ld1 {v31.s}[0], [x1], x10 - - dup v31.4s, v31.s[0] - mov x11, x6 - LoopUnit: - ld1 {v0.4s}, [x2] - ld1 {v1.4s}, [x0], #16 - fmla v0.4s, v1.4s, v31.4s - - st1 {v0.4s}, [x2], #16 - subs x11, x11, #1 - bne LoopUnit - subs x12, x12, #1 - - sub x2, x2, x8 - bne LoopK - LKEnd: - mov x0, v5.d[0] - mov x1, v5.d[1] - subs x3, x3, #1 - add x2, x2, x8 - add x1, x1, #4 //sizeof(float) - - bne LoopX - mov x1, v4.d[0] - mov x3, v4.d[1] - add x0, x0, x9 - - subs x4, x4, #1 - bne LoopY - - - - ret - -#endif diff --git a/source/backend/cpu/compute/CommonOptFunction.cpp b/source/backend/cpu/compute/CommonOptFunction.cpp index 0eefd7bb9..734851753 100644 --- a/source/backend/cpu/compute/CommonOptFunction.cpp +++ b/source/backend/cpu/compute/CommonOptFunction.cpp @@ -23,6 +23,9 @@ #include "../CPUBinary.hpp" #include "../CPUUnary.hpp" #include "../CPUPool.hpp" +#ifndef M_PI +#define M_PI 3.141592654 +#endif #define PACK 4 #define FLOAT float using Vec = MNN::Math::Vec; @@ -314,7 +317,7 @@ static void MNNSumByAxisLForMatmul_A(float* dest, int8_t* source, const float* s dest += (step * blockNum); realDstCount -= step; srcInt8 += col_buffer_unit_size; - } while(realDstCount > 0); + } while(realDstCount > 0); } template @@ -3099,6 +3102,21 @@ void MNNSiLuLowp(float* dst, const float* src, size_t dataSize) { #endif } +void MNNDftAbs(const float* input, const float* window, float* output, float* buffer, int nfft) { + for (int i = 0; i < nfft; ++i) { + buffer[i] = input[i] * window[i]; + } + for (int k = 0; k < nfft / 2 + 1; ++k) { + float real_sum = 0.f, imag_sum = 0.f; + for (int n = 0; n < nfft; ++n) { + float angle = 2 * M_PI * k * n / nfft; + real_sum += buffer[n] * std::cos(angle); + imag_sum -= buffer[n] * std::sin(angle); + } + output[k] = std::sqrt(real_sum * real_sum + imag_sum * imag_sum); + } +} + static void _MNNAdjustOptimalSparseKernel(int& sparseBlockOC, MNN::CoreFunctions::MNNPackedSparseMatMul& packedSparseMatMul) { if(sparseBlockOC == 4) { packedSparseMatMul = MNNPackedSparseMatMulEpx4; @@ -3202,7 +3220,7 @@ void MNNCoreFunctionInit() { gCoreFunction->MNNFp16ToFp8 = MNNFp16ToFp8; gCoreFunction->MNNFp8ToFp32 = MNNFp8ToFp32; gCoreFunction->MNNFp8ToFp16 = MNNFp8ToFp16; - + // MatMul gCoreFunction->MNNGetMatMulPackMode = MNNGetMatMulPackMode; gCoreFunction->MNNPackC4ForMatMul_A = MNNPackC4ForMatMul_A; diff --git a/source/backend/cpu/compute/CommonOptFunction.h b/source/backend/cpu/compute/CommonOptFunction.h index 9dac6d66e..0159aa286 100644 --- a/source/backend/cpu/compute/CommonOptFunction.h +++ b/source/backend/cpu/compute/CommonOptFunction.h @@ -101,6 +101,7 @@ void MNNGeluCommon(float* dst, const float* src, size_t size); void MNNGeluStandardCommon(float* dst, const float* src, size_t size); void MNNSoftmax(float* dest, const float* source, size_t size); void MNNNorm(float* dest, const float* source, const float *gamma, const float *beta, float epsilon, size_t size, bool RMSNorm = false); +void MNNDftAbs(const float* input, const float* window, float* output, float* buffer, int nfft); // Get Pack for MatMul's e , l , h , the pack number must be 1 or 4 * n void MNNGetMatMulPackMode(int* eP, int *lP, int* hP); @@ -313,7 +314,7 @@ struct CoreFunctions { void(*MNNPoolingMax)(const void* channelInput, int inputWidth, int inputHeight, void *channelOutput, int outputWidth, int outputHeight, int kernelWidth, int kernelHeight, int strideWidth, int strideHeight, int padWidth, int padHeight, int padType, int countType); - + void(*MNNPoolingMaxWithRedice)(const void* channelInput, int inputWidth, int inputHeight, void *channelOutput, int outputWidth, int outputHeight, int kernelWidth, int kernelHeight, int strideWidth, int strideHeight, int padWidth, int padHeight, int padType, int countType, int *RediceOutput); diff --git a/source/backend/cpu/compute/Int8FunctionsOpt.cpp b/source/backend/cpu/compute/Int8FunctionsOpt.cpp index 7dd218564..da25f6f95 100644 --- a/source/backend/cpu/compute/Int8FunctionsOpt.cpp +++ b/source/backend/cpu/compute/Int8FunctionsOpt.cpp @@ -29,7 +29,7 @@ void MNNLineDepthWiseInt8AddBiasScaleUnit(int8_t* dst, const int8_t* src, const void MNNMaxPoolInt8(int8_t* dst, int8_t* src, size_t outputWidth, size_t inputWidth, size_t kernelx, size_t kernely, size_t stridesx); void MNNAvgPoolInt8(int8_t* dst, int8_t* src, size_t outputWidth, size_t inputWidth, size_t kernelx, size_t kernely, size_t stridesx, ssize_t paddingx, ssize_t factor); -void MNNReluWithSlopeChannelInt8(int8_t* dst, const int8_t* src, const float* slope, size_t planeNumber, size_t depthQuad, QuanPrePostParameters *params); +void MNNReluWithSlopeChannelInt8(int8_t* dst, const int8_t* src, const float* slope, size_t planeNumber, size_t depthQuad, const QuanPrePostParameters *params, size_t pack); #if defined(__aarch64__) // aarch32 sdot workaround void MNNGemmInt8AddBiasScale_ARMV82_Unit(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDstCount); @@ -1543,7 +1543,7 @@ static void MNNGemmInt8AddBiasScale_16x4_w4_Unit(int8_t* dst, const int8_t* src, } } -static void MNNReluWithSlopeChannelInt8(int8_t* dst, const int8_t* src, const float* slope, size_t planeNumber, size_t depthQuad, QuanPrePostParameters *params) { +static void MNNReluWithSlopeChannelInt8(int8_t* dst, const int8_t* src, const float* slope, size_t planeNumber, size_t depthQuad, const QuanPrePostParameters *params, size_t pack) { #ifdef MNN_USE_SSE float offset = 128.f; uint8_t* srcPtr = (uint8_t*)src; @@ -1554,24 +1554,22 @@ const int8_t* srcPtr = src; int8_t* dstPtr = dst; #endif float mulVal = 0.f; - float inputScale = params->inputScale[0]; - float outputScale = params->outputScale[0]; float inputZero = static_cast(params->inputZeroPoint[0]) + offset; float outputZero = static_cast(params->outputZeroPoint[0]) + offset; int32_t minval = params->minValue + offset; int32_t maxval = params->maxValue + offset; for (int j = 0;j < depthQuad; ++j) { - const float* slopeZ = slope + 4 * j; - const auto srcZ = srcPtr + 4 * j * planeNumber; - auto dstZ = dstPtr + 4 * j * planeNumber; + const float* slopeZ = slope + pack * j; + const auto srcZ = srcPtr + pack * j * planeNumber; + auto dstZ = dstPtr + pack * j * planeNumber; for (int i = 0; i < planeNumber; ++i) { - for (int c = 0; c < 4; ++c) { - if ((float)srcZ[4 * i + c] < inputZero) { - mulVal = (srcZ[4 * i + c] - inputZero) * slopeZ[c]; - dstZ[4 * i + c] = ALIMIN(ALIMAX(static_cast(roundf(mulVal)) + outputZero, minval), maxval); - } else { - dstZ[4 * i + c] = srcZ[4 * i + c]; + for (int c = 0; c < pack; ++c) { + float valInput = (static_cast(srcZ[pack * i + c]) - inputZero) * params->inputScale[0]; + if (valInput < 0) { + valInput *= slopeZ[c]; } + auto mulVal = valInput * params->outputScale[0] + outputZero; + dstZ[pack * i + c] = ALIMIN(ALIMAX(static_cast(roundf(mulVal)), minval), maxval); } } } diff --git a/source/backend/cpu/compute/Int8FunctionsOpt.h b/source/backend/cpu/compute/Int8FunctionsOpt.h index eb405e6e8..460a8dfcd 100644 --- a/source/backend/cpu/compute/Int8FunctionsOpt.h +++ b/source/backend/cpu/compute/Int8FunctionsOpt.h @@ -113,7 +113,7 @@ struct CoreInt8Functions { void (*MNNAvgPoolInt8)(int8_t* dst, int8_t* src, size_t outputWidth, size_t inputWidth, size_t kernelx, size_t kernely, size_t stridesx, ssize_t paddingx, ssize_t factor); // Relu - void (*MNNReluWithSlopeChannelInt8)(int8_t* dst, const int8_t* src, const float* slope, size_t planeNumber, size_t depthQuad, QuanPrePostParameters *params); + void (*MNNReluWithSlopeChannelInt8)(int8_t* dst, const int8_t* src, const float* slope, size_t planeNumber, size_t depthQuad, const QuanPrePostParameters *params, size_t pack); }; void MNNCoreInt8FunctionInit(); CoreInt8Functions* MNNGetInt8CoreFunctions(); diff --git a/source/backend/cpu/compute/WinogradOptFunction.cpp b/source/backend/cpu/compute/WinogradOptFunction.cpp index 23b83eab2..31a602077 100644 --- a/source/backend/cpu/compute/WinogradOptFunction.cpp +++ b/source/backend/cpu/compute/WinogradOptFunction.cpp @@ -16,77 +16,10 @@ using Vec4 = MNN::Math::Vec; #define DEFAULT_UNIT 8 -extern "C" { -void MNNWinogradMatrixProductLeft(const float* S, const float* B, float* M, size_t w, size_t h, size_t k, - size_t length); -void MNNWinogradMatrixProductRight(const float* S, const float* B, float* M, size_t w, size_t h, size_t k, - size_t length); -} - -#ifndef MNN_USE_NEON - -// M = BT * S , M = w*h * l, S = w*k * l, B = h*k -void MNNWinogradMatrixProductLeft(const float* S, const float* B, float* M, size_t w, size_t h, size_t k, - size_t length) { - auto unitStep = 4 * length; - for (int y = 0; y < h; ++y) { - auto dstY = M + y * w * unitStep; - for (int x = 0; x < w; ++x) { - auto dstX = dstY + x * unitStep; - auto srcX = S + x * unitStep; - ::memset(dstX, 0, unitStep * sizeof(float)); - for (int i = 0; i < k; ++i) { - auto b = B[i * h + y]; - auto srcY = srcX + i * w * unitStep; - if (0.0f == b) { - continue; - } - for (int j = 0; j < unitStep; ++j) { - dstX[j] += srcY[j] * b; - } - } - } - } -} -// M = S * B , M = w*h * l, S = k*h * l, B = w*k -void MNNWinogradMatrixProductRight(const float* S, const float* B, float* M, size_t w, size_t h, size_t k, - size_t length) { - auto unitStep = 4 * length; - for (int y = 0; y < h; ++y) { - auto dstY = M + y * w * unitStep; - auto srcY = S + y * k * unitStep; - - for (int x = 0; x < w; ++x) { - auto dstX = dstY + x * unitStep; - ::memset(dstX, 0, unitStep * sizeof(float)); - for (int i = 0; i < k; ++i) { - auto srcX = srcY + i * unitStep; - auto b = B[i * h + x]; - if (0.0f == b) { - continue; - } - for (int j = 0; j < unitStep; ++j) { - dstX[j] += srcX[j] * b; - } - } - } - } -} -#endif namespace MNN { - -void WinogradFunction::productLeft(const float* S, const float* B, float* M, size_t w, size_t h, size_t k, - size_t length) { - MNNWinogradMatrixProductLeft(S, B, M, w, h, k, length); -} - -void WinogradFunction::productRight(const float* S, const float* B, float* M, size_t w, size_t h, size_t k, - size_t length) { - MNNWinogradMatrixProductRight(S, B, M, w, h, k, length); -} int WinogradFunction::getPreferNumber() { return DEFAULT_UNIT; } diff --git a/source/backend/cpu/compute/WinogradOptFunction.hpp b/source/backend/cpu/compute/WinogradOptFunction.hpp index 579811cc5..8608a2fde 100644 --- a/source/backend/cpu/compute/WinogradOptFunction.hpp +++ b/source/backend/cpu/compute/WinogradOptFunction.hpp @@ -15,9 +15,6 @@ namespace MNN { class WinogradFunction { public: - static void productLeft(const float* S, const float* B, float* M, size_t w, size_t h, size_t k, size_t length); - static void productRight(const float* S, const float* B, float* M, size_t w, size_t h, size_t k, size_t length); - static int getPreferNumber(); typedef void (*TransformFunc)(const float* srcBlock, float* dstStart, size_t srcStep, size_t dstStep); diff --git a/source/backend/cpu/x86_x64/FunctionDispatcher.cpp b/source/backend/cpu/x86_x64/FunctionDispatcher.cpp index 21c8bd408..fc82e6971 100644 --- a/source/backend/cpu/x86_x64/FunctionDispatcher.cpp +++ b/source/backend/cpu/x86_x64/FunctionDispatcher.cpp @@ -132,7 +132,6 @@ void MNNInt8FunctionInit() { auto core = MNN::MNNGetInt8CoreFunctions(); core->MNNAvgPoolInt8 = MNNAvgPoolUint8; core->MNNMaxPoolInt8 = MNNMaxPoolInt8_; - core->MNNReluWithSlopeChannelInt8 = _SSE_MNNReluWithSlopeChannelInt8; if (cpuFlags & libyuv::kCpuHasSSE41) { core->MNNFloat2Int8 = _SSE_MNNFloat2Int8; core->MNNInt8ScaleToFloat = _SSE_MNNInt8ScaleToFloat; diff --git a/source/backend/cpu/x86_x64/sse/FunctionSummary.hpp b/source/backend/cpu/x86_x64/sse/FunctionSummary.hpp index 5f8653066..0867c7c07 100644 --- a/source/backend/cpu/x86_x64/sse/FunctionSummary.hpp +++ b/source/backend/cpu/x86_x64/sse/FunctionSummary.hpp @@ -36,7 +36,7 @@ void _SSE_MNNAddC4WithStride(const float* source, float* dest, size_t srcStride, void _SSE_MNNReluWithSlopeChannel(float* dst, const float* src, const float* slope, size_t sizeQuad, size_t depthQuad); void _SSE_MNNGelu(float* dst, const float* src, size_t size, float* parameters); -void _SSE_MNNReluWithSlopeChannelInt8(int8_t* dst, const int8_t* src, const float* slope, size_t planeNumber, size_t depthQuad, QuanPrePostParameters *params); +void _SSE_MNNReluWithSlopeChannelInt8(int8_t* dst, const int8_t* src, const float* slope, size_t planeNumber, size_t depthQuad, const QuanPrePostParameters *params, size_t pack); void _SSE_MNNHardSwish(float* dst, const float* src, size_t size); diff --git a/source/backend/cpu/x86_x64/sse/MathFunctions.cpp b/source/backend/cpu/x86_x64/sse/MathFunctions.cpp index b9e857006..8238cb187 100644 --- a/source/backend/cpu/x86_x64/sse/MathFunctions.cpp +++ b/source/backend/cpu/x86_x64/sse/MathFunctions.cpp @@ -290,46 +290,3 @@ void _SSE_MNNNorm(float *dst, const float *src, const float *gamma, const float } } -void _SSE_MNNReluWithSlopeChannelInt8(int8_t* dst, const int8_t* src, const float* slope, size_t planeNumber, size_t depthQuad, QuanPrePostParameters *params) { - uint8_t* dstO = (uint8_t*)dst; - uint8_t* srcO = (uint8_t*)src; - auto outputZero = _mm_set1_ps(static_cast(params->outputZeroPoint[0])); - __m128 maxValue = _mm_set1_ps(params->maxValue); - __m128 minValue = _mm_set1_ps(params->minValue); - auto offset = _mm_set1_epi32(128); - auto zero = _mm_set1_epi32(0); - __m128 plus = _mm_set1_ps(0.5f); - __m128 minus = _mm_set1_ps(-0.5f); - __m128i zeroPointValue = _mm_set1_epi32(static_cast(params->inputZeroPoint[0]) + 128); - for (int j = 0;j < depthQuad; ++j) { - auto slopeZ = _mm_loadu_ps(slope + 4 * j); - const uint8_t* srcZ = srcO + 4 * j * planeNumber; - uint8_t* dstZ = dstO + 4 * j * planeNumber; - int32_t srcZ_ext[4] = {*(int32_t*)srcZ, 0, 0, 0}; - for (int i = 0; i < planeNumber; ++i) { - // auto srcData8 = _mm_loadu_si32(srcZ); - auto srcData8 = _mm_castps_si128(_mm_loadu_ps((float*)srcZ_ext)); - auto srcData16 = _mm_unpacklo_epi8(srcData8, zero); - auto srcData32 = _mm_unpacklo_epi16(srcData16, zero); - srcData32 = _mm_sub_epi32(srcData32, zeroPointValue); - auto srcDataf = _mm_cvtepi32_ps(srcData32); - auto mask1 = _mm_cmplt_ps(srcDataf, _mm_castsi128_ps(zero)); - auto mask0 = _mm_cmpge_ps(srcDataf, _mm_castsi128_ps(zero)); - auto f = _mm_mul_ps(srcDataf, slopeZ); - f = _mm_add_ps(f, outputZero); - f = _mm_min_ps(f, maxValue); - f = _mm_max_ps(f, minValue); - auto r = _mm_add_ps(_mm_and_ps(srcDataf, mask0), _mm_and_ps(f, mask1)); - auto m0 = _mm_cmplt_ps(r, _mm_castsi128_ps(zero)); - m0 = _mm_blendv_ps(plus, minus, m0); - r = _mm_add_ps(r, m0); - // Round to zero - auto d0 = _mm_cvtps_epi32(_mm_round_ps(r, 3)); - d0 = _mm_add_epi32(d0, offset); - d0 = _mm_packs_epi32(d0, d0); - d0 = _mm_packus_epi16(d0, d0); - *((int*)dstZ + i) = _mm_cvtsi128_si32(d0); - } - } -} - diff --git a/source/backend/metal/AllShader.cpp b/source/backend/metal/AllShader.cpp index 7cafcce6c..afb0c6fda 100644 --- a/source/backend/metal/AllShader.cpp +++ b/source/backend/metal/AllShader.cpp @@ -1428,7 +1428,6 @@ const char* shader_MetalDeconvolution_metal = " int output_height;\n" " int output_size;\n" " int output_slice;\n" -" \n" " int kernel_x;\n" " int kernel_y;\n" " int kernel_size;\n" @@ -1438,12 +1437,10 @@ const char* shader_MetalDeconvolution_metal = " int pad_y;\n" " int dilation_x;\n" " int dilation_y;\n" -" \n" " int delta_ky;\n" " int delta_kx;\n" " int delta_iy;\n" " int delta_ix;\n" -" int has_bias;\n" " int batch;\n" " conv_activation_type activation;\n" "};\n" @@ -1494,8 +1491,8 @@ const char* shader_MetalDeconvolution_metal = " const device M4 *biasTerms [[buffer(4)]],\n" " uint3 gid [[thread_position_in_grid]]) {\n" " if ((int)gid.x >= cst.output_width || (int)gid.y >= cst.output_height || (int)gid.z >= cst.batch*cst.output_slice) return;\n" -" \n" -" FLOAT4 result=FLOAT4(biasTerms[(int)(gid.z/cst.batch)]);\n" +" int oz=(int)gid.z/cst.batch;\n" +" FLOAT4 result=FLOAT4(biasTerms[oz]);\n" " \n" " int oy=(int)gid.y+cst.pad_y;\n" " int ox=(int)gid.x+cst.pad_x;\n" @@ -1512,7 +1509,7 @@ const char* shader_MetalDeconvolution_metal = " int min_iy=(oy-max_ky*cst.dilation_y)/cst.stride_y;\n" " int min_ix=(ox-max_kx*cst.dilation_x)/cst.stride_x;\n" " \n" -" auto z_wt=wt+(int)gid.z*cst.kernel_size;\n" +" auto z_wt=wt+oz*cst.kernel_size;\n" " auto z_in=in+(int)gid.z*cst.input_size;\n" " for (auto ky=max_ky,iy=min_iy; ky >= min_ky; ky -= cst.delta_ky,iy += cst.delta_iy) {\n" " for (auto kx=max_kx,ix=min_ix; kx >= min_kx; kx -= cst.delta_kx,ix += cst.delta_ix) {\n" @@ -1670,6 +1667,7 @@ const char* shader_MetalConvolution1x1_metal = " int batch;\n" " int block_size;\n" " conv_activation_type activation;\n" +" float scale_coef;\n" "};\n" "kernel void conv1x1_g1z4(const device M4 *in [[buffer(0)]],\n" " device M4 *out [[buffer(1)]],\n" @@ -1711,7 +1709,7 @@ const char* shader_MetalConvolution1x1_metal = " constant conv1x1_constants& cst [[buffer(2)]],\n" " const device MNN::char4x4 *wt [[buffer(3)]],\n" " const device M4 *biasTerms [[buffer(4)]],\n" -" const device float4 *dequantScale [[buffer(5)]],\n" +" const device M4 *dequantScale [[buffer(5)]],\n" " uint3 gid [[thread_position_in_grid]]) {\n" " if ((int)gid.x*CONV_UNROLL >= cst.output_size || (int)gid.y >= cst.output_slice || (int)gid.z >= cst.batch) return;\n" " int rx=gid.x*CONV_UNROLL;\n" @@ -1724,8 +1722,8 @@ const char* shader_MetalConvolution1x1_metal = " int computeSize=min(cst.output_size-rx,CONV_UNROLL);\n" " int block=(cst.input_slice+cst.block_size-1)/cst.block_size;\n" " for (int bi=0; bi= cst.output_size || (int)gid.y >= cst.output_slice || (int)gid.z >= cst.batch) return;\n" " int rx=gid.x*CONV_UNROLL;\n" @@ -2187,8 +2185,8 @@ const char* shader_MetalConvolution1x1_metal = " int computeSize=min(cst.output_size-rx,CONV_UNROLL);\n" " int block=(cst.input_slice+cst.block_size-1)/cst.block_size;\n" " for (int bi=0; bi -#include -using namespace metal; -struct Param { - int query_seq_len; - int key_seq_len; - int head_num; - int group; - int head_dim; - float scale; -}; -#define SIMD_GROUP_WIDTH 32 - -kernel void prefill(const device T* input0 [[buffer(0)]], - const device T* input1 [[buffer(1)]], - device T* output [[buffer(2)]], - device T* past_key [[buffer(3)]], -#ifdef FLOAT_MASK - const device T* mask [[buffer(4)]], -#else - const device int* mask [[buffer(4)]], -#endif - constant Param& param [[buffer(5)]], -#ifdef SIMD_GROUP_MATRIX - uint3 gid[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]] -#else - uint3 gid[[thread_position_in_grid]] -#endif -) { -#ifdef SIMD_GROUP_MATRIX - - /* - // Read: - ftype 0~127 ---> input: [M16, K8] - ftype 128~255 ---> input: [K8, N16] - // Write: - ftype 0~255 ---> input: [N2, M2, M8, N8] - */ - - simdgroup_float8x8 sga[2]; - simdgroup_float8x8 sgb[2]; - simdgroup_float8x8 sgd[4]; - for (int i = 0; i < 4; i++){ - sgd[i] = make_filled_simdgroup_matrix(0.f); - } - - int kl = tiitg % 2;// 0~1 - int rcl = tiitg / 2;// 0~15 - - const int slq = gid.x; // q_seq_len/16 -> M/16 - const int slk = gid.y; // k_seq_len/16 -> N/16 - const int z = gid.z; // head_num - - /** Q: - threadgroup: [M16, K8] - each thread: K4 - layout: [M, B, K] -> [M/16, M16, B, K/8, K2, K4] - index : [slq, rcl, z, 0, kl, K4] - offset: ((slq * 16 + rcl) * B + z) * K + (0 * 2 + kl) * 4 + 0 - */ - /** K: - threadgroup: [K8, N16] - each thread: N4 - layout: [N, B/G, K] -> [N/16, N16, B/G, K/8, K2, K4] - index : [slk, rcl, B/G, 0, kl, 0] - offset: ((slk * 16 + rcl) * B/G + z/G) * K + 0 * 8 + kl * 4 + 0 - */ - /** output: - threadgroup: [M16, N16] - each thread: N8 - layout: [B, M, N] -> [B, M/16, M16, N/16, N2, N8] - index : [z, sl, rcl, kl, 0] - offset: (z * M + sl * 16 + rcl) * N + slk * 16 + kl * 8 + 0 - */ - - int group = param.group; - int zin = z / param.group; - int q_seq_len = param.query_seq_len; - int k_seq_len = param.key_seq_len; - int head_num = param.head_num; - int head_dim = param.head_dim; - const int stride = head_num * head_dim / group; - - threadgroup float sdata[256] = {0.f}; - - int idx_slq = slq * 16 + rcl < q_seq_len ? slq * 16 + rcl : q_seq_len - 1; - int idx_slk = slk * 16 + rcl < k_seq_len ? slk * 16 + rcl : k_seq_len - 1; - - auto A_offset = input0 + (idx_slq * head_num + z) * head_dim + (0 * 2 + kl) * 4 + 0; - auto B_offset = input1 + (idx_slk * head_num / group + zin) * head_dim + 0 * 8 + kl * 4 + 0; - - for(int i = 0; i < head_dim; i += 8){ - sdata[rcl * 8 + kl * 4 + 0] = A_offset[i + 0]; - sdata[rcl * 8 + kl * 4 + 1] = A_offset[i + 1]; - sdata[rcl * 8 + kl * 4 + 2] = A_offset[i + 2]; - sdata[rcl * 8 + kl * 4 + 3] = A_offset[i + 3]; - - sdata[128 + (kl * 4 + 0) * 16 + rcl] = B_offset[i + 0]; - sdata[128 + (kl * 4 + 1) * 16 + rcl] = B_offset[i + 1]; - sdata[128 + (kl * 4 + 2) * 16 + rcl] = B_offset[i + 2]; - sdata[128 + (kl * 4 + 3) * 16 + rcl] = B_offset[i + 3]; - threadgroup_barrier(mem_flags::mem_threadgroup); - - simdgroup_load(sga[0], (const threadgroup float*)sdata, 8); - simdgroup_load(sga[1], ((const threadgroup float*)sdata) + 64, 8); - - simdgroup_load(sgb[0], ((const threadgroup float*)sdata) + 128, 16); - simdgroup_load(sgb[1], ((const threadgroup float*)sdata) + 136, 16); - - simdgroup_multiply_accumulate(sgd[0], sga[0], sgb[0], sgd[0]); - simdgroup_multiply_accumulate(sgd[1], sga[1], sgb[0], sgd[1]); - simdgroup_multiply_accumulate(sgd[2], sga[0], sgb[1], sgd[2]); - simdgroup_multiply_accumulate(sgd[3], sga[1], sgb[1], sgd[3]); - threadgroup_barrier(mem_flags::mem_threadgroup); - } - - simdgroup_store(sgd[0], (threadgroup float*)sdata, 8); - simdgroup_store(sgd[1], (threadgroup float*)sdata + 64, 8); - simdgroup_store(sgd[2], (threadgroup float*)sdata + 128, 8); - simdgroup_store(sgd[3], (threadgroup float*)sdata + 192, 8); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // [N2, M2, M8, N8] - float Vscale = (float)param.scale; - - auto xy_out = output + (z * q_seq_len + slq * 16 + rcl) * k_seq_len + slk * 16 + kl * 8 + 0; - if(slq * 16 + rcl < q_seq_len) { - if(slk * 16 + kl * 8 + 0 < k_seq_len) { - auto out0 = sdata[(kl * 16 + rcl) * 8 + 0] * Vscale; - #ifdef FLOAT_MASK - out0 = mask[((slq * 16 + rcl) * k_seq_len + (slk * 16 + kl * 8 + 0))] + out0; - #else - out0 = mask[((slq * 16 + rcl) * key_seq_len + (slk * 16 + kl * 8 + 0))] == 0 ? -FLT_MAX : out0; - #endif - xy_out[0] = out0; - } - if(slk * 16 + kl * 8 + 1 < k_seq_len) { - auto out0 = sdata[(kl * 16 + rcl) * 8 + 1] * Vscale; - #ifdef FLOAT_MASK - out0 = mask[((slq * 16 + rcl) * k_seq_len + (slk * 16 + kl * 8 + 1))] + out0; - #else - out0 = mask[((slq * 16 + rcl) * key_seq_len + (slk * 16 + kl * 8 + 1))] == 0 ? -FLT_MAX : out0; - #endif - xy_out[1] = out0; - } - if(slk * 16 + kl * 8 + 2 < k_seq_len) { - auto out0 = sdata[(kl * 16 + rcl) * 8 + 2] * Vscale; - #ifdef FLOAT_MASK - out0 = mask[((slq * 16 + rcl) * k_seq_len + (slk * 16 + kl * 8 + 2))] + out0; - #else - out0 = mask[((slq * 16 + rcl) * key_seq_len + (slk * 16 + kl * 8 + 2))] == 0 ? -FLT_MAX : out0; - #endif - xy_out[2] = out0; - } - if(slk * 16 + kl * 8 + 3 < k_seq_len) { - auto out0 = sdata[(kl * 16 + rcl) * 8 + 3] * Vscale; - #ifdef FLOAT_MASK - out0 = mask[((slq * 16 + rcl) * k_seq_len + (slk * 16 + kl * 8 + 3))] + out0; - #else - out0 = mask[((slq * 16 + rcl) * key_seq_len + (slk * 16 + kl * 8 + 3))] == 0 ? -FLT_MAX : out0; - #endif - xy_out[3] = out0; - } - if(slk * 16 + kl * 8 + 4 < k_seq_len) { - auto out0 = sdata[(kl * 16 + rcl) * 8 + 4] * Vscale; - #ifdef FLOAT_MASK - out0 = mask[((slq * 16 + rcl) * k_seq_len + (slk * 16 + kl * 8 + 4))] + out0; - #else - out0 = mask[((slq * 16 + rcl) * key_seq_len + (slk * 16 + kl * 8 + 4))] == 0 ? -FLT_MAX : out0; - #endif - xy_out[4] = out0; - } - if(slk * 16 + kl * 8 + 5 < k_seq_len) { - auto out0 = sdata[(kl * 16 + rcl) * 8 + 5] * Vscale; - #ifdef FLOAT_MASK - out0 = mask[((slq * 16 + rcl) * k_seq_len + (slk * 16 + kl * 8 + 5))] + out0; - #else - out0 = mask[((slq * 16 + rcl) * key_seq_len + (slk * 16 + kl * 8 + 5))] == 0 ? -FLT_MAX : out0; - #endif - xy_out[5] = out0; - } - if(slk * 16 + kl * 8 + 6 < k_seq_len) { - auto out0 = sdata[(kl * 16 + rcl) * 8 + 6] * Vscale; - #ifdef FLOAT_MASK - out0 = mask[((slq * 16 + rcl) * k_seq_len + (slk * 16 + kl * 8 + 6))] + out0; - #else - out0 = mask[((slq * 16 + rcl) * key_seq_len + (slk * 16 + kl * 8 + 6))] == 0 ? -FLT_MAX : out0; - #endif - xy_out[6] = out0; - } - if(slk * 16 + kl * 8 + 7 < k_seq_len) { - auto out0 = sdata[(kl * 16 + rcl) * 8 + 7] * Vscale; - #ifdef FLOAT_MASK - out0 = mask[((slq * 16 + rcl) * k_seq_len + (slk * 16 + kl * 8 + 7))] + out0; - #else - out0 = mask[((slq * 16 + rcl) * key_seq_len + (slk * 16 + kl * 8 + 7))] == 0 ? -FLT_MAX : out0; - #endif - xy_out[7] = out0; - } - } - -#else - const int x = gid.x; // query_seq_len - const int y = gid.y; // head_num - const int z = gid.z; // key_seq_len - - if (x >= param.query_seq_len || y >= param.head_num || z >= param.key_seq_len) { - return; - } - int group = param.group; - int query_seq_len = param.query_seq_len; - int key_seq_len = param.key_seq_len; - int head_num = param.head_num; - int head_dim = param.head_dim; - - const int offset = head_num * head_dim; - const int offset_head = y * head_dim; - const int offset_head_kv = (y / param.group) * head_dim; - const device T* A_offset = input0 + x * offset + offset_head; - - float Vscale = (float)param.scale; - - device const T* B_offset = input1 + z * offset / group + offset_head_kv; - const int output_offset = y * query_seq_len * key_seq_len; - float out0 = 0.0; - - for(int i = 0; i < head_dim; ++i){ - float A = (float)(A_offset[i]); - float B = (float)(B_offset[i]); - out0 += B * A; - } - - out0 *= Vscale; - -#ifdef FLOAT_MASK - out0 = mask[((x + 0) * key_seq_len + (z + 0))] + out0; -#else - out0 = mask[((x + 0) * key_seq_len + (z + 0))] == 0 ? -FLT_MAX : out0; -#endif - output[output_offset + x * key_seq_len + z] = (T)out0; -#endif -} - -kernel void decode(const device T* input0 [[buffer(0)]], - const device T* input1 [[buffer(1)]], - device T* output [[buffer(2)]], - device T* past_key [[buffer(3)]], -#ifdef FLOAT_MASK - const device T* mask [[buffer(4)]], -#else - const device int* mask [[buffer(4)]], -#endif - constant Param& param [[buffer(5)]], -#ifdef SIMD_GROUP_REDUCE - uint3 gid[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]] -#else - uint3 gid[[thread_position_in_grid]] -#endif -) { - const int x = gid.x; // query_seq_len - const int y = gid.y; // head_num - const int z = gid.z; // key_seq_len - if (x >= param.query_seq_len || y >= param.head_num || z >= param.key_seq_len) { - return; - } - int group = param.group; - - int key_seq_len = param.key_seq_len; - int head_num = param.head_num; - int head_dim = param.head_dim; - - const int offset = head_num * head_dim; - const int offset_head = y * head_dim; - const int offset_head_kv = (y / param.group) * head_dim; - const device T* A_offset = input0 + x * offset + offset_head; - device T* Pastkey_offset = past_key + z * offset / group + offset_head_kv; - float Vscale = (float)param.scale; - - const device T *B_offset = input1 + offset_head_kv; - float out = 0.0; - -#ifdef SIMD_GROUP_REDUCE - { - for(int i = tiisg; i < head_dim; i+=SIMD_GROUP_WIDTH){ - float A = A_offset[i]; - float B = (float)Pastkey_offset[i]; - - out += A * B; - } - } - out = simd_sum(out); - if(tiisg == 0) { - out *= Vscale; - output[y * key_seq_len + z] = (T)out; - } -#else - { - for(int i = 0; i < head_dim; i++){ - float A = A_offset[i]; - float B = (float)Pastkey_offset[i]; - - out += A * B; - } - } - out *= Vscale; - output[y * key_seq_len + z] = (T)out; -#endif -} - -)metal"; - -static const char* gCopyPastKV = R"metal( -#include -using namespace metal; -struct Param { - int head_count; - int kv_seq_len; - int src_offset; - int dst_offset; -}; -kernel void copy(const device T* input0 [[buffer(0)]], - const device T* input1 [[buffer(1)]], - device T* output0 [[buffer(2)]], - device T* output1 [[buffer(3)]], - constant Param& param [[buffer(4)]], - uint3 gid[[thread_position_in_grid]] -) { - const int x = gid.x; // head_num / group * head_dim / 4 - const int y = gid.y; // kv_seq_len - if (x >= param.head_count || y >= param.kv_seq_len) { - return; - } - const int index = y * param.head_count + x; - output0[param.dst_offset + index] = input0[param.src_offset + index]; - output1[param.dst_offset + index] = input1[param.src_offset + index]; -} -)metal"; - -static const char* gMatMulQKV = R"metal( - -#include -#include -using namespace metal; -struct Param { - int query_seq_len; - int key_seq_len; - int head_num; - int group; - int head_dim; - float scale; -}; -#define SIMD_GROUP_WIDTH 32 -kernel void prefill(const device T* input0 [[buffer(0)]], - const device T* input1 [[buffer(1)]], - device T* output [[buffer(2)]], - device T* past_value [[buffer(3)]], - constant Param& param [[buffer(4)]], -#ifdef SIMD_GROUP_MATRIX - uint3 gid[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]] -#else - uint3 gid[[thread_position_in_grid]] -#endif -) { -#ifdef SIMD_GROUP_MATRIX - /* - // Read: - ftype 0~127 ---> input: [M16, K8] - ftype 128~255 ---> input: [K8, N16] - // Write: - ftype 0~255 ---> input: [N2, M2, M8, N8] - */ - - simdgroup_float8x8 sga[2]; - simdgroup_float8x8 sgb[2]; - simdgroup_float8x8 sgd[4]; - for (int i = 0; i < 4; i++){ - sgd[i] = make_filled_simdgroup_matrix(0.f); - } - - int kl = tiitg % 2;// 0~1 - int rcl = tiitg / 2;// 0~15 - - int nl = tiitg % 4;// 0~3 - int kcl = tiitg / 4;// 0~7 - - const int sl = gid.x; // q_seq_len/16 -> M/16 - const int hm = gid.y; // head_dim/16 -> N/16 - const int z = gid.z; // head_num - - /** QK: - threadgroup: [M16, K8] - each thread: K4 - layout: [B, M, K] -> [B, M/16, M16, K/8, K2, K4] - index : [z, sl, rcl, ml, kl, K4] - offset: (z * M + sl * 16 + rcl) * K + (0 * 2 + kl) * 4 + 0 - */ - /** V: - threadgroup: [K8, N16] - each thread: N4 - layout: [K, B/G, N] -> [K/8, K8, B/G, N/16, N4, N4] - index : [0, kcl, B/G, hm, nl, 0] - offset: ((0 * 8 + kcl) * B/G + z/G) * N + hm * 16 + nl * 4 + 0 - */ - /** output: - threadgroup: [M16, N16] - each thread: N8 - layout: [M, B, N] -> [M/16, M16, B, N/16, N2, N8] - index : [sl, rcl, B, kl, 0] - offset: ((sl * 16 + rcl) * B + z) * N + hm * 16 + kl * 8 + 0 - */ - - int group = param.group; - int zin = z / param.group; - int qk_seq_len = param.query_seq_len; - int value_seq_len = param.key_seq_len; - int head_num = param.head_num; - int head_dim = param.head_dim; - const int stride = head_num * head_dim / group; - - threadgroup float sdata[256] = {0.f}; - - int idx_qk_sl = sl * 16 + rcl < qk_seq_len ? (sl * 16 + rcl) : qk_seq_len - 1; - - auto A_offset = input0 + (z * qk_seq_len + idx_qk_sl) * value_seq_len + (0 * 2 + kl) * 4 + 0; - auto B_offset = input1 + ((0 * 8 + kcl) * head_num / group + zin) * head_dim + hm * 16 + nl * 4 + 0; - - for(int i = 0; i < value_seq_len; i += 8){ - sdata[rcl * 8 + kl * 4 + 0] = (i + kl * 4 + 0 < value_seq_len) ? A_offset[i + 0] : 0.0; - sdata[rcl * 8 + kl * 4 + 1] = (i + kl * 4 + 1 < value_seq_len) ? A_offset[i + 1] : 0.0; - sdata[rcl * 8 + kl * 4 + 2] = (i + kl * 4 + 2 < value_seq_len) ? A_offset[i + 2] : 0.0; - sdata[rcl * 8 + kl * 4 + 3] = (i + kl * 4 + 3 < value_seq_len) ? A_offset[i + 3] : 0.0; - - sdata[128 + kcl * 16 + nl * 4 + 0] = (i + kcl < value_seq_len && hm * 16 + nl * 4 + 0 < head_dim) ? B_offset[i * stride + 0] : 0.0; - sdata[128 + kcl * 16 + nl * 4 + 1] = (i + kcl < value_seq_len && hm * 16 + nl * 4 + 1 < head_dim) ? B_offset[i * stride + 1] : 0.0; - sdata[128 + kcl * 16 + nl * 4 + 2] = (i + kcl < value_seq_len && hm * 16 + nl * 4 + 2 < head_dim) ? B_offset[i * stride + 2] : 0.0; - sdata[128 + kcl * 16 + nl * 4 + 3] = (i + kcl < value_seq_len && hm * 16 + nl * 4 + 3 < head_dim) ? B_offset[i * stride + 3] : 0.0; - threadgroup_barrier(mem_flags::mem_threadgroup); - - simdgroup_load(sga[0], (const threadgroup float*)sdata, 8); - simdgroup_load(sga[1], ((const threadgroup float*)sdata) + 64, 8); - - simdgroup_load(sgb[0], ((const threadgroup float*)sdata) + 128, 16); - simdgroup_load(sgb[1], ((const threadgroup float*)sdata) + 136, 16); - - simdgroup_multiply_accumulate(sgd[0], sga[0], sgb[0], sgd[0]); - simdgroup_multiply_accumulate(sgd[1], sga[1], sgb[0], sgd[1]); - simdgroup_multiply_accumulate(sgd[2], sga[0], sgb[1], sgd[2]); - simdgroup_multiply_accumulate(sgd[3], sga[1], sgb[1], sgd[3]); - threadgroup_barrier(mem_flags::mem_threadgroup); - } - - simdgroup_store(sgd[0], (threadgroup float*)sdata, 8); - simdgroup_store(sgd[1], (threadgroup float*)sdata + 64, 8); - simdgroup_store(sgd[2], (threadgroup float*)sdata + 128, 8); - simdgroup_store(sgd[3], (threadgroup float*)sdata + 192, 8); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // [N2, M2, M8, N8] - auto xy_out = output + ((sl * 16 + rcl) * head_num + z) * head_dim + hm * 16 + kl * 8 + 0; - if(sl * 16 + rcl < qk_seq_len) { - if(hm * 16 + kl * 8 + 0 < head_dim) { - xy_out[0] = sdata[(kl * 16 + rcl) * 8 + 0]; - } - if(hm * 16 + kl * 8 + 1 < head_dim) { - xy_out[1] = sdata[(kl * 16 + rcl) * 8 + 1]; - } - if(hm * 16 + kl * 8 + 2 < head_dim) { - xy_out[2] = sdata[(kl * 16 + rcl) * 8 + 2]; - } - if(hm * 16 + kl * 8 + 3 < head_dim) { - xy_out[3] = sdata[(kl * 16 + rcl) * 8 + 3]; - } - if(hm * 16 + kl * 8 + 4 < head_dim) { - xy_out[4] = sdata[(kl * 16 + rcl) * 8 + 4]; - } - if(hm * 16 + kl * 8 + 5 < head_dim) { - xy_out[5] = sdata[(kl * 16 + rcl) * 8 + 5]; - } - if(hm * 16 + kl * 8 + 6 < head_dim) { - xy_out[6] = sdata[(kl * 16 + rcl) * 8 + 6]; - } - if(hm * 16 + kl * 8 + 7 < head_dim) { - xy_out[7] = sdata[(kl * 16 + rcl) * 8 + 7]; - } - } - -#else - const int x = gid.x; // kv_seq_len - const int y = gid.y; // head_num - const int z = gid.z; // head_dim - if (x >= param.query_seq_len || y >= param.head_num || z >= param.head_dim) { - return; - } - int group = param.group; - int yin = y / param.group; - int qk_seq_len = param.query_seq_len; - int value_seq_len = param.key_seq_len; - int head_num = param.head_num; - int head_dim = param.head_dim; - const int stride = head_num * head_dim / group; - const int offset_head = yin * head_dim + z; - - device const T *A_offset = input0 + (y * qk_seq_len + x) * value_seq_len; - device const T *B_offset = input1 + offset_head; - float out = 0.0; - - for(int i = 0; i < value_seq_len; ++i){ - float A0 = (float)A_offset[i]; - float B = (float)B_offset[i*stride]; - out += A0 * B; - } - output[ x * stride * group + (y * head_dim + z)] = out; -#endif -} - -kernel void decode(const device T* input0 [[buffer(0)]], - const device T* input1 [[buffer(1)]], - device T* output [[buffer(2)]], - device T* past_value [[buffer(3)]], - constant Param& param [[buffer(4)]], -#ifdef SIMD_GROUP_REDUCE - uint3 gid[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]] -#else - uint3 gid[[thread_position_in_grid]] -#endif -) { - const int x = gid.x; // query_seq_len - const int y = gid.y; // head_num - const int z = gid.z; // head_dim - if (x >= param.query_seq_len || y >= param.head_num || z >= param.head_dim) { - return; - } - int group = param.group; - int yin = y / param.group; - - int value_seq_len = param.key_seq_len; - int head_num = param.head_num; - int head_dim = param.head_dim; - const int stride = head_num * head_dim / group; - const int offset_head = yin * head_dim + z; - - device const T *A_offset = input0 + y * value_seq_len; - device T *Pastvalue_offset = past_value + offset_head; - float out = 0; - -#ifdef SIMD_GROUP_REDUCE - for(int i = tiisg; i < value_seq_len; i+=SIMD_GROUP_WIDTH){ - float A = (float)A_offset[i]; - float B = (float)Pastvalue_offset[i * stride]; - - out += A * B; - } - out = simd_sum(out); - if(tiisg == 0) { - output[(y * head_dim + z)] = (T)out; - } -#else - for(int i = 0; i < value_seq_len; i++){ - float A = (float)A_offset[i]; - float B = (float)Pastvalue_offset[i * stride]; - - out += A * B; - } - output[(y * head_dim + z)] = (T)out; -#endif -} -)metal"; - namespace MNN { class AttentionBufExecution : public MetalExecution { public: @@ -621,7 +41,8 @@ virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override { private: void _init(); - void reallocKVCache(); + void reallocKVCache(int history_len); + void compilerShader(const std::vector &inputs); bool mKVCache; std::shared_ptr mCache; float mScale; @@ -639,6 +60,14 @@ virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override { id mParamQKV; id mParamSoftmax; id mParamCopy; + +private: + bool mQkSimdReduce = false; + bool mQkSimdMatrix = false; + bool mSftmSimdReduce = false; + bool mQkvSimdReduce = false; + bool mQkvSimdMatrix = false; + bool mUseHeadNum2 = false; }; struct Param { @@ -648,6 +77,7 @@ virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override { int group; int head_dim; float scale; + int max_kv_len; }; AttentionBufExecution::AttentionBufExecution(Backend *backend, bool kv_cahce) : MetalExecution(backend) , mKVCache(kv_cahce) { @@ -659,29 +89,48 @@ virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override { auto context = (__bridge MNNMetalContext *)mtbn->context(); mParamQKV = [context newDeviceBuffer:sizeof(Param) access:CPUWriteOnly]; mParamSoftmax = [context newDeviceBuffer:4 * sizeof(int) access:CPUWriteOnly]; - mParamCopy = [context newDeviceBuffer:4 * sizeof(int) access:CPUWriteOnly]; + mParamCopy = [context newDeviceBuffer:5 * sizeof(int) access:CPUWriteOnly]; mTempQK.reset(Tensor::createDevice({0, 0})); mTempSoftMax.reset(Tensor::createDevice({0, 0})); } -void AttentionBufExecution::reallocKVCache() { - if (!mKVCache || mCache->mPastLength < mCache->mMaxLength) { +void AttentionBufExecution::reallocKVCache(int history_len) { + /* + when kv-cache + decoding: past_len > max_len, realloc and copy past_len cache + prefill : max_len == 0 (first prefill), realloc and copy history_len cache + prefill : max_len > 0 (not first prefill) && past_len >= max_len, realloc and copy history_len cache. copy current prompt in copy shader(not this function) + prefill : max_len > 0 (not first prefill) && past_len < max_len, not realloc and no need copy history_len cache, just copy current prompt in copy shader(not this function) + + */ + if (!mKVCache) { + return; + } + + if (mIsDecode && mCache->mPastLength < mCache->mMaxLength) { return; } + // not first prefill (do reuse_kvcache) and total past_len < max_len + if(!mIsDecode && mCache->mMaxLength > 0 && mCache->mPastLength < mCache->mMaxLength && history_len != 0) { + return; + } auto mtbn = static_cast(backend()); int byte = 4; if(mtbn->useFp16InsteadFp32()) { byte = 2; } - bool needCopy = mCache->mMaxLength > 0; + bool needCopy = history_len > 0; + + size_t old_size = mKvNumHead * history_len * mHeadDim * byte; + size_t old_piece_size = history_len * byte; + size_t old_piece_stride = mCache->mMaxLength * byte; - size_t old_size = mKvNumHead * mCache->mMaxLength * mHeadDim * byte; mCache->mMaxLength = mCache->mPastLength + mExpandChunk; // past_key: [1, numhead, headdim, maxlen] auto new_key = Tensor::createDevice({mCache->mMaxLength, mKvNumHead, mHeadDim}); // past_value: [1, numhead, maxlen, headdim] - auto new_value = Tensor::createDevice({mCache->mMaxLength, mKvNumHead, mHeadDim}); + auto new_value = Tensor::createDevice({mKvNumHead, mHeadDim, mCache->mMaxLength}); size_t size = mKvNumHead * mCache->mMaxLength * mHeadDim * byte; backend()->onAcquireBuffer(new_key, Backend::STATIC); backend()->onAcquireBuffer(new_value, Backend::STATIC); @@ -696,67 +145,42 @@ virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override { auto new_value_ptr = (uint8_t*)[newValueBuf.first contents] + newValueBuf.second; auto valueBuf = MetalBackend::getBuffer(mCache->mPastValue.get()); auto value_ptr = (uint8_t*)[valueBuf.first contents] + valueBuf.second; - ::memcpy(new_value_ptr, value_ptr, old_size); + for(int i = 0; i < mKvNumHead * mHeadDim; i++) { + ::memcpy(new_value_ptr + i * mCache->mMaxLength * byte, value_ptr + i * old_piece_stride, old_piece_size); + } } mCache->mPastKey.reset(new_key); mCache->mPastValue.reset(new_value); } - -void AttentionBufExecution::onEncode(const std::vector &inputs, const std::vector &outputs, id encoder) { - - auto query = inputs[0]; - auto key = inputs[1]; - auto value = inputs[2]; +void AttentionBufExecution::compilerShader(const std::vector &inputs) { auto mask = inputs[3]; auto mtbn = static_cast(backend()); - auto context = (__bridge MNNMetalContext *)mtbn->context(); - auto shape = query->shape(); - int seq_len = shape[1]; - mNumHead = shape[2]; - mHeadDim = shape[3]; - mScale = 1.0 / sqrt(mHeadDim); - mIsDecode = seq_len == 1; - if (mCache->mPastLength == 0 || seq_len > 1) { - mCache->mPastLength = seq_len; - } - mCache->mKv_seq_len = mCache->mPastLength; - if(mIsDecode){ - mCache->mKv_seq_len = mCache->mPastLength + 1; - } - mKvNumHead = key->shape()[2]; - auto rt = (MetalRuntime*)mtbn->runtime(); - bool supportSimdReduce = rt->supportSimdGroupReduce(); - bool supportSimdMatrix = rt->supportSimdGroupMatrix(); - - // decode and thread number not too large - bool qkSimdReduce = supportSimdReduce && seq_len == 1 && mCache->mKv_seq_len * mNumHead < mHeadDim * 32; - // loop_k can divide 8, thus avoid branch - bool qkSimdMatrix = supportSimdMatrix && seq_len >= 16 && mHeadDim % 8 == 0; + auto context = (__bridge MNNMetalContext *)mtbn->context(); - bool sftmSimdReduce = supportSimdReduce; - bool qkvSimdReduce = supportSimdReduce && seq_len == 1 && mHeadDim * mNumHead < mCache->mKv_seq_len * 32; - bool qkvSimdMatrix = supportSimdMatrix && seq_len >= 16; - // Init Kernel bool float_mask = (mask->getType() == halide_type_of()); std::string T = "float"; - std::string T4 = "float4"; if (mtbn->useFp16InsteadFp32()) { T = "half"; - T4 = "half4"; } std::vector qkKeys = { {"matmul_qk_div_mask", T} }; - if(qkSimdReduce) { + if(mQkSimdReduce) { qkKeys.emplace_back("SIMD_GROUP_REDUCE"); } + + // QK matmul total thread is large + mUseHeadNum2 = mIsDecode && mCache->mKv_seq_len > 1024; + if(mUseHeadNum2) { + qkKeys.emplace_back("HEAD_NUM_2"); + } std::vector qkvKeys = { {"matmul_qkv", T} }; - if(qkvSimdReduce) { + if(mQkvSimdReduce) { qkvKeys.emplace_back("SIMD_GROUP_REDUCE"); } std::vector qkPrefillKeys = { @@ -765,17 +189,17 @@ virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override { if (float_mask) { qkPrefillKeys.emplace_back("FLOAT_MASK"); } - if(qkSimdMatrix) { + if(mQkSimdMatrix) { qkPrefillKeys.emplace_back("SIMD_GROUP_MATRIX"); } std::vector qkvPrefillKeys = { {"matmul_qkv", T, "FOR_PREFILL"} }; - if(qkvSimdMatrix) { + if(mQkvSimdMatrix) { qkvPrefillKeys.emplace_back("SIMD_GROUP_MATRIX"); } std::vector copyPastKeys = { - {"pastkv_copy", T4} + {"pastkv_copy", T} }; std::vector> keys = { qkKeys, @@ -791,6 +215,13 @@ virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override { gMatMulQKV, gCopyPastKV }; + std::vector shaders = { + "decode_qk", + "decode_qkv", + "prefill_qk", + "prefill_qkv", + "copy" + }; std::vector> pipelines(keys.size()); for (int i=0; ifindPipeline(keys[i]); @@ -803,14 +234,8 @@ virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override { [dic setValue:@"1" forKey:@(keys[i][j].c_str())];; } option.preprocessorMacros = dic; - if(std::find(keys[i].begin(), keys[i].end(), "FOR_PREFILL") != keys[i].end()) { - pipeline = mtbn->makeComputePipelineWithSourceOption(sources[i], "prefill", option); - } else if(i == 4){ - pipeline = mtbn->makeComputePipelineWithSourceOption(sources[i], "copy", option); - - } else { - pipeline = mtbn->makeComputePipelineWithSourceOption(sources[i], "decode", option); - } + + pipeline = mtbn->makeComputePipelineWithSourceOption(sources[i], shaders[i].c_str(), option); rt->insertPipeline(keys[i], pipeline); } pipelines[i] = pipeline; @@ -826,15 +251,65 @@ virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override { MNN_ASSERT(nil != mKernelPrefill_qkv); MNN_ASSERT(nil != mKernel_copy); - if(sftmSimdReduce) { + if(mSftmSimdReduce) { mKernel_softmax = [context pipelineWithName:@"softmax_plane_sg" fp16:mtbn->useFp16InsteadFp32()]; } else { mKernel_softmax = [context pipelineWithName:@"softmax_plane" fp16:mtbn->useFp16InsteadFp32()]; } +} + +void AttentionBufExecution::onEncode(const std::vector &inputs, const std::vector &outputs, id encoder) { + + auto query = inputs[0]; + auto key = inputs[1]; + auto value = inputs[2]; + auto mask = inputs[3]; + auto mtbn = static_cast(backend()); + auto context = (__bridge MNNMetalContext *)mtbn->context(); + auto shape = query->shape(); + int seq_len = shape[1]; + mNumHead = shape[2]; + mHeadDim = shape[3]; + mScale = 1.0 / sqrt(mHeadDim); + mIsDecode = seq_len == 1; + + int history_len = mCache->mPastLength; + // first prefill set history_len to 0 + if(!mIsDecode && mask->length(2) == mask->length(3)) { + history_len = 0; + } + if (!mIsDecode) { + mCache->mPastLength = mask->length(3); + } + mCache->mKv_seq_len = mCache->mPastLength; + if(mIsDecode){ + mCache->mKv_seq_len = mCache->mPastLength + 1; + } + mKvNumHead = key->shape()[2]; + + auto rt = (MetalRuntime*)mtbn->runtime(); + bool supportSimdReduce = rt->supportSimdGroupReduce(); + bool supportSimdMatrix = rt->supportSimdGroupMatrix(); + + // decode and thread number not too large + mQkSimdReduce = supportSimdReduce && seq_len == 1; + // loop_k can divide 8, thus avoid branch + mQkSimdMatrix = supportSimdMatrix && seq_len >= 16 && mHeadDim % 8 == 0; + + mSftmSimdReduce = supportSimdReduce; + mQkvSimdReduce = supportSimdReduce && seq_len == 1 && mHeadDim * mNumHead < mCache->mKv_seq_len * 32; + mQkvSimdMatrix = supportSimdMatrix && seq_len >= 16; + + // start to compile attention shaders + compilerShader(inputs); + int group_size = mNumHead / mKvNumHead; - reallocKVCache(); + // kv-cache realloc function + reallocKVCache(history_len); + + // temp tensor alloc memory bool needMalloc = mTempQK->length(0) != mNumHead; if (mIsDecode) { if (mTempQK->length(1) != mCache->mMaxLength) { @@ -844,19 +319,20 @@ virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override { mTempQK->setLength(1, mCache->mMaxLength); mTempSoftMax->setLength(0, mNumHead); mTempSoftMax->setLength(1, mCache->mMaxLength); + } else { - if (mTempQK->length(1) != mCache->mPastLength * mCache->mPastLength) { + if (mTempQK->length(1) != seq_len * mCache->mPastLength) { needMalloc = true; } mTempQK->setLength(0, mNumHead); - mTempQK->setLength(1, mCache->mPastLength * mCache->mPastLength); + mTempQK->setLength(1, seq_len * mCache->mPastLength); mTempSoftMax->setLength(0, mNumHead); - mTempSoftMax->setLength(1, mCache->mPastLength * mCache->mPastLength); + mTempSoftMax->setLength(1, seq_len * mCache->mPastLength); } if (needMalloc) { auto res = backend()->onAcquireBuffer(mTempQK.get(), Backend::STATIC) && backend()->onAcquireBuffer(mTempSoftMax.get(), Backend::STATIC); if (!res) { - MNN_ERROR("MNN::Metal: OUT_OF_MEMORY when execute attention metal\n"); + MNN_ERROR("MNN::Metal: OUT_OF_MEMORY when execute attention metal %d\n", mCache->mPastLength); return; } } @@ -870,6 +346,7 @@ virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override { param->head_num = mNumHead; param->group = group_size; param->query_seq_len = seq_len; + param->max_kv_len = mCache->mMaxLength; } // For softmax parameter int inside, outside; @@ -878,7 +355,7 @@ virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override { outside = mNumHead; } else { inside = 1; - outside = mCache->mKv_seq_len * mNumHead; + outside = seq_len * mNumHead; } int axis = mCache->mKv_seq_len; { @@ -889,22 +366,35 @@ virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override { softmax[2] = outside; softmax[3] = 0; } + // Run Copy Kernel { auto copyp = (int*)mParamCopy.contents; - copyp[0] = mKvNumHead * mHeadDim / 4; + copyp[0] = mKvNumHead * mHeadDim; int copy_line; if(mIsDecode) { + /* + each decode fill one kv_seq. + Key -> K-Cache : [1, mKvNumHead, mHeadDim] -> [mCache->mKv_seq_len + 1, mKvNumHead, mHeadDim] + Value -> V-Cache : [1, mKvNumHead, mHeadDim] -> [mKvNumHead, mHeadDim, mCache->mKv_seq_len + 1] + */ copyp[1] = 1; - copyp[2] = 0; + copyp[2] = mCache->mMaxLength; copyp[3] = (mCache->mKv_seq_len - 1) * copyp[0]; + copyp[4] = mCache->mKv_seq_len - 1; copy_line = 1; } else { - copyp[1] = mCache->mKv_seq_len; - copyp[2] = 0; - copyp[3] = 0; - copy_line = mCache->mKv_seq_len; + /* + first time copy. + Key -> K-Cache : [mCache->mKv_seq_len, mKvNumHead, mHeadDim] -> [mCache->mKv_seq_len, mKvNumHead, mHeadDim] + Value -> V-Cache : [mCache->mKv_seq_len, mKvNumHead, mHeadDim] -> [mKvNumHead, mHeadDim, mCache->mMaxLength (fill when decode)] + */ + copyp[1] = seq_len; + copyp[2] = mCache->mMaxLength; + copyp[3] = history_len * copyp[0]; + copyp[4] = history_len; + copy_line = seq_len; } id pipeline = mKernel_copy; @@ -916,7 +406,7 @@ virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override { [encoder setBuffer:mParamCopy offset:0 atIndex:4]; std::pair gl; - gl = [context computeBestGroupAndLocal:pipeline threads:MTLSizeMake(mKvNumHead * mHeadDim / 4, copy_line, 1)]; + gl = [context computeBestGroupAndLocal:pipeline threads:MTLSizeMake(mKvNumHead * mHeadDim, copy_line, 1)]; [encoder dispatchThreadgroups:gl.first threadsPerThreadgroup:gl.second]; @@ -931,21 +421,27 @@ virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override { } [encoder setComputePipelineState:pipeline]; MetalBackend::setTensor(query, encoder, 0); - MetalBackend::setTensor(key, encoder, 1); - MetalBackend::setTensor(mTempQK.get(), encoder, 2); - MetalBackend::setTensor(mCache->mPastKey.get(), encoder, 3); - MetalBackend::setTensor(mask, encoder, 4); - [encoder setBuffer:mParamQKV offset:0 atIndex:5]; + MetalBackend::setTensor(mTempQK.get(), encoder, 1); + MetalBackend::setTensor(mCache->mPastKey.get(), encoder, 2); + MetalBackend::setTensor(mask, encoder, 3); + [encoder setBuffer:mParamQKV offset:0 atIndex:4]; + int decode_grid_y = mNumHead; + if(mUseHeadNum2) { + decode_grid_y = (decode_grid_y + 1) / 2; + } std::pair gl; - if(qkSimdReduce) { - gl = std::make_pair(MTLSizeMake(seq_len, mNumHead, mCache->mKv_seq_len), MTLSizeMake(32, 1, 1)); - } else if(qkSimdMatrix) { + if(mQkSimdReduce) { + gl = std::make_pair(MTLSizeMake(seq_len, decode_grid_y, mCache->mKv_seq_len), MTLSizeMake(32, 1, 1)); + } else if(mQkSimdMatrix) { gl = std::make_pair(MTLSizeMake(UP_DIV(seq_len, 16), UP_DIV(mCache->mKv_seq_len, 16), mNumHead), MTLSizeMake(32, 1, 1)); + } else if(mIsDecode){ + gl = [context computeBestGroupAndLocal:pipeline threads:MTLSizeMake(seq_len, decode_grid_y, mCache->mKv_seq_len)]; } else { gl = [context computeBestGroupAndLocal:pipeline threads:MTLSizeMake(seq_len, mNumHead, mCache->mKv_seq_len)]; } [encoder dispatchThreadgroups:gl.first threadsPerThreadgroup:gl.second]; + } // Run Softmax Kernel { @@ -956,13 +452,14 @@ virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override { int thread_group_size = 32; std::pair gl; - if(sftmSimdReduce) { + if(mSftmSimdReduce) { gl = std::make_pair(MTLSizeMake(inside, outside, 1), MTLSizeMake(thread_group_size, 1, 1)); } else { gl = [context computeBestGroupAndLocal: mKernel_softmax threads:MTLSizeMake(inside, outside, 1)]; } [encoder dispatchThreadgroups:gl.first threadsPerThreadgroup:gl.second]; + } // Run QKV Kernel { @@ -974,27 +471,26 @@ virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override { } [encoder setComputePipelineState:pipeline]; MetalBackend::setTensor(mTempSoftMax.get(), encoder, 0); - MetalBackend::setTensor(value, encoder, 1); - MetalBackend::setTensor(outputs[0], encoder, 2); - MetalBackend::setTensor(mCache->mPastValue.get(), encoder, 3); - [encoder setBuffer:mParamQKV offset:0 atIndex:4]; + MetalBackend::setTensor(outputs[0], encoder, 1); + MetalBackend::setTensor(mCache->mPastValue.get(), encoder, 2); + [encoder setBuffer:mParamQKV offset:0 atIndex:3]; std::pair gl; - if(qkvSimdReduce) { + if(mQkvSimdReduce) { gl = std::make_pair(MTLSizeMake(seq_len, mNumHead, mHeadDim), MTLSizeMake(32, 1, 1)); - } else if(qkvSimdMatrix){ + } else if(mQkvSimdMatrix){ gl = std::make_pair(MTLSizeMake(UP_DIV(seq_len, 16), UP_DIV(mHeadDim, 16), mNumHead), MTLSizeMake(32, 1, 1)); - //printf("qk:%d %d %d, softmax:%d %d %d, qkv:%d %d %d\n", seq_len, mNumHead, mCache->mKv_seq_len, inside, outside, 1, seq_len, mNumHead, mHeadDim); } else { gl = [context computeBestGroupAndLocal:pipeline threads:MTLSizeMake(seq_len, mNumHead, mHeadDim)]; } [encoder dispatchThreadgroups:gl.first threadsPerThreadgroup:gl.second]; + } // Update status if(mIsDecode){ mCache->mPastLength += 1; mCache->mKv_seq_len = mCache->mPastLength + 1; } - //printf("qk:%d %d %d, softmax:%d %d %d, qkv:%d %d %d\n", seq_len, mNumHead, mCache->mKv_seq_len, inside, outside, 1, seq_len, mNumHead, mHeadDim); + return; } diff --git a/source/backend/metal/MetalAttentionShader.hpp b/source/backend/metal/MetalAttentionShader.hpp new file mode 100644 index 000000000..9209be406 --- /dev/null +++ b/source/backend/metal/MetalAttentionShader.hpp @@ -0,0 +1,636 @@ +// +// MetalAttentionShader.hpp +// MNN +// +// Created by MNN on b'2024/12/03'. +// Copyright © 2018, Alibaba Group Holding Limited +// + +#if MNN_METAL_ENABLED +#ifdef MNN_SUPPORT_TRANSFORMER_FUSE + +const char* gMatMulDivMask = R"metal( +#include +#include +using namespace metal; +struct Param { + int query_seq_len; + int key_seq_len; + int head_num; + int group; + int head_dim; + float scale; + int max_kv_len; +}; +#define SIMD_GROUP_WIDTH 32 + +kernel void prefill_qk(const device T* input0 [[buffer(0)]], + device T* output [[buffer(1)]], + device T* past_key [[buffer(2)]], +#ifdef FLOAT_MASK + const device T* mask [[buffer(3)]], +#else + const device int* mask [[buffer(3)]], +#endif + constant Param& param [[buffer(4)]], +#ifdef SIMD_GROUP_MATRIX + uint3 gid[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]] +#else + uint3 gid[[thread_position_in_grid]] +#endif +) { +#ifdef SIMD_GROUP_MATRIX + + /* + // Read: + ftype 0~127 ---> input: [M16, K8] + ftype 128~255 ---> input: [K8, N16] + // Write: + ftype 0~255 ---> input: [N2, M2, M8, N8] + */ + + simdgroup_float8x8 sga[2]; + simdgroup_float8x8 sgb[2]; + simdgroup_float8x8 sgd[4]; + for (int i = 0; i < 4; i++){ + sgd[i] = make_filled_simdgroup_matrix(0.f); + } + + int kl = tiitg % 2;// 0~1 + int rcl = tiitg / 2;// 0~15 + + const int slq = gid.x; // q_seq_len/16 -> M/16 + const int slk = gid.y; // k_seq_len/16 -> N/16 + const int z = gid.z; // head_num + + /** Q: + threadgroup: [M16, K8] + each thread: K4 + layout: [M, B, K] -> [M/16, M16, B, K/8, K2, K4] + index : [slq, rcl, z, 0, kl, K4] + offset: ((slq * 16 + rcl) * B + z) * K + (0 * 2 + kl) * 4 + 0 + */ + /** K: + threadgroup: [K8, N16] + each thread: N4 + layout: [N, B/G, K] -> [N/16, N16, B/G, K/8, K2, K4] + index : [slk, rcl, B/G, 0, kl, 0] + offset: ((slk * 16 + rcl) * B/G + z/G) * K + 0 * 8 + kl * 4 + 0 + */ + /** output: + threadgroup: [M16, N16] + each thread: N8 + layout: [B, M, N] -> [B, M/16, M16, N/16, N2, N8] + index : [z, sl, rcl, kl, 0] + offset: (z * M + sl * 16 + rcl) * N + slk * 16 + kl * 8 + 0 + */ + + int group = param.group; + int zin = z / param.group; + int q_seq_len = param.query_seq_len; + int k_seq_len = param.key_seq_len; + int head_num = param.head_num; + int head_dim = param.head_dim; + + threadgroup float sdata[256] = {0.f}; + + int idx_slq = slq * 16 + rcl < q_seq_len ? slq * 16 + rcl : q_seq_len - 1; + int idx_slk = slk * 16 + rcl < k_seq_len ? slk * 16 + rcl : k_seq_len - 1; + + auto A_offset = input0 + (idx_slq * head_num + z) * head_dim + (0 * 2 + kl) * 4 + 0; + auto B_offset = past_key + (idx_slk * head_num / group + zin) * head_dim + 0 * 8 + kl * 4 + 0; + + for(int i = 0; i < head_dim; i += 8){ + sdata[rcl * 8 + kl * 4 + 0] = A_offset[i + 0]; + sdata[rcl * 8 + kl * 4 + 1] = A_offset[i + 1]; + sdata[rcl * 8 + kl * 4 + 2] = A_offset[i + 2]; + sdata[rcl * 8 + kl * 4 + 3] = A_offset[i + 3]; + + sdata[128 + (kl * 4 + 0) * 16 + rcl] = B_offset[i + 0]; + sdata[128 + (kl * 4 + 1) * 16 + rcl] = B_offset[i + 1]; + sdata[128 + (kl * 4 + 2) * 16 + rcl] = B_offset[i + 2]; + sdata[128 + (kl * 4 + 3) * 16 + rcl] = B_offset[i + 3]; + threadgroup_barrier(mem_flags::mem_threadgroup); + + simdgroup_load(sga[0], (const threadgroup float*)sdata, 8); + simdgroup_load(sga[1], ((const threadgroup float*)sdata) + 64, 8); + + simdgroup_load(sgb[0], ((const threadgroup float*)sdata) + 128, 16); + simdgroup_load(sgb[1], ((const threadgroup float*)sdata) + 136, 16); + + simdgroup_multiply_accumulate(sgd[0], sga[0], sgb[0], sgd[0]); + simdgroup_multiply_accumulate(sgd[1], sga[1], sgb[0], sgd[1]); + simdgroup_multiply_accumulate(sgd[2], sga[0], sgb[1], sgd[2]); + simdgroup_multiply_accumulate(sgd[3], sga[1], sgb[1], sgd[3]); + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + simdgroup_store(sgd[0], (threadgroup float*)sdata, 8); + simdgroup_store(sgd[1], (threadgroup float*)sdata + 64, 8); + simdgroup_store(sgd[2], (threadgroup float*)sdata + 128, 8); + simdgroup_store(sgd[3], (threadgroup float*)sdata + 192, 8); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // [N2, M2, M8, N8] + float Vscale = (float)param.scale; + + auto xy_out = output + (z * q_seq_len + slq * 16 + rcl) * k_seq_len + slk * 16 + kl * 8 + 0; + if(slq * 16 + rcl < q_seq_len) { + if(slk * 16 + kl * 8 + 0 < k_seq_len) { + auto out0 = sdata[(kl * 16 + rcl) * 8 + 0] * Vscale; + #ifdef FLOAT_MASK + out0 = mask[((slq * 16 + rcl) * k_seq_len + (slk * 16 + kl * 8 + 0))] + out0; + #else + out0 = mask[((slq * 16 + rcl) * k_seq_len + (slk * 16 + kl * 8 + 0))] == 0 ? -FLT_MAX : out0; + #endif + xy_out[0] = out0; + } + if(slk * 16 + kl * 8 + 1 < k_seq_len) { + auto out0 = sdata[(kl * 16 + rcl) * 8 + 1] * Vscale; + #ifdef FLOAT_MASK + out0 = mask[((slq * 16 + rcl) * k_seq_len + (slk * 16 + kl * 8 + 1))] + out0; + #else + out0 = mask[((slq * 16 + rcl) * k_seq_len + (slk * 16 + kl * 8 + 1))] == 0 ? -FLT_MAX : out0; + #endif + xy_out[1] = out0; + } + if(slk * 16 + kl * 8 + 2 < k_seq_len) { + auto out0 = sdata[(kl * 16 + rcl) * 8 + 2] * Vscale; + #ifdef FLOAT_MASK + out0 = mask[((slq * 16 + rcl) * k_seq_len + (slk * 16 + kl * 8 + 2))] + out0; + #else + out0 = mask[((slq * 16 + rcl) * k_seq_len + (slk * 16 + kl * 8 + 2))] == 0 ? -FLT_MAX : out0; + #endif + xy_out[2] = out0; + } + if(slk * 16 + kl * 8 + 3 < k_seq_len) { + auto out0 = sdata[(kl * 16 + rcl) * 8 + 3] * Vscale; + #ifdef FLOAT_MASK + out0 = mask[((slq * 16 + rcl) * k_seq_len + (slk * 16 + kl * 8 + 3))] + out0; + #else + out0 = mask[((slq * 16 + rcl) * k_seq_len + (slk * 16 + kl * 8 + 3))] == 0 ? -FLT_MAX : out0; + #endif + xy_out[3] = out0; + } + if(slk * 16 + kl * 8 + 4 < k_seq_len) { + auto out0 = sdata[(kl * 16 + rcl) * 8 + 4] * Vscale; + #ifdef FLOAT_MASK + out0 = mask[((slq * 16 + rcl) * k_seq_len + (slk * 16 + kl * 8 + 4))] + out0; + #else + out0 = mask[((slq * 16 + rcl) * k_seq_len + (slk * 16 + kl * 8 + 4))] == 0 ? -FLT_MAX : out0; + #endif + xy_out[4] = out0; + } + if(slk * 16 + kl * 8 + 5 < k_seq_len) { + auto out0 = sdata[(kl * 16 + rcl) * 8 + 5] * Vscale; + #ifdef FLOAT_MASK + out0 = mask[((slq * 16 + rcl) * k_seq_len + (slk * 16 + kl * 8 + 5))] + out0; + #else + out0 = mask[((slq * 16 + rcl) * k_seq_len + (slk * 16 + kl * 8 + 5))] == 0 ? -FLT_MAX : out0; + #endif + xy_out[5] = out0; + } + if(slk * 16 + kl * 8 + 6 < k_seq_len) { + auto out0 = sdata[(kl * 16 + rcl) * 8 + 6] * Vscale; + #ifdef FLOAT_MASK + out0 = mask[((slq * 16 + rcl) * k_seq_len + (slk * 16 + kl * 8 + 6))] + out0; + #else + out0 = mask[((slq * 16 + rcl) * k_seq_len + (slk * 16 + kl * 8 + 6))] == 0 ? -FLT_MAX : out0; + #endif + xy_out[6] = out0; + } + if(slk * 16 + kl * 8 + 7 < k_seq_len) { + auto out0 = sdata[(kl * 16 + rcl) * 8 + 7] * Vscale; + #ifdef FLOAT_MASK + out0 = mask[((slq * 16 + rcl) * k_seq_len + (slk * 16 + kl * 8 + 7))] + out0; + #else + out0 = mask[((slq * 16 + rcl) * k_seq_len + (slk * 16 + kl * 8 + 7))] == 0 ? -FLT_MAX : out0; + #endif + xy_out[7] = out0; + } + } + +#else + const int x = gid.x; // query_seq_len + const int y = gid.y; // head_num + const int z = gid.z; // key_seq_len + + if (x >= param.query_seq_len || y >= param.head_num || z >= param.key_seq_len) { + return; + } + int group = param.group; + int query_seq_len = param.query_seq_len; + int key_seq_len = param.key_seq_len; + int head_num = param.head_num; + int head_dim = param.head_dim; + + const int offset = head_num * head_dim; + const int offset_head = y * head_dim; + const int offset_head_kv = (y / group) * head_dim; + const device T* A_offset = input0 + x * offset + offset_head; + + float Vscale = (float)param.scale; + + device const T* B_offset = past_key + z * offset / group + offset_head_kv; + const int output_offset = y * query_seq_len * key_seq_len; + float out0 = 0.0; + + for(int i = 0; i < head_dim; ++i){ + float A = (float)(A_offset[i]); + float B = (float)(B_offset[i]); + out0 += B * A; + } + + out0 *= Vscale; + +#ifdef FLOAT_MASK + out0 = mask[((x + 0) * key_seq_len + (z + 0))] + out0; +#else + out0 = mask[((x + 0) * key_seq_len + (z + 0))] == 0 ? -FLT_MAX : out0; +#endif + output[output_offset + x * key_seq_len + z] = (T)out0; +#endif +} + +kernel void decode_qk(const device T* input0 [[buffer(0)]], + device T* output [[buffer(1)]], + device T* past_key [[buffer(2)]], +#ifdef FLOAT_MASK + const device T* mask [[buffer(3)]], +#else + const device int* mask [[buffer(3)]], +#endif + constant Param& param [[buffer(4)]], +#ifdef SIMD_GROUP_REDUCE + uint3 gid[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]] +#else + uint3 gid[[thread_position_in_grid]] +#endif +) { + int x = gid.x; // query_seq_len + int y = gid.y; // head_num + int z = gid.z; // key_seq_len + +#ifdef HEAD_NUM_2 + y = y * 2; +#endif + if (x >= param.query_seq_len || y >= param.head_num || z >= param.key_seq_len) { + return; + } + int group = param.group; + + int key_seq_len = param.key_seq_len; + int head_num = param.head_num; + int head_dim = param.head_dim; + + const int offset = head_num * head_dim; + const int offset_head = y * head_dim; + const int offset_head_kv = (y / param.group) * head_dim; + const device T* A_offset = input0 + x * offset + offset_head; + device T* Pastkey_offset = past_key + z * offset / group + offset_head_kv; + float Vscale = (float)param.scale; + float out = 0.0; + +#ifdef HEAD_NUM_2 + const device T* A_offset_1 = A_offset + head_dim; + device T* Pastkey_offset_1 = past_key + z * offset / group + ((y+1) / param.group) * head_dim; + float out_1 = 0.0; +#endif + +#ifdef SIMD_GROUP_REDUCE + for(int i = tiisg; i < head_dim; i+=SIMD_GROUP_WIDTH){ + float A = A_offset[i]; + float B = (float)Pastkey_offset[i]; + + out += A * B; + } + +#ifdef HEAD_NUM_2 + if(y + 1 < param.head_num) { + for(int i = tiisg; i < head_dim; i+=SIMD_GROUP_WIDTH){ + float A = A_offset_1[i]; + float B = (float)Pastkey_offset_1[i]; + + out_1 += A * B; + } + } +#endif + out = simd_sum(out); + +#ifdef HEAD_NUM_2 + if(y + 1 < param.head_num) { + out_1 = simd_sum(out_1); + if(tiisg == 1) { + out_1 *= Vscale; + output[(y+1) * key_seq_len + z] = (T)out_1; + } + } +#endif + if(tiisg == 0) { + out *= Vscale; + output[y * key_seq_len + z] = (T)out; + } + +#else + { + for(int i = 0; i < head_dim; i++){ + float A = A_offset[i]; + float B = (float)Pastkey_offset[i]; + + out += A * B; + } + } + out *= Vscale; + output[y * key_seq_len + z] = (T)out; + +#ifdef HEAD_NUM_2 + if(y + 1 < param.head_num) { + for(int i = 0; i < head_dim; i++){ + float A = A_offset_1[i]; + float B = (float)Pastkey_offset_1[i]; + + out_1 += A * B; + } + out_1 *= Vscale; + output[(y+1) * key_seq_len + z] = (T)out_1; + } +#endif + +#endif +} + +)metal"; + +const char* gCopyPastKV = R"metal( +#include +using namespace metal; +struct Param { + int head_count; + int q_seq_len; + int max_kv_len; + int dst_k_offset; + int dst_v_offset; +}; +kernel void copy(const device T* input0 [[buffer(0)]], + const device T* input1 [[buffer(1)]], + device T* output0 [[buffer(2)]], + device T* output1 [[buffer(3)]], + constant Param& param [[buffer(4)]], + uint3 gid[[thread_position_in_grid]] +) { + const int x = gid.x; // head_num / group * head_dim + const int y = gid.y; // q_seq_len + if (x >= param.head_count || y >= param.q_seq_len) { + return; + } + const int index = y * param.head_count + x; + output0[param.dst_k_offset + index] = input0[index]; + output1[param.dst_v_offset + x * param.max_kv_len + y] = input1[index]; +} +)metal"; + +const char* gMatMulQKV = R"metal( + +#include +#include +using namespace metal; +struct Param { + int query_seq_len; + int key_seq_len; + int head_num; + int group; + int head_dim; + float scale; + int max_kv_len; +}; +#define SIMD_GROUP_WIDTH 32 +kernel void prefill_qkv(const device T* input0 [[buffer(0)]], + device T* output [[buffer(1)]], + device T* past_value [[buffer(2)]], + constant Param& param [[buffer(3)]], +#ifdef SIMD_GROUP_MATRIX + uint3 gid[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]] +#else + uint3 gid[[thread_position_in_grid]] +#endif +) { +#ifdef SIMD_GROUP_MATRIX + /* + // Read: + ftype 0~127 ---> input: [M16, K8] + ftype 128~255 ---> input: [K8, N16] + // Write: + ftype 0~255 ---> input: [N2, M2, M8, N8] + */ + + simdgroup_float8x8 sga[2]; + simdgroup_float8x8 sgb[2]; + simdgroup_float8x8 sgd[4]; + for (int i = 0; i < 4; i++){ + sgd[i] = make_filled_simdgroup_matrix(0.f); + } + + int kl = tiitg % 2;// 0~1 + int rcl = tiitg / 2;// 0~15 + + int nl = tiitg % 4;// 0~3 + int kcl = tiitg / 4;// 0~7 + + const int sl = gid.x; // q_seq_len/16 -> M/16 + const int hm = gid.y; // head_dim/16 -> N/16 + const int z = gid.z; // head_num + + /** QK: + threadgroup: [M16, K8] + each thread: K4 + layout: [B, M, K] -> [B, M/16, M16, K/8, K2, K4] + index : [z, sl, rcl, ml, kl, K4] + offset: (z * M + sl * 16 + rcl) * K + (0 * 2 + kl) * 4 + 0 + */ + /** V: + threadgroup: [K8, N16] + each thread: N4 + layout: [K, B/G, N] -> [K/8, K8, B/G, N/16, N4, N4] + index : [0, kcl, B/G, hm, nl, 0] + offset: ((0 * 8 + kcl) * B/G + z/G) * N + hm * 16 + nl * 4 + 0 + */ + /** output: + threadgroup: [M16, N16] + each thread: N8 + layout: [M, B, N] -> [M/16, M16, B, N/16, N2, N8] + index : [sl, rcl, B, kl, 0] + offset: ((sl * 16 + rcl) * B + z) * N + hm * 16 + kl * 8 + 0 + */ + + int group = param.group; + int zin = z / group; + int q_seq_len = param.query_seq_len; + int value_seq_len = param.key_seq_len; + int head_num = param.head_num; + int head_dim = param.head_dim; + + threadgroup float sdata[256] = {0.f}; + + int idx_qk_sl = sl * 16 + rcl < q_seq_len ? (sl * 16 + rcl) : q_seq_len - 1; + + auto A_offset = input0 + (z * q_seq_len + idx_qk_sl) * value_seq_len + (0 * 2 + kl) * 4 + 0; + auto B_offset = past_value + (zin * head_dim + hm * 16 + nl * 4 + 0) * param.max_kv_len + (0 * 8 + kcl); + + + for(int i = 0; i < value_seq_len; i += 8){ + sdata[rcl * 8 + kl * 4 + 0] = (i + kl * 4 + 0 < value_seq_len) ? A_offset[i + 0] : 0.0; + sdata[rcl * 8 + kl * 4 + 1] = (i + kl * 4 + 1 < value_seq_len) ? A_offset[i + 1] : 0.0; + sdata[rcl * 8 + kl * 4 + 2] = (i + kl * 4 + 2 < value_seq_len) ? A_offset[i + 2] : 0.0; + sdata[rcl * 8 + kl * 4 + 3] = (i + kl * 4 + 3 < value_seq_len) ? A_offset[i + 3] : 0.0; + + sdata[128 + kcl * 16 + nl * 4 + 0] = (i + kcl < value_seq_len && hm * 16 + nl * 4 + 0 < head_dim) ? B_offset[i + 0 * param.max_kv_len] : 0.0; + sdata[128 + kcl * 16 + nl * 4 + 1] = (i + kcl < value_seq_len && hm * 16 + nl * 4 + 1 < head_dim) ? B_offset[i + 1 * param.max_kv_len] : 0.0; + sdata[128 + kcl * 16 + nl * 4 + 2] = (i + kcl < value_seq_len && hm * 16 + nl * 4 + 2 < head_dim) ? B_offset[i + 2 * param.max_kv_len] : 0.0; + sdata[128 + kcl * 16 + nl * 4 + 3] = (i + kcl < value_seq_len && hm * 16 + nl * 4 + 3 < head_dim) ? B_offset[i + 3 * param.max_kv_len] : 0.0; + + + threadgroup_barrier(mem_flags::mem_threadgroup); + + simdgroup_load(sga[0], (const threadgroup float*)sdata, 8); + simdgroup_load(sga[1], ((const threadgroup float*)sdata) + 64, 8); + + simdgroup_load(sgb[0], ((const threadgroup float*)sdata) + 128, 16); + simdgroup_load(sgb[1], ((const threadgroup float*)sdata) + 136, 16); + + simdgroup_multiply_accumulate(sgd[0], sga[0], sgb[0], sgd[0]); + simdgroup_multiply_accumulate(sgd[1], sga[1], sgb[0], sgd[1]); + simdgroup_multiply_accumulate(sgd[2], sga[0], sgb[1], sgd[2]); + simdgroup_multiply_accumulate(sgd[3], sga[1], sgb[1], sgd[3]); + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + simdgroup_store(sgd[0], (threadgroup float*)sdata, 8); + simdgroup_store(sgd[1], (threadgroup float*)sdata + 64, 8); + simdgroup_store(sgd[2], (threadgroup float*)sdata + 128, 8); + simdgroup_store(sgd[3], (threadgroup float*)sdata + 192, 8); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // [N2, M2, M8, N8] + auto xy_out = output + ((sl * 16 + rcl) * head_num + z) * head_dim + hm * 16 + kl * 8 + 0; + if(sl * 16 + rcl < q_seq_len) { + if(hm * 16 + kl * 8 + 0 < head_dim) { + xy_out[0] = sdata[(kl * 16 + rcl) * 8 + 0]; + } + if(hm * 16 + kl * 8 + 1 < head_dim) { + xy_out[1] = sdata[(kl * 16 + rcl) * 8 + 1]; + } + if(hm * 16 + kl * 8 + 2 < head_dim) { + xy_out[2] = sdata[(kl * 16 + rcl) * 8 + 2]; + } + if(hm * 16 + kl * 8 + 3 < head_dim) { + xy_out[3] = sdata[(kl * 16 + rcl) * 8 + 3]; + } + if(hm * 16 + kl * 8 + 4 < head_dim) { + xy_out[4] = sdata[(kl * 16 + rcl) * 8 + 4]; + } + if(hm * 16 + kl * 8 + 5 < head_dim) { + xy_out[5] = sdata[(kl * 16 + rcl) * 8 + 5]; + } + if(hm * 16 + kl * 8 + 6 < head_dim) { + xy_out[6] = sdata[(kl * 16 + rcl) * 8 + 6]; + } + if(hm * 16 + kl * 8 + 7 < head_dim) { + xy_out[7] = sdata[(kl * 16 + rcl) * 8 + 7]; + } + } + +#else + const int x = gid.x; // kv_seq_len + const int y = gid.y; // head_num + const int z = gid.z; // head_dim + if (x >= param.query_seq_len || y >= param.head_num || z >= param.head_dim) { + return; + } + int group = param.group; + int yin = y / group; + int q_seq_len = param.query_seq_len; + int value_seq_len = param.key_seq_len; + int head_num = param.head_num; + int head_dim = param.head_dim; + const int stride = head_num * head_dim / group; + const int offset_head = yin * head_dim + z; + + device const T *A_offset = input0 + (y * q_seq_len + x) * value_seq_len; + device const T *B_offset = past_value + offset_head * param.max_kv_len; + float out = 0.0; + + for(int i = 0; i < value_seq_len; ++i){ + float A0 = (float)A_offset[i]; + float B = (float)B_offset[i]; + out += A0 * B; + } + output[ x * stride * group + (y * head_dim + z)] = out; +#endif +} + +kernel void decode_qkv(const device T* input0 [[buffer(0)]], + device T* output [[buffer(1)]], + device T* past_value [[buffer(2)]], + constant Param& param [[buffer(3)]], +#ifdef SIMD_GROUP_REDUCE + uint3 gid[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]] +#else + uint3 gid[[thread_position_in_grid]] +#endif +) { + const int x = gid.x; // query_seq_len + const int y = gid.y; // head_num + const int z = gid.z; // head_dim + if (x >= param.query_seq_len || y >= param.head_num || z >= param.head_dim) { + return; + } + + int yin = y / param.group; + int value_seq_len = param.key_seq_len; + + int head_dim = param.head_dim; + + const int offset_head = (yin * head_dim + z) * param.max_kv_len; + + device const T *A_offset = input0 + y * value_seq_len; + device T *Pastvalue_offset = past_value + offset_head; + float out = 0; + +#ifdef SIMD_GROUP_REDUCE + for(int i = tiisg; i < value_seq_len; i+=SIMD_GROUP_WIDTH){ + float A = (float)A_offset[i]; + float B = (float)Pastvalue_offset[i]; + + out += A * B; + } + out = simd_sum(out); + if(tiisg == 0) { + output[(y * head_dim + z)] = (T)out; + } +#else + for(int i = 0; i < value_seq_len; i++){ + float A = (float)A_offset[i]; + float B = (float)Pastvalue_offset[i]; + + out += A * B; + } + output[(y * head_dim + z)] = (T)out; +#endif +} +)metal"; + +#endif/* MNN_SUPPORT_TRANSFORMER_FUSE */ +#endif + diff --git a/source/backend/metal/MetalConvolution1x1.hpp b/source/backend/metal/MetalConvolution1x1.hpp index 672d433b6..bda5a483f 100644 --- a/source/backend/metal/MetalConvolution1x1.hpp +++ b/source/backend/metal/MetalConvolution1x1.hpp @@ -23,7 +23,7 @@ class MetalConvolution1x1 : public MetalConvolutionCommon { virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override; virtual void onEncode(const std::vector &inputs, const std::vector &outputs, id encoder) override; private: - MetalConvolution1x1(Backend *backend, const MNN::Op *op, std::shared_ptr weight, std::shared_ptr bias, std::shared_ptr dequantScale, int dequantBits); + MetalConvolution1x1(Backend *backend, const MNN::Op *op, std::shared_ptr weight, std::shared_ptr bias, std::shared_ptr dequantScale, int dequantBits, float scaleCoef); id mPipeline; std::pair mThreads; }; diff --git a/source/backend/metal/MetalConvolution1x1.mm b/source/backend/metal/MetalConvolution1x1.mm index 10ec73c9c..1d0c8d7c4 100644 --- a/source/backend/metal/MetalConvolution1x1.mm +++ b/source/backend/metal/MetalConvolution1x1.mm @@ -31,11 +31,12 @@ loadWeight(op, ldInt8Weight); } -MetalConvolution1x1::MetalConvolution1x1(Backend *backend, const MNN::Op *op, std::shared_ptr weight, std::shared_ptr bias, std::shared_ptr dequantScale, int dequantBits) : MetalConvolutionCommon(backend, op, bias) { +MetalConvolution1x1::MetalConvolution1x1(Backend *backend, const MNN::Op *op, std::shared_ptr weight, std::shared_ptr bias, std::shared_ptr dequantScale, int dequantBits, float scaleCoef) : MetalConvolutionCommon(backend, op, bias) { mWeight = weight; mBias = bias; mDequantScaleBias = dequantScale; mDequantBits = dequantBits; + mScaleCoef = scaleCoef; } @@ -46,7 +47,7 @@ if (nullptr == dst) { return true; } - *dst = new MetalConvolution1x1(bn, op, mWeight, mBias, mDequantScaleBias, mDequantBits); + *dst = new MetalConvolution1x1(bn, op, mWeight, mBias, mDequantScaleBias, mDequantBits, mScaleCoef); return true; } @@ -72,12 +73,26 @@ auto context = (__bridge MNNMetalContext *)backend->context(); int blockSize = 1; if (mDequantScaleBias.get()) { - blockSize = (int)(mDequantScaleBias->usize() /sizeof(float) / oc_4 / 2 / 4); + int bytes = sizeof(float); + if(backend->useFp16InsteadFp32()) { + bytes = sizeof(__fp16); + } + blockSize = (int)(mDequantScaleBias->usize() / bytes / oc_4 / 2 / 4); } // create const buffer - int constants[] = {is, ic_4, ow, oh, os, oc_4, oc, ob, blockSize, mActivationType}; - mConstBuffer = backend->getConstBuffer(sizeof(constants)); - ::memcpy(mConstBuffer.contents, constants, sizeof(constants)); + mConstBuffer = backend->getConstBuffer(sizeof(Param)); + auto param = (Param *)mConstBuffer.contents; + param->input_size = is; + param->input_slice = ic_4; + param->output_width = ow; + param->output_height = oh; + param->output_size = os; + param->output_slice = oc_4; + param->output_channel = oc; + param->batch = ob; + param->block_size = blockSize; + param->activation = mActivationType; + param->scale_coef = mScaleCoef; MetalRuntime* rt = (MetalRuntime *)backend->runtime(); if (mDequantScaleBias.get()) { diff --git a/source/backend/metal/MetalConvolutionCommon.hpp b/source/backend/metal/MetalConvolutionCommon.hpp index a391d65e2..ac1175a2e 100644 --- a/source/backend/metal/MetalConvolutionCommon.hpp +++ b/source/backend/metal/MetalConvolutionCommon.hpp @@ -26,8 +26,21 @@ class MetalConvolutionCommon : public MetalExecution { virtual std::shared_ptr weightTransform(int group, int oc, int ic, int kh, int kw, const float *src, bool int8Weight = false, bool int4Weight = false); -private: - +protected: + struct Param { + int input_size; + int input_slice; + int output_width; + int output_height; + int output_size; + int output_slice; + int output_channel; + int batch; + int block_size; + int activation; + float scale_coef; + }; + protected: int mKernelX = 0; int mKernelY = 0; @@ -42,6 +55,7 @@ class MetalConvolutionCommon : public MetalExecution { std::shared_ptr mBias; std::shared_ptr mDequantScaleBias; int mDequantBits; + float mScaleCoef; id mConstBuffer = nil; }; diff --git a/source/backend/metal/MetalConvolutionCommon.mm b/source/backend/metal/MetalConvolutionCommon.mm index f47464209..e5b918935 100644 --- a/source/backend/metal/MetalConvolutionCommon.mm +++ b/source/backend/metal/MetalConvolutionCommon.mm @@ -97,7 +97,8 @@ void weightInBlock(int group, int oc, int ic, int kh, int kw, const FType *src, } } -static std::shared_ptr getDequantScale(const float* scale, int size, MetalBackend *backend, bool asymmetric, int oc) { +template +static std::pair, float> getDequantScale(const float* scale, int size, MetalBackend *backend, bool asymmetric, int oc) { int totalCount = 0; if (asymmetric) { totalCount = size / 2; @@ -106,15 +107,32 @@ void weightInBlock(int group, int oc, int ic, int kh, int kw, const FType *src, } int blockSize = totalCount / oc; int alignOutputCount = ALIGN_UP4(oc); - std::shared_ptr dequantScale(MNN::Tensor::createDevice({alignOutputCount, blockSize, (int)(sizeof(float) * 2)})); + std::shared_ptr dequantScale(MNN::Tensor::createDevice({alignOutputCount, blockSize, (int)(sizeof(DType) * 2)})); bool res = backend->onAcquireBuffer(dequantScale.get(), Backend::STATIC); if (!res) { MNN_ERROR("Buffer allocated error!\n"); - return nullptr; + return std::make_pair(nullptr, 1.0); } auto buffer0 = MetalBackend::getBuffer(dequantScale.get()); - auto dst_scale = (float*)((uint8_t*)[buffer0.first contents] + buffer0.second); + DType* dst_scale = (DType*)((uint8_t*)[buffer0.first contents] + buffer0.second); ::memset(dst_scale, 0, dequantScale->usize()); + + float coef = 1.0; + if(std::is_same::value) { + float max_data = 0.0; + for (int z=0; z max_data) { + max_data = temp; + } + } + } + coef = 65504.0 / max_data; + } if (asymmetric) { for (int z=0; zmain_as_Convolution2D(); @@ -166,12 +184,20 @@ void weightInBlock(int group, int oc, int ic, int kh, int kw, const FType *src, ic = size / kw / kh / (oc / group); } - // convert + // convert if (loadWeightInt8 && qnt->weight.get() != nullptr) { auto backend = static_cast(this->backend()); mWeight = weightTransform(group, oc, ic, kh, kw, (float*)qnt->weight.get(), !qnt->canUseInt4, qnt->canUseInt4); - auto dequantParams = getDequantScale(qnt->alpha.get(), qnt->alpha.size(), backend, qnt->asymmetric, oc); - mDequantScaleBias = dequantParams; + if(backend->useFp16InsteadFp32()) { + auto dequantParams = getDequantScale<__fp16>(qnt->alpha.get(), qnt->alpha.size(), backend, qnt->asymmetric, oc); + mDequantScaleBias = dequantParams.first; + mScaleCoef = dequantParams.second; + } else { + auto dequantParams = getDequantScale(qnt->alpha.get(), qnt->alpha.size(), backend, qnt->asymmetric, oc); + mDequantScaleBias = dequantParams.first; + mScaleCoef = dequantParams.second; + } + mDequantBits = qnt->canUseInt4 ? 4:8; } else if (qnt && qnt->weightFloat.size() > 0) { mWeight = weightTransform(group, oc, ic, kh, kw, qnt->weightFloat.get(), false, false); diff --git a/source/backend/metal/MetalDeconvolution.hpp b/source/backend/metal/MetalDeconvolution.hpp index 5dc8dc4ed..c6066d646 100644 --- a/source/backend/metal/MetalDeconvolution.hpp +++ b/source/backend/metal/MetalDeconvolution.hpp @@ -24,16 +24,7 @@ class MetalDeconvolution : public MetalExecution { private: bool mDepthwise = false; int mGroup = 0; - int mKernelX = 0; - int mKernelY = 0; PadMode mPadMode = PadMode_CAFFE; - int mPadX = 0; - int mPadY = 0; - int mStrideX = 0; - int mStrideY = 0; - int mDilateX = 0; - int mDilateY = 0; - int mActivationType = 0; const MNN::Op *mOp = nullptr; diff --git a/source/backend/metal/MetalDeconvolution.mm b/source/backend/metal/MetalDeconvolution.mm index 4338d9e30..af3a6d9da 100755 --- a/source/backend/metal/MetalDeconvolution.mm +++ b/source/backend/metal/MetalDeconvolution.mm @@ -14,7 +14,33 @@ #if MNN_METAL_ENABLED namespace MNN { - +struct deconv_constants { + int input_width; + int input_height; + int input_size; + int input_slice; + int output_width; + int output_height; + int output_size; + int output_slice; + + int kernel_x; + int kernel_y; + int kernel_size; + int stride_x; + int stride_y; + int pad_x; + int pad_y; + int dilation_x; + int dilation_y; + + int delta_ky; + int delta_kx; + int delta_iy; + int delta_ix; + int batch; + int activation; +}; static int leastCommonMultiple(int m, int n) { int a = m, b = n; while(a != b){ @@ -130,17 +156,7 @@ void weightForDeconv(std::shared_ptr t, bool depthwise, const Convo auto common = deconv->common(); mOp = op; mDepthwise = op->type() == MNN::OpType_DeconvolutionDepthwise; - mGroup = common->group(); - mKernelX = common->kernelX(); - mKernelY = common->kernelY(); mPadMode = common->padMode(); - mPadX = common->padX(); - mPadY = common->padY(); - mStrideX = common->strideX(); - mStrideY = common->strideY(); - mDilateX = common->dilateX(); - mDilateY = common->dilateY(); - mActivationType = common->relu() ? 1 : (common->relu6() ? 2 : 0); // forcy downgrade to float like what CPU does std::shared_ptr qnt = NULL; @@ -167,9 +183,13 @@ void weightForDeconv(std::shared_ptr t, bool depthwise, const Convo mValid = false; return; } + auto weightBuffer = MetalBackend::getBuffer(mWeight.get()); + auto ptr = (uint8_t*)weightBuffer.first.contents + weightBuffer.second; if (mtbn->useFp16InsteadFp32()) { + ::memset(ptr, 0, weightSize * sizeof(int16_t)); weightForDeconv<__fp16>(mWeight, mDepthwise, deconv, qnt.get()); } else { + ::memset(ptr, 0, weightSize * sizeof(float)); weightForDeconv(mWeight, mDepthwise, deconv, qnt.get()); } mBias = biasForDeconv(backend, deconv, mtbn->useFp16InsteadFp32()); @@ -182,6 +202,24 @@ void weightForDeconv(std::shared_ptr t, bool depthwise, const Convo } else { mPipeline = [context pipelineWithName:@"deconv" fp16:mtbn->useFp16InsteadFp32()]; } + mConstBuffer = [context newDeviceBuffer:sizeof(deconv_constants) access:CPUWriteOnly]; + auto param = (deconv_constants*)mConstBuffer.contents; + + mGroup = common->group(); + param->kernel_x = common->kernelX(); + param->kernel_y = common->kernelY(); + param->kernel_size = common->kernelX() * common->kernelY(); + param->stride_x = common->strideX(); + param->stride_y = common->strideY(); + param->dilation_x = common->dilateX(); + param->dilation_y = common->dilateY(); + param->activation = common->relu() ? 1 : (common->relu6() ? 2 : 0); + auto deltaKy = leastCommonMultiple(common->dilateY(), common->strideY()) / common->dilateY(); + auto deltaKx = leastCommonMultiple(common->dilateX(), common->strideX()) / common->dilateX(); + param->delta_kx = deltaKx; + param->delta_ky = deltaKy; + param->delta_iy = deltaKy * common->dilateY() / common->strideY(); + param->delta_ix = deltaKx * common->dilateX() / common->strideX(); } ErrorCode MetalDeconvolution::onResize(const std::vector &inputs, const std::vector &outputs) { @@ -197,46 +235,28 @@ void weightForDeconv(std::shared_ptr t, bool depthwise, const Convo const int padY = pad.second; // const buffer - auto deltaKy = leastCommonMultiple(mDilateY, mStrideY) / mDilateY; - auto deltaKx = leastCommonMultiple(mDilateX, mStrideX) / mDilateX; - - int consts[] = { - iw, - ih, - iw * ih, - iz, - ow, - oh, - ow * oh, - oz, - mKernelX, - mKernelY, - mKernelX * mKernelY, - mStrideX, - mStrideY, - padX, - padY, - mDilateX, - mDilateY, - deltaKy, - deltaKx, - deltaKy * mDilateY / mStrideY, - deltaKx * mDilateX / mStrideX, - 1, - ob, - mActivationType - }; - mConstBuffer = [context newDeviceBuffer:sizeof(consts) bytes:consts access:CPUWriteOnly]; + auto param = (deconv_constants*)mConstBuffer.contents; + param->input_width = iw; + param->input_height = ih; + param->input_size = iw * ih; + param->input_slice = iz; + param->output_width = ow; + param->output_height = oh; + param->output_size = ow * oh; + param->output_slice = oz; + param->batch = ob; + param->pad_x = padX; + param->pad_y = padY; mThreads = [context computeBestGroupAndLocal:mPipeline threads:MTLSizeMake((NSUInteger) ow, (NSUInteger)oh, (NSUInteger)oz * ob)]; return NO_ERROR; } void MetalDeconvolution::onEncode(const std::vector &inputs, const std::vector &outputs, id encoder) { - auto input = inputs[0], output = outputs[0]; + auto input = inputs[0], output = outputs[0]; [encoder setComputePipelineState:mPipeline]; - [encoder setBuffer:(id)((MetalRuntimeAllocator::MetalBufferAlloc *)input->deviceId())->getBuffer() offset:TensorUtils::getDescribe(input)->extra.offset atIndex:0]; - [encoder setBuffer:(id)((MetalRuntimeAllocator::MetalBufferAlloc *)output->deviceId())->getBuffer() offset:TensorUtils::getDescribe(output)->extra.offset atIndex:1]; + MetalBackend::setTensor(input, encoder, 0); + MetalBackend::setTensor(output, encoder, 1); [encoder setBuffer:mConstBuffer offset:0 atIndex:2]; MetalBackend::setTensor(mWeight.get(), encoder, 3); MetalBackend::setTensor(mBias.get(), encoder, 4); diff --git a/source/backend/metal/shader/MetalConvolution1x1.metal b/source/backend/metal/shader/MetalConvolution1x1.metal index 962d95447..07ac77b3e 100644 --- a/source/backend/metal/shader/MetalConvolution1x1.metal +++ b/source/backend/metal/shader/MetalConvolution1x1.metal @@ -32,6 +32,7 @@ struct conv1x1_constants { int batch; int block_size; conv_activation_type activation; + float scale_coef; }; kernel void conv1x1_g1z4(const device ftype4 *in [[buffer(0)]], @@ -76,7 +77,7 @@ kernel void conv1x1_g1z4_w8(const device ftype4 *in [[buffer(0)]], constant conv1x1_constants& cst [[buffer(2)]], const device MNN::char4x4 *wt [[buffer(3)]], const device ftype4 *biasTerms [[buffer(4)]], - const device float4 *dequantScale [[buffer(5)]], + const device ftype4 *dequantScale [[buffer(5)]], uint3 gid [[thread_position_in_grid]]) { if ((int)gid.x * CONV_UNROLL >= cst.output_size || (int)gid.y >= cst.output_slice || (int)gid.z >= cst.batch) return; @@ -90,8 +91,8 @@ kernel void conv1x1_g1z4_w8(const device ftype4 *in [[buffer(0)]], int computeSize = min(cst.output_size - rx, CONV_UNROLL); int block = (cst.input_slice + cst.block_size - 1) / cst.block_size; for (int bi=0; bi= cst.output_size || (int)gid.y >= cst.output_slice || (int)gid.z >= cst.batch) return; @@ -580,8 +581,8 @@ kernel void conv1x1_g1z4_w4(const device ftype4 *in [[buffer(0)]], int computeSize = min(cst.output_size - rx, CONV_UNROLL); int block = (cst.input_slice + cst.block_size - 1) / cst.block_size; for (int bi=0; bi= cst.output_width || (int)gid.y >= cst.output_height || (int)gid.z >= cst.batch * cst.output_slice) return; - - FLOAT4 result = FLOAT4(biasTerms[(int)(gid.z / cst.batch)]); + int oz = (int)gid.z / cst.batch; + FLOAT4 result = FLOAT4(biasTerms[oz]); int oy = (int)gid.y + cst.pad_y; int ox = (int)gid.x + cst.pad_x; @@ -95,7 +92,7 @@ kernel void deconv_depthwise(const device ftype4 *in [[buffer(0)]], int min_iy = (oy - max_ky * cst.dilation_y) / cst.stride_y; int min_ix = (ox - max_kx * cst.dilation_x) / cst.stride_x; - auto z_wt = wt + (int)gid.z * cst.kernel_size; + auto z_wt = wt + oz * cst.kernel_size; auto z_in = in + (int)gid.z * cst.input_size; for (auto ky = max_ky, iy = min_iy; ky >= min_ky; ky -= cst.delta_ky, iy += cst.delta_iy) { for (auto kx = max_kx, ix = min_ix; kx >= min_kx; kx -= cst.delta_kx, ix += cst.delta_ix) { diff --git a/source/backend/opencl/core/BufferConvertor.cpp b/source/backend/opencl/core/BufferConvertor.cpp index 1f6abd82b..57139bb93 100644 --- a/source/backend/opencl/core/BufferConvertor.cpp +++ b/source/backend/opencl/core/BufferConvertor.cpp @@ -574,6 +574,80 @@ bool convertBufferToBuffer(Tensor *input, Tensor *output, OpenCLRuntime *runtime return true; } +bool convertBetweenAHDandCLmem(const Tensor *input, const Tensor *output, OpenCLRuntime *runtime, int memType, bool toDevice, bool toHost) { + std::set buildOptions; + auto srcDimensionFormat = TensorUtils::getDescribe(input)->dimensionFormat; + auto dstDimensionFormat = TensorUtils::getDescribe(output)->dimensionFormat; + if(runtime->getGpuMemType() == IMAGE){ + buildOptions.emplace("-DUSE_IMAGE"); + } + + buildOptions.emplace("-DINPUT_FORMAT=" + std::to_string(srcDimensionFormat)); + buildOptions.emplace("-DOUTPUT_FORMAT=" + std::to_string(dstDimensionFormat)); + std::vector outputShape; + std::shared_ptr kernelW; + if(toDevice){ + buildOptions.emplace("-DSHARED_TO_CL"); + kernelW = runtime->buildKernelWithCache("glmem_convert", "gl_to_cl", buildOptions, nullptr, output); + outputShape = tensorShapeFormat(output); + } else if(toHost){ + buildOptions.emplace("-DCL_TO_SHARED"); + kernelW = runtime->buildKernelWithCache("glmem_convert", "cl_to_gl", buildOptions, input, nullptr); + outputShape = tensorShapeFormat(input); + }else{ + MNN_PRINT("convertGLMemBetweenCLmem only support toDevice or toHost!\n"); + return false; + } + + int shape[4] = {outputShape[0], outputShape[3], outputShape[1], outputShape[2]};//N C H W + uint32_t gws[3] = {static_cast(UP_DIV(shape[3], 4)), + static_cast(UP_DIV(shape[1], 4)), + static_cast(shape[0] * shape[2])}; + auto Kernel = kernelW->get(); + uint32_t idx = 0; + cl_int ret = CL_SUCCESS; + ret |= Kernel.setArg(idx++, gws[0]); + ret |= Kernel.setArg(idx++, gws[1]); + ret |= Kernel.setArg(idx++, gws[2]); + if(toDevice){ + ret |= Kernel.setArg(idx++, *((CLSharedMemReleaseBuffer*)TensorUtils::getSharedMem(input))->getMem()); + }else{ + if(runtime->getGpuMemType() == IMAGE) { + ret |= Kernel.setArg(idx++, openCLImage(input)); + } + else { + ret |= Kernel.setArg(idx++, openCLBuffer(input)); + } + } + if (toHost){ + ret |= Kernel.setArg(idx++, *((CLSharedMemReleaseBuffer*)TensorUtils::getSharedMem(output))->getMem()); + }else{ + if(runtime->getGpuMemType() == IMAGE) { + ret |= Kernel.setArg(idx++, openCLImage(output)); + } else { + ret |= Kernel.setArg(idx++, openCLBuffer(output)); + } + } + ret |= Kernel.setArg(idx++, sizeof(shape), shape); + MNN_CHECK_CL_SUCCESS(ret, "setArg glmem_convert"); + + const uint32_t maxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(kernelW)); + const std::vector lws = {16, std::max((uint32_t)1, maxWorkGroupSize / 16), 1}; + cl::Event event; + cl_int res; + std::vector roundUpGroupWorkSize(lws.size()); + for (size_t i = 0; i < lws.size(); ++i) { + roundUpGroupWorkSize[i] = ROUND_UP(gws[i], lws[i]); + } + + res = runtime->commandQueue().enqueueNDRangeKernel(Kernel, cl::NullRange, + cl::NDRange(roundUpGroupWorkSize[0], roundUpGroupWorkSize[1], roundUpGroupWorkSize[2]), + cl::NDRange(lws[0], lws[1], lws[2]), nullptr, &event); + event.wait(); + MNN_CHECK_CL_SUCCESS(res, "glmem_convert"); + return true; +} + } // namespace OpenCL } // namespace MNN #endif /* MNN_OPENCL_BUFFER_CLOSED */ diff --git a/source/backend/opencl/core/BufferConvertor.hpp b/source/backend/opencl/core/BufferConvertor.hpp index b1843226e..f14d21e80 100644 --- a/source/backend/opencl/core/BufferConvertor.hpp +++ b/source/backend/opencl/core/BufferConvertor.hpp @@ -14,6 +14,7 @@ #include "core/Macro.h" #include #include "backend/opencl/core/OpenCLRunningUtils.hpp" +#include "backend/opencl/core/OpenCLBackend.hpp" namespace MNN { namespace OpenCL { @@ -33,6 +34,7 @@ bool convertNC4HW4BufferBetweenNC16HW16Buffer(const Tensor *input, Tensor *outpu #endif bool convertBufferToBuffer(Tensor *input, Tensor *output, OpenCLRuntime *runtime, bool toDevice, bool toHost, bool needWait = false, bool svmFlag = false); +bool convertBetweenAHDandCLmem(const Tensor *input, const Tensor *output, OpenCLRuntime *runtime, int memType, bool toDevice, bool toHost); class BufferConvertor { public: diff --git a/source/backend/opencl/core/BufferPool.cpp b/source/backend/opencl/core/BufferPool.cpp index 89e09fcc5..4398d4bd1 100644 --- a/source/backend/opencl/core/BufferPool.cpp +++ b/source/backend/opencl/core/BufferPool.cpp @@ -28,7 +28,6 @@ cl::Buffer* BufferPool::alloc(size_t size, bool separate) { return nullptr; } mAllBuffer.insert(std::make_pair(node->buffer.get(), node)); - return node->buffer.get(); } diff --git a/source/backend/opencl/core/OpenCLBackend.cpp b/source/backend/opencl/core/OpenCLBackend.cpp index 23cb43b2e..132915c6c 100644 --- a/source/backend/opencl/core/OpenCLBackend.cpp +++ b/source/backend/opencl/core/OpenCLBackend.cpp @@ -333,8 +333,8 @@ Backend::MemObj* OpenCLBackend::onAcquire(const Tensor* nativeTensor, StorageTyp if(mOpenCLRuntime->getGpuMemType() == BUFFER) { size_t size; float typeSize = getBytes(nativeTensor); - if (nativeTensor->dimensions() >= 2) { - auto alignC = ROUND_UP(C, 8); + if (MNN_DATA_FORMAT_NC4HW4 == TensorUtils::getDescribe(nativeTensor)->dimensionFormat && nativeTensor->dimensions() >= 2) { + auto alignC = ROUND_UP(C, 4); // increment of height and width auto hR = ROUND_UP(H + 3, 4) - H; auto wR = ROUND_UP(W + 3, 4) - W; @@ -353,7 +353,6 @@ Backend::MemObj* OpenCLBackend::onAcquire(const Tensor* nativeTensor, StorageTyp } // Align when int4 memory size = ROUND_UP(size, 2); - if (storageType == DYNAMIC_SEPERATE) { auto buffer = mBufferPool->alloc(size*typeSize, true); ((Tensor*)nativeTensor)->buffer().device = (uint64_t)buffer; @@ -593,32 +592,53 @@ bool OpenCLBackend::isCreateError() const { return mIsCreateError; } -void OpenCLBackend::_allocHostBuffer(int length, const Tensor* srcTensor) const { +bool OpenCLBackend::_allocHostBuffer(int length, const Tensor* srcTensor) const { auto memType = srcTensor->buffer().flags; - if (nullptr != mHostBuffer.second && length <= mHostBuffer.first && memType != MNN_FORWARD_OPENCL && memType != MNN_FORWARD_OPENGL) { - return; - } - if(memType == MNN_FORWARD_OPENCL){ - mDeviceBuffer = (cl::Buffer*)srcTensor->buffer().device; + if (nullptr != mHostBuffer.second && length <= mHostBuffer.first && memType != MNN_MEMORY_AHARDWAREBUFFER) { + return true; } + cl_int error; #ifdef __ANDROID__ - else if(memType == MNN_FORWARD_OPENGL && mOpenCLRuntime->isSupportGL()){ - cl_int error; - mDeviceTexture.reset(new cl::ImageGL(mOpenCLRuntime->context(), CL_MEM_READ_WRITE, GL_TEXTURE_2D, 0, (cl_GLuint)srcTensor->buffer().device, &error)); - std::vector map = {*mDeviceTexture.get()}; - mOpenCLRuntime->commandQueue().enqueueAcquireGLObjects(&map, NULL); - } + if(MNN_MEMORY_AHARDWAREBUFFER == memType){ + if (mOpenCLRuntime->isSupportAHD()){ + CLSharedMemReleaseBuffer *sharedMem = (CLSharedMemReleaseBuffer*)TensorUtils::getSharedMem(srcTensor); + if(sharedMem == nullptr || (sharedMem != nullptr && srcTensor->buffer().device != sharedMem->getSharedId())){ + if(mOpenCLRuntime->getGpuType() == MALI){ + const cl_import_properties_arm properties[] = {CL_IMPORT_TYPE_ARM, CL_IMPORT_TYPE_ANDROID_HARDWARE_BUFFER_ARM, 0}; + Backend::MemObj* SharedTmp = new CLSharedMemReleaseBuffer(srcTensor->buffer().device, new cl::Buffer(mOpenCLRuntime->context(), (cl_mem_flags)CL_MEM_READ_WRITE, properties, (void*)srcTensor->buffer().device, CL_IMPORT_MEMORY_WHOLE_ALLOCATION_ARM, &error)); + TensorUtils::setSharedMem(srcTensor, SharedTmp); + }else if(mOpenCLRuntime->getGpuType() == ADRENO){ + cl_mem_ahardwarebuffer_host_ptr myAHBmem = {0}; + myAHBmem.ext_host_ptr.allocation_type = CL_MEM_ANDROID_AHARDWAREBUFFER_HOST_PTR_QCOM; + myAHBmem.ext_host_ptr.host_cache_policy = CL_MEM_HOST_WRITEBACK_QCOM; + myAHBmem.ahb_ptr = (AHardwareBuffer*)srcTensor->buffer().device; + Backend::MemObj* SharedTmp = new CLSharedMemReleaseBuffer(srcTensor->buffer().device, new cl::Buffer(mOpenCLRuntime->context(), (cl_mem_flags)(CL_MEM_USE_HOST_PTR | CL_MEM_EXT_HOST_PTR_QCOM), 0, &myAHBmem, &error)); + TensorUtils::setSharedMem(srcTensor, SharedTmp); + } else{ + MNN_ERROR("This device not support AHardWareBuffer\n"); + return false; + } + if (error != CL_SUCCESS) { + MNN_ERROR("Alloc mAHardWareBuffer error, code:%d \n", error); + return false; + } + } + } else{ + MNN_ERROR("This device not support AHardWareBuffer\n"); + return false; + } + } else #endif - else{ + { MNN_ASSERT(length > 0); - cl_int res; mHostBuffer.first = length; - mHostBuffer.second.reset(new cl::Buffer(mOpenCLRuntime->context(), (cl_mem_flags)(CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR), (size_t)length, NULL, &res)); - if (nullptr == mHostBuffer.second.get() || res != CL_SUCCESS) { - MNN_ERROR("Alloc mHostBuffer %d error, code:%d \n", length, res); - return; + mHostBuffer.second.reset(new cl::Buffer(mOpenCLRuntime->context(), (cl_mem_flags)(CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR), (size_t)length, NULL, &error)); + if (nullptr == mHostBuffer.second.get() || error != CL_SUCCESS) { + MNN_ERROR("Alloc mHostBuffer %d error, code:%d \n", length, error); + return false; } } + return true; } void OpenCLBackend::copyFromDeviceInt8(const Tensor* srcTensor, const Tensor* dstTensor) const{ @@ -674,15 +694,15 @@ int OpenCLBackend::onSync(Tensor::MapType mtype, bool toCpu, const Tensor* dstTe } void CLRuntime::convertFromDevice(const Tensor* srcTensor, const Tensor* dstTensor, MNN_DATA_FORMAT data_format, bool svmFlag, int memtype) const { +#ifdef __ANDROID__ + if(MNN_MEMORY_AHARDWAREBUFFER == memtype){ + convertBetweenAHDandCLmem(const_cast(srcTensor), const_cast(dstTensor), mOpenCLRuntime.get(), memtype, false, true); + return; + } +#endif #ifndef MNN_OPENCL_BUFFER_CLOSED if(mOpenCLRuntime->getGpuMemType() == BUFFER) { - if(MNN_FORWARD_OPENGL == memtype && mOpenCLRuntime->isSupportGL()){ - OpenCL::convertNC4HW4BufferToImage(srcTensor, const_cast(dstTensor), mOpenCLRuntime.get(), false, svmFlag); - std::vector map = {openCLImage(dstTensor)}; - mOpenCLRuntime->commandQueue().enqueueReleaseGLObjects(&map, NULL); - return; - } #ifdef MNN_SUPPORT_INTEL_SUBGROUP int cPack = TensorUtils::getTensorChannelPack(srcTensor); if (cPack == 16 && mOpenCLRuntime->isSupportedIntelSubgroup()) { @@ -710,17 +730,6 @@ void CLRuntime::convertFromDevice(const Tensor* srcTensor, const Tensor* dstTens else #endif /* MNN_OPENCL_BUFFER_CLOSED */ { - if(MNN_FORWARD_OPENGL == memtype && mOpenCLRuntime->isSupportGL()){ - std::vector bufferShape = MNN::OpenCL::tensorShapeFormat(srcTensor); - - mOpenCLRuntime.get()->commandQueue().enqueueCopyImage( - openCLImage(srcTensor), openCLImage(dstTensor), - {0, 0, 0}, {0, 0, 0}, - {(size_t)bufferShape[2]* UP_DIV(bufferShape[3], 4), (size_t)bufferShape[0]*bufferShape[1], 1}); - std::vector map = {openCLImage(dstTensor)}; - mOpenCLRuntime->commandQueue().enqueueReleaseGLObjects(&map, NULL); - return; - } switch (data_format) { case MNN_DATA_FORMAT_NHWC: OpenCL::convertImageToNHWCBuffer(srcTensor, const_cast(dstTensor), mOpenCLRuntime.get(), false, svmFlag); @@ -748,8 +757,7 @@ void OpenCLBackend::copyFromDevice(const Tensor* srcTensor, const Tensor* dstTen && (srcDimensionFormat == dstDimensionFormat || srcTensor->dimensions() <= 1) && MNN::MNN_DATA_FORMAT_NC4HW4 != dstDimensionFormat && MNN_DATA_FORMAT_NC4HW4 != srcDimensionFormat && (getDataType(srcTensor) == getDataType(dstTensor)) - && memType != MNN_FORWARD_OPENCL - && memType != MNN_FORWARD_OPENGL; + && memType != MNN_MEMORY_AHARDWAREBUFFER; if (mOpenCLRuntime->isSupportedFP16()) { // Fp16 if (dstTensor->getType().code == halide_type_float) { directCopy = false; @@ -792,15 +800,15 @@ void OpenCLBackend::copyFromDevice(const Tensor* srcTensor, const Tensor* dstTen void CLRuntime::convertToDevice(const Tensor* srcTensor, const Tensor* dstTensor, MNN_DATA_FORMAT data_format, bool svmFlag, int memtype) const { // Format: Host -> OpenCL +#ifdef __ANDROID__ + if(MNN_MEMORY_AHARDWAREBUFFER == memtype){ + convertBetweenAHDandCLmem(const_cast(srcTensor), const_cast(dstTensor), mOpenCLRuntime.get(), memtype, true, false); + return; + } +#endif #ifndef MNN_OPENCL_BUFFER_CLOSED if(mOpenCLRuntime->getGpuMemType() == BUFFER) { - if(MNN_FORWARD_OPENGL == memtype && mOpenCLRuntime->isSupportGL()){ - OpenCL::convertImageToNC4HW4Buffer(srcTensor, const_cast(dstTensor),mOpenCLRuntime.get(), false, svmFlag); - std::vector map = {openCLImage(srcTensor)}; - mOpenCLRuntime->commandQueue().enqueueReleaseGLObjects(&map, NULL); - return; - } #ifdef MNN_SUPPORT_INTEL_SUBGROUP int cPack = TensorUtils::getTensorChannelPack(dstTensor); if (cPack == 16 && mOpenCLRuntime->isSupportedIntelSubgroup()) { @@ -821,17 +829,6 @@ void CLRuntime::convertToDevice(const Tensor* srcTensor, const Tensor* dstTensor else #endif /* MNN_OPENCL_BUFFER_CLOSED */ { - if(MNN_FORWARD_OPENGL == memtype && mOpenCLRuntime->isSupportGL()){ - std::vector bufferShape = MNN::OpenCL::tensorShapeFormat(dstTensor); - - mOpenCLRuntime.get()->commandQueue().enqueueCopyImage( - openCLImage(srcTensor), openCLImage(dstTensor), - {0, 0, 0}, {0, 0, 0}, - {(size_t)bufferShape[2]* UP_DIV(bufferShape[3], 4), (size_t)bufferShape[0]*bufferShape[1], 1}); - std::vector map = {openCLImage(srcTensor)}; - mOpenCLRuntime->commandQueue().enqueueReleaseGLObjects(&map, NULL); - return; - } if (MNN_DATA_FORMAT_NHWC == data_format) { OpenCL::convertNHWCBufferToImage(srcTensor, const_cast(dstTensor), mOpenCLRuntime.get(), false, svmFlag); } else if (MNN_DATA_FORMAT_NCHW == data_format) { @@ -868,8 +865,7 @@ void OpenCLBackend::copyToDevice(const Tensor* srcTensor, const Tensor* dstTenso && (srcDimensionFormat == dstDimensionFormat || srcTensor->dimensions() <= 1) && MNN_DATA_FORMAT_NC4HW4 != dstDimensionFormat && MNN_DATA_FORMAT_NC4HW4 != srcDimensionFormat && (getDataType(srcTensor) == getDataType(dstTensor)) - && memType != MNN_FORWARD_OPENCL - && memType != MNN_FORWARD_OPENGL; + && memType != MNN_MEMORY_AHARDWAREBUFFER; if (mOpenCLRuntime->isSupportedFP16()) { // Fp16 if (dstTensor->getType().code == halide_type_float) { directCopy = false; @@ -901,15 +897,13 @@ void OpenCLBackend::copyToDevice(const Tensor* srcTensor, const Tensor* dstTenso #else auto res = mOpenCLRuntime->commandQueue().enqueueWriteBuffer(*mHostBuffer.second, CL_TRUE, 0, needSize, hostPtr); if(res != CL_SUCCESS) { - MNN_ERROR("OpenCL enqueue write error:%d\n", res); - return; + MNN_ERROR("OpenCL enqueue write error:%d\n", res); + return; } #endif //Covert format mCLRuntime->convertToDevice((const Tensor*)&interTensor, dstTensor, srcDimensionFormat, false); - - return; } void OpenCLBackend::copyBetweenDevice(const Tensor* srcTensor, const Tensor* dstTensor) const{ @@ -918,33 +912,21 @@ void OpenCLBackend::copyBetweenDevice(const Tensor* srcTensor, const Tensor* dst if(MNN_FORWARD_CPU == srcMemtype && MNN_FORWARD_CPU == dstMemtype){ mCLRuntime->copyBetweenDevice(srcTensor, dstTensor); } else { - const Tensor* copyTensor = MNN_FORWARD_CPU != srcMemtype ? srcTensor : dstTensor; - MNN_DATA_FORMAT data_format = TensorUtils::getDescribe(copyTensor)->dimensionFormat; - int memType = MNN_FORWARD_CPU != srcMemtype ? srcMemtype : dstMemtype; - if(MNN_FORWARD_OPENCL != memType && MNN_FORWARD_OPENGL != memType){ - MNN_PRINT("Unsupport ForwardType %d for OpenCL backend!\n", memType); - return; - } - if(mOpenCLRuntime->isSupportGL() && MNN_FORWARD_OPENGL == memType){ - MNN_PRINT("This Device can not find OpenCL GL_EXTENTION function!\n"); + const Tensor* hostTensor = MNN_FORWARD_CPU != srcMemtype ? srcTensor : dstTensor; + const Tensor* deviceTensor = MNN_FORWARD_CPU == srcMemtype ? srcTensor : dstTensor; + MNN_DATA_FORMAT data_format = TensorUtils::getDescribe(deviceTensor)->dimensionFormat; + + bool alloc_error = _allocHostBuffer(0, hostTensor); + if(false == alloc_error){ + MNN_ERROR("Alloc _allocHostBuffer error\n"); return; } - _allocHostBuffer(0, copyTensor); - - MNN::Tensor interTensor(copyTensor, copyTensor->getDimensionType(), false); - TensorUtils::getDescribe(&interTensor)->dimensionFormat = data_format; - if(MNN_FORWARD_OPENCL == memType ){ - interTensor.buffer().device = (uint64_t)mDeviceBuffer; - }else if(MNN_FORWARD_OPENGL == memType){ - interTensor.buffer().device = (uint64_t)mDeviceTexture.get(); - }else{ - interTensor.buffer().device = (uint64_t)mHostBuffer.second.get(); - } + //Covert format if(MNN_FORWARD_CPU != srcMemtype){ - mCLRuntime->convertToDevice((const Tensor*)&interTensor, dstTensor, data_format, false, srcMemtype); + mCLRuntime->convertToDevice(hostTensor, deviceTensor, data_format, false, srcMemtype); }else{ - mCLRuntime->convertFromDevice(srcTensor, (const Tensor*)&interTensor, data_format, false, dstMemtype); + mCLRuntime->convertFromDevice(deviceTensor, hostTensor, data_format, false, dstMemtype); } } } diff --git a/source/backend/opencl/core/OpenCLBackend.hpp b/source/backend/opencl/core/OpenCLBackend.hpp index 3f0abcefb..1d4a51ece 100644 --- a/source/backend/opencl/core/OpenCLBackend.hpp +++ b/source/backend/opencl/core/OpenCLBackend.hpp @@ -153,7 +153,7 @@ class OpenCLBackend : public Backend { void copyToDeviceInt8(const Tensor* srcTensor, const Tensor* dstTensor) const; void copyBetweenDevice(const Tensor* srcTensor, const Tensor* dstTensor) const; - void _allocHostBuffer(int length, const Tensor* srcTensor) const; + bool _allocHostBuffer(int length, const Tensor* srcTensor) const; const CLRuntime* mCLRuntime; @@ -171,8 +171,6 @@ class OpenCLBackend : public Backend { std::shared_ptr mOpenCLRuntime; mutable std::pair> mHostBuffer; - mutable cl::Buffer *mDeviceBuffer = nullptr; - mutable std::shared_ptr mDeviceTexture; BackendConfig::PrecisionMode mPrecision; BackendConfig::MemoryMode mMemory; bool mIsCreateError{false}; @@ -233,6 +231,26 @@ class TypedCreator : public OpenCLBackend::Creator { } }; +class CLSharedMemReleaseBuffer : public Backend::MemObj { +public: + CLSharedMemReleaseBuffer(uint64_t sharedId, cl::Buffer *bId) { + mSharedId = sharedId; + mBuffer = bId; + } + virtual ~ CLSharedMemReleaseBuffer() { + delete mBuffer; + } + uint64_t getSharedId(){ + return mSharedId; + } + cl::Buffer *getMem(){ + return mBuffer; + } +private: + uint64_t mSharedId; + cl::Buffer *mBuffer; +}; + } // namespace OpenCL } // namespace MNN #endif /* OpenCLBackend_hpp */ diff --git a/source/backend/opencl/core/runtime/OpenCLRuntime.cpp b/source/backend/opencl/core/runtime/OpenCLRuntime.cpp index af7cbef36..076d50b7e 100644 --- a/source/backend/opencl/core/runtime/OpenCLRuntime.cpp +++ b/source/backend/opencl/core/runtime/OpenCLRuntime.cpp @@ -159,62 +159,43 @@ OpenCLRuntime::OpenCLRuntime(const BackendConfig::PrecisionMode precision, const } const std::string extensions = platforms[0].getInfo(); bool isPriorityHint = (extensions.find("cl_khr_priority_hints") != std::string::npos); - + std::vector context_properties; + if(mGpuType == ADRENO && !isPriorityHint){ + context_properties.push_back(CL_CONTEXT_PERF_HINT_QCOM); + context_properties.push_back(CL_PERF_HINT_HIGH_QCOM); + context_properties.push_back(CL_CONTEXT_PRIORITY_HINT_QCOM); + context_properties.push_back(CL_PRIORITY_HINT_LOW_QCOM); + mIsDeviceSupportedLowPower = true; + } + #ifdef ARM_OPENCL_PRINTF_DEBUG + context_properties.push_back(CL_PRINTF_CALLBACK_ARM); + context_properties.push_back((cl_context_properties)callback); + context_properties.push_back(CL_PRINTF_BUFFERSIZE_ARM); + context_properties.push_back(0x1000); + #endif + std::string deviceextensions = mFirstGPUDevicePtr.get()->getInfo(); +#ifdef MNN_USE_LIB_WRAPPER + mIsSupportAHD = (getDeviceSupportsExtension(*(mFirstGPUDevicePtr.get()), "cl_arm_import_memory_android_hardware_buffer") + && mGpuType == MALI && OpenCLSymbolsOperator::getOpenclSymbolsPtr()->getFuncAddress(platforms[platformId](), "clImportMemoryARM")) + || (mGpuType == ADRENO && getDeviceSupportsExtension(*(mFirstGPUDevicePtr.get()), "cl_qcom_android_ahardwarebuffer_host_ptr")); +#endif if(nullptr != contextPtr){ - if(nullptr != glShared && getDeviceSupportsExtension(*(mFirstGPUDevicePtr.get()), "cl_khr_gl_sharing")){ - std::vector context_properties; - context_properties.reserve(7); - context_properties.push_back(CL_GL_CONTEXT_KHR); - context_properties.push_back((cl_context_properties)contextPtr); - context_properties.push_back(CL_EGL_DISPLAY_KHR); - context_properties.push_back((cl_context_properties)glShared); - context_properties.push_back(CL_CONTEXT_PLATFORM); - context_properties.push_back((cl_context_properties)platforms[platformId]()); - context_properties.push_back(0); - mContext = std::shared_ptr(new cl::Context(std::vector({*mFirstGPUDevicePtr}), context_properties.data(), nullptr, nullptr, &res)); - } - else{ - mContext = std::shared_ptr((cl::Context*)contextPtr, [](void* ptr) { - // Do nothing - }); - } + mContext = std::shared_ptr((cl::Context*)contextPtr, [](void* ptr) { + // Do nothing + }); }else{ - if(mGpuType == ADRENO && !isPriorityHint){ - std::vector context_properties; - context_properties.reserve(5); - context_properties.push_back(CL_CONTEXT_PERF_HINT_QCOM); - context_properties.push_back(CL_PERF_HINT_HIGH_QCOM); - context_properties.push_back(CL_CONTEXT_PRIORITY_HINT_QCOM); - context_properties.push_back(CL_PRIORITY_HINT_LOW_QCOM); - context_properties.push_back(0); - mContext = std::shared_ptr(new cl::Context(std::vector({*mFirstGPUDevicePtr}), context_properties.data(), nullptr, nullptr, &res)); - mIsDeviceSupportedLowPower = true; - }else{ - #ifdef ARM_OPENCL_PRINTF_DEBUG - cl_context_properties context_properties[] = - { - CL_CONTEXT_PLATFORM, (cl_context_properties)platforms[platformId](), - CL_PRINTF_CALLBACK_ARM, (cl_context_properties)callback, - CL_PRINTF_BUFFERSIZE_ARM, 0x1000, - 0 - }; - mContext = std::shared_ptr(new cl::Context(std::vector({*mFirstGPUDevicePtr}), context_properties, nullptr, nullptr, &res)); - #else - mContext = std::shared_ptr(new cl::Context(std::vector({*mFirstGPUDevicePtr}), nullptr, nullptr, nullptr, &res)); - #endif - } - - MNN_CHECK_CL_SUCCESS(res, "context"); - if (res != CL_SUCCESS) { - mIsCreateError = true; - return; - } + context_properties.push_back(0); + mContext = std::shared_ptr(new cl::Context(std::vector({*mFirstGPUDevicePtr}), context_properties.data(), nullptr, nullptr, &res)); + } + MNN_CHECK_CL_SUCCESS(res, "context"); + if (res != CL_SUCCESS) { + mIsCreateError = true; + return; } mIsDeviceSupportedLowPower = (mIsDeviceSupportedLowPower || isPriorityHint); #ifdef MNN_USE_LIB_WRAPPER - mIsSupportGL = !OpenCLSymbolsOperator::getOpenclSymbolsPtr()->isGlError(); if(isPriorityHint) { if(true == OpenCLSymbolsOperator::getOpenclSymbolsPtr()->isPropError()) @@ -646,7 +627,7 @@ std::shared_ptr OpenCLRuntime::buildKernelWithCache(const std::strin buildOptionsStr += " -DCONVERT_OUTPUT16=convert_int16"; buildOptionsStr += " -DWI_DATA=write_imagei"; } else { - MNN_PRINT("opencl input datatype not support, bit:%d\n", output->getType().bits); + MNN_PRINT("opencl output datatype not support, bit:%d\n", output->getType().bits); MNN_ASSERT(false); } } else if(output->getType().code == halide_type_uint){ @@ -668,7 +649,7 @@ std::shared_ptr OpenCLRuntime::buildKernelWithCache(const std::strin buildOptionsStr += " -DCONVERT_OUTPUT16=convert_uint16"; buildOptionsStr += " -DWI_DATA=write_imageui"; } else { - MNN_PRINT("opencl input datatype not support, bit:%d\n", output->getType().bits); + MNN_PRINT("opencl output datatype not support, bit:%d\n", output->getType().bits); MNN_ASSERT(false); } } else { diff --git a/source/backend/opencl/core/runtime/OpenCLRuntime.hpp b/source/backend/opencl/core/runtime/OpenCLRuntime.hpp index b5dfa5918..ac7c31f83 100644 --- a/source/backend/opencl/core/runtime/OpenCLRuntime.hpp +++ b/source/backend/opencl/core/runtime/OpenCLRuntime.hpp @@ -110,9 +110,9 @@ class OpenCLRuntime { return mCLVersion; } uint32_t getPrecisionLevel() const; - bool isSupportGL(){ - return mIsSupportGL; - } + bool isSupportAHD(){ + return mIsSupportAHD; + } #ifdef MNN_OPENCL_SVM_ENABLE cl_device_svm_capabilities getSvmCapabilities() { return mSvmCapabilities; @@ -215,7 +215,7 @@ class OpenCLRuntime { bool mSupportDotInt8 = false; bool mSupportDotAccInt8 = false; bool mSupportedIntelSubgroup = false; - bool mIsSupportGL = true; + bool mIsSupportAHD = false; GpuType mGpuType; MaliAr mMaliAr; float mCLVersion = 1.0f; diff --git a/source/backend/opencl/core/runtime/OpenCLWrapper.cpp b/source/backend/opencl/core/runtime/OpenCLWrapper.cpp index a83bfaa8f..b46952dfd 100644 --- a/source/backend/opencl/core/runtime/OpenCLWrapper.cpp +++ b/source/backend/opencl/core/runtime/OpenCLWrapper.cpp @@ -121,12 +121,24 @@ bool OpenCLSymbols::isPropError() { bool OpenCLSymbols::isQcomError() { return mQcomError; } - -bool OpenCLSymbols::isGlError() { - return mGlError; + +bool OpenCLSymbols::getFuncAddress(cl_platform_id platform, const char *func_name){ + if(clGetExtensionFunctionAddressForPlatform != nullptr){ + clImportMemoryARM = reinterpret_cast(clGetExtensionFunctionAddressForPlatform(platform, "clImportMemoryARM")); + if(clImportMemoryARM == nullptr){ + return false; + } + }else if(clGetExtensionFunctionAddress != nullptr){ + clImportMemoryARM = reinterpret_cast(clGetExtensionFunctionAddress("clImportMemoryARM")); + if(clImportMemoryARM == nullptr){ + return false; + } + } else{ + return false; + } + return true; } - bool OpenCLSymbols::LoadLibraryFromPath(const std::string &library_path) { #if defined(_WIN32) handle_ = LoadLibraryA(library_path.c_str()); @@ -203,15 +215,7 @@ bool OpenCLSymbols::LoadLibraryFromPath(const std::string &library_path) { if(func_name == nullptr){ \ mQcomError = true; \ } - -#define MNN_LOAD_GL_PTR(func_name) func_name = reinterpret_cast(dlsym(handle_, #func_name)); \ - if(func_name == nullptr && loadOpenCLPointer != nullptr){ \ - func_name = reinterpret_cast(loadOpenCLPointer(#func_name)); \ - } \ - if(func_name == nullptr){ \ - mGlError = true; \ - } - + #endif MNN_LOAD_FUNCTION_PTR(clGetPlatformIDs); @@ -261,10 +265,8 @@ bool OpenCLSymbols::LoadLibraryFromPath(const std::string &library_path) { MNN_LOAD_FUNCTION_PTR(clEnqueueCopyImage); MNN_LOAD_FUNCTION_PTR(clEnqueueReadImage); MNN_LOAD_FUNCTION_PTR(clEnqueueWriteImage); - MNN_LOAD_GL_PTR(clCreateFromGLBuffer); - MNN_LOAD_GL_PTR(clCreateFromGLTexture); - MNN_LOAD_GL_PTR(clEnqueueAcquireGLObjects); - MNN_LOAD_GL_PTR(clEnqueueReleaseGLObjects); + MNN_LOAD_FUNCTION_PTR(clGetExtensionFunctionAddress); + MNN_LOAD_FUNCTION_PTR(clGetExtensionFunctionAddressForPlatform); MNN_LOAD_PROP_PTR(clCreateCommandQueueWithProperties); MNN_LOAD_SVM_PTR(clSVMAlloc); @@ -671,49 +673,6 @@ cl_int CL_API_CALL clEnqueueCopyImage(cl_command_queue queue, return func(queue, src_image, dst_image, src_origin, dst_origin, region, num_events_in_wait_list, event_wait_list, event); } -cl_mem CL_API_CALL clCreateFromGLBuffer(cl_context context, - cl_mem_flags flags, - cl_GLuint bufobj, - int *errcode_ret){ - auto func = MNN::OpenCLSymbolsOperator::getOpenclSymbolsPtr()->clCreateFromGLBuffer; - MNN_CHECK_NOTNULL(func); - return func(context, flags, bufobj, errcode_ret); -} - -cl_mem CL_API_CALL clCreateFromGLTexture(cl_context context, - cl_mem_flags flags, - cl_GLenum target, - cl_GLint miplevel, - cl_GLuint texture, - cl_int *errcode_ret){ - auto func = MNN::OpenCLSymbolsOperator::getOpenclSymbolsPtr()->clCreateFromGLTexture; - MNN_CHECK_NOTNULL(func); - return func(context, flags, target, miplevel, texture, errcode_ret); - -} - -cl_int CL_API_CALL clEnqueueAcquireGLObjects(cl_command_queue command_queue, - cl_uint num_objects, - const cl_mem *mem_objects, - cl_uint num_events_in_wait_list, - const cl_event *event_wait_list, - cl_event *event){ - auto func = MNN::OpenCLSymbolsOperator::getOpenclSymbolsPtr()->clEnqueueAcquireGLObjects; - MNN_CHECK_NOTNULL(func); - return func(command_queue, num_objects, mem_objects, num_events_in_wait_list, event_wait_list, event); -} - -cl_int CL_API_CALL clEnqueueReleaseGLObjects(cl_command_queue command_queue, - cl_uint num_objects, - const cl_mem *mem_objects, - cl_uint num_events_in_wait_list, - const cl_event *event_wait_list, - cl_event *event){ - auto func = MNN::OpenCLSymbolsOperator::getOpenclSymbolsPtr()->clEnqueueReleaseGLObjects; - MNN_CHECK_NOTNULL(func); - return func(command_queue, num_objects, mem_objects, num_events_in_wait_list, event_wait_list, event); -} - // clCreateCommandQueueWithProperties wrapper cl_command_queue CL_API_CALL clCreateCommandQueueWithProperties(cl_context context, cl_device_id device, const cl_queue_properties *properties, cl_int *errcode_ret) { auto func = MNN::OpenCLSymbolsOperator::getOpenclSymbolsPtr()->clCreateCommandQueueWithProperties; @@ -799,5 +758,22 @@ clEnqueueRecordingSVMQCOM(cl_command_queue command_queue, cl_recording_qcom reco return func(command_queue, recording, num_args, arg_array, num_svm_args, arg_svm_array, num_global_offsets, global_offset_array, num_global_workgroups, global_workgroup_array, num_local_workgroups, local_workgroups_array, num_non_arg_objs, non_arg_obj_array, num_events_in_wait_list, event_wait_list, event); } +void * CL_API_CALL clGetExtensionFunctionAddress(const char *func_name){ + auto func = MNN::OpenCLSymbolsOperator::getOpenclSymbolsPtr()->clGetExtensionFunctionAddress; + MNN_CHECK_NOTNULL(func); + return func(func_name); +} + +void * CL_API_CALL clGetExtensionFunctionAddressForPlatform(cl_platform_id platform, const char *func_name){ + auto func = MNN::OpenCLSymbolsOperator::getOpenclSymbolsPtr()->clGetExtensionFunctionAddressForPlatform; + MNN_CHECK_NOTNULL(func); + return func(platform, func_name); +} + +cl_mem CL_API_CALL clImportMemoryARM(cl_context context, cl_mem_flags flags, const cl_import_properties_arm *properties, void *memory, size_t size, cl_int *errcode_ret){ + auto func = MNN::OpenCLSymbolsOperator::getOpenclSymbolsPtr()->clImportMemoryARM; + MNN_CHECK_NOTNULL(func); + return func(context, flags, properties, memory, size, errcode_ret); +} #endif //MNN_USE_LIB_WRAPPER diff --git a/source/backend/opencl/core/runtime/OpenCLWrapper.hpp b/source/backend/opencl/core/runtime/OpenCLWrapper.hpp index ba39a8c30..0b3fecc29 100644 --- a/source/backend/opencl/core/runtime/OpenCLWrapper.hpp +++ b/source/backend/opencl/core/runtime/OpenCLWrapper.hpp @@ -31,6 +31,10 @@ #endif #include "CL/cl_ext_qcom.h" +#include "CL/cl_ext.h" +#ifdef __ANDROID__ +#include +#endif #define MNN_CHECK_NOTNULL(X) MNN_ASSERT(X != NULL) @@ -53,7 +57,7 @@ class OpenCLSymbols { bool isSvmError(); bool isPropError(); bool isQcomError(); - bool isGlError(); + bool getFuncAddress(cl_platform_id platform, const char *func_name); using clGetPlatformIDsFunc = cl_int (CL_API_CALL *)(cl_uint, cl_platform_id *, cl_uint *); using clGetPlatformInfoFunc = cl_int (CL_API_CALL *)(cl_platform_id, cl_platform_info, size_t, void *, size_t *); @@ -148,10 +152,6 @@ class OpenCLSymbols { size_t param_value_size, void *param_value, size_t *param_value_size_ret); using clGetImageInfoFunc = cl_int (CL_API_CALL *)(cl_mem, cl_image_info, size_t, void *, size_t *); - using clCreateFromGLBufferFunc = cl_mem (CL_API_CALL *)(cl_context, cl_mem_flags, cl_GLuint, int *); - using clCreateFromGLTextureFunc = cl_mem (CL_API_CALL *)(cl_context, cl_mem_flags, cl_GLenum, cl_GLint, cl_GLuint, cl_int*); - using clEnqueueAcquireGLObjectsFunc = cl_int (CL_API_CALL *)(cl_command_queue, cl_uint, const cl_mem *, cl_uint, const cl_event *, cl_event *); - using clEnqueueReleaseGLObjectsFunc = cl_int (CL_API_CALL *)(cl_command_queue, cl_uint, const cl_mem *, cl_uint, const cl_event *, cl_event *); using clReleaseDeviceFunc = cl_int (CL_API_CALL *)(cl_device_id); using clRetainDeviceFunc = cl_int (CL_API_CALL *)(cl_device_id); @@ -176,6 +176,10 @@ class OpenCLSymbols { size_t, const cl_offset_qcom*, size_t, const cl_workgroup_qcom*, size_t, const cl_workgroup_qcom*, size_t, const cl_array_kernel_exec_info_qcom*, cl_uint, const cl_event*, cl_event*); + using clGetExtensionFunctionAddressFunc = void *(CL_API_CALL *)(const char *); + using clGetExtensionFunctionAddressForPlatformFunc = void *(CL_API_CALL *)(cl_platform_id, const char *); + using clImportMemoryARMFunc = cl_mem (CL_API_CALL *)(cl_context, cl_mem_flags, const cl_import_properties_arm*, void*, size_t, cl_int*); + #define MNN_CL_DEFINE_FUNC_PTR(func) func##Func func = nullptr MNN_CL_DEFINE_FUNC_PTR(clGetPlatformIDs); @@ -225,10 +229,6 @@ class OpenCLSymbols { MNN_CL_DEFINE_FUNC_PTR(clGetImageInfo); MNN_CL_DEFINE_FUNC_PTR(clEnqueueReadImage); MNN_CL_DEFINE_FUNC_PTR(clEnqueueWriteImage); - MNN_CL_DEFINE_FUNC_PTR(clCreateFromGLBuffer); - MNN_CL_DEFINE_FUNC_PTR(clCreateFromGLTexture); - MNN_CL_DEFINE_FUNC_PTR(clEnqueueAcquireGLObjects); - MNN_CL_DEFINE_FUNC_PTR(clEnqueueReleaseGLObjects); MNN_CL_DEFINE_FUNC_PTR(clCreateCommandQueueWithProperties); MNN_CL_DEFINE_FUNC_PTR(clSVMAlloc); @@ -243,6 +243,9 @@ class OpenCLSymbols { MNN_CL_DEFINE_FUNC_PTR(clRetainRecordingQCOM); MNN_CL_DEFINE_FUNC_PTR(clEnqueueRecordingQCOM); MNN_CL_DEFINE_FUNC_PTR(clEnqueueRecordingSVMQCOM); + MNN_CL_DEFINE_FUNC_PTR(clGetExtensionFunctionAddress); + MNN_CL_DEFINE_FUNC_PTR(clGetExtensionFunctionAddressForPlatform); + MNN_CL_DEFINE_FUNC_PTR(clImportMemoryARM); #undef MNN_CL_DEFINE_FUNC_PTR @@ -258,7 +261,6 @@ class OpenCLSymbols { bool mPropError{false}; bool mQcomError{false}; bool mCL_12Error{false}; - bool mGlError{false}; }; class OpenCLSymbolsOperator { diff --git a/source/backend/opencl/execution/buffer/ConvBufExecution.cpp b/source/backend/opencl/execution/buffer/ConvBufExecution.cpp index db3fb2d38..00f48dbce 100644 --- a/source/backend/opencl/execution/buffer/ConvBufExecution.cpp +++ b/source/backend/opencl/execution/buffer/ConvBufExecution.cpp @@ -204,7 +204,7 @@ ConvBufExecution::ConvBufExecution(const std::vector &inputs, const st } mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(filterBufferCL, ptrCL); - mResource->mFilter.reset(Tensor::createDevice({1, filterImageShape[1], 1, 4 * filterImageShape[0]})); + mResource->mFilter.reset(Tensor::createDevice({filterImageShape[1] * 4 * filterImageShape[0]})); mOpenCLBackend->onAcquireBuffer(mResource->mFilter.get(), Backend::STATIC); MNN::OpenCL::BufferConvertor bufferConvertor{mOpenCLBackend->getOpenCLRuntime()}; @@ -458,8 +458,8 @@ ErrorCode ConvBufExecution::onResize(const std::vector &inputs, const std::pair min_cost(INT_MAX, 0);//(min_time, min_index) for(int knl_idx = 0; knl_idx < actual_kernel; knl_idx++) { std::set buildOption = mResource->mBuildOptions; - if(outputShape.at(3) % itemC[knl_idx] != 0){ - buildOption.emplace("-DCHANNEL_LEAVE"); + if(itemC[knl_idx] == 8 && outputShape.at(3) % itemC[knl_idx] > 0 && outputShape.at(3) % itemC[knl_idx] <= 4){ + buildOption.emplace("-DCHANNEL_BOUNDARY_PROTECT"); } if((outputShape.at(2) % itemW[knl_idx]) != 0){ buildOption.emplace("-DBLOCK_LEAVE"); @@ -496,13 +496,12 @@ ErrorCode ConvBufExecution::onResize(const std::vector &inputs, const } } - std::shared_ptr quanCommon; int min_index = min_cost.second; mGlobalWorkSize = {globalWorkSize[min_index][0], globalWorkSize[min_index][1]}; std::set buildOption = mResource->mBuildOptions; - if(outputShape.at(3) % itemC[min_index] != 0){ - buildOption.emplace("-DCHANNEL_LEAVE"); + if(itemC[min_index] == 8 && outputShape.at(3) % itemC[min_index] > 0 && outputShape.at(3) % itemC[min_index] <= 4){ + buildOption.emplace("-DCHANNEL_BOUNDARY_PROTECT"); } if((outputShape.at(2) % itemW[min_index]) != 0){ buildOption.emplace("-DBLOCK_LEAVE"); diff --git a/source/backend/opencl/execution/buffer/ConvBufLowMemoryExecution.cpp b/source/backend/opencl/execution/buffer/ConvBufLowMemoryExecution.cpp index 1ce568cbd..21db9895e 100644 --- a/source/backend/opencl/execution/buffer/ConvBufLowMemoryExecution.cpp +++ b/source/backend/opencl/execution/buffer/ConvBufLowMemoryExecution.cpp @@ -265,8 +265,8 @@ void ConvBufLowMemoryExecution::tuneGeneralCaseLowMemory(Tensor * input, Tensor // MNN_PRINT("Checking kernel %d.\n", knlCheck); for (int knl_idx = 0; knl_idx < actual_kernel; knl_idx++) { std::set buildOption = mResource->mBuildOptions; - if(outputShape.at(3) % itemC[knl_idx] != 0){ - buildOption.emplace("-DCHANNEL_LEAVE"); + if(itemC[knl_idx] == 8 && outputShape.at(3) % itemC[knl_idx] > 0 && outputShape.at(3) % itemC[knl_idx] <= 4){ + buildOption.emplace("-DCHANNEL_BOUNDARY_PROTECT"); } if((outputShape.at(2) % itemW[knl_idx]) != 0 || (outputShape.at(1) % itemH[knl_idx]) != 0){ buildOption.emplace("-DBLOCK_LEAVE"); @@ -313,8 +313,8 @@ void ConvBufLowMemoryExecution::tuneGeneralCaseLowMemory(Tensor * input, Tensor mGlobalWorkSize = {globalWorkSize[min_index][0], globalWorkSize[min_index][1]}; std::set buildOption = mResource->mBuildOptions; - if(outputShape.at(3) % itemC[min_index] != 0){ - buildOption.emplace("-DCHANNEL_LEAVE"); + if(itemC[min_index] == 8 && outputShape.at(3) % itemC[min_index] > 0 && outputShape.at(3) % itemC[min_index] <= 4){ + buildOption.emplace("-DCHANNEL_BOUNDARY_PROTECT"); } if((outputShape.at(2) % itemW[min_index]) != 0 || (outputShape.at(1) % itemH[min_index]) != 0){ buildOption.emplace("-DBLOCK_LEAVE"); diff --git a/source/backend/opencl/execution/buffer/DepthwiseConvBufExecution.cpp b/source/backend/opencl/execution/buffer/DepthwiseConvBufExecution.cpp index 44af28f35..1dac90fc3 100644 --- a/source/backend/opencl/execution/buffer/DepthwiseConvBufExecution.cpp +++ b/source/backend/opencl/execution/buffer/DepthwiseConvBufExecution.cpp @@ -160,7 +160,11 @@ ErrorCode DepthwiseConvBufExecution::onEncode(const std::vector &input std::vector localWorkSize[total_kernel]; std::pair min_cost(INT_MAX, 0);//(min_time, min_index) for(int knl_idx = 0; knl_idx < actual_kernel; knl_idx++) { - kernel[knl_idx] = mOpenCLBackend->getOpenCLRuntime()->buildKernel("depthwise_conv2d_buf", kernelName[knl_idx], mResource->mBuildOptions); + std::set buildOption = mResource->mBuildOptions; + if(itemC[knl_idx] == 8 && outputShape.at(3) % itemC[knl_idx] > 0 && outputShape.at(3) % itemC[knl_idx] <= 4){ + buildOption.emplace("-DCHANNEL_BOUNDARY_PROTECT"); + } + kernel[knl_idx] = mOpenCLBackend->getOpenCLRuntime()->buildKernel("depthwise_conv2d_buf", kernelName[knl_idx], buildOption); uint32_t maxWorkGroupSize = static_cast(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(kernel[knl_idx])); globalWorkSize[knl_idx] = {static_cast(UP_DIV(outputShape.at(3), itemC[knl_idx]) * UP_DIV(outputShape.at(2), itemW[knl_idx])), static_cast(outputShape.at(0) * UP_DIV(outputShape.at(1), itemH[knl_idx]))}; @@ -196,7 +200,11 @@ ErrorCode DepthwiseConvBufExecution::onEncode(const std::vector &input int min_index = min_cost.second; mGlobalWorkSize = {globalWorkSize[min_index][0], globalWorkSize[min_index][1]}; - unit.kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("depthwise_conv2d_buf", kernelName[min_index], mResource->mBuildOptions); + std::set buildOption = mResource->mBuildOptions; + if(itemC[min_index] == 8 && outputShape.at(3) % itemC[min_index] > 0 && outputShape.at(3) % itemC[min_index] <= 4){ + buildOption.emplace("-DCHANNEL_BOUNDARY_PROTECT"); + } + unit.kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("depthwise_conv2d_buf", kernelName[min_index], buildOption); uint32_t idx = 0; cl_int ret = CL_SUCCESS; diff --git a/source/backend/opencl/execution/cl/buffer_convert_buf.cl b/source/backend/opencl/execution/cl/buffer_convert_buf.cl index 6a4b4e220..ece688d8c 100644 --- a/source/backend/opencl/execution/cl/buffer_convert_buf.cl +++ b/source/backend/opencl/execution/cl/buffer_convert_buf.cl @@ -74,7 +74,7 @@ __kernel void buffer_copy_to_buffer(GLOBAL_SIZE_2_DIMS #endif } -// convert kernel : from buffer(oihw) to image(oc/4 h w , ic oc4) +// convert kernel : from buffer(oihw) to image(ic, oc/4, h, w, oc4) __kernel void conv2d_filter_buffer_to_nc4hw4_buffer(GLOBAL_SIZE_2_DIMS __global const FLOAT *input_ptr, __private const int output_channel, diff --git a/source/backend/opencl/execution/cl/conv_2d.cl b/source/backend/opencl/execution/cl/conv_2d.cl index 2b0bbad14..c87cb749d 100644 --- a/source/backend/opencl/execution/cl/conv_2d.cl +++ b/source/backend/opencl/execution/cl/conv_2d.cl @@ -459,6 +459,7 @@ void conv_2d_1x1_c8h1w4(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, for (int in_channel_block_idx = 0; in_channel_block_idx < in_channel_block; ++in_channel_block_idx) { #if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4) int kindex = (in_channel_block_idx * 4) / blockDim * out_channel_blocks * 8; + // already pack to 16, no need boundry protect COMPUTE_FLOAT8 ScaleOffset0 = CONVERT_COMPUTE_FLOAT8(vload8(output_channel_idx, dequantScaleOffset + kindex)); COMPUTE_FLOAT4 scale0 = (COMPUTE_FLOAT4)(ScaleOffset0.s0, ScaleOffset0.s2, ScaleOffset0.s4, ScaleOffset0.s6); COMPUTE_FLOAT4 offset0 = (COMPUTE_FLOAT4)(ScaleOffset0.s1, ScaleOffset0.s3, ScaleOffset0.s5, ScaleOffset0.s7); @@ -476,7 +477,11 @@ void conv_2d_1x1_c8h1w4(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, #if (defined USE_LOW_BIT_WEIGHT_INT8) FLOAT16 weightsInt80 = CONVERT_FLOAT16(vload16(0, kernel_ptr + weight_ic_offset + in_channel_block_idx * weight_oc_offset)); + #ifdef CHANNEL_BOUNDARY_PROTECT + FLOAT16 weightsInt81 = output_channel_idx + 1 >= out_channel_blocks ? (FLOAT16)0 : CONVERT_FLOAT16(vload16(0, kernel_ptr + 16 + weight_ic_offset + in_channel_block_idx * weight_oc_offset)); + #else FLOAT16 weightsInt81 = CONVERT_FLOAT16(vload16(0, kernel_ptr + 16 + weight_ic_offset + in_channel_block_idx * weight_oc_offset)); + #endif FLOAT4 weights0 = CONVERT_FLOAT4(weightsInt80.s0123) * scale0 + offset0; FLOAT4 weights1 = CONVERT_FLOAT4(weightsInt80.s4567) * scale0 + offset0; FLOAT4 weights2 = CONVERT_FLOAT4(weightsInt80.s89ab) * scale0 + offset0; @@ -541,10 +546,17 @@ void conv_2d_1x1_c8h1w4(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, weights2 = vload4(weights_width_base + 2, weights + weight_offset); weights3 = vload4(weights_width_base + 3, weights + weight_offset); + #ifdef CHANNEL_BOUNDARY_PROTECT + weights4 = output_channel_idx + 1 >= out_channel_blocks ? (FLOAT4)0 : vload4(weights_width_base, weights + weight_offset1); + weights5 = output_channel_idx + 1 >= out_channel_blocks ? (FLOAT4)0 : vload4(weights_width_base + 1, weights + weight_offset1); + weights6 = output_channel_idx + 1 >= out_channel_blocks ? (FLOAT4)0 : vload4(weights_width_base + 2, weights + weight_offset1); + weights7 = output_channel_idx + 1 >= out_channel_blocks ? (FLOAT4)0 : vload4(weights_width_base + 3, weights + weight_offset1); + #else weights4 = vload4(weights_width_base, weights + weight_offset1); weights5 = vload4(weights_width_base + 1, weights + weight_offset1); weights6 = vload4(weights_width_base + 2, weights + weight_offset1); weights7 = vload4(weights_width_base + 3, weights + weight_offset1); + #endif #else weights0 = RI_F(weights, SAMPLER, (int2)(weights_width_base + 0, output_channel_idx)); weights1 = RI_F(weights, SAMPLER, (int2)(weights_width_base + 1, output_channel_idx)); @@ -1081,10 +1093,18 @@ void conv_2d_c8h4w1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, weights1 = mad(CONVERT_FLOAT4(charWeight1), scale0, offset0); weights2 = mad(CONVERT_FLOAT4(charWeight2), scale0, offset0); weights3 = mad(CONVERT_FLOAT4(charWeight3), scale0, offset0); + #ifdef CHANNEL_BOUNDARY_PROTECT + charWeight0 = out_channel_block_idx + 1 >= out_channel_blocks ? (char4)0 : vload4(0, kernel_ptr+weight_offset+weight_oc_offset); + charWeight1 = out_channel_block_idx + 1 >= out_channel_blocks ? (char4)0 : vload4(0, kernel_ptr+weight_offset+weight_oc_offset+weight_ic_offset); + charWeight2 = out_channel_block_idx + 1 >= out_channel_blocks ? (char4)0 : vload4(0, kernel_ptr+weight_offset+weight_oc_offset+weight_ic_offset*2); + charWeight3 = out_channel_block_idx + 1 >= out_channel_blocks ? (char4)0 : vload4(0, kernel_ptr+weight_offset+weight_oc_offset+weight_ic_offset*3); + + #else charWeight0 = vload4(0, kernel_ptr+weight_offset+weight_oc_offset); charWeight1 = vload4(0, kernel_ptr+weight_offset+weight_oc_offset+weight_ic_offset); charWeight2 = vload4(0, kernel_ptr+weight_offset+weight_oc_offset+weight_ic_offset*2); charWeight3 = vload4(0, kernel_ptr+weight_offset+weight_oc_offset+weight_ic_offset*3); + #endif weights4 = mad(CONVERT_FLOAT4(charWeight0), scale1, offset1); weights5 = mad(CONVERT_FLOAT4(charWeight1), scale1, offset1); weights6 = mad(CONVERT_FLOAT4(charWeight2), scale1, offset1); @@ -1153,10 +1173,18 @@ void conv_2d_c8h4w1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, weights1 = vload4(0, weights+weight_offset+weight_ic_offset); weights2 = vload4(0, weights+weight_offset+weight_ic_offset*2); weights3 = vload4(0, weights+weight_offset+weight_ic_offset*3); + #ifdef CHANNEL_BOUNDARY_PROTECT + charWeight0 = + weights4 = out_channel_block_idx + 1 >= out_channel_blocks ? (FLOAT4)0 : vload4(0, weights+weight_offset + weight_oc_offset); + weights5 = out_channel_block_idx + 1 >= out_channel_blocks ? (FLOAT4)0 : vload4(0, weights+weight_offset+weight_ic_offset + weight_oc_offset); + weights6 = out_channel_block_idx + 1 >= out_channel_blocks ? (FLOAT4)0 : vload4(0, weights+weight_offset+weight_ic_offset*2 + weight_oc_offset); + weights7 = out_channel_block_idx + 1 >= out_channel_blocks ? (FLOAT4)0 : vload4(0, weights+weight_offset+weight_ic_offset*3 + weight_oc_offset); + #else weights4 = vload4(0, weights+weight_offset + weight_oc_offset); weights5 = vload4(0, weights+weight_offset+weight_ic_offset + weight_oc_offset); weights6 = vload4(0, weights+weight_offset+weight_ic_offset*2 + weight_oc_offset); weights7 = vload4(0, weights+weight_offset+weight_ic_offset*3 + weight_oc_offset); + #endif weight_offset += 4; #else weights0 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 0, weights_y_idx)); diff --git a/source/backend/opencl/execution/cl/conv_2d_buf.cl b/source/backend/opencl/execution/cl/conv_2d_buf.cl index d3d34e5f4..744324e14 100644 --- a/source/backend/opencl/execution/cl/conv_2d_buf.cl +++ b/source/backend/opencl/execution/cl/conv_2d_buf.cl @@ -200,25 +200,33 @@ void conv_2d_1x1_c8h1w4(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks, DEAL_NON_UNIFORM_DIM2(out_c_w_idx, out_b_h_idx); - const int out_c_idx = out_c_w_idx / out_w_blocks; + const int out_c_idx_0 = (out_c_w_idx / out_w_blocks) << 1; + const int out_c_idx_1 = out_c_idx_0 + 1; const int out_w_idx = out_c_w_idx % out_w_blocks; const int out_b_idx = out_b_h_idx / out_h;//equal to in_b_idx const int out_h_idx = out_b_h_idx % out_h;//equal to in_h_idx const int out_w4_idx = mul24(out_w_idx, 4); - COMPUTE_FLOAT4 out0 = CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx<<1, bias_ptr)); + COMPUTE_FLOAT4 out0 = CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx_0, bias_ptr)); COMPUTE_FLOAT4 out1 = out0; COMPUTE_FLOAT4 out2 = out0; COMPUTE_FLOAT4 out3 = out0; - COMPUTE_FLOAT4 out4 = CONVERT_COMPUTE_FLOAT4(vload4((out_c_idx<<1)+1, bias_ptr)); + #ifdef CHANNEL_BOUNDARY_PROTECT + COMPUTE_FLOAT4 out4 = out_c_idx_1 >= out_c_block ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx_1, bias_ptr)); + COMPUTE_FLOAT4 out5 = out4; + COMPUTE_FLOAT4 out6 = out4; + COMPUTE_FLOAT4 out7 = out4; + #else + COMPUTE_FLOAT4 out4 = CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx_1, bias_ptr)); COMPUTE_FLOAT4 out5 = out4; COMPUTE_FLOAT4 out6 = out4; COMPUTE_FLOAT4 out7 = out4; + #endif const int intput_width_idx0 = out_w4_idx; int inp_offset = ((out_b_idx * out_h + out_h_idx)* out_w + intput_width_idx0)<<2; - int offset = out_c_idx*8; + int offset = out_c_idx_0*4; const int inp_add = out_b*out_h*out_w*4; for (int in_channel_block_idx = 0; in_channel_block_idx < in_c_block; ++in_channel_block_idx) { @@ -229,6 +237,7 @@ void conv_2d_1x1_c8h1w4(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks, COMPUTE_FLOAT4 in2 = CONVERT_COMPUTE_FLOAT4(vload4(2, input+inp_offset)); COMPUTE_FLOAT4 in3 = CONVERT_COMPUTE_FLOAT4(vload4(3, input+inp_offset)); + // output_channel at least pack to 8, no need boundry protect COMPUTE_FLOAT4 weights0 = CONVERT_COMPUTE_FLOAT4(vload4(0, kernel_ptr + offset)); COMPUTE_FLOAT4 weights1 = CONVERT_COMPUTE_FLOAT4(vload4(1, kernel_ptr + offset)); COMPUTE_FLOAT4 weights2 = CONVERT_COMPUTE_FLOAT4(vload4(0, kernel_ptr + offset + out_c_pack)); @@ -306,7 +315,7 @@ void conv_2d_1x1_c8h1w4(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks, out7 = clamp(out7, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6); #endif - const int out_offset = (((out_b_idx + out_c_idx*2*out_b)*out_h + out_h_idx)* out_w + out_w4_idx)*4; + const int out_offset = (((out_b_idx + out_c_idx_0*out_b)*out_h + out_h_idx)* out_w + out_w4_idx)*4; __global FLOAT * _tempoutput = output + out_offset; __global FLOAT * _tempoutput1 = _tempoutput + 4*out_h*out_w*out_b; @@ -323,8 +332,8 @@ void conv_2d_1x1_c8h1w4(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks, } else if (remain == 1) { vstore4(CONVERT_FLOAT4(out0), 0, _tempoutput); } -#ifdef CHANNEL_LEAVE - if(out_c_idx*2+1 >= out_c_block) { +#ifdef CHANNEL_BOUNDARY_PROTECT + if(out_c_idx_1 >= out_c_block) { return; } #endif @@ -340,8 +349,8 @@ void conv_2d_1x1_c8h1w4(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks, } #else vstore16(CONVERT_FLOAT16((COMPUTE_FLOAT16)(out0, out1, out2, out3)), 0, _tempoutput); -#ifdef CHANNEL_LEAVE - if(out_c_idx*2+1 >= out_c_block) { +#ifdef CHANNEL_BOUNDARY_PROTECT + if(out_c_idx_1 >= out_c_block) { return; } #endif @@ -368,21 +377,26 @@ void conv_2d_1x1_c8h1w2(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks, DEAL_NON_UNIFORM_DIM2(out_c_w_idx, out_b_h_idx); - const int out_c_idx = out_c_w_idx / out_w_blocks; + const int out_c_idx_0 = (out_c_w_idx / out_w_blocks) << 1; + const int out_c_idx_1 = out_c_idx_0 + 1; const int out_w_idx = out_c_w_idx % out_w_blocks; const int out_b_idx = out_b_h_idx / out_h;//equal to in_b_idx const int out_h_idx = out_b_h_idx % out_h;//equal to in_h_idx const int out_w2_idx = mul24(out_w_idx, 2); - COMPUTE_FLOAT4 out0 = CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx<<1, bias_ptr)); + COMPUTE_FLOAT4 out0 = CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx_0, bias_ptr)); COMPUTE_FLOAT4 out1 = out0; - COMPUTE_FLOAT4 out4 = CONVERT_COMPUTE_FLOAT4(vload4((out_c_idx<<1)+1, bias_ptr)); + #ifdef CHANNEL_BOUNDARY_PROTECT + COMPUTE_FLOAT4 out4 = out_c_idx_1 >= out_c_block ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx_1, bias_ptr)); + #else + COMPUTE_FLOAT4 out4 = CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx_1, bias_ptr)); + #endif COMPUTE_FLOAT4 out5 = out4; const int intput_width_idx0 = out_w2_idx; int inp_offset = ((out_b_idx * out_h + out_h_idx)* out_w + intput_width_idx0)<<2; - int offset = out_c_idx*8; + int offset = out_c_idx_0*4; const int inp_add = out_b*out_h*out_w*4; for (int in_channel_block_idx = 0; in_channel_block_idx < in_c_block; ++in_channel_block_idx) { @@ -437,7 +451,7 @@ void conv_2d_1x1_c8h1w2(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks, out5 = clamp(out5, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6); #endif - const int out_offset = (((out_b_idx + out_c_idx*2*out_b)*out_h + out_h_idx)* out_w + out_w2_idx)*4; + const int out_offset = (((out_b_idx + out_c_idx_0*out_b)*out_h + out_h_idx)* out_w + out_w2_idx)*4; __global FLOAT * _tempoutput = output + out_offset; @@ -450,8 +464,8 @@ void conv_2d_1x1_c8h1w2(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks, } else if (remain == 1) { vstore4(CONVERT_FLOAT4(out0), 0, _tempoutput); } -#ifdef CHANNEL_LEAVE - if(out_c_idx*2+1 >= out_c_block) { +#ifdef CHANNEL_BOUNDARY_PROTECT + if(out_c_idx_1 >= out_c_block) { return; } #endif @@ -462,8 +476,8 @@ void conv_2d_1x1_c8h1w2(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks, } #else vstore8(CONVERT_FLOAT8((COMPUTE_FLOAT8)(out0, out1)), 0, _tempoutput); -#ifdef CHANNEL_LEAVE - if(out_c_idx*2+1 >= out_c_block) { +#ifdef CHANNEL_BOUNDARY_PROTECT + if(out_c_idx_1 >= out_c_block) { return; } #endif @@ -1071,16 +1085,21 @@ void conv_2d_c8h4w1(GLOBAL_SIZE_2_DIMS DEAL_NON_UNIFORM_DIM2(out_c_w_idx, out_b_h_idx); - const int out_c_idx = (out_c_w_idx / out_w_blocks) << 1; + const int out_c_idx_0 = (out_c_w_idx / out_w_blocks) << 1; + const int out_c_idx_1 = out_c_idx_0 + 1; const int out_w_idx = out_c_w_idx % out_w_blocks; const int out_b_idx = out_b_h_idx / out_h_blocks;//equal to in_b_idx const int out_h_idx = (out_b_h_idx % out_h_blocks) << 2; - COMPUTE_FLOAT4 out0 = CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx, bias)); + COMPUTE_FLOAT4 out0 = CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx_0, bias)); COMPUTE_FLOAT4 out1 = out0; COMPUTE_FLOAT4 out2 = out0; COMPUTE_FLOAT4 out3 = out0; - COMPUTE_FLOAT4 out4 = CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx + 1, bias)); + #ifdef CHANNEL_BOUNDARY_PROTECT + COMPUTE_FLOAT4 out4 = out_c_idx_1 >= out_c_blocks ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx_1, bias)); + #else + COMPUTE_FLOAT4 out4 = CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx_1, bias)); + #endif COMPUTE_FLOAT4 out5 = out4; COMPUTE_FLOAT4 out6 = out4; COMPUTE_FLOAT4 out7 = out4; @@ -1100,12 +1119,12 @@ void conv_2d_c8h4w1(GLOBAL_SIZE_2_DIMS const int weight_ic_offset = out_c_blocks * weight_oc_offset; const int in_hw_size = in_hw.x * in_hw.y; for(ushort in_c_idx = 0; in_c_idx < in_c_blocks; in_c_idx++) { - //weights NC4HW4 [1, 4*icC4, ocC4*kh*kw, 1] xic4 - //index: [0, 4*in_c_idx, out_c_idx*kh*kw + kh_start*kw + kw_start, 0] + //weights NC4HW4 [ic/4, ic_4, oc/4, kh*kw, oc_4] + //index: [0, 4*in_c_idx, out_c_idx_0*kh*kw + kh_start*kw + kw_start, 0] const int inp_offset_base = (out_b_idx + in_c_idx * batch) * in_hw.x * in_hw.y * 4; for(int iy = 0; iy < filter_hw.x; iy++) { - int weight_offset = ((((4*in_c_idx+0)* out_c_blocks + out_c_idx) *filter_hw.x + iy)*filter_hw.y + kw_start) * 4; + int weight_offset = ((((4*in_c_idx+0)* out_c_blocks + out_c_idx_0) *filter_hw.x + iy)*filter_hw.y + kw_start) * 4; const int in_h0_idx = (iy * dilate_hw.x + in_h0_idx_base) * in_hw.y; const int in_h1_idx = (iy * dilate_hw.x + in_h1_idx_base) * in_hw.y; const int in_h2_idx = (iy * dilate_hw.x + in_h2_idx_base) * in_hw.y; @@ -1142,11 +1161,18 @@ void conv_2d_c8h4w1(GLOBAL_SIZE_2_DIMS out3 = mad(in3.z, weight2, out3); out3 = mad(in3.w, weight3, out3); + // weight: [ic/4, ic_4, oc/4, kh*kw, oc_4] + #ifdef CHANNEL_BOUNDARY_PROTECT + weight0 = out_c_idx_1 >= out_c_blocks ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(0, weight+weight_offset+weight_oc_offset)); + weight1 = out_c_idx_1 >= out_c_blocks ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset)); + weight2 = out_c_idx_1 >= out_c_blocks ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset*2)); + weight3 = out_c_idx_1 >= out_c_blocks ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset*3)); + #else weight0 = CONVERT_COMPUTE_FLOAT4(vload4(0, weight+weight_offset+weight_oc_offset)); weight1 = CONVERT_COMPUTE_FLOAT4(vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset)); weight2 = CONVERT_COMPUTE_FLOAT4(vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset*2)); weight3 = CONVERT_COMPUTE_FLOAT4(vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset*3)); - + #endif out4 = mad(in0.x, weight0, out4); out4 = mad(in0.y, weight1, out4); out4 = mad(in0.z, weight2, out4); @@ -1193,7 +1219,7 @@ void conv_2d_c8h4w1(GLOBAL_SIZE_2_DIMS out7 = clamp(out7, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6); #endif - int out_offset = (((out_b_idx + out_c_idx*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; + int out_offset = (((out_b_idx + out_c_idx_0*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; #ifdef BLOCK_LEAVE const int remain = out_hw.x - out_h_idx; if(remain >= 4){ @@ -1211,12 +1237,12 @@ void conv_2d_c8h4w1(GLOBAL_SIZE_2_DIMS }else if(remain == 1){ vstore4(CONVERT_FLOAT4(out0), 0, output+out_offset); } - #ifdef CHANNEL_LEAVE - if(out_c_idx + 1 >= out_c_blocks){ + #ifdef CHANNEL_BOUNDARY_PROTECT + if(out_c_idx_1 >= out_c_blocks){ return; } #endif - out_offset = (((out_b_idx + (out_c_idx + 1)*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; + out_offset = (((out_b_idx + (out_c_idx_1)*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; if(remain >= 4){ vstore4(CONVERT_FLOAT4(out4), 0, output+out_offset); vstore4(CONVERT_FLOAT4(out5), out_hw.y, output+out_offset); @@ -1237,12 +1263,12 @@ void conv_2d_c8h4w1(GLOBAL_SIZE_2_DIMS vstore4(CONVERT_FLOAT4(out1), out_hw.y, output+out_offset); vstore4(CONVERT_FLOAT4(out2), 2 * out_hw.y, output+out_offset); vstore4(CONVERT_FLOAT4(out3), 3 * out_hw.y, output+out_offset); - #ifdef CHANNEL_LEAVE - if(out_c_idx + 1 >= out_c_blocks){ + #ifdef CHANNEL_BOUNDARY_PROTECT + if(out_c_idx_1 >= out_c_blocks){ return; } #endif - out_offset = (((out_b_idx + (out_c_idx + 1)*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; + out_offset = (((out_b_idx + (out_c_idx_1)*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; vstore4(CONVERT_FLOAT4(out4), 0, output+out_offset); vstore4(CONVERT_FLOAT4(out5), out_hw.y, output+out_offset); vstore4(CONVERT_FLOAT4(out6), 2 * out_hw.y, output+out_offset); @@ -1273,16 +1299,21 @@ void conv_2d_c8h2w1(GLOBAL_SIZE_2_DIMS DEAL_NON_UNIFORM_DIM2(out_c_w_idx, out_b_h_idx); - const int out_c_idx = (out_c_w_idx / out_w_blocks) << 1; + const int out_c_idx_0 = (out_c_w_idx / out_w_blocks) << 1; + const int out_c_idx_1 = out_c_idx_0 + 1; const int out_w_idx = out_c_w_idx % out_w_blocks; const int out_b_idx = out_b_h_idx / out_h_blocks;//equal to in_b_idx const int out_h_idx = (out_b_h_idx % out_h_blocks) << 1; - COMPUTE_FLOAT4 out0 = CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx, bias)); + COMPUTE_FLOAT4 out0 = CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx_0, bias)); COMPUTE_FLOAT4 out1 = out0; - COMPUTE_FLOAT4 out2 = CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx + 1, bias)); + #ifdef CHANNEL_BOUNDARY_PROTECT + COMPUTE_FLOAT4 out2 = out_c_idx_1 >= out_c_blocks ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx_1, bias)); + #else + COMPUTE_FLOAT4 out2 = CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx_1, bias)); + #endif COMPUTE_FLOAT4 out3 = out2; - + const int in_w_idx_base = mad24(out_w_idx, stride_hw.y, -pad_hw.y); const int in_h0_idx_base = mad24(out_h_idx, stride_hw.x, -pad_hw.x); @@ -1298,11 +1329,11 @@ void conv_2d_c8h2w1(GLOBAL_SIZE_2_DIMS // weight: [ic/4, oc, 4], loop: ic/4 for(ushort in_c_idx = 0; in_c_idx < in_c_blocks; in_c_idx++) { //weights NC4HW4 [1, 4*icC4, ocC4*kh*kw, 1] xic4 - //index: [0, 4*in_c_idx, out_c_idx*kh*kw + kh_start*kw + kw_start, 0] + //index: [0, 4*in_c_idx, out_c_idx_0*kh*kw + kh_start*kw + kw_start, 0] const int inp_offset_base = (out_b_idx + in_c_idx*batch) * in_hw.x * in_hw.y * 4; for(int iy = 0; iy < filter_hw.x; iy++) { - int weight_offset = ((((4*in_c_idx+0)* out_c_blocks + out_c_idx) *filter_hw.x + iy)*filter_hw.y + kw_start) * 4; + int weight_offset = ((((4*in_c_idx+0)* out_c_blocks + out_c_idx_0) *filter_hw.x + iy)*filter_hw.y + kw_start) * 4; const int in_h0_idx = (iy * dilate_hw.x + in_h0_idx_base) * in_hw.y; const int in_h1_idx = (iy * dilate_hw.x + in_h1_idx_base) * in_hw.y; @@ -1324,11 +1355,17 @@ void conv_2d_c8h2w1(GLOBAL_SIZE_2_DIMS out1 = mad(in1.z, weight2, out1); out1 = mad(in1.w, weight3, out1); + #ifdef CHANNEL_BOUNDARY_PROTECT + weight0 = out_c_idx_1 >= out_c_blocks ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(0, weight+weight_offset+weight_oc_offset)); + weight1 = out_c_idx_1 >= out_c_blocks ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset)); + weight2 = out_c_idx_1 >= out_c_blocks ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset*2)); + weight3 = out_c_idx_1 >= out_c_blocks ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset*3)); + #else weight0 = CONVERT_COMPUTE_FLOAT4(vload4(0, weight+weight_offset+weight_oc_offset)); weight1 = CONVERT_COMPUTE_FLOAT4(vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset)); weight2 = CONVERT_COMPUTE_FLOAT4(vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset*2)); weight3 = CONVERT_COMPUTE_FLOAT4(vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset*3)); - + #endif out2 = mad(in0.x, weight0, out2); out2 = mad(in0.y, weight1, out2); out2 = mad(in0.z, weight2, out2); @@ -1357,7 +1394,7 @@ void conv_2d_c8h2w1(GLOBAL_SIZE_2_DIMS out3 = clamp(out3, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6); #endif - int out_offset = (((out_b_idx + out_c_idx*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; + int out_offset = (((out_b_idx + out_c_idx_0*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; #ifdef BLOCK_LEAVE const int remain = out_hw.x - out_h_idx; if(remain >= 2){ @@ -1366,12 +1403,12 @@ void conv_2d_c8h2w1(GLOBAL_SIZE_2_DIMS }else if(remain == 1){ vstore4(CONVERT_FLOAT4(out0), 0, output+out_offset); } - #ifdef CHANNEL_LEAVE - if(out_c_idx + 1 >= out_c_blocks){ + #ifdef CHANNEL_BOUNDARY_PROTECT + if(out_c_idx_1 >= out_c_blocks){ return; } #endif - out_offset = (((out_b_idx + (out_c_idx + 1)*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; + out_offset = (((out_b_idx + (out_c_idx_1)*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; if(remain >= 2){ vstore4(CONVERT_FLOAT4(out2), 0, output+out_offset); vstore4(CONVERT_FLOAT4(out3), out_hw.y, output+out_offset); @@ -1381,12 +1418,12 @@ void conv_2d_c8h2w1(GLOBAL_SIZE_2_DIMS #else vstore4(CONVERT_FLOAT4(out0), 0, output+out_offset); vstore4(CONVERT_FLOAT4(out1), out_hw.y, output+out_offset); - #ifdef CHANNEL_LEAVE - if(out_c_idx + 1 >= out_c_blocks){ + #ifdef CHANNEL_BOUNDARY_PROTECT + if(out_c_idx_1 >= out_c_blocks){ return; } #endif - out_offset = (((out_b_idx + (out_c_idx + 1)*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; + out_offset = (((out_b_idx + (out_c_idx_1)*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; vstore4(CONVERT_FLOAT4(out2), 0, output+out_offset); vstore4(CONVERT_FLOAT4(out3), out_hw.y, output+out_offset); #endif @@ -1415,17 +1452,21 @@ void conv_2d_c8h1w4(GLOBAL_SIZE_2_DIMS DEAL_NON_UNIFORM_DIM2(out_c_w_idx, out_b_h_idx); - const int out_c_idx = (out_c_w_idx / out_w_blocks) << 1; + const int out_c_idx_0 = (out_c_w_idx / out_w_blocks) << 1; + const int out_c_idx_1 = out_c_idx_0 + 1; const int out_w_idx = (out_c_w_idx % out_w_blocks) << 2; const int out_b_idx = out_b_h_idx / out_hw.x;//equal to in_b_idx const int out_h_idx = out_b_h_idx % out_hw.x; - COMPUTE_FLOAT4 out0 = CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx, bias)); + COMPUTE_FLOAT4 out0 = CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx_0, bias)); COMPUTE_FLOAT4 out1 = out0; COMPUTE_FLOAT4 out2 = out0; COMPUTE_FLOAT4 out3 = out0; - - COMPUTE_FLOAT4 out4 = CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx + 1, bias)); + #ifdef CHANNEL_BOUNDARY_PROTECT + COMPUTE_FLOAT4 out4 = out_c_idx_1 >= out_c_blocks ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx_1, bias)); + #else + COMPUTE_FLOAT4 out4 = CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx_1, bias)); + #endif COMPUTE_FLOAT4 out5 = out4; COMPUTE_FLOAT4 out6 = out4; COMPUTE_FLOAT4 out7 = out4; @@ -1445,8 +1486,8 @@ void conv_2d_c8h1w4(GLOBAL_SIZE_2_DIMS const int weight_ic_offset = out_c_blocks * weight_oc_offset; for(ushort in_c_idx = 0; in_c_idx < in_c_blocks; in_c_idx++) { //weights NC4HW4 [1, 4*icC4, ocC4*kh*kw, 1] xic4 - //index: [0, 4*in_c_idx, out_c_idx*kh*kw + kh_start*kw + kw_start, 0] - int weight_offset = ((((4*in_c_idx+0)* out_c_blocks + out_c_idx) *filter_hw.x + kh_start)*filter_hw.y + 0) * 4; + //index: [0, 4*in_c_idx, out_c_idx_0*kh*kw + kh_start*kw + kw_start, 0] + int weight_offset = ((((4*in_c_idx+0)* out_c_blocks + out_c_idx_0) *filter_hw.x + kh_start)*filter_hw.y + 0) * 4; for(int iy = in_h_idx_start; iy < in_h_idx_end; iy += dilate_hw.x) { const int inp_offset_base = (((out_b_idx + in_c_idx * batch) * in_hw.x + iy) * in_hw.y + 0) * 4; @@ -1487,11 +1528,17 @@ void conv_2d_c8h1w4(GLOBAL_SIZE_2_DIMS out3 = mad(in3.z, weight2, out3); out3 = mad(in3.w, weight3, out3); + #ifdef CHANNEL_BOUNDARY_PROTECT + weight0 = out_c_idx_1 >= out_c_blocks ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(0, weight+weight_offset+weight_oc_offset)); + weight1 = out_c_idx_1 >= out_c_blocks ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset)); + weight2 = out_c_idx_1 >= out_c_blocks ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset*2)); + weight3 = out_c_idx_1 >= out_c_blocks ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset*3)); + #else weight0 = CONVERT_COMPUTE_FLOAT4(vload4(0, weight+weight_offset+weight_oc_offset)); weight1 = CONVERT_COMPUTE_FLOAT4(vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset)); weight2 = CONVERT_COMPUTE_FLOAT4(vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset*2)); weight3 = CONVERT_COMPUTE_FLOAT4(vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset*3)); - + #endif out4 = mad(in0.x, weight0, out4); out4 = mad(in0.y, weight1, out4); out4 = mad(in0.z, weight2, out4); @@ -1538,7 +1585,7 @@ void conv_2d_c8h1w4(GLOBAL_SIZE_2_DIMS out7 = clamp(out7, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6); #endif - int out_offset = (((out_b_idx + out_c_idx*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; + int out_offset = (((out_b_idx + out_c_idx_0*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; #ifdef BLOCK_LEAVE const int remain = out_hw.y - out_w_idx; if(remain >= 4){ @@ -1551,10 +1598,10 @@ void conv_2d_c8h1w4(GLOBAL_SIZE_2_DIMS }else if(remain == 1){ vstore4(CONVERT_FLOAT4(out0), 0, output+out_offset); } - #ifdef CHANNEL_LEAVE - if(out_c_idx + 1 >= out_c_blocks)return; + #ifdef CHANNEL_BOUNDARY_PROTECT + if(out_c_idx_1 >= out_c_blocks)return; #endif - out_offset = (((out_b_idx + (out_c_idx + 1)*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; + out_offset = (((out_b_idx + (out_c_idx_1)*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; if(remain >= 4){ vstore16(CONVERT_FLOAT16((COMPUTE_FLOAT16)(out4, out5, out6, out7)), 0, output+out_offset); }else if(remain == 3){ @@ -1567,10 +1614,10 @@ void conv_2d_c8h1w4(GLOBAL_SIZE_2_DIMS } #else vstore16(CONVERT_FLOAT16((COMPUTE_FLOAT16)(out0, out1, out2, out3)), 0, output+out_offset); - #ifdef CHANNEL_LEAVE - if(out_c_idx + 1 >= out_c_blocks)return; + #ifdef CHANNEL_BOUNDARY_PROTECT + if(out_c_idx_1 >= out_c_blocks)return; #endif - out_offset = (((out_b_idx + (out_c_idx + 1)*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; + out_offset = (((out_b_idx + (out_c_idx_1)*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; vstore16(CONVERT_FLOAT16((COMPUTE_FLOAT16)(out4, out5, out6, out7)), 0, output+out_offset); #endif } diff --git a/source/backend/opencl/execution/cl/conv_2d_int_buf.cl b/source/backend/opencl/execution/cl/conv_2d_int_buf.cl index e42398c63..f482f578d 100644 --- a/source/backend/opencl/execution/cl/conv_2d_int_buf.cl +++ b/source/backend/opencl/execution/cl/conv_2d_int_buf.cl @@ -10,7 +10,7 @@ } #define MOD_NUM 15 -#ifdef INPUT_CHANNEL_LEAVE +#ifdef INPUT_CHANNEL_BOUNDARY_PROTECT #define PADZEROSVEC(k, channel, data0, data1, data2, data3) \ data0 = (k << 2) < channel ? data0 : 0; \ data1 = (k << 2) + 1 < channel ? data1 : 0; \ @@ -674,17 +674,19 @@ void conv_2d_int_c8h4w1(GLOBAL_SIZE_2_DIMS DEAL_NON_UNIFORM_DIM2(out_c_w_idx, out_b_h_idx); - const int out_c_idx = (out_c_w_idx / out_w_blocks) << 1; + const int out_c_idx_0 = (out_c_w_idx / out_w_blocks) << 1; + const int out_c_idx_1 = out_c_idx_0 + 1; const int out_w_idx = out_c_w_idx % out_w_blocks; const int out_b_idx = out_b_h_idx / out_h_blocks;//equal to in_b_idx const int out_h_idx = (out_b_h_idx % out_h_blocks) << 2; - COMPUTE_FLOAT4 bias0 = CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx, bias)); + COMPUTE_FLOAT4 bias0 = CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx_0, bias)); COMPUTE_FLOAT4 out0 = bias0; COMPUTE_FLOAT4 out1 = bias0; COMPUTE_FLOAT4 out2 = bias0; COMPUTE_FLOAT4 out3 = bias0; - COMPUTE_FLOAT4 bias1 = CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx + 1, bias)); + // bias align to 8, no need boundry protect + COMPUTE_FLOAT4 bias1 = CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx_1, bias)); COMPUTE_FLOAT4 out4 = bias1; COMPUTE_FLOAT4 out5 = bias1; COMPUTE_FLOAT4 out6 = bias1; @@ -706,18 +708,22 @@ void conv_2d_int_c8h4w1(GLOBAL_SIZE_2_DIMS const int in_hw_size = in_hw.x * in_hw.y; for(ushort in_c_idx = 0; in_c_idx < in_c_blocks; in_c_idx++) { int kindex = (in_c_idx * 4) / blockDim * out_c_blocks * 8; - COMPUTE_FLOAT8 ScaleOffset0 = CONVERT_COMPUTE_FLOAT8(vload8(out_c_idx, dequantScaleOffset + kindex)); - COMPUTE_FLOAT8 ScaleOffset1 = CONVERT_COMPUTE_FLOAT8(vload8(out_c_idx + 1, dequantScaleOffset + kindex)); + COMPUTE_FLOAT8 ScaleOffset0 = CONVERT_COMPUTE_FLOAT8(vload8(out_c_idx_0, dequantScaleOffset + kindex)); + #ifdef CHANNEL_BOUNDARY_PROTECT + COMPUTE_FLOAT8 ScaleOffset1 = out_c_idx_1 >= out_c_blocks ? (COMPUTE_FLOAT8)0 : CONVERT_COMPUTE_FLOAT8(vload8(out_c_idx_1, dequantScaleOffset + kindex)); + #else + COMPUTE_FLOAT8 ScaleOffset1 = CONVERT_COMPUTE_FLOAT8(vload8(out_c_idx_1, dequantScaleOffset + kindex)); + #endif COMPUTE_FLOAT4 scale0 = (COMPUTE_FLOAT4)(ScaleOffset0.s0, ScaleOffset0.s2, ScaleOffset0.s4, ScaleOffset0.s6); COMPUTE_FLOAT4 offset0 = (COMPUTE_FLOAT4)(ScaleOffset0.s1, ScaleOffset0.s3, ScaleOffset0.s5, ScaleOffset0.s7); COMPUTE_FLOAT4 scale1 = (COMPUTE_FLOAT4)(ScaleOffset1.s0, ScaleOffset1.s2, ScaleOffset1.s4, ScaleOffset1.s6); COMPUTE_FLOAT4 offset1 = (COMPUTE_FLOAT4)(ScaleOffset1.s1, ScaleOffset1.s3, ScaleOffset1.s5, ScaleOffset1.s7); //weights NC4HW4 [1, 4*icC4, ocC4*kh*kw, 1] xic4 - //index: [0, 4*in_c_idx, out_c_idx*kh*kw + kh_start*kw + kw_start, 0] + //index: [0, 4*in_c_idx, out_c_idx_0*kh*kw + kh_start*kw + kw_start, 0] const int inp_offset_base = (out_b_idx + in_c_idx*batch) * in_hw.x * in_hw.y * 4; for(int iy = 0; iy < filter_hw.x; iy++) { - int weight_offset = ((((4*in_c_idx+0)* out_c_blocks + out_c_idx) *filter_hw.x + iy)*filter_hw.y + kw_start) * 4; + int weight_offset = ((((4*in_c_idx+0)* out_c_blocks + out_c_idx_0) *filter_hw.x + iy)*filter_hw.y + kw_start) * 4; const int in_h0_idx = (iy * dilate_hw.x + in_h0_idx_base) * in_hw.y; const int in_h1_idx = (iy * dilate_hw.x + in_h1_idx_base) * in_hw.y; const int in_h2_idx = (iy * dilate_hw.x + in_h2_idx_base) * in_hw.y; @@ -791,10 +797,17 @@ void conv_2d_int_c8h4w1(GLOBAL_SIZE_2_DIMS out3 = mad(in3.w, weight3, out3); #if (defined USE_LOW_BIT_WEIGHT_INT8) + #ifdef CHANNEL_BOUNDARY_PROTECT + charWeight0 = out_c_idx_1 >= out_c_blocks ? (char4)0 : vload4(0, weight+weight_offset+weight_oc_offset); + charWeight1 = out_c_idx_1 >= out_c_blocks ? (char4)0 : vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset); + charWeight2 = out_c_idx_1 >= out_c_blocks ? (char4)0 : vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset*2); + charWeight3 = out_c_idx_1 >= out_c_blocks ? (char4)0 : vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset*3); + #else charWeight0 = vload4(0, weight+weight_offset+weight_oc_offset); charWeight1 = vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset); charWeight2 = vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset*2); charWeight3 = vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset*3); + #endif weight0 = CONVERT_COMPUTE_FLOAT4(charWeight0) * scale1 + offset1; weight1 = CONVERT_COMPUTE_FLOAT4(charWeight1) * scale1 + offset1; weight2 = CONVERT_COMPUTE_FLOAT4(charWeight2) * scale1 + offset1; @@ -878,7 +891,7 @@ void conv_2d_int_c8h4w1(GLOBAL_SIZE_2_DIMS out7 = clamp(out7, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6); #endif - int out_offset = (((out_b_idx + out_c_idx*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; + int out_offset = (((out_b_idx + out_c_idx_0*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; #ifdef BLOCK_LEAVE const int remain = out_hw.x - out_h_idx; if(remain >= 4){ @@ -896,12 +909,12 @@ void conv_2d_int_c8h4w1(GLOBAL_SIZE_2_DIMS }else if(remain == 1){ vstore4(CONVERT_FLOAT4(out0), 0, output+out_offset); } -#ifdef CHANNEL_LEAVE - if(out_c_idx + 1 >= out_c_blocks){ +#ifdef CHANNEL_BOUNDARY_PROTECT + if(out_c_idx_1 >= out_c_blocks){ return; } #endif - out_offset = (((out_b_idx + (out_c_idx + 1)*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; + out_offset = (((out_b_idx + out_c_idx_1*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; if(remain >= 4){ vstore4(CONVERT_FLOAT4(out4), 0, output+out_offset); vstore4(CONVERT_FLOAT4(out5), out_hw.y, output+out_offset); @@ -922,12 +935,12 @@ void conv_2d_int_c8h4w1(GLOBAL_SIZE_2_DIMS vstore4(CONVERT_FLOAT4(out1), out_hw.y, output+out_offset); vstore4(CONVERT_FLOAT4(out2), 2 * out_hw.y, output+out_offset); vstore4(CONVERT_FLOAT4(out3), 3 * out_hw.y, output+out_offset); -#ifdef CHANNEL_LEAVE - if(out_c_idx + 1 >= out_c_blocks){ +#ifdef CHANNEL_BOUNDARY_PROTECT + if(out_c_idx_1 >= out_c_blocks){ return; } #endif - out_offset = (((out_b_idx + (out_c_idx + 1)*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; + out_offset = (((out_b_idx + out_c_idx_1*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; vstore4(CONVERT_FLOAT4(out4), 0, output+out_offset); vstore4(CONVERT_FLOAT4(out5), out_hw.y, output+out_offset); vstore4(CONVERT_FLOAT4(out6), 2 * out_hw.y, output+out_offset); @@ -964,15 +977,17 @@ void conv_2d_int_c8h2w1(GLOBAL_SIZE_2_DIMS DEAL_NON_UNIFORM_DIM2(out_c_w_idx, out_b_h_idx); - const int out_c_idx = (out_c_w_idx / out_w_blocks) << 1; + const int out_c_idx_0 = (out_c_w_idx / out_w_blocks) << 1; + const int out_c_idx_1 = out_c_idx_0 + 1; const int out_w_idx = out_c_w_idx % out_w_blocks; const int out_b_idx = out_b_h_idx / out_h_blocks;//equal to in_b_idx const int out_h_idx = (out_b_h_idx % out_h_blocks) << 1; - COMPUTE_FLOAT4 bias0 = CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx, bias)); + COMPUTE_FLOAT4 bias0 = CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx_0, bias)); COMPUTE_FLOAT4 out0 = bias0; COMPUTE_FLOAT4 out1 = bias0; - COMPUTE_FLOAT4 bias1 = CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx + 1, bias)); + // bias align to 8, no need boundry protect + COMPUTE_FLOAT4 bias1 = CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx_1, bias)); COMPUTE_FLOAT4 out2 = bias1; COMPUTE_FLOAT4 out3 = bias1; @@ -991,18 +1006,22 @@ void conv_2d_int_c8h2w1(GLOBAL_SIZE_2_DIMS // weight: [ic/4, oc, 4], loop: ic/4 for(ushort in_c_idx = 0; in_c_idx < in_c_blocks; in_c_idx++) { int kindex = (in_c_idx * 4) / blockDim * out_c_blocks * 8; - COMPUTE_FLOAT8 ScaleOffset0 = CONVERT_COMPUTE_FLOAT8(vload8(out_c_idx, dequantScaleOffset + kindex)); - COMPUTE_FLOAT8 ScaleOffset1 = CONVERT_COMPUTE_FLOAT8(vload8(out_c_idx + 1, dequantScaleOffset + kindex)); + COMPUTE_FLOAT8 ScaleOffset0 = CONVERT_COMPUTE_FLOAT8(vload8(out_c_idx_0, dequantScaleOffset + kindex)); + #ifdef CHANNEL_BOUNDARY_PROTECT + COMPUTE_FLOAT8 ScaleOffset1 = out_c_idx_1 >= out_c_blocks ? (COMPUTE_FLOAT8)0 : CONVERT_COMPUTE_FLOAT8(vload8(out_c_idx_1, dequantScaleOffset + kindex)); + #else + COMPUTE_FLOAT8 ScaleOffset1 = CONVERT_COMPUTE_FLOAT8(vload8(out_c_idx_1, dequantScaleOffset + kindex)); + #endif COMPUTE_FLOAT4 scale0 = (COMPUTE_FLOAT4)(ScaleOffset0.s0, ScaleOffset0.s2, ScaleOffset0.s4, ScaleOffset0.s6); COMPUTE_FLOAT4 offset0 = (COMPUTE_FLOAT4)(ScaleOffset0.s1, ScaleOffset0.s3, ScaleOffset0.s5, ScaleOffset0.s7); COMPUTE_FLOAT4 scale1 = (COMPUTE_FLOAT4)(ScaleOffset1.s0, ScaleOffset1.s2, ScaleOffset1.s4, ScaleOffset1.s6); COMPUTE_FLOAT4 offset1 = (COMPUTE_FLOAT4)(ScaleOffset1.s1, ScaleOffset1.s3, ScaleOffset1.s5, ScaleOffset1.s7); //weights NC4HW4 [1, 4*icC4, ocC4*kh*kw, 1] xic4 - //index: [0, 4*in_c_idx, out_c_idx*kh*kw + kh_start*kw + kw_start, 0] + //index: [0, 4*in_c_idx, out_c_idx_0*kh*kw + kh_start*kw + kw_start, 0] const int inp_offset_base = (out_b_idx + in_c_idx*batch) * in_hw.x * in_hw.y * 4; for(int iy = 0; iy < filter_hw.x; iy++) { - int weight_offset = ((((4*in_c_idx+0)* out_c_blocks + out_c_idx) *filter_hw.x + iy)*filter_hw.y + kw_start) * 4; + int weight_offset = ((((4*in_c_idx+0)* out_c_blocks + out_c_idx_0) *filter_hw.x + iy)*filter_hw.y + kw_start) * 4; const int in_h0_idx = (iy * dilate_hw.x + in_h0_idx_base) * in_hw.y; const int in_h1_idx = (iy * dilate_hw.x + in_h1_idx_base) * in_hw.y; @@ -1060,10 +1079,17 @@ void conv_2d_int_c8h2w1(GLOBAL_SIZE_2_DIMS out1 = mad(in1.w, weight3, out1); #if (defined USE_LOW_BIT_WEIGHT_INT8) + #ifdef CHANNEL_BOUNDARY_PROTECT + charWeight0 = out_c_idx_1 >= out_c_blocks ? (char4)0 : vload4(0, weight+weight_offset+weight_oc_offset); + charWeight1 = out_c_idx_1 >= out_c_blocks ? (char4)0 : vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset); + charWeight2 = out_c_idx_1 >= out_c_blocks ? (char4)0 : vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset*2); + charWeight3 = out_c_idx_1 >= out_c_blocks ? (char4)0 : vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset*3); + #else charWeight0 = vload4(0, weight+weight_offset+weight_oc_offset); charWeight1 = vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset); charWeight2 = vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset*2); charWeight3 = vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset*3); + #endif weight0 = CONVERT_COMPUTE_FLOAT4(charWeight0) * scale1 + offset1; weight1 = CONVERT_COMPUTE_FLOAT4(charWeight1) * scale1 + offset1; weight2 = CONVERT_COMPUTE_FLOAT4(charWeight2) * scale1 + offset1; @@ -1128,7 +1154,7 @@ void conv_2d_int_c8h2w1(GLOBAL_SIZE_2_DIMS out3 = clamp(out3, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6); #endif - int out_offset = (((out_b_idx + out_c_idx*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; + int out_offset = (((out_b_idx + out_c_idx_0*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; #ifdef BLOCK_LEAVE const int remain = out_hw.x - out_h_idx; if(remain >= 2){ @@ -1137,12 +1163,12 @@ void conv_2d_int_c8h2w1(GLOBAL_SIZE_2_DIMS }else if(remain == 1){ vstore4(CONVERT_FLOAT4(out0), 0, output+out_offset); } -#ifdef CHANNEL_LEAVE - if(out_c_idx + 1 >= out_c_blocks){ +#ifdef CHANNEL_BOUNDARY_PROTECT + if(out_c_idx_1 >= out_c_blocks){ return; } #endif - out_offset = (((out_b_idx + (out_c_idx + 1)*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; + out_offset = (((out_b_idx + out_c_idx_1*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; if(remain >= 2){ vstore4(CONVERT_FLOAT4(out2), 0, output+out_offset); vstore4(CONVERT_FLOAT4(out3), out_hw.y, output+out_offset); @@ -1152,12 +1178,12 @@ void conv_2d_int_c8h2w1(GLOBAL_SIZE_2_DIMS #else vstore4(CONVERT_FLOAT4(out0), 0, output+out_offset); vstore4(CONVERT_FLOAT4(out1), out_hw.y, output+out_offset); -#ifdef CHANNEL_LEAVE - if(out_c_idx + 1 >= out_c_blocks){ +#ifdef CHANNEL_BOUNDARY_PROTECT + if(out_c_idx_1 >= out_c_blocks){ return; } #endif - out_offset = (((out_b_idx + (out_c_idx + 1)*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; + out_offset = (((out_b_idx + out_c_idx_1*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; vstore4(CONVERT_FLOAT4(out2), 0, output+out_offset); vstore4(CONVERT_FLOAT4(out3), out_hw.y, output+out_offset); #endif @@ -1192,17 +1218,19 @@ void conv_2d_int_c8h1w4(GLOBAL_SIZE_2_DIMS DEAL_NON_UNIFORM_DIM2(out_c_w_idx, out_b_h_idx); - const int out_c_idx = (out_c_w_idx / out_w_blocks) << 1; + const int out_c_idx_0 = (out_c_w_idx / out_w_blocks) << 1; + const int out_c_idx_1 = out_c_idx_0 + 1; const int out_w_idx = (out_c_w_idx % out_w_blocks) << 2; const int out_b_idx = out_b_h_idx / out_hw.x;//equal to in_b_idx const int out_h_idx = out_b_h_idx % out_hw.x; - COMPUTE_FLOAT4 bias0 = CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx, bias)); + COMPUTE_FLOAT4 bias0 = CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx_0, bias)); COMPUTE_FLOAT4 out0 = bias0; COMPUTE_FLOAT4 out1 = bias0; COMPUTE_FLOAT4 out2 = bias0; COMPUTE_FLOAT4 out3 = bias0; - COMPUTE_FLOAT4 bias1 = CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx + 1, bias)); + // bias align to 8, no need boundry protect + COMPUTE_FLOAT4 bias1 = CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx_1, bias)); COMPUTE_FLOAT4 out4 = bias1; COMPUTE_FLOAT4 out5 = bias1; COMPUTE_FLOAT4 out6 = bias1; @@ -1223,15 +1251,19 @@ void conv_2d_int_c8h1w4(GLOBAL_SIZE_2_DIMS const int weight_ic_offset = out_c_blocks * weight_oc_offset; for(ushort in_c_idx = 0; in_c_idx < in_c_blocks; in_c_idx++) { int kindex = (in_c_idx * 4) / blockDim * out_c_blocks * 8; - COMPUTE_FLOAT8 ScaleOffset0 = CONVERT_COMPUTE_FLOAT8(vload8(out_c_idx, dequantScaleOffset + kindex)); - COMPUTE_FLOAT8 ScaleOffset1 = CONVERT_COMPUTE_FLOAT8(vload8(out_c_idx + 1, dequantScaleOffset + kindex)); + COMPUTE_FLOAT8 ScaleOffset0 = CONVERT_COMPUTE_FLOAT8(vload8(out_c_idx_0, dequantScaleOffset + kindex)); + #ifdef CHANNEL_BOUNDARY_PROTECT + COMPUTE_FLOAT8 ScaleOffset1 = out_c_idx_1 >= out_c_blocks ? (COMPUTE_FLOAT8)0 : CONVERT_COMPUTE_FLOAT8(vload8(out_c_idx_1, dequantScaleOffset + kindex)); + #else + COMPUTE_FLOAT8 ScaleOffset1 = CONVERT_COMPUTE_FLOAT8(vload8(out_c_idx_1, dequantScaleOffset + kindex)); + #endif COMPUTE_FLOAT4 scale0 = (COMPUTE_FLOAT4)(ScaleOffset0.s0, ScaleOffset0.s2, ScaleOffset0.s4, ScaleOffset0.s6); COMPUTE_FLOAT4 offset0 = (COMPUTE_FLOAT4)(ScaleOffset0.s1, ScaleOffset0.s3, ScaleOffset0.s5, ScaleOffset0.s7); COMPUTE_FLOAT4 scale1 = (COMPUTE_FLOAT4)(ScaleOffset1.s0, ScaleOffset1.s2, ScaleOffset1.s4, ScaleOffset1.s6); COMPUTE_FLOAT4 offset1 = (COMPUTE_FLOAT4)(ScaleOffset1.s1, ScaleOffset1.s3, ScaleOffset1.s5, ScaleOffset1.s7); //weights NC4HW4 [1, 4*icC4, ocC4*kh*kw, 1] xic4 - //index: [0, 4*in_c_idx, out_c_idx*kh*kw + kh_start*kw + kw_start, 0] - int weight_offset = ((((4*in_c_idx+0)* out_c_blocks + out_c_idx) *filter_hw.x + kh_start)*filter_hw.y + 0) * 4; + //index: [0, 4*in_c_idx, out_c_idx_0*kh*kw + kh_start*kw + kw_start, 0] + int weight_offset = ((((4*in_c_idx+0)* out_c_blocks + out_c_idx_0) *filter_hw.x + kh_start)*filter_hw.y + 0) * 4; for(int iy = in_h_idx_start; iy < in_h_idx_end; iy += dilate_hw.x) { const int inp_offset_base = (((out_b_idx + in_c_idx*batch) * in_hw.x + iy) * in_hw.y + 0) * 4; @@ -1309,10 +1341,17 @@ void conv_2d_int_c8h1w4(GLOBAL_SIZE_2_DIMS out3 = mad(in3.w, weight3, out3); #if (defined USE_LOW_BIT_WEIGHT_INT8) + #ifdef CHANNEL_BOUNDARY_PROTECT + charWeight0 = out_c_idx_1 >= out_c_blocks ? (char4)0 : vload4(0, weight+weight_offset+weight_oc_offset); + charWeight1 = out_c_idx_1 >= out_c_blocks ? (char4)0 : vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset); + charWeight2 = out_c_idx_1 >= out_c_blocks ? (char4)0 : vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset*2); + charWeight3 = out_c_idx_1 >= out_c_blocks ? (char4)0 : vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset*3); + #else charWeight0 = vload4(0, weight+weight_offset+weight_oc_offset); charWeight1 = vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset); charWeight2 = vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset*2); charWeight3 = vload4(0, weight+weight_offset+weight_oc_offset+weight_ic_offset*3); + #endif weight0 = CONVERT_COMPUTE_FLOAT4(charWeight0) * scale1 + offset1; weight1 = CONVERT_COMPUTE_FLOAT4(charWeight1) * scale1 + offset1; weight2 = CONVERT_COMPUTE_FLOAT4(charWeight2) * scale1 + offset1; @@ -1396,7 +1435,7 @@ void conv_2d_int_c8h1w4(GLOBAL_SIZE_2_DIMS out7 = clamp(out7, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6); #endif - int out_offset = (((out_b_idx + out_c_idx*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; + int out_offset = (((out_b_idx + out_c_idx_0*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; #ifdef BLOCK_LEAVE const int remain = out_hw.y - out_w_idx; if(remain >= 4){ @@ -1409,10 +1448,10 @@ void conv_2d_int_c8h1w4(GLOBAL_SIZE_2_DIMS }else if(remain == 1){ vstore4(CONVERT_FLOAT4(out0), 0, output+out_offset); } -#ifdef CHANNEL_LEAVE - if(out_c_idx + 1 >= out_c_blocks)return; +#ifdef CHANNEL_BOUNDARY_PROTECT + if(out_c_idx_1 >= out_c_blocks)return; #endif - out_offset = (((out_b_idx + (out_c_idx + 1)*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; + out_offset = (((out_b_idx + out_c_idx_1*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; if(remain >= 4){ vstore16(CONVERT_FLOAT16((COMPUTE_FLOAT16)(out4, out5, out6, out7)), 0, output+out_offset); }else if(remain == 3){ @@ -1425,10 +1464,10 @@ void conv_2d_int_c8h1w4(GLOBAL_SIZE_2_DIMS } #else vstore16(CONVERT_FLOAT16((COMPUTE_FLOAT16)(out0, out1, out2, out3)), 0, output+out_offset); -#ifdef CHANNEL_LEAVE - if(out_c_idx + 1 >= out_c_blocks)return; +#ifdef CHANNEL_BOUNDARY_PROTECT + if(out_c_idx_1 >= out_c_blocks)return; #endif - out_offset = (((out_b_idx + (out_c_idx + 1)*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; + out_offset = (((out_b_idx + out_c_idx_1*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; vstore16(CONVERT_FLOAT16((COMPUTE_FLOAT16)(out4, out5, out6, out7)), 0, output+out_offset); #endif } diff --git a/source/backend/opencl/execution/cl/depthwise_conv2d_buf.cl b/source/backend/opencl/execution/cl/depthwise_conv2d_buf.cl index c32400af9..12cac5dfc 100644 --- a/source/backend/opencl/execution/cl/depthwise_conv2d_buf.cl +++ b/source/backend/opencl/execution/cl/depthwise_conv2d_buf.cl @@ -303,14 +303,18 @@ void depthwise_conv2d_s1_c8h1w4(GLOBAL_SIZE_2_DIMS __global const FLOAT *input, COMPUTE_FLOAT4 inValue2 = (in_w_start_2+kw < 0 || in_w_start_2+kw >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+2, input+inp_offset_c0)); COMPUTE_FLOAT4 inValue3 = (in_w_start_3+kw < 0 || in_w_start_3+kw >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+3, input+inp_offset_c0)); - COMPUTE_FLOAT4 inValue4 = (in_w_start_0+kw < 0 || in_w_start_0+kw >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+0, input+inp_offset_c1)); - COMPUTE_FLOAT4 inValue5 = (in_w_start_1+kw < 0 || in_w_start_1+kw >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+1, input+inp_offset_c1)); - COMPUTE_FLOAT4 inValue6 = (in_w_start_2+kw < 0 || in_w_start_2+kw >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+2, input+inp_offset_c1)); - COMPUTE_FLOAT4 inValue7 = (in_w_start_3+kw < 0 || in_w_start_3+kw >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+3, input+inp_offset_c1)); + COMPUTE_FLOAT4 inValue4 = (in_w_start_0+kw < 0 || in_w_start_0+kw >= in_hw.y || c_idx+1 >= c_blocks) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+0, input+inp_offset_c1)); + COMPUTE_FLOAT4 inValue5 = (in_w_start_1+kw < 0 || in_w_start_1+kw >= in_hw.y || c_idx+1 >= c_blocks) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+1, input+inp_offset_c1)); + COMPUTE_FLOAT4 inValue6 = (in_w_start_2+kw < 0 || in_w_start_2+kw >= in_hw.y || c_idx+1 >= c_blocks) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+2, input+inp_offset_c1)); + COMPUTE_FLOAT4 inValue7 = (in_w_start_3+kw < 0 || in_w_start_3+kw >= in_hw.y || c_idx+1 >= c_blocks) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+3, input+inp_offset_c1)); //NC4HW4 [1, filterShape.x*filterShape.y, 1, channelBlocks] x oc4 //index: [0, filterIdx, 0, inChannelBlockIdx] COMPUTE_FLOAT4 weights_0 = CONVERT_COMPUTE_FLOAT4(vload4(0, filter+(filter_idx*c_blocks+c_idx+0)*4)); + /* + weight:[kh*kw, oc/4, oc_4], memory align to 8 + no need to boundry protect + */ COMPUTE_FLOAT4 weights_1 = CONVERT_COMPUTE_FLOAT4(vload4(0, filter+(filter_idx*c_blocks+c_idx+1)*4)); outValue0 = mad(inValue0, weights_0, outValue0); @@ -435,12 +439,16 @@ void depthwise_conv2d_s1_c8h1w2(GLOBAL_SIZE_2_DIMS __global const FLOAT *input, COMPUTE_FLOAT4 inValue0 = (in_w_start_0+kw < 0 || in_w_start_0+kw >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+0, input+inp_offset_c0)); COMPUTE_FLOAT4 inValue1 = (in_w_start_1+kw < 0 || in_w_start_1+kw >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+1, input+inp_offset_c0)); - COMPUTE_FLOAT4 inValue4 = (in_w_start_0+kw < 0 || in_w_start_0+kw >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+0, input+inp_offset_c1)); - COMPUTE_FLOAT4 inValue5 = (in_w_start_1+kw < 0 || in_w_start_1+kw >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+1, input+inp_offset_c1)); + COMPUTE_FLOAT4 inValue4 = (in_w_start_0+kw < 0 || in_w_start_0+kw >= in_hw.y || c_idx+1 >= c_blocks) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+0, input+inp_offset_c1)); + COMPUTE_FLOAT4 inValue5 = (in_w_start_1+kw < 0 || in_w_start_1+kw >= in_hw.y || c_idx+1 >= c_blocks) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+1, input+inp_offset_c1)); //NC4HW4 [1, filterShape.x*filterShape.y, 1, channelBlocks] x oc4 //index: [0, filterIdx, 0, inChannelBlockIdx] COMPUTE_FLOAT4 weights_0 = CONVERT_COMPUTE_FLOAT4(vload4(0, filter+(filter_idx*c_blocks+c_idx+0)*4)); + /* + weight:[kh*kw, oc/4, oc_4], memory align to 8 + no need to boundry protect + */ COMPUTE_FLOAT4 weights_1 = CONVERT_COMPUTE_FLOAT4(vload4(0, filter+(filter_idx*c_blocks+c_idx+1)*4)); outValue0 = mad(inValue0, weights_0, outValue0); diff --git a/source/backend/opencl/execution/cl/glmem_convert.cl b/source/backend/opencl/execution/cl/glmem_convert.cl new file mode 100644 index 000000000..8288ab9fa --- /dev/null +++ b/source/backend/opencl/execution/cl/glmem_convert.cl @@ -0,0 +1,211 @@ +#ifdef MNN_SUPPORT_FP16 +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#endif + +#define GLOBAL_SIZE_3_DIMS __private const int global_size_dim0, __private const int global_size_dim1, __private const int global_size_dim2, +#define DEAL_NON_UNIFORM_DIM3(input1, input2, input3) \ + if (input1 >= global_size_dim0 || input2 >= global_size_dim1 || input3 >= global_size_dim2) { \ + return; \ + } + +#define MNN_DATA_FORMAT_NCHW 0 +#define MNN_DATA_FORMAT_NHWC 1 +#define MNN_DATA_FORMAT_NC4HW4 2 +#define MNN_DATA_FORMAT_C4NHW4 3 + +#define __CAT(x, y) x##y +#define CAT(x, y) __CAT(x, y) +#define OUTPUT_TYPE2 CAT(OUTPUT_TYPE, 2) +#define OUTPUT_TYPE3 CAT(OUTPUT_TYPE, 3) +__constant sampler_t SAMPLER = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; + +#ifdef SHARED_TO_CL +__kernel void gl_to_cl(GLOBAL_SIZE_3_DIMS + __global uchar *input_ptr, + #ifdef USE_IMAGE + __write_only image2d_t output_ptr, + #else + __global OUTPUT_TYPE *output_ptr, + #endif + __private const int4 shape // N C H W +) { + + int wblock = get_global_id(0); + int cblock = get_global_id(1); + int nh = get_global_id(2); + + DEAL_NON_UNIFORM_DIM3(wblock, cblock, nh); + const int w = wblock << 2; + const int h = nh % shape.z; + const int c = cblock << 2; + const int n = nh / shape.z; + + int idx = c * shape.w + w; // c/4*w + int idy = nh; // n*h + const int offset = idy * shape.w * 4; + OUTPUT_TYPE4 in0 = CONVERT_OUTPUT4(vload4(idx, input_ptr + offset)); + OUTPUT_TYPE4 in1 = CONVERT_OUTPUT4(vload4(idx + 1, input_ptr + offset)); + OUTPUT_TYPE4 in2 = CONVERT_OUTPUT4(vload4(idx + 2, input_ptr + offset)); + OUTPUT_TYPE4 in3 = CONVERT_OUTPUT4(vload4(idx + 3, input_ptr + offset)); + +#ifdef USE_IMAGE + WI_DATA(output_ptr, (int2)(idx, idy), in0); + if(w + 1 >= shape.w) return; + WI_DATA(output_ptr, (int2)(idx+1, idy), in1); + if(w + 2 >= shape.w) return; + WI_DATA(output_ptr, (int2)(idx+2, idy), in2); + if(w + 3 >= shape.w) return; + WI_DATA(output_ptr, (int2)(idx+3, idy), in3); +#else + #if OUTPUT_FORMAT == MNN_DATA_FORMAT_NCHW + int output_offset = ((n * shape.y + c) * shape.z + h) * shape.w + w; + int stride = shape.z * shape.w; + int remain = shape.w - w; + if(remain >= 4){ + vstore4((OUTPUT_TYPE4)(in0.x, in1.x, in2.x, in3.x), 0, output_ptr + output_offset); + if(c + 1 >= shape.y) return; + vstore4((OUTPUT_TYPE4)(in0.y, in1.y, in2.y, in3.y), 0, output_ptr + output_offset + stride); + if(c + 2 >= shape.y) return; + vstore4((OUTPUT_TYPE4)(in0.z, in1.z, in2.z, in3.z), 0, output_ptr + output_offset + stride + stride); + if(c + 3 >= shape.y) return; + vstore4((OUTPUT_TYPE4)(in0.w, in1.w, in2.w, in3.w), 0, output_ptr + output_offset + stride + stride + stride); + } else if(remain == 3){ + vstore3((OUTPUT_TYPE3)(in0.x, in1.x, in2.x), 0, output_ptr + output_offset); + if(c + 1 >= shape.y) return; + vstore3((OUTPUT_TYPE3)(in0.y, in1.y, in2.y), 0, output_ptr + output_offset + stride); + if(c + 2 >= shape.y) return; + vstore3((OUTPUT_TYPE3)(in0.z, in1.z, in2.z), 0, output_ptr + output_offset + stride + stride); + if(c + 3 >= shape.y) return; + vstore3((OUTPUT_TYPE3)(in0.w, in1.w, in2.w), 0, output_ptr + output_offset + stride + stride + stride); + } else if(remain == 2){ + vstore2((OUTPUT_TYPE2)(in0.x, in1.x), 0, output_ptr + output_offset); + if(c + 1 >= shape.y) return; + vstore2((OUTPUT_TYPE2)(in0.y, in1.y), 0, output_ptr + output_offset + stride); + if(c + 2 >= shape.y) return; + vstore2((OUTPUT_TYPE2)(in0.z, in1.z), 0, output_ptr + output_offset + stride + stride); + if(c + 3 >= shape.y) return; + vstore2((OUTPUT_TYPE2)(in0.w, in1.w), 0, output_ptr + output_offset + stride + stride + stride); + }else if(remain == 1){ + output_ptr[output_offset] = in0.x; + if(c + 1 >= shape.y) return; + output_ptr[output_offset + stride] = in0.y; + if(c + 2 >= shape.y) return; + output_ptr[output_offset + stride + stride] = in0.z; + if(c + 3 >= shape.y) return; + output_ptr[output_offset + stride + stride + stride] = in0.w; + } + #elif OUTPUT_FORMAT == MNN_DATA_FORMAT_NHWC + int output_offset = ((n * shape.z + h) * shape.w + w) * shape.y + c; + int remain = shape.y - c; + if(remain >= 4){ + vstore4(CONVERT_OUTPUT4(in0), 0, output_ptr + output_offset); + if(w + 1 >= shape.w) return; + vstore4(CONVERT_OUTPUT4(in1), 0, output_ptr + output_offset + shape.y); + if(w + 2 >= shape.w) return; + vstore4(CONVERT_OUTPUT4(in2), 0, output_ptr + output_offset + shape.y + shape.y); + if(w + 3 >= shape.w) return; + vstore4(CONVERT_OUTPUT4(in3), 0, output_ptr + output_offset + shape.y + shape.y + shape.y); + } else if(remain == 3){ + vstore3((OUTPUT_TYPE3)(in0.x, in0.y, in0.z), 0, output_ptr + output_offset); + if(w + 1 >= shape.w) return; + vstore3((OUTPUT_TYPE3)(in1.x, in1.y, in1.z), 0, output_ptr + output_offset + shape.y); + if(w + 2 >= shape.w) return; + vstore3((OUTPUT_TYPE3)(in2.x, in2.y, in2.z), 0, output_ptr + output_offset + shape.y + shape.y); + if(w + 3 >= shape.w) return; + vstore3((OUTPUT_TYPE3)(in3.x, in3.y, in3.z), 0, output_ptr + output_offset + shape.y + shape.y + shape.y); + } else if(remain == 2){ + vstore2((OUTPUT_TYPE2)(in0.x, in0.y), 0, output_ptr + output_offset); + if(w + 1 >= shape.w) return; + vstore2((OUTPUT_TYPE2)(in1.x, in1.y), 0, output_ptr + output_offset + shape.y); + if(w + 2 >= shape.w) return; + vstore2((OUTPUT_TYPE2)(in2.x, in2.y), 0, output_ptr + output_offset + shape.y + shape.y); + if(w + 3 >= shape.w) return; + vstore2((OUTPUT_TYPE2)(in3.x, in3.y), 0, output_ptr + output_offset + shape.y + shape.y + shape.y); + }else if(remain == 1){ + output_ptr[output_offset] = in0.x; + if(w + 1 >= shape.w) return; + output_ptr[output_offset + shape.y] = in1.x; + if(w + 2 >= shape.w) return; + output_ptr[output_offset + shape.y + shape.y] = in1.x; + if(w + 3 >= shape.w) return; + output_ptr[output_offset + shape.y + shape.y + shape.y] = in1.x; + } + #elif OUTPUT_FORMAT == MNN_DATA_FORMAT_NC4HW4 + int output_offset = (((cblock * shape.x + n) * shape.z + h) * shape.w + w) * 4; + vstore4(in0, 0, output_ptr + output_offset); + if(w + 1 >= shape.w) return; + vstore4(in1, 0, output_ptr + output_offset + 4); + if(w + 2 >= shape.w) return; + vstore4(in2, 0, output_ptr + output_offset + 8); + if(w + 3 >= shape.w) return; + vstore4(in3, 0, output_ptr + output_offset + 12); + #endif +#endif +} +#endif + +#ifdef CL_TO_SHARED +__kernel void cl_to_gl(GLOBAL_SIZE_3_DIMS + #ifdef USE_IMAGE + __read_only image2d_t input_ptr, + #else + __global INPUT_TYPE *input_ptr, + #endif + __global uchar *output_ptr, + __private const int4 shape // N C H W +) { + + int wblock = get_global_id(0); + int cblock = get_global_id(1); + int nh = get_global_id(2); + + DEAL_NON_UNIFORM_DIM3(wblock, cblock, nh); + const int w = wblock << 2; + const int h = nh % shape.z; + const int c = cblock << 2; + const int n = nh / shape.z; + + int idx = c * shape.w + w; // c/4*w + int idy = nh; // n*h +#ifdef USE_IMAGE + INPUT_TYPE4 in0 = RI_DATA(input_ptr, SAMPLER, (int2)(idx, idy)); + INPUT_TYPE4 in1 = RI_DATA(input_ptr, SAMPLER, (int2)(idx+1, idy)); + INPUT_TYPE4 in2 = RI_DATA(input_ptr, SAMPLER, (int2)(idx+2, idy)); + INPUT_TYPE4 in3 = RI_DATA(input_ptr, SAMPLER, (int2)(idx+3, idy)); +#else + #if INPUT_FORMAT == MNN_DATA_FORMAT_NCHW + int input_offset = ((n * shape.y + c) * shape.z + h) * shape.w + w; + int stride = shape.z * shape.w; + INPUT_TYPE4 tmp0, tmp1, tmp2, tmp3; + tmp0 = vload4(0, input_ptr + input_offset); + tmp1 = vload4(0, input_ptr + input_offset + stride); + tmp2 = vload4(0, input_ptr + input_offset + stride + stride); + tmp3 = vload4(0, input_ptr + input_offset + stride + stride + stride); + INPUT_TYPE4 in0 = (INPUT_TYPE4)(tmp0.x, tmp1.x, tmp2.x, tmp3.x); + INPUT_TYPE4 in1 = (INPUT_TYPE4)(tmp0.y, tmp1.y, tmp2.y, tmp3.y); + INPUT_TYPE4 in2 = (INPUT_TYPE4)(tmp0.z, tmp1.z, tmp2.z, tmp3.z); + INPUT_TYPE4 in3 = (INPUT_TYPE4)(tmp0.w, tmp1.w, tmp2.w, tmp3.w); + #elif INPUT_FORMAT == MNN_DATA_FORMAT_NHWC + int input_offset = ((n * shape.z + h) * shape.w + w) * shape.y + c; + INPUT_TYPE4 in0 = vload4(0, input_ptr + input_offset); + INPUT_TYPE4 in1 = vload4(0, input_ptr + input_offset + shape.y); + INPUT_TYPE4 in2 = vload4(0, input_ptr + input_offset + shape.y + shape.y); + INPUT_TYPE4 in3 = vload4(0, input_ptr + input_offset + shape.y + shape.y + shape.y); + #elif INPUT_FORMAT == MNN_DATA_FORMAT_NC4HW4 + int input_offset = (((cblock * shape.x + n) * shape.z + h) * shape.w + w) * 4; + INPUT_TYPE4 in0 = vload4(0, input_ptr + input_offset); + INPUT_TYPE4 in1 = vload4(0, input_ptr + input_offset + 4); + INPUT_TYPE4 in2 = vload4(0, input_ptr + input_offset + 8); + INPUT_TYPE4 in3 = vload4(0, input_ptr + input_offset + 12); + #endif +#endif + const int offset = idy * shape.w * 4; + vstore4(convert_uchar4(in0), idx, output_ptr + offset); + if(w + 1 >= shape.w) return; + vstore4(convert_uchar4(in1), idx+1, output_ptr + offset); + if(w + 2 >= shape.w) return; + vstore4(convert_uchar4(in2), idx+2, output_ptr + offset); + if(w + 3 >= shape.w) return; + vstore4(convert_uchar4(in3), idx+3, output_ptr + offset); +} +#endif diff --git a/source/backend/opencl/execution/cl/opencl_program.cc b/source/backend/opencl/execution/cl/opencl_program.cc index b66986c17..a809072dc 100644 --- a/source/backend/opencl/execution/cl/opencl_program.cc +++ b/source/backend/opencl/execution/cl/opencl_program.cc @@ -384,6 +384,7 @@ const char* conv_2d = " for (int in_channel_block_idx=0; in_channel_block_idx= out_channel_blocks ? (FLOAT16)0 : CONVERT_FLOAT16(vload16(0,kernel_ptr+16+weight_ic_offset+in_channel_block_idx*weight_oc_offset));\n" +" #else\n" " FLOAT16 weightsInt81=CONVERT_FLOAT16(vload16(0,kernel_ptr+16+weight_ic_offset+in_channel_block_idx*weight_oc_offset));\n" +" #endif\n" " FLOAT4 weights0=CONVERT_FLOAT4(weightsInt80.s0123)*scale0+offset0;\n" " FLOAT4 weights1=CONVERT_FLOAT4(weightsInt80.s4567)*scale0+offset0;\n" " FLOAT4 weights2=CONVERT_FLOAT4(weightsInt80.s89ab)*scale0+offset0;\n" @@ -464,10 +469,17 @@ const char* conv_2d = " weights1=vload4(weights_width_base+1,weights+weight_offset);\n" " weights2=vload4(weights_width_base+2,weights+weight_offset);\n" " weights3=vload4(weights_width_base+3,weights+weight_offset);\n" +" #ifdef CHANNEL_BOUNDARY_PROTECT\n" +" weights4=output_channel_idx+1 >= out_channel_blocks ? (FLOAT4)0 : vload4(weights_width_base,weights+weight_offset1);\n" +" weights5=output_channel_idx+1 >= out_channel_blocks ? (FLOAT4)0 : vload4(weights_width_base+1,weights+weight_offset1);\n" +" weights6=output_channel_idx+1 >= out_channel_blocks ? (FLOAT4)0 : vload4(weights_width_base+2,weights+weight_offset1);\n" +" weights7=output_channel_idx+1 >= out_channel_blocks ? (FLOAT4)0 : vload4(weights_width_base+3,weights+weight_offset1);\n" +" #else\n" " weights4=vload4(weights_width_base,weights+weight_offset1);\n" " weights5=vload4(weights_width_base+1,weights+weight_offset1);\n" " weights6=vload4(weights_width_base+2,weights+weight_offset1);\n" " weights7=vload4(weights_width_base+3,weights+weight_offset1);\n" +" #endif\n" "#else\n" " weights0=RI_F(weights,SAMPLER,(int2)(weights_width_base+0,output_channel_idx));\n" " weights1=RI_F(weights,SAMPLER,(int2)(weights_width_base+1,output_channel_idx));\n" @@ -979,10 +991,18 @@ const char* conv_2d = " weights1=mad(CONVERT_FLOAT4(charWeight1),scale0,offset0);\n" " weights2=mad(CONVERT_FLOAT4(charWeight2),scale0,offset0);\n" " weights3=mad(CONVERT_FLOAT4(charWeight3),scale0,offset0);\n" +" #ifdef CHANNEL_BOUNDARY_PROTECT\n" +" charWeight0=out_channel_block_idx+1 >= out_channel_blocks ? (char4)0 : vload4(0,kernel_ptr+weight_offset+weight_oc_offset);\n" +" charWeight1=out_channel_block_idx+1 >= out_channel_blocks ? (char4)0 : vload4(0,kernel_ptr+weight_offset+weight_oc_offset+weight_ic_offset);\n" +" charWeight2=out_channel_block_idx+1 >= out_channel_blocks ? (char4)0 : vload4(0,kernel_ptr+weight_offset+weight_oc_offset+weight_ic_offset*2);\n" +" charWeight3=out_channel_block_idx+1 >= out_channel_blocks ? (char4)0 : vload4(0,kernel_ptr+weight_offset+weight_oc_offset+weight_ic_offset*3);\n" +" \n" +" #else\n" " charWeight0=vload4(0,kernel_ptr+weight_offset+weight_oc_offset);\n" " charWeight1=vload4(0,kernel_ptr+weight_offset+weight_oc_offset+weight_ic_offset);\n" " charWeight2=vload4(0,kernel_ptr+weight_offset+weight_oc_offset+weight_ic_offset*2);\n" " charWeight3=vload4(0,kernel_ptr+weight_offset+weight_oc_offset+weight_ic_offset*3);\n" +" #endif\n" " weights4=mad(CONVERT_FLOAT4(charWeight0),scale1,offset1);\n" " weights5=mad(CONVERT_FLOAT4(charWeight1),scale1,offset1);\n" " weights6=mad(CONVERT_FLOAT4(charWeight2),scale1,offset1);\n" @@ -1051,10 +1071,18 @@ const char* conv_2d = " weights1=vload4(0,weights+weight_offset+weight_ic_offset);\n" " weights2=vload4(0,weights+weight_offset+weight_ic_offset*2);\n" " weights3=vload4(0,weights+weight_offset+weight_ic_offset*3);\n" +" #ifdef CHANNEL_BOUNDARY_PROTECT\n" +" charWeight0 =\n" +" weights4=out_channel_block_idx+1 >= out_channel_blocks ? (FLOAT4)0 : vload4(0,weights+weight_offset+weight_oc_offset);\n" +" weights5=out_channel_block_idx+1 >= out_channel_blocks ? (FLOAT4)0 : vload4(0,weights+weight_offset+weight_ic_offset+weight_oc_offset);\n" +" weights6=out_channel_block_idx+1 >= out_channel_blocks ? (FLOAT4)0 : vload4(0,weights+weight_offset+weight_ic_offset*2+weight_oc_offset);\n" +" weights7=out_channel_block_idx+1 >= out_channel_blocks ? (FLOAT4)0 : vload4(0,weights+weight_offset+weight_ic_offset*3+weight_oc_offset);\n" +" #else\n" " weights4=vload4(0,weights+weight_offset+weight_oc_offset);\n" " weights5=vload4(0,weights+weight_offset+weight_ic_offset+weight_oc_offset);\n" " weights6=vload4(0,weights+weight_offset+weight_ic_offset*2+weight_oc_offset);\n" " weights7=vload4(0,weights+weight_offset+weight_ic_offset*3+weight_oc_offset);\n" +" #endif\n" " weight_offset += 4;\n" "#else\n" " weights0=RI_F(weights,SAMPLER,(int2)(weights_x_idx+0,weights_y_idx));\n" @@ -4303,7 +4331,7 @@ const char* conv_2d_int_buf = "#define GLOBAL_SIZE_2_DIMS __private const int global_size_dim0,__private const int global_size_dim1,\n" "#define DEAL_NON_UNIFORM_DIM2(input1, input2) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1) { "" return; "" }\n" "#define MOD_NUM 15\n" -"#ifdef INPUT_CHANNEL_LEAVE\n" +"#ifdef INPUT_CHANNEL_BOUNDARY_PROTECT\n" " #define PADZEROSVEC(k, channel, data0, data1, data2, data3) "" data0 = (k << 2) < channel ? data0 : 0; "" data1 = (k << 2) + 1 < channel ? data1 : 0; "" data2 = (k << 2) + 2 < channel ? data2 : 0; "" data3=(k << 2)+3= out_c_blocks ? (COMPUTE_FLOAT8)0 : CONVERT_COMPUTE_FLOAT8(vload8(out_c_idx_1,dequantScaleOffset+kindex));\n" +" #else\n" +" COMPUTE_FLOAT8 ScaleOffset1=CONVERT_COMPUTE_FLOAT8(vload8(out_c_idx_1,dequantScaleOffset+kindex));\n" +" #endif\n" " COMPUTE_FLOAT4 scale0=(COMPUTE_FLOAT4)(ScaleOffset0.s0,ScaleOffset0.s2,ScaleOffset0.s4,ScaleOffset0.s6);\n" " COMPUTE_FLOAT4 offset0=(COMPUTE_FLOAT4)(ScaleOffset0.s1,ScaleOffset0.s3,ScaleOffset0.s5,ScaleOffset0.s7);\n" " COMPUTE_FLOAT4 scale1=(COMPUTE_FLOAT4)(ScaleOffset1.s0,ScaleOffset1.s2,ScaleOffset1.s4,ScaleOffset1.s6);\n" " COMPUTE_FLOAT4 offset1=(COMPUTE_FLOAT4)(ScaleOffset1.s1,ScaleOffset1.s3,ScaleOffset1.s5,ScaleOffset1.s7);\n" " //weights NC4HW4 [1,4*icC4,ocC4*kh*kw,1] xic4\n" -" //index: [0,4*in_c_idx,out_c_idx*kh*kw+kh_start*kw+kw_start,0]\n" +" //index: [0,4*in_c_idx,out_c_idx_0*kh*kw+kh_start*kw+kw_start,0]\n" " const int inp_offset_base=(out_b_idx+in_c_idx*batch)*in_hw.x*in_hw.y*4;\n" " for(int iy=0; iy= out_c_blocks ? (char4)0 : vload4(0,weight+weight_offset+weight_oc_offset);\n" +" charWeight1=out_c_idx_1 >= out_c_blocks ? (char4)0 : vload4(0,weight+weight_offset+weight_oc_offset+weight_ic_offset);\n" +" charWeight2=out_c_idx_1 >= out_c_blocks ? (char4)0 : vload4(0,weight+weight_offset+weight_oc_offset+weight_ic_offset*2);\n" +" charWeight3=out_c_idx_1 >= out_c_blocks ? (char4)0 : vload4(0,weight+weight_offset+weight_oc_offset+weight_ic_offset*3);\n" +" #else\n" " charWeight0=vload4(0,weight+weight_offset+weight_oc_offset);\n" " charWeight1=vload4(0,weight+weight_offset+weight_oc_offset+weight_ic_offset);\n" " charWeight2=vload4(0,weight+weight_offset+weight_oc_offset+weight_ic_offset*2);\n" " charWeight3=vload4(0,weight+weight_offset+weight_oc_offset+weight_ic_offset*3);\n" +" #endif\n" " weight0=CONVERT_COMPUTE_FLOAT4(charWeight0)*scale1+offset1;\n" " weight1=CONVERT_COMPUTE_FLOAT4(charWeight1)*scale1+offset1;\n" " weight2=CONVERT_COMPUTE_FLOAT4(charWeight2)*scale1+offset1;\n" @@ -5109,7 +5150,7 @@ const char* conv_2d_int_buf = " out6=clamp(out6,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n" " out7=clamp(out7,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n" "#endif\n" -" int out_offset=(((out_b_idx+out_c_idx*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n" +" int out_offset=(((out_b_idx+out_c_idx_0*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n" "#ifdef BLOCK_LEAVE\n" " const int remain=out_hw.x-out_h_idx;\n" " if(remain >= 4){\n" @@ -5127,12 +5168,12 @@ const char* conv_2d_int_buf = " }else if(remain == 1){\n" " vstore4(CONVERT_FLOAT4(out0),0,output+out_offset);\n" " }\n" -"#ifdef CHANNEL_LEAVE\n" -" if(out_c_idx+1 >= out_c_blocks){\n" +"#ifdef CHANNEL_BOUNDARY_PROTECT\n" +" if(out_c_idx_1 >= out_c_blocks){\n" " return;\n" " }\n" "#endif\n" -" out_offset=(((out_b_idx+(out_c_idx+1)*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n" +" out_offset=(((out_b_idx+out_c_idx_1*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n" " if(remain >= 4){\n" " vstore4(CONVERT_FLOAT4(out4),0,output+out_offset);\n" " vstore4(CONVERT_FLOAT4(out5),out_hw.y,output+out_offset);\n" @@ -5153,12 +5194,12 @@ const char* conv_2d_int_buf = " vstore4(CONVERT_FLOAT4(out1),out_hw.y,output+out_offset);\n" " vstore4(CONVERT_FLOAT4(out2),2*out_hw.y,output+out_offset);\n" " vstore4(CONVERT_FLOAT4(out3),3*out_hw.y,output+out_offset);\n" -"#ifdef CHANNEL_LEAVE\n" -" if(out_c_idx+1 >= out_c_blocks){\n" +"#ifdef CHANNEL_BOUNDARY_PROTECT\n" +" if(out_c_idx_1 >= out_c_blocks){\n" " return;\n" " }\n" "#endif\n" -" out_offset=(((out_b_idx+(out_c_idx+1)*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n" +" out_offset=(((out_b_idx+out_c_idx_1*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n" " vstore4(CONVERT_FLOAT4(out4),0,output+out_offset);\n" " vstore4(CONVERT_FLOAT4(out5),out_hw.y,output+out_offset);\n" " vstore4(CONVERT_FLOAT4(out6),2*out_hw.y,output+out_offset);\n" @@ -5192,14 +5233,16 @@ const char* conv_2d_int_buf = " const int out_c_w_idx=get_global_id(0); //c/4 w\n" " const int out_b_h_idx=get_global_id(1); //b h\n" " DEAL_NON_UNIFORM_DIM2(out_c_w_idx,out_b_h_idx);\n" -" const int out_c_idx=(out_c_w_idx/out_w_blocks) << 1;\n" +" const int out_c_idx_0=(out_c_w_idx/out_w_blocks) << 1;\n" +" const int out_c_idx_1=out_c_idx_0+1;\n" " const int out_w_idx=out_c_w_idx % out_w_blocks;\n" " const int out_b_idx=out_b_h_idx/out_h_blocks;//equal to in_b_idx\n" " const int out_h_idx=(out_b_h_idx % out_h_blocks) << 1;\n" -" COMPUTE_FLOAT4 bias0=CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx,bias));\n" +" COMPUTE_FLOAT4 bias0=CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx_0,bias));\n" " COMPUTE_FLOAT4 out0=bias0;\n" " COMPUTE_FLOAT4 out1=bias0;\n" -" COMPUTE_FLOAT4 bias1=CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx+1,bias));\n" +" // bias align to 8,no need boundry protect\n" +" COMPUTE_FLOAT4 bias1=CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx_1,bias));\n" " COMPUTE_FLOAT4 out2=bias1;\n" " COMPUTE_FLOAT4 out3=bias1;\n" " const int in_w_idx_base=mad24(out_w_idx,stride_hw.y,-pad_hw.y);\n" @@ -5216,17 +5259,21 @@ const char* conv_2d_int_buf = " // weight: [ic/4,oc,4],loop: ic/4\n" " for(ushort in_c_idx=0; in_c_idx= out_c_blocks ? (COMPUTE_FLOAT8)0 : CONVERT_COMPUTE_FLOAT8(vload8(out_c_idx_1,dequantScaleOffset+kindex));\n" +" #else\n" +" COMPUTE_FLOAT8 ScaleOffset1=CONVERT_COMPUTE_FLOAT8(vload8(out_c_idx_1,dequantScaleOffset+kindex));\n" +" #endif\n" " COMPUTE_FLOAT4 scale0=(COMPUTE_FLOAT4)(ScaleOffset0.s0,ScaleOffset0.s2,ScaleOffset0.s4,ScaleOffset0.s6);\n" " COMPUTE_FLOAT4 offset0=(COMPUTE_FLOAT4)(ScaleOffset0.s1,ScaleOffset0.s3,ScaleOffset0.s5,ScaleOffset0.s7);\n" " COMPUTE_FLOAT4 scale1=(COMPUTE_FLOAT4)(ScaleOffset1.s0,ScaleOffset1.s2,ScaleOffset1.s4,ScaleOffset1.s6);\n" " COMPUTE_FLOAT4 offset1=(COMPUTE_FLOAT4)(ScaleOffset1.s1,ScaleOffset1.s3,ScaleOffset1.s5,ScaleOffset1.s7);\n" " //weights NC4HW4 [1,4*icC4,ocC4*kh*kw,1] xic4\n" -" //index: [0,4*in_c_idx,out_c_idx*kh*kw+kh_start*kw+kw_start,0]\n" +" //index: [0,4*in_c_idx,out_c_idx_0*kh*kw+kh_start*kw+kw_start,0]\n" " const int inp_offset_base=(out_b_idx+in_c_idx*batch)*in_hw.x*in_hw.y*4;\n" " for(int iy=0; iy= out_c_blocks ? (char4)0 : vload4(0,weight+weight_offset+weight_oc_offset);\n" +" charWeight1=out_c_idx_1 >= out_c_blocks ? (char4)0 : vload4(0,weight+weight_offset+weight_oc_offset+weight_ic_offset);\n" +" charWeight2=out_c_idx_1 >= out_c_blocks ? (char4)0 : vload4(0,weight+weight_offset+weight_oc_offset+weight_ic_offset*2);\n" +" charWeight3=out_c_idx_1 >= out_c_blocks ? (char4)0 : vload4(0,weight+weight_offset+weight_oc_offset+weight_ic_offset*3);\n" +" #else\n" " charWeight0=vload4(0,weight+weight_offset+weight_oc_offset);\n" " charWeight1=vload4(0,weight+weight_offset+weight_oc_offset+weight_ic_offset);\n" " charWeight2=vload4(0,weight+weight_offset+weight_oc_offset+weight_ic_offset*2);\n" " charWeight3=vload4(0,weight+weight_offset+weight_oc_offset+weight_ic_offset*3);\n" +" #endif\n" " weight0=CONVERT_COMPUTE_FLOAT4(charWeight0)*scale1+offset1;\n" " weight1=CONVERT_COMPUTE_FLOAT4(charWeight1)*scale1+offset1;\n" " weight2=CONVERT_COMPUTE_FLOAT4(charWeight2)*scale1+offset1;\n" @@ -5348,7 +5402,7 @@ const char* conv_2d_int_buf = " out2=clamp(out2,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n" " out3=clamp(out3,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n" "#endif\n" -" int out_offset=(((out_b_idx+out_c_idx*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n" +" int out_offset=(((out_b_idx+out_c_idx_0*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n" "#ifdef BLOCK_LEAVE\n" " const int remain=out_hw.x-out_h_idx;\n" " if(remain >= 2){\n" @@ -5357,12 +5411,12 @@ const char* conv_2d_int_buf = " }else if(remain == 1){\n" " vstore4(CONVERT_FLOAT4(out0),0,output+out_offset);\n" " }\n" -"#ifdef CHANNEL_LEAVE\n" -" if(out_c_idx+1 >= out_c_blocks){\n" +"#ifdef CHANNEL_BOUNDARY_PROTECT\n" +" if(out_c_idx_1 >= out_c_blocks){\n" " return;\n" " }\n" "#endif\n" -" out_offset=(((out_b_idx+(out_c_idx+1)*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n" +" out_offset=(((out_b_idx+out_c_idx_1*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n" " if(remain >= 2){\n" " vstore4(CONVERT_FLOAT4(out2),0,output+out_offset);\n" " vstore4(CONVERT_FLOAT4(out3),out_hw.y,output+out_offset);\n" @@ -5372,12 +5426,12 @@ const char* conv_2d_int_buf = "#else\n" " vstore4(CONVERT_FLOAT4(out0),0,output+out_offset);\n" " vstore4(CONVERT_FLOAT4(out1),out_hw.y,output+out_offset);\n" -"#ifdef CHANNEL_LEAVE\n" -" if(out_c_idx+1 >= out_c_blocks){\n" +"#ifdef CHANNEL_BOUNDARY_PROTECT\n" +" if(out_c_idx_1 >= out_c_blocks){\n" " return;\n" " }\n" "#endif\n" -" out_offset=(((out_b_idx+(out_c_idx+1)*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n" +" out_offset=(((out_b_idx+out_c_idx_1*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n" " vstore4(CONVERT_FLOAT4(out2),0,output+out_offset);\n" " vstore4(CONVERT_FLOAT4(out3),out_hw.y,output+out_offset);\n" "#endif\n" @@ -5409,17 +5463,19 @@ const char* conv_2d_int_buf = " const int out_c_w_idx=get_global_id(0); //c/4 w\n" " const int out_b_h_idx=get_global_id(1); //b h\n" " DEAL_NON_UNIFORM_DIM2(out_c_w_idx,out_b_h_idx);\n" -" const int out_c_idx=(out_c_w_idx/out_w_blocks) << 1;\n" +" const int out_c_idx_0=(out_c_w_idx/out_w_blocks) << 1;\n" +" const int out_c_idx_1=out_c_idx_0+1;\n" " const int out_w_idx=(out_c_w_idx % out_w_blocks) << 2;\n" " const int out_b_idx=out_b_h_idx/out_hw.x;//equal to in_b_idx\n" " const int out_h_idx=out_b_h_idx % out_hw.x;\n" " \n" -" COMPUTE_FLOAT4 bias0=CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx,bias));\n" +" COMPUTE_FLOAT4 bias0=CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx_0,bias));\n" " COMPUTE_FLOAT4 out0=bias0;\n" " COMPUTE_FLOAT4 out1=bias0;\n" " COMPUTE_FLOAT4 out2=bias0;\n" " COMPUTE_FLOAT4 out3=bias0;\n" -" COMPUTE_FLOAT4 bias1=CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx+1,bias));\n" +" // bias align to 8,no need boundry protect\n" +" COMPUTE_FLOAT4 bias1=CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx_1,bias));\n" " COMPUTE_FLOAT4 out4=bias1;\n" " COMPUTE_FLOAT4 out5=bias1;\n" " COMPUTE_FLOAT4 out6=bias1;\n" @@ -5438,15 +5494,19 @@ const char* conv_2d_int_buf = " const int weight_ic_offset=out_c_blocks*weight_oc_offset;\n" " for(ushort in_c_idx=0; in_c_idx= out_c_blocks ? (COMPUTE_FLOAT8)0 : CONVERT_COMPUTE_FLOAT8(vload8(out_c_idx_1,dequantScaleOffset+kindex));\n" +" #else\n" +" COMPUTE_FLOAT8 ScaleOffset1=CONVERT_COMPUTE_FLOAT8(vload8(out_c_idx_1,dequantScaleOffset+kindex));\n" +" #endif\n" " COMPUTE_FLOAT4 scale0=(COMPUTE_FLOAT4)(ScaleOffset0.s0,ScaleOffset0.s2,ScaleOffset0.s4,ScaleOffset0.s6);\n" " COMPUTE_FLOAT4 offset0=(COMPUTE_FLOAT4)(ScaleOffset0.s1,ScaleOffset0.s3,ScaleOffset0.s5,ScaleOffset0.s7);\n" " COMPUTE_FLOAT4 scale1=(COMPUTE_FLOAT4)(ScaleOffset1.s0,ScaleOffset1.s2,ScaleOffset1.s4,ScaleOffset1.s6);\n" " COMPUTE_FLOAT4 offset1=(COMPUTE_FLOAT4)(ScaleOffset1.s1,ScaleOffset1.s3,ScaleOffset1.s5,ScaleOffset1.s7);\n" " //weights NC4HW4 [1,4*icC4,ocC4*kh*kw,1] xic4\n" -" //index: [0,4*in_c_idx,out_c_idx*kh*kw+kh_start*kw+kw_start,0]\n" -" int weight_offset=((((4*in_c_idx+0)* out_c_blocks+out_c_idx) *filter_hw.x+kh_start)*filter_hw.y+0)*4;\n" +" //index: [0,4*in_c_idx,out_c_idx_0*kh*kw+kh_start*kw+kw_start,0]\n" +" int weight_offset=((((4*in_c_idx+0)* out_c_blocks+out_c_idx_0) *filter_hw.x+kh_start)*filter_hw.y+0)*4;\n" " for(int iy=in_h_idx_start; iy= out_c_blocks ? (char4)0 : vload4(0,weight+weight_offset+weight_oc_offset);\n" +" charWeight1=out_c_idx_1 >= out_c_blocks ? (char4)0 : vload4(0,weight+weight_offset+weight_oc_offset+weight_ic_offset);\n" +" charWeight2=out_c_idx_1 >= out_c_blocks ? (char4)0 : vload4(0,weight+weight_offset+weight_oc_offset+weight_ic_offset*2);\n" +" charWeight3=out_c_idx_1 >= out_c_blocks ? (char4)0 : vload4(0,weight+weight_offset+weight_oc_offset+weight_ic_offset*3);\n" +" #else\n" " charWeight0=vload4(0,weight+weight_offset+weight_oc_offset);\n" " charWeight1=vload4(0,weight+weight_offset+weight_oc_offset+weight_ic_offset);\n" " charWeight2=vload4(0,weight+weight_offset+weight_oc_offset+weight_ic_offset*2);\n" " charWeight3=vload4(0,weight+weight_offset+weight_oc_offset+weight_ic_offset*3);\n" +" #endif\n" " weight0=CONVERT_COMPUTE_FLOAT4(charWeight0)*scale1+offset1;\n" " weight1=CONVERT_COMPUTE_FLOAT4(charWeight1)*scale1+offset1;\n" " weight2=CONVERT_COMPUTE_FLOAT4(charWeight2)*scale1+offset1;\n" @@ -5603,7 +5670,7 @@ const char* conv_2d_int_buf = " out6=clamp(out6,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n" " out7=clamp(out7,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n" "#endif\n" -" int out_offset=(((out_b_idx+out_c_idx*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n" +" int out_offset=(((out_b_idx+out_c_idx_0*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n" "#ifdef BLOCK_LEAVE\n" " const int remain=out_hw.y-out_w_idx;\n" " if(remain >= 4){\n" @@ -5616,10 +5683,10 @@ const char* conv_2d_int_buf = " }else if(remain == 1){\n" " vstore4(CONVERT_FLOAT4(out0),0,output+out_offset);\n" " }\n" -"#ifdef CHANNEL_LEAVE\n" -" if(out_c_idx+1 >= out_c_blocks)return;\n" +"#ifdef CHANNEL_BOUNDARY_PROTECT\n" +" if(out_c_idx_1 >= out_c_blocks)return;\n" "#endif\n" -" out_offset=(((out_b_idx+(out_c_idx+1)*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n" +" out_offset=(((out_b_idx+out_c_idx_1*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n" " if(remain >= 4){\n" " vstore16(CONVERT_FLOAT16((COMPUTE_FLOAT16)(out4,out5,out6,out7)),0,output+out_offset);\n" " }else if(remain == 3){\n" @@ -5632,10 +5699,10 @@ const char* conv_2d_int_buf = " }\n" "#else\n" " vstore16(CONVERT_FLOAT16((COMPUTE_FLOAT16)(out0,out1,out2,out3)),0,output+out_offset);\n" -"#ifdef CHANNEL_LEAVE\n" -" if(out_c_idx+1 >= out_c_blocks)return;\n" +"#ifdef CHANNEL_BOUNDARY_PROTECT\n" +" if(out_c_idx_1 >= out_c_blocks)return;\n" "#endif\n" -" out_offset=(((out_b_idx+(out_c_idx+1)*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n" +" out_offset=(((out_b_idx+out_c_idx_1*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n" " vstore16(CONVERT_FLOAT16((COMPUTE_FLOAT16)(out4,out5,out6,out7)),0,output+out_offset);\n" "#endif\n" "}\n" @@ -8316,14 +8383,18 @@ const char* depthwise_conv2d_buf = " COMPUTE_FLOAT4 inValue1=(in_w_start_1+kw<0 || in_w_start_1+kw >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+1,input+inp_offset_c0));\n" " COMPUTE_FLOAT4 inValue2=(in_w_start_2+kw<0 || in_w_start_2+kw >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+2,input+inp_offset_c0));\n" " COMPUTE_FLOAT4 inValue3=(in_w_start_3+kw<0 || in_w_start_3+kw >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+3,input+inp_offset_c0));\n" -" COMPUTE_FLOAT4 inValue4=(in_w_start_0+kw<0 || in_w_start_0+kw >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+0,input+inp_offset_c1));\n" -" COMPUTE_FLOAT4 inValue5=(in_w_start_1+kw<0 || in_w_start_1+kw >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+1,input+inp_offset_c1));\n" -" COMPUTE_FLOAT4 inValue6=(in_w_start_2+kw<0 || in_w_start_2+kw >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+2,input+inp_offset_c1));\n" -" COMPUTE_FLOAT4 inValue7=(in_w_start_3+kw<0 || in_w_start_3+kw >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+3,input+inp_offset_c1));\n" +" COMPUTE_FLOAT4 inValue4=(in_w_start_0+kw<0 || in_w_start_0+kw >= in_hw.y || c_idx+1 >= c_blocks) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+0,input+inp_offset_c1));\n" +" COMPUTE_FLOAT4 inValue5=(in_w_start_1+kw<0 || in_w_start_1+kw >= in_hw.y || c_idx+1 >= c_blocks) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+1,input+inp_offset_c1));\n" +" COMPUTE_FLOAT4 inValue6=(in_w_start_2+kw<0 || in_w_start_2+kw >= in_hw.y || c_idx+1 >= c_blocks) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+2,input+inp_offset_c1));\n" +" COMPUTE_FLOAT4 inValue7=(in_w_start_3+kw<0 || in_w_start_3+kw >= in_hw.y || c_idx+1 >= c_blocks) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+3,input+inp_offset_c1));\n" " \n" " //NC4HW4 [1,filterShape.x*filterShape.y,1,channelBlocks] x oc4\n" " //index: [0,filterIdx,0,inChannelBlockIdx]\n" " COMPUTE_FLOAT4 weights_0=CONVERT_COMPUTE_FLOAT4(vload4(0,filter+(filter_idx*c_blocks+c_idx+0)*4));\n" +" /*\n" +" weight:[kh*kw,oc/4,oc_4],memory align to 8\n" +" no need to boundry prptecy\n" +" */\n" " COMPUTE_FLOAT4 weights_1=CONVERT_COMPUTE_FLOAT4(vload4(0,filter+(filter_idx*c_blocks+c_idx+1)*4));\n" " outValue0=mad(inValue0,weights_0,outValue0);\n" " outValue1=mad(inValue1,weights_0,outValue1);\n" @@ -8435,11 +8506,15 @@ const char* depthwise_conv2d_buf = " const int filter_idx=mad24(kh,filter_hw.y,kw);\n" " COMPUTE_FLOAT4 inValue0=(in_w_start_0+kw<0 || in_w_start_0+kw >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+0,input+inp_offset_c0));\n" " COMPUTE_FLOAT4 inValue1=(in_w_start_1+kw<0 || in_w_start_1+kw >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+1,input+inp_offset_c0));\n" -" COMPUTE_FLOAT4 inValue4=(in_w_start_0+kw<0 || in_w_start_0+kw >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+0,input+inp_offset_c1));\n" -" COMPUTE_FLOAT4 inValue5=(in_w_start_1+kw<0 || in_w_start_1+kw >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+1,input+inp_offset_c1));\n" +" COMPUTE_FLOAT4 inValue4=(in_w_start_0+kw<0 || in_w_start_0+kw >= in_hw.y || c_idx+1 >= c_blocks) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+0,input+inp_offset_c1));\n" +" COMPUTE_FLOAT4 inValue5=(in_w_start_1+kw<0 || in_w_start_1+kw >= in_hw.y || c_idx+1 >= c_blocks) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+1,input+inp_offset_c1));\n" " //NC4HW4 [1,filterShape.x*filterShape.y,1,channelBlocks] x oc4\n" " //index: [0,filterIdx,0,inChannelBlockIdx]\n" " COMPUTE_FLOAT4 weights_0=CONVERT_COMPUTE_FLOAT4(vload4(0,filter+(filter_idx*c_blocks+c_idx+0)*4));\n" +" /*\n" +" weight:[kh*kw,oc/4,oc_4],memory align to 8\n" +" no need to boundry protect\n" +" */\n" " COMPUTE_FLOAT4 weights_1=CONVERT_COMPUTE_FLOAT4(vload4(0,filter+(filter_idx*c_blocks+c_idx+1)*4));\n" " outValue0=mad(inValue0,weights_0,outValue0);\n" " outValue1=mad(inValue1,weights_0,outValue1);\n" @@ -8801,6 +8876,206 @@ const char* depthwise_conv2d_buf = "}\n" ; #endif +const char* glmem_convert = +"#ifdef MNN_SUPPORT_FP16\n" +"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n" +"#endif\n" +"#define GLOBAL_SIZE_3_DIMS __private const int global_size_dim0,__private const int global_size_dim1,__private const int global_size_dim2,\n" +"#define DEAL_NON_UNIFORM_DIM3(input1, input2, input3) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1 || input3 >= global_size_dim2) { "" return; "" }\n" +"#define MNN_DATA_FORMAT_NCHW 0\n" +"#define MNN_DATA_FORMAT_NHWC 1\n" +"#define MNN_DATA_FORMAT_NC4HW4 2\n" +"#define MNN_DATA_FORMAT_C4NHW4 3\n" +"#define __CAT(x,y) x##y\n" +"#define CAT(x,y) __CAT(x,y)\n" +"#define OUTPUT_TYPE2 CAT(OUTPUT_TYPE,2)\n" +"#define OUTPUT_TYPE3 CAT(OUTPUT_TYPE,3)\n" +"__constant sampler_t SAMPLER=CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" +"#ifdef SHARED_TO_CL\n" +"__kernel void gl_to_cl(GLOBAL_SIZE_3_DIMS\n" +" __global uchar *input_ptr,\n" +" #ifdef USE_IMAGE\n" +" __write_only image2d_t output_ptr,\n" +" #else\n" +" __global OUTPUT_TYPE *output_ptr,\n" +" #endif\n" +" __private const int4 shape // N C H W\n" +") {\n" +" int wblock=get_global_id(0);\n" +" int cblock=get_global_id(1);\n" +" int nh=get_global_id(2);\n" +" DEAL_NON_UNIFORM_DIM3(wblock,cblock,nh);\n" +" const int w=wblock << 2;\n" +" const int h=nh % shape.z;\n" +" const int c=cblock << 2;\n" +" const int n=nh/shape.z;\n" +" \n" +" int idx=c*shape.w+w; // c/4*w\n" +" int idy=nh; // n*h\n" +" const int offset=idy*shape.w*4;\n" +" OUTPUT_TYPE4 in0=CONVERT_OUTPUT4(vload4(idx,input_ptr+offset));\n" +" OUTPUT_TYPE4 in1=CONVERT_OUTPUT4(vload4(idx+1,input_ptr+offset));\n" +" OUTPUT_TYPE4 in2=CONVERT_OUTPUT4(vload4(idx+2,input_ptr+offset));\n" +" OUTPUT_TYPE4 in3=CONVERT_OUTPUT4(vload4(idx+3,input_ptr+offset));\n" +"#ifdef USE_IMAGE\n" +" WI_DATA(output_ptr,(int2)(idx,idy),in0);\n" +" if(w+1 >= shape.w) return;\n" +" WI_DATA(output_ptr,(int2)(idx+1,idy),in1);\n" +" if(w+2 >= shape.w) return;\n" +" WI_DATA(output_ptr,(int2)(idx+2,idy),in2);\n" +" if(w+3 >= shape.w) return;\n" +" WI_DATA(output_ptr,(int2)(idx+3,idy),in3);\n" +"#else\n" +" #if OUTPUT_FORMAT == MNN_DATA_FORMAT_NCHW\n" +" int output_offset=((n*shape.y+c)*shape.z+h)*shape.w+w;\n" +" int stride=shape.z*shape.w;\n" +" int remain=shape.w-w;\n" +" if(remain >= 4){\n" +" vstore4((OUTPUT_TYPE4)(in0.x,in1.x,in2.x,in3.x),0,output_ptr+output_offset);\n" +" if(c+1 >= shape.y) return;\n" +" vstore4((OUTPUT_TYPE4)(in0.y,in1.y,in2.y,in3.y),0,output_ptr+output_offset+stride);\n" +" if(c+2 >= shape.y) return;\n" +" vstore4((OUTPUT_TYPE4)(in0.z,in1.z,in2.z,in3.z),0,output_ptr+output_offset+stride+stride);\n" +" if(c+3 >= shape.y) return;\n" +" vstore4((OUTPUT_TYPE4)(in0.w,in1.w,in2.w,in3.w),0,output_ptr+output_offset+stride+stride+stride);\n" +" } else if(remain == 3){\n" +" vstore3((OUTPUT_TYPE3)(in0.x,in1.x,in2.x),0,output_ptr+output_offset);\n" +" if(c+1 >= shape.y) return;\n" +" vstore3((OUTPUT_TYPE3)(in0.y,in1.y,in2.y),0,output_ptr+output_offset+stride);\n" +" if(c+2 >= shape.y) return;\n" +" vstore3((OUTPUT_TYPE3)(in0.z,in1.z,in2.z),0,output_ptr+output_offset+stride+stride);\n" +" if(c+3 >= shape.y) return;\n" +" vstore3((OUTPUT_TYPE3)(in0.w,in1.w,in2.w),0,output_ptr+output_offset+stride+stride+stride);\n" +" } else if(remain == 2){\n" +" vstore2((OUTPUT_TYPE2)(in0.x,in1.x),0,output_ptr+output_offset);\n" +" if(c+1 >= shape.y) return;\n" +" vstore2((OUTPUT_TYPE2)(in0.y,in1.y),0,output_ptr+output_offset+stride);\n" +" if(c+2 >= shape.y) return;\n" +" vstore2((OUTPUT_TYPE2)(in0.z,in1.z),0,output_ptr+output_offset+stride+stride);\n" +" if(c+3 >= shape.y) return;\n" +" vstore2((OUTPUT_TYPE2)(in0.w,in1.w),0,output_ptr+output_offset+stride+stride+stride);\n" +" }else if(remain == 1){\n" +" output_ptr[output_offset]=in0.x;\n" +" if(c+1 >= shape.y) return;\n" +" output_ptr[output_offset+stride]=in0.y;\n" +" if(c+2 >= shape.y) return;\n" +" output_ptr[output_offset+stride+stride]=in0.z;\n" +" if(c+3 >= shape.y) return;\n" +" output_ptr[output_offset+stride+stride+stride]=in0.w;\n" +" }\n" +" #elif OUTPUT_FORMAT == MNN_DATA_FORMAT_NHWC\n" +" int output_offset=((n*shape.z+h)*shape.w+w)*shape.y+c;\n" +" int remain=shape.y-c;\n" +" if(remain >= 4){\n" +" vstore4(CONVERT_OUTPUT4(in0),0,output_ptr+output_offset);\n" +" if(w+1 >= shape.w) return;\n" +" vstore4(CONVERT_OUTPUT4(in1),0,output_ptr+output_offset+shape.y);\n" +" if(w+2 >= shape.w) return;\n" +" vstore4(CONVERT_OUTPUT4(in2),0,output_ptr+output_offset+shape.y+shape.y);\n" +" if(w+3 >= shape.w) return;\n" +" vstore4(CONVERT_OUTPUT4(in3),0,output_ptr+output_offset+shape.y+shape.y+shape.y);\n" +" } else if(remain == 3){\n" +" vstore3((OUTPUT_TYPE3)(in0.x,in0.y,in0.z),0,output_ptr+output_offset);\n" +" if(w+1 >= shape.w) return;\n" +" vstore3((OUTPUT_TYPE3)(in1.x,in1.y,in1.z),0,output_ptr+output_offset+shape.y);\n" +" if(w+2 >= shape.w) return;\n" +" vstore3((OUTPUT_TYPE3)(in2.x,in2.y,in2.z),0,output_ptr+output_offset+shape.y+shape.y);\n" +" if(w+3 >= shape.w) return;\n" +" vstore3((OUTPUT_TYPE3)(in3.x,in3.y,in3.z),0,output_ptr+output_offset+shape.y+shape.y+shape.y);\n" +" } else if(remain == 2){\n" +" vstore2((OUTPUT_TYPE2)(in0.x,in0.y),0,output_ptr+output_offset);\n" +" if(w+1 >= shape.w) return;\n" +" vstore2((OUTPUT_TYPE2)(in1.x,in1.y),0,output_ptr+output_offset+shape.y);\n" +" if(w+2 >= shape.w) return;\n" +" vstore2((OUTPUT_TYPE2)(in2.x,in2.y),0,output_ptr+output_offset+shape.y+shape.y);\n" +" if(w+3 >= shape.w) return;\n" +" vstore2((OUTPUT_TYPE2)(in3.x,in3.y),0,output_ptr+output_offset+shape.y+shape.y+shape.y);\n" +" }else if(remain == 1){\n" +" output_ptr[output_offset]=in0.x;\n" +" if(w+1 >= shape.w) return;\n" +" output_ptr[output_offset+shape.y]=in1.x;\n" +" if(w+2 >= shape.w) return;\n" +" output_ptr[output_offset+shape.y+shape.y]=in1.x;\n" +" if(w+3 >= shape.w) return;\n" +" output_ptr[output_offset+shape.y+shape.y+shape.y]=in1.x;\n" +" }\n" +" #elif OUTPUT_FORMAT == MNN_DATA_FORMAT_NC4HW4\n" +" int output_offset=(((cblock*shape.x+n)*shape.z+h)*shape.w+w)*4;\n" +" vstore4(in0,0,output_ptr+output_offset);\n" +" if(w+1 >= shape.w) return;\n" +" vstore4(in1,0,output_ptr+output_offset+4);\n" +" if(w+2 >= shape.w) return;\n" +" vstore4(in2,0,output_ptr+output_offset+8);\n" +" if(w+3 >= shape.w) return;\n" +" vstore4(in3,0,output_ptr+output_offset+12);\n" +" #endif\n" +"#endif\n" +"}\n" +"#endif\n" +"#ifdef CL_TO_SHARED\n" +"__kernel void cl_to_gl(GLOBAL_SIZE_3_DIMS\n" +" #ifdef USE_IMAGE\n" +" __read_only image2d_t input_ptr,\n" +" #else\n" +" __global INPUT_TYPE *input_ptr,\n" +" #endif\n" +" __global uchar *output_ptr,\n" +" __private const int4 shape // N C H W\n" +") {\n" +" int wblock=get_global_id(0);\n" +" int cblock=get_global_id(1);\n" +" int nh=get_global_id(2);\n" +" DEAL_NON_UNIFORM_DIM3(wblock,cblock,nh);\n" +" const int w=wblock << 2;\n" +" const int h=nh % shape.z;\n" +" const int c=cblock << 2;\n" +" const int n=nh/shape.z;\n" +" \n" +" int idx=c*shape.w+w; // c/4*w\n" +" int idy=nh; // n*h\n" +"#ifdef USE_IMAGE\n" +" INPUT_TYPE4 in0=RI_DATA(input_ptr,SAMPLER,(int2)(idx,idy));\n" +" INPUT_TYPE4 in1=RI_DATA(input_ptr,SAMPLER,(int2)(idx+1,idy));\n" +" INPUT_TYPE4 in2=RI_DATA(input_ptr,SAMPLER,(int2)(idx+2,idy));\n" +" INPUT_TYPE4 in3=RI_DATA(input_ptr,SAMPLER,(int2)(idx+3,idy));\n" +"#else\n" +" #if INPUT_FORMAT == MNN_DATA_FORMAT_NCHW\n" +" int input_offset=((n*shape.y+c)*shape.z+h)*shape.w+w;\n" +" int stride=shape.z*shape.w;\n" +" INPUT_TYPE4 tmp0,tmp1,tmp2,tmp3;\n" +" tmp0=vload4(0,input_ptr+input_offset);\n" +" tmp1=vload4(0,input_ptr+input_offset+stride);\n" +" tmp2=vload4(0,input_ptr+input_offset+stride+stride);\n" +" tmp3=vload4(0,input_ptr+input_offset+stride+stride+stride);\n" +" INPUT_TYPE4 in0=(INPUT_TYPE4)(tmp0.x,tmp1.x,tmp2.x,tmp3.x);\n" +" INPUT_TYPE4 in1=(INPUT_TYPE4)(tmp0.y,tmp1.y,tmp2.y,tmp3.y);\n" +" INPUT_TYPE4 in2=(INPUT_TYPE4)(tmp0.z,tmp1.z,tmp2.z,tmp3.z);\n" +" INPUT_TYPE4 in3=(INPUT_TYPE4)(tmp0.w,tmp1.w,tmp2.w,tmp3.w);\n" +" #elif INPUT_FORMAT == MNN_DATA_FORMAT_NHWC\n" +" int input_offset=((n*shape.z+h)*shape.w+w)*shape.y+c;\n" +" INPUT_TYPE4 in0=vload4(0,input_ptr+input_offset);\n" +" INPUT_TYPE4 in1=vload4(0,input_ptr+input_offset+shape.y);\n" +" INPUT_TYPE4 in2=vload4(0,input_ptr+input_offset+shape.y+shape.y);\n" +" INPUT_TYPE4 in3=vload4(0,input_ptr+input_offset+shape.y+shape.y+shape.y);\n" +" #elif INPUT_FORMAT == MNN_DATA_FORMAT_NC4HW4\n" +" int input_offset=(((cblock*shape.x+n)*shape.z+h)*shape.w+w)*4;\n" +" INPUT_TYPE4 in0=vload4(0,input_ptr+input_offset);\n" +" INPUT_TYPE4 in1=vload4(0,input_ptr+input_offset+4);\n" +" INPUT_TYPE4 in2=vload4(0,input_ptr+input_offset+8);\n" +" INPUT_TYPE4 in3=vload4(0,input_ptr+input_offset+12);\n" +" #endif\n" +"#endif\n" +" const int offset=idy*shape.w*4;\n" +" vstore4(convert_uchar4(in0),idx,output_ptr+offset);\n" +" if(w+1 >= shape.w) return;\n" +" vstore4(convert_uchar4(in1),idx+1,output_ptr+offset);\n" +" if(w+2 >= shape.w) return;\n" +" vstore4(convert_uchar4(in2),idx+2,output_ptr+offset);\n" +" if(w+3 >= shape.w) return;\n" +" vstore4(convert_uchar4(in3),idx+3,output_ptr+offset);\n" +"}\n" +"#endif\n" +; #ifndef MNN_OPENCL_BUFFER_CLOSED const char* winogradTransform_buf = "#ifdef MNN_SUPPORT_FP16\n" @@ -13607,23 +13882,31 @@ const char* conv_2d_buf = " const int out_c_w_idx=get_global_id(0); //c/8 w/4\n" " const int out_b_h_idx=get_global_id(1); //b h\n" " DEAL_NON_UNIFORM_DIM2(out_c_w_idx,out_b_h_idx);\n" -" const int out_c_idx=out_c_w_idx/out_w_blocks;\n" +" const int out_c_idx_0=(out_c_w_idx/out_w_blocks) << 1;\n" +" const int out_c_idx_1=out_c_idx_0+1;\n" " const int out_w_idx=out_c_w_idx % out_w_blocks;\n" " const int out_b_idx=out_b_h_idx/out_h;//equal to in_b_idx\n" " const int out_h_idx=out_b_h_idx % out_h;//equal to in_h_idx\n" " const int out_w4_idx=mul24(out_w_idx,4);\n" -" COMPUTE_FLOAT4 out0=CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx<<1,bias_ptr));\n" +" COMPUTE_FLOAT4 out0=CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx_0,bias_ptr));\n" " COMPUTE_FLOAT4 out1=out0;\n" " COMPUTE_FLOAT4 out2=out0;\n" " COMPUTE_FLOAT4 out3=out0;\n" " \n" -" COMPUTE_FLOAT4 out4=CONVERT_COMPUTE_FLOAT4(vload4((out_c_idx<<1)+1,bias_ptr));\n" +" #ifdef CHANNEL_BOUNDARY_PROTECT\n" +" COMPUTE_FLOAT4 out4=out_c_idx_1 >= out_c_block ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx_1,bias_ptr));\n" " COMPUTE_FLOAT4 out5=out4;\n" " COMPUTE_FLOAT4 out6=out4;\n" " COMPUTE_FLOAT4 out7=out4;\n" +" #else\n" +" COMPUTE_FLOAT4 out4=CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx_1,bias_ptr));\n" +" COMPUTE_FLOAT4 out5=out4;\n" +" COMPUTE_FLOAT4 out6=out4;\n" +" COMPUTE_FLOAT4 out7=out4;\n" +" #endif\n" " const int intput_width_idx0=out_w4_idx;\n" " int inp_offset=((out_b_idx*out_h+out_h_idx)* out_w+intput_width_idx0)<<2;\n" -" int offset=out_c_idx*8;\n" +" int offset=out_c_idx_0*4;\n" " const int inp_add=out_b*out_h*out_w*4;\n" " for (int in_channel_block_idx=0; in_channel_block_idx= out_c_block) {\n" +"#ifdef CHANNEL_BOUNDARY_PROTECT\n" +" if(out_c_idx_1 >= out_c_block) {\n" " return;\n" " }\n" "#endif\n" @@ -13737,8 +14021,8 @@ const char* conv_2d_buf = " }\n" "#else\n" " vstore16(CONVERT_FLOAT16((COMPUTE_FLOAT16)(out0,out1,out2,out3)),0,_tempoutput);\n" -"#ifdef CHANNEL_LEAVE\n" -" if(out_c_idx*2+1 >= out_c_block) {\n" +"#ifdef CHANNEL_BOUNDARY_PROTECT\n" +" if(out_c_idx_1 >= out_c_block) {\n" " return;\n" " }\n" "#endif\n" @@ -13760,20 +14044,25 @@ const char* conv_2d_buf = " const int out_c_w_idx=get_global_id(0); //c/8 w/4\n" " const int out_b_h_idx=get_global_id(1); //b h\n" " DEAL_NON_UNIFORM_DIM2(out_c_w_idx,out_b_h_idx);\n" -" const int out_c_idx=out_c_w_idx/out_w_blocks;\n" +" const int out_c_idx_0=(out_c_w_idx/out_w_blocks) << 1;\n" +" const int out_c_idx_1=out_c_idx_0+1;\n" " const int out_w_idx=out_c_w_idx % out_w_blocks;\n" " const int out_b_idx=out_b_h_idx/out_h;//equal to in_b_idx\n" " const int out_h_idx=out_b_h_idx % out_h;//equal to in_h_idx\n" " \n" " const int out_w2_idx=mul24(out_w_idx,2);\n" -" COMPUTE_FLOAT4 out0=CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx<<1,bias_ptr));\n" +" COMPUTE_FLOAT4 out0=CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx_0,bias_ptr));\n" " COMPUTE_FLOAT4 out1=out0;\n" " \n" -" COMPUTE_FLOAT4 out4=CONVERT_COMPUTE_FLOAT4(vload4((out_c_idx<<1)+1,bias_ptr));\n" +" #ifdef CHANNEL_BOUNDARY_PROTECT\n" +" COMPUTE_FLOAT4 out4=out_c_idx_1 >= out_c_block ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx_1,bias_ptr));\n" +" #else\n" +" COMPUTE_FLOAT4 out4=CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx_1,bias_ptr));\n" +" #endif\n" " COMPUTE_FLOAT4 out5=out4;\n" " const int intput_width_idx0=out_w2_idx;\n" " int inp_offset=((out_b_idx*out_h+out_h_idx)* out_w+intput_width_idx0)<<2;\n" -" int offset=out_c_idx*8;\n" +" int offset=out_c_idx_0*4;\n" " const int inp_add=out_b*out_h*out_w*4;\n" " for (int in_channel_block_idx=0; in_channel_block_idx= out_c_block) {\n" +"#ifdef CHANNEL_BOUNDARY_PROTECT\n" +" if(out_c_idx_1 >= out_c_block) {\n" " return;\n" " }\n" "#endif\n" @@ -13844,8 +14133,8 @@ const char* conv_2d_buf = " }\n" "#else\n" " vstore8(CONVERT_FLOAT8((COMPUTE_FLOAT8)(out0,out1)),0,_tempoutput);\n" -"#ifdef CHANNEL_LEAVE\n" -" if(out_c_idx*2+1 >= out_c_block) {\n" +"#ifdef CHANNEL_BOUNDARY_PROTECT\n" +" if(out_c_idx_1 >= out_c_block) {\n" " return;\n" " }\n" "#endif\n" @@ -14383,16 +14672,21 @@ const char* conv_2d_buf = " const int out_c_w_idx=get_global_id(0); //c/4 w\n" " const int out_b_h_idx=get_global_id(1); //b h\n" " DEAL_NON_UNIFORM_DIM2(out_c_w_idx,out_b_h_idx);\n" -" const int out_c_idx=(out_c_w_idx/out_w_blocks) << 1;\n" +" const int out_c_idx_0=(out_c_w_idx/out_w_blocks) << 1;\n" +" const int out_c_idx_1=out_c_idx_0+1;\n" " const int out_w_idx=out_c_w_idx % out_w_blocks;\n" " const int out_b_idx=out_b_h_idx/out_h_blocks;//equal to in_b_idx\n" " const int out_h_idx=(out_b_h_idx % out_h_blocks) << 2;\n" " \n" -" COMPUTE_FLOAT4 out0=CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx,bias));\n" +" COMPUTE_FLOAT4 out0=CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx_0,bias));\n" " COMPUTE_FLOAT4 out1=out0;\n" " COMPUTE_FLOAT4 out2=out0;\n" " COMPUTE_FLOAT4 out3=out0;\n" -" COMPUTE_FLOAT4 out4=CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx+1,bias));\n" +" #ifdef CHANNEL_BOUNDARY_PROTECT\n" +" COMPUTE_FLOAT4 out4=out_c_idx_1 >= out_c_blocks ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx_1,bias));\n" +" #else\n" +" COMPUTE_FLOAT4 out4=CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx_1,bias));\n" +" #endif\n" " COMPUTE_FLOAT4 out5=out4;\n" " COMPUTE_FLOAT4 out6=out4;\n" " COMPUTE_FLOAT4 out7=out4;\n" @@ -14410,11 +14704,11 @@ const char* conv_2d_buf = " const int weight_ic_offset=out_c_blocks*weight_oc_offset;\n" " const int in_hw_size=in_hw.x*in_hw.y;\n" " for(ushort in_c_idx=0; in_c_idx= out_c_blocks ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(0,weight+weight_offset+weight_oc_offset));\n" +" weight1=out_c_idx_1 >= out_c_blocks ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(0,weight+weight_offset+weight_oc_offset+weight_ic_offset));\n" +" weight2=out_c_idx_1 >= out_c_blocks ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(0,weight+weight_offset+weight_oc_offset+weight_ic_offset*2));\n" +" weight3=out_c_idx_1 >= out_c_blocks ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(0,weight+weight_offset+weight_oc_offset+weight_ic_offset*3));\n" +" #else\n" " weight0=CONVERT_COMPUTE_FLOAT4(vload4(0,weight+weight_offset+weight_oc_offset));\n" " weight1=CONVERT_COMPUTE_FLOAT4(vload4(0,weight+weight_offset+weight_oc_offset+weight_ic_offset));\n" " weight2=CONVERT_COMPUTE_FLOAT4(vload4(0,weight+weight_offset+weight_oc_offset+weight_ic_offset*2));\n" " weight3=CONVERT_COMPUTE_FLOAT4(vload4(0,weight+weight_offset+weight_oc_offset+weight_ic_offset*3));\n" +" #endif\n" " out4=mad(in0.x,weight0,out4);\n" " out4=mad(in0.y,weight1,out4);\n" " out4=mad(in0.z,weight2,out4);\n" @@ -14496,7 +14798,7 @@ const char* conv_2d_buf = " out6=clamp(out6,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n" " out7=clamp(out7,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n" "#endif\n" -" int out_offset=(((out_b_idx+out_c_idx*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n" +" int out_offset=(((out_b_idx+out_c_idx_0*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n" "#ifdef BLOCK_LEAVE\n" " const int remain=out_hw.x-out_h_idx;\n" " if(remain >= 4){\n" @@ -14514,12 +14816,12 @@ const char* conv_2d_buf = " }else if(remain == 1){\n" " vstore4(CONVERT_FLOAT4(out0),0,output+out_offset);\n" " }\n" -" #ifdef CHANNEL_LEAVE\n" -" if(out_c_idx+1 >= out_c_blocks){\n" +" #ifdef CHANNEL_BOUNDARY_PROTECT\n" +" if(out_c_idx_1 >= out_c_blocks){\n" " return;\n" " }\n" " #endif\n" -" out_offset=(((out_b_idx+(out_c_idx+1)*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n" +" out_offset=(((out_b_idx+(out_c_idx_1)*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n" " if(remain >= 4){\n" " vstore4(CONVERT_FLOAT4(out4),0,output+out_offset);\n" " vstore4(CONVERT_FLOAT4(out5),out_hw.y,output+out_offset);\n" @@ -14540,12 +14842,12 @@ const char* conv_2d_buf = " vstore4(CONVERT_FLOAT4(out1),out_hw.y,output+out_offset);\n" " vstore4(CONVERT_FLOAT4(out2),2*out_hw.y,output+out_offset);\n" " vstore4(CONVERT_FLOAT4(out3),3*out_hw.y,output+out_offset);\n" -" #ifdef CHANNEL_LEAVE\n" -" if(out_c_idx+1 >= out_c_blocks){\n" +" #ifdef CHANNEL_BOUNDARY_PROTECT\n" +" if(out_c_idx_1 >= out_c_blocks){\n" " return;\n" " }\n" " #endif\n" -" out_offset=(((out_b_idx+(out_c_idx+1)*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n" +" out_offset=(((out_b_idx+(out_c_idx_1)*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n" " vstore4(CONVERT_FLOAT4(out4),0,output+out_offset);\n" " vstore4(CONVERT_FLOAT4(out5),out_hw.y,output+out_offset);\n" " vstore4(CONVERT_FLOAT4(out6),2*out_hw.y,output+out_offset);\n" @@ -14573,15 +14875,21 @@ const char* conv_2d_buf = " const int out_c_w_idx=get_global_id(0); //c/4 w\n" " const int out_b_h_idx=get_global_id(1); //b h\n" " DEAL_NON_UNIFORM_DIM2(out_c_w_idx,out_b_h_idx);\n" -" const int out_c_idx=(out_c_w_idx/out_w_blocks) << 1;\n" +" const int out_c_idx_0=(out_c_w_idx/out_w_blocks) << 1;\n" +" const int out_c_idx_1=out_c_idx_0+1;\n" " const int out_w_idx=out_c_w_idx % out_w_blocks;\n" " const int out_b_idx=out_b_h_idx/out_h_blocks;//equal to in_b_idx\n" " const int out_h_idx=(out_b_h_idx % out_h_blocks) << 1;\n" " \n" -" COMPUTE_FLOAT4 out0=CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx,bias));\n" +" COMPUTE_FLOAT4 out0=CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx_0,bias));\n" " COMPUTE_FLOAT4 out1=out0;\n" -" COMPUTE_FLOAT4 out2=CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx+1,bias));\n" +" #ifdef CHANNEL_BOUNDARY_PROTECT\n" +" COMPUTE_FLOAT4 out2=out_c_idx_1 >= out_c_blocks ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx_1,bias));\n" +" #else\n" +" COMPUTE_FLOAT4 out2=CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx_1,bias));\n" +" #endif\n" " COMPUTE_FLOAT4 out3=out2;\n" +" \n" " const int in_w_idx_base=mad24(out_w_idx,stride_hw.y,-pad_hw.y);\n" " const int in_h0_idx_base=mad24(out_h_idx,stride_hw.x,-pad_hw.x);\n" " const int in_h1_idx_base=in_h0_idx_base+stride_hw.x;\n" @@ -14596,10 +14904,10 @@ const char* conv_2d_buf = " // weight: [ic/4,oc,4],loop: ic/4\n" " for(ushort in_c_idx=0; in_c_idx= out_c_blocks ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(0,weight+weight_offset+weight_oc_offset));\n" +" weight1=out_c_idx_1 >= out_c_blocks ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(0,weight+weight_offset+weight_oc_offset+weight_ic_offset));\n" +" weight2=out_c_idx_1 >= out_c_blocks ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(0,weight+weight_offset+weight_oc_offset+weight_ic_offset*2));\n" +" weight3=out_c_idx_1 >= out_c_blocks ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(0,weight+weight_offset+weight_oc_offset+weight_ic_offset*3));\n" +" #else\n" " weight0=CONVERT_COMPUTE_FLOAT4(vload4(0,weight+weight_offset+weight_oc_offset));\n" " weight1=CONVERT_COMPUTE_FLOAT4(vload4(0,weight+weight_offset+weight_oc_offset+weight_ic_offset));\n" " weight2=CONVERT_COMPUTE_FLOAT4(vload4(0,weight+weight_offset+weight_oc_offset+weight_ic_offset*2));\n" " weight3=CONVERT_COMPUTE_FLOAT4(vload4(0,weight+weight_offset+weight_oc_offset+weight_ic_offset*3));\n" -" \n" +" #endif\n" " out2=mad(in0.x,weight0,out2);\n" " out2=mad(in0.y,weight1,out2);\n" " out2=mad(in0.z,weight2,out2);\n" @@ -14651,7 +14965,7 @@ const char* conv_2d_buf = " out2=clamp(out2,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n" " out3=clamp(out3,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n" "#endif\n" -" int out_offset=(((out_b_idx+out_c_idx*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n" +" int out_offset=(((out_b_idx+out_c_idx_0*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n" "#ifdef BLOCK_LEAVE\n" " const int remain=out_hw.x-out_h_idx;\n" " if(remain >= 2){\n" @@ -14660,12 +14974,12 @@ const char* conv_2d_buf = " }else if(remain == 1){\n" " vstore4(CONVERT_FLOAT4(out0),0,output+out_offset);\n" " }\n" -" #ifdef CHANNEL_LEAVE\n" -" if(out_c_idx+1 >= out_c_blocks){\n" +" #ifdef CHANNEL_BOUNDARY_PROTECT\n" +" if(out_c_idx_1 >= out_c_blocks){\n" " return;\n" " }\n" " #endif\n" -" out_offset=(((out_b_idx+(out_c_idx+1)*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n" +" out_offset=(((out_b_idx+(out_c_idx_1)*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n" " if(remain >= 2){\n" " vstore4(CONVERT_FLOAT4(out2),0,output+out_offset);\n" " vstore4(CONVERT_FLOAT4(out3),out_hw.y,output+out_offset);\n" @@ -14675,12 +14989,12 @@ const char* conv_2d_buf = "#else\n" " vstore4(CONVERT_FLOAT4(out0),0,output+out_offset);\n" " vstore4(CONVERT_FLOAT4(out1),out_hw.y,output+out_offset);\n" -" #ifdef CHANNEL_LEAVE\n" -" if(out_c_idx+1 >= out_c_blocks){\n" +" #ifdef CHANNEL_BOUNDARY_PROTECT\n" +" if(out_c_idx_1 >= out_c_blocks){\n" " return;\n" " }\n" " #endif\n" -" out_offset=(((out_b_idx+(out_c_idx+1)*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n" +" out_offset=(((out_b_idx+(out_c_idx_1)*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n" " vstore4(CONVERT_FLOAT4(out2),0,output+out_offset);\n" " vstore4(CONVERT_FLOAT4(out3),out_hw.y,output+out_offset);\n" "#endif\n" @@ -14706,17 +15020,21 @@ const char* conv_2d_buf = " const int out_c_w_idx=get_global_id(0); //c/4 w\n" " const int out_b_h_idx=get_global_id(1); //b h\n" " DEAL_NON_UNIFORM_DIM2(out_c_w_idx,out_b_h_idx);\n" -" const int out_c_idx=(out_c_w_idx/out_w_blocks) << 1;\n" +" const int out_c_idx_0=(out_c_w_idx/out_w_blocks) << 1;\n" +" const int out_c_idx_1=out_c_idx_0+1;\n" " const int out_w_idx=(out_c_w_idx % out_w_blocks) << 2;\n" " const int out_b_idx=out_b_h_idx/out_hw.x;//equal to in_b_idx\n" " const int out_h_idx=out_b_h_idx % out_hw.x;\n" " \n" -" COMPUTE_FLOAT4 out0=CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx,bias));\n" +" COMPUTE_FLOAT4 out0=CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx_0,bias));\n" " COMPUTE_FLOAT4 out1=out0;\n" " COMPUTE_FLOAT4 out2=out0;\n" " COMPUTE_FLOAT4 out3=out0;\n" -" \n" -" COMPUTE_FLOAT4 out4=CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx+1,bias));\n" +" #ifdef CHANNEL_BOUNDARY_PROTECT\n" +" COMPUTE_FLOAT4 out4=out_c_idx_1 >= out_c_blocks ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx_1,bias));\n" +" #else\n" +" COMPUTE_FLOAT4 out4=CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx_1,bias));\n" +" #endif\n" " COMPUTE_FLOAT4 out5=out4;\n" " COMPUTE_FLOAT4 out6=out4;\n" " COMPUTE_FLOAT4 out7=out4;\n" @@ -14734,8 +15052,8 @@ const char* conv_2d_buf = " const int weight_ic_offset=out_c_blocks*weight_oc_offset;\n" " for(ushort in_c_idx=0; in_c_idx= out_c_blocks ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(0,weight+weight_offset+weight_oc_offset));\n" +" weight1=out_c_idx_1 >= out_c_blocks ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(0,weight+weight_offset+weight_oc_offset+weight_ic_offset));\n" +" weight2=out_c_idx_1 >= out_c_blocks ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(0,weight+weight_offset+weight_oc_offset+weight_ic_offset*2));\n" +" weight3=out_c_idx_1 >= out_c_blocks ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(0,weight+weight_offset+weight_oc_offset+weight_ic_offset*3));\n" +" #else\n" " weight0=CONVERT_COMPUTE_FLOAT4(vload4(0,weight+weight_offset+weight_oc_offset));\n" " weight1=CONVERT_COMPUTE_FLOAT4(vload4(0,weight+weight_offset+weight_oc_offset+weight_ic_offset));\n" " weight2=CONVERT_COMPUTE_FLOAT4(vload4(0,weight+weight_offset+weight_oc_offset+weight_ic_offset*2));\n" " weight3=CONVERT_COMPUTE_FLOAT4(vload4(0,weight+weight_offset+weight_oc_offset+weight_ic_offset*3));\n" -" \n" +" #endif\n" " out4=mad(in0.x,weight0,out4);\n" " out4=mad(in0.y,weight1,out4);\n" " out4=mad(in0.z,weight2,out4);\n" @@ -14820,7 +15144,7 @@ const char* conv_2d_buf = " out6=clamp(out6,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n" " out7=clamp(out7,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n" "#endif\n" -" int out_offset=(((out_b_idx+out_c_idx*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n" +" int out_offset=(((out_b_idx+out_c_idx_0*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n" "#ifdef BLOCK_LEAVE\n" " const int remain=out_hw.y-out_w_idx;\n" " if(remain >= 4){\n" @@ -14833,10 +15157,10 @@ const char* conv_2d_buf = " }else if(remain == 1){\n" " vstore4(CONVERT_FLOAT4(out0),0,output+out_offset);\n" " }\n" -" #ifdef CHANNEL_LEAVE\n" -" if(out_c_idx+1 >= out_c_blocks)return;\n" +" #ifdef CHANNEL_BOUNDARY_PROTECT\n" +" if(out_c_idx_1 >= out_c_blocks)return;\n" " #endif\n" -" out_offset=(((out_b_idx+(out_c_idx+1)*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n" +" out_offset=(((out_b_idx+(out_c_idx_1)*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n" " if(remain >= 4){\n" " vstore16(CONVERT_FLOAT16((COMPUTE_FLOAT16)(out4,out5,out6,out7)),0,output+out_offset);\n" " }else if(remain == 3){\n" @@ -14849,10 +15173,10 @@ const char* conv_2d_buf = " }\n" "#else\n" " vstore16(CONVERT_FLOAT16((COMPUTE_FLOAT16)(out0,out1,out2,out3)),0,output+out_offset);\n" -" #ifdef CHANNEL_LEAVE\n" -" if(out_c_idx+1 >= out_c_blocks)return;\n" +" #ifdef CHANNEL_BOUNDARY_PROTECT\n" +" if(out_c_idx_1 >= out_c_blocks)return;\n" " #endif\n" -" out_offset=(((out_b_idx+(out_c_idx+1)*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n" +" out_offset=(((out_b_idx+(out_c_idx_1)*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n" " vstore16(CONVERT_FLOAT16((COMPUTE_FLOAT16)(out4,out5,out6,out7)),0,output+out_offset);\n" "#endif\n" "}\n" @@ -18853,7 +19177,7 @@ const char* buffer_convert_buf = " }\n" "#endif\n" "}\n" -"// convert kernel : from buffer(oihw) to image(oc/4 h w ,ic oc4)\n" +"// convert kernel : from buffer(oihw) to image(ic,oc/4,h,w,oc4)\n" "__kernel void conv2d_filter_buffer_to_nc4hw4_buffer(GLOBAL_SIZE_2_DIMS\n" " __global const FLOAT *input_ptr,\n" " __private const int output_channel,\n" diff --git a/source/backend/opencl/execution/cl/opencl_source_map.hpp b/source/backend/opencl/execution/cl/opencl_source_map.hpp index 5f6861718..a347025c2 100644 --- a/source/backend/opencl/execution/cl/opencl_source_map.hpp +++ b/source/backend/opencl/execution/cl/opencl_source_map.hpp @@ -71,6 +71,7 @@ extern const char* unary_buf; #ifndef MNN_OPENCL_BUFFER_CLOSED extern const char* depthwise_conv2d_buf; #endif +extern const char* glmem_convert; #ifndef MNN_OPENCL_BUFFER_CLOSED extern const char* winogradTransform_buf; #endif @@ -242,6 +243,7 @@ const std::map OpenCLProgramMap = #ifndef MNN_OPENCL_BUFFER_CLOSED { "depthwise_conv2d_buf", depthwise_conv2d_buf }, #endif + { "glmem_convert", glmem_convert }, #ifndef MNN_OPENCL_BUFFER_CLOSED { "winogradTransform_buf", winogradTransform_buf }, #endif diff --git a/source/backend/opencl/execution/image/ConvExecution.cpp b/source/backend/opencl/execution/image/ConvExecution.cpp index d2f6d288a..2f3a8e4a4 100644 --- a/source/backend/opencl/execution/image/ConvExecution.cpp +++ b/source/backend/opencl/execution/image/ConvExecution.cpp @@ -25,12 +25,12 @@ ConvCommonExecution::ConvCommonExecution(const Convolution2D *conv2dParams, Back int biasSize = conv2dParams->bias()->size(); const float *biasDataPtr = conv2dParams->bias()->data(); - int buffer_size = ALIGN_UP4(biasSize) * sizeof(float); + int buffer_size = ALIGN_UP8(biasSize) * sizeof(float); cl::Buffer biasBuffer(runtime->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, buffer_size); cl_int error; auto biasPtrCL = runtime->commandQueue().enqueueMapBuffer(biasBuffer, true, CL_MAP_WRITE, 0, buffer_size, nullptr, nullptr, &error); if(biasPtrCL != nullptr && error == CL_SUCCESS){ - ::memset(biasPtrCL, 0, ALIGN_UP4(biasSize) * sizeof(float)); + ::memset(biasPtrCL, 0, ALIGN_UP8(biasSize) * sizeof(float)); ::memcpy(biasPtrCL, biasDataPtr, biasSize * sizeof(float)); }else{ MNN_ERROR("Map error biasPtrCL == nullptr \n"); @@ -328,7 +328,11 @@ ErrorCode ConvExecution::onEncode(const std::vector &inputs, const std std::pair min_cost(INT_MAX, 0);//(min_time, min_index) for(int knl_idx = 0; knl_idx < 1; knl_idx++) { - kernel[knl_idx] = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d", kernelName[knl_idx], mResource->mBuildOptions); + std::set buildOption = mResource->mBuildOptions; + if(itemC[knl_idx] == 8 && outputShape.at(3) % itemC[knl_idx] > 0 && outputShape.at(3) % itemC[knl_idx] <= 4){ + buildOption.emplace("-DCHANNEL_BOUNDARY_PROTECT"); + } + kernel[knl_idx] = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d", kernelName[knl_idx], buildOption); uint32_t maxWorkGroupSize = static_cast(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(kernel[knl_idx])); globalWorkSize[knl_idx] = {static_cast(UP_DIV(outputShape.at(3), itemC[knl_idx]) * UP_DIV(outputShape.at(2), itemW[knl_idx])), static_cast(outputShape.at(0) * UP_DIV(outputShape.at(1), itemH[knl_idx]))}; @@ -363,7 +367,11 @@ ErrorCode ConvExecution::onEncode(const std::vector &inputs, const std int min_index = min_cost.second; //printf("min_index = %d %d\n", min_index, min_cost.first); mGlobalWorkSize = {globalWorkSize[min_index][0], globalWorkSize[min_index][1]}; - unit.kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d", kernelName[min_index], mResource->mBuildOptions); + std::set buildOption = mResource->mBuildOptions; + if(itemC[min_index] == 8 && outputShape.at(3) % itemC[min_index] > 0 && outputShape.at(3) % itemC[min_index] <= 4){ + buildOption.emplace("-DCHANNEL_BOUNDARY_PROTECT"); + } + unit.kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d", kernelName[min_index], buildOption); uint32_t idx = 0; unit.kernel->get().setArg(idx++, mGlobalWorkSize[0]); @@ -407,7 +415,11 @@ ErrorCode ConvExecution::onEncode(const std::vector &inputs, const std std::pair min_cost(INT_MAX, 0);//(min_time, min_index) for(int knl_idx = 0; knl_idx < total_kernel; knl_idx++) { - kernel[knl_idx] = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d", kernelName[knl_idx], mResource->mBuildOptions); + std::set buildOption = mResource->mBuildOptions; + if(itemC[knl_idx] == 8 && outputShape.at(3) % itemC[knl_idx] > 0 && outputShape.at(3) % itemC[knl_idx] <= 4){ + buildOption.emplace("-DCHANNEL_BOUNDARY_PROTECT"); + } + kernel[knl_idx] = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d", kernelName[knl_idx], buildOption); uint32_t maxWorkGroupSize = static_cast(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(kernel[knl_idx])); globalWorkSize[knl_idx] = {static_cast(UP_DIV(outputShape.at(3), itemC[knl_idx]) * UP_DIV(outputShape.at(2), itemW[knl_idx])), static_cast(outputShape.at(0) * UP_DIV(outputShape.at(1), itemH[knl_idx]))}; @@ -446,7 +458,11 @@ ErrorCode ConvExecution::onEncode(const std::vector &inputs, const std } int min_index = min_cost.second; mGlobalWorkSize = {globalWorkSize[min_index][0], globalWorkSize[min_index][1]}; - unit.kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d", kernelName[min_index], mResource->mBuildOptions); + std::set buildOption = mResource->mBuildOptions; + if(itemC[min_index] == 8 && outputShape.at(3) % itemC[min_index] > 0 && outputShape.at(3) % itemC[min_index] <= 4){ + buildOption.emplace("-DCHANNEL_BOUNDARY_PROTECT"); + } + unit.kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d", kernelName[min_index], buildOption); uint32_t idx = 0; cl_int ret = CL_SUCCESS; diff --git a/source/backend/opencl/execution/image/ConvLowMemoryExecution.cpp b/source/backend/opencl/execution/image/ConvLowMemoryExecution.cpp index ec8e9f3e2..97dd8a770 100644 --- a/source/backend/opencl/execution/image/ConvLowMemoryExecution.cpp +++ b/source/backend/opencl/execution/image/ConvLowMemoryExecution.cpp @@ -239,6 +239,9 @@ void ConvLowMemoryExecution::tune1x1CaseLowMemory(Tensor * input, Tensor * outpu if(inputChannels % 4 != 0){ buildOption.emplace("-DINPUT_CHANNEL_LEAVE"); } + if(itemC[knl_idx] == 8 && outputShape.at(3) % itemC[knl_idx] > 0 && outputShape.at(3) % itemC[knl_idx] <= 4){ + buildOption.emplace("-DCHANNEL_BOUNDARY_PROTECT"); + } kernel[knl_idx] = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d", kernelName[knl_idx], buildOption); uint32_t maxWorkGroupSize = static_cast(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(kernel[knl_idx])); @@ -277,6 +280,9 @@ void ConvLowMemoryExecution::tune1x1CaseLowMemory(Tensor * input, Tensor * outpu if(inputChannels % 4 != 0){ buildOption.emplace("-DINPUT_CHANNEL_LEAVE"); } + if(itemC[min_index] == 8 && outputShape.at(3) % itemC[min_index] > 0 && outputShape.at(3) % itemC[min_index] <= 4){ + buildOption.emplace("-DCHANNEL_BOUNDARY_PROTECT"); + } unit.kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d", kernelName[min_index], buildOption); uint32_t idx = 0; ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[0]); @@ -338,6 +344,9 @@ void ConvLowMemoryExecution::tuneGeneralCaseLowMemory(Tensor * input, Tensor * o if(inputChannels % 4 != 0){ buildOption.emplace("-DINPUT_CHANNEL_LEAVE"); } + if(itemC[knl_idx] == 8 && outputShape.at(3) % itemC[knl_idx] > 0 && outputShape.at(3) % itemC[knl_idx] <= 4){ + buildOption.emplace("-DCHANNEL_BOUNDARY_PROTECT"); + } kernel[knl_idx] = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d", kernelName[knl_idx], buildOption); uint32_t maxWorkGroupSize = static_cast(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(kernel[knl_idx])); @@ -379,6 +388,9 @@ void ConvLowMemoryExecution::tuneGeneralCaseLowMemory(Tensor * input, Tensor * o if(inputChannels % 4 != 0){ buildOption.emplace("-DINPUT_CHANNEL_LEAVE"); } + if(itemC[min_index] == 8 && outputShape.at(3) % itemC[min_index] > 0 && outputShape.at(3) % itemC[min_index] <= 4){ + buildOption.emplace("-DCHANNEL_BOUNDARY_PROTECT"); + } unit.kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d", kernelName[min_index], buildOption); uint32_t idx = 0; diff --git a/source/backend/vulkan/buffer/execution/VulkanPRelu.cpp b/source/backend/vulkan/buffer/execution/VulkanPRelu.cpp index 3181524e9..751d24934 100644 --- a/source/backend/vulkan/buffer/execution/VulkanPRelu.cpp +++ b/source/backend/vulkan/buffer/execution/VulkanPRelu.cpp @@ -80,7 +80,7 @@ class VulkanReluCreator : public VulkanBackend::Creator { public: virtual VulkanBasicExecution *onCreate(const std::vector &inputs, const std::vector& outputs, const MNN::Op *op, Backend *bn) const override { if (1 == op->main_as_PRelu()->slopeCount()) { - return new VulkanUnary("RELU", bn, op->main_as_PRelu()->slope()->data()[0]); + return new VulkanUnary("RELU", bn, false, op->main_as_PRelu()->slope()->data()[0]); } return new VulkanPrelu(bn, op); } diff --git a/source/core/Interpreter.cpp b/source/core/Interpreter.cpp index c8280ceef..4b62bc93a 100644 --- a/source/core/Interpreter.cpp +++ b/source/core/Interpreter.cpp @@ -193,6 +193,9 @@ void Interpreter::setExternalFile(const char* file, size_t flag) { } ErrorCode Interpreter::updateCacheFile(Session *session, int flag) { + if (mNet->cacheFile.empty()) { + return NOT_SUPPORT; + } std::lock_guard _l(mNet->lock); // Backend_Auto and no Async work, then don't need updateCache diff --git a/source/core/Pipeline.cpp b/source/core/Pipeline.cpp index 8bb123a41..12807a42f 100644 --- a/source/core/Pipeline.cpp +++ b/source/core/Pipeline.cpp @@ -27,7 +27,7 @@ static bool _supportQuant(const Op* op, const std::vector& inputs, cons switch (otype) { case OpType_Convolution: case OpType_ConvolutionDepthwise: - case OpType_Deconvolution: +// case OpType_Deconvolution: if (inputs.size() > 1) { return false; } diff --git a/source/core/TensorUtils.cpp b/source/core/TensorUtils.cpp index 01398fb34..5c9830568 100644 --- a/source/core/TensorUtils.cpp +++ b/source/core/TensorUtils.cpp @@ -487,7 +487,7 @@ static bool _ClipDst(int* stride, int srcOffset, int dstOffset, const int* srcSi dx=sx-xo -> [max(0, -xo), max(0, min(sxr-xo, dxr))] dy,dz compute the same **/ - + int offsetBias = dstOffset - srcOffset; if (sizeNum == 0) { // All stride is zero, then size will be all one @@ -903,4 +903,14 @@ void TensorUtils::setTensorPad(const Tensor* tensor, int left, int right, int bo srcDes->mPads.top = std::max(srcDes->mPads.top, top); } +void TensorUtils::setSharedMem(const Tensor *tensor, Backend::MemObj *mem){ + auto srcDes = TensorUtils::getDescribe(tensor); + srcDes->mSharedMem = mem; +} + +Backend::MemObj* TensorUtils::getSharedMem(const Tensor* tensor){ + auto srcDes = TensorUtils::getDescribe(tensor); + return srcDes->mSharedMem.get(); +} + } // namespace MNN diff --git a/source/core/TensorUtils.hpp b/source/core/TensorUtils.hpp index 442b3184a..268c8be00 100644 --- a/source/core/TensorUtils.hpp +++ b/source/core/TensorUtils.hpp @@ -124,6 +124,8 @@ struct Tensor::InsideDescribe { pad mPads; // For isMutable = false Tensor , determine whether the content can be convert to main backend uint32_t stageMask = 0; + // Use for shared memory + SharedPtr mSharedMem; }; std::shared_ptr mContent; SharedPtr mem; @@ -224,6 +226,10 @@ class MNN_PUBLIC TensorUtils { static void setTensorSupportPack(const Tensor* tensor, bool flag); static void setTensorPad(const Tensor* tensor, int left, int right, int bottom, int top); + + static void setSharedMem(const Tensor* tensor, Backend::MemObj *mem); + + static Backend::MemObj* getSharedMem(const Tensor* tensor); }; } // namespace MNN diff --git a/source/shape/ShapeConcat.cpp b/source/shape/ShapeConcat.cpp index 3eb2675b3..8eba40670 100644 --- a/source/shape/ShapeConcat.cpp +++ b/source/shape/ShapeConcat.cpp @@ -14,7 +14,7 @@ class ConcatSizeComputer : public SizeComputer { virtual bool onComputeSize(const MNN::Op* op, const std::vector& inputs, const std::vector& outputs) const override { MNN_ASSERT(1 == outputs.size()); - MNN_ASSERT(inputs.size() >= 2); + // MNN_ASSERT(inputs.size() >= 2); auto& ob = outputs[0]->buffer(); int basicAxis = 0; if (op->type() == OpType_Concat) { diff --git a/source/shape/ShapeRegister.cpp b/source/shape/ShapeRegister.cpp index f917bf39a..f4ed83802 100644 --- a/source/shape/ShapeRegister.cpp +++ b/source/shape/ShapeRegister.cpp @@ -122,6 +122,9 @@ extern void ___FmhaV2SizeComputer__OpType_FmhaV2__(); extern void ___FmhcaSizeComputer__OpType_Fmhca__(); extern void ___AttentionSizeComputer__OpType_Attention__(); #endif +#ifdef MNN_BUILD_AUDIO +extern void ___StftOpComputer__OpType_Stft__(); +#endif void registerShapeOps() { ___ShapeSizeComputer__OpType_Shape__(); ___ShapeRasterComputer__OpType_Raster__(); @@ -244,5 +247,8 @@ ___FmhaV2SizeComputer__OpType_FmhaV2__(); ___FmhcaSizeComputer__OpType_Fmhca__(); ___AttentionSizeComputer__OpType_Attention__(); #endif +#ifdef MNN_BUILD_AUDIO +___StftOpComputer__OpType_Stft__(); +#endif } } diff --git a/source/shape/ShapeStft.cpp b/source/shape/ShapeStft.cpp new file mode 100644 index 000000000..59847ad62 --- /dev/null +++ b/source/shape/ShapeStft.cpp @@ -0,0 +1,38 @@ +// +// ShapeStft.cpp +// MNN +// +// Created by MNN on 2024/11/26. +// Copyright © 2018, Alibaba Group Holding Limited +// + +#ifdef MNN_BUILD_AUDIO + +#include "shape/SizeComputer.hpp" +#include "core/Macro.h" +#include "core/TensorUtils.hpp" + +namespace MNN { + +class StftOpComputer : public SizeComputer { + virtual bool onComputeSize(const MNN::Op* op, const std::vector& inputs, + const std::vector& outputs) const override { + int sample_length = inputs[0]->elementSize(); + auto stft = op->main_as_StftParam(); + bool abs = stft->abs(); + int n_fft = stft->n_fft(); + int hop_length = stft->hop_length(); + int frames = (sample_length - n_fft) / hop_length + 1; + // Scalar + outputs[0]->buffer().dimensions = 2; + outputs[0]->setLength(0, frames); + outputs[0]->setLength(1, n_fft / 2 + 1); + outputs[0]->buffer().type = inputs[0]->getType(); + TensorUtils::getDescribe(outputs[0])->dimensionFormat = TensorUtils::getDescribe(inputs[0])->dimensionFormat; + return true; + } +}; + +REGISTER_SHAPE_AUDIO(StftOpComputer, OpType_Stft); +} // namespace MNN +#endif // MNN_BUILD_AUDIO diff --git a/source/shape/SizeComputer.hpp b/source/shape/SizeComputer.hpp index 996561d51..7a1c95312 100644 --- a/source/shape/SizeComputer.hpp +++ b/source/shape/SizeComputer.hpp @@ -186,4 +186,13 @@ class SizeComputerRegister { #endif +#ifdef MNN_BUILD_AUDIO +#define REGISTER_SHAPE_AUDIO(name, op) \ + void ___##name##__##op##__() { \ + name* _temp = new name; \ + SizeComputerSuite* ts = SizeComputerSuite::get(); \ + ts->insert(_temp, op); \ + } +#endif + #endif diff --git a/test.sh b/test.sh index 7204117a8..168f06c40 100755 --- a/test.sh +++ b/test.sh @@ -167,7 +167,7 @@ android_static_build() { -DMNN_INTERNAL=ON \ -DMNN_USE_LOGCAT=false \ -DMNN_BUILD_BENCHMARK=ON \ - -DANDROID_NATIVE_API_LEVEL=android-21 \ + -DANDROID_NATIVE_API_LEVEL=android-26 \ -DMNN_BUILD_FOR_ANDROID_COMMAND=true \ -DMNN_OPENGL=true \ -DMNN_BUILD_TRAIN=true \ @@ -198,7 +198,7 @@ android_static_build() { -DMNN_USE_LOGCAT=false \ -DMNN_BUILD_BENCHMARK=ON \ -DMNN_INTERNAL=ON \ - -DANDROID_NATIVE_API_LEVEL=android-21 \ + -DANDROID_NATIVE_API_LEVEL=android-26 \ -DMNN_BUILD_FOR_ANDROID_COMMAND=true \ -DMNN_OPENGL=true \ -DMNN_BUILD_TRAIN=true \ diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index f128825a6..9aa84590a 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -4,6 +4,10 @@ if(MNN_WITH_PLUGIN) list(APPEND TEST_DEPS plugin_matmul) endif() +if (CMAKE_SYSTEM_NAME MATCHES "^Android") + list(APPEND TEST_DEPS android) +endif() + if(APPLE) file(GLOB_RECURSE Files ${CMAKE_CURRENT_LIST_DIR}/*.cpp ${CMAKE_CURRENT_LIST_DIR}/*.mm) else() diff --git a/test/op/DeconvolutionTest.cpp b/test/op/DeconvolutionTest.cpp index e7443064b..4e1e4d4ea 100644 --- a/test/op/DeconvolutionTest.cpp +++ b/test/op/DeconvolutionTest.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include "MNNTestSuite.h" #include "TestUtils.h" @@ -17,6 +18,56 @@ using namespace std; using namespace MNN; using namespace MNN::Express; +static void reference_deconv2d(const float* input, const std::vector& weight, + const std::vector& bias, std::vector& output, int batch, int ic, int oc, + int ih, int iw, int pad_h, int pad_w, int kh, int kw, int stridew, int strideh, + int dilation) { + int oh, ow; + ow = (iw - 1) * stridew + dilation * (kw - 1) + 1 - pad_w * 2; + oh = (ih - 1) * strideh + dilation * (kh - 1) + 1 - pad_h * 2; + + if (oh <= 0 || ow <= 0) { + output.clear(); + return; + } + output.resize(batch * oh * ow * oc); + for (int b = 0; b < batch; ++b) { + for (int oz = 0; oz < oc; ++oz) { + auto outputPtr = output.data() + b * oh * ow * oc + oz * ow * oh; + for (int s=0; s= 0 && ox < ow && oy >= 0 && oy < oh) { + auto w = weight[wOffset+ky*kw+kx]; + output[destOffset + oy * ow + ox] += xValue * w; + } + } + } + } + } + } + } + } +} + + static PadMode _convertPadMode(PaddingMode mode) { switch (mode) { case CAFFE: @@ -72,7 +123,52 @@ class DeconvolutionCommonTest : public MNNTestCase { virtual ~DeconvolutionCommonTest() = default; protected: - static bool test(MNNForwardType type, const std::string& device_name, const std::string& test_op_name, + static bool test(const std::string& test_op_name, + int batch, int ic, int oc, int ih, int iw, int pad_h, int pad_w, int kh, + int kw, int stride, int dilation, int group, int precision) { + int ow = (iw - 1) * stride + dilation * (kw - 1) + 1 - pad_w * 2; + int oh = (ih - 1) * stride + dilation * (kh - 1) + 1 - pad_h * 2; + if (ow <=0 || oh <= 0) { + return true; + } + auto input = _Input({batch, ic, ih, iw}, NCHW, halide_type_of()); + auto inputPtr = input->writeMap(); + { + int size = input->getInfo()->size; + for (int i=0; i weightData(ic*oc*kh*kw); + for (int i=0; i biasData(oc); + for (int i=0; i rightOutData; + reference_deconv2d(inputPtr, weightData, biasData, rightOutData, batch, ic, oc, ih, iw, pad_h, pad_w, kh, kw, stride, stride, dilation); + input = _Convert(input, NC4HW4); + auto output = _Deconv(std::move(weightData), std::move(biasData), input, {ic, oc}, {kw, kh}, VALID, + {stride, stride}, {dilation, dilation}, group, {pad_w, pad_h}, false, false); + output = _Convert(output, NCHW); + if (rightOutData.size() != output->getInfo()->size) { + FUNC_PRINT(1); + return false; + } + + + // difference below 0.5% relative error is considered correct. + auto outputPtr = output->readMap(); + float errorScale = precision <= MNN::BackendConfig::Precision_High ? 1 : 20; + if (!checkVectorByRelativeError(outputPtr, rightOutData.data(), rightOutData.size(), 0.005 * errorScale)) { + MNN_ERROR("%s test failed!\n", test_op_name.c_str()); + return false; + } + return true; + } + static bool test(const std::string& test_op_name, vector& inputData, vector& weightData, vector& biasData, vector& rightOutData, int batch, int ic, int oc, int ih, int iw, PadMode mode, int pad_h, int pad_w, int kh, int kw, int stride, int dilation, int group, int precision) { @@ -87,7 +183,7 @@ class DeconvolutionCommonTest : public MNNTestCase { auto outputPtr = output->readMap(); float errorScale = precision <= MNN::BackendConfig::Precision_High ? 1 : 20; if (!checkVectorByRelativeError(outputPtr, rightOutData.data(), rightOutData.size(), 0.005 * errorScale)) { - MNN_ERROR("%s(%s) test failed!\n", test_op_name.c_str(), device_name.c_str()); + MNN_ERROR("%s test failed!\n", test_op_name.c_str()); return false; } return true; @@ -122,6 +218,60 @@ class DeconvolutionCommonTestInt8 : public MNNTestCase { return true; } }; +class DeconvolutionFullTest : public DeconvolutionCommonTest { +public: + virtual ~DeconvolutionFullTest() = default; + virtual bool run(int precision) { + if (MNN_FORWARD_OPENCL == getCurrentType()) { + MNN_ERROR("Currently opencl run deconvolution has error, skip it\n"); + return true; + } + int ocStep = 1; + int icStep = 1; + int isStep = 3; + std::vector ocSize = { + 1, 3, 10, 17 + }; + std::vector icSize = { + 1, 4, 3, 8, 11 + }; + std::vector isSize = { + 1, 7, 9, 13 + }; + + for (int batch = 1; batch <= 2; batch++) { + for (auto oc : ocSize) { + for (auto ic : icSize) { + for (auto is : isSize) { + int ih = is; + int iw = is; + for (int kw = 1; kw <= 7 && kw <= is; kw+=2) { + for (int kh = 1; kh <= 7 && kh <= is; kh+=3) { + for (int d = 1; d <= 2; d++) { + for (int s = 1; s <= 2; s++) { + int stride = s; + for (int p = 0; p <= 1; p++) { + std::ostringstream name; + int pad_w = p; + int pad_h = p; + name << "Deconvolution: " << batch <<","<< oc <<","<setName("input_tensor"); // set input data @@ -39,6 +40,7 @@ class MomentsTest : public MNNTestCase { MNN_ERROR("MomentsTest test failed!\n"); return false; } +#endif return true; } }; diff --git a/test/op/PReLUTest.cpp b/test/op/PReLUTest.cpp index f6a3d1365..62b8bab95 100644 --- a/test/op/PReLUTest.cpp +++ b/test/op/PReLUTest.cpp @@ -43,7 +43,7 @@ class PreluTestInt8 : public MNNTestCase { auto input = _Input({1, 12, 4, 2}, NCHW); input->setName("input_tensor"); // set input data - input->writeScaleMap(0.03567, 1.0); + input->writeScaleMap(0.02745, -18.714); const float inpudata[] = {-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, -3.0, -3.0, -3.0, -3.0, -3.0, -3.0, -3.0, -3.0, @@ -57,7 +57,7 @@ class PreluTestInt8 : public MNNTestCase { -3.0, -3.0, -3.0, -3.0, -3.0, -3.0, -3.0, -3.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0}; auto inputPtr = input->writeMap(); - memcpy(inputPtr, inpudata, 4 * sizeof(float)); + memcpy(inputPtr, inpudata, 96 * sizeof(float)); input->unMap(); input = _Convert(input, NC4HW4); auto output = _PRelu(input, {3.0, 1.5, 1.5, 1.5, 3.0, 1.5, 1.5, 1.5, 3.0, 1.5, 1.5, 1.5}); @@ -75,10 +75,32 @@ class PreluTestInt8 : public MNNTestCase { -4.5, -4.5, -4.5, -4.5, -4.5, -4.5, -4.5, -4.5, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0 }; - output->writeScaleMap(0.03567, 1.0); + output->writeScaleMap(0.03333, 7.f); auto gotOutput = output->readMap(); - if (!checkVector(gotOutput, expectedOutput.data(), 4, 0.05)) { - MNN_ERROR("PreluTest test failed!\n"); + if (!checkVector(gotOutput, expectedOutput.data(), 96, 0.1)) { + MNN_ERROR("PreluTest test 1 failed!\n"); + return false; + } + // prelu: one slope + auto output1 = _PRelu(input, {3.0}); + output1 = _Convert(output1, NCHW); + const std::vector expectedOutput1 = {-3.0, -3.0, -3.0, -3.0, -3.0, -3.0, -3.0, -3.0, + 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, + -9.0, -9.0, -9.0, -9.0, -9.0, -9.0, -9.0, -9.0, + 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, + -3.0, -3.0, -3.0, -3.0, -3.0, -3.0, -3.0, -3.0, + 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, + -9.0, -9.0, -9.0, -9.0, -9.0, -9.0, -9.0, -9.0, + 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, + -3.0, -3.0, -3.0, -3.0, -3.0, -3.0, -3.0, -3.0, + 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, + -9.0, -9.0, -9.0, -9.0, -9.0, -9.0, -9.0, -9.0, + 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, + }; + output1->writeScaleMap(0.05098, 48.54); + auto gotOutput1 = output1->readMap(); + if (!checkVector(gotOutput1, expectedOutput1.data(), 96, 0.1)) { + MNN_ERROR("PreluTest test 2 failed!\n"); return false; } return true; diff --git a/test/op/StftTest.cpp b/test/op/StftTest.cpp new file mode 100644 index 000000000..894b849c0 --- /dev/null +++ b/test/op/StftTest.cpp @@ -0,0 +1,62 @@ +// +// StftTest.cpp +// MNNTests +// +// Created by MNN on 2024/11/27. +// Copyright © 2018, Alibaba Group Holding Limited +// + +#ifdef MNN_BUILD_AUDIO +#include +#include +#include "MNNTestSuite.h" +#include "TestUtils.h" + +using namespace MNN::Express; +class StftTest : public MNNTestCase { +public: + virtual ~StftTest() = default; + virtual bool run(int precision) { + /* + python: + import torch + freq = 5, sample_rate = 100, duration = 0.2 + t = torch.arange(0, duration, 1.0 / sample_rate) + sine_wave = torch.sin(2 * torch.pi * freq * t) + n_fft = 8, hop_length = 4, win_length = 8 + window = torch.hann_window(win_length) + stft_result = torch.stft(sine_wave, n_fft=n_fft, hop_length=hop_length, + win_length=win_length, window=window, center=False) + magnitude = torch.abs(stft_result).transpose(1, 0) + */ + auto signal = _Input({ 20 }, NCHW); + auto window = _Input({ 8 }, NCHW); + signal->setName("signal"); + window->setName("window"); + const float signalData[] = { + 0.000, 0.309, 0.588, 0.809, 0.951, 1.000, 0.951, 0.809, 0.588, 0.309, + 0.000, -0.309, -0.588, -0.809, -0.951, -1.000, -0.951, -0.809, -0.588, -0.309 + }; + const float windowData[] = { 0.000, 0.146, 0.500, 0.854, 1.000, 0.854, 0.500, 0.146 }; + auto signalPtr = signal->writeMap(); + auto windowPtr = window->writeMap(); + memcpy(signalPtr, signalData, 20 * sizeof(float)); + memcpy(windowPtr, windowData, 8 * sizeof(float)); + auto output = _Stft(signal, window, 8, 4); + const float expectedOutput[] = { + 3.428, 1.958, 0.203, 0.029, 0.013, 2.119, 1.501, 0.261, 0.041, 0.008, + 2.119, 1.501, 0.261, 0.041, 0.008, 3.428, 1.958, 0.203, 0.029, 0.013 + }; + auto gotOutput = output->readMap(); + for (int i = 0; i < 20; ++i) { + auto diff = ::fabsf(gotOutput[i] - expectedOutput[i]); + if (diff > 0.01) { + MNN_ERROR("StftTest test failed: %f - %f!\n", expectedOutput[i], gotOutput[i]); + return false; + } + } + return true; + } +}; +MNNTestSuiteRegister(StftTest, "op/stft"); +#endif // MNN_BUILD_AUDIO \ No newline at end of file diff --git a/test/sharedmem/AhardWareBufferTest.cpp b/test/sharedmem/AhardWareBufferTest.cpp new file mode 100644 index 000000000..f8ea492ef --- /dev/null +++ b/test/sharedmem/AhardWareBufferTest.cpp @@ -0,0 +1,176 @@ +// +// ReplaceTest.cpp +// MNNTests +// +// Created by MNN on 2019/09/10. +// Copyright © 2018, Alibaba Group Holding Limited +// + +#ifdef __ANDROID__ +#include +#include "MNNTestSuite.h" +#include "MNN_generated.h" +#include +#include "TestUtils.h" +#include +#define MNN_OPEN_TIME_TRACE +#include + +using namespace MNN; +using namespace MNN::Express; + +static AHardwareBuffer* creatAHardwareBuffer(int width, int height, void *data){ + // 创建和初始化硬件缓冲区 + AHardwareBuffer_Desc bufferDesc = {}; + bufferDesc.width = width; + bufferDesc.height = height; + bufferDesc.layers = 1; + bufferDesc.format = AHARDWAREBUFFER_FORMAT_R8G8B8A8_UNORM; + bufferDesc.usage = AHARDWAREBUFFER_USAGE_CPU_WRITE_OFTEN | AHARDWAREBUFFER_USAGE_GPU_SAMPLED_IMAGE; + + AHardwareBuffer* buffer = nullptr; + int result = AHardwareBuffer_allocate(&bufferDesc, &buffer); + if(result != 0) { + MNN_ERROR("alloc AHardwareBuffer failed %d\n", result); + } + + if(nullptr != data){ + void* map = nullptr; + ARect rect = { 0, 0, width, height }; // Define the region to lock + result = AHardwareBuffer_lock(buffer, AHARDWAREBUFFER_USAGE_CPU_WRITE_OFTEN, -1, &rect, &map); + if (result != 0) { + MNN_ERROR("Handle lock failed\n"); + } + if (map) { + memcpy(map, data, width * height * 4); + } + + AHardwareBuffer_unlock(buffer, nullptr); + } + return buffer; +} + +static void ReleaseAHardWareBuffer(AHardwareBuffer* buffer){ + AHardwareBuffer_release(buffer); +} + +static void copyDataFromAHardWareBuffer(AHardwareBuffer* buffer, int width, int height, void *data){ + int result = 0; + if(nullptr != data){ + void* map = nullptr; + ARect rect = { 0, 0, width, height }; // Define the region to lock + result = AHardwareBuffer_lock(buffer, AHARDWAREBUFFER_USAGE_CPU_READ_OFTEN, -1, &rect, &map); + if (result != 0) { + MNN_ERROR("Handle lock failed\n"); + } + if (map) { + memcpy(data, map, width * height * 4); + } + + AHardwareBuffer_unlock(buffer, nullptr); + } +} + +static bool checkvalue(const float* ref, const unsigned char* out, int size){ + for(int i = 0; i < size; ++i){ + if(ref[i] != (float)out[i]){ + MNN_ERROR("%d: ref %f != out %f\n", i, ref[i], (float)out[i]); + return false; + } + } + return true; +} + +const int width = 1280; +const int height = 720; +const int channel = 3; +static std::shared_ptr _createModel() { + auto x = _Input({1, channel, height, width}, NCHW, halide_type_of()); + x->setName("Input"); + auto y = _Transpose(x, {0, 1, 3, 2}); + y->setName("Transpose"); + std::unique_ptr net(new NetT); + Variable::save({y}, net.get()); + flatbuffers::FlatBufferBuilder builder; + auto len = MNN::Net::Pack(builder, net.get()); + builder.Finish(len); + return std::shared_ptr(Module::load({"Input"}, {"Transpose"}, builder.GetBufferPointer(), builder.GetSize())); +} +// Test prepareCompute for dynamic-graph usage +class AhardWareBufferTest : public MNNTestCase { +public: + virtual bool run(int precision) { + if (MNN_FORWARD_OPENCL != getCurrentType()) { + MNN_ERROR("Currently forwardtype[%d] run sharedmem/AhardWareBuffer has error, skip it\n", getCurrentType()); + return true; + } + auto net = _createModel(); + auto x = _Input({1, channel, height, width}, NCHW, halide_type_of()); + unsigned char inputData[4 * height * width]; + unsigned char outputData[4 * height * width]; + for(int i = 0; i < 4 * height * width; ++i){ + inputData[i] = i; + } + // ahardwarebuffer default format is nc4hw4 + { + auto xPtr = x->writeMap(); + for (int i = 0; i < channel; ++i){ + for (int j = 0; j < height * width; ++j) { + xPtr[i * height * width + j] = (float)inputData[j * 4 + i]; + } + } + x->unMap(); + } + + auto outputs = net->onForward({x}); + outputs[0] = _Convert(outputs[0], NC4HW4); + auto refPtr = outputs[0]->readMap(); + auto size = outputs[0]->getInfo()->size; + + auto xShared = _Input({1, channel, height, width}, NCHW, halide_type_of()); + auto inputAhardwareBuffer = creatAHardwareBuffer(width, height, inputData); + volatile uint64_t inputValue = (uint64_t)inputAhardwareBuffer; + xShared->setDevicePtr((void*)inputValue, MNN_MEMORY_AHARDWAREBUFFER); + auto outputsShared = net->onForward({xShared}); + auto outputAhardwareBuffer = creatAHardwareBuffer(width, height, nullptr); + volatile uint64_t outputValue = (uint64_t)inputAhardwareBuffer; + { + outputsShared[0]->copyToDevicePtr((void*)outputValue, MNN_MEMORY_AHARDWAREBUFFER); + copyDataFromAHardWareBuffer(inputAhardwareBuffer, width, height, outputData); + if(checkvalue(refPtr, outputData, size) == false){ + MNN_ERROR("sharedmem/AhardWareBuffer test failed!\n"); + return false; + } + } + + // speed + const auto time = 100; + { + MNN::Timer _t; + for (int t = 0; t < time; ++t) { + x->writeMap(); + auto outputs = net->onForward({x}); + outputs[0]->readMap(); + } + float timeCost = _t.durationInUs() / 1000.0f / (float)time; + MNN_PRINT("cpu copy [%d, %d, %d], Avg time: %f ms\n", channel, height, width, timeCost); + } + { + MNN::Timer _t; + for (int t = 0; t < time; ++t) { + xShared->setDevicePtr((void*)inputValue, MNN_MEMORY_AHARDWAREBUFFER); + auto outputs = net->onForward({xShared}); + outputs[0]->copyToDevicePtr((void*)outputValue, MNN_MEMORY_AHARDWAREBUFFER); + } + float timeCost = _t.durationInUs() / 1000.0f / (float)time; + MNN_PRINT("shared memory copy [%d, %d, %d], Avg time: %f ms\n", channel, height, width, timeCost); + } + + ReleaseAHardWareBuffer(inputAhardwareBuffer); + ReleaseAHardWareBuffer(outputAhardwareBuffer); + return true; + } +}; + +MNNTestSuiteRegister(AhardWareBufferTest, "sharedmem/AhardWareBuffer"); +#endif diff --git a/test/speed/StftSpeed.cpp b/test/speed/StftSpeed.cpp new file mode 100644 index 000000000..4d40d3112 --- /dev/null +++ b/test/speed/StftSpeed.cpp @@ -0,0 +1,40 @@ +// +// StftSpeed.cpp +// MNNTests +// +// Created by MNN on 2024/11/27. +// Copyright © 2018, Alibaba Group Holding Limited +// +#ifdef MNN_BUILD_AUDIO + +#include +#include +#include +#include +#define MNN_OPEN_TIME_TRACE +#include +#include "MNNTestSuite.h" +using namespace MNN::Express; +#define SAMPLE 10240 +#define NFFT 256 +#define HOP 128 +#define TIME 100 +class StftSpeed : public MNNTestCase { +public: + virtual bool run(int precision) { + auto x = _Input({SAMPLE}, NHWC); + auto w = _Input({NFFT}, NHWC); + auto y = _Stft(x, w, NFFT, HOP); + { + AUTOTIME; + for (int i = 0; i < TIME; ++i) { + x->writeMap(); + w->writeMap(); + y->readMap(); + } + } + return true; + } +}; +MNNTestSuiteRegister(StftSpeed, "speed/stft"); +#endif // MNN_BUILD_AUDIO \ No newline at end of file diff --git a/tools/audio/CMakeLists.txt b/tools/audio/CMakeLists.txt new file mode 100644 index 000000000..e72aa450d --- /dev/null +++ b/tools/audio/CMakeLists.txt @@ -0,0 +1,44 @@ +IF(MNN_BUILD_AUDIO) + # imgproc submodules start + option(MNN_AUDIO_TEST "Enable audio test" OFF) + + SET(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/../../) + include_directories(${CMAKE_CURRENT_LIST_DIR}/include) + include_directories(${CMAKE_CURRENT_LIST_DIR}/../../3rd_party/imageHelper/) + + # include(${CMAKE_CURRENT_LIST_DIR}/test/CMakeLists.txt) + if(${MNN_AUDIO_TEST}) + add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/test) + endif() + + # include dir + include_directories(${CMAKE_CURRENT_LIST_DIR}/include/) + + # source files + FILE(GLOB AUDIO_SRCS ${CMAKE_CURRENT_LIST_DIR}/source/*.cpp) + + IF(MNN_SEP_BUILD) + IF(MNN_BUILD_SHARED_LIBS) + add_library(MNNAudio SHARED ${AUDIO_SRCS}) + target_link_libraries(MNNAudio MNN MNN_Express) + ELSE() + add_library(MNNAudio STATIC ${AUDIO_SRCS}) + ENDIF() + ELSE() + add_library(MNNAudio OBJECT ${AUDIO_SRCS}) + ENDIF() + # copy header files + IF(CMAKE_SYSTEM_NAME MATCHES "^Android" AND NOT MNN_BUILD_FOR_ANDROID_COMMAND) + IF(NOT NATIVE_INCLUDE_OUTPUT) + set(NATIVE_INCLUDE_OUTPUT ".") + ENDIF() + add_custom_command( + TARGET MNNAudio + POST_BUILD + COMMAND ${CMAKE_COMMAND} + ARGS -E copy_directory ${CMAKE_CURRENT_LIST_DIR}/include ${NATIVE_INCLUDE_OUTPUT} + ) + ELSE() + INSTALL(DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/include/audio DESTINATION include FILES_MATCHING PATTERN *.hpp) + ENDIF() +ENDIF() \ No newline at end of file diff --git a/tools/audio/README.md b/tools/audio/README.md new file mode 100644 index 000000000..f250212a7 --- /dev/null +++ b/tools/audio/README.md @@ -0,0 +1,9 @@ +# MNN audio + +MNN audio is a utils of audio process functions. + +## Usage +Compile MNN with audio, using below command: +```bash +cmake -DMNN_BUILD_AUDIO=ON .. && make -j8 +``` diff --git a/tools/audio/include/audio/audio.hpp b/tools/audio/include/audio/audio.hpp new file mode 100644 index 000000000..3e14912cb --- /dev/null +++ b/tools/audio/include/audio/audio.hpp @@ -0,0 +1,169 @@ +// +// audio.hpp +// MNN +// +// Created by MNN on 2024/11/15. +// Copyright © 2018, Alibaba Group Holding Limited +// + +#ifndef MNN_AUDIO_HPP +#define MNN_AUDIO_HPP + +#include +#include +#include + +namespace MNN { +namespace AUDIO { + +using namespace Express; + +enum WINDOW_TYPE { HAMMING = 0, HANNING = 1, POVEY = 2, RECTANGULAR = 3, BLACKMAN = 4 }; + +/** + * Structure to store parameters for the `melscale_fbanks`. + */ +struct MelscaleParams { + /** Number of mel filterbanks, default is 128. */ + int n_mels = 128; + /** Number of FFT bins, default is 400. */ + int n_fft = 400; + /** Sample rate, default is 16000. */ + int sample_rate = 16000; + /** Scale to use `htk` or `slaney`, default is true mean `htk`. */ + bool htk = true; + /** Divide the triangular mel weights by the width of the mel band, default is false. */ + bool norm = false; + /** Minimum frequency, default is 0. */ + float f_min = 0.0; + /** Maximum frequency, default is 0.(equal to `sample_rate / 2`). */ + float f_max = 0.0; +}; + +/** + * Structure to store parameters for the `spectrogram`. + */ +struct SpectrogramParams { + /** Size of the FFT window, default is 400. */ + int n_fft = 400; + + /** Hop length between frames, default is 0 (equal to `n_fft / 2`). */ + int hop_length = 0; + + /** Window length, default is 0 (equal to `n_fft`). */ + int win_length = 0; + + /** Type of window function, default is Hann window (HANNING). */ + int window_type = HANNING; + + /** Constant padding value on the left side of the input audio, default is 0. */ + int pad_left = 0; + + /** Constant padding value on the right side of the input audio, default is 0. */ + int pad_right = 0; + + /** Whether to apply center padding to the STFT input, default is false. */ + bool center = false; + + /** Whether to normalize the output, default is false. */ + bool normalized = false; + + /** Padding mode of `center = true`, default is reflect (REFLECT). */ + int pad_mode = REFLECT; + + /** Power scaling factor, default is 2.0. */ + float power = 2.0; +}; + +/** + * @brief load audio from file + * @param filename audio file path + * @param frame_offset start frame + * @param num_frames number of frames + * @return pair