Skip to content

Commit

Permalink
[runtime/xpu] Support the execution of non-streaming parsing on the K…
Browse files Browse the repository at this point in the history
…unlun XPU card #1455
  • Loading branch information
panhehe committed Oct 25, 2022
1 parent 89e8d0d commit fac202e
Show file tree
Hide file tree
Showing 28 changed files with 3,406 additions and 6 deletions.
37 changes: 37 additions & 0 deletions runtime/core/cmake/xpu.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
if(NOT WIN32)
string(ASCII 27 Esc)
set(ColourReset "${Esc}[m")
set(ColourBold "${Esc}[1m")
set(Red "${Esc}[31m")
set(Green "${Esc}[32m")
set(Yellow "${Esc}[33m")
set(Blue "${Esc}[34m")
set(Magenta "${Esc}[35m")
set(Cyan "${Esc}[36m")
set(White "${Esc}[37m")
set(BoldRed "${Esc}[1;31m")
set(BoldGreen "${Esc}[1;32m")
set(BoldYellow "${Esc}[1;33m")
set(BoldBlue "${Esc}[1;34m")
set(BoldMagenta "${Esc}[1;35m")
set(BoldCyan "${Esc}[1;36m")
set(BoldWhite "${Esc}[1;37m")
endif()

if(XPU)
set(RUNTIME_XPU_PATH ${CMAKE_CURRENT_SOURCE_DIR})
message(STATUS "RUNTIME_XPU_PATH is ${RUNTIME_XPU_PATH} .\n")
set(XPU_KUNLUN_PATH ${RUNTIME_XPU_PATH}/decoder/xpu_kunlun)
if(NOT DEFINED ENV{XPU_API_PATH})
message(FATAL_ERROR "${BoldRed}NO ENV{XPU_API_PATH} in your env. Please set XPU_API_PATH.${ColourReset}\n")
else()
set(XPU_API_PATH $ENV{XPU_API_PATH})
message("set XPU_API_PATH from env_var. Val is $ENV{XPU_API_PATH}.")
endif()

include_directories(${XPU_KUNLUN_PATH}/
${XPU_API_PATH}/output/include ${XPU_API_PATH}/../runtime/include)
link_directories(${XPU_API_PATH}/output/so/ ${XPU_API_PATH}/../runtime/output/so/)

add_definitions(-DUSE_XPU)
endif()
21 changes: 18 additions & 3 deletions runtime/core/decoder/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ set(decoder_srcs
ctc_endpoint.cc
)

if(NOT TORCH AND NOT ONNX)
message(FATAL_ERROR "Please build with TORCH or ONNX!!!")
if(NOT TORCH AND NOT ONNX AND NOT XPU)
message(FATAL_ERROR "Please build with TORCH or ONNX or XPU!!!")
endif()
if(TORCH)
list(APPEND decoder_srcs torch_asr_model.cc)
Expand All @@ -17,8 +17,23 @@ if(ONNX)
list(APPEND decoder_srcs onnx_asr_model.cc)
endif()

if(XPU)
list(APPEND decoder_srcs xpu_asr_model.cc)
list(APPEND decoder_srcs ./xpu_kunlun/xpu_conformer.cpp)
list(APPEND decoder_srcs ./xpu_kunlun/xpu_util.cpp)
message(STATUS "xpu decoder_srcs is :: ${decoder_srcs} \n")
# compile conformer_test
add_subdirectory(xpu_kunlun)
endif()

add_library(decoder STATIC ${decoder_srcs})
target_link_libraries(decoder PUBLIC kaldi-decoder frontend post_processor utils)
if(XPU)
target_link_libraries(decoder PUBLIC kaldi-decoder frontend
post_processor utils xpuapi xpurt)
else()
target_link_libraries(decoder PUBLIC kaldi-decoder frontend
post_processor utils)
endif()

