diff --git a/tutorials/mobilenetv3_prod/Step6/deploy/inference_cpp/CMakeLists.txt b/tutorials/mobilenetv3_prod/Step6/deploy/inference_cpp/CMakeLists.txt new file mode 100755 index 0000000000..4f148869f8 --- /dev/null +++ b/tutorials/mobilenetv3_prod/Step6/deploy/inference_cpp/CMakeLists.txt @@ -0,0 +1,215 @@ +project(clas_system CXX C) + +option(WITH_MKL "Compile demo with MKL/OpenBlas support, default use MKL." ON) +option(WITH_GPU "Compile demo with GPU/CPU, default use CPU." OFF) +option(WITH_STATIC_LIB "Compile demo with static/shared library, default use static." ON) +option(WITH_TENSORRT "Compile demo with TensorRT." OFF) + +SET(PADDLE_LIB "" CACHE PATH "Location of libraries") +SET(OPENCV_DIR "" CACHE PATH "Location of libraries") +SET(CUDA_LIB "" CACHE PATH "Location of libraries") +SET(CUDNN_LIB "" CACHE PATH "Location of libraries") +SET(TENSORRT_DIR "" CACHE PATH "Compile demo with TensorRT") + +set(DEMO_NAME "clas_system") + + +macro(safe_set_static_flag) + foreach(flag_var + CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELEASE + CMAKE_CXX_FLAGS_MINSIZEREL CMAKE_CXX_FLAGS_RELWITHDEBINFO) + if(${flag_var} MATCHES "/MD") + string(REGEX REPLACE "/MD" "/MT" ${flag_var} "${${flag_var}}") + endif(${flag_var} MATCHES "/MD") + endforeach(flag_var) +endmacro() + +if (WITH_MKL) + ADD_DEFINITIONS(-DUSE_MKL) +endif() + +if(NOT DEFINED PADDLE_LIB) + message(FATAL_ERROR "please set PADDLE_LIB with -DPADDLE_LIB=/path/paddle/lib") +endif() + +if(NOT DEFINED OPENCV_DIR) + message(FATAL_ERROR "please set OPENCV_DIR with -DOPENCV_DIR=/path/opencv") +endif() + + +if (WIN32) + include_directories("${PADDLE_LIB}/paddle/fluid/inference") + include_directories("${PADDLE_LIB}/paddle/include") + link_directories("${PADDLE_LIB}/paddle/fluid/inference") + find_package(OpenCV REQUIRED PATHS ${OPENCV_DIR}/build/ NO_DEFAULT_PATH) + +else () + find_package(OpenCV REQUIRED PATHS ${OPENCV_DIR}/share/OpenCV NO_DEFAULT_PATH) + include_directories("${PADDLE_LIB}/paddle/include") + link_directories("${PADDLE_LIB}/paddle/lib") +endif () +include_directories(${OpenCV_INCLUDE_DIRS}) + +if (WIN32) + add_definitions("/DGOOGLE_GLOG_DLL_DECL=") + set(CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} /bigobj /MTd") + set(CMAKE_C_FLAGS_RELEASE "${CMAKE_C_FLAGS_RELEASE} /bigobj /MT") + set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} /bigobj /MTd") + set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /bigobj /MT") + if (WITH_STATIC_LIB) + safe_set_static_flag() + add_definitions(-DSTATIC_LIB) + endif() +else() + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -o3 -std=c++11") + set(CMAKE_STATIC_LIBRARY_PREFIX "") +endif() +message("flags" ${CMAKE_CXX_FLAGS}) + + +if (WITH_GPU) + if (NOT DEFINED CUDA_LIB OR ${CUDA_LIB} STREQUAL "") + message(FATAL_ERROR "please set CUDA_LIB with -DCUDA_LIB=/path/cuda-8.0/lib64") + endif() + if (NOT WIN32) + if (NOT DEFINED CUDNN_LIB) + message(FATAL_ERROR "please set CUDNN_LIB with -DCUDNN_LIB=/path/cudnn_v7.4/cuda/lib64") + endif() + endif(NOT WIN32) +endif() + +include_directories("${PADDLE_LIB}/third_party/install/protobuf/include") +include_directories("${PADDLE_LIB}/third_party/install/glog/include") +include_directories("${PADDLE_LIB}/third_party/install/gflags/include") +include_directories("${PADDLE_LIB}/third_party/install/xxhash/include") +include_directories("${PADDLE_LIB}/third_party/install/zlib/include") +include_directories("${PADDLE_LIB}/third_party/boost") +include_directories("${PADDLE_LIB}/third_party/eigen3") + +include_directories("${CMAKE_SOURCE_DIR}/") + +if (NOT WIN32) + if (WITH_TENSORRT AND WITH_GPU) + include_directories("${TENSORRT_DIR}/include") + link_directories("${TENSORRT_DIR}/lib") + endif() +endif(NOT WIN32) + +link_directories("${PADDLE_LIB}/third_party/install/zlib/lib") + +link_directories("${PADDLE_LIB}/third_party/install/protobuf/lib") +link_directories("${PADDLE_LIB}/third_party/install/glog/lib") +link_directories("${PADDLE_LIB}/third_party/install/gflags/lib") +link_directories("${PADDLE_LIB}/third_party/install/xxhash/lib") +link_directories("${PADDLE_LIB}/paddle/lib") + + +if(WITH_MKL) + include_directories("${PADDLE_LIB}/third_party/install/mklml/include") + if (WIN32) + set(MATH_LIB ${PADDLE_LIB}/third_party/install/mklml/lib/mklml.lib + ${PADDLE_LIB}/third_party/install/mklml/lib/libiomp5md.lib) + else () + set(MATH_LIB ${PADDLE_LIB}/third_party/install/mklml/lib/libmklml_intel${CMAKE_SHARED_LIBRARY_SUFFIX} + ${PADDLE_LIB}/third_party/install/mklml/lib/libiomp5${CMAKE_SHARED_LIBRARY_SUFFIX}) + execute_process(COMMAND cp -r ${PADDLE_LIB}/third_party/install/mklml/lib/libmklml_intel${CMAKE_SHARED_LIBRARY_SUFFIX} /usr/lib) + endif () + set(MKLDNN_PATH "${PADDLE_LIB}/third_party/install/mkldnn") + if(EXISTS ${MKLDNN_PATH}) + include_directories("${MKLDNN_PATH}/include") + if (WIN32) + set(MKLDNN_LIB ${MKLDNN_PATH}/lib/mkldnn.lib) + else () + set(MKLDNN_LIB ${MKLDNN_PATH}/lib/libmkldnn.so.0) + endif () + endif() +else() + if (WIN32) + set(MATH_LIB ${PADDLE_LIB}/third_party/install/openblas/lib/openblas${CMAKE_STATIC_LIBRARY_SUFFIX}) + else () + set(MATH_LIB ${PADDLE_LIB}/third_party/install/openblas/lib/libopenblas${CMAKE_STATIC_LIBRARY_SUFFIX}) + endif () +endif() + +# Note: libpaddle_inference_api.so/a must put before libpaddle_fluid.so/a +if(WITH_STATIC_LIB) + if(WIN32) + set(DEPS + ${PADDLE_LIB}/paddle/lib/paddle_inference${CMAKE_STATIC_LIBRARY_SUFFIX}) + else() + set(DEPS + ${PADDLE_LIB}/paddle/lib/libpaddle_inference${CMAKE_STATIC_LIBRARY_SUFFIX}) + endif() +else() + if(WIN32) + set(DEPS + ${PADDLE_LIB}/paddle/lib/paddle_inference${CMAKE_SHARED_LIBRARY_SUFFIX}) + else() + set(DEPS + ${PADDLE_LIB}/paddle/lib/libpaddle_inference${CMAKE_SHARED_LIBRARY_SUFFIX}) + endif() +endif(WITH_STATIC_LIB) + +if (NOT WIN32) + set(DEPS ${DEPS} + ${MATH_LIB} ${MKLDNN_LIB} + glog gflags protobuf z xxhash + ) + if(EXISTS "${PADDLE_LIB}/third_party/install/snappystream/lib") + set(DEPS ${DEPS} snappystream) + endif() + if (EXISTS "${PADDLE_LIB}/third_party/install/snappy/lib") + set(DEPS ${DEPS} snappy) + endif() +else() + set(DEPS ${DEPS} + ${MATH_LIB} ${MKLDNN_LIB} + glog gflags_static libprotobuf xxhash) + set(DEPS ${DEPS} libcmt shlwapi) + if (EXISTS "${PADDLE_LIB}/third_party/install/snappy/lib") + set(DEPS ${DEPS} snappy) + endif() + if(EXISTS "${PADDLE_LIB}/third_party/install/snappystream/lib") + set(DEPS ${DEPS} snappystream) + endif() +endif(NOT WIN32) + + +if(WITH_GPU) + if(NOT WIN32) + if (WITH_TENSORRT) + set(DEPS ${DEPS} ${TENSORRT_DIR}/lib/libnvinfer${CMAKE_SHARED_LIBRARY_SUFFIX}) + set(DEPS ${DEPS} ${TENSORRT_DIR}/lib/libnvinfer_plugin${CMAKE_SHARED_LIBRARY_SUFFIX}) + endif() + set(DEPS ${DEPS} ${CUDA_LIB}/libcudart${CMAKE_SHARED_LIBRARY_SUFFIX}) + set(DEPS ${DEPS} ${CUDNN_LIB}/libcudnn${CMAKE_SHARED_LIBRARY_SUFFIX}) + else() + set(DEPS ${DEPS} ${CUDA_LIB}/cudart${CMAKE_STATIC_LIBRARY_SUFFIX} ) + set(DEPS ${DEPS} ${CUDA_LIB}/cublas${CMAKE_STATIC_LIBRARY_SUFFIX} ) + set(DEPS ${DEPS} ${CUDNN_LIB}/cudnn${CMAKE_STATIC_LIBRARY_SUFFIX}) + endif() +endif() + + +if (NOT WIN32) + set(EXTERNAL_LIB "-ldl -lrt -lgomp -lz -lm -lpthread") + set(DEPS ${DEPS} ${EXTERNAL_LIB}) +endif() + +set(DEPS ${DEPS} ${OpenCV_LIBS}) + +AUX_SOURCE_DIRECTORY(./src SRCS) +add_executable(${DEMO_NAME} ${SRCS}) + +target_link_libraries(${DEMO_NAME} ${DEPS}) + +if (WIN32 AND WITH_MKL) + add_custom_command(TARGET ${DEMO_NAME} POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different ${PADDLE_LIB}/third_party/install/mklml/lib/mklml.dll ./mklml.dll + COMMAND ${CMAKE_COMMAND} -E copy_if_different ${PADDLE_LIB}/third_party/install/mklml/lib/libiomp5md.dll ./libiomp5md.dll + COMMAND ${CMAKE_COMMAND} -E copy_if_different ${PADDLE_LIB}/third_party/install/mkldnn/lib/mkldnn.dll ./mkldnn.dll + COMMAND ${CMAKE_COMMAND} -E copy_if_different ${PADDLE_LIB}/third_party/install/mklml/lib/mklml.dll ./release/mklml.dll + COMMAND ${CMAKE_COMMAND} -E copy_if_different ${PADDLE_LIB}/third_party/install/mklml/lib/libiomp5md.dll ./release/libiomp5md.dll + COMMAND ${CMAKE_COMMAND} -E copy_if_different ${PADDLE_LIB}/third_party/install/mkldnn/lib/mkldnn.dll ./release/mkldnn.dll + ) +endif() diff --git a/tutorials/mobilenetv3_prod/Step6/deploy/inference_cpp/README.md b/tutorials/mobilenetv3_prod/Step6/deploy/inference_cpp/README.md index e69de29bb2..8ccb26d808 100644 --- a/tutorials/mobilenetv3_prod/Step6/deploy/inference_cpp/README.md +++ b/tutorials/mobilenetv3_prod/Step6/deploy/inference_cpp/README.md @@ -0,0 +1,243 @@ +# 服务器端C++预测 + +本教程将介绍在服务器端部署mobilenet_v3_small模型的详细步骤。 + + +## 1. 准备环境 + +### 运行准备 +- Linux环境,推荐使用docker[安装说明](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/docker/linux-docker.html)。 + +### 1.1 编译opencv库 + +* 首先需要从opencv官网上下载Linux环境下的源码,以3.4.7版本为例,下载及解压缩命令如下: + +``` +wget https://github.com/opencv/opencv/archive/3.4.7.tar.gz +tar -xvf 3.4.7.tar.gz +``` + +最终可以在当前目录下看到`opencv-3.4.7/`的文件夹。 + +* 编译opencv,首先设置opencv源码路径(`root_path`)以及安装路径(`install_path`),`root_path`为下载的opencv源码路径,`install_path`为opencv的安装路径。在本例中,源码路径即为当前目录下的`opencv-3.4.7/`。 + +```shell +cd ./opencv-3.4.7 +export root_path=$PWD +export install_path=${root_path}/opencv3 +``` + +* 然后在opencv源码路径下,按照下面的命令进行编译。 + +```shell +rm -rf build +mkdir build +cd build + +cmake .. \ + -DCMAKE_INSTALL_PREFIX=${install_path} \ + -DCMAKE_BUILD_TYPE=Release \ + -DBUILD_SHARED_LIBS=OFF \ + -DWITH_IPP=OFF \ + -DBUILD_IPP_IW=OFF \ + -DWITH_LAPACK=OFF \ + -DWITH_EIGEN=OFF \ + -DCMAKE_INSTALL_LIBDIR=lib64 \ + -DWITH_ZLIB=ON \ + -DBUILD_ZLIB=ON \ + -DWITH_JPEG=ON \ + -DBUILD_JPEG=ON \ + -DWITH_PNG=ON \ + -DBUILD_PNG=ON \ + -DWITH_TIFF=ON \ + -DBUILD_TIFF=ON + +make -j +make install +``` + +* `make install`完成之后,会在该文件夹下生成opencv头文件和库文件,用于后面的代码编译。 + +以opencv3.4.7版本为例,最终在安装路径下的文件结构如下所示。**注意**:不同的opencv版本,下述的文件结构可能不同。 + +``` +opencv3/ +|-- bin :可执行文件 +|-- include :头文件 +|-- lib64 :库文件 +|-- share :部分第三方库 +``` + +### 1.2 下载或者编译Paddle预测库 + +* 有2种方式获取Paddle预测库,下面进行详细介绍。 + +#### 1.2.1 预测库源码编译 +* 如果希望获取最新预测库特性,可以从Paddle github上克隆最新代码,源码编译预测库。 +* 可以参考[Paddle预测库官网](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/guides/05_inference_deployment/inference/build_and_install_lib_cn.html#id16)的说明,从github上获取Paddle代码,然后进行编译,生成最新的预测库。使用git获取代码方法如下。 + +```shell +git clone https://github.com/PaddlePaddle/Paddle.git +``` + +* 进入Paddle目录后,使用如下命令编译。 + +```shell +rm -rf build +mkdir build +cd build + +cmake .. \ + -DWITH_CONTRIB=OFF \ + -DWITH_MKL=ON \ + -DWITH_MKLDNN=ON \ + -DWITH_TESTING=OFF \ + -DCMAKE_BUILD_TYPE=Release \ + -DWITH_INFERENCE_API_TEST=OFF \ + -DON_INFER=ON \ + -DWITH_PYTHON=ON +make -j +make inference_lib_dist +``` + +更多编译参数选项可以参考Paddle C++预测库官网:[https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/guides/05_inference_deployment/inference/build_and_install_lib_cn.html#id16](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/guides/05_inference_deployment/inference/build_and_install_lib_cn.html#id16)。 + + +* 编译完成之后,可以在`build/paddle_inference_install_dir/`文件下看到生成了以下文件及文件夹。 + +``` +build/paddle_inference_install_dir/ +|-- CMakeCache.txt +|-- paddle +|-- third_party +|-- version.txt +``` + +其中`paddle`就是之后进行C++预测时所需的Paddle库,`version.txt`中包含当前预测库的版本信息。 + +#### 1.2.2 直接下载安装 + +* [Paddle预测库官网](https://paddleinference.paddlepaddle.org.cn/user_guides/download_lib.html)上提供了不同cuda版本的Linux预测库,可以在官网查看并选择合适的预测库版本。 + + 以`manylinux_cuda11.1_cudnn8.1_avx_mkl_trt7_gcc8.2`版本为例,使用下述命令下载并解压: + + +```shell +wget https://paddle-inference-lib.bj.bcebos.com/2.2.2/cxx_c/Linux/GPU/x86-64_gcc8.2_avx_mkl_cuda11.1_cudnn8.1.1_trt7.2.3.4/paddle_inference.tgz + +tar -xvf paddle_inference.tgz +``` + + +最终会在当前的文件夹中生成`paddle_inference/`的子文件夹,文件内容和上述的paddle_inference_install_dir一样。 + + +## 2 开始运行 + +### 2.1 将模型导出为inference model + +* 可以参考[模型导出](../../tools/export_model.py),导出`inference model`,用于模型预测。得到预测模型后,假设模型文件放在`inference`目录下,则目录结构如下。 + +``` +mobilenet_v3_small_infer/ +|--inference.pdmodel +|--inference.pdiparams +|--inference.pdiparams.info +``` +**注意**:上述文件中,`inference.pdmodel`文件存储了模型结构信息,`inference.pdiparams`文件存储了模型参数信息。注意两个文件的路径需要与配置文件`tools/config.txt`中的`cls_model_path`和`cls_params_path`参数对应一致。 + +### 2.2 编译 C++预测demo + +* 编译命令如下,其中Paddle C++预测库、opencv等其他依赖库的地址需要换成自己机器上的实际地址。 + + +```shell +sh tools/build.sh +``` + +具体地,`tools/build.sh`中内容如下。 + +```shell +OPENCV_DIR=your_opencv_dir +LIB_DIR=your_paddle_inference_dir +CUDA_LIB_DIR=your_cuda_lib_dir +CUDNN_LIB_DIR=your_cudnn_lib_dir +TENSORRT_DIR=your_tensorrt_lib_dir + +BUILD_DIR=build +rm -rf ${BUILD_DIR} +mkdir ${BUILD_DIR} +cd ${BUILD_DIR} +cmake .. \ + -DPADDLE_LIB=${LIB_DIR} \ + -DWITH_MKL=ON \ + -DDEMO_NAME=clas_system \ + -DWITH_GPU=OFF \ + -DWITH_STATIC_LIB=OFF \ + -DWITH_TENSORRT=OFF \ + -DTENSORRT_DIR=${TENSORRT_DIR} \ + -DOPENCV_DIR=${OPENCV_DIR} \ + -DCUDNN_LIB=${CUDNN_LIB_DIR} \ + -DCUDA_LIB=${CUDA_LIB_DIR} \ + +make -j +``` + +上述命令中, + +* `OPENCV_DIR`为opencv编译安装的地址(本例中为`opencv-3.4.7/opencv3`文件夹的路径); + +* `LIB_DIR`为下载的Paddle预测库(`paddle_inference`文件夹),或编译生成的Paddle预测库(`build/paddle_inference_install_dir`文件夹)的路径; + +* `CUDA_LIB_DIR`为cuda库文件地址,在docker中为`/usr/local/cuda/lib64`; + +* `CUDNN_LIB_DIR`为cudnn库文件地址,在docker中为`/usr/lib64`。 + +* `TENSORRT_DIR`是tensorrt库文件地址,在dokcer中为`/usr/local/TensorRT-7.2.3.4/`,TensorRT需要结合GPU使用。 + +在执行上述命令,编译完成之后,会在当前路径下生成`build`文件夹,其中生成一个名为`clas_system`的可执行文件。 + + +### 运行demo +* 首先修改`tools/config.txt`中对应字段: + * use_gpu:是否使用GPU; + * gpu_id:使用的GPU卡号; + * gpu_mem:显存; + * cpu_math_library_num_threads:底层科学计算库所用线程的数量; + * use_mkldnn:是否使用MKLDNN加速; + * use_tensorrt: 是否使用tensorRT进行加速; + * use_fp16:是否使用半精度浮点数进行计算,该选项仅在use_tensorrt为true时有效; + * cls_model_path:预测模型结构文件路径; + * cls_params_path:预测模型参数文件路径; + * resize_short_size:预处理时图像缩放大小; + * crop_size:预处理时图像裁剪后的大小。 + +* 然后修改`tools/run.sh`: + * `./build/clas_system ./tools/config.txt ../../images/demo.jpg` + * 上述命令中分别为:编译得到的可执行文件`clas_system`;运行时的配置文件`config.txt`;待预测的图像。 + +* 最后执行以下命令,完成对一幅图像的分类。 + +```shell +sh tools/run.sh +``` +对于下面的图像进行预测 + +
+ +
+ +* 最终屏幕上会输出结果,如下所示 +``` +class id: 8 + +score: 0.9014717937 + +Current image path: ../../images/demo.jpg + +Current time cost: 0.0473620000 s, average time cost in all: 0.0473620000 s. + +``` + +表示预测的类别ID是`8`,置信度为`0.901`,该结果与基于训练引擎的结果完全一致。 +其中`class id`表示置信度最高的类别对应的id,score表示图片属于该类别的概率。 diff --git a/tutorials/mobilenetv3_prod/Step6/deploy/inference_cpp/include/cls.h b/tutorials/mobilenetv3_prod/Step6/deploy/inference_cpp/include/cls.h new file mode 100644 index 0000000000..f7a8711e7d --- /dev/null +++ b/tutorials/mobilenetv3_prod/Step6/deploy/inference_cpp/include/cls.h @@ -0,0 +1,91 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "opencv2/core.hpp" +#include "opencv2/imgcodecs.hpp" +#include "opencv2/imgproc.hpp" +#include "paddle_inference_api.h" +#include +#include +#include +#include +#include + +#include +#include +#include + +#include + +using namespace paddle_infer; + +namespace MobileNetV3 { + +class Classifier { +public: + explicit Classifier(const std::string &model_path, + const std::string ¶ms_path, const bool &use_gpu, + const int &gpu_id, const int &gpu_mem, + const int &cpu_math_library_num_threads, + const bool &use_mkldnn, const bool &use_tensorrt, + const bool &use_fp16, const int &resize_short_size, + const int &crop_size) { + this->use_gpu_ = use_gpu; + this->gpu_id_ = gpu_id; + this->gpu_mem_ = gpu_mem; + this->cpu_math_library_num_threads_ = cpu_math_library_num_threads; + this->use_mkldnn_ = use_mkldnn; + this->use_tensorrt_ = use_tensorrt; + this->use_fp16_ = use_fp16; + + this->resize_short_size_ = resize_short_size; + this->crop_size_ = crop_size; + + LoadModel(model_path, params_path); + } + + // Load Paddle inference model + void LoadModel(const std::string &model_path, const std::string ¶ms_path); + + // Run predictor + double Run(cv::Mat &img); + +private: + std::shared_ptr predictor_; + + bool use_gpu_ = false; + int gpu_id_ = 0; + int gpu_mem_ = 4000; + int cpu_math_library_num_threads_ = 4; + bool use_mkldnn_ = false; + bool use_tensorrt_ = false; + bool use_fp16_ = false; + + std::vector mean_ = {0.485f, 0.456f, 0.406f}; + std::vector scale_ = {1 / 0.229f, 1 / 0.224f, 1 / 0.225f}; + bool is_scale_ = true; + + int resize_short_size_ = 256; + int crop_size_ = 224; + + // pre-process + ResizeImg resize_op_; + Normalize normalize_op_; + Permute permute_op_; + CenterCropImg crop_op_; +}; + +} // namespace MobileNetV3 diff --git a/tutorials/mobilenetv3_prod/Step6/deploy/inference_cpp/include/cls_config.h b/tutorials/mobilenetv3_prod/Step6/deploy/inference_cpp/include/cls_config.h new file mode 100644 index 0000000000..231738b4b5 --- /dev/null +++ b/tutorials/mobilenetv3_prod/Step6/deploy/inference_cpp/include/cls_config.h @@ -0,0 +1,88 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "include/utility.h" + +namespace MobileNetV3 { + +class ClsConfig { +public: + explicit ClsConfig(const std::string &config_file) { + config_map_ = LoadConfig(config_file); + + this->use_gpu = bool(stoi(config_map_["use_gpu"])); + + this->gpu_id = stoi(config_map_["gpu_id"]); + + this->gpu_mem = stoi(config_map_["gpu_mem"]); + + this->cpu_math_library_num_threads = + stoi(config_map_["cpu_math_library_num_threads"]); + + this->use_mkldnn = bool(stoi(config_map_["use_mkldnn"])); + + this->use_tensorrt = bool(stoi(config_map_["use_tensorrt"])); + this->use_fp16 = bool(stoi(config_map_["use_fp16"])); + + this->cls_model_path.assign(config_map_["cls_model_path"]); + + this->cls_params_path.assign(config_map_["cls_params_path"]); + + this->resize_short_size = stoi(config_map_["resize_short_size"]); + + this->crop_size = stoi(config_map_["crop_size"]); + } + + bool use_gpu = false; + + int gpu_id = 0; + + int gpu_mem = 4000; + + int cpu_math_library_num_threads = 1; + + bool use_mkldnn = false; + + bool use_tensorrt = false; + bool use_fp16 = false; + + std::string cls_model_path; + + std::string cls_params_path; + + int resize_short_size = 256; + int crop_size = 224; + + void PrintConfigInfo(); + +private: + // Load configuration + std::map LoadConfig(const std::string &config_file); + + std::vector split(const std::string &str, + const std::string &delim); + + std::map config_map_; +}; + +} // namespace MobileNetV3 diff --git a/tutorials/mobilenetv3_prod/Step6/deploy/inference_cpp/include/preprocess_op.h b/tutorials/mobilenetv3_prod/Step6/deploy/inference_cpp/include/preprocess_op.h new file mode 100644 index 0000000000..db86d72feb --- /dev/null +++ b/tutorials/mobilenetv3_prod/Step6/deploy/inference_cpp/include/preprocess_op.h @@ -0,0 +1,56 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "opencv2/core.hpp" +#include "opencv2/imgcodecs.hpp" +#include "opencv2/imgproc.hpp" +#include +#include +#include +#include +#include + +#include +#include +#include + +using namespace std; + +namespace MobileNetV3 { + +class Normalize { +public: + virtual void Run(cv::Mat *im, const std::vector &mean, + const std::vector &scale, const bool is_scale = true); +}; + +// RGB -> CHW +class Permute { +public: + virtual void Run(const cv::Mat *im, float *data); +}; + +class CenterCropImg { +public: + virtual void Run(cv::Mat &im, const int crop_size = 224); +}; + +class ResizeImg { +public: + virtual void Run(const cv::Mat &img, cv::Mat &resize_img, int max_size_len); +}; + +} // namespace MobileNetV3 diff --git a/tutorials/mobilenetv3_prod/Step6/deploy/inference_cpp/include/utility.h b/tutorials/mobilenetv3_prod/Step6/deploy/inference_cpp/include/utility.h new file mode 100644 index 0000000000..b2ce841a70 --- /dev/null +++ b/tutorials/mobilenetv3_prod/Step6/deploy/inference_cpp/include/utility.h @@ -0,0 +1,46 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "opencv2/core.hpp" +#include "opencv2/imgcodecs.hpp" +#include "opencv2/imgproc.hpp" + +namespace MobileNetV3 { + +class Utility { +public: + static std::vector ReadDict(const std::string &path); + + // template + // inline static size_t argmax(ForwardIterator first, ForwardIterator last) + // { + // return std::distance(first, std::max_element(first, last)); + // } +}; + +} // namespace MobileNetV3 diff --git a/tutorials/mobilenetv3_prod/Step6/deploy/inference_cpp/src/cls.cpp b/tutorials/mobilenetv3_prod/Step6/deploy/inference_cpp/src/cls.cpp new file mode 100644 index 0000000000..febf4c70c2 --- /dev/null +++ b/tutorials/mobilenetv3_prod/Step6/deploy/inference_cpp/src/cls.cpp @@ -0,0 +1,103 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +namespace MobileNetV3 { + +void Classifier::LoadModel(const std::string &model_path, + const std::string ¶ms_path) { + paddle_infer::Config config; + config.SetModel(model_path, params_path); + + if (this->use_gpu_) { + config.EnableUseGpu(this->gpu_mem_, this->gpu_id_); + if (this->use_tensorrt_) { + config.EnableTensorRtEngine( + 1 << 20, 1, 3, + this->use_fp16_ ? paddle_infer::Config::Precision::kHalf + : paddle_infer::Config::Precision::kFloat32, + false, false); + } + } else { + config.DisableGpu(); + if (this->use_mkldnn_) { + config.EnableMKLDNN(); + // cache 10 different shapes for mkldnn to avoid memory leak + config.SetMkldnnCacheCapacity(10); + } + config.SetCpuMathLibraryNumThreads(this->cpu_math_library_num_threads_); + } + + config.SwitchUseFeedFetchOps(false); + // true for multiple input + config.SwitchSpecifyInputNames(true); + + config.SwitchIrOptim(true); + + config.EnableMemoryOptim(); + config.DisableGlogInfo(); + + this->predictor_ = CreatePredictor(config); +} + +double Classifier::Run(cv::Mat &img) { + cv::Mat srcimg; + cv::Mat resize_img; + img.copyTo(srcimg); + + this->resize_op_.Run(img, resize_img, this->resize_short_size_); + + this->crop_op_.Run(resize_img, this->crop_size_); + + this->normalize_op_.Run(&resize_img, this->mean_, this->scale_, + this->is_scale_); + std::vector input(1 * 3 * resize_img.rows * resize_img.cols, 0.0f); + this->permute_op_.Run(&resize_img, input.data()); + + auto input_names = this->predictor_->GetInputNames(); + auto input_t = this->predictor_->GetInputHandle(input_names[0]); + input_t->Reshape({1, 3, resize_img.rows, resize_img.cols}); + auto start = std::chrono::system_clock::now(); + input_t->CopyFromCpu(input.data()); + this->predictor_->Run(); + + std::vector out_data; + auto output_names = this->predictor_->GetOutputNames(); + auto output_t = this->predictor_->GetOutputHandle(output_names[0]); + std::vector output_shape = output_t->shape(); + int out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1, + std::multiplies()); + + out_data.resize(out_num); + output_t->CopyToCpu(out_data.data()); + auto end = std::chrono::system_clock::now(); + auto duration = + std::chrono::duration_cast(end - start); + double cost_time = double(duration.count()) * + std::chrono::microseconds::period::num / + std::chrono::microseconds::period::den; + + int maxPosition = + max_element(out_data.begin(), out_data.end()) - out_data.begin(); + std::cout << "result: " << std::endl; + std::cout << "\tclass id: " << maxPosition << std::endl; + std::cout << std::fixed << std::setprecision(10) + << "\tscore: " << double(out_data[maxPosition]) << std::endl; + + return cost_time; +} + +} // namespace MobileNetV3 diff --git a/tutorials/mobilenetv3_prod/Step6/deploy/inference_cpp/src/cls_config.cpp b/tutorials/mobilenetv3_prod/Step6/deploy/inference_cpp/src/cls_config.cpp new file mode 100644 index 0000000000..3955890970 --- /dev/null +++ b/tutorials/mobilenetv3_prod/Step6/deploy/inference_cpp/src/cls_config.cpp @@ -0,0 +1,64 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +namespace MobileNetV3 { + +std::vector ClsConfig::split(const std::string &str, + const std::string &delim) { + std::vector res; + if ("" == str) + return res; + char *strs = new char[str.length() + 1]; + std::strcpy(strs, str.c_str()); + + char *d = new char[delim.length() + 1]; + std::strcpy(d, delim.c_str()); + + char *p = std::strtok(strs, d); + while (p) { + std::string s = p; + res.push_back(s); + p = std::strtok(NULL, d); + } + + return res; +} + +std::map +ClsConfig::LoadConfig(const std::string &config_path) { + auto config = Utility::ReadDict(config_path); + + std::map dict; + for (int i = 0; i < config.size(); i++) { + // pass for empty line or comment + if (config[i].size() <= 1 || config[i][0] == '#') { + continue; + } + std::vector res = split(config[i], " "); + dict[res[0]] = res[1]; + } + return dict; +} + +void ClsConfig::PrintConfigInfo() { + std::cout << "=======Paddle Class inference config======" << std::endl; + for (auto iter = config_map_.begin(); iter != config_map_.end(); iter++) { + std::cout << iter->first << " : " << iter->second << std::endl; + } + std::cout << "=======End of Paddle Class inference config======" << std::endl; +} + +} // namespace MobileNetV3 diff --git a/tutorials/mobilenetv3_prod/Step6/deploy/inference_cpp/src/main.cpp b/tutorials/mobilenetv3_prod/Step6/deploy/inference_cpp/src/main.cpp new file mode 100644 index 0000000000..919bcabe3e --- /dev/null +++ b/tutorials/mobilenetv3_prod/Step6/deploy/inference_cpp/src/main.cpp @@ -0,0 +1,94 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "opencv2/core.hpp" +#include "opencv2/imgcodecs.hpp" +#include "opencv2/imgproc.hpp" +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include + +using namespace std; +using namespace cv; +using namespace MobileNetV3; + +int main(int argc, char **argv) { + if (argc < 3) { + std::cerr << "[ERROR] usage: " << argv[0] + << " configure_filepath image_path\n"; + exit(1); + } + + ClsConfig config(argv[1]); + + config.PrintConfigInfo(); + + std::string path(argv[2]); + + std::vector img_files_list; + if (cv::utils::fs::isDirectory(path)) { + std::vector filenames; + cv::glob(path, filenames); + for (auto f : filenames) { + img_files_list.push_back(f); + } + } else { + img_files_list.push_back(path); + } + + std::cout << "img_file_list length: " << img_files_list.size() << std::endl; + + Classifier classifier(config.cls_model_path, config.cls_params_path, + config.use_gpu, config.gpu_id, config.gpu_mem, + config.cpu_math_library_num_threads, config.use_mkldnn, + config.use_tensorrt, config.use_fp16, + config.resize_short_size, config.crop_size); + + double elapsed_time = 0.0; + int warmup_iter = img_files_list.size() > 5 ? 5 : 0; + for (int idx = 0; idx < img_files_list.size(); ++idx) { + std::string img_path = img_files_list[idx]; + cv::Mat srcimg = cv::imread(img_path, cv::IMREAD_COLOR); + if (!srcimg.data) { + std::cerr << "[ERROR] image read failed! image path: " << img_path + << "\n"; + exit(-1); + } + + cv::cvtColor(srcimg, srcimg, cv::COLOR_BGR2RGB); + + double run_time = classifier.Run(srcimg); + if (idx >= warmup_iter) { + elapsed_time += run_time; + std::cout << "Current image path: " << img_path << std::endl; + std::cout << "Current time cost: " << run_time << " s, " + << "average time cost in all: " + << elapsed_time / (idx + 1 - warmup_iter) << " s." << std::endl; + } else { + std::cout << "Current time cost: " << run_time << " s." << std::endl; + } + } + + return 0; +} diff --git a/tutorials/mobilenetv3_prod/Step6/deploy/inference_cpp/src/preprocess_op.cpp b/tutorials/mobilenetv3_prod/Step6/deploy/inference_cpp/src/preprocess_op.cpp new file mode 100644 index 0000000000..ce506bb374 --- /dev/null +++ b/tutorials/mobilenetv3_prod/Step6/deploy/inference_cpp/src/preprocess_op.cpp @@ -0,0 +1,90 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "opencv2/core.hpp" +#include "opencv2/imgcodecs.hpp" +#include "opencv2/imgproc.hpp" +#include "paddle_api.h" +#include "paddle_inference_api.h" +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include + +namespace MobileNetV3 { + +void Permute::Run(const cv::Mat *im, float *data) { + int rh = im->rows; + int rw = im->cols; + int rc = im->channels(); + for (int i = 0; i < rc; ++i) { + cv::extractChannel(*im, cv::Mat(rh, rw, CV_32FC1, data + i * rh * rw), i); + } +} + +void Normalize::Run(cv::Mat *im, const std::vector &mean, + const std::vector &scale, const bool is_scale) { + double e = 1.0; + if (is_scale) { + e /= 255.0; + } + (*im).convertTo(*im, CV_32FC3, e); + for (int h = 0; h < im->rows; h++) { + for (int w = 0; w < im->cols; w++) { + im->at(h, w)[0] = + (im->at(h, w)[0] - mean[0]) * scale[0]; + im->at(h, w)[1] = + (im->at(h, w)[1] - mean[1]) * scale[1]; + im->at(h, w)[2] = + (im->at(h, w)[2] - mean[2]) * scale[2]; + } + } +} + +void CenterCropImg::Run(cv::Mat &img, const int crop_size) { + int resize_w = img.cols; + int resize_h = img.rows; + int w_start = int((resize_w - crop_size) / 2); + int h_start = int((resize_h - crop_size) / 2); + cv::Rect rect(w_start, h_start, crop_size, crop_size); + img = img(rect); +} + +void ResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img, + int resize_short_size) { + int w = img.cols; + int h = img.rows; + + float ratio = 1.f; + if (h < w) { + ratio = float(resize_short_size) / float(h); + } else { + ratio = float(resize_short_size) / float(w); + } + + int resize_h = round(float(h) * ratio); + int resize_w = round(float(w) * ratio); + + cv::resize(img, resize_img, cv::Size(resize_w, resize_h)); +} + +} // namespace MobileNetV3 diff --git a/tutorials/mobilenetv3_prod/Step6/deploy/inference_cpp/src/utility.cpp b/tutorials/mobilenetv3_prod/Step6/deploy/inference_cpp/src/utility.cpp new file mode 100644 index 0000000000..30d9cb8ec4 --- /dev/null +++ b/tutorials/mobilenetv3_prod/Step6/deploy/inference_cpp/src/utility.cpp @@ -0,0 +1,39 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include + +namespace MobileNetV3 { + +std::vector Utility::ReadDict(const std::string &path) { + std::ifstream in(path); + std::string line; + std::vector m_vec; + if (in) { + while (getline(in, line)) { + m_vec.push_back(line); + } + } else { + std::cout << "no such label file: " << path << ", exit the program..." + << std::endl; + exit(1); + } + return m_vec; +} + +} // namespace MobileNetV3 diff --git a/tutorials/mobilenetv3_prod/Step6/deploy/inference_cpp/tools/build.sh b/tutorials/mobilenetv3_prod/Step6/deploy/inference_cpp/tools/build.sh new file mode 100755 index 0000000000..b2ec278e67 --- /dev/null +++ b/tutorials/mobilenetv3_prod/Step6/deploy/inference_cpp/tools/build.sh @@ -0,0 +1,21 @@ +OPENCV_DIR=../opencv-3.4.7/opencv3/ +LIB_DIR=../paddle_inference/ +CUDA_LIB_DIR=/usr/local/cuda/lib64 +CUDNN_LIB_DIR=/usr/lib64 +TENSORRT_DIR=/usr/local/TensorRT-7.2.3.4 + +BUILD_DIR=build +rm -rf ${BUILD_DIR} +mkdir ${BUILD_DIR} +cd ${BUILD_DIR} +cmake .. \ + -DPADDLE_LIB=${LIB_DIR} \ + -DWITH_MKL=ON \ + -DWITH_GPU=OFF \ + -DWITH_STATIC_LIB=OFF \ + -DUSE_TENSORRT=OFF \ + -DOPENCV_DIR=${OPENCV_DIR} \ + -DCUDNN_LIB=${CUDNN_LIB_DIR} \ + -DCUDA_LIB=${CUDA_LIB_DIR} \ + +make -j diff --git a/tutorials/mobilenetv3_prod/Step6/deploy/inference_cpp/tools/config.txt b/tutorials/mobilenetv3_prod/Step6/deploy/inference_cpp/tools/config.txt new file mode 100755 index 0000000000..039cf3b521 --- /dev/null +++ b/tutorials/mobilenetv3_prod/Step6/deploy/inference_cpp/tools/config.txt @@ -0,0 +1,14 @@ +# model load config +use_gpu 0 +gpu_id 0 +gpu_mem 4000 +cpu_math_library_num_threads 10 +use_mkldnn 1 +use_tensorrt 0 +use_fp16 0 + +# cls config +cls_model_path ./mobilenet_v3_small_infer/inference.pdmodel +cls_params_path ./mobilenet_v3_small_infer/inference.pdiparams +resize_short_size 256 +crop_size 224 diff --git a/tutorials/mobilenetv3_prod/Step6/deploy/inference_cpp/tools/run.sh b/tutorials/mobilenetv3_prod/Step6/deploy/inference_cpp/tools/run.sh new file mode 100755 index 0000000000..c4ceae9156 --- /dev/null +++ b/tutorials/mobilenetv3_prod/Step6/deploy/inference_cpp/tools/run.sh @@ -0,0 +1 @@ +./build/clas_system ./tools/config.txt ../../images/demo.jpg