From a38881817c1c68946d5a344621ee271107e5dbcc Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Fri, 31 May 2024 12:34:30 +0800 Subject: [PATCH] Support customize scores for hotwords (#926) * Support customize scores for hotwords * Skip blank lines --- sherpa-onnx/csrc/context-graph.h | 7 ++-- .../csrc/offline-recognizer-transducer-impl.h | 39 +++++++++++++----- .../csrc/online-recognizer-transducer-impl.h | 40 ++++++++++++++----- sherpa-onnx/csrc/text2token-test.cc | 36 ++++++++++++----- sherpa-onnx/csrc/utils.cc | 13 ++++-- sherpa-onnx/csrc/utils.h | 3 +- 6 files changed, 103 insertions(+), 35 deletions(-) diff --git a/sherpa-onnx/csrc/context-graph.h b/sherpa-onnx/csrc/context-graph.h index 14f03f4b6..e16fc4d37 100644 --- a/sherpa-onnx/csrc/context-graph.h +++ b/sherpa-onnx/csrc/context-graph.h @@ -61,10 +61,9 @@ class ContextGraph { } ContextGraph(const std::vector> &token_ids, - float context_score, const std::vector &scores = {}, - const std::vector &phrases = {}) - : ContextGraph(token_ids, context_score, 0.0f, scores, phrases, - std::vector()) {} + float context_score, const std::vector &scores = {}) + : ContextGraph(token_ids, context_score, 0.0f, scores, + std::vector(), std::vector()) {} std::tuple ForwardOneStep( const ContextState *state, int32_t token_id, diff --git a/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h b/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h index 5051c8b65..265f42bb9 100644 --- a/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h @@ -145,15 +145,35 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { auto hws = std::regex_replace(hotwords, std::regex("/"), "\n"); std::istringstream is(hws); std::vector> current; + std::vector current_scores; if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_, - bpe_encoder_.get(), ¤t)) { + bpe_encoder_.get(), ¤t, ¤t_scores)) { SHERPA_ONNX_LOGE("Encode hotwords failed, skipping, hotwords are : %s", hotwords.c_str()); } + + int32_t num_default_hws = hotwords_.size(); + int32_t num_hws = current.size(); + current.insert(current.end(), hotwords_.begin(), hotwords_.end()); - auto context_graph = - std::make_shared(current, config_.hotwords_score); + if (!current_scores.empty() && !boost_scores_.empty()) { + current_scores.insert(current_scores.end(), boost_scores_.begin(), + boost_scores_.end()); + } else if (!current_scores.empty() && boost_scores_.empty()) { + current_scores.insert(current_scores.end(), num_default_hws, + config_.hotwords_score); + } else if (current_scores.empty() && !boost_scores_.empty()) { + current_scores.insert(current_scores.end(), num_hws, + config_.hotwords_score); + current_scores.insert(current_scores.end(), boost_scores_.begin(), + boost_scores_.end()); + } else { + // Do nothing. + } + + auto context_graph = std::make_shared( + current, config_.hotwords_score, current_scores); return std::make_unique(config_.feat_config, context_graph); } @@ -226,13 +246,13 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { } if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_, - bpe_encoder_.get(), &hotwords_)) { + bpe_encoder_.get(), &hotwords_, &boost_scores_)) { SHERPA_ONNX_LOGE( "Failed to encode some hotwords, skip them already, see logs above " "for details."); } - hotwords_graph_ = - std::make_shared(hotwords_, config_.hotwords_score); + hotwords_graph_ = std::make_shared( + hotwords_, config_.hotwords_score, boost_scores_); } #if __ANDROID_API__ >= 9 @@ -250,13 +270,13 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { } if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_, - bpe_encoder_.get(), &hotwords_)) { + bpe_encoder_.get(), &hotwords_, &boost_scores_)) { SHERPA_ONNX_LOGE( "Failed to encode some hotwords, skip them already, see logs above " "for details."); } - hotwords_graph_ = - std::make_shared(hotwords_, config_.hotwords_score); + hotwords_graph_ = std::make_shared( + hotwords_, config_.hotwords_score, boost_scores_); } #endif @@ -264,6 +284,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { OfflineRecognizerConfig config_; SymbolTable symbol_table_; std::vector> hotwords_; + std::vector boost_scores_; ContextGraphPtr hotwords_graph_; std::unique_ptr bpe_encoder_; std::unique_ptr model_; diff --git a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h index 60e3aa2b9..16c44b9de 100644 --- a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h @@ -182,14 +182,35 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { auto hws = std::regex_replace(hotwords, std::regex("/"), "\n"); std::istringstream is(hws); std::vector> current; + std::vector current_scores; if (!EncodeHotwords(is, config_.model_config.modeling_unit, sym_, - bpe_encoder_.get(), ¤t)) { + bpe_encoder_.get(), ¤t, ¤t_scores)) { SHERPA_ONNX_LOGE("Encode hotwords failed, skipping, hotwords are : %s", hotwords.c_str()); } + + int32_t num_default_hws = hotwords_.size(); + int32_t num_hws = current.size(); + current.insert(current.end(), hotwords_.begin(), hotwords_.end()); - auto context_graph = - std::make_shared(current, config_.hotwords_score); + + if (!current_scores.empty() && !boost_scores_.empty()) { + current_scores.insert(current_scores.end(), boost_scores_.begin(), + boost_scores_.end()); + } else if (!current_scores.empty() && boost_scores_.empty()) { + current_scores.insert(current_scores.end(), num_default_hws, + config_.hotwords_score); + } else if (current_scores.empty() && !boost_scores_.empty()) { + current_scores.insert(current_scores.end(), num_hws, + config_.hotwords_score); + current_scores.insert(current_scores.end(), boost_scores_.begin(), + boost_scores_.end()); + } else { + // Do nothing. + } + + auto context_graph = std::make_shared( + current, config_.hotwords_score, current_scores); auto stream = std::make_unique(config_.feat_config, context_graph); InitOnlineStream(stream.get()); @@ -376,13 +397,13 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { } if (!EncodeHotwords(is, config_.model_config.modeling_unit, sym_, - bpe_encoder_.get(), &hotwords_)) { + bpe_encoder_.get(), &hotwords_, &boost_scores_)) { SHERPA_ONNX_LOGE( "Failed to encode some hotwords, skip them already, see logs above " "for details."); } - hotwords_graph_ = - std::make_shared(hotwords_, config_.hotwords_score); + hotwords_graph_ = std::make_shared( + hotwords_, config_.hotwords_score, boost_scores_); } #if __ANDROID_API__ >= 9 @@ -400,13 +421,13 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { } if (!EncodeHotwords(is, config_.model_config.modeling_unit, sym_, - bpe_encoder_.get(), &hotwords_)) { + bpe_encoder_.get(), &hotwords_, &boost_scores_)) { SHERPA_ONNX_LOGE( "Failed to encode some hotwords, skip them already, see logs above " "for details."); } - hotwords_graph_ = - std::make_shared(hotwords_, config_.hotwords_score); + hotwords_graph_ = std::make_shared( + hotwords_, config_.hotwords_score, boost_scores_); } #endif @@ -428,6 +449,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { private: OnlineRecognizerConfig config_; std::vector> hotwords_; + std::vector boost_scores_; ContextGraphPtr hotwords_graph_; std::unique_ptr bpe_encoder_; std::unique_ptr model_; diff --git a/sherpa-onnx/csrc/text2token-test.cc b/sherpa-onnx/csrc/text2token-test.cc index ef07797db..0ad912df8 100644 --- a/sherpa-onnx/csrc/text2token-test.cc +++ b/sherpa-onnx/csrc/text2token-test.cc @@ -35,17 +35,21 @@ TEST(TEXT2TOKEN, TEST_cjkchar) { auto sym_table = SymbolTable(tokens); - std::string text = "世界人民大团结\n中国 V S 美国"; + std::string text = + "世界人民大团结\n中国 V S 美国\n\n"; // Test blank lines also std::istringstream iss(text); std::vector> ids; + std::vector scores; - auto r = EncodeHotwords(iss, "cjkchar", sym_table, nullptr, &ids); + auto r = EncodeHotwords(iss, "cjkchar", sym_table, nullptr, &ids, &scores); std::vector> expected_ids( {{379, 380, 72, 874, 93, 1251, 489}, {262, 147, 3423, 2476, 21, 147}}); EXPECT_EQ(ids, expected_ids); + + EXPECT_EQ(scores.size(), 0); } TEST(TEXT2TOKEN, TEST_bpe) { @@ -68,17 +72,22 @@ TEST(TEXT2TOKEN, TEST_bpe) { auto sym_table = SymbolTable(tokens); auto bpe_processor = std::make_unique(bpe); - std::string text = "HELLO WORLD\nI LOVE YOU"; + std::string text = "HELLO WORLD\nI LOVE YOU :2.0"; std::istringstream iss(text); std::vector> ids; + std::vector scores; - auto r = EncodeHotwords(iss, "bpe", sym_table, bpe_processor.get(), &ids); + auto r = + EncodeHotwords(iss, "bpe", sym_table, bpe_processor.get(), &ids, &scores); std::vector> expected_ids( {{22, 58, 24, 425}, {19, 370, 47}}); EXPECT_EQ(ids, expected_ids); + + std::vector expected_scores({0, 2.0}); + EXPECT_EQ(scores, expected_scores); } TEST(TEXT2TOKEN, TEST_cjkchar_bpe) { @@ -101,19 +110,23 @@ TEST(TEXT2TOKEN, TEST_cjkchar_bpe) { auto sym_table = SymbolTable(tokens); auto bpe_processor = std::make_unique(bpe); - std::string text = "世界人民 GOES TOGETHER\n中国 GOES WITH 美国"; + std::string text = "世界人民 GOES TOGETHER :1.5\n中国 GOES WITH 美国 :0.5"; std::istringstream iss(text); std::vector> ids; + std::vector scores; - auto r = - EncodeHotwords(iss, "cjkchar+bpe", sym_table, bpe_processor.get(), &ids); + auto r = EncodeHotwords(iss, "cjkchar+bpe", sym_table, bpe_processor.get(), + &ids, &scores); std::vector> expected_ids( {{1368, 1392, 557, 680, 275, 178, 475}, {685, 736, 275, 178, 179, 921, 736}}); EXPECT_EQ(ids, expected_ids); + + std::vector expected_scores({1.5, 0.5}); + EXPECT_EQ(scores, expected_scores); } TEST(TEXT2TOKEN, TEST_bbpe) { @@ -136,17 +149,22 @@ TEST(TEXT2TOKEN, TEST_bbpe) { auto sym_table = SymbolTable(tokens); auto bpe_processor = std::make_unique(bpe); - std::string text = "频繁\n李鞑靼"; + std::string text = "频繁 :1.0\n李鞑靼"; std::istringstream iss(text); std::vector> ids; + std::vector scores; - auto r = EncodeHotwords(iss, "bpe", sym_table, bpe_processor.get(), &ids); + auto r = + EncodeHotwords(iss, "bpe", sym_table, bpe_processor.get(), &ids, &scores); std::vector> expected_ids( {{259, 1118, 234, 188, 132}, {259, 1585, 236, 161, 148, 236, 160, 191}}); EXPECT_EQ(ids, expected_ids); + + std::vector expected_scores({1.0, 0}); + EXPECT_EQ(scores, expected_scores); } } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/utils.cc b/sherpa-onnx/csrc/utils.cc index 6363f03c4..93de43e73 100644 --- a/sherpa-onnx/csrc/utils.cc +++ b/sherpa-onnx/csrc/utils.cc @@ -103,7 +103,8 @@ static bool EncodeBase(const std::vector &lines, bool EncodeHotwords(std::istream &is, const std::string &modeling_unit, const SymbolTable &symbol_table, const ssentencepiece::Ssentencepiece *bpe_encoder, - std::vector> *hotwords) { + std::vector> *hotwords, + std::vector *boost_scores) { std::vector lines; std::string line; std::string word; @@ -131,7 +132,12 @@ bool EncodeHotwords(std::istream &is, const std::string &modeling_unit, break; } } - phrase = oss.str().substr(1); + phrase = oss.str(); + if (phrase.empty()) { + continue; + } else { + phrase = phrase.substr(1); + } std::istringstream piss(phrase); oss.clear(); oss.str(""); @@ -177,7 +183,8 @@ bool EncodeHotwords(std::istream &is, const std::string &modeling_unit, } lines.push_back(oss.str()); } - return EncodeBase(lines, symbol_table, hotwords, nullptr, nullptr, nullptr); + return EncodeBase(lines, symbol_table, hotwords, nullptr, boost_scores, + nullptr); } bool EncodeKeywords(std::istream &is, const SymbolTable &symbol_table, diff --git a/sherpa-onnx/csrc/utils.h b/sherpa-onnx/csrc/utils.h index a3189a20a..a9d59e8a2 100644 --- a/sherpa-onnx/csrc/utils.h +++ b/sherpa-onnx/csrc/utils.h @@ -29,7 +29,8 @@ namespace sherpa_onnx { bool EncodeHotwords(std::istream &is, const std::string &modeling_unit, const SymbolTable &symbol_table, const ssentencepiece::Ssentencepiece *bpe_encoder_, - std::vector> *hotwords_id); + std::vector> *hotwords_id, + std::vector *boost_scores); /* Encode the keywords in an input stream to be tokens ids. *