Skip to content

Commit e8f451e

Browse files
duj12dujing
and
dujing
authored
ITN runtime. (#2001)
* wenet_api and decoder_main support ITN. * remove resource dir * turn websocket cmake on * fix cpplint * fix clang-format * fix clang-format * fix clang-format * fix CRLF. * try to add wetext as subproject, but failed. Because wetext and wenet both have utils and utils/string.h * add wetext subproject, intergrate ITN. * add fst file exist check. * include file.h in params.h --------- Co-authored-by: dujing <[email protected]>
1 parent d4c56b5 commit e8f451e

File tree

11 files changed

+182
-31
lines changed

11 files changed

+182
-31
lines changed

runtime/core/api/wenet_api.cc

+26-10
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ class Recognizer {
6565

6666
std::string fst_path = wenet::JoinPath(model_dir, "TLG.fst");
6767
if (wenet::FileExists(fst_path)) { // With LM
68-
resource_->fst = std::shared_ptr<fst::Fst<fst::StdArc>>(
69-
fst::Fst<fst::StdArc>::Read(fst_path));
68+
resource_->fst = std::shared_ptr<fst::VectorFst<fst::StdArc>>(
69+
fst::VectorFst<fst::StdArc>::Read(fst_path));
7070

7171
std::string symbol_path = wenet::JoinPath(model_dir, "words.txt");
7272
CHECK(wenet::FileExists(symbol_path));
@@ -79,7 +79,30 @@ class Recognizer {
7979
// Context config init
8080
context_config_ = std::make_shared<wenet::ContextConfig>();
8181
decode_options_ = std::make_shared<wenet::DecodeOptions>();
82+
83+
// PostProcessor
8284
post_process_opts_ = std::make_shared<wenet::PostProcessOptions>();
85+
if (language_ == "chs") { // TODO(Binbin Zhang): CJK(chs, jp, kr)
86+
post_process_opts_->language_type = wenet::kMandarinEnglish;
87+
} else {
88+
post_process_opts_->language_type = wenet::kIndoEuropean;
89+
}
90+
resource_->post_processor =
91+
std::make_shared<wenet::PostProcessor>(*post_process_opts_);
92+
// Optional: ITN
93+
std::string itn_tagger_path =
94+
wenet::JoinPath(model_dir, "zh_itn_tagger.fst");
95+
std::string itn_verbalizer_path =
96+
wenet::JoinPath(model_dir, "zh_itn_verbalizer.fst");
97+
if (wenet::FileExists(itn_tagger_path) &&
98+
wenet::FileExists(itn_verbalizer_path)) {
99+
LOG(INFO) << "Reading ITN fst";
100+
post_process_opts_->itn = true;
101+
auto postprocessor =
102+
std::make_shared<wenet::PostProcessor>(*post_process_opts_);
103+
postprocessor->InitITNResource(itn_tagger_path, itn_verbalizer_path);
104+
resource_->post_processor = postprocessor;
105+
}
83106
}
84107

85108
void Reset() {
@@ -101,14 +124,7 @@ class Recognizer {
101124
context_graph->BuildContextGraph(context_, resource_->symbol_table);
102125
resource_->context_graph = context_graph;
103126
}
104-
// PostProcessor
105-
if (language_ == "chs") { // TODO(Binbin Zhang): CJK(chs, jp, kr)
106-
post_process_opts_->language_type = wenet::kMandarinEnglish;
107-
} else {
108-
post_process_opts_->language_type = wenet::kIndoEuropean;
109-
}
110-
resource_->post_processor =
111-
std::make_shared<wenet::PostProcessor>(*post_process_opts_);
127+
112128
// Init decode options
113129
decode_options_->chunk_size = chunk_size_;
114130
// Init decoder

runtime/core/bin/CMakeLists.txt

+4-4
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@ endif()
77
add_executable(label_checker_main label_checker_main.cc)
88
target_link_libraries(label_checker_main PUBLIC decoder)
99

10-
# if(TORCH)
11-
# add_executable(api_main api_main.cc)
12-
# target_link_libraries(api_main PUBLIC wenet_api)
13-
# endif()
10+
if(TORCH)
11+
add_executable(api_main api_main.cc)
12+
target_link_libraries(api_main PUBLIC wenet_api)
13+
endif()
1414

1515
if(WEBSOCKET)
1616
add_executable(websocket_client_main websocket_client_main.cc)

runtime/core/cmake/openfst.cmake

+4-2
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,10 @@ if(NOT ANDROID)
2929
# To build openfst with gflags and glog, we comment out some vars of {flags, log}.h and flags.cc.
3030
set(openfst_SOURCE_DIR ${fc_base}/openfst-src CACHE PATH "OpenFST source directory")
3131
FetchContent_Declare(openfst
32-
URL https://github.com/kkm000/openfst/archive/refs/tags/win/1.6.5.1.tar.gz
33-
URL_HASH SHA256=02c49b559c3976a536876063369efc0e41ab374be1035918036474343877046e
32+
URL https://github.com/kkm000/openfst/archive/refs/tags/win/1.7.2.1.tar.gz
33+
URL_HASH SHA256=e04e1dabcecf3a687ace699ccb43a8a27da385777a56e69da6e103344cc66bca
34+
#URL https://github.com/kkm000/openfst/archive/refs/tags/win/1.6.5.1.tar.gz
35+
#URL_HASH SHA256=02c49b559c3976a536876063369efc0e41ab374be1035918036474343877046e
3436
PATCH_COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_CURRENT_SOURCE_DIR}/patch/openfst ${openfst_SOURCE_DIR}
3537
)
3638
FetchContent_MakeAvailable(openfst)
+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
FetchContent_Declare(wetextprocessing
2+
GIT_REPOSITORY https://github.com/wenet-e2e/WeTextProcessing.git
3+
GIT_TAG origin/master
4+
)
5+
FetchContent_MakeAvailable(wetextprocessing)
6+
include_directories(${wetextprocessing_SOURCE_DIR}/runtime )
7+
add_subdirectory(${wetextprocessing_SOURCE_DIR}/runtime/utils)
8+
add_subdirectory(${wetextprocessing_SOURCE_DIR}/runtime/processor)
9+
10+

runtime/core/decoder/asr_decoder.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ enum DecodeState {
9292
struct DecodeResource {
9393
std::shared_ptr<AsrModel> model = nullptr;
9494
std::shared_ptr<fst::SymbolTable> symbol_table = nullptr;
95-
std::shared_ptr<fst::Fst<fst::StdArc>> fst = nullptr;
95+
std::shared_ptr<fst::VectorFst<fst::StdArc>> fst = nullptr;
9696
std::shared_ptr<fst::SymbolTable> unit_table = nullptr;
9797
std::shared_ptr<ContextGraph> context_graph = nullptr;
9898
std::shared_ptr<PostProcessor> post_processor = nullptr;
@@ -140,7 +140,7 @@ class AsrDecoder {
140140
std::shared_ptr<PostProcessor> post_processor_;
141141
std::shared_ptr<ContextGraph> context_graph_;
142142

143-
std::shared_ptr<fst::Fst<fst::StdArc>> fst_ = nullptr;
143+
std::shared_ptr<fst::VectorFst<fst::StdArc>> fst_ = nullptr;
144144
// output symbol table
145145
std::shared_ptr<fst::SymbolTable> symbol_table_;
146146
// e2e unit symbol table

runtime/core/decoder/params.h

+25-2
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
#endif
4040
#include "frontend/feature_pipeline.h"
4141
#include "post_processor/post_processor.h"
42+
#include "utils/file.h"
4243
#include "utils/flags.h"
4344
#include "utils/string.h"
4445

@@ -65,6 +66,11 @@ DEFINE_int32(sample_rate, 16000, "sample rate for audio");
6566
// TLG fst
6667
DEFINE_string(fst_path, "", "TLG fst path");
6768

69+
// ITN fst
70+
DEFINE_string(itn_model_dir, "",
71+
"fst based ITN model dir, "
72+
"should contain itn_tagger.fst and itn_verbalizer.fst");
73+
6874
// DecodeOptions flags
6975
DEFINE_int32(chunk_size, 16, "decoding chunk size");
7076
DEFINE_int32(num_left_chunks, -1, "left chunks in decoding");
@@ -203,8 +209,8 @@ std::shared_ptr<DecodeResource> InitDecodeResourceFromFlags() {
203209
if (!FLAGS_fst_path.empty()) { // With LM
204210
CHECK(!FLAGS_dict_path.empty());
205211
LOG(INFO) << "Reading fst " << FLAGS_fst_path;
206-
auto fst = std::shared_ptr<fst::Fst<fst::StdArc>>(
207-
fst::Fst<fst::StdArc>::Read(FLAGS_fst_path));
212+
auto fst = std::shared_ptr<fst::VectorFst<fst::StdArc>>(
213+
fst::VectorFst<fst::StdArc>::Read(FLAGS_fst_path));
208214
CHECK(fst != nullptr);
209215
resource->fst = fst;
210216

@@ -237,6 +243,23 @@ std::shared_ptr<DecodeResource> InitDecodeResourceFromFlags() {
237243
post_process_opts.lowercase = FLAGS_lowercase;
238244
resource->post_processor =
239245
std::make_shared<PostProcessor>(std::move(post_process_opts));
246+
247+
if (!FLAGS_itn_model_dir.empty()) { // With ITN
248+
std::string itn_tagger_path =
249+
wenet::JoinPath(FLAGS_itn_model_dir, "zh_itn_tagger.fst");
250+
std::string itn_verbalizer_path =
251+
wenet::JoinPath(FLAGS_itn_model_dir, "zh_itn_verbalizer.fst");
252+
if (wenet::FileExists(itn_tagger_path) &&
253+
wenet::FileExists(itn_verbalizer_path)) {
254+
LOG(INFO) << "Reading ITN fst" << FLAGS_itn_model_dir;
255+
post_process_opts.itn = true;
256+
auto postprocessor =
257+
std::make_shared<wenet::PostProcessor>(std::move(post_process_opts));
258+
postprocessor->InitITNResource(itn_tagger_path, itn_verbalizer_path);
259+
resource->post_processor = postprocessor;
260+
}
261+
}
262+
240263
return resource;
241264
}
242265

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
add_executable(fst_test
2+
fst_test.cc
3+
../include/fst/test/fst_test.h
4+
)
5+
target_link_libraries(fst_test fst ${CMAKE_DL_LIBS})
6+
set_target_properties(fst_test PROPERTIES FOLDER test)
7+
add_test(NAME fst_test-test COMMAND fst_test)
8+
9+
add_executable(weight_test
10+
weight_test.cc
11+
../include/fst/test/weight-tester.h
12+
)
13+
target_link_libraries(weight_test fst ${CMAKE_DL_LIBS})
14+
set_target_properties(weight_test PROPERTIES FOLDER test)
15+
add_test(NAME weight_test-test COMMAND weight_test)
16+
17+
add_executable(algo_test_log algo_test.cc ../include/fst/test/algo_test.h ../include/fst/test/rand-fst.h)
18+
target_link_libraries(algo_test_log fst ${CMAKE_DL_LIBS})
19+
target_compile_definitions(algo_test_log
20+
PRIVATE TEST_LOG=1)
21+
set_target_properties(algo_test_log PROPERTIES FOLDER test)
22+
add_test(NAME algo_test_log-test COMMAND algo_test_log)
23+
24+
25+
add_executable(algo_test_tropical algo_test.cc ../include/fst/test/algo_test.h ../include/fst/test/rand-fst.h)
26+
target_link_libraries(algo_test_tropical fst ${CMAKE_DL_LIBS})
27+
target_compile_definitions(algo_test_tropical
28+
PRIVATE TEST_TROPICAL=1)
29+
set_target_properties(algo_test_tropical PROPERTIES FOLDER test)
30+
add_test(NAME algo_test_tropical-test COMMAND algo_test_tropical)
31+
32+
33+
add_executable(algo_test_minmax algo_test.cc ../include/fst/test/algo_test.h ../include/fst/test/rand-fst.h)
34+
target_link_libraries(algo_test_minmax fst ${CMAKE_DL_LIBS})
35+
target_compile_definitions(algo_test_minmax
36+
PRIVATE TEST_MINMAX=1)
37+
set_target_properties(algo_test_minmax PROPERTIES FOLDER test)
38+
add_test(NAME algo_test_minmax-test COMMAND algo_test_minmax)
39+
40+
41+
add_executable(algo_test_lexicographic algo_test.cc ../include/fst/test/algo_test.h ../include/fst/test/rand-fst.h)
42+
target_link_libraries(algo_test_lexicographic fst ${CMAKE_DL_LIBS})
43+
target_compile_definitions(algo_test_lexicographic
44+
PRIVATE TEST_LEXICOGRAPHIC=1)
45+
set_target_properties(algo_test_lexicographic PROPERTIES FOLDER test)
46+
add_test(NAME algo_test_lexicographic-test COMMAND algo_test_lexicographic)
47+
48+
49+
add_executable(algo_test_power algo_test.cc ../include/fst/test/algo_test.h ../include/fst/test/rand-fst.h)
50+
target_link_libraries(algo_test_power fst ${CMAKE_DL_LIBS})
51+
target_compile_definitions(algo_test_power
52+
PRIVATE TEST_POWER=1)
53+
set_target_properties(algo_test_power PROPERTIES FOLDER test)
54+
add_test(NAME algo_test_power-test COMMAND algo_test_power)
55+
+5-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11
add_library(post_processor STATIC
22
post_processor.cc
33
)
4-
target_link_libraries(post_processor PUBLIC utils)
4+
if(ITN)
5+
target_link_libraries(post_processor PUBLIC utils wetext_processor)
6+
else()
7+
target_link_libraries(post_processor PUBLIC utils)
8+
endif()

runtime/core/post_processor/post_processor.cc

+34-4
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
// Copyright (c) 2021 Xingchen Song [email protected]
2+
// 2023 Jing Du ([email protected])
23
//
34
// Licensed under the Apache License, Version 2.0 (the "License");
45
// you may not use this file except in compliance with the License.
@@ -13,13 +14,18 @@
1314
// limitations under the License
1415

1516
#include "post_processor/post_processor.h"
16-
1717
#include <sstream>
1818
#include <vector>
19-
19+
#include "processor/wetext_processor.h"
2020
#include "utils/string.h"
2121

2222
namespace wenet {
23+
void PostProcessor::InitITNResource(const std::string& tagger_path,
24+
const std::string& verbalizer_path) {
25+
auto itn_processor =
26+
std::make_shared<wetext::Processor>(tagger_path, verbalizer_path);
27+
itn_resource = itn_processor;
28+
}
2329

2430
std::string PostProcessor::ProcessSpace(const std::string& str) {
2531
std::string result = str;
@@ -56,10 +62,34 @@ std::string PostProcessor::ProcessSpace(const std::string& str) {
5662
return result;
5763
}
5864

65+
std::string del_substr(const std::string& str, const std::string& sub) {
66+
std::string result = str;
67+
int pos = 0;
68+
while (string::npos != (pos = result.find(sub))) {
69+
result.erase(pos, sub.size());
70+
}
71+
return result;
72+
}
73+
74+
std::string PostProcessor::ProcessSymbols(const std::string& str) {
75+
std::string result = str;
76+
result = del_substr(result, "<unk>");
77+
result = del_substr(result, "<context>");
78+
result = del_substr(result, "</context>");
79+
return result;
80+
}
81+
5982
std::string PostProcessor::Process(const std::string& str, bool finish) {
6083
std::string result;
61-
result = ProcessSpace(str);
62-
// TODO(xcsong): do itn/punctuation if finish == true
84+
// remove symbols with "<>" first
85+
result = ProcessSymbols(str);
86+
result = ProcessSpace(result);
87+
// TODO(xcsong): do punctuation if finish == true
88+
if (finish == true && opts_.itn) {
89+
if (nullptr != itn_resource) {
90+
result = itn_resource->Normalize(result);
91+
}
92+
}
6393
return result;
6494
}
6595

runtime/core/post_processor/post_processor.h

+10-5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
// Copyright (c) 2021 Xingchen Song [email protected]
2+
// 2023 Jing Du ([email protected])
23
//
34
// Licensed under the Apache License, Version 2.0 (the "License");
45
// you may not use this file except in compliance with the License.
@@ -18,7 +19,7 @@
1819
#include <memory>
1920
#include <string>
2021
#include <utility>
21-
22+
#include "processor/wetext_processor.h"
2223
#include "utils/utils.h"
2324

2425
namespace wenet {
@@ -43,10 +44,10 @@ struct PostProcessOptions {
4344
LanguageType language_type = kMandarinEnglish;
4445
// whether lowercase letters are required
4546
bool lowercase = true;
47+
bool itn = false;
4648
};
4749

48-
// TODO(xcsong): add itn/punctuation related resource
49-
struct PostProcessResource {};
50+
// TODO(xcsong): add punctuation related resource
5051

5152
// Post Processor
5253
class PostProcessor {
@@ -57,11 +58,15 @@ class PostProcessor {
5758
std::string Process(const std::string& str, bool finish);
5859
// process spaces according to configurations
5960
std::string ProcessSpace(const std::string& str);
60-
// TODO(xcsong): add itn/punctuation
61-
// void InverseTN(const std::string& str);
61+
std::string ProcessSymbols(const std::string& str);
62+
// TODO(xcsong): add punctuation
6263
// void Punctuate(const std::string& str);
6364

65+
void InitITNResource(const std::string& tagger_path,
66+
const std::string& verbalizer_path);
67+
6468
private:
69+
std::shared_ptr<wetext::Processor> itn_resource = nullptr;
6570
const PostProcessOptions opts_;
6671

6772
public:

runtime/libtorch/CMakeLists.txt

+7-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ project(wenet VERSION 0.1)
44

55
option(CXX11_ABI "whether to use CXX11_ABI libtorch" OFF)
66
option(GRAPH_TOOLS "whether to build TLG graph tools" OFF)
7-
option(BUILD_TESTING "whether to build unit test" ON)
7+
option(BUILD_TESTING "whether to build unit test" OFF)
88

99
option(GRPC "whether to build with gRPC" OFF)
1010
# TODO(Binbin Zhang): Change websocket to OFF since it depends on boost
@@ -14,7 +14,9 @@ option(HTTP "whether to build with http" OFF)
1414
option(TORCH "whether to build with Torch" ON)
1515
option(ONNX "whether to build with ONNX" OFF)
1616
option(GPU "whether to build with GPU" OFF)
17+
option(ITN "whether to use WeTextProcessing" ON)
1718

19+
set(CMAKE_CXX_FLAGS_DEBUG "$ENV{CXXFLAGS} -g")
1820
set(CMAKE_VERBOSE_MAKEFILE OFF)
1921

2022
include(FetchContent)
@@ -46,6 +48,10 @@ include_directories(
4648
${CMAKE_CURRENT_SOURCE_DIR}/kaldi
4749
)
4850

51+
if(ITN)
52+
include(wetextprocessing)
53+
endif()
54+
4955
# Build all libraries
5056
add_subdirectory(utils)
5157
add_subdirectory(frontend)

0 commit comments

Comments
 (0)