From 08534651c99fa5e1af60364d172ced53f501f777 Mon Sep 17 00:00:00 2001 From: Chen Gong Date: Tue, 21 Jul 2020 17:26:52 +0800 Subject: [PATCH] perf(poet): optimize for performance in making sentences (~40% faster) --- src/rime/gear/contextual_translation.cc | 12 +- src/rime/gear/grammar.h | 8 +- src/rime/gear/poet.cc | 234 +++++++++++++++--------- src/rime/gear/poet.h | 11 +- src/rime/gear/table_translator.cc | 4 +- src/rime/gear/translator_commons.cc | 19 +- src/rime/gear/translator_commons.h | 16 +- 7 files changed, 183 insertions(+), 121 deletions(-) diff --git a/src/rime/gear/contextual_translation.cc b/src/rime/gear/contextual_translation.cc index 7aee15a8b..8e3336774 100644 --- a/src/rime/gear/contextual_translation.cc +++ b/src/rime/gear/contextual_translation.cc @@ -1,6 +1,7 @@ #include #include #include +#include #include namespace rime { @@ -37,12 +38,13 @@ bool ContextualTranslation::Replenish() { } an ContextualTranslation::Evaluate(an phrase) { - auto sentence = New(phrase->language()); - sentence->Offset(phrase->start()); bool is_rear = phrase->end() == input_.length(); - sentence->Extend(phrase->entry(), phrase->end(), is_rear, preceding_text_, - grammar_); - phrase->set_weight(sentence->weight()); + double weight = Grammar::Evaluate(preceding_text_, + phrase->text(), + phrase->weight(), + is_rear, + grammar_); + phrase->set_weight(weight); DLOG(INFO) << "contextual suggestion: " << phrase->text() << " weight: " << phrase->weight(); return phrase; diff --git a/src/rime/gear/grammar.h b/src/rime/gear/grammar.h index 09923d8ac..9d545ed78 100644 --- a/src/rime/gear/grammar.h +++ b/src/rime/gear/grammar.h @@ -3,7 +3,6 @@ #include #include -#include namespace rime { @@ -17,12 +16,13 @@ class Grammar : public Class { bool is_rear) = 0; inline static double Evaluate(const string& context, - const DictEntry& entry, + const string& entry_text, + double entry_weight, bool is_rear, Grammar* grammar) { const double kPenalty = -18.420680743952367; // log(1e-8) - return entry.weight + - (grammar ? grammar->Query(context, entry.text, is_rear) : kPenalty); + return entry_weight + + (grammar ? grammar->Query(context, entry_text, is_rear) : kPenalty); } }; diff --git a/src/rime/gear/poet.cc b/src/rime/gear/poet.cc index a87f385a1..7cf9ce17d 100644 --- a/src/rime/gear/poet.cc +++ b/src/rime/gear/poet.cc @@ -16,6 +16,64 @@ namespace rime { +// internal data structure used during the sentence making process. +// the output line of the algorithm is transformed to an. +struct Line { + // be sure the pointer to predecessor Line object is stable. it works since + // pointer to values stored in std::map and std::unordered_map are stable. + const Line* predecessor; + // as long as the word graph lives, pointers to entries are valid. + const DictEntry* entry; + size_t end_pos; + double weight; + + static const Line kEmpty; + + bool empty() const { + return !predecessor && !entry; + } + + string last_word() const { + return entry ? entry->text : string(); + } + + struct Components { + vector lines; + + Components(const Line* line) { + for (const Line* cursor = line; + !cursor->empty(); + cursor = cursor->predecessor) { + lines.push_back(cursor); + } + } + + decltype(lines.crbegin()) begin() const { return lines.crbegin(); } + decltype(lines.crend()) end() const { return lines.crend(); } + }; + + Components components() const { return Components(this); } + + string context() const { + // look back 2 words + return empty() ? string() : + !predecessor || predecessor->empty() ? last_word() : + predecessor->last_word() + last_word(); + } + + vector word_lengths() const { + vector lengths; + size_t last_end_pos = 0; + for (const auto* c : components()) { + lengths.push_back(c->end_pos - last_end_pos); + last_end_pos = c->end_pos; + } + return lengths; + } +}; + +const Line Line::kEmpty{nullptr, nullptr, 0, 0.0}; + inline static Grammar* create_grammar(Config* config) { if (auto* grammar = Grammar::Require("grammar")) { return grammar->Create(config); @@ -30,102 +88,103 @@ Poet::Poet(const Language* language, Config* config, Compare compare) Poet::~Poet() {} -bool Poet::LeftAssociateCompare(const Sentence& one, const Sentence& other) { - return one.weight() < other.weight() || ( // left associate if even - one.weight() == other.weight() && ( - one.size() > other.size() || ( // less components is more favorable - one.size() == other.size() && - std::lexicographical_compare(one.syllable_lengths().begin(), - one.syllable_lengths().end(), - other.syllable_lengths().begin(), - other.syllable_lengths().end())))); +bool Poet::CompareWeight(const Line& one, const Line& other) { + return one.weight < other.weight; +} + +// returns true if one is less than other. +bool Poet::LeftAssociateCompare(const Line& one, const Line& other) { + if (one.weight < other.weight) return true; + if (one.weight == other.weight) { + auto one_word_lens = one.word_lengths(); + auto other_word_lens = other.word_lengths(); + // less words is more favorable + if (one_word_lens.size() > other_word_lens.size()) return true; + if (one_word_lens.size() == other_word_lens.size()) { + return std::lexicographical_compare( + one_word_lens.begin(), one_word_lens.end(), + other_word_lens.begin(), other_word_lens.end()); + } + } + return false; } -// keep the best sentence candidate per last phrase -using SentenceCandidates = hash_map>; +// keep the best line candidate per last phrase +using LineCandidates = hash_map; template -static vector> find_top_candidates( - const SentenceCandidates& candidates, Poet::Compare compare) { - vector> top; +static vector find_top_candidates( + const LineCandidates& candidates, Poet::Compare compare) { + vector top; top.reserve(N + 1); for (const auto& candidate : candidates) { auto pos = std::upper_bound( - top.begin(), top.end(), candidate.second, - [&](const an& a, const an& b) { - return !compare(*a, *b); // desc - }); + top.begin(), top.end(), &candidate.second, + [&](const Line* a, const Line* b) { return compare(*b, *a); }); // desc if (pos - top.begin() >= N) continue; - top.insert(pos, candidate.second); + top.insert(pos, &candidate.second); if (top.size() > N) top.pop_back(); } return top; } -static an find_best_sentence(const SentenceCandidates& candidates, - Poet::Compare compare) { - an best = nullptr; - for (const auto& candidate : candidates) { - if (!best || compare(*best, *candidate.second)) { - best = candidate.second; - } - } - return best; -} - -using UpdateSetenceCandidate = function& candidate)>; +using UpdateLineCandidate = function; struct BeamSearch { - using State = SentenceCandidates; + using State = LineCandidates; - static constexpr int kMaxSentenceCandidates = 7; + static constexpr int kMaxLineCandidates = 7; - static void Initiate(State& initial_state, const Language* language) { - initial_state.emplace("", New(language)); + static void Initiate(State& initial_state) { + initial_state.emplace("", Line::kEmpty); } static void ForEachCandidate(const State& state, Poet::Compare compare, - UpdateSetenceCandidate update) { + UpdateLineCandidate update) { auto top_candidates = - find_top_candidates(state, compare); - for (const auto& candidate : top_candidates) { - update(candidate); + find_top_candidates(state, compare); + for (const auto* candidate : top_candidates) { + update(*candidate); } } - static an& BestSentenceToUpdate(State& state, - const an& new_sentence) { - const auto& key = new_sentence->components().back().text; + static Line& BestLineToUpdate(State& state, const Line& new_line) { + const auto& key = new_line.last_word(); return state[key]; } - static an BestSentence(const State& final_state, - Poet::Compare compare) { - return find_best_sentence(final_state, compare); + static const Line& BestLineInState(const State& final_state, + Poet::Compare compare) { + const Line* best = nullptr; + for (const auto& candidate : final_state) { + if (!best || compare(*best, candidate.second)) { + best = &candidate.second; + } + } + return best ? *best : Line::kEmpty; } }; struct DynamicProgramming { - using State = an; + using State = Line; - static void Initiate(State& initial_state, const Language* language) { - initial_state = New(language); + static void Initiate(State& initial_state) { + initial_state = Line::kEmpty; } static void ForEachCandidate(const State& state, Poet::Compare compare, - UpdateSetenceCandidate update) { + UpdateLineCandidate update) { update(state); } - static an& BestSentenceToUpdate(State& state, - const an& new_sentence) { + static Line& BestLineToUpdate(State& state, const Line& new_line) { return state; } - static an BestSentence(const State& final_state, - Poet::Compare compare) { + static const Line& BestLineInState(const State& final_state, + Poet::Compare compare) { return final_state; } }; @@ -134,47 +193,58 @@ template an Poet::MakeSentenceWithStrategy(const WordGraph& graph, size_t total_length, const string& preceding_text) { - map sentences; - Strategy::Initiate(sentences[0], language_); - for (const auto& w : graph) { - size_t start_pos = w.first; - if (sentences.find(start_pos) == sentences.end()) + map states; + Strategy::Initiate(states[0]); + for (const auto& sv : graph) { + size_t start_pos = sv.first; + if (states.find(start_pos) == states.end()) continue; DLOG(INFO) << "start pos: " << start_pos; - const auto& source(sentences[start_pos]); - Strategy::ForEachCandidate( - source, compare_, - [&](const an& candidate) { - for (const auto& x : w.second) { - size_t end_pos = x.first; + const auto& source_state = states[start_pos]; + const auto update = + [this, &states, &sv, start_pos, total_length, &preceding_text] + (const Line& candidate) { + for (const auto& ev : sv.second) { + size_t end_pos = ev.first; if (start_pos == 0 && end_pos == total_length) - continue; // exclude single words from the result + continue; // exclude single word from the result DLOG(INFO) << "end pos: " << end_pos; bool is_rear = end_pos == total_length; - auto& target(sentences[end_pos]); + auto& target_state = states[end_pos]; // extend candidates with dict entries on a valid edge. - const DictEntryList& entries(x.second); + const DictEntryList& entries = ev.second; for (const auto& entry : entries) { - auto new_sentence = New(*candidate); - new_sentence->Extend( - *entry, end_pos, is_rear, preceding_text, grammar_.get()); - auto& best_sentence = - Strategy::BestSentenceToUpdate(target, new_sentence); - if (!best_sentence || compare_(*best_sentence, *new_sentence)) { - DLOG(INFO) << "updated sentences " << end_pos << ") with " - << new_sentence->text() << " weight: " - << new_sentence->weight(); - best_sentence = std::move(new_sentence); + const string& context = + candidate.empty() ? preceding_text : candidate.context(); + double weight = candidate.weight + + Grammar::Evaluate(context, + entry->text, + entry->weight, + is_rear, + grammar_.get()); + Line new_line{&candidate, entry.get(), end_pos, weight}; + Line& best = Strategy::BestLineToUpdate(target_state, new_line); + if (best.empty() || compare_(best, new_line)) { + DLOG(INFO) << "updated line ending at " << end_pos + << " with text: ..." << new_line.last_word() + << " weight: " << new_line.weight; + best = new_line; } } } - }); + }; + Strategy::ForEachCandidate(source_state, compare_, update); } - auto found = sentences.find(total_length); - if (found == sentences.end()) + auto found = states.find(total_length); + if (found == states.end() || found->second.empty()) return nullptr; - else - return Strategy::BestSentence(found->second, compare_); + const Line& best = Strategy::BestLineInState(found->second, compare_); + auto sentence = New(language_); + for (const auto* c : best.components()) { + if (!c->entry) continue; + sentence->Extend(*c->entry, c->end_pos, c->weight); + } + return sentence; } an Poet::MakeSentence(const WordGraph& graph, diff --git a/src/rime/gear/poet.h b/src/rime/gear/poet.h index 69b498734..f67d1ae0a 100644 --- a/src/rime/gear/poet.h +++ b/src/rime/gear/poet.h @@ -22,16 +22,15 @@ using WordGraph = map; class Grammar; class Language; +struct Line; class Poet { public: - // sentence "less", used to compare sentences of the same input range. - using Compare = function; + // Line "less", used to compare composed line of the same input range. + using Compare = function; - static bool CompareWeight(const Sentence& one, const Sentence& other) { - return one.weight() < other.weight(); - } - static bool LeftAssociateCompare(const Sentence& one, const Sentence& other); + static bool CompareWeight(const Line& one, const Line& other); + static bool LeftAssociateCompare(const Line& one, const Line& other); Poet(const Language* language, Config* config, Compare compare = CompareWeight); diff --git a/src/rime/gear/table_translator.cc b/src/rime/gear/table_translator.cc index 744683bd4..13935a49c 100644 --- a/src/rime/gear/table_translator.cc +++ b/src/rime/gear/table_translator.cc @@ -383,7 +383,7 @@ Spans SentenceSyllabifier::Syllabify(const Phrase* phrase) { if (auto sentence = dynamic_cast(phrase)) { size_t stop = sentence->start(); result.AddVertex(stop); - for (size_t len : sentence->syllable_lengths()) { + for (size_t len : sentence->word_lengths()) { stop += len; result.AddVertex(stop); } @@ -501,7 +501,7 @@ void SentenceTranslation::PrepareSentence() { const string& delimiters(translator_->delimiters()); // split syllables size_t pos = 0; - for (int len : sentence_->syllable_lengths()) { + for (int len : sentence_->word_lengths()) { if (pos > 0 && delimiters.find(input_[pos - 1]) == string::npos) { preedit.insert(pos, 1, ' '); ++pos; diff --git a/src/rime/gear/translator_commons.cc b/src/rime/gear/translator_commons.cc index 8e305fcd9..ba3635c34 100644 --- a/src/rime/gear/translator_commons.cc +++ b/src/rime/gear/translator_commons.cc @@ -88,19 +88,16 @@ bool Spans::HasVertex(size_t vertex) const { // Sentence -void Sentence::Extend(const DictEntry& entry, +void Sentence::Extend(const DictEntry& another, size_t end_pos, - bool is_rear, - const string& preceding_text, - Grammar* grammar) { - const string& context = empty() ? preceding_text : text(); - entry_->weight += Grammar::Evaluate(context, entry, is_rear, grammar); - entry_->text.append(entry.text); + double new_weight) { + entry_->weight = new_weight; + entry_->text.append(another.text); entry_->code.insert(entry_->code.end(), - entry.code.begin(), - entry.code.end()); - components_.push_back(entry); - syllable_lengths_.push_back(end_pos - end()); + another.code.begin(), + another.code.end()); + components_.push_back(another); + word_lengths_.push_back(end_pos - end()); set_end(end_pos); DLOG(INFO) << "extend sentence " << end_pos << ") " << text() << " weight: " << weight(); diff --git a/src/rime/gear/translator_commons.h b/src/rime/gear/translator_commons.h index 343fd7485..de614fdf2 100644 --- a/src/rime/gear/translator_commons.h +++ b/src/rime/gear/translator_commons.h @@ -109,8 +109,6 @@ class Phrase : public Candidate { // -class Grammar; - class Sentence : public Phrase { public: Sentence(const Language* language) @@ -118,14 +116,10 @@ class Sentence : public Phrase { Sentence(const Sentence& other) : Phrase(other), components_(other.components_), - syllable_lengths_(other.syllable_lengths_) { + word_lengths_(other.word_lengths_) { entry_ = New(other.entry()); } - void Extend(const DictEntry& entry, - size_t end_pos, - bool is_rear, - const string& preceding_text, - Grammar* grammar); + void Extend(const DictEntry& another, size_t end_pos, double new_weight); void Offset(size_t offset); bool empty() const { @@ -139,13 +133,13 @@ class Sentence : public Phrase { const vector& components() const { return components_; } - const vector& syllable_lengths() const { - return syllable_lengths_; + const vector& word_lengths() const { + return word_lengths_; } protected: vector components_; - vector syllable_lengths_; + vector word_lengths_; }; //