From 97fbd57f16be765b737edc199a170fd231170b15 Mon Sep 17 00:00:00 2001 From: saikiranvalluri Date: Wed, 27 Feb 2019 09:53:26 +0000 Subject: [PATCH 1/7] RNNLM scripts support for UTF-8 encoded text --- scripts/rnnlm/choose_features.py | 12 +++------- scripts/rnnlm/get_best_model.py | 28 +++++++++++------------- scripts/rnnlm/get_embedding_dim.py | 4 ++-- scripts/rnnlm/get_num_splits.sh | 2 +- scripts/rnnlm/get_special_symbol_opts.py | 8 ++----- scripts/rnnlm/get_unigram_probs.py | 18 ++++++--------- scripts/rnnlm/get_vocab.py | 11 ++++------ scripts/rnnlm/get_word_features.py | 15 +++++-------- scripts/rnnlm/lmrescore.sh | 6 ----- scripts/rnnlm/lmrescore_nbest.sh | 4 ++-- scripts/rnnlm/lmrescore_pruned.sh | 17 ++++---------- scripts/rnnlm/prepare_rnnlm_dir.sh | 9 ++------ scripts/rnnlm/prepare_split_data.py | 13 +++++------ scripts/rnnlm/rnnlm_cleanup.py | 2 +- scripts/rnnlm/show_word_features.py | 19 +++++----------- scripts/rnnlm/train_rnnlm.sh | 2 +- scripts/rnnlm/validate_features.py | 7 ++---- scripts/rnnlm/validate_text_dir.py | 11 ++++------ scripts/rnnlm/validate_word_features.py | 11 ++++------ 19 files changed, 68 insertions(+), 131 deletions(-) diff --git a/scripts/rnnlm/choose_features.py b/scripts/rnnlm/choose_features.py index c6621e04494..799f6b6dcc8 100755 --- a/scripts/rnnlm/choose_features.py +++ b/scripts/rnnlm/choose_features.py @@ -10,12 +10,6 @@ from collections import defaultdict sys.stdout = open(1, 'w', encoding='utf-8', closefd=False) -# because this script splits inside words, we cannot use latin-1; we actually need to know what -# what the encoding is. By default we make this utf-8; to handle encodings that are not compatible -# with utf-8 (e.g. gbk), we'll eventually have to make the encoding an option to this script. - -import re -tab_or_space = re.compile('[ \t]+') parser = argparse.ArgumentParser(description="This script chooses the sparse feature representation of words. " "To be more specific, it chooses the set of features-- you compute " @@ -90,9 +84,9 @@ # and 'wordlist' is a list indexed by integer id, that returns the string-valued word. def read_vocab(vocab_file): vocab = {} - with open(vocab_file, 'r', encoding="utf-8") as f: + with open(vocab_file, 'r', encoding="utf-8", errors='replace') as f: for line in f: - fields = re.split(tab_or_space, line) + fields = line.split() assert len(fields) == 2 if fields[0] in vocab: sys.exit(sys.argv[0] + ": duplicated word({0}) in vocab: {1}" @@ -121,7 +115,7 @@ def read_unigram_probs(unigram_probs_file): unigram_probs = [] with open(unigram_probs_file, 'r', encoding="utf-8") as f: for line in f: - fields = re.split(tab_or_space, line) + fields = line.split() assert len(fields) == 2 idx = int(fields[0]) if idx >= len(unigram_probs): diff --git a/scripts/rnnlm/get_best_model.py b/scripts/rnnlm/get_best_model.py index 333ed8dbfc7..45487b18b0c 100755 --- a/scripts/rnnlm/get_best_model.py +++ b/scripts/rnnlm/get_best_model.py @@ -3,14 +3,14 @@ # Copyright 2017 Johns Hopkins University (author: Daniel Povey) # License: Apache 2.0. +import os import argparse -import glob -import re import sys +import re parser = argparse.ArgumentParser(description="Works out the best iteration of RNNLM training " - "based on dev-set perplexity, and prints the number corresponding " - "to that iteration", + "based on dev-set perplexity, and prints the number corresponding " + "to that iteration", epilog="E.g. " + sys.argv[0] + " exp/rnnlm_a", formatter_class=argparse.ArgumentDefaultsHelpFormatter) @@ -19,9 +19,10 @@ args = parser.parse_args() -num_iters = None + +num_iters=None try: - with open(args.rnnlm_dir + "/info.txt", encoding="latin-1") as f: + with open(args.rnnlm_dir + "/info.txt", encoding="utf-8") as f: for line in f: a = line.split("=") if a[0] == "num_iters": @@ -35,15 +36,15 @@ sys.exit(sys.argv[0] + ": could not get num_iters from {0}/info.txt".format( args.rnnlm_dir)) -best_objf = -2000 -best_iter = -1 -for i in range(1, num_iters): +best_objf=-2000 +best_iter=-1 +for i in range(num_iters): this_logfile = "{0}/log/compute_prob.{1}.log".format(args.rnnlm_dir, i) try: - f = open(this_logfile, 'r', encoding='latin-1') + f = open(this_logfile, 'r', encoding='utf-8') except: sys.exit(sys.argv[0] + ": could not open log-file {0}".format(this_logfile)) - this_objf = -1000 + this_objf=-1000 for line in f: m = re.search('Overall objf .* (\S+)$', str(line)) if m is not None: @@ -52,10 +53,6 @@ except Exception as e: sys.exit(sys.argv[0] + ": line in file {0} could not be parsed: {1}, error is: {2}".format( this_logfile, line, str(e))) - # verify this iteration still has model files present - if len(glob.glob("{0}/{1}.raw".format(args.rnnlm_dir, i))) == 0: - # this iteration has log files, but model files have been cleaned up, skip it - continue if this_objf == -1000: print(sys.argv[0] + ": warning: could not parse objective function from {0}".format( this_logfile), file=sys.stderr) @@ -66,4 +63,5 @@ if best_iter == -1: sys.exit(sys.argv[0] + ": error: could not get best iteration.") + print(str(best_iter)) diff --git a/scripts/rnnlm/get_embedding_dim.py b/scripts/rnnlm/get_embedding_dim.py index 63eaf307498..b6810ef2cbf 100755 --- a/scripts/rnnlm/get_embedding_dim.py +++ b/scripts/rnnlm/get_embedding_dim.py @@ -45,7 +45,7 @@ left_context=0 right_context=0 for line in out_lines: - line = line.decode('latin-1') + line = line.decode('utf-8') m = re.search(r'input-node name=input dim=(\d+)', line) if m is not None: try: @@ -101,4 +101,4 @@ "nnet '{0}': {1} != {2}".format( args.nnet, input_dim, output_dim)) -print('{}'.format(input_dim)) +print(str(input_dim)) diff --git a/scripts/rnnlm/get_num_splits.sh b/scripts/rnnlm/get_num_splits.sh index 974fd8bf204..93d1f7f169c 100755 --- a/scripts/rnnlm/get_num_splits.sh +++ b/scripts/rnnlm/get_num_splits.sh @@ -65,7 +65,7 @@ tot_with_multiplicities=0 for f in $text/*.counts; do if [ "$f" != "$text/dev.counts" ]; then - this_tot=$(cat $f | awk '{tot += $2} END{printf("%d", tot)}') + this_tot=$(cat $f | awk '{tot += $2} END{print tot}') if ! [ $this_tot -gt 0 ]; then echo "$0: there were no counts in counts file $f" 1>&2 exit 1 diff --git a/scripts/rnnlm/get_special_symbol_opts.py b/scripts/rnnlm/get_special_symbol_opts.py index 4310b116ad7..13fe497faf9 100755 --- a/scripts/rnnlm/get_special_symbol_opts.py +++ b/scripts/rnnlm/get_special_symbol_opts.py @@ -8,9 +8,6 @@ import argparse import sys -import re -tab_or_space = re.compile('[ \t]+') - parser = argparse.ArgumentParser(description="This script checks whether the special symbols " "appear in words.txt with expected values, if not, it will " "print out the options with correct value to stdout, which may look like " @@ -28,10 +25,9 @@ lower_ids = {} upper_ids = {} -input_stream = io.TextIOWrapper(sys.stdin.buffer, encoding='latin-1') +input_stream = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8', errors='replace') for line in input_stream: - fields = re.split(tab_or_space, line) - assert(len(fields) == 2) + fields = line.split() sym = fields[0] if sym in special_symbols: assert sym not in lower_ids diff --git a/scripts/rnnlm/get_unigram_probs.py b/scripts/rnnlm/get_unigram_probs.py index ab3f9bb382f..32b01728ca3 100755 --- a/scripts/rnnlm/get_unigram_probs.py +++ b/scripts/rnnlm/get_unigram_probs.py @@ -7,9 +7,6 @@ import argparse import sys -import re -tab_or_space = re.compile('[ \t]+') - parser = argparse.ArgumentParser(description="This script gets the unigram probabilities of words.", epilog="E.g. " + sys.argv[0] + " --vocab-file=data/rnnlm/vocab/words.txt " "--data-weights-file=exp/rnnlm/data_weights.txt data/rnnlm/data " @@ -77,10 +74,10 @@ def get_all_data_sources_except_dev(text_dir): # value is a tuple (repeated_times_per_epoch, weight) def read_data_weights(weights_file, data_sources): data_weights = {} - with open(weights_file, 'r', encoding="latin-1") as f: + with open(weights_file, 'r', encoding="utf-8", errors='replace') as f: for line in f: try: - fields = re.split(tab_or_space, line) + fields = line.split() assert len(fields) == 3 if fields[0] in data_weights: raise Exception("duplicated data source({0}) specified in " @@ -102,9 +99,9 @@ def read_data_weights(weights_file, data_sources): # return the vocab, which is a dict mapping the word to a integer id. def read_vocab(vocab_file): vocab = {} - with open(vocab_file, 'r', encoding="latin-1") as f: + with open(vocab_file, 'r', encoding="utf-8", errors='replace') as f: for line in f: - fields = re.split(tab_or_space, line) + fields = line.split() assert len(fields) == 2 if fields[0] in vocab: sys.exit(sys.argv[0] + ": duplicated word({0}) in vocab: {1}" @@ -131,11 +128,10 @@ def get_counts(data_sources, data_weights, vocab): if weight == 0.0: continue - with open(counts_file, 'r', encoding="latin-1") as f: + with open(counts_file, 'r', encoding="utf-8", errors='replace') as f: for line in f: - fields = re.split(tab_or_space, line) - if len(fields) != 2: print("Warning, should be 2 cols:", fields, line, file=sys.stderr); - assert(len(fields) == 2) + fields = line.split() + assert len(fields) == 2 word = fields[0] count = fields[1] if word not in vocab: diff --git a/scripts/rnnlm/get_vocab.py b/scripts/rnnlm/get_vocab.py index 1502e915f9c..f290ef721c1 100755 --- a/scripts/rnnlm/get_vocab.py +++ b/scripts/rnnlm/get_vocab.py @@ -6,10 +6,7 @@ import os import argparse import sys -sys.stdout = open(1, 'w', encoding='latin-1', closefd=False) - -import re -tab_or_space = re.compile('[ \t]+') +sys.stdout = open(1, 'w', encoding='utf-8', closefd=False) parser = argparse.ArgumentParser(description="This script get a vocab from unigram counts " "of words produced by get_unigram_counts.sh", @@ -28,10 +25,10 @@ # Add the count for every word in counts_file # the result is written into word_counts def add_counts(word_counts, counts_file): - with open(counts_file, 'r', encoding="latin-1") as f: + with open(counts_file, 'r', encoding="utf-8") as f: for line in f: - line = line.strip(" \t\r\n") - word_and_count = re.split(tab_or_space, line) + line = line.strip() + word_and_count = line.split() assert len(word_and_count) == 2 if word_and_count[0] in word_counts: word_counts[word_and_count[0]] += int(word_and_count[1]) diff --git a/scripts/rnnlm/get_word_features.py b/scripts/rnnlm/get_word_features.py index aeb7a3ec6ae..8bdb553b9c8 100755 --- a/scripts/rnnlm/get_word_features.py +++ b/scripts/rnnlm/get_word_features.py @@ -9,9 +9,6 @@ import math from collections import defaultdict -import re -tab_or_space = re.compile('[ \t]+') - parser = argparse.ArgumentParser(description="This script turns the words into the sparse feature representation, " "using features from rnnlm/choose_features.py.", epilog="E.g. " + sys.argv[0] + " --unigram-probs=exp/rnnlm/unigram_probs.txt " @@ -41,9 +38,9 @@ # return the vocab, which is a dict mapping the word to a integer id. def read_vocab(vocab_file): vocab = {} - with open(vocab_file, 'r', encoding="latin-1") as f: + with open(vocab_file, 'r', encoding="utf-8", errors='replace') as f: for line in f: - fields = re.split(tab_or_space, line) + fields = line.split() assert len(fields) == 2 if fields[0] in vocab: sys.exit(sys.argv[0] + ": duplicated word({0}) in vocab: {1}" @@ -62,9 +59,9 @@ def read_vocab(vocab_file): # return a list of unigram_probs, indexed by word id def read_unigram_probs(unigram_probs_file): unigram_probs = [] - with open(unigram_probs_file, 'r', encoding="latin-1") as f: + with open(unigram_probs_file, 'r', encoding="utf-8", errors='replace') as f: for line in f: - fields = re.split(tab_or_space, line) + fields = line.split() assert len(fields) == 2 idx = int(fields[0]) if idx >= len(unigram_probs): @@ -103,9 +100,9 @@ def read_features(features_file): feats['min_ngram_order'] = 10000 feats['max_ngram_order'] = -1 - with open(features_file, 'r', encoding="latin-1") as f: + with open(features_file, 'r', encoding="utf-8", errors='replace') as f: for line in f: - fields = re.split(tab_or_space, line) + fields = line.split() assert(len(fields) in [3, 4, 5]) feat_id = int(fields[0]) diff --git a/scripts/rnnlm/lmrescore.sh b/scripts/rnnlm/lmrescore.sh index 9da22ae75a2..cd0cf793d8d 100755 --- a/scripts/rnnlm/lmrescore.sh +++ b/scripts/rnnlm/lmrescore.sh @@ -72,12 +72,6 @@ awk -v n=$0 -v w=$weight 'BEGIN {if (w < 0 || w > 1) { print n": Interpolation weight should be in the range of [0, 1]"; exit 1;}}' \ || exit 1; -if ! head -n -1 $rnnlm_dir/config/words.txt | cmp $oldlang/words.txt -; then - # the last word of the RNNLM word list is an added word - echo "$0: Word lists mismatch for lattices and RNNLM." - exit 1 -fi - oldlm_command="fstproject --project_output=true $oldlm |" special_symbol_opts=$(cat $rnnlm_dir/special_symbol_opts.txt) diff --git a/scripts/rnnlm/lmrescore_nbest.sh b/scripts/rnnlm/lmrescore_nbest.sh index 58b19b9fa79..f50a3c909f0 100755 --- a/scripts/rnnlm/lmrescore_nbest.sh +++ b/scripts/rnnlm/lmrescore_nbest.sh @@ -29,7 +29,7 @@ if [ $# != 6 ]; then echo "This version applies an RNNLM and mixes it with the LM scores" echo "previously in the lattices., controlled by the first parameter (rnnlm-weight)" echo "" - echo "Usage: $0 [options] " + echo "Usage: utils/rnnlmrescore.sh " echo "Main options:" echo " --inv-acwt # default 12. e.g. --inv-acwt 17. Equivalent to LM scale to use." echo " # for N-best list generation... note, we'll score at different acwt's" @@ -177,7 +177,7 @@ fi if [ $stage -le 6 ]; then echo "$0: invoking rnnlm/compute_sentence_scores.sh which calls rnnlm to get RNN LM scores." $cmd JOB=1:$nj $dir/log/rnnlm_compute_scores.JOB.log \ - rnnlm/compute_sentence_scores.sh $rnndir $adir.JOB/temp \ + local/rnnlm/compute_sentence_scores.sh $rnndir $adir.JOB/temp \ $adir.JOB/words_text $adir.JOB/lmwt.rnn fi if [ $stage -le 7 ]; then diff --git a/scripts/rnnlm/lmrescore_pruned.sh b/scripts/rnnlm/lmrescore_pruned.sh index 9ba78415708..46ee5846424 100755 --- a/scripts/rnnlm/lmrescore_pruned.sh +++ b/scripts/rnnlm/lmrescore_pruned.sh @@ -16,18 +16,16 @@ max_ngram_order=4 # Approximate the lattice-rescoring by limiting the max-ngram- # the same ngram history and this prevents the lattice from # exploding exponentially. Details of the n-gram approximation # method are described in section 2.3 of the paper - # http://www.danielpovey.com/files/2018_icassp_lattice_pruning.pdf -max_arcs= # limit the max arcs in lattice while rescoring. E.g., 20000 + # http://www.danielpovey.com/files/2018_icassp_lattice_pruning.pdm +max_arcs=499 # limit the max arcs in lattice while rescoring. E.g., 20000 -acwt=0.1 -weight=0.5 # Interpolation weight for RNNLM. +acwt=1 +weight=1 # Interpolation weight for RNNLM. normalize=false # If true, we add a normalization step to the output of the RNNLM # so that it adds up to *exactly* 1. Note that this is not necessary # as in our RNNLM setup, a properly trained network would automatically # have its normalization term close to 1. The details of this # could be found at http://www.danielpovey.com/files/2018_icassp_rnnlm.pdf -lattice_prune_beam=4 # Beam used in pruned lattice composition - # This option affects speed and how large the composed lattice may be # End configuration section. @@ -75,12 +73,6 @@ awk -v n=$0 -v w=$weight 'BEGIN {if (w < 0 || w > 1) { print n": Interpolation weight should be in the range of [0, 1]"; exit 1;}}' \ || exit 1; -if ! head -n -1 $rnnlm_dir/config/words.txt | cmp $oldlang/words.txt -; then - # the last word of the RNNLM word list is an added word - echo "$0: Word lists mismatch for lattices and RNNLM." - exit 1 -fi - normalize_opt= if $normalize; then normalize_opt="--normalize-probs=true" @@ -105,7 +97,6 @@ cp $indir/num_jobs $outdir $cmd JOB=1:$nj $outdir/log/rescorelm.JOB.log \ lattice-lmrescore-kaldi-rnnlm-pruned --lm-scale=$weight $special_symbol_opts \ - --lattice-compose-beam=$lattice_prune_beam \ --acoustic-scale=$acwt --max-ngram-order=$max_ngram_order $normalize_opt $max_arcs_opt \ $carpa_option $oldlm $word_embedding "$rnnlm_dir/final.raw" \ "ark:gunzip -c $indir/lat.JOB.gz|" "ark,t:|gzip -c>$outdir/lat.JOB.gz" || exit 1; diff --git a/scripts/rnnlm/prepare_rnnlm_dir.sh b/scripts/rnnlm/prepare_rnnlm_dir.sh index e101822d983..d3ee44f1f95 100755 --- a/scripts/rnnlm/prepare_rnnlm_dir.sh +++ b/scripts/rnnlm/prepare_rnnlm_dir.sh @@ -23,7 +23,7 @@ if [ $# != 3 ]; then echo "Usage: $0 [options] " echo "Sets up the directory for RNNLM training as done by" echo "rnnlm/train_rnnlm.sh, and initializes the model." - echo " is as validated by rnnlm/validate_text_dir.py" + echo " is as validated by rnnlm/validate_data_dir.py" echo " is as validated by rnnlm/validate_config_dir.sh." exit 1 fi @@ -34,7 +34,6 @@ config_dir=$2 dir=$3 set -e -. ./path.sh if [ $stage -le 0 ]; then echo "$0: validating input" @@ -53,13 +52,9 @@ if [ $stage -le 1 ]; then echo "$0: copying config directory" mkdir -p $dir/config # copy expected things from $config_dir to $dir/config. - for f in words.txt data_weights.txt oov.txt xconfig; do + for f in words.txt features.txt data_weights.txt oov.txt xconfig; do cp $config_dir/$f $dir/config done - # features.txt is optional, check separately - if [ -f $config_dir/features.txt ]; then - cp $config_dir/features.txt $dir/config - fi fi rnnlm/get_special_symbol_opts.py < $dir/config/words.txt > $dir/special_symbol_opts.txt diff --git a/scripts/rnnlm/prepare_split_data.py b/scripts/rnnlm/prepare_split_data.py index cceac48313e..9cc4f69d09f 100755 --- a/scripts/rnnlm/prepare_split_data.py +++ b/scripts/rnnlm/prepare_split_data.py @@ -8,9 +8,6 @@ import argparse import sys -import re -tab_or_space = re.compile('[ \t]+') - parser = argparse.ArgumentParser(description="This script prepares files containing integerized text, " "for consumption by nnet3-get-egs.", epilog="E.g. " + sys.argv[0] + " --vocab-file=data/rnnlm/vocab/words.txt " @@ -66,10 +63,10 @@ def get_all_data_sources_except_dev(text_dir): # value is a tuple (repeated_times_per_epoch, weight) def read_data_weights(weights_file, data_sources): data_weights = {} - with open(weights_file, 'r', encoding="latin-1") as f: + with open(weights_file, 'r', encoding="utf-8") as f: for line in f: try: - fields = re.split(tab_or_space, line) + fields = line.split() assert len(fields) == 3 if fields[0] in data_weights: raise Exception("duplicated data source({0}) specified in " @@ -97,7 +94,7 @@ def distribute_to_outputs(source_filename, weight, output_filehandles): num_outputs = len(output_filehandles) n = 0 try: - f = open(source_filename, 'r', encoding="latin-1") + f = open(source_filename, 'r', encoding="utf-8") except Exception as e: sys.exit(sys.argv[0] + ": failed to open file {0} for reading: {1} ".format( source_filename, str(e))) @@ -124,7 +121,7 @@ def distribute_to_outputs(source_filename, weight, output_filehandles): os.makedirs(args.split_dir + "/info") # set up the 'num_splits' file, which contains an integer. -with open("{0}/info/num_splits".format(args.split_dir), 'w', encoding="latin-1") as f: +with open("{0}/info/num_splits".format(args.split_dir), 'w', encoding="utf-8") as f: print(args.num_splits, file=f) # e.g. set temp_files = [ 'foo/1.tmp', 'foo/2.tmp', ..., 'foo/5.tmp' ] @@ -136,7 +133,7 @@ def distribute_to_outputs(source_filename, weight, output_filehandles): temp_filehandles = [] for fname in temp_files: try: - temp_filehandles.append(open(fname, 'w', encoding="latin-1")) + temp_filehandles.append(open(fname, 'w', encoding="utf-8")) except Exception as e: sys.exit(sys.argv[0] + ": failed to open file: " + str(e) + ".. if this is a max-open-filehandles limitation, you may " diff --git a/scripts/rnnlm/rnnlm_cleanup.py b/scripts/rnnlm/rnnlm_cleanup.py index 40cbee7a496..6a304f7f4cb 100644 --- a/scripts/rnnlm/rnnlm_cleanup.py +++ b/scripts/rnnlm/rnnlm_cleanup.py @@ -69,7 +69,7 @@ def get_compute_prob_info(log_file): compute_prob_done = False # roughly based on code in get_best_model.py try: - f = open(log_file, "r", encoding="latin-1") + f = open(log_file, "r", encoding="utf-8") except: print(script_name + ": warning: compute_prob log not found for iteration " + str(iter) + ". Skipping", diff --git a/scripts/rnnlm/show_word_features.py b/scripts/rnnlm/show_word_features.py index 89b134adaf9..89d84d53f3e 100755 --- a/scripts/rnnlm/show_word_features.py +++ b/scripts/rnnlm/show_word_features.py @@ -6,16 +6,7 @@ import os import argparse import sys - -# The use of latin-1 encoding does not preclude reading utf-8. latin-1 encoding -# means "treat words as sequences of bytes", and it is compatible with utf-8 -# encoding as well as other encodings such as gbk, as long as the spaces are -# also spaces in ascii (which we check). It is basically how we emulate the -# behavior of python before python3. -sys.stdout = open(1, 'w', encoding='latin-1', closefd=False) - -import re -tab_or_space = re.compile('[ \t]+') +sys.stdout = open(1, 'w', encoding='utf-8', closefd=False) parser = argparse.ArgumentParser(description="This script turns the word features to a human readable format.", epilog="E.g. " + sys.argv[0] + "exp/rnnlm/word_feats.txt exp/rnnlm/features.txt " @@ -36,9 +27,9 @@ def read_feature_type_and_key(features_file): feat_types = {} - with open(features_file, 'r', encoding="latin-1") as f: + with open(features_file, 'r', encoding="utf-8") as f: for line in f: - fields = re.split(tab_or_space, line) + fields = line.split() assert(len(fields) in [2, 3, 4]) feat_id = int(fields[0]) @@ -53,9 +44,9 @@ def read_feature_type_and_key(features_file): feat_type_and_key = read_feature_type_and_key(args.features_file) num_word_feats = 0 -with open(args.word_features_file, 'r', encoding="latin-1") as f: +with open(args.word_features_file, 'r', encoding="utf-8") as f: for line in f: - fields = re.split(tab_or_space, line) + fields = line.split() assert len(fields) % 2 == 1 print(int(fields[0]), end='\t') diff --git a/scripts/rnnlm/train_rnnlm.sh b/scripts/rnnlm/train_rnnlm.sh index 013e9a56c2f..f056d096120 100755 --- a/scripts/rnnlm/train_rnnlm.sh +++ b/scripts/rnnlm/train_rnnlm.sh @@ -41,7 +41,7 @@ use_gpu_for_diagnostics=false # set true to use GPU for compute_prob_*.log # optional cleanup options cleanup=false # add option --cleanup true to enable automatic cleanup of old models cleanup_strategy="keep_latest" # determines cleanup strategy, use either "keep_latest" or "keep_best" -cleanup_keep_iters=3 # number of iterations that will have their models retained +cleanup_keep_iters=100 # number of iterations that will have their models retained trap 'for pid in $(jobs -pr); do kill -KILL $pid; done' INT QUIT TERM . utils/parse_options.sh diff --git a/scripts/rnnlm/validate_features.py b/scripts/rnnlm/validate_features.py index 2a077da4758..a650092b086 100755 --- a/scripts/rnnlm/validate_features.py +++ b/scripts/rnnlm/validate_features.py @@ -7,9 +7,6 @@ import argparse import sys -import re -tab_or_space = re.compile('[ \t]+') - parser = argparse.ArgumentParser(description="Validates features file, produced by rnnlm/choose_features.py.", epilog="E.g. " + sys.argv[0] + " exp/rnnlm/features.txt", formatter_class=argparse.ArgumentDefaultsHelpFormatter) @@ -24,7 +21,7 @@ if not os.path.isfile(args.features_file): sys.exit(sys.argv[0] + ": Expected file {0} to exist".format(args.features_file)) -with open(args.features_file, 'r', encoding="latin-1") as f: +with open(args.features_file, 'r', encoding="utf-8") as f: has_unigram = False has_length = False idx = 0 @@ -33,7 +30,7 @@ final_feats = {} word_feats = {} for line in f: - fields = re.split(tab_or_space, line) + fields = line.split() assert(len(fields) in [3, 4, 5]) assert idx == int(fields[0]) diff --git a/scripts/rnnlm/validate_text_dir.py b/scripts/rnnlm/validate_text_dir.py index 903e720bdf4..d644d77911e 100755 --- a/scripts/rnnlm/validate_text_dir.py +++ b/scripts/rnnlm/validate_text_dir.py @@ -7,9 +7,6 @@ import argparse import sys -import re -tab_or_space = re.compile('[ \t]+') - parser = argparse.ArgumentParser(description="Validates data directory containing text " "files from one or more data sources, including dev.txt.", epilog="E.g. " + sys.argv[0] + " data/rnnlm/data", @@ -40,7 +37,7 @@ def check_text_file(text_file): - with open(text_file, 'r', encoding="latin-1") as f: + with open(text_file, 'r', encoding="utf-8") as f: found_nonempty_line = False lineno = 0 if args.allow_internal_eos == 'true': @@ -54,7 +51,7 @@ def check_text_file(text_file): lineno += 1 if args.spot_check == 'true' and lineno > 10: break - words = re.split(tab_or_space, line) + words = line.split() if len(words) != 0: found_nonempty_line = True for word in words: @@ -76,9 +73,9 @@ def check_text_file(text_file): # with some kind of utterance-id first_field_set = set() other_fields_set = set() - with open(text_file, 'r', encoding="latin-1") as f: + with open(text_file, 'r', encoding="utf-8") as f: for line in f: - array = re.split(tab_or_space, line) + array = line.split() if len(array) > 0: first_word = array[0] if first_word in first_field_set or first_word in other_fields_set: diff --git a/scripts/rnnlm/validate_word_features.py b/scripts/rnnlm/validate_word_features.py index 205b934ae1b..3dc9b23aa41 100755 --- a/scripts/rnnlm/validate_word_features.py +++ b/scripts/rnnlm/validate_word_features.py @@ -7,9 +7,6 @@ import argparse import sys -import re -tab_or_space = re.compile('[ \t]+') - parser = argparse.ArgumentParser(description="Validates word features file, produced by rnnlm/get_word_features.py.", epilog="E.g. " + sys.argv[0] + " --features-file=exp/rnnlm/features.txt " "exp/rnnlm/word_feats.txt", @@ -28,9 +25,9 @@ unigram_feat_id = -1 length_feat_id = -1 max_feat_id = -1 -with open(args.features_file, 'r', encoding="latin-1") as f: +with open(args.features_file, 'r', encoding="utf-8") as f: for line in f: - fields = re.split(tab_or_space, line) + fields = line.split() assert(len(fields) in [3, 4, 5]) feat_id = int(fields[0]) @@ -52,9 +49,9 @@ if feat_id > max_feat_id: max_feat_id = feat_id -with open(args.word_features_file, 'r', encoding="latin-1") as f: +with open(args.word_features_file, 'r', encoding="utf-8") as f: for line in f: - fields = re.split(tab_or_space, line) + fields = line.split() assert len(fields) > 0 and len(fields) % 2 == 1 word_id = int(fields[0]) From d1cfa337d794a4525f7fee6645bf13d01e255a4d Mon Sep 17 00:00:00 2001 From: saikiranvalluri Date: Sat, 2 Mar 2019 04:07:44 +0000 Subject: [PATCH 2/7] RNNLM scripts taken from latest kaldi checkout --- scripts/rnnlm/choose_features.py | 12 +++++++++--- scripts/rnnlm/get_best_model.py | 24 +++++++++++++----------- scripts/rnnlm/get_embedding_dim.py | 2 +- scripts/rnnlm/get_num_splits.sh | 2 +- scripts/rnnlm/get_special_symbol_opts.py | 8 ++++++-- scripts/rnnlm/get_unigram_probs.py | 18 +++++++++++------- scripts/rnnlm/get_vocab.py | 7 +++++-- scripts/rnnlm/get_word_features.py | 15 +++++++++------ scripts/rnnlm/lmrescore.sh | 6 ++++++ scripts/rnnlm/lmrescore_nbest.sh | 4 ++-- scripts/rnnlm/lmrescore_pruned.sh | 17 +++++++++++++---- scripts/rnnlm/prepare_rnnlm_dir.sh | 9 +++++++-- scripts/rnnlm/prepare_split_data.py | 5 ++++- scripts/rnnlm/show_word_features.py | 13 +++++++++++-- scripts/rnnlm/train_rnnlm.sh | 2 +- scripts/rnnlm/validate_features.py | 5 ++++- scripts/rnnlm/validate_text_dir.py | 7 +++++-- scripts/rnnlm/validate_word_features.py | 7 +++++-- 18 files changed, 113 insertions(+), 50 deletions(-) diff --git a/scripts/rnnlm/choose_features.py b/scripts/rnnlm/choose_features.py index 799f6b6dcc8..12d69b545e5 100755 --- a/scripts/rnnlm/choose_features.py +++ b/scripts/rnnlm/choose_features.py @@ -10,6 +10,12 @@ from collections import defaultdict sys.stdout = open(1, 'w', encoding='utf-8', closefd=False) +# because this script splits inside words, we cannot use utf-8; we actually need to know what +# what the encoding is. By default we make this utf-8; to handle encodings that are not compatible +# with utf-8 (e.g. gbk), we'll eventually have to make the encoding an option to this script. + +import re +tab_or_space = re.compile('[ \t]+') parser = argparse.ArgumentParser(description="This script chooses the sparse feature representation of words. " "To be more specific, it chooses the set of features-- you compute " @@ -84,9 +90,9 @@ # and 'wordlist' is a list indexed by integer id, that returns the string-valued word. def read_vocab(vocab_file): vocab = {} - with open(vocab_file, 'r', encoding="utf-8", errors='replace') as f: + with open(vocab_file, 'r', encoding="utf-8") as f: for line in f: - fields = line.split() + fields = re.split(tab_or_space, line) assert len(fields) == 2 if fields[0] in vocab: sys.exit(sys.argv[0] + ": duplicated word({0}) in vocab: {1}" @@ -115,7 +121,7 @@ def read_unigram_probs(unigram_probs_file): unigram_probs = [] with open(unigram_probs_file, 'r', encoding="utf-8") as f: for line in f: - fields = line.split() + fields = re.split(tab_or_space, line) assert len(fields) == 2 idx = int(fields[0]) if idx >= len(unigram_probs): diff --git a/scripts/rnnlm/get_best_model.py b/scripts/rnnlm/get_best_model.py index 45487b18b0c..ed266346e06 100755 --- a/scripts/rnnlm/get_best_model.py +++ b/scripts/rnnlm/get_best_model.py @@ -3,14 +3,14 @@ # Copyright 2017 Johns Hopkins University (author: Daniel Povey) # License: Apache 2.0. -import os import argparse -import sys +import glob import re +import sys parser = argparse.ArgumentParser(description="Works out the best iteration of RNNLM training " - "based on dev-set perplexity, and prints the number corresponding " - "to that iteration", + "based on dev-set perplexity, and prints the number corresponding " + "to that iteration", epilog="E.g. " + sys.argv[0] + " exp/rnnlm_a", formatter_class=argparse.ArgumentDefaultsHelpFormatter) @@ -19,8 +19,7 @@ args = parser.parse_args() - -num_iters=None +num_iters = None try: with open(args.rnnlm_dir + "/info.txt", encoding="utf-8") as f: for line in f: @@ -36,15 +35,15 @@ sys.exit(sys.argv[0] + ": could not get num_iters from {0}/info.txt".format( args.rnnlm_dir)) -best_objf=-2000 -best_iter=-1 -for i in range(num_iters): +best_objf = -2000 +best_iter = -1 +for i in range(1, num_iters): this_logfile = "{0}/log/compute_prob.{1}.log".format(args.rnnlm_dir, i) try: f = open(this_logfile, 'r', encoding='utf-8') except: sys.exit(sys.argv[0] + ": could not open log-file {0}".format(this_logfile)) - this_objf=-1000 + this_objf = -1000 for line in f: m = re.search('Overall objf .* (\S+)$', str(line)) if m is not None: @@ -53,6 +52,10 @@ except Exception as e: sys.exit(sys.argv[0] + ": line in file {0} could not be parsed: {1}, error is: {2}".format( this_logfile, line, str(e))) + # verify this iteration still has model files present + if len(glob.glob("{0}/{1}.raw".format(args.rnnlm_dir, i))) == 0: + # this iteration has log files, but model files have been cleaned up, skip it + continue if this_objf == -1000: print(sys.argv[0] + ": warning: could not parse objective function from {0}".format( this_logfile), file=sys.stderr) @@ -63,5 +66,4 @@ if best_iter == -1: sys.exit(sys.argv[0] + ": error: could not get best iteration.") - print(str(best_iter)) diff --git a/scripts/rnnlm/get_embedding_dim.py b/scripts/rnnlm/get_embedding_dim.py index b6810ef2cbf..1d516e0edf5 100755 --- a/scripts/rnnlm/get_embedding_dim.py +++ b/scripts/rnnlm/get_embedding_dim.py @@ -101,4 +101,4 @@ "nnet '{0}': {1} != {2}".format( args.nnet, input_dim, output_dim)) -print(str(input_dim)) +print('{}'.format(input_dim)) diff --git a/scripts/rnnlm/get_num_splits.sh b/scripts/rnnlm/get_num_splits.sh index 93d1f7f169c..974fd8bf204 100755 --- a/scripts/rnnlm/get_num_splits.sh +++ b/scripts/rnnlm/get_num_splits.sh @@ -65,7 +65,7 @@ tot_with_multiplicities=0 for f in $text/*.counts; do if [ "$f" != "$text/dev.counts" ]; then - this_tot=$(cat $f | awk '{tot += $2} END{print tot}') + this_tot=$(cat $f | awk '{tot += $2} END{printf("%d", tot)}') if ! [ $this_tot -gt 0 ]; then echo "$0: there were no counts in counts file $f" 1>&2 exit 1 diff --git a/scripts/rnnlm/get_special_symbol_opts.py b/scripts/rnnlm/get_special_symbol_opts.py index 13fe497faf9..0cf8e10feca 100755 --- a/scripts/rnnlm/get_special_symbol_opts.py +++ b/scripts/rnnlm/get_special_symbol_opts.py @@ -8,6 +8,9 @@ import argparse import sys +import re +tab_or_space = re.compile('[ \t]+') + parser = argparse.ArgumentParser(description="This script checks whether the special symbols " "appear in words.txt with expected values, if not, it will " "print out the options with correct value to stdout, which may look like " @@ -25,9 +28,10 @@ lower_ids = {} upper_ids = {} -input_stream = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8', errors='replace') +input_stream = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8') for line in input_stream: - fields = line.split() + fields = re.split(tab_or_space, line) + assert(len(fields) == 2) sym = fields[0] if sym in special_symbols: assert sym not in lower_ids diff --git a/scripts/rnnlm/get_unigram_probs.py b/scripts/rnnlm/get_unigram_probs.py index 32b01728ca3..d115b6f54bf 100755 --- a/scripts/rnnlm/get_unigram_probs.py +++ b/scripts/rnnlm/get_unigram_probs.py @@ -7,6 +7,9 @@ import argparse import sys +import re +tab_or_space = re.compile('[ \t]+') + parser = argparse.ArgumentParser(description="This script gets the unigram probabilities of words.", epilog="E.g. " + sys.argv[0] + " --vocab-file=data/rnnlm/vocab/words.txt " "--data-weights-file=exp/rnnlm/data_weights.txt data/rnnlm/data " @@ -74,10 +77,10 @@ def get_all_data_sources_except_dev(text_dir): # value is a tuple (repeated_times_per_epoch, weight) def read_data_weights(weights_file, data_sources): data_weights = {} - with open(weights_file, 'r', encoding="utf-8", errors='replace') as f: + with open(weights_file, 'r', encoding="utf-8") as f: for line in f: try: - fields = line.split() + fields = re.split(tab_or_space, line) assert len(fields) == 3 if fields[0] in data_weights: raise Exception("duplicated data source({0}) specified in " @@ -99,9 +102,9 @@ def read_data_weights(weights_file, data_sources): # return the vocab, which is a dict mapping the word to a integer id. def read_vocab(vocab_file): vocab = {} - with open(vocab_file, 'r', encoding="utf-8", errors='replace') as f: + with open(vocab_file, 'r', encoding="utf-8") as f: for line in f: - fields = line.split() + fields = re.split(tab_or_space, line) assert len(fields) == 2 if fields[0] in vocab: sys.exit(sys.argv[0] + ": duplicated word({0}) in vocab: {1}" @@ -128,10 +131,11 @@ def get_counts(data_sources, data_weights, vocab): if weight == 0.0: continue - with open(counts_file, 'r', encoding="utf-8", errors='replace') as f: + with open(counts_file, 'r', encoding="utf-8") as f: for line in f: - fields = line.split() - assert len(fields) == 2 + fields = re.split(tab_or_space, line) + if len(fields) != 2: print("Warning, should be 2 cols:", fields, line, file=sys.stderr); + assert(len(fields) == 2) word = fields[0] count = fields[1] if word not in vocab: diff --git a/scripts/rnnlm/get_vocab.py b/scripts/rnnlm/get_vocab.py index f290ef721c1..d65f8e3669b 100755 --- a/scripts/rnnlm/get_vocab.py +++ b/scripts/rnnlm/get_vocab.py @@ -8,6 +8,9 @@ import sys sys.stdout = open(1, 'w', encoding='utf-8', closefd=False) +import re +tab_or_space = re.compile('[ \t]+') + parser = argparse.ArgumentParser(description="This script get a vocab from unigram counts " "of words produced by get_unigram_counts.sh", epilog="E.g. " + sys.argv[0] + " data/rnnlm/data > data/rnnlm/vocab/words.txt", @@ -27,8 +30,8 @@ def add_counts(word_counts, counts_file): with open(counts_file, 'r', encoding="utf-8") as f: for line in f: - line = line.strip() - word_and_count = line.split() + line = line.strip(" \t\r\n") + word_and_count = re.split(tab_or_space, line) assert len(word_and_count) == 2 if word_and_count[0] in word_counts: word_counts[word_and_count[0]] += int(word_and_count[1]) diff --git a/scripts/rnnlm/get_word_features.py b/scripts/rnnlm/get_word_features.py index 8bdb553b9c8..7555b774b83 100755 --- a/scripts/rnnlm/get_word_features.py +++ b/scripts/rnnlm/get_word_features.py @@ -9,6 +9,9 @@ import math from collections import defaultdict +import re +tab_or_space = re.compile('[ \t]+') + parser = argparse.ArgumentParser(description="This script turns the words into the sparse feature representation, " "using features from rnnlm/choose_features.py.", epilog="E.g. " + sys.argv[0] + " --unigram-probs=exp/rnnlm/unigram_probs.txt " @@ -38,9 +41,9 @@ # return the vocab, which is a dict mapping the word to a integer id. def read_vocab(vocab_file): vocab = {} - with open(vocab_file, 'r', encoding="utf-8", errors='replace') as f: + with open(vocab_file, 'r', encoding="utf-8") as f: for line in f: - fields = line.split() + fields = re.split(tab_or_space, line) assert len(fields) == 2 if fields[0] in vocab: sys.exit(sys.argv[0] + ": duplicated word({0}) in vocab: {1}" @@ -59,9 +62,9 @@ def read_vocab(vocab_file): # return a list of unigram_probs, indexed by word id def read_unigram_probs(unigram_probs_file): unigram_probs = [] - with open(unigram_probs_file, 'r', encoding="utf-8", errors='replace') as f: + with open(unigram_probs_file, 'r', encoding="utf-8") as f: for line in f: - fields = line.split() + fields = re.split(tab_or_space, line) assert len(fields) == 2 idx = int(fields[0]) if idx >= len(unigram_probs): @@ -100,9 +103,9 @@ def read_features(features_file): feats['min_ngram_order'] = 10000 feats['max_ngram_order'] = -1 - with open(features_file, 'r', encoding="utf-8", errors='replace') as f: + with open(features_file, 'r', encoding="utf-8") as f: for line in f: - fields = line.split() + fields = re.split(tab_or_space, line) assert(len(fields) in [3, 4, 5]) feat_id = int(fields[0]) diff --git a/scripts/rnnlm/lmrescore.sh b/scripts/rnnlm/lmrescore.sh index cd0cf793d8d..9da22ae75a2 100755 --- a/scripts/rnnlm/lmrescore.sh +++ b/scripts/rnnlm/lmrescore.sh @@ -72,6 +72,12 @@ awk -v n=$0 -v w=$weight 'BEGIN {if (w < 0 || w > 1) { print n": Interpolation weight should be in the range of [0, 1]"; exit 1;}}' \ || exit 1; +if ! head -n -1 $rnnlm_dir/config/words.txt | cmp $oldlang/words.txt -; then + # the last word of the RNNLM word list is an added word + echo "$0: Word lists mismatch for lattices and RNNLM." + exit 1 +fi + oldlm_command="fstproject --project_output=true $oldlm |" special_symbol_opts=$(cat $rnnlm_dir/special_symbol_opts.txt) diff --git a/scripts/rnnlm/lmrescore_nbest.sh b/scripts/rnnlm/lmrescore_nbest.sh index f50a3c909f0..58b19b9fa79 100755 --- a/scripts/rnnlm/lmrescore_nbest.sh +++ b/scripts/rnnlm/lmrescore_nbest.sh @@ -29,7 +29,7 @@ if [ $# != 6 ]; then echo "This version applies an RNNLM and mixes it with the LM scores" echo "previously in the lattices., controlled by the first parameter (rnnlm-weight)" echo "" - echo "Usage: utils/rnnlmrescore.sh " + echo "Usage: $0 [options] " echo "Main options:" echo " --inv-acwt # default 12. e.g. --inv-acwt 17. Equivalent to LM scale to use." echo " # for N-best list generation... note, we'll score at different acwt's" @@ -177,7 +177,7 @@ fi if [ $stage -le 6 ]; then echo "$0: invoking rnnlm/compute_sentence_scores.sh which calls rnnlm to get RNN LM scores." $cmd JOB=1:$nj $dir/log/rnnlm_compute_scores.JOB.log \ - local/rnnlm/compute_sentence_scores.sh $rnndir $adir.JOB/temp \ + rnnlm/compute_sentence_scores.sh $rnndir $adir.JOB/temp \ $adir.JOB/words_text $adir.JOB/lmwt.rnn fi if [ $stage -le 7 ]; then diff --git a/scripts/rnnlm/lmrescore_pruned.sh b/scripts/rnnlm/lmrescore_pruned.sh index 46ee5846424..9ba78415708 100755 --- a/scripts/rnnlm/lmrescore_pruned.sh +++ b/scripts/rnnlm/lmrescore_pruned.sh @@ -16,16 +16,18 @@ max_ngram_order=4 # Approximate the lattice-rescoring by limiting the max-ngram- # the same ngram history and this prevents the lattice from # exploding exponentially. Details of the n-gram approximation # method are described in section 2.3 of the paper - # http://www.danielpovey.com/files/2018_icassp_lattice_pruning.pdm -max_arcs=499 # limit the max arcs in lattice while rescoring. E.g., 20000 + # http://www.danielpovey.com/files/2018_icassp_lattice_pruning.pdf +max_arcs= # limit the max arcs in lattice while rescoring. E.g., 20000 -acwt=1 -weight=1 # Interpolation weight for RNNLM. +acwt=0.1 +weight=0.5 # Interpolation weight for RNNLM. normalize=false # If true, we add a normalization step to the output of the RNNLM # so that it adds up to *exactly* 1. Note that this is not necessary # as in our RNNLM setup, a properly trained network would automatically # have its normalization term close to 1. The details of this # could be found at http://www.danielpovey.com/files/2018_icassp_rnnlm.pdf +lattice_prune_beam=4 # Beam used in pruned lattice composition + # This option affects speed and how large the composed lattice may be # End configuration section. @@ -73,6 +75,12 @@ awk -v n=$0 -v w=$weight 'BEGIN {if (w < 0 || w > 1) { print n": Interpolation weight should be in the range of [0, 1]"; exit 1;}}' \ || exit 1; +if ! head -n -1 $rnnlm_dir/config/words.txt | cmp $oldlang/words.txt -; then + # the last word of the RNNLM word list is an added word + echo "$0: Word lists mismatch for lattices and RNNLM." + exit 1 +fi + normalize_opt= if $normalize; then normalize_opt="--normalize-probs=true" @@ -97,6 +105,7 @@ cp $indir/num_jobs $outdir $cmd JOB=1:$nj $outdir/log/rescorelm.JOB.log \ lattice-lmrescore-kaldi-rnnlm-pruned --lm-scale=$weight $special_symbol_opts \ + --lattice-compose-beam=$lattice_prune_beam \ --acoustic-scale=$acwt --max-ngram-order=$max_ngram_order $normalize_opt $max_arcs_opt \ $carpa_option $oldlm $word_embedding "$rnnlm_dir/final.raw" \ "ark:gunzip -c $indir/lat.JOB.gz|" "ark,t:|gzip -c>$outdir/lat.JOB.gz" || exit 1; diff --git a/scripts/rnnlm/prepare_rnnlm_dir.sh b/scripts/rnnlm/prepare_rnnlm_dir.sh index d3ee44f1f95..e101822d983 100755 --- a/scripts/rnnlm/prepare_rnnlm_dir.sh +++ b/scripts/rnnlm/prepare_rnnlm_dir.sh @@ -23,7 +23,7 @@ if [ $# != 3 ]; then echo "Usage: $0 [options] " echo "Sets up the directory for RNNLM training as done by" echo "rnnlm/train_rnnlm.sh, and initializes the model." - echo " is as validated by rnnlm/validate_data_dir.py" + echo " is as validated by rnnlm/validate_text_dir.py" echo " is as validated by rnnlm/validate_config_dir.sh." exit 1 fi @@ -34,6 +34,7 @@ config_dir=$2 dir=$3 set -e +. ./path.sh if [ $stage -le 0 ]; then echo "$0: validating input" @@ -52,9 +53,13 @@ if [ $stage -le 1 ]; then echo "$0: copying config directory" mkdir -p $dir/config # copy expected things from $config_dir to $dir/config. - for f in words.txt features.txt data_weights.txt oov.txt xconfig; do + for f in words.txt data_weights.txt oov.txt xconfig; do cp $config_dir/$f $dir/config done + # features.txt is optional, check separately + if [ -f $config_dir/features.txt ]; then + cp $config_dir/features.txt $dir/config + fi fi rnnlm/get_special_symbol_opts.py < $dir/config/words.txt > $dir/special_symbol_opts.txt diff --git a/scripts/rnnlm/prepare_split_data.py b/scripts/rnnlm/prepare_split_data.py index 9cc4f69d09f..adcb164771d 100755 --- a/scripts/rnnlm/prepare_split_data.py +++ b/scripts/rnnlm/prepare_split_data.py @@ -8,6 +8,9 @@ import argparse import sys +import re +tab_or_space = re.compile('[ \t]+') + parser = argparse.ArgumentParser(description="This script prepares files containing integerized text, " "for consumption by nnet3-get-egs.", epilog="E.g. " + sys.argv[0] + " --vocab-file=data/rnnlm/vocab/words.txt " @@ -66,7 +69,7 @@ def read_data_weights(weights_file, data_sources): with open(weights_file, 'r', encoding="utf-8") as f: for line in f: try: - fields = line.split() + fields = re.split(tab_or_space, line) assert len(fields) == 3 if fields[0] in data_weights: raise Exception("duplicated data source({0}) specified in " diff --git a/scripts/rnnlm/show_word_features.py b/scripts/rnnlm/show_word_features.py index 89d84d53f3e..c5ed72a2899 100755 --- a/scripts/rnnlm/show_word_features.py +++ b/scripts/rnnlm/show_word_features.py @@ -6,8 +6,17 @@ import os import argparse import sys + +# The use of utf-8 encoding does not preclude reading utf-8. utf-8 encoding +# means "treat words as sequences of bytes", and it is compatible with utf-8 +# encoding as well as other encodings such as gbk, as long as the spaces are +# also spaces in ascii (which we check). It is basically how we emulate the +# behavior of python before python3. sys.stdout = open(1, 'w', encoding='utf-8', closefd=False) +import re +tab_or_space = re.compile('[ \t]+') + parser = argparse.ArgumentParser(description="This script turns the word features to a human readable format.", epilog="E.g. " + sys.argv[0] + "exp/rnnlm/word_feats.txt exp/rnnlm/features.txt " "> exp/rnnlm/word_feats.str.txt", @@ -29,7 +38,7 @@ def read_feature_type_and_key(features_file): with open(features_file, 'r', encoding="utf-8") as f: for line in f: - fields = line.split() + fields = re.split(tab_or_space, line) assert(len(fields) in [2, 3, 4]) feat_id = int(fields[0]) @@ -46,7 +55,7 @@ def read_feature_type_and_key(features_file): num_word_feats = 0 with open(args.word_features_file, 'r', encoding="utf-8") as f: for line in f: - fields = line.split() + fields = re.split(tab_or_space, line) assert len(fields) % 2 == 1 print(int(fields[0]), end='\t') diff --git a/scripts/rnnlm/train_rnnlm.sh b/scripts/rnnlm/train_rnnlm.sh index f056d096120..013e9a56c2f 100755 --- a/scripts/rnnlm/train_rnnlm.sh +++ b/scripts/rnnlm/train_rnnlm.sh @@ -41,7 +41,7 @@ use_gpu_for_diagnostics=false # set true to use GPU for compute_prob_*.log # optional cleanup options cleanup=false # add option --cleanup true to enable automatic cleanup of old models cleanup_strategy="keep_latest" # determines cleanup strategy, use either "keep_latest" or "keep_best" -cleanup_keep_iters=100 # number of iterations that will have their models retained +cleanup_keep_iters=3 # number of iterations that will have their models retained trap 'for pid in $(jobs -pr); do kill -KILL $pid; done' INT QUIT TERM . utils/parse_options.sh diff --git a/scripts/rnnlm/validate_features.py b/scripts/rnnlm/validate_features.py index a650092b086..939e634592c 100755 --- a/scripts/rnnlm/validate_features.py +++ b/scripts/rnnlm/validate_features.py @@ -7,6 +7,9 @@ import argparse import sys +import re +tab_or_space = re.compile('[ \t]+') + parser = argparse.ArgumentParser(description="Validates features file, produced by rnnlm/choose_features.py.", epilog="E.g. " + sys.argv[0] + " exp/rnnlm/features.txt", formatter_class=argparse.ArgumentDefaultsHelpFormatter) @@ -30,7 +33,7 @@ final_feats = {} word_feats = {} for line in f: - fields = line.split() + fields = re.split(tab_or_space, line) assert(len(fields) in [3, 4, 5]) assert idx == int(fields[0]) diff --git a/scripts/rnnlm/validate_text_dir.py b/scripts/rnnlm/validate_text_dir.py index d644d77911e..61914e4836a 100755 --- a/scripts/rnnlm/validate_text_dir.py +++ b/scripts/rnnlm/validate_text_dir.py @@ -7,6 +7,9 @@ import argparse import sys +import re +tab_or_space = re.compile('[ \t]+') + parser = argparse.ArgumentParser(description="Validates data directory containing text " "files from one or more data sources, including dev.txt.", epilog="E.g. " + sys.argv[0] + " data/rnnlm/data", @@ -51,7 +54,7 @@ def check_text_file(text_file): lineno += 1 if args.spot_check == 'true' and lineno > 10: break - words = line.split() + words = re.split(tab_or_space, line) if len(words) != 0: found_nonempty_line = True for word in words: @@ -75,7 +78,7 @@ def check_text_file(text_file): other_fields_set = set() with open(text_file, 'r', encoding="utf-8") as f: for line in f: - array = line.split() + array = re.split(tab_or_space, line) if len(array) > 0: first_word = array[0] if first_word in first_field_set or first_word in other_fields_set: diff --git a/scripts/rnnlm/validate_word_features.py b/scripts/rnnlm/validate_word_features.py index 3dc9b23aa41..303daf28bb1 100755 --- a/scripts/rnnlm/validate_word_features.py +++ b/scripts/rnnlm/validate_word_features.py @@ -7,6 +7,9 @@ import argparse import sys +import re +tab_or_space = re.compile('[ \t]+') + parser = argparse.ArgumentParser(description="Validates word features file, produced by rnnlm/get_word_features.py.", epilog="E.g. " + sys.argv[0] + " --features-file=exp/rnnlm/features.txt " "exp/rnnlm/word_feats.txt", @@ -27,7 +30,7 @@ max_feat_id = -1 with open(args.features_file, 'r', encoding="utf-8") as f: for line in f: - fields = line.split() + fields = re.split(tab_or_space, line) assert(len(fields) in [3, 4, 5]) feat_id = int(fields[0]) @@ -51,7 +54,7 @@ with open(args.word_features_file, 'r', encoding="utf-8") as f: for line in f: - fields = line.split() + fields = re.split(tab_or_space, line) assert len(fields) > 0 and len(fields) % 2 == 1 word_id = int(fields[0]) From 9c2d04424a76e49c78d988c132e3aa7a03e74c8a Mon Sep 17 00:00:00 2001 From: saikiranvalluri <41471921+saikiranvalluri@users.noreply.github.com> Date: Sat, 2 Mar 2019 17:29:23 +0530 Subject: [PATCH 3/7] Update choose_features.py --- scripts/rnnlm/choose_features.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/rnnlm/choose_features.py b/scripts/rnnlm/choose_features.py index 12d69b545e5..e8db64e6635 100755 --- a/scripts/rnnlm/choose_features.py +++ b/scripts/rnnlm/choose_features.py @@ -10,9 +10,9 @@ from collections import defaultdict sys.stdout = open(1, 'w', encoding='utf-8', closefd=False) -# because this script splits inside words, we cannot use utf-8; we actually need to know what +# because this script splits inside words, we cannot use latin-1; we actually need to know what # what the encoding is. By default we make this utf-8; to handle encodings that are not compatible -# with utf-8 (e.g. gbk), we'll eventually have to make the encoding an option to this script. +# with latin-1 (e.g. gbk), we'll eventually have to make the encoding an option to this script. import re tab_or_space = re.compile('[ \t]+') From 37bdc0a034908bb813404d786c6728bc8b69f460 Mon Sep 17 00:00:00 2001 From: saikiranvalluri <41471921+saikiranvalluri@users.noreply.github.com> Date: Sat, 2 Mar 2019 17:30:45 +0530 Subject: [PATCH 4/7] Update choose_features.py --- scripts/rnnlm/choose_features.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/rnnlm/choose_features.py b/scripts/rnnlm/choose_features.py index e8db64e6635..c6621e04494 100755 --- a/scripts/rnnlm/choose_features.py +++ b/scripts/rnnlm/choose_features.py @@ -12,7 +12,7 @@ # because this script splits inside words, we cannot use latin-1; we actually need to know what # what the encoding is. By default we make this utf-8; to handle encodings that are not compatible -# with latin-1 (e.g. gbk), we'll eventually have to make the encoding an option to this script. +# with utf-8 (e.g. gbk), we'll eventually have to make the encoding an option to this script. import re tab_or_space = re.compile('[ \t]+') From 5bc9385a472e94a1acf07151a5af75d05804d888 Mon Sep 17 00:00:00 2001 From: saikiranvalluri <41471921+saikiranvalluri@users.noreply.github.com> Date: Sat, 2 Mar 2019 17:33:04 +0530 Subject: [PATCH 5/7] Update show_word_features.py --- scripts/rnnlm/show_word_features.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/rnnlm/show_word_features.py b/scripts/rnnlm/show_word_features.py index c5ed72a2899..8b69fbb7d8a 100755 --- a/scripts/rnnlm/show_word_features.py +++ b/scripts/rnnlm/show_word_features.py @@ -7,7 +7,7 @@ import argparse import sys -# The use of utf-8 encoding does not preclude reading utf-8. utf-8 encoding +# The use of latin-1 encoding does not preclude reading utf-8. latin-1 encoding # means "treat words as sequences of bytes", and it is compatible with utf-8 # encoding as well as other encodings such as gbk, as long as the spaces are # also spaces in ascii (which we check). It is basically how we emulate the From 033e7b5497b9af9c4f578a0c4f500c98af15748c Mon Sep 17 00:00:00 2001 From: saikiranvalluri <41471921+saikiranvalluri@users.noreply.github.com> Date: Sun, 3 Mar 2019 09:44:25 +0530 Subject: [PATCH 6/7] Update show_word_features.py --- scripts/rnnlm/show_word_features.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/scripts/rnnlm/show_word_features.py b/scripts/rnnlm/show_word_features.py index 8b69fbb7d8a..d1051ddfcca 100755 --- a/scripts/rnnlm/show_word_features.py +++ b/scripts/rnnlm/show_word_features.py @@ -7,15 +7,10 @@ import argparse import sys -# The use of latin-1 encoding does not preclude reading utf-8. latin-1 encoding -# means "treat words as sequences of bytes", and it is compatible with utf-8 -# encoding as well as other encodings such as gbk, as long as the spaces are -# also spaces in ascii (which we check). It is basically how we emulate the -# behavior of python before python3. + sys.stdout = open(1, 'w', encoding='utf-8', closefd=False) import re -tab_or_space = re.compile('[ \t]+') parser = argparse.ArgumentParser(description="This script turns the word features to a human readable format.", epilog="E.g. " + sys.argv[0] + "exp/rnnlm/word_feats.txt exp/rnnlm/features.txt " @@ -38,7 +33,7 @@ def read_feature_type_and_key(features_file): with open(features_file, 'r', encoding="utf-8") as f: for line in f: - fields = re.split(tab_or_space, line) + fields = line.split() assert(len(fields) in [2, 3, 4]) feat_id = int(fields[0]) @@ -55,7 +50,7 @@ def read_feature_type_and_key(features_file): num_word_feats = 0 with open(args.word_features_file, 'r', encoding="utf-8") as f: for line in f: - fields = re.split(tab_or_space, line) + fields = line.split() assert len(fields) % 2 == 1 print(int(fields[0]), end='\t') From 7b5472846a58f887a0ad3bd0decdfce53d39b667 Mon Sep 17 00:00:00 2001 From: saikiranvalluri Date: Sun, 3 Mar 2019 04:27:56 +0000 Subject: [PATCH 7/7] Dan's changes resolved --- scripts/rnnlm/choose_features.py | 10 +++------- scripts/rnnlm/get_special_symbol_opts.py | 4 ++-- scripts/rnnlm/get_unigram_probs.py | 8 ++++---- scripts/rnnlm/get_vocab.py | 4 ++-- scripts/rnnlm/get_word_features.py | 8 ++++---- scripts/rnnlm/prepare_split_data.py | 4 ++-- scripts/rnnlm/show_word_features.py | 2 +- scripts/rnnlm/validate_features.py | 4 ++-- scripts/rnnlm/validate_text_dir.py | 6 +++--- scripts/rnnlm/validate_word_features.py | 6 +++--- 10 files changed, 26 insertions(+), 30 deletions(-) diff --git a/scripts/rnnlm/choose_features.py b/scripts/rnnlm/choose_features.py index c6621e04494..595c1d85bc1 100755 --- a/scripts/rnnlm/choose_features.py +++ b/scripts/rnnlm/choose_features.py @@ -10,12 +10,8 @@ from collections import defaultdict sys.stdout = open(1, 'w', encoding='utf-8', closefd=False) -# because this script splits inside words, we cannot use latin-1; we actually need to know what -# what the encoding is. By default we make this utf-8; to handle encodings that are not compatible -# with utf-8 (e.g. gbk), we'll eventually have to make the encoding an option to this script. - import re -tab_or_space = re.compile('[ \t]+') + parser = argparse.ArgumentParser(description="This script chooses the sparse feature representation of words. " "To be more specific, it chooses the set of features-- you compute " @@ -92,7 +88,7 @@ def read_vocab(vocab_file): vocab = {} with open(vocab_file, 'r', encoding="utf-8") as f: for line in f: - fields = re.split(tab_or_space, line) + fields = line.split() assert len(fields) == 2 if fields[0] in vocab: sys.exit(sys.argv[0] + ": duplicated word({0}) in vocab: {1}" @@ -121,7 +117,7 @@ def read_unigram_probs(unigram_probs_file): unigram_probs = [] with open(unigram_probs_file, 'r', encoding="utf-8") as f: for line in f: - fields = re.split(tab_or_space, line) + fields = line.split() assert len(fields) == 2 idx = int(fields[0]) if idx >= len(unigram_probs): diff --git a/scripts/rnnlm/get_special_symbol_opts.py b/scripts/rnnlm/get_special_symbol_opts.py index 0cf8e10feca..7ee0ca54c9a 100755 --- a/scripts/rnnlm/get_special_symbol_opts.py +++ b/scripts/rnnlm/get_special_symbol_opts.py @@ -9,7 +9,7 @@ import sys import re -tab_or_space = re.compile('[ \t]+') + parser = argparse.ArgumentParser(description="This script checks whether the special symbols " "appear in words.txt with expected values, if not, it will " @@ -30,7 +30,7 @@ upper_ids = {} input_stream = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8') for line in input_stream: - fields = re.split(tab_or_space, line) + fields = line.split() assert(len(fields) == 2) sym = fields[0] if sym in special_symbols: diff --git a/scripts/rnnlm/get_unigram_probs.py b/scripts/rnnlm/get_unigram_probs.py index d115b6f54bf..e3189b26a92 100755 --- a/scripts/rnnlm/get_unigram_probs.py +++ b/scripts/rnnlm/get_unigram_probs.py @@ -8,7 +8,7 @@ import sys import re -tab_or_space = re.compile('[ \t]+') + parser = argparse.ArgumentParser(description="This script gets the unigram probabilities of words.", epilog="E.g. " + sys.argv[0] + " --vocab-file=data/rnnlm/vocab/words.txt " @@ -80,7 +80,7 @@ def read_data_weights(weights_file, data_sources): with open(weights_file, 'r', encoding="utf-8") as f: for line in f: try: - fields = re.split(tab_or_space, line) + fields = line.split() assert len(fields) == 3 if fields[0] in data_weights: raise Exception("duplicated data source({0}) specified in " @@ -104,7 +104,7 @@ def read_vocab(vocab_file): vocab = {} with open(vocab_file, 'r', encoding="utf-8") as f: for line in f: - fields = re.split(tab_or_space, line) + fields = line.split() assert len(fields) == 2 if fields[0] in vocab: sys.exit(sys.argv[0] + ": duplicated word({0}) in vocab: {1}" @@ -133,7 +133,7 @@ def get_counts(data_sources, data_weights, vocab): with open(counts_file, 'r', encoding="utf-8") as f: for line in f: - fields = re.split(tab_or_space, line) + fields = line.split() if len(fields) != 2: print("Warning, should be 2 cols:", fields, line, file=sys.stderr); assert(len(fields) == 2) word = fields[0] diff --git a/scripts/rnnlm/get_vocab.py b/scripts/rnnlm/get_vocab.py index d65f8e3669b..baafcb3a131 100755 --- a/scripts/rnnlm/get_vocab.py +++ b/scripts/rnnlm/get_vocab.py @@ -9,7 +9,7 @@ sys.stdout = open(1, 'w', encoding='utf-8', closefd=False) import re -tab_or_space = re.compile('[ \t]+') + parser = argparse.ArgumentParser(description="This script get a vocab from unigram counts " "of words produced by get_unigram_counts.sh", @@ -31,7 +31,7 @@ def add_counts(word_counts, counts_file): with open(counts_file, 'r', encoding="utf-8") as f: for line in f: line = line.strip(" \t\r\n") - word_and_count = re.split(tab_or_space, line) + word_and_count = line.split() assert len(word_and_count) == 2 if word_and_count[0] in word_counts: word_counts[word_and_count[0]] += int(word_and_count[1]) diff --git a/scripts/rnnlm/get_word_features.py b/scripts/rnnlm/get_word_features.py index 7555b774b83..cdcc0a77734 100755 --- a/scripts/rnnlm/get_word_features.py +++ b/scripts/rnnlm/get_word_features.py @@ -10,7 +10,7 @@ from collections import defaultdict import re -tab_or_space = re.compile('[ \t]+') + parser = argparse.ArgumentParser(description="This script turns the words into the sparse feature representation, " "using features from rnnlm/choose_features.py.", @@ -43,7 +43,7 @@ def read_vocab(vocab_file): vocab = {} with open(vocab_file, 'r', encoding="utf-8") as f: for line in f: - fields = re.split(tab_or_space, line) + fields = line.split() assert len(fields) == 2 if fields[0] in vocab: sys.exit(sys.argv[0] + ": duplicated word({0}) in vocab: {1}" @@ -64,7 +64,7 @@ def read_unigram_probs(unigram_probs_file): unigram_probs = [] with open(unigram_probs_file, 'r', encoding="utf-8") as f: for line in f: - fields = re.split(tab_or_space, line) + fields = line.split() assert len(fields) == 2 idx = int(fields[0]) if idx >= len(unigram_probs): @@ -105,7 +105,7 @@ def read_features(features_file): with open(features_file, 'r', encoding="utf-8") as f: for line in f: - fields = re.split(tab_or_space, line) + fields = line.split() assert(len(fields) in [3, 4, 5]) feat_id = int(fields[0]) diff --git a/scripts/rnnlm/prepare_split_data.py b/scripts/rnnlm/prepare_split_data.py index adcb164771d..427f043df98 100755 --- a/scripts/rnnlm/prepare_split_data.py +++ b/scripts/rnnlm/prepare_split_data.py @@ -9,7 +9,7 @@ import sys import re -tab_or_space = re.compile('[ \t]+') + parser = argparse.ArgumentParser(description="This script prepares files containing integerized text, " "for consumption by nnet3-get-egs.", @@ -69,7 +69,7 @@ def read_data_weights(weights_file, data_sources): with open(weights_file, 'r', encoding="utf-8") as f: for line in f: try: - fields = re.split(tab_or_space, line) + fields = line.split() assert len(fields) == 3 if fields[0] in data_weights: raise Exception("duplicated data source({0}) specified in " diff --git a/scripts/rnnlm/show_word_features.py b/scripts/rnnlm/show_word_features.py index d1051ddfcca..4335caed5d8 100755 --- a/scripts/rnnlm/show_word_features.py +++ b/scripts/rnnlm/show_word_features.py @@ -7,11 +7,11 @@ import argparse import sys - sys.stdout = open(1, 'w', encoding='utf-8', closefd=False) import re + parser = argparse.ArgumentParser(description="This script turns the word features to a human readable format.", epilog="E.g. " + sys.argv[0] + "exp/rnnlm/word_feats.txt exp/rnnlm/features.txt " "> exp/rnnlm/word_feats.str.txt", diff --git a/scripts/rnnlm/validate_features.py b/scripts/rnnlm/validate_features.py index 939e634592c..e67f03207bb 100755 --- a/scripts/rnnlm/validate_features.py +++ b/scripts/rnnlm/validate_features.py @@ -8,7 +8,7 @@ import sys import re -tab_or_space = re.compile('[ \t]+') + parser = argparse.ArgumentParser(description="Validates features file, produced by rnnlm/choose_features.py.", epilog="E.g. " + sys.argv[0] + " exp/rnnlm/features.txt", @@ -33,7 +33,7 @@ final_feats = {} word_feats = {} for line in f: - fields = re.split(tab_or_space, line) + fields = line.split() assert(len(fields) in [3, 4, 5]) assert idx == int(fields[0]) diff --git a/scripts/rnnlm/validate_text_dir.py b/scripts/rnnlm/validate_text_dir.py index 61914e4836a..1f250d4c2f8 100755 --- a/scripts/rnnlm/validate_text_dir.py +++ b/scripts/rnnlm/validate_text_dir.py @@ -8,7 +8,7 @@ import sys import re -tab_or_space = re.compile('[ \t]+') + parser = argparse.ArgumentParser(description="Validates data directory containing text " "files from one or more data sources, including dev.txt.", @@ -54,7 +54,7 @@ def check_text_file(text_file): lineno += 1 if args.spot_check == 'true' and lineno > 10: break - words = re.split(tab_or_space, line) + words = line.split() if len(words) != 0: found_nonempty_line = True for word in words: @@ -78,7 +78,7 @@ def check_text_file(text_file): other_fields_set = set() with open(text_file, 'r', encoding="utf-8") as f: for line in f: - array = re.split(tab_or_space, line) + array = line.split() if len(array) > 0: first_word = array[0] if first_word in first_field_set or first_word in other_fields_set: diff --git a/scripts/rnnlm/validate_word_features.py b/scripts/rnnlm/validate_word_features.py index 303daf28bb1..372286d8d12 100755 --- a/scripts/rnnlm/validate_word_features.py +++ b/scripts/rnnlm/validate_word_features.py @@ -8,7 +8,7 @@ import sys import re -tab_or_space = re.compile('[ \t]+') + parser = argparse.ArgumentParser(description="Validates word features file, produced by rnnlm/get_word_features.py.", epilog="E.g. " + sys.argv[0] + " --features-file=exp/rnnlm/features.txt " @@ -30,7 +30,7 @@ max_feat_id = -1 with open(args.features_file, 'r', encoding="utf-8") as f: for line in f: - fields = re.split(tab_or_space, line) + fields = line.split() assert(len(fields) in [3, 4, 5]) feat_id = int(fields[0]) @@ -54,7 +54,7 @@ with open(args.word_features_file, 'r', encoding="utf-8") as f: for line in f: - fields = re.split(tab_or_space, line) + fields = line.split() assert len(fields) > 0 and len(fields) % 2 == 1 word_id = int(fields[0])