Skip to content

Commit

Permalink
[runtime/ios] add iOS runtime (#1549)
Browse files Browse the repository at this point in the history
* Add iOS build files and test application.

* Clean up code and add license information

* Update Podfile

* Add license

* Fix lint tab check

* Fix lint trailing whitespace

* Simplify build and fix some cpplint

* Fix some cpplint

* Add NOLINT to Objective C header file

* Merge ios_asr_model into torch_asr_model

* Fix lint

* Fix lint

* Fix code style

Co-authored-by: 马丹 <[email protected]>
  • Loading branch information
Ma-Dan and 马丹 authored Nov 14, 2022
1 parent 74d2826 commit d7fba09
Show file tree
Hide file tree
Showing 34 changed files with 2,327 additions and 10 deletions.
24 changes: 17 additions & 7 deletions runtime/core/cmake/libtorch.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ if(TORCH)
add_definitions(-DUSE_GPU)
set(CUDA_NAME "cu113")
endif()
if(IOS)
add_definitions(-DIOS)
endif()
if(${CMAKE_SYSTEM_NAME} STREQUAL "Windows")
if(GPU)
message(FATAL_ERROR "GPU on Windows is unsupported, you can use CPU version")
Expand Down Expand Up @@ -40,17 +43,24 @@ if(TORCH)
endif()
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cpu/libtorch-macos-${PYTORCH_VERSION}.zip")
set(URL_HASH "SHA256=07cac2c36c34f13065cb9559ad5270109ecbb468252fb0aeccfd89322322a2b5")
elseif(${CMAKE_SYSTEM_NAME} STREQUAL "iOS")
if(GPU)
message(FATAL_ERROR "GPU on iOS is unsupported, you can use CPU version")
endif()
else()
message(FATAL_ERROR "Unsupported CMake System Name '${CMAKE_SYSTEM_NAME}' (expected 'Windows', 'Linux' or 'Darwin')")
endif()

FetchContent_Declare(libtorch
URL ${LIBTORCH_URL}
URL_HASH ${URL_HASH}
)
FetchContent_MakeAvailable(libtorch)
find_package(Torch REQUIRED PATHS ${libtorch_SOURCE_DIR} NO_DEFAULT_PATH)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS} -DC10_USE_GLOG")
# iOS use LibTorch from pod install
if(NOT IOS)
FetchContent_Declare(libtorch
URL ${LIBTORCH_URL}
URL_HASH ${URL_HASH}
)
FetchContent_MakeAvailable(libtorch)
find_package(Torch REQUIRED PATHS ${libtorch_SOURCE_DIR} NO_DEFAULT_PATH)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS} -DC10_USE_GLOG")
endif()

if(MSVC)
file(GLOB TORCH_DLLS "${TORCH_INSTALL_PREFIX}/lib/*.dll")
Expand Down
6 changes: 3 additions & 3 deletions runtime/core/decoder/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ set(decoder_srcs
ctc_endpoint.cc
)

if(NOT TORCH AND NOT ONNX AND NOT XPU)
message(FATAL_ERROR "Please build with TORCH or ONNX or XPU!!!")
if(NOT TORCH AND NOT ONNX AND NOT XPU AND NOT IOS)
message(FATAL_ERROR "Please build with TORCH or ONNX or XPU or IOS!!!")
endif()
if(TORCH)
if(TORCH OR IOS)
list(APPEND decoder_srcs torch_asr_model.cc)
endif()
if(ONNX)
Expand Down
4 changes: 4 additions & 0 deletions runtime/core/decoder/torch_asr_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,13 @@
#include <stdexcept>

#include "torch/script.h"
#ifndef IOS
#include "torch/torch.h"
#endif

namespace wenet {

#ifndef IOS
void TorchAsrModel::InitEngineThreads(int num_threads) {
// For multi-thread performance
at::set_num_threads(num_threads);
Expand All @@ -36,6 +39,7 @@ void TorchAsrModel::InitEngineThreads(int num_threads) {
VLOG(1) << "Num intra-op threads: " << at::get_num_threads();
VLOG(1) << "Num inter-op threads: " << at::get_num_interop_threads();
}
#endif

void TorchAsrModel::Read(const std::string& model_path) {
torch::DeviceType device = at::kCPU;
Expand Down
4 changes: 4 additions & 0 deletions runtime/core/decoder/torch_asr_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
#include <vector>

#include "torch/script.h"
#ifndef IOS
#include "torch/torch.h"
#endif

#include "decoder/asr_model.h"
#include "utils/utils.h"
Expand All @@ -31,8 +33,10 @@ namespace wenet {

class TorchAsrModel : public AsrModel {
public:
#ifndef IOS
// Note: Do not call the InitEngineThreads function more than once.
static void InitEngineThreads(int num_threads = 1);
#endif

public:
using TorchModule = torch::jit::script::Module;
Expand Down
Empty file.
Loading

0 comments on commit d7fba09

Please sign in to comment.