diff --git a/src/rnnlm/Makefile b/src/rnnlm/Makefile index dfb94ab8ea0..3a11f0556ed 100644 --- a/src/rnnlm/Makefile +++ b/src/rnnlm/Makefile @@ -10,15 +10,14 @@ TESTFILES = rnnlm-utils-test rnnlm-sampling-test OBJFILES = rnnlm-component-itf.o rnnlm-utils.o rnnlm-nnet.o rnnlm-component.o nnet-parse.o \ rnnlm-training.o \ - rnnlm-diagnostics.o -# rnnlm-utils-test.o -# rnnlm-test-utils.o + rnnlm-diagnostics.o \ + arpa-sampling.o \ LIBNAME = kaldi-rnnlm ADDLIBS = ../nnet3/kaldi-nnet3.a ../chain/kaldi-chain.a \ ../cudamatrix/kaldi-cudamatrix.a ../decoder/kaldi-decoder.a \ - ../lat/kaldi-lat.a ../fstext/kaldi-fstext.a ../hmm/kaldi-hmm.a \ + ../lat/kaldi-lat.a ../fstext/kaldi-fstext.a ../hmm/kaldi-hmm.a ../lm/kaldi-lm.a \ ../transform/kaldi-transform.a ../gmm/kaldi-gmm.a \ ../tree/kaldi-tree.a ../util/kaldi-util.a ../thread/kaldi-thread.a \ ../matrix/kaldi-matrix.a ../base/kaldi-base.a diff --git a/src/rnnlm/arpa-sampling.cc b/src/rnnlm/arpa-sampling.cc new file mode 100644 index 00000000000..faf380528d0 --- /dev/null +++ b/src/rnnlm/arpa-sampling.cc @@ -0,0 +1,359 @@ +// arpa-sampling.cc + +#include "arpa-sampling.h" +#include +#include +#include +#include +#include + +namespace kaldi { + +/// this function reads each ngram line in the ARPA file +void ArpaSampling::ConsumeNGram(const NGram& ngram) { + int32 cur_order = ngram.words.size(); + int32 word = ngram.words.back(); // word is the last word in vector words + HistType history(ngram.words.begin(), ngram.words.begin() + cur_order - 1); + KALDI_ASSERT(history.size() == cur_order - 1); + + BaseFloat log_prob = ngram.logprob / M_LN10; + BaseFloat backoff_weight = ngram.backoff / M_LN10; + std::pair probs_pair; + probs_pair = std::make_pair(log_prob, backoff_weight); + // update map + probs_[cur_order - 1][history].insert({word, probs_pair}); + + // get vocab_, the map from word string to integer + const fst::SymbolTable* sym = Symbols(); + if (cur_order == 1) { + num_words_++; + std::string word_s = sym->Find(word); + std::pair word_pair; + word_pair = std::make_pair(word_s, word); + vocab_.push_back(word_pair); + } +} + +void ArpaSampling::HeaderAvailable() { + ngram_counts_ = NgramCounts(); + ngram_order_ = NgramCounts().size(); + probs_.resize(ngram_order_); +} + +// this function returns the probability of the ngram (history, word) for given +// order if the history and the word given the history exists. +// Otherwise it backoff to previous order to recursively search the lower order +// ngram until backoff to unigram. +BaseFloat ArpaSampling::GetProb(int32 order, int32 word, const HistType& history) { + BaseFloat prob = 0.0; + NgramType::const_iterator it = probs_[order - 1].find(history); + if (it != probs_[order - 1].end() && + probs_[order-1][history].find(word) != probs_[order-1][history].end()) { + prob += probs_[order-1][history][word].first; + } else { // backoff to the previous order + order--; + if (order >= 1) { + HistType::const_iterator first = history.begin() + 1; + HistType::const_iterator last = history.end(); + HistType h(first, last); + prob += GetProb(order, word, h); + int32 word_new = history.back(); + HistType::const_iterator last_new = history.end() - 1; + HistType h_new(history.begin(), last_new); + prob += GetBackoffWeight(order, word_new, h_new); + } + } + return prob; +} + +// this function returns the backoff weight of the ngram (history, word) +BaseFloat ArpaSampling::GetBackoffWeight(int32 order, int32 word, const HistType& history) { + BaseFloat bow = 0.0; + NgramType::const_iterator it = probs_[order - 1].find(history); + if (it != probs_[order - 1].end()) { + WordToProbsMap::const_iterator it2 = probs_[order - 1][history].find(word); + if (it2 != probs_[order - 1][history].end()) { + bow = it2->second.second; + } + } + return bow; +} + +// this function computes the estimated pdf given a history +void ArpaSampling::ComputeWordPdf(const HistType& history, std::vector >* pdf) { + int32 order = history.size(); + BaseFloat prob = 0.0; + (*pdf).resize(num_words_); + for (int32 i = 0; i < num_words_; i++) { + NgramType::const_iterator it = probs_[order].find(history); + int32 word = vocab_[i].second; + if (it != probs_[order].end()) { + WordToProbsMap::const_iterator it2 = probs_[order][history].find(word); + if (it2 != probs_[order][history].end()) { + prob = pow(10, it2->second.first); + (*pdf)[i].first = word; + (*pdf)[i].second += prob; + } else { + HistType::const_iterator first = history.begin() + 1; + HistType::const_iterator last = history.end(); + HistType h(first, last); + int32 word_new = history.back(); + HistType::const_iterator last_new = history.end() - 1; + HistType h_new(history.begin(), last_new); + prob = pow(10, GetBackoffWeight(order, word_new, h_new) + GetProb(order, word, h)); + (*pdf)[i].first = word; + (*pdf)[i].second += prob; + } + } else { + HistType::const_iterator first = history.begin() + 1; + HistType::const_iterator last = history.end(); + HistType h(first, last); + int32 word_new = history.back(); + HistType::const_iterator last_new = history.end() - 1; + HistType h_new(history.begin(), last_new); + prob = pow(10, GetBackoffWeight(order, word_new, h_new) + GetProb(order, word, h)); + (*pdf)[i].first = word; + (*pdf)[i].second += prob; + } + } +} + +// this function computes history weights for given histories +// the total weights of histories is 1 +HistWeightsType ArpaSampling::ComputeHistoriesWeights(std::vector histories) { + HistWeightsType hists_weights; + for (std::vector::iterator it = histories.begin(); it != histories.end(); ++it) { + HistType history(*(it)); + KALDI_ASSERT(history.size() <= ngram_order_); + for (int32 i = 0; i < history.size() + 1; i++) { + HistType h_tmp = history; + BaseFloat prob = 1.0 / histories.size(); + while (h_tmp.size() > (history.size() - i)) { + HistType::iterator last = h_tmp.end() - 1; + HistType h(h_tmp.begin(), last); + int32 word = h_tmp.back(); + prob *= pow(10, GetBackoffWeight(h_tmp.size(), word, h)); + HistType h_up(h_tmp.begin() + 1, h_tmp.end()); + h_tmp = h_up; + } + HistType::iterator begin = history.begin() + i; + HistType h(begin, history.end()); + hists_weights[h] += prob; + } + } + return hists_weights; +} + +// Get weighted pdf given a list of histories +void ArpaSampling::ComputeWeightedPdf(HistWeightsType hists_weights, + std::vector >* pdf_w) { + BaseFloat prob = 0; + (*pdf_w).clear(); + (*pdf_w).resize(num_words_); + for (int32 i = 0; i < num_words_; i++) { + for (HistWeightsType::const_iterator it = hists_weights.begin(); + it != hists_weights.end(); ++it) { + HistType h(it->first); + int32 order = h.size(); + NgramType::const_iterator it_hist = probs_[order].find(h); + if (it_hist != probs_[order].end()) { + int32 word = vocab_[i].second; + WordToProbsMap::const_iterator it_word = probs_[order][h].find(word); + if (it_word != probs_[order][h].end()) { + if (order > 0) { + HistType::iterator last = h.end() - 1; + HistType::iterator first = h.begin() + 1; + HistType h1(h.begin(), last); + HistType h2(first, h.end()); + prob = it->second * (pow(10, probs_[order][h][word].first) - + pow(10, GetBackoffWeight(order, h.back(), h1) + GetProb(order, word, h2))); + (*pdf_w)[i].first = word; + (*pdf_w)[i].second += prob; + } else { + prob = it->second * pow(10, probs_[order][h][word].first); + (*pdf_w)[i].first = word; + (*pdf_w)[i].second += prob; + } + } + } + } // end reading history + } // end reading words +} + +// this function compute words existing for given histories and their corresponding +// probabilities +void ArpaSampling::ComputeOutputWords(std::vector histories, + unordered_map* pdf_w) { + HistWeightsType hists_weights = ComputeHistoriesWeights(histories); + BaseFloat prob = 0; + for (HistWeightsType::const_iterator it = hists_weights.begin(); it != hists_weights.end(); ++it) { + HistType h(it->first); + int32 order = h.size(); + NgramType::const_iterator it_hist = probs_[order].find(h); + if (it_hist != probs_[order].end()) { + for(WordToProbsMap::const_iterator it_word = probs_[order][h].begin(); + it_word != probs_[order][h].end(); ++it_word) { + int32 word = it_word->first; + if (order > 0) { + HistType::iterator last = h.end() - 1; + HistType::iterator first = h.begin() + 1; + HistType h1(h.begin(), last); + HistType h2(first, h.end()); + prob = it->second * (pow(10, probs_[order][h][word].first) - + pow(10, GetBackoffWeight(order, h.back(), h1) + GetProb(order, word, h2))); + unordered_map::iterator map_it = (*pdf_w).find(word); + if (map_it != (*pdf_w).end()) { + (*pdf_w)[word] += prob; + } else { + (*pdf_w).insert({word, prob}); + } + } + } + } + } +} + +// this function randomly generate 5 - 1005 histories +std::vector ArpaSampling::RandomGenerateHistories() { + std::vector histories; + int32 num_histories = rand() % 1000 + 5; // generate at least 5 histories + for (int32 i = 0; i < num_histories; i++) { + HistType hist; + // size of history should be in {1, 2, ..., ngram_order_} + int32 size_hist = rand() % (ngram_order_ - 1) + 1; + KALDI_ASSERT(size_hist <= ngram_order_); + for (int32 j = 0; j < size_hist; j++) { + // word can not be zero since zero represents epsilon in the fst symbol format + int32 word = rand() % (vocab_.size() - 1) + 1; + KALDI_ASSERT(word > 0 && word <= vocab_.size()); + hist.push_back(word); + } + histories.push_back(hist); + } + return histories; +} + +// this function checks the two estimated pdfs from 1) weighted history +// and 2) normal computation are the same +void ArpaSampling::TestPdfsEqual() { + std::vector histories; + histories = RandomGenerateHistories(); + HistWeightsType hists_weights; + hists_weights = ComputeHistoriesWeights(histories); + std::vector > pdf_hist_weight; + ComputeWeightedPdf(hists_weights, &pdf_hist_weight); + // check the averaged pdf sums to 1 + BaseFloat sum = 0; + for (int32 i = 0; i < num_words_; i++) { + sum += pdf_hist_weight[i].second; + } + KALDI_ASSERT(ApproxEqual(sum, 1.0)); + // get the average pdf + std::vector > pdf; + pdf.resize(num_words_); + for (int32 i = 0; i < histories.size(); i++) { + std::vector > pdf_h; + ComputeWordPdf(histories[i], &pdf_h); + for(int32 j = 0; j < pdf_h.size(); j++) { + pdf[j].first = pdf_h[j].first; + pdf[j].second += pdf_h[j].second / histories.size(); + } + } + // check the averaged pdf sums to 1 + sum = 0; + for (int32 i = 0; i < num_words_; i++) { + sum += pdf[i].second; + } + KALDI_ASSERT(ApproxEqual(sum, 1.0)); + // check equality of the two pdfs + BaseFloat diff = 0; + for (int32 i = 0; i < num_words_; i++) { + diff += abs(pdf_hist_weight[i].second - pdf[i].second); + } + KALDI_ASSERT(ApproxEqual(diff, 0.0)); +} + +// Test the read-in language model +void ArpaSampling::TestReadingModel() { + KALDI_LOG << "Testing model reading part..."<< std::endl; + KALDI_LOG << "Vocab size is: " << vocab_.size(); + KALDI_LOG << "Ngram_order is: " << ngram_order_; + KALDI_ASSERT(probs_.size() == ngram_counts_.size()); + for (int32 i = 0; i < ngram_order_; i++) { + int32 size_ngrams = 0; + KALDI_LOG << "Test: for order " << (i + 1); + KALDI_LOG << "Expected number of " << (i + 1) << "-grams: " << ngram_counts_[i]; + for (NgramType::const_iterator it1 = probs_[i].begin(); it1 != probs_[i].end(); ++it1) { + HistType h(it1->first); + for (WordToProbsMap::const_iterator it2 = probs_[i][h].begin(); it2 != probs_[i][h].end(); ++it2) { + size_ngrams++; // number of words given + } + } + KALDI_LOG << "Read in number of " << (i + 1) << "-grams: " << size_ngrams; + } + KALDI_LOG << "Assert sum of unigram probs equal to 1..."; + BaseFloat prob_sum = 0.0; + int32 count = 0; + for (NgramType::const_iterator it1 = probs_[0].begin(); it1 != probs_[0].end();++it1) { + HistType h(it1->first); + for (WordToProbsMap::const_iterator it2 = probs_[0][h].begin(); it2 != probs_[0][h].end(); ++it2) { + prob_sum += 1.0 * pow(10.0, it2->second.first); + count++; + } + } + KALDI_LOG << "Number of total words: " << count; + KALDI_LOG << "Sum of unigram probs equal to " << prob_sum; + + KALDI_LOG << "Assert sum of bigram probs given a history equal to 1..."; + prob_sum = 0.0; + NgramType::const_iterator it1 = probs_[1].begin(); + HistType h(it1->first); + for (int32 i = 0; i < num_words_; i++) { + WordToProbsMap::const_iterator it2 = probs_[1][h].find(vocab_[i].second); + if (it2 != probs_[1][h].end()) { + prob_sum += 1.0 * pow(10, it2->second.first); + } else { + prob_sum += pow(10, GetProb(2, vocab_[i].second, h)); + } + } + KALDI_LOG << "Sum of bigram probs given a history equal to " << prob_sum; +} + +int32 ArpaSampling::GetNgramOrder() { + return ngram_order_; +} + +// Read histories of integers from a file +std::vector ArpaSampling::ReadHistories(std::istream &is, bool binary) { + if (binary) { + KALDI_ERR << "binary-mode reading is not implemented for ArpaFileParser"; + } + const fst::SymbolTable* sym = Symbols(); + std::vector histories; + std::string line; + KALDI_LOG << "Start reading histories from file..."; + while (getline(is, line)) { + std::istringstream is(line); + std::istream_iterator begin(is), end; + std::vector tokens(begin, end); + HistType history; + int32 word; + for (int32 i = 0; i < tokens.size(); i++) { + word = sym->Find(tokens[i]); + if (word == fst::SymbolTable::kNoSymbol) { + word = sym->Find(unk_symbol_); + } + history.push_back(word); + } + if (history.size() >= ngram_order_) { + HistType h(history.end() - ngram_order_ + 1, history.end()); + history.clear(); + HistType history = h; + } + histories.push_back(history); + } + KALDI_LOG << "Finished reading histories from file."; + return histories; +} + +} // end of kaldi diff --git a/src/rnnlm/arpa-sampling.h b/src/rnnlm/arpa-sampling.h new file mode 100644 index 00000000000..5f80ca308a5 --- /dev/null +++ b/src/rnnlm/arpa-sampling.h @@ -0,0 +1,142 @@ +// arpa_sampling.h + +// Copyright 2016 Ke Li + +// See ../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABILITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef ARPA_SAMPLING_H_ +#define ARPA_SAMPLING_H_ + +#include +#include +#include "lm/arpa-file-parser.h" +#include "fst/fstlib.h" +#include "util/common-utils.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace kaldi { + +typedef int32_t int32; + +enum { + kEps = 0, + kDisambig, + kBos, kEos, kUnk +}; + +typedef std::vector HistType; +typedef unordered_map > WordToProbsMap; +typedef unordered_map > NgramType; +typedef unordered_map > HistWeightsType; + +class ArpaSampling : public ArpaFileParser { + public: + // constructor + explicit ArpaSampling(ArpaParseOptions options, fst::SymbolTable* symbols) + : ArpaFileParser(options, symbols) { + ngram_order_ = 0; + num_words_ = 0; + bos_symbol_ = ""; + eos_symbol_ = ""; + unk_symbol_ = ""; + } + + // This function returns the log probability of a ngram term from the ARPA LM + // if it is found; it backoffs to the lower order model when the ngram term + // does not exist. + BaseFloat GetProb(int32 order, int32 word, const HistType& history); + + // Get the back-off weight of a ngram in the read-in model + BaseFloat GetBackoffWeight(int32 order, int32 word, const HistType& history); + + // Compute non-unigram output words and corresponding probs for given histories + void ComputeOutputWords(std::vector histories, + unordered_map* pdf_w); + + // Compute weighted pdf given all histories + void ComputeWeightedPdf(HistWeightsType hists_weights, + std::vector >* weighted_pdf); + + // Get ngram order + int32 GetNgramOrder(); + + void TestReadingModel(); + + void TestProbs(std::istream &is, bool binary); + + void TestPdfsEqual(); + + std::vector ReadHistories(std::istream &is, bool binary); + + protected: + // ArpaFileParser overrides. + virtual void HeaderAvailable(); + virtual void ConsumeNGram(const NGram& ngram); + virtual void ReadComplete() {} + + private: + // For test: randomly generate histories + std::vector RandomGenerateHistories(); + + // Compute a pdf of words in the vocab given a history + void ComputeWordPdf(const HistType& history, + std::vector >* pdf); + + // Compute weights of given histories + HistWeightsType ComputeHistoriesWeights(std::vector histories); + + // N-gram order of the read-in LM. + int32 ngram_order_; + + // num_words + int32 num_words_; + + // Bos symbol + std::string bos_symbol_; + + // Eos symbol + std::string eos_symbol_; + + // Unk symbol + std::string unk_symbol_; + + // Vocab + std::vector > vocab_; + + // Counts of each ngram + std::vector ngram_counts_; + + // N-gram probabilities. + std::vector probs_; + + // Histories' weights + HistWeightsType hists_weights_; + + // Test sentences + std::vector > sentences_; +}; + +} // end of namespace kaldi +#endif diff --git a/src/rnnlm/rnnlm-utils-test.cc b/src/rnnlm/rnnlm-utils-test.cc index 1e5efe7d424..1a3256a818d 100644 --- a/src/rnnlm/rnnlm-utils-test.cc +++ b/src/rnnlm/rnnlm-utils-test.cc @@ -1,7 +1,13 @@ // rnnlm/rnnlm-utils-test.cc -#include #include "rnnlm/rnnlm-utils.h" +#include "arpa-sampling.h" + +#include +#include +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "fst/fstlib.h" namespace kaldi { namespace rnnlm { @@ -164,7 +170,7 @@ void UnitTestSamplingTime(int iters) { } // end namespace rnnlm } // end namespace kaldi. -int main() { +int main(int argc, char **argv) { using namespace kaldi; using namespace rnnlm; int N = 1000; @@ -173,3 +179,43 @@ int main() { UnitTestSamplingTime(N); } + const char *usage = ""; + ParseOptions po(usage); + po.Read(argc, argv); + std::string arpa_file = po.GetArg(1), history_file = po.GetArg(2); + + ArpaParseOptions options; + fst::SymbolTable symbols; + // Use spaces on special symbols, so we rather fail than read them by mistake. + symbols.AddSymbol(" ", kEps); + // symbols.AddSymbol(" #0", kDisambig); + options.bos_symbol = symbols.AddSymbol("", kBos); + options.eos_symbol = symbols.AddSymbol("", kEos); + options.unk_symbol = symbols.AddSymbol("", kUnk); + options.oov_handling = ArpaParseOptions::kAddToSymbols; + ArpaSampling mdl(options, &symbols); + + bool binary; + Input k1(arpa_file, &binary); + mdl.Read(k1.Stream(), binary); + mdl.TestReadingModel(); + + Input k2(history_file, &binary); + std::vector histories; + histories = mdl.ReadHistories(k2.Stream(), binary); + unordered_map pdf_hist_weight; + mdl.ComputeOutputWords(histories, &pdf_hist_weight); + // command for running the test binary: ./test-binary arpa-file history-file + // arpa-file is the ARPA-format language model + // history-file has lines of histories, one history per line + + // this test can be slow + /* + KALDI_LOG << "Start weighted histories test..."; + for (int i = 0; i < N / 100; i++) { + mdl.TestPdfsEqual(); + } + KALDI_LOG << "Successfuly pass the test."; + */ + return 0; +}