if(ANDROID)
target_link_libraries(decoder PUBLIC ${PYTORCH_LIBRARY} ${FBJNI_LIBRARY})
Expand Down
27 changes: 24 additions & 3 deletions runtime/core/decoder/params.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.


#ifndef DECODER_PARAMS_H_
#define DECODER_PARAMS_H_

Expand All @@ -29,17 +28,24 @@
#ifdef USE_TORCH
#include "decoder/torch_asr_model.h"
#endif
#ifdef USE_XPU
#include "decoder/xpu_asr_model.h"
#endif
#include "frontend/feature_pipeline.h"
#include "post_processor/post_processor.h"
#include "utils/flags.h"
#include "utils/string.h"

DEFINE_int32(num_threads, 1, "num threads for ASR model");
DEFINE_int32(device_id, 0, "set XPU DeviceID for ASR model");

// TorchAsrModel flags
DEFINE_string(model_path, "", "pytorch exported model path");
// OnnxAsrModel flags
DEFINE_string(onnx_dir, "", "directory where the onnx model is saved");
// XPUAsrModel flags
DEFINE_string(xpu_model_dir, "",
"directory where the XPU model and weights is saved");

// FeaturePipelineConfig flags
DEFINE_int32(num_bins, 80, "num mel bins for fbank feature");
Expand All @@ -66,7 +72,8 @@ DEFINE_double(lattice_beam, 10.0, "lattice beam in ctc wfst search");
DEFINE_double(acoustic_scale, 1.0, "acoustic scale for ctc wfst search");
DEFINE_double(blank_skip_thresh, 1.0,
"blank skip thresh for ctc wfst search, 1.0 means no skip");
DEFINE_double(length_penalty, 0.0, "length penalty ctc wfst search, will not"
DEFINE_double(length_penalty, 0.0,
"length penalty ctc wfst search, will not"
"apply on self-loop arc, for balancing the del/ins ratio, "
"suggest set to -3.0");
DEFINE_int32(nbest, 10, "nbest for ctc wfst or prefix search");
Expand Down Expand Up @@ -130,7 +137,7 @@ std::shared_ptr<DecodeResource> InitDecodeResourceFromFlags() {
#else
LOG(FATAL) << "Please rebuild with cmake options '-DONNX=ON'.";
#endif
} else {
} else if (!FLAGS_model_path.empty()) {
#ifdef USE_TORCH
LOG(INFO) << "Reading torch model " << FLAGS_model_path;
TorchAsrModel::InitEngineThreads(FLAGS_num_threads);
Expand All @@ -140,6 +147,19 @@ std::shared_ptr<DecodeResource> InitDecodeResourceFromFlags() {
#else
LOG(FATAL) << "Please rebuild with cmake options '-DTORCH=ON'.";
#endif
} else if (!FLAGS_xpu_model_dir.empty()) {
#ifdef USE_XPU
LOG(INFO) << "Reading XPU WeNet model weight from " << FLAGS_xpu_model_dir;
auto model = std::make_shared<XPUAsrModel>();
model->SetEngineThreads(FLAGS_num_threads);
model->SetDeviceId(FLAGS_device_id);
model->Read(FLAGS_xpu_model_dir);
resource->model = model;
#else
LOG(FATAL) << "Please rebuild with cmake options '-DXPU=ON'.";
#endif
} else {
LOG(FATAL) << "Please set ONNX, TORCH or XPU model path!!!";
}

LOG(INFO) << "Reading unit table " << FLAGS_unit_path;
Expand Down Expand Up @@ -186,6 +206,7 @@ std::shared_ptr<DecodeResource> InitDecodeResourceFromFlags() {
post_process_opts.lowercase = FLAGS_lowercase;
resource->post_processor =
std::make_shared<PostProcessor>(std::move(post_process_opts));
LOG(INFO) << "Finish set PostProcessOptions. \n";
return resource;
}

Expand Down
Loading

0 comments on commit fac202e

Please sign in to comment.