diff --git a/egs/wsj/s5/steps/train_sat.sh b/egs/wsj/s5/steps/train_sat.sh index 0211f7bcf67..92b744dc75c 100755 --- a/egs/wsj/s5/steps/train_sat.sh +++ b/egs/wsj/s5/steps/train_sat.sh @@ -276,4 +276,3 @@ steps/info/gmm_dir_info.pl $dir echo "$0: done training SAT system in $dir" exit 0 - diff --git a/egs/wsj/s5/utils/apply_map.pl b/egs/wsj/s5/utils/apply_map.pl index ff9507fd894..a138287170b 100755 --- a/egs/wsj/s5/utils/apply_map.pl +++ b/egs/wsj/s5/utils/apply_map.pl @@ -9,47 +9,59 @@ # be sequences of tokens. See the usage message. -if (@ARGV > 0 && $ARGV[0] eq "-f") { - shift @ARGV; - $field_spec = shift @ARGV; - if ($field_spec =~ m/^\d+$/) { - $field_begin = $field_spec - 1; $field_end = $field_spec - 1; - } - if ($field_spec =~ m/^(\d*)[-:](\d*)/) { # accept e.g. 1:10 as a courtesty (properly, 1-10) - if ($1 ne "") { - $field_begin = $1 - 1; # Change to zero-based indexing. +$permissive = 0; + +for ($x = 0; $x <= 2; $x++) { + + if (@ARGV > 0 && $ARGV[0] eq "-f") { + shift @ARGV; + $field_spec = shift @ARGV; + if ($field_spec =~ m/^\d+$/) { + $field_begin = $field_spec - 1; $field_end = $field_spec - 1; } - if ($2 ne "") { - $field_end = $2 - 1; # Change to zero-based indexing. + if ($field_spec =~ m/^(\d*)[-:](\d*)/) { # accept e.g. 1:10 as a courtesty (properly, 1-10) + if ($1 ne "") { + $field_begin = $1 - 1; # Change to zero-based indexing. + } + if ($2 ne "") { + $field_end = $2 - 1; # Change to zero-based indexing. + } + } + if (!defined $field_begin && !defined $field_end) { + die "Bad argument to -f option: $field_spec"; } } - if (!defined $field_begin && !defined $field_end) { - die "Bad argument to -f option: $field_spec"; - } -} -# Mapping is obligatory -$permissive = 0; -if (@ARGV > 0 && $ARGV[0] eq '--permissive') { - shift @ARGV; - # Mapping is optional (missing key is printed to output) - $permissive = 1; + if (@ARGV > 0 && $ARGV[0] eq '--permissive') { + shift @ARGV; + # Mapping is optional (missing key is printed to output) + $permissive = 1; + } } if(@ARGV != 1) { print STDERR "Invalid usage: " . join(" ", @ARGV) . "\n"; - print STDERR "Usage: apply_map.pl [options] map output\n" . - "options: [-f ]\n" . - "Applies the map 'map' to all input text, where each line of the map\n" . - "is interpreted as a map from the first field to the list of the other fields\n" . - "Note: can look like 4-5, or 4-, or 5-, or 1, it means the field\n" . - "range in the input to apply the map to.\n" . - "e.g.: echo A B | apply_map.pl a.txt\n" . - "where a.txt is:\n" . - "A a1 a2\n" . - "B b\n" . - "will produce:\n" . - "a1 a2 b\n"; + print STDERR <<'EOF'; +Usage: apply_map.pl [options] map output + options: [-f ] [--permissive] + This applies a map to some specified fields of some input text: + For each line in the map file: the first field is the thing wae + map from, and the remaining fields are the sequence we map it to. + The -f (field-range) option says which fields of the input file the map + map should apply to. + If the --permissive option is supplied, fields which are not present + in the map will be left as they were. + Applies the map 'map' to all input text, where each line of the map + is interpreted as a map from the first field to the list of the other fields + Note: can look like 4-5, or 4-, or 5-, or 1, it means the field + range in the input to apply the map to. + e.g.: echo A B | apply_map.pl a.txt + where a.txt is: + A a1 a2 + B b + will produce: + a1 a2 b +EOF exit(1); } @@ -72,12 +84,12 @@ $a = $A[$x]; if (!defined $map{$a}) { if (!$permissive) { - die "apply_map.pl: undefined key $a in $map_file\n"; + die "apply_map.pl: undefined key $a in $map_file\n"; } else { print STDERR "apply_map.pl: warning! missing key $a in $map_file\n"; } } else { - $A[$x] = $map{$a}; + $A[$x] = $map{$a}; } } } diff --git a/egs/wsj/s5/utils/data/perturb_speed_to_allowed_lengths.py b/egs/wsj/s5/utils/data/perturb_speed_to_allowed_lengths.py index c6bdb95cb2f..7924fc4fcf1 100755 --- a/egs/wsj/s5/utils/data/perturb_speed_to_allowed_lengths.py +++ b/egs/wsj/s5/utils/data/perturb_speed_to_allowed_lengths.py @@ -66,7 +66,7 @@ class Utterance: """ def __init__(self, uid, wavefile, speaker, transcription, dur): - self.wavefile = (wavefile if wavefile.rstrip().endswith('|') else + self.wavefile = (wavefile if wavefile.rstrip(" \t\r\n").endswith('|') else 'cat {} |'.format(wavefile)) self.speaker = speaker self.transcription = transcription @@ -130,7 +130,7 @@ def read_kaldi_mapfile(path): m = {} with open(path, 'r', encoding='latin-1') as f: for line in f: - line = line.strip() + line = line.strip(" \t\r\n") sp_pos = line.find(' ') key = line[:sp_pos] val = line[sp_pos+1:] diff --git a/egs/wsj/s5/utils/lang/bpe/prepend_words.py b/egs/wsj/s5/utils/lang/bpe/prepend_words.py index face771c7ca..4a11895a712 100755 --- a/egs/wsj/s5/utils/lang/bpe/prepend_words.py +++ b/egs/wsj/s5/utils/lang/bpe/prepend_words.py @@ -4,11 +4,13 @@ # the beginning of the words for finding the initial-space of every word # after decoding. -import sys, io +import sys +import io +import re +whitespace = re.compile("[ \t]+") infile = io.TextIOWrapper(sys.stdin.buffer, encoding='latin-1') output = io.TextIOWrapper(sys.stdout.buffer, encoding='latin-1') for line in infile: - output.write(' '.join([ "|"+word for word in line.split()]) + '\n') - - + words = whitespace.split(line.strip(" \t\r\n")) + output.write(' '.join([ "|"+word for word in words]) + '\n') diff --git a/egs/wsj/s5/utils/lang/compute_sentence_probs_arpa.py b/egs/wsj/s5/utils/lang/compute_sentence_probs_arpa.py index 5a7743badee..dc480903db4 100755 --- a/egs/wsj/s5/utils/lang/compute_sentence_probs_arpa.py +++ b/egs/wsj/s5/utils/lang/compute_sentence_probs_arpa.py @@ -99,13 +99,13 @@ def compute_begin_prob(sub_list): for i in range(1, len(sub_list) - 1): logprob += compute_sublist_prob(sub_list[:i + 1]) return logprob - + # The probability is computed in this way: # p(word_N | word_N-1 ... word_1) = ngram_dict[word_1 ... word_N][0]. # Here gram_dict is a dictionary stores a tuple corresponding to ngrams. # The first element of tuple is probablity and the second is backoff probability (if exists). # If the particular ngram (word_1 ... word_N) is not in the dictionary, then -# p(word_N | word_N-1 ... word_1) = p(word_N | word_(N-1) ... word_2) * backoff_weight(word_(N-1) | word_(N-2) ... word_1) +# p(word_N | word_N-1 ... word_1) = p(word_N | word_(N-1) ... word_2) * backoff_weight(word_(N-1) | word_(N-2) ... word_1) # If the sequence (word_(N-1) ... word_1) is not in the dictionary, then the backoff_weight gets replaced with 0.0 (log1) # More details can be found in https://cmusphinx.github.io/wiki/arpaformat/ def compute_sentence_prob(sentence, ngram_order): @@ -127,7 +127,7 @@ def compute_sentence_prob(sentence, ngram_order): logprob += compute_sublist_prob(cur_sublist) return logprob - + def output_result(text_in_handle, output_file_handle, ngram_order): lines = text_in_handle.readlines() @@ -139,8 +139,8 @@ def output_result(text_in_handle, output_file_handle, ngram_order): output_file_handle.write("{}\n".format(new_logprob)) text_in_handle.close() output_file_handle.close() - - + + if __name__ == "__main__": check_args(args) ngram_dict, tot_num = load_model(args.arpa_lm) @@ -149,7 +149,7 @@ def output_result(text_in_handle, output_file_handle, ngram_order): if not num_valid: sys.exit("compute_sentence_probs_arpa.py: Wrong loading model.") if args.ngram_order <= 0 or args.ngram_order > max_ngram_order: - sys.exit("compute_sentence_probs_arpa.py: " + + sys.exit("compute_sentence_probs_arpa.py: " + "Invalid ngram_order (either negative or greater than maximum ngram number ({}) allowed)".format(max_ngram_order)) output_result(args.text_in_handle, args.prob_file_handle, args.ngram_order) diff --git a/egs/wsj/s5/utils/lang/grammar/augment_phones_txt.py b/egs/wsj/s5/utils/lang/grammar/augment_phones_txt.py index 1033df31ad0..f0087680a4b 100755 --- a/egs/wsj/s5/utils/lang/grammar/augment_phones_txt.py +++ b/egs/wsj/s5/utils/lang/grammar/augment_phones_txt.py @@ -2,6 +2,7 @@ import argparse +import re import os import sys @@ -34,11 +35,12 @@ def read_phones_txt(filename): # with utf-8 encoding as well as other encodings such as gbk, as long as the # spaces are also spaces in ascii (which we check). It is basically how we # emulate the behavior of python before python3. + whitespace = re.compile("[ \t]+") with open(filename, 'r', encoding='latin-1') as f: - lines = [line.strip() for line in f] + lines = [line.strip(" \t\r\n") for line in f] highest_numbered_symbol = 0 for line in lines: - s = line.split() + s = whitespace.split(line) try: i = int(s[1]) if i > highest_numbered_symbol: @@ -57,9 +59,9 @@ def read_nonterminals(filename): it has the expected format and has no duplicates, and returns the nonterminal symbols as a list of strings, e.g. ['#nonterm:contact_list', '#nonterm:phone_number', ... ]. """ - ans = [line.strip() for line in open(filename, 'r', encoding='latin-1')] + ans = [line.strip(" \t\r\n") for line in open(filename, 'r', encoding='latin-1')] if len(ans) == 0: - raise RuntimeError("The file {0} contains no nonterminals symbols.".format(filename)) + raise RuntimeError("The file {0} contains no nonterminal symbols.".format(filename)) for nonterm in ans: if nonterm[:9] != '#nonterm:': raise RuntimeError("In file '{0}', expected nonterminal symbols to start with '#nonterm:', found '{1}'" diff --git a/egs/wsj/s5/utils/lang/grammar/augment_words_txt.py b/egs/wsj/s5/utils/lang/grammar/augment_words_txt.py index 5cd6f904efe..00ab9e59eaa 100755 --- a/egs/wsj/s5/utils/lang/grammar/augment_words_txt.py +++ b/egs/wsj/s5/utils/lang/grammar/augment_words_txt.py @@ -35,11 +35,12 @@ def read_words_txt(filename): # with utf-8 encoding as well as other encodings such as gbk, as long as the # spaces are also spaces in ascii (which we check). It is basically how we # emulate the behavior of python before python3. + whitespace = re.compile("[ \t]+") with open(filename, 'r', encoding='latin-1') as f: - lines = [line.strip() for line in f] + lines = [line.strip(" \t\r\n") for line in f] highest_numbered_symbol = 0 for line in lines: - s = line.split() + s = whitespace.split(line) try: i = int(s[1]) if i > highest_numbered_symbol: @@ -58,9 +59,9 @@ def read_nonterminals(filename): it has the expected format and has no duplicates, and returns the nonterminal symbols as a list of strings, e.g. ['#nonterm:contact_list', '#nonterm:phone_number', ... ]. """ - ans = [line.strip() for line in open(filename, 'r', encoding='latin-1')] + ans = [line.strip(" \t\r\n") for line in open(filename, 'r', encoding='latin-1')] if len(ans) == 0: - raise RuntimeError("The file {0} contains no nonterminals symbols.".format(filename)) + raise RuntimeError("The file {0} contains no nonterminal symbols.".format(filename)) for nonterm in ans: if nonterm[:9] != '#nonterm:': raise RuntimeError("In file '{0}', expected nonterminal symbols to start with '#nonterm:', found '{1}'" diff --git a/egs/wsj/s5/utils/lang/limit_arpa_unk_history.py b/egs/wsj/s5/utils/lang/limit_arpa_unk_history.py index f7e0dcbdc5f..68f7b4b5639 100755 --- a/egs/wsj/s5/utils/lang/limit_arpa_unk_history.py +++ b/egs/wsj/s5/utils/lang/limit_arpa_unk_history.py @@ -58,6 +58,7 @@ def get_ngram_stats(old_lm_lines): def find_and_replace_unks(old_lm_lines, max_ngrams, skip_rows): ngram_diffs = defaultdict(int) + whitespace_pattern = re.compile("[ \t]+") unk_pattern = re.compile( "[0-9.-]+(?:[\s\\t]\S+){1,3}[\s\\t]" + args.oov_dict_entry + "[\s\\t](?!-[0-9]+\.[0-9]+).*") @@ -70,7 +71,7 @@ def find_and_replace_unks(old_lm_lines, max_ngrams, skip_rows): new_lm_lines = old_lm_lines[:skip_rows] for i in range(skip_rows, len(old_lm_lines)): - line = old_lm_lines[i].strip() + line = old_lm_lines[i].strip(" \t\r\n") if "\{}-grams:".format(3) in line: passed_2grams = True @@ -101,7 +102,7 @@ def find_and_replace_unks(old_lm_lines, max_ngrams, skip_rows): if not last_ngram: g_backoff = backoff_pattern.search(line) if g_backoff: - updated_row = g_backoff.group(0).split()[:-1] + updated_row = whitespace_pattern.split(g_backoff.group(0))[:-1] updated_row = updated_row[0] + \ "\t" + " ".join(updated_row[1:]) + "\n" new_lm_lines.append(updated_row) diff --git a/egs/wsj/s5/utils/lang/make_lexicon_fst.py b/egs/wsj/s5/utils/lang/make_lexicon_fst.py index 67ed0ac2789..790af2f2314 100755 --- a/egs/wsj/s5/utils/lang/make_lexicon_fst.py +++ b/egs/wsj/s5/utils/lang/make_lexicon_fst.py @@ -72,28 +72,28 @@ def read_lexiconp(filename): with open(filename, 'r', encoding='latin-1') as f: whitespace = re.compile("[ \t]+") for line in f: - a = whitespace.split(line.strip()) + a = whitespace.split(line.strip(" \t\r\n")) if len(a) < 2: print("{0}: error: found bad line '{1}' in lexicon file {2} ".format( - sys.argv[0], line.strip(), filename), file=sys.stderr) + sys.argv[0], line.strip(" \t\r\n"), filename), file=sys.stderr) sys.exit(1) word = a[0] if word == "": # This would clash with the epsilon symbol normally used in OpenFst. print("{0}: error: found as a word in lexicon file " - "{1}".format(line.strip(), filename), file=sys.stderr) + "{1}".format(line.strip(" \t\r\n"), filename), file=sys.stderr) sys.exit(1) try: pron_prob = float(a[1]) except: print("{0}: error: found bad line '{1}' in lexicon file {2}, 2nd field " - "should be pron-prob".format(sys.argv[0], line.strip(), filename), + "should be pron-prob".format(sys.argv[0], line.strip(" \t\r\n"), filename), file=sys.stderr) sys.exit(1) prons = a[2:] if pron_prob <= 0.0: print("{0}: error: invalid pron-prob in line '{1}' of lexicon file {1} ".format( - sys.argv[0], line.strip(), filename), file=sys.stderr) + sys.argv[0], line.strip(" \t\r\n"), filename), file=sys.stderr) sys.exit(1) if len(prons) == 0: found_empty_prons = True @@ -324,7 +324,7 @@ def read_nonterminals(filename): it has the expected format and has no duplicates, and returns the nonterminal symbols as a list of strings, e.g. ['#nonterm:contact_list', '#nonterm:phone_number', ... ]. """ - ans = [line.strip() for line in open(filename, 'r', encoding='latin-1')] + ans = [line.strip(" \t\r\n") for line in open(filename, 'r', encoding='latin-1')] if len(ans) == 0: raise RuntimeError("The file {0} contains no nonterminals symbols.".format(filename)) for nonterm in ans: @@ -338,11 +338,12 @@ def read_nonterminals(filename): def read_left_context_phones(filename): """Reads, checks, and returns a list of left-context phones, in text form, one per line. Returns a list of strings, e.g. ['a', 'ah', ..., '#nonterm_bos' ]""" - ans = [line.strip() for line in open(filename, 'r', encoding='latin-1')] + ans = [line.strip(" \t\r\n") for line in open(filename, 'r', encoding='latin-1')] if len(ans) == 0: raise RuntimeError("The file {0} contains no left-context phones.".format(filename)) + whitespace = re.compile("[ \t]+") for s in ans: - if len(s.split()) != 1: + if len(whitespace.split(s)) != 1: raise RuntimeError("The file {0} contains an invalid line '{1}'".format(filename, s) ) if len(set(ans)) != len(ans): @@ -354,7 +355,8 @@ def is_token(s): """Returns true if s is a string and is space-free.""" if not isinstance(s, str): return False - split_str = s.split() + whitespace = re.compile("[ \t\r\n]+") + split_str = whitespace.split(s); return len(split_str) == 1 and s == split_str[0] diff --git a/egs/wsj/s5/utils/lang/make_lexicon_fst_silprob.py b/egs/wsj/s5/utils/lang/make_lexicon_fst_silprob.py index ed88f0468c4..d5a9cc334be 100755 --- a/egs/wsj/s5/utils/lang/make_lexicon_fst_silprob.py +++ b/egs/wsj/s5/utils/lang/make_lexicon_fst_silprob.py @@ -82,10 +82,10 @@ def read_silprobs(filename): with open(filename, 'r', encoding='latin-1') as f: whitespace = re.compile("[ \t]+") for line in f: - a = whitespace.split(line.strip()) + a = whitespace.split(line.strip(" \t\r\n")) if len(a) != 2: print("{0}: error: found bad line '{1}' in silprobs file {1} ".format( - sys.argv[0], line.strip(), filename), file=sys.stderr) + sys.argv[0], line.strip(" \t\r\n"), filename), file=sys.stderr) sys.exit(1) label = a[0] try: @@ -101,7 +101,7 @@ def read_silprobs(filename): raise RuntimeError() except: print("{0}: error: found bad line '{1}' in silprobs file {1}" - .format(sys.argv[0], line.strip(), filename), + .format(sys.argv[0], line.strip(" \t\r\n"), filename), file=sys.stderr) sys.exit(1) if (silbeginprob <= 0.0 or silbeginprob > 1.0 or @@ -130,18 +130,19 @@ def read_lexiconp(filename): found_empty_prons = False found_large_pronprobs = False # See the comment near the top of this file, RE why we use latin-1. + whitespace = re.compile("[ \t]+") with open(filename, 'r', encoding='latin-1') as f: for line in f: - a = line.split() + a = whitespace.split(line.strip(" \t\r\n")) if len(a) < 2: print("{0}: error: found bad line '{1}' in lexicon file {1} ".format( - sys.argv[0], line.strip(), filename), file=sys.stderr) + sys.argv[0], line.strip(" \t\r\n"), filename), file=sys.stderr) sys.exit(1) word = a[0] if word == "": # This would clash with the epsilon symbol normally used in OpenFst. print("{0}: error: found as a word in lexicon file " - "{1}".format(line.strip(), filename), file=sys.stderr) + "{1}".format(line.strip(" \t\r\n"), filename), file=sys.stderr) sys.exit(1) try: pron_prob = float(a[1]) @@ -151,13 +152,13 @@ def read_lexiconp(filename): except: print("{0}: error: found bad line '{1}' in lexicon file {2}, 2nd field " "through 5th field should be numbers".format(sys.argv[0], - line.strip(), filename), + line.strip(" \t\r\n"), filename), file=sys.stderr) sys.exit(1) prons = a[5:] if pron_prob <= 0.0: print("{0}: error: invalid pron-prob in line '{1}' of lexicon file {2} ".format( - sys.argv[0], line.strip(), filename), file=sys.stderr) + sys.argv[0], line.strip(" \t\r\n"), filename), file=sys.stderr) sys.exit(1) if len(prons) == 0: found_empty_prons = True @@ -357,7 +358,7 @@ def read_nonterminals(filename): it has the expected format and has no duplicates, and returns the nonterminal symbols as a list of strings, e.g. ['#nonterm:contact_list', '#nonterm:phone_number', ... ]. """ - ans = [line.strip() for line in open(filename, 'r', encoding='latin-1')] + ans = [line.strip(" \t\r\n") for line in open(filename, 'r', encoding='latin-1')] if len(ans) == 0: raise RuntimeError("The file {0} contains no nonterminals symbols.".format(filename)) for nonterm in ans: @@ -371,7 +372,7 @@ def read_nonterminals(filename): def read_left_context_phones(filename): """Reads, checks, and returns a list of left-context phones, in text form, one per line. Returns a list of strings, e.g. ['a', 'ah', ..., '#nonterm_bos' ]""" - ans = [line.strip() for line in open(filename, 'r', encoding='latin-1')] + ans = [line.strip(" \t\r\n") for line in open(filename, 'r', encoding='latin-1')] if len(ans) == 0: raise RuntimeError("The file {0} contains no left-context phones.".format(filename)) for s in ans: diff --git a/scripts/rnnlm/get_vocab.py b/scripts/rnnlm/get_vocab.py index 5036db0ed2a..1502e915f9c 100755 --- a/scripts/rnnlm/get_vocab.py +++ b/scripts/rnnlm/get_vocab.py @@ -30,7 +30,7 @@ def add_counts(word_counts, counts_file): with open(counts_file, 'r', encoding="latin-1") as f: for line in f: - line = line.strip() + line = line.strip(" \t\r\n") word_and_count = re.split(tab_or_space, line) assert len(word_and_count) == 2 if word_and_count[0] in word_counts: diff --git a/src/chain/language-model.cc b/src/chain/language-model.cc index 41e06116ea8..dd69340a6b8 100644 --- a/src/chain/language-model.cc +++ b/src/chain/language-model.cc @@ -129,7 +129,6 @@ int32 LanguageModelEstimator::FindOrCreateLmStateIndexForHistory( int32 backoff_lm_state = FindOrCreateLmStateIndexForHistory( backoff_hist); lm_states_[ans].backoff_lmstate_index = backoff_lm_state; - hist_to_lmstate_index_[backoff_hist] = backoff_lm_state; } return ans; } @@ -298,7 +297,7 @@ int32 LanguageModelEstimator::AssignFstStates() { void LanguageModelEstimator::Estimate(fst::StdVectorFst *fst) { KALDI_LOG << "Estimating language model with --no-prune-ngram-order=" << opts_.no_prune_ngram_order << ", --ngram-order=" - << opts_.ngram_order << ", --num-extra-lm-state=" + << opts_.ngram_order << ", --num-extra-lm-states=" << opts_.num_extra_lm_states; SetParentCounts(); num_basic_lm_states_ = CheckActiveStates(); @@ -408,5 +407,3 @@ void LanguageModelEstimator::OutputToFst( } // namespace chain } // namespace kaldi - -