From b3f4005d35a80e08b4b85234237c6777cb91c1a3 Mon Sep 17 00:00:00 2001 From: Chen Gong Date: Fri, 19 Apr 2019 02:22:17 +0800 Subject: [PATCH] feat(poet): find best sentence candidates --- src/rime/gear/poet.cc | 84 +++++++++++++++++++++++++++++++------------ 1 file changed, 61 insertions(+), 23 deletions(-) diff --git a/src/rime/gear/poet.cc b/src/rime/gear/poet.cc index 5a5f87038..abf5ea60f 100644 --- a/src/rime/gear/poet.cc +++ b/src/rime/gear/poet.cc @@ -41,35 +41,73 @@ bool Poet::LeftAssociateCompare(const Sentence& one, const Sentence& other) { other.syllable_lengths().end())))); } +// keep the best sentence candidate per last phrase +using SentenceCandidates = hash_map>; + +static vector> top_candidates(const SentenceCandidates& candidates, + size_t n, + Poet::Compare& compare) { + vector> top; + top.reserve(n + 1); + for (const auto& candidate : candidates) { + auto pos = std::upper_bound( + top.begin(), top.end(), candidate.second, + [&](const an& a, const an& b) { + return !compare(*a, *b); // desc + }); + if (pos - top.begin() >= n) continue; + top.insert(pos, candidate.second); + if (top.size() > n) top.pop_back(); + } + return top; +} + +an find_best_sentence(const SentenceCandidates& candidates, + Poet::Compare& compare) { + an best = nullptr; + for (const auto& candidate : candidates) { + if (!best || compare(*best, *candidate.second)) { + best = candidate.second; + } + } + return best; +} + +constexpr int kMaxSentenceCandidates = 7; + an Poet::MakeSentence(const WordGraph& graph, size_t total_length, const string& preceding_text) { - // TODO: save more intermediate sentence candidates - map> sentences; - sentences[0] = New(language_); - // dynamic programming + map sentences; + sentences[0].emplace("", New(language_)); for (const auto& w : graph) { size_t start_pos = w.first; - DLOG(INFO) << "start pos: " << start_pos; if (sentences.find(start_pos) == sentences.end()) continue; - for (const auto& x : w.second) { - size_t end_pos = x.first; - if (start_pos == 0 && end_pos == total_length) - continue; // exclude single words from the result - DLOG(INFO) << "end pos: " << end_pos; - bool is_rear = end_pos == total_length; - const DictEntryList& entries(x.second); - for (const auto& entry : entries) { - auto new_sentence = New(*sentences[start_pos]); - new_sentence->Extend( - *entry, end_pos, is_rear, preceding_text, grammar_.get()); - if (sentences.find(end_pos) == sentences.end() || - compare_(*sentences[end_pos], *new_sentence)) { - DLOG(INFO) << "updated sentences " << end_pos << ") with " - << new_sentence->text() << " weight: " - << new_sentence->weight(); - sentences[end_pos] = std::move(new_sentence); + DLOG(INFO) << "start pos: " << start_pos; + auto top = top_candidates( + sentences[start_pos], kMaxSentenceCandidates, compare_); + for (const auto& candidate : top) { + for (const auto& x : w.second) { + size_t end_pos = x.first; + if (start_pos == 0 && end_pos == total_length) + continue; // exclude single words from the result + DLOG(INFO) << "end pos: " << end_pos; + bool is_rear = end_pos == total_length; + auto& target(sentences[end_pos]); + const DictEntryList& entries(x.second); + for (const auto& entry : entries) { + auto new_sentence = New(*candidate); + new_sentence->Extend( + *entry, end_pos, is_rear, preceding_text, grammar_.get()); + const auto& key = new_sentence->components().back().text; + auto& best_sentence = target[key]; + if (!best_sentence || compare_(*best_sentence, *new_sentence)) { + DLOG(INFO) << "updated sentences " << end_pos << ") with " + << new_sentence->text() << " weight: " + << new_sentence->weight(); + best_sentence = std::move(new_sentence); + } } } } @@ -77,7 +115,7 @@ an Poet::MakeSentence(const WordGraph& graph, if (sentences.find(total_length) == sentences.end()) return nullptr; else - return sentences[total_length]; + return find_best_sentence(sentences[total_length], compare_); } } // namespace rime