diff --git a/egs/ami/s5/cmd.sh b/egs/ami/s5/cmd.sh index dd7145dff60..4d0e0fe0f6b 100644 --- a/egs/ami/s5/cmd.sh +++ b/egs/ami/s5/cmd.sh @@ -12,6 +12,7 @@ export train_cmd="queue.pl --mem 1G" export decode_cmd="queue.pl --mem 2G" +export tensorflow_cmd="queue.pl -l hostname=b*" # the use of cuda_cmd is deprecated but it is sometimes still used in nnet1 # scripts. export cuda_cmd="queue.pl --gpu 1 --mem 20G" diff --git a/egs/ami/s5/local/tensorflow/lstm.py b/egs/ami/s5/local/tensorflow/lstm.py new file mode 100644 index 00000000000..1aba92b129b --- /dev/null +++ b/egs/ami/s5/local/tensorflow/lstm.py @@ -0,0 +1,384 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# Modified by Hainan Xu to be used in Kaldi for lattice rescoring 2017 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys + +sys.path.insert(0,"/home/hxu/.local/lib/python2.7/site-packages/") + +import inspect +import time + +import numpy as np +import tensorflow as tf + +import reader + +flags = tf.flags +logging = tf.logging + +flags.DEFINE_string( + "model", "small", + "A type of model. Possible options are: small, medium, large.") +flags.DEFINE_string("data_path", None, + "Where the training/test data is stored.") +flags.DEFINE_string("vocab_path", None, + "Where the wordlist file is stored.") +flags.DEFINE_string("save_path", None, + "Model output directory.") +flags.DEFINE_bool("use_fp16", False, + "Train using 16-bit floats instead of 32bit floats") + +FLAGS = flags.FLAGS + + +def data_type(): + return tf.float16 if FLAGS.use_fp16 else tf.float32 + + +class RNNLMInput(object): + """The input data.""" + + def __init__(self, config, data, name=None): + self.batch_size = batch_size = config.batch_size + self.num_steps = num_steps = config.num_steps + self.epoch_size = ((len(data) // batch_size) - 1) // num_steps + self.input_data, self.targets = reader.rnnlm_producer( + data, batch_size, num_steps, name=name) + + +class RNNLMModel(object): + """The RNNLM model.""" + + def __init__(self, is_training, config, input_): + self._input = input_ + + batch_size = input_.batch_size + num_steps = input_.num_steps + size = config.hidden_size + vocab_size = config.vocab_size + + # Slightly better results can be obtained with forget gate biases + # initialized to 1 but the hyperparameters of the model would need to be + # different than reported in the paper. + def lstm_cell(): + # With the latest TensorFlow source code (as of Mar 27, 2017), + # the BasicLSTMCell will need a reuse parameter which is unfortunately not + # defined in TensorFlow 1.0. To maintain backwards compatibility, we add + # an argument check here: + if 'reuse' in inspect.getargspec( + tf.contrib.rnn.BasicLSTMCell.__init__).args: + return tf.contrib.rnn.BasicLSTMCell( + size, forget_bias=0.0, state_is_tuple=True, + reuse=tf.get_variable_scope().reuse) + else: + return tf.contrib.rnn.BasicLSTMCell( + size, forget_bias=0.0, state_is_tuple=True) + attn_cell = lstm_cell + if is_training and config.keep_prob < 1: + def attn_cell(): + return tf.contrib.rnn.DropoutWrapper( + lstm_cell(), output_keep_prob=config.keep_prob) + self.cell = tf.contrib.rnn.MultiRNNCell( + [attn_cell() for _ in range(config.num_layers)], state_is_tuple=True) + + self._initial_state = self.cell.zero_state(batch_size, data_type()) + self._initial_state_single = self.cell.zero_state(1, data_type()) + + self.initial = tf.reshape(tf.stack(axis=0, values=self._initial_state_single), [config.num_layers, 2, 1, size], name="test_initial_state") + + + # first implement the less efficient version + test_word_in = tf.placeholder(tf.int32, [1, 1], name="test_word_in") + + state_placeholder = tf.placeholder(tf.float32, [config.num_layers, 2, 1, size], name="test_state_in") + # unpacking the input state context + l = tf.unstack(state_placeholder, axis=0) + test_input_state = tuple( + [tf.contrib.rnn.LSTMStateTuple(l[idx][0],l[idx][1]) + for idx in range(config.num_layers)] + ) + + with tf.device("/cpu:0"): + self.embedding = tf.get_variable( + "embedding", [vocab_size, size], dtype=data_type()) + + inputs = tf.nn.embedding_lookup(self.embedding, input_.input_data) + test_inputs = tf.nn.embedding_lookup(self.embedding, test_word_in) + + # test time + with tf.variable_scope("RNN"): + (test_cell_output, test_output_state) = self.cell(test_inputs[:, 0, :], test_input_state) + + test_state_out = tf.reshape(tf.stack(axis=0, values=test_output_state), [config.num_layers, 2, 1, size], name="test_state_out") + test_cell_out = tf.reshape(test_cell_output, [1, size], name="test_cell_out") + # above is the first part of the graph for test + # test-word-in + # > ---- > test-state-out + # test-state-in > test-cell-out + + + # below is the 2nd part of the graph for test + # test-word-out + # > prob(word | test-word-out) + # test-cell-in + + test_word_out = tf.placeholder(tf.int32, [1, 1], name="test_word_out") + cellout_placeholder = tf.placeholder(tf.float32, [1, size], name="test_cell_in") + + softmax_w = tf.get_variable( + "softmax_w", [size, vocab_size], dtype=data_type()) + softmax_b = tf.get_variable("softmax_b", [vocab_size], dtype=data_type()) + + test_logits = tf.matmul(cellout_placeholder, softmax_w) + softmax_b + test_softmaxed = tf.nn.log_softmax(test_logits) + + p_word = test_softmaxed[0, test_word_out[0,0]] + test_out = tf.identity(p_word, name="test_out") + + if is_training and config.keep_prob < 1: + inputs = tf.nn.dropout(inputs, config.keep_prob) + + # Simplified version of models/tutorials/rnn/rnn.py's rnn(). + # This builds an unrolled LSTM for tutorial purposes only. + # In general, use the rnn() or state_saving_rnn() from rnn.py. + # + # The alternative version of the code below is: + # + # inputs = tf.unstack(inputs, num=num_steps, axis=1) + # outputs, state = tf.contrib.rnn.static_rnn( + # cell, inputs, initial_state=self._initial_state) + outputs = [] + state = self._initial_state + with tf.variable_scope("RNN"): + for time_step in range(num_steps): + if time_step > -1: tf.get_variable_scope().reuse_variables() + (cell_output, state) = self.cell(inputs[:, time_step, :], state) + outputs.append(cell_output) + + output = tf.reshape(tf.stack(axis=1, values=outputs), [-1, size]) + logits = tf.matmul(output, softmax_w) + softmax_b + loss = tf.contrib.legacy_seq2seq.sequence_loss_by_example( + [logits], + [tf.reshape(input_.targets, [-1])], + [tf.ones([batch_size * num_steps], dtype=data_type())]) + self._cost = cost = tf.reduce_sum(loss) / batch_size + self._final_state = state + + if not is_training: + return + + self._lr = tf.Variable(0.0, trainable=False) + tvars = tf.trainable_variables() + grads, _ = tf.clip_by_global_norm(tf.gradients(cost, tvars), + config.max_grad_norm) + optimizer = tf.train.GradientDescentOptimizer(self._lr) + self._train_op = optimizer.apply_gradients( + zip(grads, tvars), + global_step=tf.contrib.framework.get_or_create_global_step()) + + self._new_lr = tf.placeholder( + tf.float32, shape=[], name="new_learning_rate") + self._lr_update = tf.assign(self._lr, self._new_lr) + + def assign_lr(self, session, lr_value): + session.run(self._lr_update, feed_dict={self._new_lr: lr_value}) + + @property + def input(self): + return self._input + + @property + def initial_state(self): + return self._initial_state + + @property + def cost(self): + return self._cost + + @property + def final_state(self): + return self._final_state + + @property + def lr(self): + return self._lr + + @property + def train_op(self): + return self._train_op + +class TestConfig(object): + """Tiny config, for testing.""" + init_scale = 0.1 + learning_rate = 1.0 + max_grad_norm = 1 + num_layers = 1 + num_steps = 2 + hidden_size = 2 + max_epoch = 1 + max_max_epoch = 1 + keep_prob = 1.0 + lr_decay = 0.5 + batch_size = 20 + +class SmallConfig(object): + """Small config.""" + init_scale = 0.1 + learning_rate = 1.0 + max_grad_norm = 5 + num_layers = 2 + num_steps = 20 + hidden_size = 200 + max_epoch = 4 + max_max_epoch = 13 + keep_prob = 1.0 + lr_decay = 0.5 + batch_size = 64 + + +class MediumConfig(object): + """Medium config.""" + init_scale = 0.05 + learning_rate = 1.0 + max_grad_norm = 5 + num_layers = 2 + num_steps = 35 + hidden_size = 650 + max_epoch = 6 + max_max_epoch = 39 + keep_prob = 0.5 + lr_decay = 0.8 + batch_size = 20 + + +class LargeConfig(object): + """Large config.""" + init_scale = 0.04 + learning_rate = 1.0 + max_grad_norm = 10 + num_layers = 2 + num_steps = 35 + hidden_size = 1500 + max_epoch = 14 + max_max_epoch = 55 + keep_prob = 0.35 + lr_decay = 1 / 1.15 + batch_size = 20 + + + +def run_epoch(session, model, eval_op=None, verbose=False): + """Runs the model on the given data.""" + start_time = time.time() + costs = 0.0 + iters = 0 + state = session.run(model.initial_state) + + fetches = { + "cost": model.cost, + "final_state": model.final_state, + } + if eval_op is not None: + fetches["eval_op"] = eval_op + + for step in range(model.input.epoch_size): + feed_dict = {} + for i, (c, h) in enumerate(model.initial_state): + feed_dict[c] = state[i].c + feed_dict[h] = state[i].h + + vals = session.run(fetches, feed_dict) + cost = vals["cost"] + state = vals["final_state"] + + costs += cost + iters += model.input.num_steps + + if verbose and step % (model.input.epoch_size // 10) == 10: + print("%.3f perplexity: %.3f speed: %.0f wps" % + (step * 1.0 / model.input.epoch_size, np.exp(costs / iters), + iters * model.input.batch_size / (time.time() - start_time))) + + return np.exp(costs / iters) + + +def get_config(): + if FLAGS.model == "small": + return SmallConfig() + elif FLAGS.model == "medium": + return MediumConfig() + elif FLAGS.model == "large": + return LargeConfig() + elif FLAGS.model == "test": + return TestConfig() + else: + raise ValueError("Invalid model: %s", FLAGS.model) + + +def main(_): + if not FLAGS.data_path: + raise ValueError("Must set --data_path to RNNLM data directory") + + raw_data = reader.rnnlm_raw_data(FLAGS.data_path, FLAGS.vocab_path) + train_data, valid_data, _, word_map = raw_data + + config = get_config() + config.vocab_size = len(word_map) + eval_config = get_config() + eval_config.batch_size = 1 + eval_config.num_steps = 1 + + with tf.Graph().as_default(): + initializer = tf.random_uniform_initializer(-config.init_scale, + config.init_scale) + + with tf.name_scope("Train"): + train_input = RNNLMInput(config=config, data=train_data, name="TrainInput") + with tf.variable_scope("Model", reuse=None, initializer=initializer): + m = RNNLMModel(is_training=True, config=config, input_=train_input) + tf.summary.scalar("Training Loss", m.cost) + tf.summary.scalar("Learning Rate", m.lr) + + with tf.name_scope("Valid"): + valid_input = RNNLMInput(config=config, data=valid_data, name="ValidInput") + with tf.variable_scope("Model", reuse=True, initializer=initializer): + mvalid = RNNLMModel(is_training=False, config=config, input_=valid_input) + tf.summary.scalar("Validation Loss", mvalid.cost) + + sv = tf.train.Supervisor(logdir=FLAGS.save_path) + with sv.managed_session() as session: + for i in range(config.max_max_epoch): + lr_decay = config.lr_decay ** max(i + 1 - config.max_epoch, 0.0) + m.assign_lr(session, config.learning_rate * lr_decay) + + print("Epoch: %d Learning rate: %.3f" % (i + 1, session.run(m.lr))) + train_perplexity = run_epoch(session, m, eval_op=m.train_op, + verbose=True) + + print("Epoch: %d Train Perplexity: %.3f" % (i + 1, train_perplexity)) + valid_perplexity = run_epoch(session, mvalid) + print("Epoch: %d Valid Perplexity: %.3f" % (i + 1, valid_perplexity)) + + if FLAGS.save_path: + print("Saving model to %s." % FLAGS.save_path) + sv.saver.save(session, FLAGS.save_path) + +if __name__ == "__main__": + tf.app.run() diff --git a/egs/ami/s5/local/tensorflow/lstm_fast.py b/egs/ami/s5/local/tensorflow/lstm_fast.py new file mode 100644 index 00000000000..e5b7bcc91a2 --- /dev/null +++ b/egs/ami/s5/local/tensorflow/lstm_fast.py @@ -0,0 +1,407 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# Modified by Hainan Xu to be used in Kaldi for lattice rescoring 2017 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys + +sys.path.insert(0,"/home/hxu/.local/lib/python2.7/site-packages/") + +import inspect +import time + +import numpy as np +import tensorflow as tf + +import reader + +flags = tf.flags +logging = tf.logging + +flags.DEFINE_string( + "model", "small", + "A type of model. Possible options are: small, medium, large.") +flags.DEFINE_string("data_path", None, + "Where the training/test data is stored.") +flags.DEFINE_string("vocab_path", None, + "Where the wordlist file is stored.") +flags.DEFINE_string("save_path", None, + "Model output directory.") +flags.DEFINE_bool("use_fp16", False, + "Train using 16-bit floats instead of 32bit floats") + +FLAGS = flags.FLAGS + + +def data_type(): + return tf.float16 if FLAGS.use_fp16 else tf.float32 + +# this function does the following: +# return exp(x) if x < 0 +# x if x >= 0 +def f(x): + x1 = tf.minimum(0.0, x) + x2 = tf.maximum(0.0, x) + return tf.exp(x1) + x2 + +def new_softmax(labels, logits): + target = tf.reshape(labels, [-1]) + f_logits = tf.exp(logits) +# f_logits = f(logits) + row_sums = tf.reduce_sum(f_logits, 1) # this is the negative part of the objf + + t2 = tf.expand_dims(target, 1) + range = tf.expand_dims(tf.range(tf.shape(target)[0]), 1) + ind = tf.concat([range, t2], 1) + res = tf.gather_nd(logits, ind) + + return -res + row_sums - 1 +# return -res + tf.log(row_sums) # this is the original softmax + +class RNNLMInput(object): + """The input data.""" + + def __init__(self, config, data, name=None): + self.batch_size = batch_size = config.batch_size + self.num_steps = num_steps = config.num_steps + self.epoch_size = ((len(data) // batch_size) - 1) // num_steps + self.input_data, self.targets = reader.rnnlm_producer( + data, batch_size, num_steps, name=name) + + +class RNNLMModel(object): + """The RNNLM model.""" + + def __init__(self, is_training, config, input_): + self._input = input_ + + batch_size = input_.batch_size + num_steps = input_.num_steps + size = config.hidden_size + vocab_size = config.vocab_size + + # Slightly better results can be obtained with forget gate biases + # initialized to 1 but the hyperparameters of the model would need to be + # different than reported in the paper. + def lstm_cell(): + # With the latest TensorFlow source code (as of Mar 27, 2017), + # the BasicLSTMCell will need a reuse parameter which is unfortunately not + # defined in TensorFlow 1.0. To maintain backwards compatibility, we add + # an argument check here: + if 'reuse' in inspect.getargspec( + tf.contrib.rnn.BasicLSTMCell.__init__).args: + return tf.contrib.rnn.BasicLSTMCell( + size, forget_bias=0.0, state_is_tuple=True, + reuse=tf.get_variable_scope().reuse) + else: + return tf.contrib.rnn.BasicLSTMCell( + size, forget_bias=0.0, state_is_tuple=True) + attn_cell = lstm_cell + if is_training and config.keep_prob < 1: + def attn_cell(): + return tf.contrib.rnn.DropoutWrapper( + lstm_cell(), output_keep_prob=config.keep_prob) + self.cell = tf.contrib.rnn.MultiRNNCell( + [attn_cell() for _ in range(config.num_layers)], state_is_tuple=True) + + self._initial_state = self.cell.zero_state(batch_size, data_type()) + self._initial_state_single = self.cell.zero_state(1, data_type()) + + self.initial = tf.reshape(tf.stack(axis=0, values=self._initial_state_single), [config.num_layers, 2, 1, size], name="test_initial_state") + + + # first implement the less efficient version + test_word_in = tf.placeholder(tf.int32, [1, 1], name="test_word_in") + + state_placeholder = tf.placeholder(tf.float32, [config.num_layers, 2, 1, size], name="test_state_in") + # unpacking the input state context + l = tf.unstack(state_placeholder, axis=0) + test_input_state = tuple( + [tf.contrib.rnn.LSTMStateTuple(l[idx][0],l[idx][1]) + for idx in range(config.num_layers)] + ) + + with tf.device("/cpu:0"): + self.embedding = tf.get_variable( + "embedding", [vocab_size, size], dtype=data_type()) + + inputs = tf.nn.embedding_lookup(self.embedding, input_.input_data) + test_inputs = tf.nn.embedding_lookup(self.embedding, test_word_in) + + # test time + with tf.variable_scope("RNN"): + (test_cell_output, test_output_state) = self.cell(test_inputs[:, 0, :], test_input_state) + + test_state_out = tf.reshape(tf.stack(axis=0, values=test_output_state), [config.num_layers, 2, 1, size], name="test_state_out") + test_cell_out = tf.reshape(test_cell_output, [1, size], name="test_cell_out") + # above is the first part of the graph for test + # test-word-in + # > ---- > test-state-out + # test-state-in > test-cell-out + + + # below is the 2nd part of the graph for test + # test-word-out + # > prob(word | test-word-out) + # test-cell-in + + test_word_out = tf.placeholder(tf.int32, [1, 1], name="test_word_out") + cellout_placeholder = tf.placeholder(tf.float32, [1, size], name="test_cell_in") + + softmax_w = tf.get_variable( + "softmax_w", [size, vocab_size], dtype=data_type()) + softmax_b = tf.get_variable("softmax_b", [vocab_size], dtype=data_type()) + softmax_b = softmax_b - 9.0 + + test_logits = tf.matmul(cellout_placeholder, tf.transpose(tf.nn.embedding_lookup(tf.transpose(softmax_w), test_word_out[0]))) + softmax_b[test_word_out[0,0]] + + p_word = test_logits[0, 0] + test_out = tf.identity(p_word, name="test_out") + + if is_training and config.keep_prob < 1: + inputs = tf.nn.dropout(inputs, config.keep_prob) + + # Simplified version of models/tutorials/rnn/rnn.py's rnn(). + # This builds an unrolled LSTM for tutorial purposes only. + # In general, use the rnn() or state_saving_rnn() from rnn.py. + # + # The alternative version of the code below is: + # + # inputs = tf.unstack(inputs, num=num_steps, axis=1) + # outputs, state = tf.contrib.rnn.static_rnn( + # cell, inputs, initial_state=self._initial_state) + outputs = [] + state = self._initial_state + with tf.variable_scope("RNN"): + for time_step in range(num_steps): + if time_step > -1: tf.get_variable_scope().reuse_variables() + (cell_output, state) = self.cell(inputs[:, time_step, :], state) + outputs.append(cell_output) + + output = tf.reshape(tf.stack(axis=1, values=outputs), [-1, size]) + logits = tf.matmul(output, softmax_w) + softmax_b + loss = tf.contrib.legacy_seq2seq.sequence_loss_by_example( + [logits], + [tf.reshape(input_.targets, [-1])], + [tf.ones([batch_size * num_steps], dtype=data_type())], + softmax_loss_function=new_softmax) + self._cost = cost = tf.reduce_sum(loss) / batch_size + self._final_state = state + + if not is_training: + return + + self._lr = tf.Variable(0.0, trainable=False) + tvars = tf.trainable_variables() + grads, _ = tf.clip_by_global_norm(tf.gradients(cost, tvars), + config.max_grad_norm) + optimizer = tf.train.GradientDescentOptimizer(self._lr) + self._train_op = optimizer.apply_gradients( + zip(grads, tvars), + global_step=tf.contrib.framework.get_or_create_global_step()) + + self._new_lr = tf.placeholder( + tf.float32, shape=[], name="new_learning_rate") + self._lr_update = tf.assign(self._lr, self._new_lr) + + def assign_lr(self, session, lr_value): + session.run(self._lr_update, feed_dict={self._new_lr: lr_value}) + + @property + def input(self): + return self._input + + @property + def initial_state(self): + return self._initial_state + + @property + def cost(self): + return self._cost + + @property + def final_state(self): + return self._final_state + + @property + def lr(self): + return self._lr + + @property + def train_op(self): + return self._train_op + +class TestConfig(object): + """Tiny config, for testing.""" + init_scale = 0.1 + learning_rate = 1.0 + max_grad_norm = 1 + num_layers = 1 + num_steps = 2 + hidden_size = 2 + max_epoch = 1 + max_max_epoch = 1 + keep_prob = 1.0 + lr_decay = 0.5 + batch_size = 20 + +class SmallConfig(object): + """Small config.""" + init_scale = 0.1 + learning_rate = 1 + max_grad_norm = 5 + num_layers = 2 + num_steps = 20 + hidden_size = 200 + max_epoch = 4 + max_max_epoch = 13 + keep_prob = 1.0 + lr_decay = 0.8 + batch_size = 64 + + +class MediumConfig(object): + """Medium config.""" + init_scale = 0.05 + learning_rate = 1.0 + max_grad_norm = 5 + num_layers = 2 + num_steps = 35 + hidden_size = 650 + max_epoch = 6 + max_max_epoch = 39 + keep_prob = 0.5 + lr_decay = 0.8 + batch_size = 20 + + +class LargeConfig(object): + """Large config.""" + init_scale = 0.04 + learning_rate = 1.0 + max_grad_norm = 10 + num_layers = 2 + num_steps = 35 + hidden_size = 1500 + max_epoch = 14 + max_max_epoch = 55 + keep_prob = 0.35 + lr_decay = 1 / 1.15 + batch_size = 20 + + + +def run_epoch(session, model, eval_op=None, verbose=False): + """Runs the model on the given data.""" + start_time = time.time() + costs = 0.0 + iters = 0 + state = session.run(model.initial_state) + + fetches = { + "cost": model.cost, + "final_state": model.final_state, + } + if eval_op is not None: + fetches["eval_op"] = eval_op + + for step in range(model.input.epoch_size): + feed_dict = {} + for i, (c, h) in enumerate(model.initial_state): + feed_dict[c] = state[i].c + feed_dict[h] = state[i].h + + vals = session.run(fetches, feed_dict) + cost = vals["cost"] + state = vals["final_state"] + + + costs += cost + iters += model.input.num_steps + + if verbose and step % (model.input.epoch_size // 10) == 10: + print("%.3f perplexity: %.3f speed: %.0f wps" % + (step * 1.0 / model.input.epoch_size, np.exp(costs / iters), + iters * model.input.batch_size / (time.time() - start_time))) + + return np.exp(costs / iters) + + +def get_config(): + if FLAGS.model == "small": + return SmallConfig() + elif FLAGS.model == "medium": + return MediumConfig() + elif FLAGS.model == "large": + return LargeConfig() + elif FLAGS.model == "test": + return TestConfig() + else: + raise ValueError("Invalid model: %s", FLAGS.model) + + +def main(_): + if not FLAGS.data_path: + raise ValueError("Must set --data_path to RNNLM data directory") + + raw_data = reader.rnnlm_raw_data(FLAGS.data_path, FLAGS.vocab_path) + train_data, valid_data, _, word_map = raw_data + + config = get_config() + config.vocab_size = len(word_map) + eval_config = get_config() + eval_config.batch_size = 1 + eval_config.num_steps = 1 + + with tf.Graph().as_default(): + initializer = tf.random_uniform_initializer(-config.init_scale, + config.init_scale) + + with tf.name_scope("Train"): + train_input = RNNLMInput(config=config, data=train_data, name="TrainInput") + with tf.variable_scope("Model", reuse=None, initializer=initializer): + m = RNNLMModel(is_training=True, config=config, input_=train_input) + tf.summary.scalar("Training Loss", m.cost) + tf.summary.scalar("Learning Rate", m.lr) + + with tf.name_scope("Valid"): + valid_input = RNNLMInput(config=config, data=valid_data, name="ValidInput") + with tf.variable_scope("Model", reuse=True, initializer=initializer): + mvalid = RNNLMModel(is_training=False, config=config, input_=valid_input) + tf.summary.scalar("Validation Loss", mvalid.cost) + + sv = tf.train.Supervisor(logdir=FLAGS.save_path) + with sv.managed_session() as session: + for i in range(config.max_max_epoch): + lr_decay = config.lr_decay ** max(i + 1 - config.max_epoch, 0.0) + m.assign_lr(session, config.learning_rate * lr_decay) + + print("Epoch: %d Learning rate: %.3f" % (i + 1, session.run(m.lr))) + train_perplexity = run_epoch(session, m, eval_op=m.train_op, + verbose=True) + + print("Epoch: %d Train Perplexity: %.3f" % (i + 1, train_perplexity)) + valid_perplexity = run_epoch(session, mvalid) + print("Epoch: %d Valid Perplexity: %.3f" % (i + 1, valid_perplexity)) + + if FLAGS.save_path: + print("Saving model to %s." % FLAGS.save_path) + sv.saver.save(session, FLAGS.save_path) + +if __name__ == "__main__": + tf.app.run() diff --git a/egs/ami/s5/local/tensorflow/prep_data.sh b/egs/ami/s5/local/tensorflow/prep_data.sh new file mode 100755 index 00000000000..49825781c7c --- /dev/null +++ b/egs/ami/s5/local/tensorflow/prep_data.sh @@ -0,0 +1,69 @@ +#!/bin/bash + +#set -v +set -e + +train_text=data/ihm/train/text +nwords=9999 + +. path.sh +. cmd.sh + +. utils/parse_options.sh + +if [ $# != 1 ]; then + echo "Usage: $0 [options] " + echo "For options, see top of script file" + exit 1; +fi + +dir=$1 +srcdir=data/local/dict + +mkdir -p $dir + +cat $srcdir/lexicon.txt | awk '{print $1}' | sort -u | grep -v -w '!SIL' > $dir/wordlist.all + +# Get training data with OOV words (w.r.t. our current vocab) replaced with . +cat $train_text | awk -v w=$dir/wordlist.all \ + 'BEGIN{while((getline0) v[$1]=1;} + {for (i=2;i<=NF;i++) if ($i in v) printf $i" ";else printf " ";print ""}' | sed 's=$= =g' \ + | perl -e ' use List::Util qw(shuffle); @A=<>; print join("", shuffle(@A)); ' \ + | gzip -c > $dir/all.gz + +echo "Splitting data into train and validation sets." +heldout_sent=10000 +gunzip -c $dir/all.gz | head -n $heldout_sent > $dir/valid.in # validation data +gunzip -c $dir/all.gz | tail -n +$heldout_sent > $dir/train.in # training data + + +cat $dir/train.in $dir/wordlist.all | \ + awk '{ for(x=1;x<=NF;x++) count[$x]++; } END{for(w in count){print count[w], w;}}' | \ + sort -nr > $dir/unigram.counts + +total_nwords=`wc -l $dir/unigram.counts | awk '{print $1}'` + +head -$nwords $dir/unigram.counts | awk '{print $2}' | tee $dir/wordlist.rnn | awk '{print NR-1, $1}' > $dir/wordlist.rnn.id + +tail -n +$nwords $dir/unigram.counts > $dir/unk_class.counts + +for type in train valid; do + cat $dir/$type.in | awk -v w=$dir/wordlist.rnn 'BEGIN{while((getline0)d[$1]=1}{for(i=1;i<=NF;i++){if(d[$i]==1){s=$i}else{s=""} printf("%s ",s)} print""}' > $dir/$type +done + +# OK we'll train the RNNLM on this data. + +cat $dir/unk_class.counts | awk '{print $2, $1}' > $dir/unk.probs # dummy file, not used for cued-rnnlm + +cp $dir/wordlist.rnn $dir/wordlist.rnn.final + +has_oos=`grep "" $dir/wordlist.rnn.final | wc -l | awk '{print $1}'` +if [ $has_oos == "0" ]; then +# n=`wc -l $dir/wordlist.rnn.final | awk '{print $1}'` +# echo n is $n + echo "" >> $dir/wordlist.rnn.final +fi + + +echo "data preparation finished" + diff --git a/egs/ami/s5/local/tensorflow/reader.py b/egs/ami/s5/local/tensorflow/reader.py new file mode 100644 index 00000000000..5458b93ea31 --- /dev/null +++ b/egs/ami/s5/local/tensorflow/reader.py @@ -0,0 +1,133 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# Modified by Hainan Xu to be used in Kaldi for lattice rescoring 2017 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + + +"""Utilities for parsing RNNLM text files.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import os + +import tensorflow as tf + +def _read_words(filename): + with tf.gfile.GFile(filename, "r") as f: + return f.read().decode("utf-8").split() +# return f.read().decode("utf-8").replace("\n", "").split() + +def _build_vocab(filename): +# data = _read_words(filename) +# +# counter = collections.Counter(data) +# count_pairs = sorted(counter.items(), key=lambda x: (-x[1], x[0])) +# +# words, _ = list(zip(*count_pairs)) +# word_to_id = dict(zip(words, range(len(words)))) + +# word_to_id = {} +# new_id = 0 +# with open(filename, "r") as f: +# for word in f: +# word_to_id[word] = new_id +# new_id = new_id + 1 +# return word_to_id + + words = _read_words(filename) + word_to_id = dict(zip(words, range(len(words)))) + return word_to_id + + +def _file_to_word_ids(filename, word_to_id): + data = _read_words(filename) + return [word_to_id[word] for word in data if word in word_to_id] + + +def rnnlm_raw_data(data_path, vocab_path): + """Load RNNLM raw data from data directory "data_path". + + Reads RNNLM text files, converts strings to integer ids, + and performs mini-batching of the inputs. + + The RNNLM dataset comes from Tomas Mikolov's webpage: + + http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz + + Args: + data_path: string path to the directory where simple-examples.tgz has + been extracted. + + Returns: + tuple (train_data, valid_data, test_data, vocabulary) + where each of the data objects can be passed to RNNLMIterator. + """ + + train_path = os.path.join(data_path, "train") + valid_path = os.path.join(data_path, "valid") +# test_path = os.path.join(data_path, "eval.txt") + + word_to_id = _build_vocab(vocab_path) + train_data = _file_to_word_ids(train_path, word_to_id) + valid_data = _file_to_word_ids(valid_path, word_to_id) +# test_data = _file_to_word_ids(test_path, word_to_id) + vocabulary = len(word_to_id) + return train_data, valid_data, vocabulary, word_to_id +# return train_data, valid_data, test_data, vocabulary, word_to_id + + +def rnnlm_producer(raw_data, batch_size, num_steps, name=None): + """Iterate on the raw RNNLM data. + + This chunks up raw_data into batches of examples and returns Tensors that + are drawn from these batches. + + Args: + raw_data: one of the raw data outputs from rnnlm_raw_data. + batch_size: int, the batch size. + num_steps: int, the number of unrolls. + name: the name of this operation (optional). + + Returns: + A pair of Tensors, each shaped [batch_size, num_steps]. The second element + of the tuple is the same data time-shifted to the right by one. + + Raises: + tf.errors.InvalidArgumentError: if batch_size or num_steps are too high. + """ + with tf.name_scope(name, "RNNLMProducer", [raw_data, batch_size, num_steps]): + raw_data = tf.convert_to_tensor(raw_data, name="raw_data", dtype=tf.int32) + + data_len = tf.size(raw_data) + batch_len = data_len // batch_size + data = tf.reshape(raw_data[0 : batch_size * batch_len], + [batch_size, batch_len]) + + epoch_size = (batch_len - 1) // num_steps + assertion = tf.assert_positive( + epoch_size, + message="epoch_size == 0, decrease batch_size or num_steps") + with tf.control_dependencies([assertion]): + epoch_size = tf.identity(epoch_size, name="epoch_size") + + i = tf.train.range_input_producer(epoch_size, shuffle=False).dequeue() + x = tf.strided_slice(data, [0, i * num_steps], + [batch_size, (i + 1) * num_steps]) + x.set_shape([batch_size, num_steps]) + y = tf.strided_slice(data, [0, i * num_steps + 1], + [batch_size, (i + 1) * num_steps + 1]) + y.set_shape([batch_size, num_steps]) + return x, y diff --git a/egs/ami/s5/local/tensorflow/rnnlm.py b/egs/ami/s5/local/tensorflow/rnnlm.py new file mode 120000 index 00000000000..86c615508a3 --- /dev/null +++ b/egs/ami/s5/local/tensorflow/rnnlm.py @@ -0,0 +1 @@ +lstm.py \ No newline at end of file diff --git a/egs/ami/s5/local/tensorflow/run.sh b/egs/ami/s5/local/tensorflow/run.sh new file mode 100755 index 00000000000..b1aa2d06614 --- /dev/null +++ b/egs/ami/s5/local/tensorflow/run.sh @@ -0,0 +1,45 @@ +#!/bin/bash +mic=ihm +ngram_order=4 +model_type=small +stage=1 +weight=0.5 + +. ./utils/parse_options.sh +. ./cmd.sh +. ./path.sh + +set -e + +dir=data/tensorflow/$model_type +mkdir -p $dir + +if [ $stage -le 1 ]; then + local/tensorflow/prep_data.sh $dir +fi + +mkdir -p $dir/ +if [ $stage -le 2 ]; then + $decode_cmd $dir/train.log python local/tensorflow/rnnlm.py --data_path=$dir --model=$model_type --save_path=$dir/rnnlm --vocab_path=$dir/wordlist.rnn.final +fi + +final_lm=ami_fsh.o3g.kn +LM=$final_lm.pr1-7 + +if [ $stage -le 3 ]; then + for decode_set in dev eval; do + basedir=exp/$mic/nnet3/tdnn_sp/ + decode_dir=${basedir}/decode_${decode_set} + + # Lattice rescoring + steps/lmrescore_rnnlm_lat.sh \ + --cmd "$tensorflow_cmd --mem 16G" \ + --rnnlm-ver tensorflow --weight $weight --max-ngram-order $ngram_order \ + data/lang_$LM $dir \ + data/$mic/${decode_set}_hires ${decode_dir} \ + ${decode_dir}.tfrnnlm.lat.${ngram_order}gram.$weight & + + done +fi + +wait diff --git a/egs/ami/s5/local/tensorflow/run_fast.sh b/egs/ami/s5/local/tensorflow/run_fast.sh new file mode 100755 index 00000000000..629a7e064fc --- /dev/null +++ b/egs/ami/s5/local/tensorflow/run_fast.sh @@ -0,0 +1,49 @@ +#!/bin/bash +mic=ihm +ngram_order=3 +model_type=small +stage=1 +weight=0.5 + +. ./utils/parse_options.sh +. ./cmd.sh +. ./path.sh + +set -e + +dir=data/fast_tensorflow/$model_type +mkdir -p $dir + +if [ $stage -le 1 ]; then + local/tensorflow/prep_data.sh $dir +fi + +mkdir -p $dir/ +if [ $stage -le 2 ]; then + python local/tensorflow/lstm_fast.py --data_path=$dir --model=$model_type --save_path=$dir/rnnlm --vocab_path=$dir/wordlist.rnn.final +# $decode_cmd $dir/train.log python local/tensorflow/lstm_fast.py --data_path=$dir --model=$model_type --save_path=$dir/rnnlm --vocab_path=$dir/wordlist.rnn.final +fi + +final_lm=ami_fsh.o3g.kn +LM=$final_lm.pr1-7 + +date +if [ $stage -le 3 ]; then +# for decode_set in dev; do + for decode_set in dev eval; do + basedir=exp/$mic/nnet3/tdnn_sp/ + decode_dir=${basedir}/decode_${decode_set} + + # Lattice rescoring + steps/lmrescore_rnnlm_lat.sh \ + --cmd "$tensorflow_cmd --mem 16G" \ + --rnnlm-ver tensorflow --weight $weight --max-ngram-order $ngram_order \ + data/lang_$LM $dir \ + data/$mic/${decode_set}_hires ${decode_dir} \ + ${decode_dir}.unk.fast.tfrnnlm.lat.${ngram_order}gram.$weight & + + done +fi + +wait +date diff --git a/egs/ami/s5/local/tensorflow/run_vannila.sh b/egs/ami/s5/local/tensorflow/run_vannila.sh new file mode 100755 index 00000000000..71ecd7340ba --- /dev/null +++ b/egs/ami/s5/local/tensorflow/run_vannila.sh @@ -0,0 +1,46 @@ +#!/bin/bash +mic=ihm +ngram_order=3 +model_type=small +stage=1 +weight=0.5 + +. ./utils/parse_options.sh +. ./cmd.sh +. ./path.sh + +set -e + +dir=data/vannila_tensorflow_200/$model_type +mkdir -p $dir + +if [ $stage -le 1 ]; then + local/tensorflow/prep_data.sh $dir +fi + +if [ $stage -le 2 ]; then + mkdir -p $dir/ + python local/tensorflow/vanilla_rnnlm.py --data_path=$dir --model=$model_type --save_path=$dir/rnnlm --vocab_path=$dir/wordlist.rnn.final +fi + +final_lm=ami_fsh.o3g.kn +LM=$final_lm.pr1-7 + +if [ $stage -le 3 ]; then +# for decode_set in dev; do + for decode_set in dev eval; do + basedir=exp/$mic/nnet3/tdnn_sp/ + decode_dir=${basedir}/decode_${decode_set} + + # Lattice rescoring + steps/lmrescore_rnnlm_lat.sh \ + --cmd "$tensorflow_cmd --mem 16G" \ + --rnnlm-ver tensorflow --weight $weight --max-ngram-order $ngram_order \ + data/lang_$LM $dir \ + data/$mic/${decode_set}_hires ${decode_dir} \ + ${decode_dir}.vanilla.tfrnnlm.lat.${ngram_order}gram.$weight & + + done +fi + +wait diff --git a/egs/ami/s5/local/tensorflow/vanilla_rnnlm.py b/egs/ami/s5/local/tensorflow/vanilla_rnnlm.py new file mode 100644 index 00000000000..6e5c72f6adb --- /dev/null +++ b/egs/ami/s5/local/tensorflow/vanilla_rnnlm.py @@ -0,0 +1,376 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# Modified by Hainan Xu to be used in Kaldi for lattice rescoring 2017 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys + +sys.path.insert(0,"/home/hxu/.local/lib/python2.7/site-packages/") + +import inspect +import time + +import numpy as np +import tensorflow as tf + +import reader + +flags = tf.flags +logging = tf.logging + +flags.DEFINE_string( + "model", "small", + "A type of model. Possible options are: small, medium, large.") +flags.DEFINE_string("data_path", None, + "Where the training/test data is stored.") +flags.DEFINE_string("vocab_path", None, + "Where the wordlist file is stored.") +flags.DEFINE_string("save_path", None, + "Model output directory.") +flags.DEFINE_bool("use_fp16", False, + "Train using 16-bit floats instead of 32bit floats") + +FLAGS = flags.FLAGS + + +def data_type(): + return tf.float16 if FLAGS.use_fp16 else tf.float32 + + +class RNNLMInput(object): + """The input data.""" + + def __init__(self, config, data, name=None): + self.batch_size = batch_size = config.batch_size + self.num_steps = num_steps = config.num_steps + self.epoch_size = ((len(data) // batch_size) - 1) // num_steps + self.input_data, self.targets = reader.rnnlm_producer( + data, batch_size, num_steps, name=name) + +class RNNLMModel(object): + """The RNNLM model.""" + + def __init__(self, is_training, config, input_): + self._input = input_ + + batch_size = input_.batch_size + num_steps = input_.num_steps + size = config.hidden_size + vocab_size = config.vocab_size + + def rnn_cell(): + # With the latest TensorFlow source code (as of Mar 27, 2017), + # the BasicLSTMCell will need a reuse parameter which is unfortunately not + # defined in TensorFlow 1.0. To maintain backwards compatibility, we add + # an argument check here: + if 'reuse' in inspect.getargspec( + tf.contrib.rnn.BasicRNNCell.__init__).args: + return tf.contrib.rnn.BasicRNNCell(size, + reuse=tf.get_variable_scope().reuse) + else: + return tf.contrib.rnn.BasicRNNCell(size) + attn_cell = rnn_cell + + if is_training and config.keep_prob < 1: + def attn_cell(): + return tf.contrib.rnn.DropoutWrapper( + rnn_cell(), output_keep_prob=config.keep_prob) + + self.cell = tf.contrib.rnn.MultiRNNCell( + [attn_cell() for _ in range(config.num_layers)], state_is_tuple=True) + + self._initial_state = self.cell.zero_state(batch_size, data_type()) + self._initial_state_single = self.cell.zero_state(1, data_type()) + + self.initial = tf.reshape(tf.stack(axis=0, values=self._initial_state_single), [config.num_layers, 1, size], name="test_initial_state") + + # first implement the less efficient version + test_word_in = tf.placeholder(tf.int32, [1, 1], name="test_word_in") + + state_placeholder = tf.placeholder(tf.float32, [config.num_layers, 1, size], name="test_state_in") + # unpacking the input state context + l = tf.unstack(state_placeholder, axis=0) + test_input_state = tuple( + [l[idx] for idx in range(config.num_layers)] + ) + + with tf.device("/cpu:0"): + self.embedding = tf.get_variable( + "embedding", [vocab_size, size], dtype=data_type()) + + inputs = tf.nn.embedding_lookup(self.embedding, input_.input_data) + test_inputs = tf.nn.embedding_lookup(self.embedding, test_word_in) + + # test time + with tf.variable_scope("RNN"): + (test_cell_output, test_output_state) = self.cell(test_inputs[:, 0, :], test_input_state) + + test_state_out = tf.reshape(tf.stack(axis=0, values=test_output_state), [config.num_layers, 1, size], name="test_state_out") + test_cell_out = tf.reshape(test_cell_output, [1, size], name="test_cell_out") + # above is the first part of the graph for test + # test-word-in + # > ---- > test-state-out + # test-state-in > test-cell-out + + + # below is the 2nd part of the graph for test + # test-word-out + # > prob(word | test-word-out) + # test-cell-in + + test_word_out = tf.placeholder(tf.int32, [1, 1], name="test_word_out") + cellout_placeholder = tf.placeholder(tf.float32, [1, size], name="test_cell_in") + + softmax_w = tf.get_variable( + "softmax_w", [size, vocab_size], dtype=data_type()) + softmax_b = tf.get_variable("softmax_b", [vocab_size], dtype=data_type()) + + test_logits = tf.matmul(cellout_placeholder, softmax_w) + softmax_b + test_softmaxed = tf.nn.log_softmax(test_logits) + + p_word = test_softmaxed[0, test_word_out[0,0]] + test_out = tf.identity(p_word, name="test_out") + + if is_training and config.keep_prob < 1: + inputs = tf.nn.dropout(inputs, config.keep_prob) + + # Simplified version of models/tutorials/rnn/rnn.py's rnn(). + # This builds an unrolled LSTM for tutorial purposes only. + # In general, use the rnn() or state_saving_rnn() from rnn.py. + # + # The alternative version of the code below is: + # + # inputs = tf.unstack(inputs, num=num_steps, axis=1) + # outputs, state = tf.contrib.rnn.static_rnn( + # cell, inputs, initial_state=self._initial_state) + outputs = [] + state = self._initial_state + with tf.variable_scope("RNN"): + for time_step in range(num_steps): + if time_step > -1: tf.get_variable_scope().reuse_variables() + (cell_output, state) = self.cell(inputs[:, time_step, :], state) + outputs.append(cell_output) + + output = tf.reshape(tf.stack(axis=1, values=outputs), [-1, size]) + logits = tf.matmul(output, softmax_w) + softmax_b + loss = tf.contrib.legacy_seq2seq.sequence_loss_by_example( + [logits], + [tf.reshape(input_.targets, [-1])], + [tf.ones([batch_size * num_steps], dtype=data_type())]) + self._cost = cost = tf.reduce_sum(loss) / batch_size + self._final_state = state + + if not is_training: + return + + self._lr = tf.Variable(0.0, trainable=False) + tvars = tf.trainable_variables() + grads, _ = tf.clip_by_global_norm(tf.gradients(cost, tvars), + config.max_grad_norm) +# optimizer = tf.train.AdamOptimizer() # TODO + optimizer = tf.train.MomentumOptimizer(self._lr, 0.9) # TODO +# optimizer = tf.train.GradientDescentOptimizer(self._lr) # TODO + self._train_op = optimizer.apply_gradients( + zip(grads, tvars), + global_step=tf.contrib.framework.get_or_create_global_step()) + + self._new_lr = tf.placeholder( + tf.float32, shape=[], name="new_learning_rate") + self._lr_update = tf.assign(self._lr, self._new_lr) + + def assign_lr(self, session, lr_value): + session.run(self._lr_update, feed_dict={self._new_lr: lr_value}) + + @property + def input(self): + return self._input + + @property + def initial_state(self): + return self._initial_state + + @property + def cost(self): + return self._cost + + @property + def final_state(self): + return self._final_state + + @property + def lr(self): + return self._lr + + @property + def train_op(self): + return self._train_op + +class TestConfig(object): + """Tiny config, for testing.""" + init_scale = 0.1 + learning_rate = 1.0 + max_grad_norm = 1 + num_layers = 1 + num_steps = 2 + hidden_size = 2 + max_epoch = 1 + max_max_epoch = 1 + keep_prob = 1.0 + lr_decay = 0.5 + batch_size = 20 + +class SmallConfig(object): + """Small config.""" + init_scale = 0.1 + learning_rate = 0.2 + max_grad_norm = 1 + num_layers = 1 + num_steps = 20 + hidden_size = 200 + max_epoch = 4 + max_max_epoch = 20 + keep_prob = 1 + lr_decay = 0.95 + batch_size = 64 + +class MediumConfig(object): + """Medium config.""" + init_scale = 0.05 + learning_rate = 1.0 + max_grad_norm = 5 + num_layers = 2 + num_steps = 35 + hidden_size = 650 + max_epoch = 6 + max_max_epoch = 39 + keep_prob = 0.5 + lr_decay = 0.8 + batch_size = 20 + +class LargeConfig(object): + """Large config.""" + init_scale = 0.04 + learning_rate = 1.0 + max_grad_norm = 10 + num_layers = 2 + num_steps = 35 + hidden_size = 1500 + max_epoch = 14 + max_max_epoch = 55 + keep_prob = 0.35 + lr_decay = 1 / 1.15 + batch_size = 20 + +def run_epoch(session, model, eval_op=None, verbose=False): + """Runs the model on the given data.""" + start_time = time.time() + costs = 0.0 + iters = 0 + state = session.run(model.initial_state) + + fetches = { + "cost": model.cost, + "final_state": model.final_state, + } + if eval_op is not None: + fetches["eval_op"] = eval_op + + for step in range(model.input.epoch_size): + feed_dict = {} + for i, h in enumerate(model.initial_state): + feed_dict[h] = state[i] + + vals = session.run(fetches, feed_dict) + cost = vals["cost"] + state = vals["final_state"] + + costs += cost + iters += model.input.num_steps + + if verbose and step % (model.input.epoch_size // 10) == 10: + print("%.3f perplexity: %.3f speed: %.0f wps" % + (step * 1.0 / model.input.epoch_size, np.exp(costs / iters), + iters * model.input.batch_size / (time.time() - start_time))) + + return np.exp(costs / iters) + + +def get_config(): + if FLAGS.model == "small": + return SmallConfig() + elif FLAGS.model == "medium": + return MediumConfig() + elif FLAGS.model == "large": + return LargeConfig() + elif FLAGS.model == "test": + return TestConfig() + else: + raise ValueError("Invalid model: %s", FLAGS.model) + + +def main(_): + if not FLAGS.data_path: + raise ValueError("Must set --data_path to RNNLM data directory") + + raw_data = reader.rnnlm_raw_data(FLAGS.data_path, FLAGS.vocab_path) + train_data, valid_data, _, word_map = raw_data + + config = get_config() + config.vocab_size = len(word_map) + eval_config = get_config() + eval_config.batch_size = 1 + eval_config.num_steps = 1 + + with tf.Graph().as_default(): + initializer = tf.random_uniform_initializer(-config.init_scale, + config.init_scale) + + with tf.name_scope("Train"): + train_input = RNNLMInput(config=config, data=train_data, name="TrainInput") + with tf.variable_scope("Model", reuse=None, initializer=initializer): + m = RNNLMModel(is_training=True, config=config, input_=train_input) + tf.summary.scalar("Training Loss", m.cost) + tf.summary.scalar("Learning Rate", m.lr) + + with tf.name_scope("Valid"): + valid_input = RNNLMInput(config=config, data=valid_data, name="ValidInput") + with tf.variable_scope("Model", reuse=True, initializer=initializer): + mvalid = RNNLMModel(is_training=False, config=config, input_=valid_input) + tf.summary.scalar("Validation Loss", mvalid.cost) + + sv = tf.train.Supervisor(logdir=FLAGS.save_path) + with sv.managed_session() as session: + for i in range(config.max_max_epoch): + lr_decay = config.lr_decay ** max(i + 1 - config.max_epoch, 0.0) + + m.assign_lr(session, config.learning_rate * lr_decay) + + print("Epoch: %d Learning rate: %.3f" % (i + 1, session.run(m.lr))) + train_perplexity = run_epoch(session, m, eval_op=m.train_op, + verbose=True) + + print("Epoch: %d Train Perplexity: %.3f" % (i + 1, train_perplexity)) + valid_perplexity = run_epoch(session, mvalid) + print("Epoch: %d Valid Perplexity: %.3f" % (i + 1, valid_perplexity)) + + if FLAGS.save_path: + print("Saving model to %s." % FLAGS.save_path) + sv.saver.save(session, FLAGS.save_path) + +if __name__ == "__main__": + tf.app.run() diff --git a/egs/ami/s5/path.sh b/egs/ami/s5/path.sh index ad2c93b309b..4f627ff81ff 100644 --- a/egs/ami/s5/path.sh +++ b/egs/ami/s5/path.sh @@ -4,6 +4,7 @@ export PATH=$PWD/utils/:$KALDI_ROOT/tools/openfst/bin:$PWD:$PATH [ ! -f $KALDI_ROOT/tools/config/common_path.sh ] && echo >&2 "The standard file $KALDI_ROOT/tools/config/common_path.sh is not present -> Exit!" && exit 1 . $KALDI_ROOT/tools/config/common_path.sh export LC_ALL=C +export LD_LIBRARY_PATH=$KALDI_ROOT/tools/tensorflow/bazel-bin/tensorflow/ LMBIN=$KALDI_ROOT/tools/irstlm/bin SRILM=$KALDI_ROOT/tools/srilm/bin/i686-m64 diff --git a/egs/tedlium/s5_r2/local/run_segmentation_long_utts.sh b/egs/tedlium/s5_r2/local/run_segmentation_long_utts.sh new file mode 100644 index 00000000000..560d6fb450f --- /dev/null +++ b/egs/tedlium/s5_r2/local/run_segmentation_long_utts.sh @@ -0,0 +1,270 @@ +#!/bin/bash + +# Copyright 2016 Vimal Manohar +# Apache 2.0 + +# This script demonstrates how to re-segment long audios into short segments. +# The basic idea is to decode with an existing in-domain acoustic model, and a +# bigram language model built from the reference, and then work out the +# segmentation from a ctm like file. + +## %WER results. + +## Baseline results +# %WER 18.1 | 507 17783 | 84.7 10.7 4.6 2.8 18.1 91.1 | -0.073 | exp/tri3/decode_dev_rescore/score_16_0.0/ctm.filt.filt.sys +# %WER 16.6 | 1155 27500 | 85.7 10.7 3.6 2.4 16.6 86.0 | -0.041 | exp/tri3/decode_test_rescore/score_16_0.0/ctm.filt.filt.sys + +## With Cleanup +# %WER 18.0 | 507 17783 | 85.0 10.6 4.4 3.0 18.0 90.9 | -0.064 | exp/tri3_cleaned/decode_dev_rescore/score_14_0.0/ctm.filt.filt.sys +# %WER 16.6 | 1155 27500 | 85.9 10.8 3.3 2.5 16.6 86.6 | -0.050 | exp/tri3_cleaned/decode_test_rescore/score_14_0.0/ctm.filt.filt.sys + +## Segmentation results +# %WER 18.9 | 507 17783 | 83.9 11.1 5.0 2.8 18.9 92.9 | -0.103 | exp/tri3_reseg_a/decode_nosp_dev_rescore/score_14_0.0/ctm.filt.filt.sys +# %WER 17.6 | 1155 27500 | 84.6 11.3 4.1 2.2 17.6 86.8 | -0.005 | exp/tri3_reseg_a/decode_nosp_test_rescore/score_14_0.0/ctm.filt.filt.sys + +## Segmentation + Cleanup + +# cleaned - +# Default segmentation-opts "--max-junk-proportion=1 --max-deleted-words-kept-when-merging=1 --min-split-point-duration=0.1" +# cleaned_b - +# "--max-junk-proportion=0.5 --max-deleted-words-kept-when-merging=10" +# cleaned_c - +# "--max-junk-proportion=0.2 --max-deleted-words-kept-when-merging=6 --min-split-point-duration=0.3" + +# %WER 18.7 | 507 17783 | 84.0 11.0 5.0 2.8 18.7 91.7 | -0.119 | exp/tri3_reseg_a_cleaned/decode_nosp_dev_rescore/score_15_0.0/ctm.filt.filt.sys +# %WER 18.6 | 507 17783 | 84.0 11.0 4.9 2.7 18.6 91.5 | -0.092 | exp/tri3_reseg_a_cleaned_b/decode_nosp_dev_rescore/score_15_0.0/ctm.filt.filt.sys +# %WER 18.6 | 507 17783 | 84.1 10.8 5.0 2.7 18.6 92.1 | -0.114 | exp/tri3_reseg_a_cleaned_c/decode_nosp_dev_rescore/score_15_0.0/ctm.filt.filt.sys + +# %WER 17.7 | 1155 27500 | 84.5 11.4 4.0 2.2 17.7 86.8 | -0.020 | exp/tri3_reseg_a_cleaned/decode_nosp_test_rescore/score_14_0.0/ctm.filt.filt.sys +# %WER 17.3 | 1155 27500 | 84.8 11.2 4.1 2.1 17.3 86.8 | -0.002 | exp/tri3_reseg_a_cleaned_b/decode_nosp_test_rescore/score_15_0.0/ctm.filt.filt.sys +# %WER 17.7 | 1155 27500 | 84.6 11.4 4.1 2.3 17.7 86.6 | -0.018 | exp/tri3_reseg_a_cleaned_c/decode_nosp_test_rescore/score_14_0.0/ctm.filt.filt.sys + +## Use silence and pronunciation probs estimated from resegmented data +# %WER 18.2 | 507 17783 | 84.6 10.8 4.5 2.9 18.2 92.5 | -0.037 | exp/tri3_reseg_a/decode_a_dev_rescore/score_16_0.0/ctm.filt.filt.sys +# %WER 16.9 | 1155 27500 | 85.5 11.0 3.5 2.4 16.9 86.1 | -0.024 | exp/tri3_reseg_a/decode_a_test_rescore/score_14_0.0/ctm.filt.filt.sys + +## Use silence and pronunciation probs estimated from resegmented and cleaned up data +# %WER 18.2 | 507 17783 | 84.4 10.8 4.9 2.6 18.2 92.5 | -0.074 | exp/tri3_reseg_a_cleaned_b/decode_a_cleaned_b_dev_rescore/score_15_0.5/ctm.filt.filt.sys +# %WER 16.8 | 1155 27500 | 85.4 10.7 3.9 2.1 16.8 86.8 | -0.046 | exp/tri3_reseg_a_cleaned_b/decode_a_cleaned_b_test_rescore/score_14_0.5/ctm.filt.filt.sys + +. ./cmd.sh +. ./path.sh + +set -e -o pipefail -u + +segment_stage=-9 +cleanup_stage=-1 +cleanup_affix=cleaned_b +affix=_a + +decode_nj=8 # note: should not be >38 which is the number of speakers in the dev set + # after applying --seconds-per-spk-max 180. We decode with 4 threads, so + # this will be too many jobs if you're using run.pl. + +############################################################################### +# Simulate unsegmented data directory. +############################################################################### +utils/data/convert_data_dir_to_whole.sh data/train data/train_long + +############################################################################### +# Train system on a small subset of 2000 utterances that are +# manually segmented. +############################################################################### + +utils/subset_data_dir.sh --speakers data/train 2000 data/train_2k +utils/subset_data_dir.sh --shortest data/train_2k 500 data/train_2k_short500 + +steps/make_mfcc.sh --cmd "$train_cmd" --nj 32 \ + data/train_long exp/make_mfcc/train_long mfcc || exit 1 +steps/compute_cmvn_stats.sh data/train_long \ + exp/make_mfcc/train_long mfcc + + +steps/train_mono.sh --nj 20 --cmd "$train_cmd" \ + data/train_2k_short500 data/lang_nosp exp/mono_a + +steps/align_si.sh --nj 20 --cmd "$train_cmd" \ + data/train_2k data/lang_nosp exp/mono_a exp/mono_a_ali_2k + +steps/train_deltas.sh --cmd "$train_cmd" \ + 500 5000 data/train_2k data/lang_nosp exp/mono_a_ali_2k exp/tri1a + +steps/align_si.sh --nj 20 --cmd "$train_cmd" \ + data/train_2k data/lang_nosp exp/tri1a exp/tri1a_ali + +steps/train_lda_mllt.sh --cmd "$train_cmd" \ + 1000 10000 data/train_2k data/lang_nosp exp/tri1a_ali exp/tri1b + +############################################################################### +# Segment long recordings using TF-IDF retrieval of reference text +# for uniformly segmented audio chunks based on Smith-Waterman alignment. +# Use a model trained on train_2k (tri1b) +############################################################################### + +steps/cleanup/segment_long_utterances.sh --cmd "$train_cmd" \ + --stage $segment_stage --nj 80 \ + --max-bad-proportion 0.5 \ + exp/tri1b data/lang_nosp data/train_long data/train_reseg${affix} \ + exp/segment_long_utts${affix}_train + +steps/compute_cmvn_stats.sh data/train_reseg${affix} \ + exp/make_mfcc/train_reseg${affix} mfcc +utils/fix_data_dir.sh data/train_reseg${affix} + +############################################################################### +# Train new model on segmented data directory starting from the same model +# used for segmentation. (tri2_reseg) +############################################################################### + +# Align tri1b system with reseg${affix} data +steps/align_si.sh --nj 40 --cmd "$train_cmd" \ + data/train_reseg${affix} \ + data/lang_nosp exp/tri1b exp/tri1b_ali_reseg${affix} || exit 1; + +# Train LDA+MLLT system on reseg${affix} data +steps/train_lda_mllt.sh --cmd "$train_cmd" \ + 4000 50000 data/train_reseg${affix} data/lang_nosp \ + exp/tri1b_ali_reseg${affix} exp/tri2_reseg${affix} + +( +utils/mkgraph.sh data/lang_nosp exp/tri2_reseg${affix} \ + exp/tri2_reseg${affix}/graph_nosp +for dset in dev test; do + steps/decode.sh --nj $decode_nj --cmd "$decode_cmd" --num-threads 4 \ + exp/tri2_reseg${affix}/graph_nosp data/${dset} \ + exp/tri2_reseg${affix}/decode_nosp_${dset} + steps/lmrescore_const_arpa.sh --cmd "$decode_cmd" data/lang_nosp \ + data/lang_nosp_rescore \ + data/${dset} exp/tri2_reseg${affix}/decode_nosp_${dset} \ + exp/tri2_reseg${affix}/decode_nosp_${dset}_rescore +done +) & + +############################################################################### +# Train SAT model on segmented data directory +############################################################################### + +# Train SAT system on reseg${affix} data +steps/train_sat.sh --cmd "$train_cmd" 5000 100000 \ + data/train_reseg${affix} data/lang_nosp \ + exp/tri2_reseg${affix} exp/tri3_reseg${affix} + +( +utils/mkgraph.sh data/lang_nosp exp/tri3_reseg${affix} \ + exp/tri3_reseg${affix}/graph_nosp +for dset in dev test; do + steps/decode_fmllr.sh --nj $decode_nj --cmd "$decode_cmd" --num-threads 4 \ + exp/tri3_reseg${affix}/graph_nosp data/${dset} \ + exp/tri3_reseg${affix}/decode_nosp_${dset} + steps/lmrescore_const_arpa.sh --cmd "$decode_cmd" data/lang_nosp \ + data/lang_nosp_rescore \ + data/${dset} exp/tri3_reseg${affix}/decode_nosp_${dset} \ + exp/tri3_reseg${affix}/decode_nosp_${dset}_rescore +done +) & + +############################################################################### +# Clean and segmented data +############################################################################### + +segmentation_opts=( +--max-junk-proportion=0.5 +--max-deleted-words-kept-when-merging=10 +) +opts="${segmentation_opts[@]}" + +steps/cleanup/clean_and_segment_data.sh --nj 40 --cmd "$train_cmd" \ + --segmentation-opts "$opts" \ + data/train_reseg${affix} data/lang_nosp exp/tri3_reseg${affix} \ + exp/tri3_reseg${affix}_${cleanup_affix}_work \ + data/train_reseg${affix}_${cleanup_affix} + +############################################################################### +# Train new SAT model on cleaned data directory +############################################################################### + +steps/align_fmllr.sh --nj 40 --cmd "$train_cmd" \ + data/train_reseg${affix}_${cleanup_affix} data/lang_nosp \ + exp/tri3_reseg${affix} exp/tri3_reseg${affix}_ali_${cleanup_affix} + +steps/train_sat.sh --cmd "$train_cmd" \ + 5000 100000 data/train_reseg${affix}_${cleanup_affix} data/lang_nosp \ + exp/tri3_reseg${affix}_ali_${cleanup_affix} \ + exp/tri3_reseg${affix}_$cleanup_affix + +( +utils/mkgraph.sh data/lang_nosp exp/tri3_reseg${affix}_$cleanup_affix \ + exp/tri3_reseg${affix}_$cleanup_affix/graph_nosp +for dset in dev test; do + steps/decode_fmllr.sh --nj $decode_nj --cmd "$decode_cmd" --num-threads 4 \ + exp/tri3_reseg${affix}_$cleanup_affix/graph_nosp data/${dset} \ + exp/tri3_reseg${affix}_$cleanup_affix/decode_nosp_${dset} + steps/lmrescore_const_arpa.sh --cmd "$decode_cmd" data/lang_nosp \ + data/lang_nosp_rescore \ + data/${dset} exp/tri3_reseg${affix}_$cleanup_affix/decode_nosp_${dset} \ + exp/tri3_reseg${affix}_$cleanup_affix/decode_nosp_${dset}_rescore +done +) & + +steps/get_prons.sh --cmd "$train_cmd" \ + data/train_reseg${affix}_${cleanup_affix} \ + data/lang_nosp exp/tri3_reseg${affix}_$cleanup_affix +utils/dict_dir_add_pronprobs.sh --max-normalize true \ + data/local/dict_nosp \ + exp/tri3_reseg${affix}_$cleanup_affix/{pron,sil,pron_bigram}_counts_nowb.txt \ + data/local/dict${affix}_$cleanup_affix + +utils/prepare_lang.sh data/local/dict${affix}_$cleanup_affix \ + "" data/local/lang data/lang${affix}_$cleanup_affix +cp -rT data/lang${affix}_$cleanup_affix data/lang${affix}_${cleanup_affix}_rescore +cp data/lang_nosp/G.fst data/lang${affix}_$cleanup_affix/ +cp data/lang_nosp_rescore/G.carpa data/lang${affix}_${cleanup_affix}_rescore/ + +( +utils/mkgraph.sh data/lang${affix}_${cleanup_affix} \ + exp/tri3_reseg${affix}_$cleanup_affix{,/graph${affix}_${cleanup_affix}} + +for dset in dev test; do + steps/decode_fmllr.sh --nj $decode_nj --cmd "$decode_cmd" --num-threads 4 \ + exp/tri3_reseg${affix}_$cleanup_affix/graph${affix}_${cleanup_affix} \ + data/${dset} \ + exp/tri3_reseg${affix}_$cleanup_affix/decode${affix}_${cleanup_affix}_${dset} + steps/lmrescore_const_arpa.sh --cmd "$decode_cmd" data/lang${affix}_${cleanup_affix} \ + data/lang${affix}_${cleanup_affix}_rescore \ + data/${dset} exp/tri3_reseg${affix}_$cleanup_affix/decode${affix}_${cleanup_affix}_${dset} \ + exp/tri3_reseg${affix}_$cleanup_affix/decode${affix}_${cleanup_affix}_${dset}_rescore +done +) & + +steps/get_prons.sh --cmd "$train_cmd" \ + data/train_reseg${affix} \ + data/lang_nosp exp/tri3_reseg${affix} +utils/dict_dir_add_pronprobs.sh --max-normalize true \ + data/local/dict_nosp \ + exp/tri3_reseg${affix}/{pron,sil,pron_bigram}_counts_nowb.txt \ + data/local/dict${affix} + +utils/prepare_lang.sh data/local/dict${affix} \ + "" data/local/lang data/lang${affix} +cp -rT data/lang${affix} data/lang${affix}_rescore +cp data/lang_nosp/G.fst data/lang${affix}/ +cp data/lang_nosp_rescore/G.carpa data/lang${affix}_rescore/ + +( +utils/mkgraph.sh data/lang${affix} \ + exp/tri3_reseg${affix}{,/graph${affix}} + +for dset in dev test; do + steps/decode_fmllr.sh --nj $decode_nj --cmd "$decode_cmd" --num-threads 4 \ + exp/tri3_reseg${affix}/graph${affix} \ + data/${dset} \ + exp/tri3_reseg${affix}/decode${affix}_${dset} + steps/lmrescore_const_arpa.sh --cmd "$decode_cmd" data/lang${affix} \ + data/lang${affix}_rescore \ + data/${dset} exp/tri3_reseg${affix}/decode${affix}_${dset} \ + exp/tri3_reseg${affix}/decode${affix}_${dset}_rescore +done +) & + +wait +exit 0 diff --git a/egs/tedlium/s5_r2_wsj/RESULTS b/egs/tedlium/s5_r2_wsj/RESULTS new file mode 100644 index 00000000000..ec4b9c24a12 --- /dev/null +++ b/egs/tedlium/s5_r2_wsj/RESULTS @@ -0,0 +1,149 @@ +# Results based on the Tedlium Release 2 Paper using the original LM given by the Lium Team +# PAPER Results: 10.1 / 11.1 +# http://www.lrec-conf.org/proceedings/lrec2014/pdf/1104_Paper.pdf + + +# steps/info/gmm_dir_info.pl exp/mono exp/tri{1,2,3,3_cleaned} +# exp/mono: nj=20 align prob=-96.77 over 4.65h [retry=1.1%, fail=0.1%] states=127 gauss=1001 +# exp/tri1: nj=35 align prob=-96.06 over 210.65h [retry=5.0%, fail=0.3%] states=1986 gauss=30087 tree-impr=3.84 +# exp/tri2: nj=35 align prob=-50.05 over 210.21h [retry=6.1%, fail=0.5%] states=3342 gauss=50121 tree-impr=4.81 lda-sum=18.73 mllt:impr,logdet=0.93,1.51 +# exp/tri3: nj=35 align prob=-49.01 over 210.04h [retry=4.4%, fail=0.5%] states=4177 gauss=100136 fmllr-impr=2.98 over 172.09h tree-impr=7.26 +# exp/tri3_cleaned: nj=100 align prob=-49.05 over 202.62h [retry=1.5%, fail=0.0%] states=4186 gauss=100093 fmllr-impr=0.46 over 171.37h tree-impr=7.81 + +# steps/info/nnet3_dir_info.pl exp/nnet3{,_cleaned}/tdnn_sp +# exp/nnet3/tdnn_sp: num-iters=250 nj=2..12 num-params=11.0M dim=40+100->4177 combine=-1.06->-1.05 loglike:train/valid[165,249,final]=(-1.16,-1.08,-1.08/-1.29,-1.28,-1.27) accuracy:train/valid[165,249,final]=(0.69,0.70,0.70/0.65,0.66,0.66) +# exp/nnet3_cleaned/tdnn_sp: num-iters=240 nj=2..12 num-params=11.0M dim=40+100->4186 combine=-0.97->-0.96 loglike:train/valid[159,239,final]=(-1.06,-0.98,-0.98/-1.19,-1.16,-1.16) accuracy:train/valid[159,239,final]=(0.70,0.72,0.72/0.66,0.67,0.68) + +# steps/info/chain_dir_info.pl exp/chain{,_cleaned}/tdnn_sp_bi +# exp/chain/tdnn_sp_bi: num-iters=264 nj=2..12 num-params=7.0M dim=40+100->3615 combine=-0.11->-inf xent:train/valid[175,263,final]=(-1.42,-1.35,-1.35/-1.48,-1.44,-1.44) logprob:train/valid[175,263,final]=(-0.10,-0.09,-0.09/-0.11,-0.12,-0.12) +# exp/chain_cleaned/tdnn_sp_bi: num-iters=253 nj=2..12 num-params=7.0M dim=40+100->3589 combine=-0.10->-0.10 xent:train/valid[167,252,final]=(-1.37,-1.30,-1.30/-1.43,-1.38,-1.38) logprob:train/valid[167,252,final]=(-0.10,-0.09,-0.09/-0.11,-0.11,-0.10) + +######### tri1 results ######## +for d in exp/tri1/decode_*; do grep Sum $d/*ore*/*ys | utils/best_wer.sh ; done +# small LM +%WER 27.8 | 507 17783 | 75.7 17.5 6.8 3.4 27.8 96.6 | 0.071 | exp/tri1/decode_nosp_dev/score_10_0.0/ctm.filt.filt.sys +%WER 27.3 | 1155 27500 | 75.3 18.4 6.3 2.7 27.3 93.0 | 0.119 | exp/tri1/decode_nosp_test/score_11_0.0/ctm.filt.filt.sys +# big LM +%WER 26.3 | 507 17783 | 76.8 16.1 7.1 3.1 26.3 95.9 | 0.080 | exp/tri1/decode_nosp_dev_rescore/score_11_0.0/ctm.filt.filt.sys +%WER 26.2 | 1155 27500 | 76.6 17.3 6.1 2.8 26.2 92.6 | 0.081 | exp/tri1/decode_nosp_test_rescore/score_11_0.0/ctm.filt.filt.sys + +####### tri2 results ########## +#for d in exp/tri2/decode_*; do grep Sum $d/score*/*ys | utils/best_wer.sh ; done + +# small LM +%WER 23.6 | 507 17783 | 79.6 14.8 5.6 3.2 23.6 95.1 | 0.024 | exp/tri2/decode_nosp_dev/score_12_0.0/ctm.filt.filt.sys +%WER 23.2 | 1155 27500 | 79.5 15.5 5.0 2.7 23.2 91.1 | 0.070 | exp/tri2/decode_nosp_test/score_12_0.0/ctm.filt.filt.sys +# big LM +%WER 22.3 | 507 17783 | 80.7 13.5 5.8 3.0 22.3 93.7 | -0.002 | exp/tri2/decode_nosp_dev_rescore/score_13_0.0/ctm.filt.filt.sys +%WER 21.9 | 1155 27500 | 80.7 14.6 4.7 2.6 21.9 90.2 | 0.026 | exp/tri2/decode_nosp_test_rescore/score_12_0.0/ctm.filt.filt.sys + +# small LM with silence and pronunciation probs. +%WER 22.5 | 507 17783 | 80.5 14.0 5.5 3.1 22.5 94.7 | 0.092 | exp/tri2/decode_dev/score_15_0.0/ctm.filt.filt.sys +%WER 22.1 | 1155 27500 | 80.7 14.9 4.3 2.8 22.1 90.6 | 0.089 | exp/tri2/decode_test/score_13_0.0/ctm.filt.filt.sys + +# big LM with silence and pronunciation probs. +%WER 21.3 | 507 17783 | 81.8 13.1 5.1 3.1 21.3 93.7 | 0.038 | exp/tri2/decode_dev_rescore/score_14_0.0/ctm.filt.filt.sys +%WER 20.9 | 1155 27500 | 81.9 14.0 4.1 2.8 20.9 90.5 | 0.046 | exp/tri2/decode_test_rescore/score_13_0.0/ctm.filt.filt.sys + +####### tri3 results ########## +# small LM +%WER 18.7 | 507 17783 | 83.9 11.4 4.7 2.6 18.7 92.3 | -0.006 | exp/tri3/decode_dev/score_17_0.0/ctm.filt.filt.sys +%WER 17.6 | 1155 27500 | 84.7 11.6 3.7 2.4 17.6 87.2 | 0.013 | exp/tri3/decode_test/score_15_0.0/ctm.filt.filt.sys + +# big LM +%WER 17.6 | 507 17783 | 85.0 10.5 4.4 2.6 17.6 90.5 | -0.030 | exp/tri3/decode_dev_rescore/score_16_0.0/ctm.filt.filt.sys +%WER 16.7 | 1155 27500 | 85.7 10.9 3.4 2.4 16.7 86.4 | -0.044 | exp/tri3/decode_test_rescore/score_14_0.0/ctm.filt.filt.sys + + +for d in exp/tri3_cleaned/decode_*; do grep Sum $d/score*/*ys | utils/best_wer.sh ; done +# tri3 after cleaning, small LM. +# +%WER 19.0 | 507 17783 | 83.9 11.4 4.7 2.9 19.0 92.1 | -0.054 | exp/tri3_cleaned/decode_dev/score_13_0.5/ctm.filt.filt.sys +%WER 17.6 | 1155 27500 | 84.8 11.7 3.5 2.4 17.6 87.6 | 0.001 | exp/tri3_cleaned/decode_test/score_15_0.0/ctm.filt.filt.sys + +# tri3 after cleaning, large LM. +%WER 17.9 | 507 17783 | 85.1 10.5 4.4 3.0 17.9 90.9 | -0.055 | exp/tri3_cleaned/decode_dev_rescore/score_15_0.0/ctm.filt.filt.sys +%WER 16.6 | 1155 27500 | 85.8 10.9 3.4 2.4 16.6 86.4 | -0.058 | exp/tri3_cleaned/decode_test_rescore/score_15_0.0/ctm.filt.filt.sys + + +########## nnet3+chain systems +# +# chain+TDNN, small LM +%WER 9.7 | 507 17783 | 91.7 5.8 2.5 1.4 9.7 78.7 | 0.097 | exp/chain_cleaned/tdnn_sp_bi/decode_dev/score_10_0.0/ctm.filt.filt.sys +%WER 9.5 | 1155 27500 | 91.7 5.8 2.5 1.2 9.5 72.5 | 0.079 | exp/chain_cleaned/tdnn_sp_bi/decode_test/score_10_0.0/ctm.filt.filt.sys + +# chain+TDNN, large LM +%WER 9.0 | 507 17783 | 92.3 5.3 2.4 1.3 9.0 76.7 | 0.067 | exp/chain_cleaned/tdnn_sp_bi/decode_dev_rescore/score_10_0.0/ctm.filt.filt.sys +%WER 9.0 | 1155 27500 | 92.2 5.3 2.5 1.2 9.0 71.3 | 0.064 | exp/chain_cleaned/tdnn_sp_bi/decode_test_rescore/score_10_0.0/ctm.filt.filt.sys + + # chain+TDNN systems ran without cleanup, using the command: + # local/chain/run_tdnn.sh --train-set train --gmm tri3 --nnet3-affix "" + # for d in exp/chain/tdnn_sp_bi/decode_*; do grep Sum $d/*/*ys | utils/best_wer.sh; done + # This is about 0.1 (dev) / 0.4 (test) % worse than the corresponding results with cleanup. + %WER 9.8 | 507 17783 | 91.6 6.0 2.4 1.5 9.8 80.1 | -0.038 | exp/chain/tdnn_sp_bi/decode_dev/score_8_0.0/ctm.filt.filt.sys + %WER 9.9 | 1155 27500 | 91.4 5.7 2.9 1.3 9.9 74.9 | 0.083 | exp/chain/tdnn_sp_bi/decode_test/score_9_0.0/ctm.filt.filt.sys + %WER 9.1 | 507 17783 | 92.3 5.5 2.3 1.4 9.1 77.5 | 0.011 | exp/chain/tdnn_sp_bi/decode_dev_rescore/score_8_0.0/ctm.filt.filt.sys + %WER 9.4 | 1155 27500 | 91.9 5.6 2.5 1.4 9.4 72.7 | 0.018 | exp/chain/tdnn_sp_bi/decode_test_rescore/score_8_0.0/ctm.filt.filt.sys +#################################################################################################################### +For the record, results with unpruned LM: +%WER 8.2 | 507 17783 | 92.8 4.5 2.6 1.1 8.2 70.8 | -0.036 | exp/chain/tdnn_sp_bi/decode_dev_1848_rescore/score_9_0.0/ctm.filt.filt.sys +%WER 9.3 | 1155 27500 | 91.8 5.1 3.0 1.2 9.3 71.7 | -0.008 | exp/chain/tdnn_sp_bi/decode_test_1848_rescore/score_9_0.0/ctm.filt.filt.sys + + +##################################################################################################################### +# BELOW FOR REFERENCE, old results with the Cantab LM -- including Nnet3 results tdnn + blstm +##################################################################################################################### + +####### nnet3 results ##### + +# tdnn, small LM +for x in exp/nnet3_cleaned/tdnn_sp/decode_*; do grep Sum $x/*ore*/*ys | utils/best_wer.sh; done +%WER 12.5 | 507 17783 | 89.6 7.4 2.9 2.2 12.5 83.6 | -0.118 | exp/nnet3_cleaned/tdnn_sp/decode_dev/score_10_0.0/ctm.filt.filt.sys +%WER 11.4 | 1155 27500 | 90.0 7.2 2.8 1.4 11.4 78.1 | -0.056 | exp/nnet3_cleaned/tdnn_sp/decode_test/score_11_0.0/ctm.filt.filt.sys + +# tdnn, large LM +%WER 11.9 | 507 17783 | 90.0 7.0 3.0 1.9 11.9 81.9 | -0.072 | exp/nnet3_cleaned/tdnn_sp/decode_dev_rescore/score_11_0.0/ctm.filt.filt.sys +%WER 10.8 | 1155 27500 | 90.6 6.7 2.7 1.4 10.8 76.6 | -0.101 | exp/nnet3_cleaned/tdnn_sp/decode_test_rescore/score_11_0.0/ctm.filt.filt.sys + +# BLSTM small LM +# The results are with ClipGradientComponent and without deriv_time fix, so it may not reflect the latest changes +# for x in exp/nnet3_cleaned/lstm_bidirectional_sp/decode_*; do grep Sum $x/*ore*/*ys | utils/best_wer.sh; done +%WER 11.1 | 507 17783 | 90.5 6.8 2.7 1.6 11.1 80.7 | -0.251 | exp/nnet3_cleaned/lstm_bidirectional_sp/decode_dev/score_10_0.0/ctm.filt.filt.sys +%WER 10.2 | 1155 27500 | 91.0 6.4 2.6 1.2 10.2 75.5 | -0.278 | exp/nnet3_cleaned/lstm_bidirectional_sp/decode_test/score_10_0.0/ctm.filt.filt.sys + +# BLSTM large LM +%WER 10.6 | 507 17783 | 91.0 6.5 2.5 1.6 10.6 79.3 | -0.275 | exp/nnet3_cleaned/lstm_bidirectional_sp/decode_dev_rescore/score_10_0.0/ctm.filt.filt.sys +%WER 9.9 | 1155 27500 | 91.3 6.1 2.6 1.2 9.9 74.1 | -0.306 | exp/nnet3_cleaned/lstm_bidirectional_sp/decode_test_rescore/score_10_0.0/ctm.filt.filt.sys + + # nnet3 results without cleanup, run with: + # local/nnet3/run_tdnn.sh --train-set train --gmm tri3 --nnet3-affix "" + # This is only about 0.1% worse than the baseline with cleanup... the cleanup helps + # mostly for the chain models. + for d in exp/nnet3/tdnn_sp/decode_*; do grep Sum $d/*/*ys | utils/best_wer.sh; done + %WER 12.6 | 507 17783 | 89.6 7.4 3.1 2.1 12.6 83.6 | -0.051 | exp/nnet3/tdnn_sp/decode_dev/score_10_0.0/ctm.filt.filt.sys + %WER 11.5 | 1155 27500 | 90.0 7.2 2.8 1.5 11.5 79.5 | -0.141 | exp/nnet3/tdnn_sp/decode_test/score_10_0.0/ctm.filt.filt.sys + + %WER 11.9 | 507 17783 | 90.0 6.9 3.1 1.9 11.9 82.4 | -0.032 | exp/nnet3/tdnn_sp/decode_dev_rescore/score_11_0.0/ctm.filt.filt.sys + %WER 10.9 | 1155 27500 | 90.4 6.7 2.9 1.4 10.9 77.1 | -0.109 | exp/nnet3/tdnn_sp/decode_test_rescore/score_11_0.0/ctm.filt.filt.sys + + +########## nnet3+chain systems + +# chain+TDNN, small LM +%WER 10.4 | 507 17783 | 91.1 6.3 2.6 1.5 10.4 80.5 | 0.052 | exp/chain_cleaned/tdnn_sp_bi/decode_dev/score_10_0.0/ctm.filt.filt.sys +%WER 9.8 | 1155 27500 | 91.4 6.0 2.6 1.1 9.8 73.5 | 0.048 | exp/chain_cleaned/tdnn_sp_bi/decode_test/score_10_0.0/ctm.filt.filt.sys + +# chain+TDNN, large LM +%WER 9.8 | 507 17783 | 91.6 5.8 2.6 1.5 9.8 78.7 | 0.022 | exp/chain_cleaned/tdnn_sp_bi/decode_dev_rescore/score_10_0.0/ctm.filt.filt.sys +%WER 9.3 | 1155 27500 | 91.8 5.5 2.7 1.1 9.3 71.7 | 0.001 | exp/chain_cleaned/tdnn_sp_bi/decode_test_rescore/score_10_0.0/ctm.filt.filt.sys + + + # chain+TDNN systems ran without cleanup, using the command: + # local/chain/run_tdnn.sh --train-set train --gmm tri3 --nnet3-affix "" + # for d in exp/chain/tdnn_sp_bi/decode_*; do grep Sum $d/*/*ys | utils/best_wer.sh; done + # This is about 0.6% worse than the corresponding results with cleanup. + + %WER 11.0 | 507 17783 | 90.9 6.5 2.6 1.9 11.0 80.5 | 0.004 | exp/chain/tdnn_sp_bi/decode_dev/score_8_0.0/ctm.filt.filt.sys + %WER 10.1 | 1155 27500 | 91.2 6.0 2.8 1.3 10.1 75.5 | -0.004 | exp/chain/tdnn_sp_bi/decode_test/score_8_0.0/ctm.filt.filt.sys + %WER 10.6 | 507 17783 | 90.7 5.5 3.8 1.3 10.6 79.3 | 0.070 | exp/chain/tdnn_sp_bi/decode_dev_rescore/score_10_0.0/ctm.filt.filt.sys + %WER 9.8 | 1155 27500 | 91.2 5.2 3.7 1.0 9.8 73.2 | 0.055 | exp/chain/tdnn_sp_bi/decode_test_rescore/score_10_0.0/ctm.filt.filt.sys diff --git a/egs/tedlium/s5_r2_wsj/cmd.sh b/egs/tedlium/s5_r2_wsj/cmd.sh new file mode 100755 index 00000000000..87eab1892e8 --- /dev/null +++ b/egs/tedlium/s5_r2_wsj/cmd.sh @@ -0,0 +1,29 @@ +# "queue.pl" uses qsub. The options to it are +# options to qsub. If you have GridEngine installed, +# change this to a queue you have access to. +# Otherwise, use "run.pl", which will run jobs locally +# (make sure your --num-jobs options are no more than +# the number of cpus on your machine. + +# Run locally: +#export train_cmd=run.pl +#export decode_cmd=run.pl + +# JHU cluster (or most clusters using GridEngine, with a suitable +# conf/queue.conf). +export train_cmd="queue.pl" +export decode_cmd="queue.pl --mem 4G" + +host=$(hostname -f) +if [ ${host#*.} == "fit.vutbr.cz" ]; then + # BUT cluster: + queue="all.q@@blade,all.q@@speech" + gpu_queue="long.q@@gpu" + storage="matylda5" + export train_cmd="queue.pl -q $queue -l ram_free=1500M,mem_free=1500M,${storage}=1" + export decode_cmd="queue.pl -q $queue -l ram_free=2500M,mem_free=2500M,${storage}=0.5" +elif [ ${host#*.} == "cm.cluster" ]; then + # MARCC bluecrab cluster: + export train_cmd="slurm.pl --time 4:00:00 " + export decode_cmd="slurm.pl --mem 4G --time 4:00:00 " +fi diff --git a/egs/tedlium/s5_r2_wsj/conf/decode.config b/egs/tedlium/s5_r2_wsj/conf/decode.config new file mode 100644 index 00000000000..7ba966f2b83 --- /dev/null +++ b/egs/tedlium/s5_r2_wsj/conf/decode.config @@ -0,0 +1 @@ +# empty config, just use the defaults. diff --git a/egs/tedlium/s5_r2_wsj/conf/decode_dnn.config b/egs/tedlium/s5_r2_wsj/conf/decode_dnn.config new file mode 100644 index 00000000000..ab8dcc1dc08 --- /dev/null +++ b/egs/tedlium/s5_r2_wsj/conf/decode_dnn.config @@ -0,0 +1,2 @@ +beam=13.0 # beam for decoding. Was 13.0 in the scripts. +lattice_beam=8.0 # this has most effect on size of the lattices. diff --git a/egs/tedlium/s5_r2_wsj/conf/fbank.conf b/egs/tedlium/s5_r2_wsj/conf/fbank.conf new file mode 100644 index 00000000000..4c57f8a8765 --- /dev/null +++ b/egs/tedlium/s5_r2_wsj/conf/fbank.conf @@ -0,0 +1,5 @@ +--window-type=hamming # disable Dans window, use the standard +--use-energy=false # only fbank outputs +--dither=1 +--num-mel-bins=40 # 8 filters/octave, 40 filters/16Khz as used by IBM +--htk-compat=true # try to make it compatible with HTK diff --git a/egs/tedlium/s5_r2_wsj/conf/mfcc.conf b/egs/tedlium/s5_r2_wsj/conf/mfcc.conf new file mode 100644 index 00000000000..32988403b00 --- /dev/null +++ b/egs/tedlium/s5_r2_wsj/conf/mfcc.conf @@ -0,0 +1,2 @@ +--use-energy=false +--sample-frequency=16000 diff --git a/egs/tedlium/s5_r2_wsj/conf/mfcc_hires.conf b/egs/tedlium/s5_r2_wsj/conf/mfcc_hires.conf new file mode 100644 index 00000000000..434834a6725 --- /dev/null +++ b/egs/tedlium/s5_r2_wsj/conf/mfcc_hires.conf @@ -0,0 +1,10 @@ +# config for high-resolution MFCC features, intended for neural network training +# Note: we keep all cepstra, so it has the same info as filterbank features, +# but MFCC is more easily compressible (because less correlated) which is why +# we prefer this method. +--use-energy=false # use average of log energy, not energy. +--num-mel-bins=40 # similar to Google's setup. +--num-ceps=40 # there is no dimensionality reduction. +--low-freq=20 # low cutoff frequency for mel bins... this is high-bandwidth data, so + # there might be some information at the low end. +--high-freq=-400 # high cutoff frequently, relative to Nyquist of 8000 (=7600) diff --git a/egs/tedlium/s5_r2_wsj/conf/no_k20.conf b/egs/tedlium/s5_r2_wsj/conf/no_k20.conf new file mode 100644 index 00000000000..f0cba4df971 --- /dev/null +++ b/egs/tedlium/s5_r2_wsj/conf/no_k20.conf @@ -0,0 +1,13 @@ +# Default configuration +command qsub -v PATH -cwd -S /bin/bash -j y -l arch=*64* +option mem=* -l mem_free=$0,ram_free=$0 +option mem=0 # Do not add anything to qsub_opts +option num_threads=* -pe smp $0 +option num_threads=1 # Do not add anything to qsub_opts +option max_jobs_run=* -tc $0 +default gpu=0 +option gpu=0 -q all.q +option gpu=* -l gpu=$0 -q g.q +default allow_k20=true +option allow_k20=true +option allow_k20=false -l 'hostname=!g01*&!g02*&!b06*' diff --git a/egs/tedlium/s5_r2_wsj/conf/online_cmvn.conf b/egs/tedlium/s5_r2_wsj/conf/online_cmvn.conf new file mode 100644 index 00000000000..7748a4a4dd3 --- /dev/null +++ b/egs/tedlium/s5_r2_wsj/conf/online_cmvn.conf @@ -0,0 +1 @@ +# configuration file for apply-cmvn-online, used in the script ../local/run_online_decoding.sh diff --git a/egs/tedlium/s5_r2_wsj/conf/pitch.conf b/egs/tedlium/s5_r2_wsj/conf/pitch.conf new file mode 100644 index 00000000000..bba51335be3 --- /dev/null +++ b/egs/tedlium/s5_r2_wsj/conf/pitch.conf @@ -0,0 +1,2 @@ +--nccf-ballast-online=true # helps for online operation. + diff --git a/egs/tedlium/s5_r2_wsj/local/dict b/egs/tedlium/s5_r2_wsj/local/dict new file mode 120000 index 00000000000..384304fdf2a --- /dev/null +++ b/egs/tedlium/s5_r2_wsj/local/dict @@ -0,0 +1 @@ +../../../wsj/s5/local/dict/ \ No newline at end of file diff --git a/egs/tedlium/s5_r2_wsj/local/download_data.sh b/egs/tedlium/s5_r2_wsj/local/download_data.sh new file mode 120000 index 00000000000..bdbe175f21f --- /dev/null +++ b/egs/tedlium/s5_r2_wsj/local/download_data.sh @@ -0,0 +1 @@ +../../s5_r2/local/download_data.sh \ No newline at end of file diff --git a/egs/tedlium/s5_r2_wsj/local/find_transcripts.pl b/egs/tedlium/s5_r2_wsj/local/find_transcripts.pl new file mode 120000 index 00000000000..1f7b8dc414c --- /dev/null +++ b/egs/tedlium/s5_r2_wsj/local/find_transcripts.pl @@ -0,0 +1 @@ +../../../wsj/s5/local/find_transcripts.pl \ No newline at end of file diff --git a/egs/tedlium/s5_r2_wsj/local/flist2scp.pl b/egs/tedlium/s5_r2_wsj/local/flist2scp.pl new file mode 120000 index 00000000000..58b414114e7 --- /dev/null +++ b/egs/tedlium/s5_r2_wsj/local/flist2scp.pl @@ -0,0 +1 @@ +../../../wsj/s5/local/flist2scp.pl \ No newline at end of file diff --git a/egs/tedlium/s5_r2_wsj/local/format_lms.sh b/egs/tedlium/s5_r2_wsj/local/format_lms.sh new file mode 100755 index 00000000000..5ba0d8c4c12 --- /dev/null +++ b/egs/tedlium/s5_r2_wsj/local/format_lms.sh @@ -0,0 +1,41 @@ +#!/bin/bash +# +# Copyright 2014 Nickolay V. Shmyrev +# Apache 2.0 + +if [ -f path.sh ]; then . path.sh; fi + +small_arpa_lm=data/local/local_lm/data/arpa/4gram_small.arpa.gz +big_arpa_lm=data/local/local_lm/data/arpa/4gram_big.arpa.gz +lang=lang_nosp + +. utils/parse_options.sh + +for f in $small_arpa_lm $big_arpa_lm data/${lang}/words.txt; do + [ ! -f $f ] && echo "$0: expected file $f to exist" && exit 1 +done + + +set -e + +if [ -f data/${lang}/G.fst ] && [ data/${lang}/G.fst -nt $small_arpa_lm ]; then + echo "$0: not regenerating data/${lang}/G.fst as it already exists and " + echo ".. is newer than the source LM." +else + arpa2fst --disambig-symbol=#0 --read-symbol-table=data/${lang}/words.txt \ + "gunzip -c $small_arpa_lm|" data/${lang}/G.fst + echo "$0: Checking how stochastic G is (the first of these numbers should be small):" + fstisstochastic data/${lang}/G.fst || true + utils/validate_lang.pl --skip-determinization-check data/${lang} +fi + + + +if [ -f data/${lang}_rescore/G.carpa ] && [ data/${lang}_rescore/G.carpa -nt $big_arpa_lm ] && \ + [ data/${lang}_rescore/G.carpa -nt data/${lang}/words.txt ]; then + echo "$0: not regenerating data/${lang}_rescore/ as it seems to already by up to date." +else + utils/build_const_arpa_lm.sh $big_arpa_lm data/${lang} data/${lang}_rescore || exit 1; +fi + +exit 0; diff --git a/egs/tedlium/s5_r2_wsj/local/join_suffix.py b/egs/tedlium/s5_r2_wsj/local/join_suffix.py new file mode 120000 index 00000000000..ed457448f2d --- /dev/null +++ b/egs/tedlium/s5_r2_wsj/local/join_suffix.py @@ -0,0 +1 @@ +../../s5_r2/local/join_suffix.py \ No newline at end of file diff --git a/egs/tedlium/s5_r2_wsj/local/lm/merge_word_counts.py b/egs/tedlium/s5_r2_wsj/local/lm/merge_word_counts.py new file mode 100755 index 00000000000..6338cbbf875 --- /dev/null +++ b/egs/tedlium/s5_r2_wsj/local/lm/merge_word_counts.py @@ -0,0 +1,30 @@ +#! /usr/bin/env python + +# Copyright 2016 Vimal Manohar +# Apache 2.0. + +"""This script merges pocolm word_counts and write a new word_counts file. +A min-count argument is required to only write counts that are above the +specified minimum count. +""" + +import sys + + +def main(): + if len(sys.argv) != 2: + sys.stderr.write("Usage: {0} \n".format(sys.argv[0])) + raise SystemExit(1) + + words = {} + for line in sys.stdin.readlines(): + parts = line.strip().split() + words[parts[1]] = words.get(parts[1], 0) + int(parts[0]) + + for word, count in words.iteritems(): + if count >= int(sys.argv[1]): + print ("{0} {1}".format(count, word)) + + +if __name__ == '__main__': + main() diff --git a/egs/tedlium/s5_r2_wsj/local/ndx2flist.pl b/egs/tedlium/s5_r2_wsj/local/ndx2flist.pl new file mode 120000 index 00000000000..c7909a824a9 --- /dev/null +++ b/egs/tedlium/s5_r2_wsj/local/ndx2flist.pl @@ -0,0 +1 @@ +../../../wsj/s5/local/ndx2flist.pl \ No newline at end of file diff --git a/egs/tedlium/s5_r2_wsj/local/normalize_transcript.pl b/egs/tedlium/s5_r2_wsj/local/normalize_transcript.pl new file mode 120000 index 00000000000..5546706fce5 --- /dev/null +++ b/egs/tedlium/s5_r2_wsj/local/normalize_transcript.pl @@ -0,0 +1 @@ +../../../wsj/s5/local/normalize_transcript.pl \ No newline at end of file diff --git a/egs/tedlium/s5_r2_wsj/local/prepare_data.sh b/egs/tedlium/s5_r2_wsj/local/prepare_data.sh new file mode 100755 index 00000000000..6b852616eb6 --- /dev/null +++ b/egs/tedlium/s5_r2_wsj/local/prepare_data.sh @@ -0,0 +1,86 @@ +#!/bin/bash +# +# Copyright 2014 Nickolay V. Shmyrev +# 2014 Brno University of Technology (Author: Karel Vesely) +# 2016 Johns Hopkins University (Author: Daniel Povey) +# Apache 2.0 + +# To be run from one directory above this script. + +. path.sh + +export LC_ALL=C + +set -e +set -o pipefail +set -u + +# Prepare: test, train, +for set in dev test train; do + dir=data/$set.orig + mkdir -p $dir + + # Merge transcripts into a single 'stm' file, do some mappings: + # - -> : map dev stm labels to be coherent with train + test, + # - -> : --||-- + # - (2) -> null : remove pronunciation variants in transcripts, keep in dictionary + # - -> null : remove marked , it is modelled implicitly (in kaldi) + # - (...) -> null : remove utterance names from end-lines of train + # - it 's -> it's : merge words that contain apostrophe (if compound in dictionary, local/join_suffix.py) + { # Add STM header, so sclite can prepare the '.lur' file + echo ';; +;; LABEL "o" "Overall" "Overall results" +;; LABEL "f0" "f0" "Wideband channel" +;; LABEL "f2" "f2" "Telephone channel" +;; LABEL "male" "Male" "Male Talkers" +;; LABEL "female" "Female" "Female Talkers" +;;' + # Process the STMs + cat db/TEDLIUM_release2/$set/stm/*.stm | sort -k1,1 -k2,2 -k4,4n | \ + sed -e 's:::' \ + -e 's:::' \ + -e 's:([0-9])::g' \ + -e 's:::g' \ + -e 's:([^ ]*)$::' | \ + awk '{ $2 = "A"; print $0; }' + } | local/join_suffix.py > data/$set.orig/stm + + # Prepare 'text' file + # - {NOISE} -> [NOISE] : map the tags to match symbols in dictionary + cat $dir/stm | grep -v -e 'ignore_time_segment_in_scoring' -e ';;' | \ + awk '{ printf ("%s-%07d-%07d", $1, $4*100, $5*100); + for (i=7;i<=NF;i++) { printf(" %s", $i); } + printf("\n"); + }' | tr '{}' '[]' | sort -k1,1 > $dir/text.orig + + cat $dir/text.orig | awk '{if (NF > 1) print $0}' | \ + local/normalize_transcript.pl '' | awk '{if (NF > 1) print $0}' \ + > $dir/text || exit 1 + + # Prepare 'segments', 'utt2spk', 'spk2utt' + cat $dir/text | cut -d" " -f 1 | awk -F"-" '{printf("%s %s %07.2f %07.2f\n", $0, $1, $2/100.0, $3/100.0)}' > $dir/segments + cat $dir/segments | awk '{print $1, $2}' > $dir/utt2spk + cat $dir/utt2spk | utils/utt2spk_to_spk2utt.pl > $dir/spk2utt + + # Prepare 'wav.scp', 'reco2file_and_channel' + cat $dir/spk2utt | awk -v set=$set -v pwd=$PWD '{ printf("%s sph2pipe -f wav -p %s/db/TEDLIUM_release2/%s/sph/%s.sph |\n", $1, pwd, set, $1); }' > $dir/wav.scp + cat $dir/wav.scp | awk '{ print $1, $1, "A"; }' > $dir/reco2file_and_channel + + # Create empty 'glm' file + echo ';; empty.glm + [FAKE] => %HESITATION / [ ] __ [ ] ;; hesitation token + ' > data/$set.orig/glm + + # The training set seems to not have enough silence padding in the segmentations, + # especially at the beginning of segments. Extend the times. + if [ $set == "train" ]; then + mv data/$set.orig/segments data/$set.orig/segments.temp + utils/data/extend_segment_times.py --start-padding=0.15 \ + --end-padding=0.1 data/$set.orig/segments || exit 1 + rm data/$set.orig/segments.temp + fi + + # Check that data dirs are okay! + utils/validate_data_dir.sh --no-feats $dir || exit 1 +done + diff --git a/egs/tedlium/s5_r2_wsj/local/prepare_dict.sh b/egs/tedlium/s5_r2_wsj/local/prepare_dict.sh new file mode 100755 index 00000000000..7742ee32c4f --- /dev/null +++ b/egs/tedlium/s5_r2_wsj/local/prepare_dict.sh @@ -0,0 +1,201 @@ +#!/bin/bash + +# Copyright 2010-2012 Microsoft Corporation +# 2012-2014 Johns Hopkins University (Author: Daniel Povey) +# 2015 Guoguo Chen +# 2016 Vimal Manohar +# Apache 2.0 + +# Call this script from one level above, e.g. from the s3/ directory. It puts +# its output in data/local/. + +# The parts of the output of this that will be needed are +# [in data/local/dict/ ] +# lexicon.txt +# extra_questions.txt +# nonsilence_phones.txt +# optional_silence.txt +# silence_phones.txt + +. path.sh +. cmd.sh + +set -e +set -o pipefail +set -u + +# run this from ../ +dict_suffix= +stage=-1 + +echo "$0 $@" # Print the command line for logging +. utils/parse_options.sh || exit 1; + +if [ $# -ne 1 ]; then + echo "Usage: $0 " + echo "e.g. : $0 data/local/local_lm/data/work/wordlist" + exit 1 +fi + +wordlist=$1 + +dir=data/local/dict${dict_suffix} +mkdir -p $dir + +if [ ! -d $dir/cmudict ]; then + # (1) Get the CMU dictionary + svn co https://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict \ + $dir/cmudict || exit 1; +fi + +# can add -r 10966 for strict compatibility. + + +#(2) Dictionary preparation: + + +if [ $stage -le 0 ]; then + # Make phones symbol-table (adding in silence and verbal and non-verbal noises at this point). + # We are adding suffixes _B, _E, _S for beginning, ending, and singleton phones. + + # silence phones, one per line. + (echo SIL; echo SPN; echo NSN; echo UNK;) > $dir/silence_phones.txt + echo SIL > $dir/optional_silence.txt + + # nonsilence phones; on each line is a list of phones that correspond + # really to the same base phone. + cat $dir/cmudict/cmudict.0.7a.symbols | perl -ane 's:\r::; print;' | \ + perl -e 'while(<>){ + chop; m:^([^\d]+)(\d*)$: || die "Bad phone $_"; + $phones_of{$1} .= "$_ "; } + foreach $list (values %phones_of) {print $list . "\n"; } ' \ + > $dir/nonsilence_phones.txt || exit 1; + + # A few extra questions that will be added to those obtained by automatically clustering + # the "real" phones. These ask about stress; there's also one for silence. + cat $dir/silence_phones.txt| awk '{printf("%s ", $1);} END{printf "\n";}' > $dir/extra_questions.txt || exit 1; + cat $dir/nonsilence_phones.txt | perl -e 'while(<>){ foreach $p (split(" ", $_)) { + $p =~ m:^([^\d]+)(\d*)$: || die "Bad phone $_"; $q{$2} .= "$p "; } } foreach $l (values %q) {print "$l\n";}' \ + >> $dir/extra_questions.txt || exit 1; + + grep -v ';;;' $dir/cmudict/cmudict.0.7a | \ + perl -ane 'if(!m:^;;;:){ s:(\S+)\(\d+\) :$1 :; print; }' \ + > $dir/dict.cmu || exit 1; + + # Add to cmudict the silences, noises etc. + + (echo '!SIL SIL'; echo ' SPN'; echo ' UNK'; echo ' NSN'; ) | \ + cat - $dir/dict.cmu > $dir/lexicon2_raw.txt + awk '{print $1}' $dir/lexicon2_raw.txt > $dir/orig_wordlist + + cat <$dir/silence_phones.txt +SIL +SPN +NSN +UNK +EOF + +fi + + +if [ $stage -le 2 ]; then + if [ ! -f exp/g2p/.done ]; then + steps/dict/train_g2p.sh --cmd "$train_cmd" \ + --silence-phones $dir/silence_phones.txt \ + $dir/dict.cmu exp/g2p + touch exp/g2p/.done + fi +fi + +export PATH=$PATH:`pwd`/local/dict + +if [ $stage -le 3 ]; then + cat $wordlist | python -c ' +import sys + +words = {} +for line in open(sys.argv[1]).readlines(): + words[line.strip()] = 1 + +oovs = {} +for line in sys.stdin.readlines(): + word = line.strip() + if word not in words: + oovs[word] = 1 + +for oov in oovs: + print (oov)' $dir/orig_wordlist | sort -u > $dir/oovlist + + cat $dir/oovlist | \ + get_acronym_prons.pl $dir/lexicon2_raw.txt > $dir/dict.acronyms +fi + +mkdir -p $dir/f $dir/b # forward, backward directions of rules... + +if [ $stage -le 4 ]; then + # forward is normal suffix + # rules, backward is reversed (prefix rules). These + # dirs contain stuff we create while making the rule-based + # extensions to the dictionary. + + # Remove ; and , from words, if they are present; these + # might crash our scripts, as they are used as separators there. + filter_dict.pl $dir/dict.cmu > $dir/f/dict + cat $dir/oovlist | filter_dict.pl > $dir/f/oovs + reverse_dict.pl $dir/f/dict > $dir/b/dict + reverse_dict.pl $dir/f/oovs > $dir/b/oovs +fi + +if [ $stage -le 5 ]; then + # The next stage takes a few minutes. + # Note: the forward stage takes longer, as English is + # mostly a suffix-based language, and there are more rules + # that it finds. + for d in $dir/f $dir/b; do + ( + cd $d + cat dict | get_rules.pl 2>get_rules.log >rules + get_rule_hierarchy.pl rules >hierarchy + awk '{print $1}' dict | get_candidate_prons.pl rules dict | \ + limit_candidate_prons.pl hierarchy | \ + score_prons.pl dict | \ + count_rules.pl >rule.counts + # the sort command below is just for convenience of reading. + score_rules.pl rules.with_scores + get_candidate_prons.pl rules.with_scores dict oovs | \ + limit_candidate_prons.pl hierarchy > oovs.candidates + ) & + done + wait +fi + +if [ $stage -le 6 ]; then + # Merge the candidates. + reverse_candidates.pl $dir/b/oovs.candidates | cat - $dir/f/oovs.candidates | sort > $dir/oovs.candidates + select_candidate_prons.pl <$dir/oovs.candidates | awk -F';' '{printf("%s %s\n", $1, $2);}' \ + > $dir/dict.oovs + + cat $dir/dict.acronyms $dir/dict.oovs | sort | uniq > $dir/dict.oovs_merged + awk '{print $1}' $dir/dict.oovs_merged | uniq > $dir/oovlist.handled + sort $dir/oovlist | { diff - $dir/oovlist.handled || true; } | grep -v 'd' | sed 's:< ::' > $dir/oovlist.not_handled +fi + +if [ $stage -le 7 ]; then + steps/dict/apply_g2p.sh --cmd "$train_cmd" \ + $dir/oovlist.not_handled exp/g2p exp/g2p/oov_lex + cat exp/g2p/oov_lex/lexicon.lex | cut -f 1,3 | awk '{if (NF > 1) print $0}' > \ + $dir/dict.oovs_g2p +fi + +if [ $stage -le 8 ]; then + # the sort | uniq is to remove a duplicated pron from cmudict. + cat $dir/lexicon2_raw.txt $dir/dict.oovs_merged $dir/dict.oovs_g2p | sort | uniq > \ + $dir/lexicon.txt || exit 1; + # lexicon.txt is without the _B, _E, _S, _I markers. + + rm $dir/lexiconp.txt 2>/dev/null || true +fi + +echo "Dictionary preparation succeeded" + + diff --git a/egs/tedlium/s5_r2_wsj/local/run_segmentation_wsj.sh b/egs/tedlium/s5_r2_wsj/local/run_segmentation_wsj.sh new file mode 120000 index 00000000000..3017d949d27 --- /dev/null +++ b/egs/tedlium/s5_r2_wsj/local/run_segmentation_wsj.sh @@ -0,0 +1 @@ +tuning/run_segmentation_wsj_e.sh \ No newline at end of file diff --git a/egs/tedlium/s5_r2_wsj/local/score.sh b/egs/tedlium/s5_r2_wsj/local/score.sh new file mode 120000 index 00000000000..d89286dc25a --- /dev/null +++ b/egs/tedlium/s5_r2_wsj/local/score.sh @@ -0,0 +1 @@ +score_sclite.sh \ No newline at end of file diff --git a/egs/tedlium/s5_r2_wsj/local/score_basic.sh b/egs/tedlium/s5_r2_wsj/local/score_basic.sh new file mode 100755 index 00000000000..d840bd9c981 --- /dev/null +++ b/egs/tedlium/s5_r2_wsj/local/score_basic.sh @@ -0,0 +1,55 @@ +#!/bin/bash +# Copyright 2012 Johns Hopkins University (Author: Daniel Povey) +# Apache 2.0 + +[ -f ./path.sh ] && . ./path.sh + +# begin configuration section. +cmd=run.pl +min_lmwt=7 +max_lmwt=17 +#end configuration section. + +[ -f ./path.sh ] && . ./path.sh +. parse_options.sh || exit 1; + +if [ $# -ne 3 ]; then + echo "Usage: local/score.sh [--cmd (run.pl|queue.pl...)] " + echo " Options:" + echo " --cmd (run.pl|queue.pl...) # specify how to run the sub-processes." + echo " --min_lmwt # minumum LM-weight for lattice rescoring " + echo " --max_lmwt # maximum LM-weight for lattice rescoring " + exit 1; +fi + +data=$1 +lang_or_graph=$2 +dir=$3 + +symtab=$lang_or_graph/words.txt + +for f in $symtab $dir/lat.1.gz $data/text; do + [ ! -f $f ] && echo "score.sh: no such file $f" && exit 1; +done + +mkdir -p $dir/scoring/log + +cat $data/text > $dir/scoring/test_filt.txt + +$cmd LMWT=$min_lmwt:$max_lmwt $dir/scoring/log/best_path.LMWT.log \ + lattice-best-path --lm-scale=LMWT --word-symbol-table=$symtab \ + "ark:gunzip -c $dir/lat.*.gz|" ark,t:$dir/scoring/LMWT.tra || exit 1; + +# Note: the double level of quoting for the sed command + +$cmd LMWT=$min_lmwt:$max_lmwt $dir/scoring/log/score.LMWT.log \ + cat $dir/scoring/LMWT.tra \| \ + utils/int2sym.pl -f 2- $symtab \| \ + sed 's:\\\\]::g' \| \ + compute-wer --text --mode=present \ + ark:$dir/scoring/test_filt.txt ark,p:- ">&" $dir/wer_LMWT || exit 1; + +# Show results +for f in $dir/wer_*; do echo $f; egrep '(WER)|(SER)' < $f; done + +exit 0; diff --git a/egs/tedlium/s5_r2_wsj/local/score_sclite.sh b/egs/tedlium/s5_r2_wsj/local/score_sclite.sh new file mode 100755 index 00000000000..16c8b30e52f --- /dev/null +++ b/egs/tedlium/s5_r2_wsj/local/score_sclite.sh @@ -0,0 +1,96 @@ +#!/bin/bash +# +# Copyright Johns Hopkins University (Author: Daniel Povey) 2012, +# Brno University of Technology (Author: Karel Vesely) 2014, +# Apache 2.0 +# + +# begin configuration section. +cmd=run.pl +stage=0 +decode_mbr=true +beam=7 # speed-up, but may affect MBR confidences. +word_ins_penalty=0.0,0.5,1.0 +min_lmwt=7 +max_lmwt=17 +iter=final +#end configuration section. + +[ -f ./path.sh ] && . ./path.sh +. parse_options.sh || exit 1; + +if [ $# -ne 3 ]; then + echo "Usage: local/score_sclite.sh [--cmd (run.pl|queue.pl...)] " + echo " Options:" + echo " --cmd (run.pl|queue.pl...) # specify how to run the sub-processes." + echo " --stage (0|1|2) # start scoring script from part-way through." + echo " --min_lmwt # minumum LM-weight for lattice rescoring " + echo " --max_lmwt # maximum LM-weight for lattice rescoring " + exit 1; +fi + +data=$1 +lang=$2 # Note: may be graph directory not lang directory, but has the necessary stuff copied. +dir=$3 + +model=$dir/../$iter.mdl # assume model one level up from decoding dir. + +hubscr=$KALDI_ROOT/tools/sctk/bin/hubscr.pl +[ ! -f $hubscr ] && echo "Cannot find scoring program at $hubscr" && exit 1; +hubdir=`dirname $hubscr` + +for f in $data/stm $data/glm $lang/words.txt $lang/phones/word_boundary.int \ + $model $data/segments $data/reco2file_and_channel $dir/lat.1.gz; do + [ ! -f $f ] && echo "$0: expecting file $f to exist" && exit 1; +done + +# name=`basename $data`; # e.g. eval2000 +nj=$(cat $dir/num_jobs) + +mkdir -p $dir/scoring/log + +if [ -f $dir/../frame_shift ]; then + frame_shift_opt="--frame-shift=$(cat $dir/../frame_shift)" + echo "$0: $dir/../frame_shift exists, using $frame_shift_opt" +elif [ -f $dir/../frame_subsampling_factor ]; then + factor=$(cat $dir/../frame_subsampling_factor) || exit 1 + frame_shift_opt="--frame-shift=0.0$factor" + echo "$0: $dir/../frame_subsampling_factor exists, using $frame_shift_opt" +fi + +if [ $stage -le 0 ]; then + for wip in $(echo $word_ins_penalty | sed 's/,/ /g'); do + $cmd LMWT=$min_lmwt:$max_lmwt $dir/scoring/log/get_ctm.LMWT.${wip}.log \ + set -e -o pipefail \; \ + mkdir -p $dir/score_LMWT_${wip}/ '&&' \ + lattice-scale --inv-acoustic-scale=LMWT "ark:gunzip -c $dir/lat.*.gz|" ark:- \| \ + lattice-add-penalty --word-ins-penalty=$wip ark:- ark:- \| \ + lattice-prune --beam=$beam ark:- ark:- \| \ + lattice-align-words --output-error-lats=true --max-expand=10.0 --test=false \ + $lang/phones/word_boundary.int $model ark:- ark:- \| \ + lattice-to-ctm-conf --decode-mbr=$decode_mbr $frame_shift_opt ark:- - \| \ + utils/int2sym.pl -f 5 $lang/words.txt \| \ + utils/convert_ctm.pl $data/segments $data/reco2file_and_channel \| \ + sort -k1,1 -k2,2 -k3,3nb '>' $dir/score_LMWT_${wip}/ctm || exit 1; + done +fi + +if [ $stage -le 1 ]; then + # Remove some stuff we don't want to score, from the ctm. + for x in $dir/score_*/ctm; do + # `-i` is not needed in the following. It is added for robustness in ase this code is copy-pasted + # into another script that, e.g., uses instead of + grep -v -w -i '' <$x > ${x}.filt || exit 1; + done +fi + +# Score the set... +if [ $stage -le 2 ]; then + for wip in $(echo $word_ins_penalty | sed 's/,/ /g'); do + $cmd LMWT=$min_lmwt:$max_lmwt $dir/scoring/log/score.LMWT.${wip}.log \ + cp $data/stm $dir/score_LMWT_${wip}/ '&&' \ + $hubscr -p $hubdir -V -l english -h hub5 -g $data/glm -r $dir/score_LMWT_${wip}/stm $dir/score_LMWT_${wip}/ctm.filt || exit 1; + done +fi + +exit 0 diff --git a/egs/tedlium/s5_r2_wsj/local/train_lm.sh b/egs/tedlium/s5_r2_wsj/local/train_lm.sh new file mode 100755 index 00000000000..2e8f8de11f9 --- /dev/null +++ b/egs/tedlium/s5_r2_wsj/local/train_lm.sh @@ -0,0 +1,189 @@ +#!/bin/bash + +# Copyright 2016 Vincent Nguyen +# 2016 Johns Hopkins University (author: Daniel Povey) +# Apache 2.0 +# +# This script trains a LM on the Cantab-Tedlium text data and tedlium acoustic training data. +# It is based on the example scripts distributed with PocoLM + +# It will first check if pocolm is installed and if not will process with installation +# It will then get the source data from the pre-downloaded Cantab-Tedlium files +# and the pre-prepared data/train text source. + + +set -e +stage=0 +cmd=run.pl + +echo "$0 $@" # Print the command line for logging +. utils/parse_options.sh || exit 1; + +dir=data/local/local_lm +lm_dir=${dir}/data + +mkdir -p $dir +. ./path.sh || exit 1; # for KALDI_ROOT +export PATH=$KALDI_ROOT/tools/pocolm/scripts:$PATH +( # First make sure the pocolm toolkit is installed. + cd $KALDI_ROOT/tools || exit 1; + if [ -d pocolm ]; then + echo Not installing the pocolm toolkit since it is already there. + else + echo "$0: Please install the PocoLM toolkit with: " + echo " cd ../../../tools; extras/install_pocolm.sh; cd -" + exit 1; + fi +) || exit 1; + +num_dev_sentences=10000 +bypass_metaparam_optim_opt= + +if [ $stage -le 0 ]; then + mkdir -p ${dir}/data + mkdir -p ${dir}/data/text + + echo "$0: Getting the Data sources" + + rm ${dir}/data/text/* 2>/dev/null || true + + # Unzip TEDLIUM 6 data sources, normalize apostrophe+suffix to previous word, gzip the result. + gunzip -c db/TEDLIUM_release2/LM/*.en.gz | sed 's/ <\/s>//g' | \ + local/join_suffix.py | awk '{print "foo "$0}' | \ + local/normalize_transcript.pl '' | cut -d ' ' -f 2- | gzip -c > ${dir}/data/text/train.txt.gz + # use a subset of the annotated training data as the dev set . + # Note: the name 'dev' is treated specially by pocolm, it automatically + # becomes the dev set. + head -n $num_dev_sentences < data/train/text | cut -d " " -f 2- > ${dir}/data/text/dev.txt + # .. and the rest of the training data as an additional data source. + # we can later fold the dev data into this. + tail -n +$[$num_dev_sentences+1] < data/train/text | cut -d " " -f 2- > ${dir}/data/text/ted.txt + + cat data/train_si284/text | cut -d " " -f 2- > ${dir}/data/text/wsj_si284.txt + + # for reporting perplexities, we'll use the "real" dev set. + # (a subset of the training data is used as ${dir}/data/text/ted.txt to work + # out interpolation weights. + # note, we can't put it in ${dir}/data/text/, because then pocolm would use + # it as one of the data sources. + cut -d " " -f 2- < data/dev/text > ${dir}/data/real_dev_set.txt +fi + +if [ $stage -le 1 ]; then + mkdir -p $dir/data/work + get_word_counts.py $dir/data/text $dir/data/work/word_counts + touch $dir/data/work/word_counts/.done +fi + +if [ $stage -le 2 ]; then + # decide on the vocabulary. + + cat $dir/data/work/word_counts/{ted,dev}.counts | \ + local/lm/merge_word_counts.py 2 > $dir/data/work/ted.wordlist_counts + + cat $dir/data/work/word_counts/train.counts | \ + local/lm/merge_word_counts.py 5 > $dir/data/work/train.wordlist_counts + + cat $dir/data/work/word_counts/wsj_si284.counts | \ + local/lm/merge_word_counts.py 2 > $dir/data/work/wsj_si284.wordlist_counts + + cat $dir/data/work/{ted,train,wsj_si284}.wordlist_counts | \ + perl -ane 'if ($F[1] =~ m/[A-Za-z]/) { print "$F[0] $F[1]\n"; }' | \ + local/lm/merge_word_counts.py 1 | sort -k 1,1nr > $dir/data/work/final.wordlist_counts + + if [ ! -z "$vocab_size" ]; then + awk -v sz=$vocab_size 'BEGIN{count=-1;} + { i+=1; + if (i == int(sz)) { + count = $1; + }; + if (count > 0 && count != $1) { + exit(0); + } + print $0; + }' $dir/data/work/final.wordlist_counts + else + cat $dir/data/work/final.wordlist_counts + fi | awk '{print $2}' > $dir/data/work/wordlist +fi + +order=4 +wordlist=${dir}/data/work/wordlist +min_counts='train=2 ted=1 wsj_si284=5' + +# Uncomment these if you want to remove WSJ data from LM. It should not +# affect much. WSJ data improves perplexity by a couple of points. +# min_counts='train=2 ted=1' +# [ -f $dir/data/text/wsj_si284.txt ] && mv $dir/data/text/wsj_si284.txt $dir/data/ +# [ -f $dir/data/work/word_counts/wsj_si284.counts ] && mv $dir/data/work/word_counts/wsj_si284.counts $dir/data/work + +lm_name="`basename ${wordlist}`_${order}" +if [ -n "${min_counts}" ]; then + lm_name+="_`echo ${min_counts} | tr -s "[:blank:]" "_" | tr "," "." | tr "=" "-"`" +fi +unpruned_lm_dir=${lm_dir}/${lm_name}.pocolm + +if [ $stage -le 3 ]; then + echo "$0: training the unpruned LM" + + $cmd ${unpruned_lm_dir}/log/train.log \ + train_lm.py --wordlist=${wordlist} --num-splits=10 --warm-start-ratio=20 \ + --limit-unk-history=true \ + --fold-dev-into=ted ${bypass_metaparam_optim_opt} \ + --min-counts="${min_counts}" \ + ${dir}/data/text ${order} ${lm_dir}/work ${unpruned_lm_dir} + + for x in real_dev_set; do + $cmd ${unpruned_lm_dir}/log/compute_data_prob_${x}.log \ + get_data_prob.py ${dir}/data/${x}.txt ${unpruned_lm_dir} + + cat ${unpruned_lm_dir}/log/compute_data_prob_${x}.log | grep -F '[perplexity' + done + # Preplexity with just cantab-tedlium LM and Ted text: [perplexity = 157.87] over 18290.0 words + # Perplexity with WSJ text added: + # log-prob of data/local/local_lm/data/real_dev_set.txt given model data/local/local_lm/data/wordlist_4_train-2_ted-1_wsj_si284-5.pocolm was -5.05607815615 per word [perplexity = 156.973681282] over 18290.0 words. + +fi + +if [ $stage -le 4 ]; then + echo "$0: pruning the LM (to larger size)" + # Using 10 million n-grams for a big LM for rescoring purposes. + size=10000000 + $cmd ${dir}/data/lm_${order}_prune_big/log/prune_lm.log \ + prune_lm_dir.py --target-num-ngrams=$size --initial-threshold=0.02 ${unpruned_lm_dir} ${dir}/data/lm_${order}_prune_big + + for x in real_dev_set; do + $cmd ${dir}/data/lm_${order}_prune_big/log/compute_data_prob_${x}.log \ + get_data_prob.py ${dir}/data/${x}.txt ${dir}/data/lm_${order}_prune_big + + cat ${dir}/data/lm_${order}_prune_big/log/compute_data_prob_${x}.log | grep -F '[perplexity' + done + + # current results, after adding --limit-unk-history=true: + # get_data_prob.py: log-prob of data/local/local_lm/data/real_dev_set.txt given model data/local/local_lm/data/lm_4_prune_big was -5.16562818753 per word [perplexity = 175.147449465] over 18290.0 words. + + + mkdir -p ${dir}/data/arpa + format_arpa_lm.py ${dir}/data/lm_${order}_prune_big | gzip -c > ${dir}/data/arpa/${order}gram_big.arpa.gz +fi + +if [ $stage -le 5 ]; then + echo "$0: pruning the LM (to smaller size)" + # Using 2 million n-grams for a smaller LM for graph building. Prune from the + # bigger-pruned LM, it'll be faster. + size=2000000 + $cmd ${dir}/data/lm_${order}_prune_small/log/prune_lm.log \ + prune_lm_dir.py --target-num-ngrams=$size ${dir}/data/lm_${order}_prune_big ${dir}/data/lm_${order}_prune_small + + for x in real_dev_set; do + $cmd ${dir}/data/lm_${order}_prune_small/log/compute_data_prob_${x}.log \ + get_data_prob.py ${dir}/data/${x}.txt ${dir}/data/lm_${order}_prune_small + + cat ${dir}/data/lm_${order}_prune_small/log/compute_data_prob_${x}.log | grep -F '[perplexity' + done + + # current results, after adding --limit-unk-history=true (needed for modeling OOVs and not blowing up LG.fst): + # get_data_prob.py: log-prob of data/local/local_lm/data/real_dev_set.txt given model data/local/local_lm/data/lm_4_prune_small was -5.29432352378 per word [perplexity = 199.202824404 over 18290.0 words. + + format_arpa_lm.py ${dir}/data/lm_${order}_prune_small | gzip -c > ${dir}/data/arpa/${order}gram_small.arpa.gz +fi diff --git a/egs/tedlium/s5_r2_wsj/local/tuning/run_segmentation_wsj_d.sh b/egs/tedlium/s5_r2_wsj/local/tuning/run_segmentation_wsj_d.sh new file mode 100644 index 00000000000..000a09c9159 --- /dev/null +++ b/egs/tedlium/s5_r2_wsj/local/tuning/run_segmentation_wsj_d.sh @@ -0,0 +1,152 @@ +#! /bin/bash + +. ./cmd.sh +. ./path.sh + +set -e -o pipefail -u + +# This differs from _e by using the option --align-full-hyp true, which +# retrieves the best match (based on Levenshtein distance) of the +# reference with the full hypothesis for the segment, +# as against the best matching subsequence using Smith-Waterman alignment. + +segment_stage=-10 +affix=_1d +decode_nj=30 +cleanup_stage=-10 + +############################################################################### +# Segment long recordings using TF-IDF retrieval of reference text +# for uniformly segmented audio chunks based on +# modified Levenshtein alignment. +# Use a model trained on WSJ train_si84 (tri2b) +############################################################################### + +### +# STAGE 0 +### + +utils/data/convert_data_dir_to_whole.sh data/train data/train_long +steps/make_mfcc.sh --nj 40 --cmd "$train_cmd" \ + data/train_long exp/make_mfcc/train_long mfcc +steps/compute_cmvn_stats.sh \ + data/train_long exp/make_mfcc/train_long mfcc +utils/fix_data_dir.sh data/train_long + +steps/cleanup/segment_long_utterances.sh \ + --cmd "$train_cmd" --nj 80 \ + --stage $segment_stage \ + --max-bad-proportion 0.5 --align-full-hyp true \ + exp/wsj_tri2b data/lang_nosp data/train_long data/train_long/text data/train_reseg${affix} \ + exp/segment_wsj_long_utts${affix}_train + +steps/compute_cmvn_stats.sh \ + data/train_reseg${affix} exp/make_mfcc/train_reseg${affix} mfcc +utils/fix_data_dir.sh data/train_reseg${affix} + +rm -r data/train_reseg${affix}/split20 || true +steps/align_fmllr.sh --nj 20 --cmd "$train_cmd" \ + data/train_reseg${affix} data/lang_nosp exp/wsj_tri4a exp/wsj_tri4${affix}_ali_train_reseg${affix} || exit 1; + +steps/train_sat.sh --cmd "$train_cmd" 5000 100000 \ + data/train_reseg${affix} data/lang_nosp \ + exp/wsj_tri4${affix}_ali_train_reseg${affix} exp/tri4${affix} || exit 1; + +( +utils/mkgraph.sh data/lang_nosp exp/tri4${affix} exp/tri4${affix}/graph_nosp + +for dset in dev test; do + steps/decode_fmllr.sh --nj $decode_nj --cmd "$decode_cmd" --num-threads 4 \ + exp/tri4${affix}/graph_nosp data/${dset} exp/tri4${affix}/decode_${dset} + steps/lmrescore_const_arpa.sh --cmd "$decode_cmd" data/lang_nosp data/lang_nosp_rescore \ + data/${dset} exp/tri4${affix}/decode_${dset} \ + exp/tri4${affix}/decode_${dset}_rescore +done +) & + +new_affix=`echo $affix | perl -ne 'm/(\S+)([0-9])(\S+)/; print $1 . ($2+1) . $3;'` + +### +# STAGE 1 +### + +steps/cleanup/segment_long_utterances.sh \ + --cmd "$train_cmd" --nj 80 \ + --stage $segment_stage \ + --max-bad-proportion 0.75 --align-full-hyp true \ + exp/tri4${affix} data/lang_nosp data/train_long data/train_long/text data/train_reseg${new_affix} \ + exp/segment_long_utts${new_affix}_train + +steps/compute_cmvn_stats.sh data/train_reseg${new_affix} +utils/fix_data_dir.sh data/train_reseg${new_affix} + +steps/align_fmllr.sh --nj 20 --cmd "$train_cmd" \ + data/train_reseg${new_affix} data/lang_nosp exp/tri4${affix} exp/tri4${affix}_ali_train_reseg${new_affix} || exit 1; + +steps/train_sat.sh --cmd "$train_cmd" 5000 100000 \ + data/train_reseg${new_affix} data/lang_nosp \ + exp/tri4${affix}_ali_train_reseg${new_affix} exp/tri5${new_affix} || exit 1; + +utils/mkgraph.sh data/lang_nosp exp/tri5${new_affix} exp/tri5${new_affix}/graph_nosp + +for dset in dev test; do + steps/decode_fmllr.sh --nj $decode_nj --cmd "$decode_cmd" --num-threads 4 \ + exp/tri5${new_affix}/graph_nosp data/${dset} exp/tri5${new_affix}/decode_nosp_${dset} + steps/lmrescore_const_arpa.sh --cmd "$decode_cmd" data/lang_nosp data/lang_nosp_rescore \ + data/${dset} exp/tri5${new_affix}/decode_nosp_${dset} \ + exp/tri5${new_affix}/decode_nosp_${dset}_rescore +done + +### +# STAGE 2 +### + +srcdir=exp/tri5${new_affix} +cleanup_affix=cleaned +cleaned_data=data/train_reseg${new_affix}_${cleanup_affix} +cleaned_dir=${srcdir}_${cleanup_affix} + +steps/cleanup/clean_and_segment_data.sh --stage $cleanup_stage --nj 80 \ + --cmd "$train_cmd" \ + data/train_reseg${new_affix} data/lang_nosp $srcdir \ + ${srcdir}_${cleanup_affix}_work $cleaned_data + +steps/align_fmllr.sh --nj 40 --cmd "$train_cmd" \ + $cleaned_data data/lang_nosp $srcdir ${srcdir}_ali_${cleanup_affix} + +steps/train_sat.sh --cmd "$train_cmd" \ + 5000 100000 $cleaned_data data/lang_nosp ${srcdir}_ali_${cleanup_affix} \ + ${cleaned_dir} + +utils/mkgraph.sh data/lang_nosp $cleaned_dir ${cleaned_dir}/graph_nosp + +for dset in dev test; do + steps/decode_fmllr.sh --nj $decode_nj --cmd "$decode_cmd" --num-threads 4 \ + ${cleaned_dir}/graph_nosp data/${dset} ${cleaned_dir}/decode_nosp_${dset} + steps/lmrescore_const_arpa.sh --cmd "$decode_cmd" data/lang_nosp data/lang_nosp_rescore \ + data/${dset} ${cleaned_dir}/decode_nosp_${dset} \ + ${cleaned_dir}/decode_nosp_${dset}_rescore +done + +exit 0 + +# Baseline | %WER 17.9 | 507 17783 | 85.1 10.5 4.4 3.0 17.9 90.9 | -0.055 | exp/tri3_cleaned/decode_dev_rescore/score_15_0.0/ctm.filt.filt.sys +# STAGE 0 | %WER 19.5 | 507 17783 | 83.8 10.7 5.5 3.3 19.5 93.1 | -0.141 | exp/tri4_1d/decode_dev_rescore/score_14_0.0/ctm.filt.filt.sys +# STAGE 1 | %WER 19.2 | 507 17783 | 84.0 10.7 5.3 3.2 19.2 91.5 | -0.166 | exp/tri5_2d/decode_nosp_dev_rescore/score_14_0.0/ctm.filt.filt.sys +# STAGE 2 | %WER 19.1 | 507 17783 | 84.1 10.7 5.2 3.2 19.1 91.1 | -0.193 | exp/tri5_2d_cleaned/decode_nosp_dev_rescore/score_14_0.0/ctm.filt.filt.sys + +# Baseline | %WER 16.6 | 1155 27500 | 85.8 10.9 3.4 2.4 16.6 86.4 | -0.058 | exp/tri3_cleaned/decode_test_rescore/score_15_0.0/ctm.filt.filt.sys +# STAGE 0 | %WER 18.1 | 1155 27500 | 84.2 11.7 4.1 2.3 18.1 87.3 | -0.034 | exp/tri4_1d/decode_test_rescore/score_13_0.0/ctm.filt.filt.sys +# STAGE 1 | %WER 18.0 | 1155 27500 | 84.4 11.6 4.0 2.4 18.0 87.1 | -0.057 | exp/tri5_2d/decode_nosp_test_rescore/score_13_0.0/ctm.filt.filt.sys +# STAGE 2 | %WER 17.7 | 1155 27500 | 84.6 11.4 3.9 2.3 17.7 87.4 | -0.076 | exp/tri5_2d_cleaned/decode_nosp_test_rescore/score_13_0.0/ctm.filt.filt.sys + +# Baseline | %WER 19.0 | 507 17783 | 83.9 11.4 4.7 2.9 19.0 92.1 | -0.054 | exp/tri3_cleaned/decode_dev/score_13_0.5/ctm.filt.filt.sys +# STAGE 0 | %WER 20.5 | 507 17783 | 83.0 11.6 5.3 3.5 20.5 94.5 | -0.103 | exp/tri4_1d/decode_dev/score_13_0.0/ctm.filt.filt.sys +# STAGE 1 | %WER 20.2 | 507 17783 | 83.2 11.7 5.1 3.4 20.2 94.9 | -0.128 | exp/tri5_2d/decode_nosp_dev/score_13_0.0/ctm.filt.filt.sys +# STAGE 2 | %WER 20.1 | 507 17783 | 83.4 11.7 4.8 3.6 20.1 92.9 | -0.159 | exp/tri5_2d_cleaned/decode_nosp_dev/score_12_0.0/ctm.filt.filt.sys + +# Baseline | %WER 17.6 | 1155 27500 | 84.8 11.7 3.5 2.4 17.6 87.6 | 0.001 | exp/tri3_cleaned/decode_test/score_15_0.0/ctm.filt.filt.sys +# STAGE 0 | %WER 19.2 | 1155 27500 | 83.3 12.7 4.0 2.5 19.2 88.0 | -0.011 | exp/tri4_1d/decode_test/score_12_0.0/ctm.filt.filt.sys +# STAGE 1 | %WER 19.1 | 1155 27500 | 83.4 12.5 4.2 2.4 19.1 88.8 | 0.004 | exp/tri5_2d/decode_nosp_test/score_13_0.0/ctm.filt.filt.sys +# STAGE 2 | %WER 18.8 | 1155 27500 | 83.7 12.4 3.9 2.5 18.8 88.7 | -0.049 | exp/tri5_2d_cleaned/decode_nosp_test/score_12_0.0/ctm.filt.filt.sys + diff --git a/egs/tedlium/s5_r2_wsj/local/tuning/run_segmentation_wsj_e.sh b/egs/tedlium/s5_r2_wsj/local/tuning/run_segmentation_wsj_e.sh new file mode 100644 index 00000000000..59bac197202 --- /dev/null +++ b/egs/tedlium/s5_r2_wsj/local/tuning/run_segmentation_wsj_e.sh @@ -0,0 +1,189 @@ +#! /bin/bash + +. ./cmd.sh +. ./path.sh + +set -e -o pipefail -u + +# This differs from _d by using the --align-full-hyp false, which +# gets best matching subsequence of reference and hypothesis using +# Smith-Waterman alignment, +# as against using Levenshtein distance w.r.t. full hypothesis. + +# _d +# STAGE 2 | %WER 19.1 | 507 17783 | 84.1 10.7 5.2 3.2 19.1 91.1 | -0.193 | exp/tri5_2d_cleaned/decode_nosp_dev_rescore/score_14_0.0/ctm.filt.filt.sys +# STAGE 2 | %WER 17.7 | 1155 27500 | 84.6 11.4 3.9 2.3 17.7 87.4 | -0.076 | exp/tri5_2d_cleaned/decode_nosp_test_rescore/score_13_0.0/ctm.filt.filt.sys + +# _e +# STAGE 3 | %WER 17.4 | 1155 27500 | 85.0 11.4 3.7 2.4 17.4 87.2 | -0.086 | exp/tri5_2e_cleaned/decode_nosp_test_rescore/score_12_0.0/ctm.filt.filt.sys +# STAGE 3 | %WER 18.8 | 507 17783 | 84.4 10.7 4.9 3.2 18.8 91.3 | -0.162 | exp/tri5_2e_cleaned/decode_nosp_dev_rescore/score_14_0.0/ctm.filt.filt.sys + +# Note: Better results can be obtained by using silence and pronunciation +# probs as seen in STAGE 2. + +segment_stage=-10 +affix=_1e +decode_nj=30 +cleanup_stage=-10 + +############################################################################### +# Segment long recordings using TF-IDF retrieval of reference text +# for uniformly segmented audio chunks based on Smith-Waterman alignment. +# Use a model trained on WSJ train_si84 (tri2b) +############################################################################### + +### +# STAGE 0 +### + +utils/data/convert_data_dir_to_whole.sh data/train data/train_long +steps/make_mfcc.sh --nj 40 --cmd "$train_cmd" \ + data/train_long exp/make_mfcc/train_long mfcc +steps/compute_cmvn_stats.sh \ + data/train_long exp/make_mfcc/train_long mfcc +utils/fix_data_dir.sh data/train_long + +steps/cleanup/segment_long_utterances.sh \ + --cmd "$train_cmd" --nj 80 \ + --stage $segment_stage \ + --max-bad-proportion 0.5 --align-full-hyp false \ + exp/wsj_tri2b data/lang_nosp data/train_long data/train_long/text data/train_reseg${affix} \ + exp/segment_wsj_long_utts${affix}_train + +steps/compute_cmvn_stats.sh \ + data/train_reseg${affix} exp/make_mfcc/train_reseg${affix} mfcc +utils/fix_data_dir.sh data/train_reseg${affix} + +rm -r data/train_reseg${affix}/split20 || true +steps/align_fmllr.sh --nj 20 --cmd "$train_cmd" \ + data/train_reseg${affix} data/lang_nosp exp/wsj_tri4a exp/wsj_tri4${affix}_ali_train_reseg${affix} || exit 1; + +steps/train_sat.sh --cmd "$train_cmd" 5000 100000 \ + data/train_reseg${affix} data/lang_nosp \ + exp/wsj_tri4${affix}_ali_train_reseg${affix} exp/tri4${affix} || exit 1; + +utils/mkgraph.sh data/lang_nosp exp/tri4${affix} exp/tri4${affix}/graph_nosp + +for dset in dev test; do + steps/decode_fmllr.sh --nj $decode_nj --cmd "$decode_cmd" --num-threads 4 \ + exp/tri4${affix}/graph_nosp data/${dset} exp/tri4${affix}/decode_${dset} + steps/lmrescore_const_arpa.sh --cmd "$decode_cmd" data/lang_nosp data/lang_nosp_rescore \ + data/${dset} exp/tri4${affix}/decode_${dset} \ + exp/tri4${affix}/decode_${dset}_rescore +done + +new_affix=`echo $affix | perl -ne 'm/(\S+)([0-9])(\S+)/; print $1 . ($2+1) . $3;'` + +### +# STAGE 1 +### + +steps/cleanup/segment_long_utterances.sh \ + --cmd "$train_cmd" --nj 80 \ + --stage $segment_stage \ + --max-bad-proportion 0.75 --align-full-hyp false \ + exp/tri4${affix} data/lang_nosp data/train_long data/train_long/text data/train_reseg${new_affix} \ + exp/segment_long_utts${new_affix}_train + +steps/compute_cmvn_stats.sh data/train_reseg${new_affix} +utils/fix_data_dir.sh data/train_reseg${new_affix} + +steps/align_fmllr.sh --nj 20 --cmd "$train_cmd" \ + data/train_reseg${new_affix} data/lang_nosp exp/tri4${affix} exp/tri4${affix}_ali_train_reseg${new_affix} || exit 1; + +steps/train_sat.sh --cmd "$train_cmd" 5000 100000 \ + data/train_reseg${new_affix} data/lang_nosp \ + exp/tri4${affix}_ali_train_reseg${new_affix} exp/tri5${new_affix} || exit 1; + +utils/mkgraph.sh data/lang_nosp exp/tri5${new_affix} exp/tri5${new_affix}/graph_nosp + +for dset in dev test; do + steps/decode_fmllr.sh --nj $decode_nj --cmd "$decode_cmd" --num-threads 4 \ + exp/tri5${new_affix}/graph_nosp data/${dset} exp/tri5${new_affix}/decode_nosp_${dset} + steps/lmrescore_const_arpa.sh --cmd "$decode_cmd" data/lang_nosp data/lang_nosp_rescore \ + data/${dset} exp/tri5${new_affix}/decode_nosp_${dset} \ + exp/tri5${new_affix}/decode_nosp_${dset}_rescore +done + +### +# STAGE 2 +### + +steps/get_prons.sh --cmd "$train_cmd" data/train_reseg${new_affix} \ + data/lang_nosp exp/tri5${new_affix} +utils/dict_dir_add_pronprobs.sh --max-normalize true \ + data/local/dict_nosp exp/tri5${new_affix}/pron_counts_nowb.txt \ + exp/tri5${new_affix}/sil_counts_nowb.txt \ + exp/tri5${new_affix}/pron_bigram_counts_nowb.txt data/local/dict + +utils/prepare_lang.sh data/local/dict "" data/local/lang data/lang +cp -rT data/lang data/lang_rescore +cp data/lang_nosp/G.fst data/lang/ +cp data/lang_nosp_rescore/G.carpa data/lang_rescore/ + +utils/mkgraph.sh data/lang exp/tri5${new_affix} exp/tri5${new_affix}/graph + +for dset in dev test; do + steps/decode_fmllr.sh --nj $decode_nj --cmd "$decode_cmd" --num-threads 4 \ + exp/tri5${new_affix}/graph data/${dset} exp/tri5${new_affix}/decode_${dset} + steps/lmrescore_const_arpa.sh --cmd "$decode_cmd" data/lang data/lang_rescore \ + data/${dset} exp/tri5${new_affix}/decode_${dset} \ + exp/tri5${new_affix}/decode_${dset}_rescore +done + +### +# STAGE 3 +### + +srcdir=exp/tri5${new_affix} +cleanup_affix=cleaned +cleaned_data=data/train_reseg${new_affix}_${cleanup_affix} +cleaned_dir=${srcdir}_${cleanup_affix} + +steps/cleanup/clean_and_segment_data.sh --stage $cleanup_stage --nj 80 \ + --cmd "$train_cmd" \ + data/train_reseg${new_affix} data/lang_nosp $srcdir \ + ${srcdir}_${cleanup_affix}_work $cleaned_data + +steps/align_fmllr.sh --nj 40 --cmd "$train_cmd" \ + $cleaned_data data/lang_nosp $srcdir ${srcdir}_ali_${cleanup_affix} + +steps/train_sat.sh --cmd "$train_cmd" \ + 5000 100000 $cleaned_data data/lang_nosp ${srcdir}_ali_${cleanup_affix} \ + ${cleaned_dir} + +utils/mkgraph.sh data/lang_nosp $cleaned_dir ${cleaned_dir}/graph_nosp + +for dset in dev test; do + steps/decode_fmllr.sh --nj $decode_nj --cmd "$decode_cmd" --num-threads 4 \ + ${cleaned_dir}/graph_nosp data/${dset} ${cleaned_dir}/decode_nosp_${dset} + steps/lmrescore_const_arpa.sh --cmd "$decode_cmd" data/lang_nosp data/lang_nosp_rescore \ + data/${dset} ${cleaned_dir}/decode_nosp_${dset} \ + ${cleaned_dir}/decode_nosp_${dset}_rescore +done + +exit 0 + +# Baseline | %WER 17.9 | 507 17783 | 85.1 10.5 4.4 3.0 17.9 90.9 | -0.055 | exp/tri3_cleaned/decode_dev_rescore/score_15_0.0/ctm.filt.filt.sys +# STAGE 0 | %WER 19.3 | 507 17783 | 83.9 10.8 5.2 3.2 19.3 92.3 | -0.178 | exp/tri4_1e/decode_dev_rescore/score_14_0.0/ctm.filt.filt.sys +# STAGE 1 | %WER 18.8 | 507 17783 | 84.4 10.7 4.9 3.2 18.8 91.7 | -0.199 | exp/tri5_2e/decode_nosp_dev_rescore/score_13_0.0/ctm.filt.filt.sys +# STAGE 2 | %WER 18.4 | 507 17783 | 84.7 10.4 4.8 3.2 18.4 91.7 | -0.192 | exp/tri5_2e/decode_dev_rescore/score_14_0.5/ctm.filt.filt.sys +# STAGE 3 | %WER 18.8 | 507 17783 | 84.4 10.7 4.9 3.2 18.8 91.3 | -0.162 | exp/tri5_2e_cleaned/decode_nosp_dev_rescore/score_14_0.0/ctm.filt.filt.sys + +# Baseline | %WER 16.6 | 1155 27500 | 85.8 10.9 3.4 2.4 16.6 86.4 | -0.058 | exp/tri3_cleaned/decode_test_rescore/score_15_0.0/ctm.filt.filt.sys +# STAGE 0 | %WER 18.0 | 1155 27500 | 84.4 11.7 3.9 2.4 18.0 87.5 | -0.038 | exp/tri4_1e/decode_test_rescore/score_13_0.0/ctm.filt.filt.sys +# STAGE 1 | %WER 17.7 | 1155 27500 | 84.7 11.4 3.9 2.3 17.7 87.0 | -0.044 | exp/tri5_2e/decode_nosp_test_rescore/score_13_0.0/ctm.filt.filt.sys +# STAGE 2 | %WER 16.8 | 1155 27500 | 85.7 11.0 3.3 2.5 16.8 86.6 | -0.066 | exp/tri5_2e/decode_test_rescore/score_14_0.0/ctm.filt.filt.sys +# STAGE 3 | %WER 17.4 | 1155 27500 | 85.0 11.4 3.7 2.4 17.4 87.2 | -0.086 | exp/tri5_2e_cleaned/decode_nosp_test_rescore/score_12_0.0/ctm.filt.filt.sys + +# Baseline | %WER 19.0 | 507 17783 | 83.9 11.4 4.7 2.9 19.0 92.1 | -0.054 | exp/tri3_cleaned/decode_dev/score_13_0.5/ctm.filt.filt.sys +# STAGE 0 | %WER 20.5 | 507 17783 | 82.8 11.9 5.3 3.3 20.5 94.1 | -0.098 | exp/tri4_1e/decode_dev/score_14_0.0/ctm.filt.filt.sys +# STAGE 1 | %WER 19.8 | 507 17783 | 83.3 11.5 5.2 3.2 19.8 94.7 | -0.133 | exp/tri5_2e/decode_nosp_dev/score_14_0.0/ctm.filt.filt.sys +# STAGE 2 | %WER 19.5 | 507 17783 | 83.9 11.4 4.7 3.5 19.5 94.1 | -0.120 | exp/tri5_2e/decode_dev/score_16_0.0/ctm.filt.filt.sys +# STAGE 3 | %WER 20.0 | 507 17783 | 83.5 11.7 4.8 3.5 20.0 93.3 | -0.111 | exp/tri5_2e_cleaned/decode_nosp_dev/score_13_0.0/ctm.filt.filt.sys + +# Baseline | %WER 17.6 | 1155 27500 | 84.8 11.7 3.5 2.4 17.6 87.6 | 0.001 | exp/tri3_cleaned/decode_test/score_15_0.0/ctm.filt.filt.sys +# STAGE 0 | %WER 19.1 | 1155 27500 | 83.4 12.5 4.1 2.5 19.1 88.6 | 0.022 | exp/tri4_1e/decode_test/score_13_0.0/ctm.filt.filt.sys +# STAGE 1 | %WER 18.7 | 1155 27500 | 83.7 12.2 4.1 2.4 18.7 88.1 | 0.007 | exp/tri5_2e/decode_nosp_test/score_13_0.0/ctm.filt.filt.sys +# STAGE 2 | %WER 17.9 | 1155 27500 | 84.8 11.9 3.3 2.7 17.9 87.5 | -0.015 | exp/tri5_2e/decode_test/score_13_0.0/ctm.filt.filt.sys +# STAGE 3 | %WER 18.4 | 1155 27500 | 83.9 12.1 4.0 2.3 18.4 88.1 | -0.015 | exp/tri5_2e_cleaned/decode_nosp_test/score_13_0.0/ctm.filt.filt.sys diff --git a/egs/tedlium/s5_r2_wsj/local/wsj_data_prep.sh b/egs/tedlium/s5_r2_wsj/local/wsj_data_prep.sh new file mode 100755 index 00000000000..62174ec4349 --- /dev/null +++ b/egs/tedlium/s5_r2_wsj/local/wsj_data_prep.sh @@ -0,0 +1,158 @@ +#!/bin/bash + +# Copyright 2009-2012 Microsoft Corporation Johns Hopkins University (Author: Daniel Povey) +# Apache 2.0. + + +if [ $# -le 3 ]; then + echo "Arguments should be a list of WSJ directories, see ../run.sh for example." + exit 1; +fi + + +dir=`pwd`/data/local/data +lmdir=`pwd`/data/local/nist_lm +mkdir -p $dir $lmdir +local=`pwd`/local +utils=`pwd`/utils + +. ./path.sh # Needed for KALDI_ROOT +sph2pipe=$KALDI_ROOT/tools/sph2pipe_v2.5/sph2pipe +if [ ! -x $sph2pipe ]; then + echo "Could not find (or execute) the sph2pipe program at $sph2pipe"; + exit 1; +fi + +if [ -z $IRSTLM ] ; then + export IRSTLM=$KALDI_ROOT/tools/irstlm/ +fi +export PATH=${PATH}:$IRSTLM/bin +if ! command -v prune-lm >/dev/null 2>&1 ; then + echo "$0: Error: the IRSTLM is not available or compiled" >&2 + echo "$0: Error: We used to install it by default, but." >&2 + echo "$0: Error: this is no longer the case." >&2 + echo "$0: Error: To install it, go to $KALDI_ROOT/tools" >&2 + echo "$0: Error: and run extras/install_irstlm.sh" >&2 + exit 1 +fi + +cd $dir +# Make directory of links to the WSJ disks such as 11-13.1. This relies on the command +# line arguments being absolute pathnames. +rm -r links/ 2>/dev/null +mkdir links/ +ln -s $* links + +# Do some basic checks that we have what we expected. +if [ ! -d links/11-13.1 -o ! -d links/13-34.1 -o ! -d links/11-2.1 ]; then + echo "wsj_data_prep.sh: Spot check of command line arguments failed" + echo "Command line arguments must be absolute pathnames to WSJ directories" + echo "with names like 11-13.1." + echo "Note: if you have old-style WSJ distribution," + echo "local/cstr_wsj_data_prep.sh may work instead, see run.sh for example." + exit 1; +fi + +# This version for SI-84 + +cat links/11-13.1/wsj0/doc/indices/train/tr_s_wv1.ndx | \ + $local/ndx2flist.pl $* | sort | \ + grep -v -i 11-2.1/wsj0/si_tr_s/401 > train_si84.flist + +nl=`cat train_si84.flist | wc -l` +[ "$nl" -eq 7138 ] || echo "Warning: expected 7138 lines in train_si84.flist, got $nl" + +# This version for SI-284 +cat links/13-34.1/wsj1/doc/indices/si_tr_s.ndx \ + links/11-13.1/wsj0/doc/indices/train/tr_s_wv1.ndx | \ + $local/ndx2flist.pl $* | sort | \ + grep -v -i 11-2.1/wsj0/si_tr_s/401 > train_si284.flist + +nl=`cat train_si284.flist | wc -l` +[ "$nl" -eq 37416 ] || echo "Warning: expected 37416 lines in train_si284.flist, got $nl" + +# Now for the test sets. +# links/13-34.1/wsj1/doc/indices/readme.doc +# describes all the different test sets. +# Note: each test-set seems to come in multiple versions depending +# on different vocabulary sizes, verbalized vs. non-verbalized +# pronunciations, etc. We use the largest vocab and non-verbalized +# pronunciations. +# The most normal one seems to be the "baseline 60k test set", which +# is h1_p0. + +# Nov'92 (333 utts) +# These index files have a slightly different format; +# have to add .wv1 +cat links/11-13.1/wsj0/doc/indices/test/nvp/si_et_20.ndx | \ + $local/ndx2flist.pl $* | awk '{printf("%s.wv1\n", $1)}' | \ + sort > test_eval92.flist + +# Nov'92 (330 utts, 5k vocab) +cat links/11-13.1/wsj0/doc/indices/test/nvp/si_et_05.ndx | \ + $local/ndx2flist.pl $* | awk '{printf("%s.wv1\n", $1)}' | \ + sort > test_eval92_5k.flist + +# Nov'93: (213 utts) +# Have to replace a wrong disk-id. +cat links/13-32.1/wsj1/doc/indices/wsj1/eval/h1_p0.ndx | \ + sed s/13_32_1/13_33_1/ | \ + $local/ndx2flist.pl $* | sort > test_eval93.flist + +# Nov'93: (213 utts, 5k) +cat links/13-32.1/wsj1/doc/indices/wsj1/eval/h2_p0.ndx | \ + sed s/13_32_1/13_33_1/ | \ + $local/ndx2flist.pl $* | sort > test_eval93_5k.flist + +# Dev-set for Nov'93 (503 utts) +cat links/13-34.1/wsj1/doc/indices/h1_p0.ndx | \ + $local/ndx2flist.pl $* | sort > test_dev93.flist + +# Dev-set for Nov'93 (513 utts, 5k vocab) +cat links/13-34.1/wsj1/doc/indices/h2_p0.ndx | \ + $local/ndx2flist.pl $* | sort > test_dev93_5k.flist + + +# Dev-set Hub 1,2 (503, 913 utterances) + +# Note: the ???'s below match WSJ and SI_DT, or wsj and si_dt. +# Sometimes this gets copied from the CD's with upcasing, don't know +# why (could be older versions of the disks). +find `readlink links/13-16.1`/???1/??_??_20 -print | grep -i ".wv1" | sort > dev_dt_20.flist +find `readlink links/13-16.1`/???1/??_??_05 -print | grep -i ".wv1" | sort > dev_dt_05.flist + + +# Finding the transcript files: +for x in $*; do find -L $x -iname '*.dot'; done > dot_files.flist + +# Convert the transcripts into our format (no normalization yet) +for x in train_si84 train_si284 test_eval92 test_eval93 test_dev93 test_eval92_5k test_eval93_5k test_dev93_5k dev_dt_05 dev_dt_20; do + $local/flist2scp.pl $x.flist | sort > ${x}_sph.scp + cat ${x}_sph.scp | awk '{print $1}' | $local/find_transcripts.pl dot_files.flist > $x.trans1 +done + +# Do some basic normalization steps. At this point we don't remove OOVs-- +# that will be done inside the training scripts, as we'd like to make the +# data-preparation stage independent of the specific lexicon used. +noiseword=""; +for x in train_si84 train_si284 test_eval92 test_eval93 test_dev93 test_eval92_5k test_eval93_5k test_dev93_5k dev_dt_05 dev_dt_20; do + cat $x.trans1 | $local/normalize_transcript.pl $noiseword | sort > $x.txt || exit 1; +done + +# Create scp's with wav's. (the wv1 in the distribution is not really wav, it is sph.) +for x in train_si84 train_si284 test_eval92 test_eval93 test_dev93 test_eval92_5k test_eval93_5k test_dev93_5k dev_dt_05 dev_dt_20; do + awk '{printf("%s '$sph2pipe' -f wav %s |\n", $1, $2);}' < ${x}_sph.scp > ${x}_wav.scp +done + +# Make the utt2spk and spk2utt files. +for x in train_si84 train_si284 test_eval92 test_eval93 test_dev93 test_eval92_5k test_eval93_5k test_dev93_5k dev_dt_05 dev_dt_20; do + cat ${x}_sph.scp | awk '{print $1}' | perl -ane 'chop; m:^...:; print "$_ $&\n";' > $x.utt2spk + cat $x.utt2spk | $utils/utt2spk_to_spk2utt.pl > $x.spk2utt || exit 1; +done + + +#in case we want to limit lm's on most frequent words, copy lm training word frequency list +cp links/13-32.1/wsj1/doc/lng_modl/vocab/wfl_64.lst $lmdir +chmod u+w $lmdir/*.lst # had weird permissions on source. + +echo "Data preparation succeeded" diff --git a/egs/tedlium/s5_r2_wsj/local/wsj_format_data.sh b/egs/tedlium/s5_r2_wsj/local/wsj_format_data.sh new file mode 100755 index 00000000000..4bbfd3d3014 --- /dev/null +++ b/egs/tedlium/s5_r2_wsj/local/wsj_format_data.sh @@ -0,0 +1,34 @@ +#!/bin/bash + +# Copyright 2012 Microsoft Corporation Johns Hopkins University (Author: Daniel Povey) +# 2015 Guoguo Chen +# 2016 Vimal Manohar +# Apache 2.0 + +# This script takes data prepared in a corpus-dependent way +# in data/local/, and converts it into the "canonical" form, +# in various subdirectories of data/, e.g. +# data/train_si284, data/train_si84, etc. + +# Don't bother doing train_si84 separately (although we have the file lists +# in data/local/) because it's just the first 7138 utterances in train_si284. +# We'll create train_si84 after doing the feature extraction. + +echo "$0 $@" # Print the command line for logging +. utils/parse_options.sh || exit 1; + +. ./path.sh || exit 1; + +echo "Preparing train and test data" +srcdir=data/local/data + +for x in train_si284 test_eval92 test_eval93 test_dev93 test_eval92_5k test_eval93_5k test_dev93_5k dev_dt_05 dev_dt_20; do + mkdir -p data/$x + cp $srcdir/${x}_wav.scp data/$x/wav.scp || exit 1; + cp $srcdir/$x.txt data/$x/text || exit 1; + cp $srcdir/$x.spk2utt data/$x/spk2utt || exit 1; + cp $srcdir/$x.utt2spk data/$x/utt2spk || exit 1; +done + +echo "Succeeded in formatting data." + diff --git a/egs/tedlium/s5_r2_wsj/path.sh b/egs/tedlium/s5_r2_wsj/path.sh new file mode 100755 index 00000000000..92b43679edf --- /dev/null +++ b/egs/tedlium/s5_r2_wsj/path.sh @@ -0,0 +1,7 @@ +export KALDI_ROOT=`pwd`/../../.. +[ -f $KALDI_ROOT/tools/env.sh ] && . $KALDI_ROOT/tools/env.sh +export PATH=$PWD/utils/:$KALDI_ROOT/tools/openfst/bin:$PWD:$PATH:$KALDI_ROOT/tools/sph2pipe_v2.5 +[ ! -f $KALDI_ROOT/tools/config/common_path.sh ] && echo >&2 "The standard file $KALDI_ROOT/tools/config/common_path.sh is not present -> Exit!" && exit 1 +. $KALDI_ROOT/tools/config/common_path.sh +export PATH=$PATH:/home/vmanoha1/kaldi-asr-diarization/src/segmenterbin +export LC_ALL=C diff --git a/egs/tedlium/s5_r2_wsj/results.sh b/egs/tedlium/s5_r2_wsj/results.sh new file mode 100755 index 00000000000..11b23294d80 --- /dev/null +++ b/egs/tedlium/s5_r2_wsj/results.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +filter_regexp=. +[ $# -ge 1 ] && filter_regexp=$1 + +for x in exp/*/decode*; do + [ -d $x ] && grep WER $x/wer_* | utils/best_wer.sh; +done 2>/dev/null + +for x in exp/*{mono,tri,sgmm,nnet,dnn,lstm,chain}*/decode*; do + [ -d $x ] && grep Sum $x/score_*/*.sys | utils/best_wer.sh; +done 2>/dev/null | grep $filter_regexp + +for x in exp/*{nnet,dnn,lstm,chain}*/*/decode*; do + [ -d $x ] && grep Sum $x/score_*/*.sys | utils/best_wer.sh; +done 2>/dev/null | grep $filter_regexp + +exit 0 + diff --git a/egs/tedlium/s5_r2_wsj/run.sh b/egs/tedlium/s5_r2_wsj/run.sh new file mode 100755 index 00000000000..70d9a1d6927 --- /dev/null +++ b/egs/tedlium/s5_r2_wsj/run.sh @@ -0,0 +1,127 @@ +#!/bin/bash +# +# This recipe uses WSJ models and TED-LIUM audio with un-aligned transcripts. +# +# http://www-lium.univ-lemans.fr/en/content/ted-lium-corpus +# http://www.openslr.org/resources (Mirror). +# +# The data is distributed under 'Creative Commons BY-NC-ND 3.0' license, +# which allow free non-commercial use, while only a citation is required. +# +# Copyright 2014 Nickolay V. Shmyrev +# 2014 Brno University of Technology (Author: Karel Vesely) +# 2016 Vincent Nguyen +# 2016 Johns Hopkins University (Author: Daniel Povey) +# +# Apache 2.0 +# + +. cmd.sh +. path.sh + + +set -e -o pipefail -u + +nj=35 +decode_nj=30 # note: should not be >38 which is the number of speakers in the dev set + # after applying --seconds-per-spk-max 180. We decode with 4 threads, so + # this will be too many jobs if you're using run.pl. + +. utils/parse_options.sh # accept options + +# Data preparation +local/download_data.sh + +wsj0=/export/corpora5/LDC/LDC93S6B +wsj1=/export/corpora5/LDC/LDC94S13B +local/wsj_data_prep.sh $wsj0/??-{?,??}.? $wsj1/??-{?,??}.? || exit 1; + +local/wsj_format_data.sh + +local/prepare_data.sh + +# Split speakers up into 3-minute chunks. This doesn't hurt adaptation, and +# lets us use more jobs for decoding etc. +# [we chose 3 minutes because that gives us 38 speakers for the dev data, which is +# more than our normal 30 jobs.] +for dset in dev test train; do + utils/data/modify_speaker_info.sh --seconds-per-spk-max 180 data/${dset}.orig data/${dset} +done + +local/train_lm.sh + +local/prepare_dict.sh --dict-suffix "_nosp" \ + data/local/local_lm/data/work/wordlist + +utils/prepare_lang.sh data/local/dict_nosp \ + "" data/local/lang_nosp data/lang_nosp + +local/format_lms.sh + +# Feature extraction +for set in train_si284; do + dir=data/$set + steps/make_mfcc.sh --nj 30 --cmd "$train_cmd" $dir + steps/compute_cmvn_stats.sh $dir + utils/fix_data_dir.sh $dir +done + +utils/subset_data_dir.sh --first data/train_si284 7138 data/train_si84 || exit 1 + +# Now make subset with the shortest 2k utterances from si-84. +utils/subset_data_dir.sh --shortest data/train_si84 2000 data/train_si84_2kshort || exit 1; + +# Now make subset with half of the data from si-84. +utils/subset_data_dir.sh data/train_si84 3500 data/train_si84_half || exit 1; + +# Note: the --boost-silence option should probably be omitted by default +# for normal setups. It doesn't always help. [it's to discourage non-silence +# models from modeling silence.] +steps/train_mono.sh --boost-silence 1.25 --nj 10 --cmd "$train_cmd" \ + data/train_si84_2kshort data/lang_nosp exp/wsj_mono0a || exit 1; + +steps/align_si.sh --boost-silence 1.25 --nj 10 --cmd "$train_cmd" \ + data/train_si84_half data/lang_nosp exp/wsj_mono0a exp/wsj_mono0a_ali || exit 1; + +steps/train_deltas.sh --boost-silence 1.25 --cmd "$train_cmd" 2000 10000 \ + data/train_si84_half data/lang_nosp exp/wsj_mono0a_ali exp/wsj_tri1 || exit 1; + +steps/align_si.sh --nj 10 --cmd "$train_cmd" \ + data/train_si84 data/lang_nosp exp/wsj_tri1 exp/wsj_tri1_ali_si84 || exit 1; + +steps/train_lda_mllt.sh --cmd "$train_cmd" \ + --splice-opts "--left-context=3 --right-context=3" 2500 15000 \ + data/train_si84 data/lang_nosp exp/wsj_tri1_ali_si84 exp/wsj_tri2b || exit 1; + +# Align tri2b system with si84 data. +steps/align_si.sh --nj 10 --cmd "$train_cmd" \ + --use-graphs true data/train_si84 \ + data/lang_nosp exp/wsj_tri2b exp/wsj_tri2b_ali_si84 || exit 1; + +# From 2b system, train 3b which is LDA + MLLT + SAT. +steps/train_sat.sh --cmd "$train_cmd" 2500 15000 \ + data/train_si84 data/lang_nosp exp/wsj_tri2b_ali_si84 exp/wsj_tri3b || exit 1; + +# From 3b system, align all si284 data. +steps/align_fmllr.sh --nj 20 --cmd "$train_cmd" \ + data/train_si284 data/lang_nosp exp/wsj_tri3b exp/wsj_tri3b_ali_si284 || exit 1; + +# From 3b system, train another SAT system (tri4a) with all the si284 data. +steps/train_sat.sh --cmd "$train_cmd" 4200 40000 \ + data/train_si284 data/lang_nosp exp/wsj_tri3b_ali_si284 exp/wsj_tri4a || exit 1; + +utils/mkgraph.sh data/lang_nosp exp/wsj_tri4a exp/wsj_tri4a/graph_nosp + +( +for dset in dev test; do + steps/decode_fmllr.sh --nj $decode_nj --cmd "$decode_cmd" --num-threads 4 \ + exp/wsj_tri4a/graph_nosp data/${dset} exp/wsj_tri4a/decode_nosp_${dset} + steps/lmrescore_const_arpa.sh --cmd "$decode_cmd" data/lang_nosp data/lang_nosp_rescore \ + data/${dset} exp/wsj_tri4a/decode_nosp_${dset} exp/wsj_tri4a/decode_nosp_${dset}_rescore +done +) & + +wait + +echo "$0: success." +exit 0 diff --git a/egs/tedlium/s5_r2_wsj/steps b/egs/tedlium/s5_r2_wsj/steps new file mode 120000 index 00000000000..6e99bf5b5ad --- /dev/null +++ b/egs/tedlium/s5_r2_wsj/steps @@ -0,0 +1 @@ +../../wsj/s5/steps \ No newline at end of file diff --git a/egs/tedlium/s5_r2_wsj/utils b/egs/tedlium/s5_r2_wsj/utils new file mode 120000 index 00000000000..b240885218f --- /dev/null +++ b/egs/tedlium/s5_r2_wsj/utils @@ -0,0 +1 @@ +../../wsj/s5/utils \ No newline at end of file diff --git a/egs/wsj/s5/local/run_segmentation.sh b/egs/wsj/s5/local/run_segmentation.sh index 458536162cb..3c0b8e5b0a8 100755 --- a/egs/wsj/s5/local/run_segmentation.sh +++ b/egs/wsj/s5/local/run_segmentation.sh @@ -7,9 +7,14 @@ # The basic idea is to decode with an existing in-domain acoustic model, and a # bigram language model built from the reference, and then work out the # segmentation from a ctm like file. +# See the script local/run_segmentation_long_utts.sh for +# a more sophesticated approach using Smith-Waterman alignment +# to align decoded hypothesis and parts of imperfect long-transcripts # retrieved using TF-IDF document similarities. stage=0 +. utils/parse_options.sh + . ./cmd.sh . ./path.sh @@ -24,13 +29,13 @@ if [ $stage -le 1 ]; then steps/make_mfcc.sh --cmd "$train_cmd" --nj 64 \ data/train_si284_split exp/make_mfcc/train_si284_split mfcc || exit 1; steps/compute_cmvn_stats.sh data/train_si284_split \ - exp/make_mfcc/train_si284_split mfcc || exit 1; + exp/make_mfcc/train_si284_split mfcc || exit 1; fi if [ $stage -le 2 ]; then steps/cleanup/make_segmentation_graph.sh \ --cmd "$mkgraph_cmd" --nj 32 \ - data/train_si284_split/ data/lang exp/tri2b/ \ + data/train_si284_split/ data/lang_nosp exp/tri2b/ \ exp/tri2b/graph_train_si284_split || exit 1; fi @@ -50,31 +55,31 @@ if [ $stage -le 5 ]; then steps/cleanup/make_segmentation_data_dir.sh --wer-cutoff 0.9 \ --min-sil-length 0.5 --max-seg-length 15 --min-seg-length 1 \ exp/tri2b/decode_train_si284_split/score_10/train_si284_split.ctm \ - data/train_si284_split data/train_si284_reseg + data/train_si284_split data/train_si284_reseg_a fi # Now, use the re-segmented data for training. if [ $stage -le 6 ]; then steps/make_mfcc.sh --cmd "$train_cmd" --nj 64 \ - data/train_si284_reseg exp/make_mfcc/train_si284_reseg mfcc || exit 1; - steps/compute_cmvn_stats.sh data/train_si284_reseg \ - exp/make_mfcc/train_si284_reseg mfcc || exit 1; + data/train_si284_reseg_a exp/make_mfcc/train_si284_reseg_a mfcc || exit 1; + steps/compute_cmvn_stats.sh data/train_si284_reseg_a \ + exp/make_mfcc/train_si284_reseg_a mfcc || exit 1; fi if [ $stage -le 7 ]; then steps/align_fmllr.sh --nj 20 --cmd "$train_cmd" \ - data/train_si284_reseg data/lang exp/tri3b exp/tri3b_ali_si284_reseg || exit 1; + data/train_si284_reseg_a data/lang_nosp exp/tri3b exp/tri3b_ali_si284_reseg_a || exit 1; fi if [ $stage -le 8 ]; then steps/train_sat.sh --cmd "$train_cmd" \ - 4200 40000 data/train_si284_reseg \ - data/lang exp/tri3b_ali_si284_reseg exp/tri4c || exit 1; + 4200 40000 data/train_si284_reseg_a \ + data/lang_nosp exp/tri3b_ali_si284_reseg_a exp/tri4c || exit 1; fi if [ $stage -le 9 ]; then - utils/mkgraph.sh data/lang_test_tgpr exp/tri4c exp/tri4c/graph_tgpr || exit 1; + utils/mkgraph.sh data/lang_nosp_test_tgpr exp/tri4c exp/tri4c/graph_tgpr || exit 1; steps/decode_fmllr.sh --nj 10 --cmd "$decode_cmd" \ exp/tri4c/graph_tgpr data/test_dev93 exp/tri4c/decode_tgpr_dev93 || exit 1; steps/decode_fmllr.sh --nj 8 --cmd "$decode_cmd" \ diff --git a/egs/wsj/s5/local/run_segmentation_long_utts.sh b/egs/wsj/s5/local/run_segmentation_long_utts.sh new file mode 100644 index 00000000000..b2f4362edcb --- /dev/null +++ b/egs/wsj/s5/local/run_segmentation_long_utts.sh @@ -0,0 +1,233 @@ +#!/bin/bash + +# Copyright 2016 Vimal Manohar +# Apache 2.0 + +set -e -o pipefail + +# This script demonstrates how to re-segment long audios into short segments. +# The basic idea is to decode with an existing in-domain acoustic model, and a +# bigram language model built from the reference, and then work out the +# segmentation from a ctm like file. +# This is similar to the script local/run_segmentation.sh, but +# uses a more sophesticated approach using Smith-Waterman alignment +# to align decoded hypothesis and parts of imperfect long-transcripts # retrieved using TF-IDF document similarities. + +## %WER results. + +## Baseline with manual transcripts +# %WER 7.87 [ 444 / 5643, 114 ins, 25 del, 305 sub ] exp/tri4a/decode_nosp_tgpr_eval92/wer_13_1.0 +# %WER 11.84 [ 975 / 8234, 187 ins, 107 del, 681 sub ] exp/tri4a/decode_nosp_tgpr_dev93/wer_17_0.5 + +## Baseline using local/run_segmentation.sh +# %WER 7.76 [ 438 / 5643, 119 ins, 22 del, 297 sub ] exp/tri4c/decode_tgpr_eval92/wer_14_0.5 +# %WER 12.41 [ 1022 / 8234, 216 ins, 96 del, 710 sub ] exp/tri4c/decode_tgpr_dev93/wer_17_0.0 + +## Training directly on segmented data directory train_si284_reseg +# %WER 7.69 [ 434 / 5643, 105 ins, 27 del, 302 sub ] exp/tri3c_reseg_d/decode_nosp_tgpr_eval92/wer_15_0.5 +# %WER 7.78 [ 439 / 5643, 105 ins, 20 del, 314 sub ] exp/tri4c_reseg_d/decode_nosp_tgpr_eval92/wer_15_0.5 +# %WER 7.43 [ 419 / 5643, 95 ins, 29 del, 295 sub ] exp/tri4c_reseg_e/decode_nosp_tgpr_eval92/wer_16_1.0 + +# %WER 12.04 [ 991 / 8234, 187 ins, 119 del, 685 sub ] exp/tri4c_reseg_d/decode_nosp_tgpr_dev93/wer_16_1.0 +# %WER 12.29 [ 1012 / 8234, 224 ins, 105 del, 683 sub ] exp/tri3c_reseg_d/decode_nosp_tgpr_dev93/wer_14_0.5 +# %WER 12.08 [ 995 / 8234, 199 ins, 113 del, 683 sub ] exp/tri4c_reseg_e/decode_nosp_tgpr_dev93/wer_16_0.5 + +## Using additional stage of cleanup. +# %WER 7.71 [ 435 / 5643, 100 ins, 33 del, 302 sub ] exp/tri4d_e_cleaned_a/decode_nosp_tgpr_eval92/wer_16_1.0 +# %WER 7.78 [ 439 / 5643, 109 ins, 18 del, 312 sub ] exp/tri4d_e_cleaned_c/decode_nosp_tgpr_eval92/wer_15_0.5 +# %WER 7.73 [ 436 / 5643, 116 ins, 21 del, 299 sub ] exp/tri4d_e_cleaned_b/decode_nosp_tgpr_eval92/wer_15_0.5 + +# %WER 11.97 [ 986 / 8234, 190 ins, 110 del, 686 sub ] exp/tri4d_e_cleaned_c/decode_nosp_tgpr_dev93/wer_15_1.0 +# %WER 12.13 [ 999 / 8234, 211 ins, 102 del, 686 sub ] exp/tri4d_e_cleaned_a/decode_nosp_tgpr_dev93/wer_15_0.5 +# %WER 12.67 [ 1043 / 8234, 217 ins, 121 del, 705 sub ] exp/tri4d_e_cleaned_b/decode_nosp_tgpr_dev93/wer_15_1.0 + +. ./cmd.sh +. ./path.sh + +segment_stage=-1 +affix=_e + +############################################################################### +## Simulate unsegmented data directory. +############################################################################### +local/append_utterances.sh data/train_si284 data/train_si284_long + +steps/make_mfcc.sh --cmd "$train_cmd" --nj 32 \ + data/train_si284_long exp/make_mfcc/train_si284_long mfcc || exit 1 +steps/compute_cmvn_stats.sh data/train_si284_long \ + exp/make_mfcc/train_si284_long mfcc + +############################################################################### +# Segment long recordings using TF-IDF retrieval of reference text +# for uniformly segmented audio chunks based on Smith-Waterman alignment. +# Use a model trained on train_si84 (tri2b) +############################################################################### +steps/cleanup/segment_long_utterances.sh --cmd "$train_cmd" \ + --stage $segment_stage \ + --config conf/segment_long_utts.conf \ + --max-segment-duration 30 --overlap-duration 5 \ + --num-neighbors-to-search 0 --nj 80 \ + exp/tri2b data/lang_nosp data/train_si284_long data/train_si284_reseg${affix} \ + exp/segment_long_utts${affix}_train_si284 + +steps/compute_cmvn_stats.sh data/train_si284_reseg${affix} \ + exp/make_mfcc/train_si284_reseg${affix} mfcc +utils/fix_data_dir.sh data/train_si284_reseg${affix} + +############################################################################### +# Train new model on segmented data directory starting from the same model +# used for segmentation. (tri2b) +############################################################################### + +# Align tri2b system with reseg${affix} data +steps/align_si.sh --nj 40 --cmd "$train_cmd" \ + data/train_si284_reseg${affix} \ + data/lang_nosp exp/tri2b exp/tri2b_ali_si284_reseg${affix} || exit 1; + +# Train SAT system on reseg data +steps/train_sat.sh --cmd "$train_cmd" 4200 40000 \ + data/train_si284_reseg${affix} data/lang_nosp \ + exp/tri2b_ali_si284_reseg${affix} exp/tri3c_reseg${affix} + +( +utils/mkgraph.sh data/lang_nosp_test_tgpr \ + exp/tri3c_reseg${affix} exp/tri3c_reseg${affix}/graph_nosp_tgpr || exit 1; +steps/decode_fmllr.sh --nj 10 --cmd "$decode_cmd" \ + exp/tri3c_reseg${affix}/graph_nosp_tgpr data/test_dev93 \ + exp/tri3c_reseg${affix}/decode_nosp_tgpr_dev93 || exit 1; +steps/decode_fmllr.sh --nj 8 --cmd "$decode_cmd" \ + exp/tri3c_reseg${affix}/graph_nosp_tgpr data/test_eval92 \ + exp/tri3c_reseg${affix}/decode_nosp_tgpr_eval92 || exit 1; +) & + +############################################################################### +# Train new model on segmented data directory starting from a better model +# (tri3b) +############################################################################### + +# Align tri3b system with reseg data +steps/align_fmllr.sh --nj 40 --cmd "$train_cmd" \ + data/train_si284_reseg${affix} data/lang_nosp exp/tri3b \ + exp/tri3b_ali_si284_reseg${affix} + +# Train SAT system on reseg data +steps/train_sat.sh --cmd "$train_cmd" 4200 40000 \ + data/train_si284_reseg${affix} data/lang_nosp \ + exp/tri3b_ali_si284_reseg${affix} exp/tri4c_reseg${affix} + +( + utils/mkgraph.sh data/lang_nosp_test_tgpr \ + exp/tri4c_reseg${affix} exp/tri4c_reseg${affix}/graph_nosp_tgpr || exit 1; + steps/decode_fmllr.sh --nj 10 --cmd "$decode_cmd" \ + exp/tri4c_reseg${affix}/graph_nosp_tgpr data/test_dev93 \ + exp/tri4c_reseg${affix}/decode_nosp_tgpr_dev93 || exit 1; + steps/decode_fmllr.sh --nj 8 --cmd "$decode_cmd" \ + exp/tri4c_reseg${affix}/graph_nosp_tgpr data/test_eval92 \ + exp/tri4c_reseg${affix}/decode_nosp_tgpr_eval92 || exit 1; +) & + +############################################################################### +# cleaned_a : Cleanup the segmented data directory using tri3b model. +############################################################################### + +steps/cleanup/clean_and_segment_data.sh --cmd "$train_cmd" \ + --nj 80 \ + data/train_si284_reseg${affix} data/lang_nosp \ + exp/tri3b_ali_si284_reseg${affix} exp/tri3b_work_si284_reseg${affix} data/train_si284_reseg${affix}_cleaned_a + +############################################################################### +# Train new model on the cleaned_a data directory +############################################################################### + +# Align tri3b system with cleaned data +steps/align_fmllr.sh --nj 40 --cmd "$train_cmd" \ + data/train_si284_reseg${affix}_cleaned_a data/lang_nosp exp/tri3b \ + exp/tri3b_ali_si284_reseg${affix}_cleaned_a + +# Train SAT system on cleaned data +steps/train_sat.sh --cmd "$train_cmd" 4200 40000 \ + data/train_si284_reseg${affix}_cleaned_a data/lang_nosp \ + exp/tri3b_ali_si284_reseg${affix}_cleaned_a exp/tri4d${affix}_cleaned_a + +( + utils/mkgraph.sh data/lang_nosp_test_tgpr \ + exp/tri4d${affix}_cleaned_a exp/tri4d${affix}_cleaned_a/graph_nosp_tgpr || exit 1; + steps/decode_fmllr.sh --nj 10 --cmd "$decode_cmd" \ + exp/tri4d${affix}_cleaned_a/graph_nosp_tgpr data/test_dev93 \ + exp/tri4d${affix}_cleaned_a/decode_nosp_tgpr_dev93 || exit 1; + steps/decode_fmllr.sh --nj 8 --cmd "$decode_cmd" \ + exp/tri4d${affix}_cleaned_a/graph_nosp_tgpr data/test_eval92 \ + exp/tri4d${affix}_cleaned_a/decode_nosp_tgpr_eval92 || exit 1; +) & + +############################################################################### +# cleaned_b : Cleanup the segmented data directory using the tri3c_reseg +# model, which is a like a first-pass model trained on the resegmented data. +############################################################################### + +steps/cleanup/clean_and_segment_data.sh --cmd "$train_cmd" \ + --nj 80 \ + data/train_si284_reseg${affix} data/lang_nosp \ + exp/tri3c_reseg${affix} exp/tri3c_reseg${affix}_work_si284_reseg${affix} \ + data/train_si284_reseg${affix}_cleaned_b + +############################################################################### +# Train new model on the cleaned_b data directory +############################################################################### + +# Align tri3c_reseg system with cleaned data +steps/align_fmllr.sh --nj 40 --cmd "$train_cmd" \ + data/train_si284_reseg${affix}_cleaned_b data/lang_nosp exp/tri3c_reseg${affix} \ + exp/tri3c_reseg${affix}_ali_si284_reseg${affix}_cleaned_b + +# Train SAT system on cleaned data +steps/train_sat.sh --cmd "$train_cmd" 4200 40000 \ + data/train_si284_reseg${affix}_cleaned_b data/lang_nosp \ + exp/tri3c_reseg${affix}_ali_si284_reseg${affix}_cleaned_b exp/tri4d${affix}_cleaned_b + +( + utils/mkgraph.sh data/lang_nosp_test_tgpr \ + exp/tri4d${affix}_cleaned_b exp/tri4d${affix}_cleaned_b/graph_nosp_tgpr || exit 1; + steps/decode_fmllr.sh --nj 10 --cmd "$decode_cmd" \ + exp/tri4d${affix}_cleaned_b/graph_nosp_tgpr data/test_dev93 \ + exp/tri4d${affix}_cleaned_b/decode_nosp_tgpr_dev93 || exit 1; + steps/decode_fmllr.sh --nj 8 --cmd "$decode_cmd" \ + exp/tri4d${affix}_cleaned_b/graph_nosp_tgpr data/test_eval92 \ + exp/tri4d${affix}_cleaned_b/decode_nosp_tgpr_eval92 || exit 1; +) & + +############################################################################### +# cleaned_c : Cleanup the segmented data directory using the tri4c_reseg +# model, which is a like a first-pass model trained on the resegmented data. +############################################################################### + +steps/cleanup/clean_and_segment_data.sh --cmd "$train_cmd" \ + --nj 80 \ + data/train_si284_reseg${affix} data/lang_nosp \ + exp/tri4c_reseg${affix} exp/tri4c_reseg${affix}_work_si284_reseg${affix} \ + data/train_si284_reseg${affix}_cleaned_c + +############################################################################### +# Train new model on the cleaned_c data directory +############################################################################### + +# Align tri4c_reseg system with cleaned data +steps/align_fmllr.sh --nj 40 --cmd "$train_cmd" \ + data/train_si284_reseg${affix}_cleaned_c data/lang_nosp exp/tri4c_reseg${affix} \ + exp/tri4c_reseg${affix}_ali_si284_reseg${affix}_cleaned_c + +# Train SAT system on cleaned data +steps/train_sat.sh --cmd "$train_cmd" 4200 40000 \ + data/train_si284_reseg${affix}_cleaned_c data/lang_nosp \ + exp/tri4c_reseg${affix}_ali_si284_reseg${affix}_cleaned_c exp/tri4d${affix}_cleaned_c + +( + utils/mkgraph.sh data/lang_nosp_test_tgpr \ + exp/tri4d${affix}_cleaned_c exp/tri4d${affix}_cleaned_c/graph_nosp_tgpr || exit 1; + steps/decode_fmllr.sh --nj 10 --cmd "$decode_cmd" \ + exp/tri4d${affix}_cleaned_c/graph_nosp_tgpr data/test_dev93 \ + exp/tri4d${affix}_cleaned_c/decode_nosp_tgpr_dev93 || exit 1; + steps/decode_fmllr.sh --nj 8 --cmd "$decode_cmd" \ + exp/tri4d${affix}_cleaned_c/graph_nosp_tgpr data/test_eval92 \ + exp/tri4d${affix}_cleaned_c/decode_nosp_tgpr_eval92 || exit 1; +) & diff --git a/egs/wsj/s5/steps/cleanup/clean_and_segment_data.sh b/egs/wsj/s5/steps/cleanup/clean_and_segment_data.sh index a3babeca598..a523de30e6f 100755 --- a/egs/wsj/s5/steps/cleanup/clean_and_segment_data.sh +++ b/egs/wsj/s5/steps/cleanup/clean_and_segment_data.sh @@ -82,7 +82,7 @@ if [ $stage -le 1 ]; then echo "$0: Building biased-language-model decoding graphs..." steps/cleanup/make_biased_lm_graphs.sh $graph_opts \ --nj $nj --cmd "$cmd" \ - $data $lang $dir + $data $lang $dir $dir/graphs fi if [ $stage -le 2 ]; then @@ -100,7 +100,7 @@ if [ $stage -le 2 ]; then steps/cleanup/decode_segmentation.sh \ --beam 15.0 --nj $nj --cmd "$cmd --mem 4G" $transform_opt \ --skip-scoring true --allow-partial false \ - $dir $data $dir/lats + $dir/graphs $data $dir/lats # the following is for diagnostics, e.g. it will give us the lattice depth. steps/diagnostic/analyze_lats.sh --cmd "$cmd" $lang $dir/lats diff --git a/egs/wsj/s5/steps/cleanup/decode_fmllr_segmentation.sh b/egs/wsj/s5/steps/cleanup/decode_fmllr_segmentation.sh new file mode 100755 index 00000000000..1dfa74ab3f6 --- /dev/null +++ b/egs/wsj/s5/steps/cleanup/decode_fmllr_segmentation.sh @@ -0,0 +1,257 @@ +#!/bin/bash + +# Copyright 2014 Guoguo Chen, 2015 GoVivace Inc. (Nagendra Goel) +# 2017 Vimal Manohar +# Apache 2.0 + +# Similar to steps/cleanup/decode_segmentation.sh, but does fMLLR adaptation. +# Decoding script with per-utterance graph that does fMLLR adaptation. +# This can be on top of delta+delta-delta, or LDA+MLLT features. + +# There are 3 models involved potentially in this script, +# and for a standard, speaker-independent system they will all be the same. +# The "alignment model" is for the 1st-pass decoding and to get the +# Gaussian-level alignments for the "adaptation model" the first time we +# do fMLLR. The "adaptation model" is used to estimate fMLLR transforms +# and to generate state-level lattices. The lattices are then rescored +# with the "final model". + +# The following table explains where we get these 3 models from. +# Note: $srcdir is one level up from the decoding directory. +# +# Model Default source: +# +# "alignment model" $srcdir/final.alimdl --alignment-model +# (or $srcdir/final.mdl if alimdl absent) +# "adaptation model" $srcdir/final.mdl --adapt-model +# "final model" $srcdir/final.mdl --final-model + +set -e +set -o pipefail + +# Begin configuration section +first_beam=10.0 # Beam used in initial, speaker-indep. pass +first_max_active=2000 # max-active used in initial pass. +alignment_model= +adapt_model= +final_model= +stage=0 +acwt=0.083333 # Acoustic weight used in getting fMLLR transforms, and also in + # lattice generation. +max_active=7000 +beam=13.0 +lattice_beam=6.0 +nj=4 +silence_weight=0.01 +cmd=run.pl +si_dir= +fmllr_update_type=full +num_threads=1 # if >1, will use gmm-latgen-faster-parallel +parallel_opts= # ignored now. +skip_scoring=false +scoring_opts= +max_fmllr_jobs=25 # I've seen the fMLLR jobs overload NFS badly if the decoding + # was started with a lot of many jobs, so we limit the number of + # parallel jobs to 25 by default. End configuration section +allow_partial=true +# End configuration section. + +echo "$0 $@" # Print the command line for logging + +[ -f ./path.sh ] && . ./path.sh; # source the path. +. parse_options.sh || exit 1; + +if [ $# != 3 ]; then + echo "$0: This is a special decoding script for segmentation where we" + echo "use one decoding graph per segment. We assume a file HCLG.fsts.scp exists" + echo "which is the scp file of the graphs for each segment." + echo "This will normally be obtained by steps/cleanup/make_biased_lm_graphs.sh." + echo "" + echo "Usage: $0 [options] " + echo " e.g.: $0 exp/tri2b/graph_train_si284_split \\" + echo " data/train_si284_split exp/tri2b/decode_train_si284_split" + echo "" + echo "where is assumed to be a sub-directory of the directory" + echo "where the model is." + echo "" + echo "main options (for others, see top of script file)" + echo " --config # config containing options" + echo " --nj # number of parallel jobs" + echo " --cmd # Command to run in parallel with" + echo " --adapt-model # Model to compute transforms with" + echo " --alignment-model # Model to get Gaussian-level alignments for" + echo " # 1st pass of transform computation." + echo " --final-model # Model to finally decode with" + echo " --si-dir # use this to skip 1st pass of decoding" + echo " # Caution-- must be with same tree" + echo " --acwt # default 0.08333 ... used to get posteriors" + echo " --num-threads # number of threads to use, default 1." + echo " --scoring-opts # options to local/score.sh" + exit 1; +fi + + +graphdir=$1 +data=$2 +dir=`echo $3 | sed 's:/$::g'` # remove any trailing slash. + +srcdir=`dirname $dir`; # Assume model directory one level up from decoding directory. +sdata=$data/split$nj; + +thread_string= +[ $num_threads -gt 1 ] && thread_string="-parallel --num-threads=$num_threads" + + +mkdir -p $dir/log +split_data.sh $data $nj || exit 1; +echo $nj > $dir/num_jobs +splice_opts=`cat $srcdir/splice_opts 2>/dev/null` || true # frame-splicing options. +cmvn_opts=`cat $srcdir/cmvn_opts 2>/dev/null` +delta_opts=`cat $srcdir/delta_opts 2>/dev/null` || true + +silphonelist=`cat $graphdir/phones/silence.csl` || exit 1; + +utils/lang/check_phones_compatible.sh $graph_dir/phones.txt $srcdir/phones.txt + +# Some checks. Note: we don't need $srcdir/tree but we expect +# it should exist, given the current structure of the scripts. +for f in $graphdir/HCLG.fsts.scp $data/feats.scp $srcdir/tree; do + [ ! -f $f ] && echo "$0: no such file $f" && exit 1; +done + +# Split HCLG.fsts.scp by input utterance +n1=$(cat $graphdir/HCLG.fsts.scp | wc -l) +n2=$(cat $data/feats.scp | wc -l) +if [ $n1 != $n2 ]; then + echo "$0: expected $n2 graphs in $graphdir/HCLG.fsts.scp, got $n1" +fi + +mkdir -p $dir/split_fsts +utils/filter_scps.pl --no-warn -f 1 JOB=1:$nj \ + $sdata/JOB/feats.scp $graphdir/HCLG.fsts.scp $dir/split_fsts/HCLG.fsts.JOB.scp +HCLG=scp:$dir/split_fsts/HCLG.fsts.JOB.scp + + +## Work out name of alignment model. ## +if [ -z "$alignment_model" ]; then + if [ -f "$srcdir/final.alimdl" ]; then alignment_model=$srcdir/final.alimdl; + else alignment_model=$srcdir/final.mdl; fi +fi +[ ! -f "$alignment_model" ] && echo "$0: no alignment model $alignment_model " && exit 1; +## + +## Do the speaker-independent decoding, if --si-dir option not present. ## +if [ -z "$si_dir" ]; then # we need to do the speaker-independent decoding pass. + si_dir=${dir}.si # Name it as our decoding dir, but with suffix ".si". + if [ $stage -le 0 ]; then + if [ -f "$graphdir/num_pdfs" ]; then + [ "`cat $graphdir/num_pdfs`" -eq `am-info --print-args=false $alignment_model | grep pdfs | awk '{print $NF}'` ] || \ + { echo "Mismatch in number of pdfs with $alignment_model"; exit 1; } + fi + steps/cleanup/decode_segmentation.sh --scoring-opts "$scoring_opts" \ + --num-threads $num_threads --skip-scoring $skip_scoring \ + --acwt $acwt --nj $nj --cmd "$cmd" --beam $first_beam \ + --model $alignment_model --max-active \ + $first_max_active $graphdir $data $si_dir || exit 1; + fi +fi +## + +## Some checks, and setting of defaults for variables. +[ "$nj" -ne "`cat $si_dir/num_jobs`" ] && echo "Mismatch in #jobs with si-dir" && exit 1; +[ ! -f "$si_dir/lat.1.gz" ] && echo "No such file $si_dir/lat.1.gz" && exit 1; +[ -z "$adapt_model" ] && adapt_model=$srcdir/final.mdl +[ -z "$final_model" ] && final_model=$srcdir/final.mdl +for f in $adapt_model $final_model; do + [ ! -f $f ] && echo "$0: no such file $f" && exit 1; +done +## + +## Set up the unadapted features "$sifeats" +if [ -f $srcdir/final.mat ]; then feat_type=lda; else feat_type=delta; fi +echo "$0: feature type is $feat_type"; +case $feat_type in + delta) sifeats="ark,s,cs:apply-cmvn $cmvn_opts --utt2spk=ark:$sdata/JOB/utt2spk scp:$sdata/JOB/cmvn.scp scp:$sdata/JOB/feats.scp ark:- | add-deltas $delta_opts ark:- ark:- |";; + lda) sifeats="ark,s,cs:apply-cmvn $cmvn_opts --utt2spk=ark:$sdata/JOB/utt2spk scp:$sdata/JOB/cmvn.scp scp:$sdata/JOB/feats.scp ark:- | splice-feats $splice_opts ark:- ark:- | transform-feats $srcdir/final.mat ark:- ark:- |";; + *) echo "Invalid feature type $feat_type" && exit 1; +esac +## + +## Now get the first-pass fMLLR transforms. +if [ $stage -le 1 ]; then + echo "$0: getting first-pass fMLLR transforms." + $cmd --max-jobs-run $max_fmllr_jobs JOB=1:$nj $dir/log/fmllr_pass1.JOB.log \ + gunzip -c $si_dir/lat.JOB.gz \| \ + lattice-to-post --acoustic-scale=$acwt ark:- ark:- \| \ + weight-silence-post $silence_weight $silphonelist $alignment_model ark:- ark:- \| \ + gmm-post-to-gpost $alignment_model "$sifeats" ark:- ark:- \| \ + gmm-est-fmllr-gpost --fmllr-update-type=$fmllr_update_type \ + --spk2utt=ark:$sdata/JOB/spk2utt $adapt_model "$sifeats" ark,s,cs:- \ + ark:$dir/pre_trans.JOB || exit 1; +fi +## + +pass1feats="$sifeats transform-feats --utt2spk=ark:$sdata/JOB/utt2spk ark:$dir/pre_trans.JOB ark:- ark:- |" + +## Do the main lattice generation pass. Note: we don't determinize the lattices at +## this stage, as we're going to use them in acoustic rescoring with the larger +## model, and it's more correct to store the full state-level lattice for this purpose. +if [ $stage -le 2 ]; then + echo "$0: doing main lattice generation phase" + if [ -f "$graphdir/num_pdfs" ]; then + [ "`cat $graphdir/num_pdfs`" -eq `am-info --print-args=false $adapt_model | grep pdfs | awk '{print $NF}'` ] || \ + { echo "Mismatch in number of pdfs with $adapt_model"; exit 1; } + fi + $cmd --num-threads $num_threads JOB=1:$nj $dir/log/decode.JOB.log \ + gmm-latgen-faster$thread_string --max-active=$max_active --beam=$beam --lattice-beam=$lattice_beam \ + --acoustic-scale=$acwt --determinize-lattice=false \ + --allow-partial=$allow_partial --word-symbol-table=$graphdir/words.txt \ + $adapt_model "$HCLG" "$pass1feats" "ark:|gzip -c > $dir/lat.tmp.JOB.gz" +fi +## + +## Do a second pass of estimating the transform-- this time with the lattices +## generated from the alignment model. Compose the transforms to get +## $dir/trans.1, etc. +if [ $stage -le 3 ]; then + echo "$0: estimating fMLLR transforms a second time." + $cmd --max-jobs-run $max_fmllr_jobs JOB=1:$nj $dir/log/fmllr_pass2.JOB.log \ + lattice-determinize-pruned$thread_string --acoustic-scale=$acwt --beam=4.0 \ + "ark:gunzip -c $dir/lat.tmp.JOB.gz|" ark:- \| \ + lattice-to-post --acoustic-scale=$acwt ark:- ark:- \| \ + weight-silence-post $silence_weight $silphonelist $adapt_model ark:- ark:- \| \ + gmm-est-fmllr --fmllr-update-type=$fmllr_update_type \ + --spk2utt=ark:$sdata/JOB/spk2utt $adapt_model "$pass1feats" \ + ark,s,cs:- ark:$dir/trans_tmp.JOB '&&' \ + compose-transforms --b-is-affine=true ark:$dir/trans_tmp.JOB ark:$dir/pre_trans.JOB \ + ark:$dir/trans.JOB || exit 1; +fi +## + +feats="$sifeats transform-feats --utt2spk=ark:$sdata/JOB/utt2spk ark:$dir/trans.JOB ark:- ark:- |" + +# Rescore the state-level lattices with the final adapted features, and the final model +# (which by default is $srcdir/final.mdl, but which may be specified on the command line, +# useful in case of discriminatively trained systems). +# At this point we prune and determinize the lattices and write them out, ready for +# language model rescoring. + +if [ $stage -le 4 ]; then + echo "$0: doing a final pass of acoustic rescoring." + $cmd --num-threads $num_threads JOB=1:$nj $dir/log/acoustic_rescore.JOB.log \ + gmm-rescore-lattice $final_model "ark:gunzip -c $dir/lat.tmp.JOB.gz|" "$feats" ark:- \| \ + lattice-determinize-pruned$thread_string --acoustic-scale=$acwt --beam=$lattice_beam ark:- \ + "ark:|gzip -c > $dir/lat.JOB.gz" '&&' rm $dir/lat.tmp.JOB.gz || exit 1; +fi + +if ! $skip_scoring ; then + [ ! -x local/score.sh ] && \ + echo "$0: Not scoring because local/score.sh does not exist or not executable." && exit 1; + local/score.sh --cmd "$cmd" $scoring_opts $data $graphdir $dir || + { echo "$0: Scoring failed. (ignore by '--skip-scoring true')"; exit 1; } +fi + +rm $dir/{trans_tmp,pre_trans}.* + +exit 0; + diff --git a/egs/wsj/s5/steps/cleanup/decode_segmentation.sh b/egs/wsj/s5/steps/cleanup/decode_segmentation.sh index a140b26173f..e07105e8e08 100755 --- a/egs/wsj/s5/steps/cleanup/decode_segmentation.sh +++ b/egs/wsj/s5/steps/cleanup/decode_segmentation.sh @@ -1,8 +1,14 @@ #!/bin/bash # Copyright 2014 Guoguo Chen, 2015 GoVivace Inc. (Nagendra Goel) +# 2017 Vimal Manohar # Apache 2.0 +# Some basic error checking, similar to steps/decode.sh is added. + +set -e +set -o pipefail + # Begin configuration section. transform_dir= # this option won't normally be used, but it can be used if you # want to supply existing fMLLR transforms when decoding. @@ -83,10 +89,22 @@ if [ -z "$model" ]; then # if --model was not specified on the command lin else model=$srcdir/$iter.mdl; fi fi +if [ $(basename $model) != final.alimdl ] ; then + # Do not use the $srcpath -- look at the path where the model is + if [ -f $(dirname $model)/final.alimdl ] && [ -z "$transform_dir" ]; then + echo -e '\n\n' + echo $0 'WARNING: Running speaker independent system decoding using a SAT model!' + echo $0 'WARNING: This is OK if you know what you are doing...' + echo -e '\n\n' + fi +fi + for f in $sdata/1/feats.scp $sdata/1/cmvn.scp $model $graphdir/HCLG.fsts.scp; do [ ! -f $f ] && echo "$0: no such file $f" && exit 1; done +utils/lang/check_phones_compatible.sh $graph_dir/phones.txt $srcdir/phones.txt + # Split HCLG.fsts.scp by input utterance n1=$(cat $graphdir/HCLG.fsts.scp | wc -l) n2=$(cat $data/feats.scp | wc -l) @@ -96,15 +114,16 @@ fi mkdir -p $dir/split_fsts -utils/filter_scps.pl --no-warn -f 1 JOB=1:$nj $sdata/JOB/feats.scp $graphdir/HCLG.fsts.scp $dir/split_fsts/HCLG.fsts.JOB.scp +utils/filter_scps.pl --no-warn -f 1 JOB=1:$nj \ + $sdata/JOB/feats.scp $graphdir/HCLG.fsts.scp $dir/split_fsts/HCLG.fsts.JOB.scp HCLG=scp:$dir/split_fsts/HCLG.fsts.JOB.scp if [ -f $srcdir/final.mat ]; then feat_type=lda; else feat_type=delta; fi echo "$0: feature type is $feat_type"; -splice_opts=`cat $srcdir/splice_opts 2>/dev/null` # frame-splicing options. -cmvn_opts=`cat $srcdir/cmvn_opts 2>/dev/null` -delta_opts=`cat $srcdir/delta_opts 2>/dev/null` +splice_opts=`cat $srcdir/splice_opts 2>/dev/null` || true # frame-splicing options. +cmvn_opts=`cat $srcdir/cmvn_opts 2>/dev/null` || true +delta_opts=`cat $srcdir/delta_opts 2>/dev/null` || true thread_string= [ $num_threads -gt 1 ] && thread_string="-parallel --num-threads=$num_threads" @@ -145,8 +164,8 @@ fi if ! $skip_scoring ; then [ ! -x local/score.sh ] && \ - echo "Not scoring because local/score.sh does not exist or not executable." && exit 1; - steps/score_kaldi.sh --cmd "$cmd" $scoring_opts $data $graphdir $dir || + echo "$0: Not scoring because local/score.sh does not exist or not executable." && exit 1; + local/score.sh --cmd "$cmd" $scoring_opts $data $graphdir $dir || { echo "$0: Scoring failed. (ignore by '--skip-scoring true')"; exit 1; } fi diff --git a/egs/wsj/s5/steps/cleanup/internal/align_ctm_ref.py b/egs/wsj/s5/steps/cleanup/internal/align_ctm_ref.py new file mode 100755 index 00000000000..8994fb7fde1 --- /dev/null +++ b/egs/wsj/s5/steps/cleanup/internal/align_ctm_ref.py @@ -0,0 +1,615 @@ +#! /usr/bin/env python + +# Copyright 2016 Vimal Manohar +# Apache 2.0. + +"""This module aligns a hypothesis (CTM or text) with a reference to +find the best matching sub-sequence in the reference for the hypothesis +using Smith-Waterman like alignment. + +e.g.: align_ctm_ref.py --hyp-format=CTM --ref=data/train/text --hyp=foo/ctm + --output=foo/ctm_edits +""" + +from __future__ import print_function +import argparse +import logging +import sys + +sys.path.insert(0, 'steps') +import libs.common as common_lib + +logger = logging.getLogger(__name__) +handler = logging.StreamHandler() +formatter = logging.Formatter("%(asctime)s [%(pathname)s:%(lineno)s - " + "%(funcName)s - %(levelname)s ] %(message)s") +handler.setFormatter(formatter) +logger.setLevel(logging.DEBUG) + +verbose_level = 0 + + +def get_args(): + parser = argparse.ArgumentParser(description=""" + This module aligns a hypothesis (CTM or text) with a reference to find the + best matching sub-sequence in the reference for the hypothesis using + Smith-Waterman like alignment. + + e.g.: align_ctm_ref.py --align-full-hyp=false --hyp-format=CTM + --reco2file-and-channel=data/foo/reco2file_and_channel --ref=data/train/text + --hyp=foo/ctm --output=foo/ctm_edits + """) + + parser.add_argument("--hyp-format", type=str, choices=["Text", "CTM"], + default="CTM", + help="Format used for the hypothesis") + parser.add_argument("--reco2file-and-channel", type=argparse.FileType('r'), + help="""reco2file_and_channel file. + This will be used to match references that are usually + indexed by the recording-id with the CTM lines that have + file and channel. This option is typically not + required.""") + parser.add_argument("--eps-symbol", type=str, default="-", + help="Symbol used to contain alignment " + "to empty symbol") + parser.add_argument("--oov-word", type=str, default=None, + action=common_lib.NullstrToNoneAction, + help="Symbol of OOV word in hypothesis") + parser.add_argument("--symbol-table", type=argparse.FileType('r'), + help="""Symbol table for words in vocabulary. Used + to determine if a word is a OOV or not""") + + parser.add_argument("--correct-score", type=int, default=1, + help="Score for correct matches") + parser.add_argument("--substitution-penalty", type=int, default=1, + help="Penalty for substitution errors") + parser.add_argument("--deletion-penalty", type=int, default=1, + help="Penalty for deletion errors") + parser.add_argument("--insertion-penalty", type=int, default=1, + help="Penalty for insertion errors") + + parser.add_argument("--align-full-hyp", type=str, + action=common_lib.StrToBoolAction, + choices=["true", "false"], default=True, + help="""Align full hypothesis i.e. trackback from + the end to get the alignment. This is different + from the normal Smith-Waterman alignment, where the + traceback will be from the maximum score.""") + + parser.add_argument("--debug-only", type=str, default="false", + choices=["true", "false"], + help="Run test functions only") + parser.add_argument("--verbose", type=int, default=0, + choices=[0, 1, 2, 3], + help="Use larger value for more verbose logging.") + + parser.add_argument("--ref", dest='ref_in_file', + type=argparse.FileType('r'), required=True, + help="Reference text file") + parser.add_argument("--hyp", dest='hyp_in_file', required=True, + type=argparse.FileType('r'), + help="Hypothesis text or CTM file") + parser.add_argument("--output", dest='alignment_out_file', required=True, + type=argparse.FileType('w'), + help="""File to write output alignment. + If hyp-format=CTM, then the output is in the form of + CTM, but with two additional columns of Edit-type and + Reference-word matched to the hypothesis.""") + + args = parser.parse_args() + + if args.hyp_format == "CTM" and args.reco2file_and_channel is None: + raise RuntimeError( + "--reco2file-and-channel must be provided for " + "hyp-format=CTM") + + args.debug_only = bool(args.debug_only == "true") + + global verbose_level + verbose_level = args.verbose + if args.verbose > 2: + handler.setLevel(logging.DEBUG) + else: + handler.setLevel(logging.INFO) + logger.addHandler(handler) + + return args + + +def read_text(text_file): + """Reads a kaldi-format text file and yield elements of a dictionary + { utterane_id : transcript (as a list of words) } + + The first-column of the text file is the utterance-id, which will be + used as the key to index the dictionary elements. + The remaining columns of the file are text of the transcript and they are + returned as a list of words. + """ + for line in text_file: + parts = line.strip().split() + if len(parts) <= 2: + raise RuntimeError( + "Did not get enough columns; line {0} in {1}" + "".format(line, text_file.name)) + yield parts[0], parts[1:] + text_file.close() + + +def read_ctm(ctm_file, file_and_channel2reco=None): + """Reads a CTM file and yields elements of a dictionary + { utterance-id : CTM for the utterance }, + where CTM for the utterance is stored as a list of lines + from a CTM correponding to the utterance. + + Note: *_reco in the variables usually correspond to utterances rather + than recordings. + """ + prev_reco = "" + ctm_lines = [] + for line in ctm_file: + try: + parts = line.strip().split() + parts[2] = float(parts[2]) + parts[3] = float(parts[3]) + + if len(parts) == 5: + parts.append(1.0) # confidence defaults to 1.0. + + if len(parts) != 6: + raise ValueError("CTM must have 6 fields.") + + if file_and_channel2reco is None: + reco = parts[0] + if parts[1] != '1': + raise ValueError("Channel should be 1, " + "got {0}".format(parts[1])) + else: + reco = file_and_channel2reco[(parts[0], parts[1])] + if prev_reco != "" and reco != prev_reco: + # New recording + yield prev_reco, ctm_lines + ctm_lines = [] + ctm_lines.append(parts[2:]) + prev_reco = reco + except Exception: + logger.error("Error in processing CTM line {0}".format(line)) + raise + if prev_reco != "" and len(ctm_lines) > 0: + yield prev_reco, ctm_lines + ctm_file.close() + + +def smith_waterman_alignment(ref, hyp, similarity_score_function, + del_score, ins_score, + eps_symbol="", align_full_hyp=True): + """Does Smith-Waterman alignment of reference sequence and hypothesis + sequence. + This is a special case of the Smith-Waterman alignment that assumes that + the deletion and insertion costs are linear with number of incorrect words. + + If align_full_hyp is True, then the traceback of the alignment + is started at the end of the hypothesis. This is when we want the + reference that aligns with the full hypothesis. + This differs from the normal Smith-Waterman alignment, where the traceback + is from the highest score in the alignment score matrix. This + can be obtained by setting align_full_hyp as False. This gets only the + sub-sequence of the hypothesis that best matches with a + sub-sequence of the reference. + + Returns a list of tuples where each tuple has the format: + (ref_word, hyp_word, ref_word_from_index, hyp_word_from_index, + ref_word_to_index, hyp_word_to_index) + """ + output = [] + + ref_len = len(ref) + hyp_len = len(hyp) + + bp = [[] for x in range(ref_len+1)] + + # Score matrix of size (ref_len + 1) x (hyp_len + 1) + # The index m, n in this matrix corresponds to the score + # of the best matching sub-sequence pair between reference and hypothesis + # ending with the reference word ref[m-1] and hypothesis word hyp[n-1]. + # If align_full_hyp is True, then the hypothesis sub-sequence is from + # the 0th word i.e. hyp[0]. + H = [[] for x in range(ref_len+1)] + + for ref_index in range(ref_len+1): + if align_full_hyp: + H[ref_index] = [-(hyp_len+2) for x in range(hyp_len+1)] + H[ref_index][0] = 0 + else: + H[ref_index] = [0 for x in range(hyp_len+1)] + bp[ref_index] = [(0, 0) for x in range(hyp_len+1)] + + if align_full_hyp and ref_index == 0: + for hyp_index in range(1, hyp_len+1): + H[0][hyp_index] = H[0][hyp_index-1] + ins_score + bp[ref_index][hyp_index] = (ref_index, hyp_index-1) + logger.debug( + "({0},{1}) -> ({2},{3}): {4}" + "".format(ref_index, hyp_index-1, ref_index, hyp_index, + H[ref_index][hyp_index])) + + max_score = -float("inf") + max_score_element = (0, 0) + + for ref_index in range(1, ref_len+1): # Reference + for hyp_index in range(1, hyp_len+1): # Hypothesis + sub_or_ok = (H[ref_index-1][hyp_index-1] + + similarity_score_function(ref[ref_index-1], + hyp[hyp_index-1])) + + if ((not align_full_hyp and sub_or_ok > 0) + or (align_full_hyp + and sub_or_ok >= H[ref_index][hyp_index])): + H[ref_index][hyp_index] = sub_or_ok + bp[ref_index][hyp_index] = (ref_index-1, hyp_index-1) + logger.debug( + "({0},{1}) -> ({2},{3}): {4} ({5},{6})" + "".format(ref_index-1, hyp_index-1, ref_index, hyp_index, + H[ref_index][hyp_index], + ref[ref_index-1], hyp[hyp_index-1])) + + if H[ref_index-1][hyp_index] + del_score > H[ref_index][hyp_index]: + H[ref_index][hyp_index] = H[ref_index-1][hyp_index] + del_score + bp[ref_index][hyp_index] = (ref_index-1, hyp_index) + logger.debug( + "({0},{1}) -> ({2},{3}): {4}" + "".format(ref_index-1, hyp_index, ref_index, hyp_index, + H[ref_index][hyp_index])) + + if H[ref_index][hyp_index-1] + ins_score > H[ref_index][hyp_index]: + H[ref_index][hyp_index] = H[ref_index][hyp_index-1] + ins_score + bp[ref_index][hyp_index] = (ref_index, hyp_index-1) + logger.debug( + "({0},{1}) -> ({2},{3}): {4}" + "".format(ref_index, hyp_index-1, ref_index, hyp_index, + H[ref_index][hyp_index])) + + #if hyp_index == hyp_len and H[ref_index][hyp_index] >= max_score: + if ((not align_full_hyp or hyp_index == hyp_len) + and H[ref_index][hyp_index] >= max_score): + max_score = H[ref_index][hyp_index] + max_score_element = (ref_index, hyp_index) + + ref_index, hyp_index = max_score_element + score = max_score + logger.debug("Alignment score: %s for (%d, %d)", + score, ref_index, hyp_index) + + while ((not align_full_hyp and score >= 0) + or (align_full_hyp and hyp_index > 0)): + try: + prev_ref_index, prev_hyp_index = bp[ref_index][hyp_index] + + if ((prev_ref_index, prev_hyp_index) == (ref_index, hyp_index) + or (prev_ref_index, prev_hyp_index) == (0, 0)): + ref_index, hyp_index = (prev_ref_index, prev_hyp_index) + score = H[ref_index][hyp_index] + break + + if (ref_index == prev_ref_index + 1 + and hyp_index == prev_hyp_index + 1): + # Substitution or correct + output.append( + (ref[ref_index-1] if ref_index > 0 else eps_symbol, + hyp[hyp_index-1] if hyp_index > 0 else eps_symbol, + prev_ref_index, prev_hyp_index, ref_index, hyp_index)) + elif (prev_hyp_index == hyp_index): + # Deletion + assert prev_ref_index == ref_index - 1 + output.append( + (ref[ref_index-1] if ref_index > 0 else eps_symbol, + eps_symbol, + prev_ref_index, prev_hyp_index, ref_index, hyp_index)) + elif (prev_ref_index == ref_index): + # Insertion + assert prev_hyp_index == hyp_index - 1 + output.append( + (eps_symbol, + hyp[hyp_index-1] if hyp_index > 0 else eps_symbol, + prev_ref_index, prev_hyp_index, ref_index, hyp_index)) + else: + raise RuntimeError + + ref_index, hyp_index = (prev_ref_index, prev_hyp_index) + score = H[ref_index][hyp_index] + except Exception: + logger.error("Unexpected entry (%d,%d) -> (%d,%d), %s, %s", + prev_ref_index, prev_hyp_index, ref_index, hyp_index, + ref[prev_ref_index], hyp[prev_hyp_index]) + raise RuntimeError("Unexpected result: Bug in code!!") + + assert (align_full_hyp or score == 0) + + output.reverse() + + if verbose_level > 2: + for ref_index in range(ref_len+1): + for hyp_index in range(hyp_len+1): + print ("{0} ".format(H[ref_index][hyp_index]), end='', + file=sys.stderr) + print ("", file=sys.stderr) + + logger.debug("Aligned output:") + logger.debug(" - ".join(["({0},{1})".format(x[4], x[5]) + for x in output])) + logger.debug("REF: ") + logger.debug(" ".join(str(x[0]) for x in output)) + logger.debug("HYP:") + logger.debug(" ".join(str(x[1]) for x in output)) + + return (output, max_score) + + +def print_alignment(recording, alignment, out_file_handle): + out_text = [recording] + for line in alignment: + try: + out_text.append(line[1]) + except Exception: + logger.error("Something wrong with alignment. " + "Invalid line {0}".format(line)) + raise + print (" ".join(out_text), file=out_file_handle) + + +def get_edit_type(hyp_word, ref_word, duration=-1, eps_symbol='', + oov_word=None, symbol_table=None): + if hyp_word == ref_word and hyp_word != eps_symbol: + return 'cor' + if hyp_word != eps_symbol and ref_word == eps_symbol: + return 'ins' + if hyp_word == eps_symbol and ref_word != eps_symbol and duration == 0.0: + return 'del' + if (hyp_word == oov_word and symbol_table is not None + and len(symbol_table) > 0 and ref_word not in symbol_table): + return 'cor' # this special case is treated as correct + if hyp_word == eps_symbol and ref_word == eps_symbol and duration > 0.0: + # silence in hypothesis; we don't match this up with any reference + # word. + return 'sil' + # The following assertion is because, based on how get_ctm_edits() + # works, we shouldn't hit this case. + assert hyp_word != eps_symbol and ref_word != eps_symbol + return 'sub' + + +def get_ctm_edits(alignment_output, ctm_array, eps_symbol="", + oov_word=None, symbol_table=None): + """ + This function takes two lists + alignment_output = The output of smith_waterman_alignment() which is a + list of tuples (ref_word, hyp_word, ref_word_from_index, + hyp_word_from_index, ref_word_to_index, hyp_word_to_index) + ctm_array = [ [ start1, duration1, hyp_word1, confidence1 ], ... ] + and pads them with new list elements so that the entries 'match up'. + + Returns CTM edits lines, which are CTM lines appended with reference word + and edit type. + + What we are aiming for is that for each i, ctm_array[i][2] == + alignment_output[i][1]. The reasons why this is not automatically true + are: + + (1) There may be insertions in the hypothesis sequence that are not + aligned with any reference words in the beginning of the + alignment_output. + (2) There may be deletions in the end of the alignment_output that + do not correspond to any additional hypothesis CTM lines. + + We introduce suitable entries in to alignment_output and ctm_array as + necessary to make them 'match up'. + """ + ctm_edits = [] + ali_len = len(alignment_output) + ctm_len = len(ctm_array) + ali_pos = 0 + ctm_pos = 0 + + # current_time is the end of the last ctm segment we processesed. + current_time = ctm_array[0][0] if ctm_len > 0 else 0.0 + + for (ref_word, hyp_word, ref_prev_i, hyp_prev_i, + ref_i, hyp_i) in alignment_output: + try: + ctm_pos = hyp_prev_i + # This is true because we cannot have errors at the end because + # that will decrease the smith-waterman alignment score. + assert ctm_pos < ctm_len + assert len(ctm_array[ctm_pos]) == 4 + + if hyp_prev_i == hyp_i: + assert hyp_word == eps_symbol + # These are deletions as there are no CTM entries + # corresponding to these alignments. + edit_type = get_edit_type( + hyp_word=eps_symbol, ref_word=ref_word, + duration=0.0, eps_symbol=eps_symbol, + oov_word=oov_word, symbol_table=symbol_table) + ctm_line = [current_time, 0.0, eps_symbol, 1.0, + ref_word, edit_type] + ctm_edits.append(ctm_line) + else: + assert hyp_i == hyp_prev_i + 1 + assert hyp_word == ctm_array[ctm_pos][2] + # This is the normal case, where there are 2 entries where + # they hyp-words match up. + ctm_line = list(ctm_array[ctm_pos]) + if hyp_word == eps_symbol and ref_word != eps_symbol: + # This is a silence in hypothesis aligned with a reference + # word. We split this into two ctm edit lines where the + # first one is a deletion of duration 0 and the second + # one is a silence of duration given by the ctm line. + edit_type = get_edit_type( + hyp_word=eps_symbol, ref_word=ref_word, + duration=0.0, eps_symbol=eps_symbol, + oov_word=oov_word, symbol_table=symbol_table) + assert edit_type == 'del' + ctm_edits.append([current_time, 0.0, eps_symbol, 1.0, + ref_word, edit_type]) + + edit_type = get_edit_type( + hyp_word=eps_symbol, ref_word=eps_symbol, + duration=ctm_line[1], eps_symbol=eps_symbol, + oov_word=oov_word, symbol_table=symbol_table) + assert edit_type == 'sil' + ctm_line.extend([eps_symbol, edit_type]) + ctm_edits.append(ctm_line) + else: + edit_type = get_edit_type( + hyp_word=hyp_word, ref_word=ref_word, + duration=ctm_line[1], eps_symbol=eps_symbol, + oov_word=oov_word, symbol_table=symbol_table) + ctm_line.extend([ref_word, edit_type]) + ctm_edits.append(ctm_line) + current_time = (ctm_array[ctm_pos][0] + + ctm_array[ctm_pos][1]) + except Exception: + logger.error("Could not get ctm edits for " + "edits@{edits_pos} = {0}, ctm@{ctm_pos} = {1}".format( + ("NONE" if ali_pos >= ali_len + else alignment_output[ali_pos]), + ("NONE" if ctm_pos >= ctm_len + else ctm_array[ctm_pos]), + edits_pos=ali_pos, ctm_pos=ctm_pos)) + logger.error("alignment = {0}".format(alignment_output)) + raise + return ctm_edits + + +def ctm_line_to_string(ctm_line): + if len(ctm_line) != 8: + raise RuntimeError("len(ctm_line) expected to be {0}. " + "Invalid line {1}".format(8, ctm_line)) + + return " ".join([str(x) for x in ctm_line]) + + +def test_alignment(): + hyp = "ACACACTA" + ref = "AGCACACA" + + output, score = smith_waterman_alignment( + ref, hyp, similarity_score_function=lambda x, y: 2 if (x == y) else -1, + del_score=-1, ins_score=-1, eps_symbol="-", align_full_hyp=True) + + print_alignment("Alignment", output, out_file_handle=sys.stderr) + + +def run(args): + if args.debug_only: + test_alignment() + raise SystemExit("Exiting since --debug-only was true") + + def similarity_score_function(x, y): + if x == y: + return args.correct_score + return -args.substitution_penalty + + del_score = -args.deletion_penalty + ins_score = -args.insertion_penalty + + reco2file_and_channel = {} + file_and_channel2reco = {} + + if args.reco2file_and_channel is not None: + for line in args.reco2file_and_channel: + parts = line.strip().split() + + reco2file_and_channel[parts[0]] = (parts[1], parts[2]) + file_and_channel2reco[(parts[1], parts[2])] = parts[0] + args.reco2file_and_channel.close() + else: + file_and_channel2reco = None + + symbol_table = {} + if args.symbol_table is not None: + for line in args.symbol_table: + parts = line.strip().split() + symbol_table[parts[0]] = int(parts[1]) + args.symbol_table.close() + + if args.hyp_format == "Text": + hyp_lines = {key: value + for (key, value) in read_text(args.hyp_in_file)} + else: + hyp_lines = {key: value + for (key, value) in read_ctm(args.hyp_in_file, + file_and_channel2reco)} + + num_err = 0 + num_done = 0 + for reco, ref_text in read_text(args.ref_in_file): + try: + if reco not in hyp_lines: + num_err += 1 + raise Warning("Could not find recording {0} " + "in hypothesis {1}".format( + reco, args.hyp_in_file.name)) + continue + + if args.hyp_format == "CTM": + hyp_array = [x[2] for x in hyp_lines[reco]] + else: + hyp_array = hyp_lines[reco] + + if args.reco2file_and_channel is None: + reco2file_and_channel[reco] = "1" + + logger.debug("Running Smith-Waterman alignment for %s", reco) + + output, score = smith_waterman_alignment( + ref_text, hyp_array, eps_symbol=args.eps_symbol, + similarity_score_function=similarity_score_function, + del_score=del_score, ins_score=ins_score, + align_full_hyp=args.align_full_hyp) + + if args.hyp_format == "CTM": + ctm_edits = get_ctm_edits(output, hyp_lines[reco], + eps_symbol=args.eps_symbol, + oov_word=args.oov_word, + symbol_table=symbol_table) + for line in ctm_edits: + ctm_line = list(reco2file_and_channel[reco]) + ctm_line.extend(line) + print(ctm_line_to_string(ctm_line), + file=args.alignment_out_file) + else: + print_alignment( + reco, output, out_file_handle=args.alignment_out_file) + num_done += 1 + except: + logger.error("Alignment failed for recording {0} " + "with ref = {1} and hyp = {2}".format( + reco, " ".join(ref_text), + " ".join(hyp_array))) + raise + + logger.info("Processed %d recordings; failed with %d", num_done, num_err) + + if num_done == 0: + raise RuntimeError("Processed 0 recordings.") + + +def main(): + args = get_args() + + try: + run(args) + except Exception: + logger.error("Failed to align ref and hypotheses; " + "got exception ", exc_info=True) + raise SystemExit(1) + finally: + if args.reco2file_and_channel is not None: + args.reco2file_and_channel.close() + args.ref_in_file.close() + args.hyp_in_file.close() + args.alignment_out_file.close() + + +if __name__ == '__main__': + main() diff --git a/egs/wsj/s5/steps/cleanup/internal/compute_tf_idf.py b/egs/wsj/s5/steps/cleanup/internal/compute_tf_idf.py new file mode 100755 index 00000000000..92d2f8a2b9d --- /dev/null +++ b/egs/wsj/s5/steps/cleanup/internal/compute_tf_idf.py @@ -0,0 +1,141 @@ +#! /usr/bin/env python + +from __future__ import print_function +import argparse +import logging +import sys + +import tf_idf +sys.path.insert(0, 'steps') + +logger = logging.getLogger('tf_idf') +logger.setLevel(logging.INFO) +handler = logging.StreamHandler() +handler.setLevel(logging.INFO) +formatter = logging.Formatter("%(asctime)s [%(filename)s:%(lineno)s - " + "%(funcName)s - %(levelname)s ] %(message)s") +handler.setFormatter(formatter) +logger.addHandler(handler) + + +def _get_args(): + parser = argparse.ArgumentParser( + description="""This script takes in a set of documents and computes the + TF-IDF for each n-gram up to the specified order. The script can also + load IDF stats from a different file instead of computing them from the + input set of documents.""") + + parser.add_argument("--tf-weighting-scheme", type=str, default="raw", + choices=["binary", "raw", "log", "normalized"], + help="""The function applied on the raw + term-frequencies f(t,d) when computing tf(t,d). + TF weighting schemes:- + binary : tf(t,d) = 1 if t in d else 0 + raw : tf(t,d) = f(t,d) + log : tf(t,d) = 1 + log(f(t,d)) + normalized : tf(t,d) = K + (1-K) * """ + """f(t,d) / max{f(t',d): t' in d}""") + parser.add_argument("--tf-normalization-factor", type=float, default=0.5, + help="K value for normalized TF weighting scheme") + parser.add_argument("--idf-weighting-scheme", type=str, default="log", + choices=["unary", "log", "log-smoothed", + "probabilistic"], + help="""The function applied on the raw + inverse-document frequencies n(t) = |d in D: t in d| + when computing idf(t,d). + IDF weighting schemes:- + unary : idf(t,D) = 1 + log : idf(t,D) = log (N / 1 + n(t)) + log-smoothed : idf(t,D) = log(1 + N / n(t)) + probabilistic: idf(t,D) = log((N - n(t)) / n(t))""") + parser.add_argument("--ngram-order", type=int, default=2, + help="Accumulate for terms upto this n-grams order") + + parser.add_argument("--input-idf-stats", type=argparse.FileType('r'), + help="If provided, IDF stats are loaded from this " + "file") + parser.add_argument("--output-idf-stats", type=argparse.FileType('w'), + help="If providied, IDF stats are written to this " + "file") + parser.add_argument("--accumulate-over-docs", type=str, default="true", + choices=["true", "false"], + help="If true, the stats are accumulated over all the " + "documents and a single tf-idf-file is written out.") + parser.add_argument("docs", type=argparse.FileType('r'), + help="Input documents in kaldi text format i.e. " + " ") + parser.add_argument("tf_idf_file", type=argparse.FileType('w'), + help="Output tf-idf for each (t,d) pair in the " + "input documents written in the format " + " ") + + args = parser.parse_args() + + if args.tf_normalization_factor >= 1.0 or args.tf_normalization_factor < 0: + raise ValueError("--tf-normalization-factor must be in [0,1)") + + args.accumulate_over_docs = bool(args.accumulate_over_docs == "true") + + if not args.accumulate_over_docs and args.input_idf_stats is None: + raise TypeError( + "If --accumulate-over-docs=false is provided, " + "then --input-idf-stats must be provided.") + + return args + + +def _run(args): + tf_stats = tf_idf.TFStats() + idf_stats = tf_idf.IDFStats() + + if args.input_idf_stats is not None: + idf_stats.read(args.input_idf_stats) + + for line in args.docs: + parts = line.strip().split() + doc = parts[0] + tf_stats.accumulate(doc, parts[1:], args.ngram_order) + + if not args.accumulate_over_docs: + # Write the document-id and the corresponding tf-idf values. + print (doc, file=args.tf_idf_file, end=' ') + tf_idf.write_tfidf_from_stats( + tf_stats, idf_stats, args.tf_idf_file, + tf_weighting_scheme=args.tf_weighting_scheme, + idf_weighting_scheme=args.idf_weighting_scheme, + tf_normalization_factor=args.tf_normalization_factor, + expected_document_id=doc) + tf_stats = tf_idf.TFStats() + + if args.accumulate_over_docs: + tf_stats.compute_term_stats(idf_stats=idf_stats + if args.input_idf_stats is None + else None) + + if args.output_idf_stats is not None: + idf_stats.write(args.output_idf_stats) + args.output_idf_stats.close() + + tf_idf.write_tfidf_from_stats( + tf_stats, idf_stats, args.tf_idf_file, + tf_weighting_scheme=args.tf_weighting_scheme, + idf_weighting_scheme=args.idf_weighting_scheme, + tf_normalization_factor=args.tf_normalization_factor) + + +def main(): + args = _get_args() + + try: + _run(args) + finally: + if args.input_idf_stats is not None: + args.input_idf_stats.close() + if args.output_idf_stats is not None: + args.output_idf_stats.close() + args.docs.close() + args.tf_idf_file.close() + + +if __name__ == '__main__': + main() diff --git a/egs/wsj/s5/steps/cleanup/internal/get_ctm.sh b/egs/wsj/s5/steps/cleanup/internal/get_ctm.sh new file mode 100755 index 00000000000..05f96ed35f3 --- /dev/null +++ b/egs/wsj/s5/steps/cleanup/internal/get_ctm.sh @@ -0,0 +1,81 @@ +#!/bin/bash +# Copyright Johns Hopkins University (Author: Daniel Povey) 2012. Apache 2.0. +# Copyright 2017 Vimal Manohar + +# This script produces CTM files from a decoding directory that has lattices +# present. +# This is similar to get_ctm.sh, but gets the +# CTM at the utterance-level. + + +# begin configuration section. +cmd=run.pl +stage=0 +frame_shift=0.01 +lmwt=10 +print_silence=false +#end configuration section. + +echo "$0 $@" # Print the command line for logging + +[ -f ./path.sh ] && . ./path.sh +. parse_options.sh || exit 1; + +if [ $# -ne 3 ]; then + echo "Usage: $0 [options] " + echo " Options:" + echo " --cmd (run.pl|queue.pl...) # specify how to run the sub-processes." + echo " --stage (0|1|2) # start scoring script from part-way through." + echo " --frame-shift (default=0.01) # specify this if your lattices have a frame-shift" + echo " # not equal to 0.01 seconds" + echo "e.g.:" + echo "$0 data/train data/lang exp/tri4a/decode/" + echo "See also: steps/get_train_ctm.sh" + exit 1; +fi + +data=$1 +lang=$2 # Note: may be graph directory not lang directory, but has the necessary stuff copied. +dir=$3 + +if [ -f $dir/final.mdl ]; then + model=$dir/final.mdl +else + model=$dir/../final.mdl # assume model one level up from decoding dir. +fi + +for f in $lang/words.txt $model $dir/lat.1.gz; do + [ ! -f $f ] && echo "$0: expecting file $f to exist" && exit 1; +done + +name=`basename $data`; # e.g. eval2000 + +mkdir -p $dir/scoring/log + +if [ $stage -le 0 ]; then + nj=$(cat $dir/num_jobs) + if [ -f $lang/phones/word_boundary.int ]; then + $cmd JOB=1:$nj $dir/scoring/log/get_ctm.JOB.log \ + set -o pipefail '&&' mkdir -p $dir/score_$lmwt/ '&&' \ + lattice-1best --lm-scale=$lmwt "ark:gunzip -c $dir/lat.JOB.gz|" ark:- \| \ + lattice-align-words $lang/phones/word_boundary.int $model ark:- ark:- \| \ + nbest-to-ctm --frame-shift=$frame_shift --print-silence=$print_silence ark:- - \| \ + utils/int2sym.pl -f 5 $lang/words.txt \ + '>' $dir/score_$lmwt/${name}.ctm.JOB || exit 1; + elif [ -f $lang/phones/align_lexicon.int ]; then + $cmd JOB=1:$nj $dir/scoring/log/get_ctm.JOB.log \ + set -o pipefail '&&' mkdir -p $dir/score_$lmwt/ '&&' \ + lattice-1best --lm-scale=$lmwt "ark:gunzip -c $dir/lat.JOB.gz|" ark:- \| \ + lattice-align-words-lexicon $lang/phones/align_lexicon.int $model ark:- ark:- \| \ + lattice-1best ark:- ark:- \| \ + nbest-to-ctm --frame-shift=$frame_shift --print-silence=$print_silence ark:- - \| \ + utils/int2sym.pl -f 5 $lang/words.txt \ + '>' $dir/score_LMWT/${name}.ctm.JOB || exit 1; + else + echo "$0: neither $lang/phones/word_boundary.int nor $lang/phones/align_lexicon.int exists: cannot align." + exit 1; + fi +fi + + + diff --git a/egs/wsj/s5/steps/cleanup/internal/get_non_scored_words.py b/egs/wsj/s5/steps/cleanup/internal/get_non_scored_words.py index 84d1ca0fbf6..aa71fa47d84 100755 --- a/egs/wsj/s5/steps/cleanup/internal/get_non_scored_words.py +++ b/egs/wsj/s5/steps/cleanup/internal/get_non_scored_words.py @@ -5,9 +5,23 @@ # Apache 2.0 from __future__ import print_function -import sys, operator, argparse, os +import argparse +import logging +import operator +import os +import sys from collections import defaultdict + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) +handler = logging.StreamHandler() +handler.setLevel(logging.INFO) +formatter = logging.Formatter("%(asctime)s [%(pathname)s:%(lineno)s - " + "%(funcName)s - %(levelname)s ] %(message)s") +handler.setFormatter(formatter) +logger.addHandler(handler) + # If you supply the directory (the one that corresponds to # how you decoded the data) to this script, it assumes that the # directory contains phones/align_lexicon.int, and it uses this to work @@ -35,24 +49,28 @@ non_scored_words = set() -def ReadLang(lang_dir): +def read_lang(lang_dir): global non_scored_words if not os.path.isdir(lang_dir): - sys.exit("get_non_scored_words.py expected lang/ directory {0} to " - "exist.".format(lang_dir)) + logger.error("expected lang/ directory %s to " + "exist.", lang_dir) + raise RuntimeError + for f in [ '/words.txt', '/phones/silence.int', '/phones/align_lexicon.int' ]: if not os.path.exists(lang_dir + f): - sys.exit("get_non_scored_words.py: expected file {0}{1} to exist.".format( - lang_dir, f)) + logger.error("expected file %s%s to exist.", lang_dir, f) + raise RuntimeError + # read silence-phones. try: silence_phones = set() for line in open(lang_dir + '/phones/silence.int').readlines(): silence_phones.add(int(line)) - except Exception as e: - sys.exit("get_non_scored_words.py: problem reading file " - "{0}/phones/silence.int: {1}".format(lang_dir, str(e))) + except Exception: + logger.error("problem reading file " + "%s/phones/silence.int", lang_dir) + raise # read align_lexicon.int. # format is: .. @@ -66,25 +84,25 @@ def ReadLang(lang_dir): if len(a) == 3 and a[0] == a[1] and int(a[0]) > 0 and \ int(a[2]) in silence_phones: silence_word_ints.add(int(a[0])) - except Exception as e: - sys.exit("get_non_scored_words.py: problem reading file " - "{0}/phones/align_lexicon.int: " - "{1}".format(lang_dir, str(e))) + except Exception: + logger.error("problem reading file %s/phones/align_lexicon.int", + lang_dir) + raise try: for line in open(lang_dir + '/words.txt').readlines(): [ word, integer ] = line.split() if int(integer) in silence_word_ints: non_scored_words.add(word) - except Exception as e: - sys.exit("get_non_scored_words.py: problem reading file " - "{0}/words.txt.int: {1}".format(lang_dir, str(e))) + except Exception: + logger.error("problem reading file %s/words.txt.int", lang_dir) + raise if not len(non_scored_words) == len(silence_word_ints): - sys.exit("get_non_scored_words.py: error getting silence words, len({0}) != len({1})", - str(non_scored_words), str(silence_word_ints)) + raise RuntimeError("error getting silence words, len({0}) != len({1})" + "".format(non_scored_words, silence_word_ints)) for word in non_scored_words: print(word) -ReadLang(args.lang) +read_lang(args.lang) diff --git a/egs/wsj/s5/steps/cleanup/internal/modify_ctm_edits.py b/egs/wsj/s5/steps/cleanup/internal/modify_ctm_edits.py index ea56219fe2a..ed83595a224 100755 --- a/egs/wsj/s5/steps/cleanup/internal/modify_ctm_edits.py +++ b/egs/wsj/s5/steps/cleanup/internal/modify_ctm_edits.py @@ -5,49 +5,62 @@ # Apache 2.0 from __future__ import print_function -import sys, operator, argparse, os +import argparse +import logging +import sys from collections import defaultdict -# This script reads and writes the 'ctm-edits' file that is -# produced by get_ctm_edits.py. - -# It modifies the ctm-edits so that non-scored words -# are not counted as errors: for instance, if there are things like -# [COUGH] and [NOISE] in the transcript, deletions, insertions and -# substitutions involving them are allowed, and we modify the reference -# to correspond to the hypothesis. -# -# If you supply the directory (the one that corresponds to -# how you decoded the data) to this script, it assumes that the -# directory contains phones/align_lexicon.int, and it uses this to work -# out a reasonable guess of the non-scored phones, based on which have -# a single-word pronunciation that maps to a silence phone. -# It then uses the words.txt to work out the written form of those words. -# -# Alternatively, you may specify a file containing the non-scored words one -# per line, with the --non-scored-words option. -# -# Non-scored words that were deleted (i.e. they were in the ref but not the -# hyp) are simply removed from the ctm. For non-scored words that -# were inserted or substituted, we change the reference word to match the -# hyp word, but instead of marking the operation as 'cor' (correct), we -# mark it as 'fix' (fixed), so that it will not be positively counted as a correct -# word for purposes of finding the optimal segment boundaries. -# -# e.g. -# -# [note: the will always be 1]. - -# AJJacobs_2007P-0001605-0003029 1 0 0.09 1.0 sil -# AJJacobs_2007P-0001605-0003029 1 0.09 0.15 i 1.0 i cor -# AJJacobs_2007P-0001605-0003029 1 0.24 0.25 thought 1.0 thought cor -# AJJacobs_2007P-0001605-0003029 1 0.49 0.14 i'd 1.0 i'd cor -# AJJacobs_2007P-0001605-0003029 1 0.63 0.22 tell 1.0 tell cor -# AJJacobs_2007P-0001605-0003029 1 0.85 0.11 you 1.0 you cor -# AJJacobs_2007P-0001605-0003029 1 0.96 0.05 a 1.0 a cor -# AJJacobs_2007P-0001605-0003029 1 1.01 0.24 little 1.0 little cor -# AJJacobs_2007P-0001605-0003029 1 1.25 0.5 about 1.0 about cor -# AJJacobs_2007P-0001605-0003029 1 1.75 0.48 [UH] 1.0 [UH] cor +""" +This script reads and writes the 'ctm-edits' file that is +produced by get_ctm_edits.py. + +It modifies the ctm-edits so that non-scored words +are not counted as errors: for instance, if there are things like +[COUGH] and [NOISE] in the transcript, deletions, insertions and +substitutions involving them are allowed, and we modify the reference +to correspond to the hypothesis. + +If you supply the directory (the one that corresponds to +how you decoded the data) to this script, it assumes that the +directory contains phones/align_lexicon.int, and it uses this to work +out a reasonable guess of the non-scored phones, based on which have +a single-word pronunciation that maps to a silence phone. +It then uses the words.txt to work out the written form of those words. + +Alternatively, you may specify a file containing the non-scored words one +per line, with the --non-scored-words option. + +Non-scored words that were deleted (i.e. they were in the ref but not the +hyp) are simply removed from the ctm. For non-scored words that +were inserted or substituted, we change the reference word to match the +hyp word, but instead of marking the operation as 'cor' (correct), we +mark it as 'fix' (fixed), so that it will not be positively counted as a correct +word for purposes of finding the optimal segment boundaries. + +e.g. + +[note: the will always be 1]. + +AJJacobs_2007P-0001605-0003029 1 0 0.09 1.0 sil +AJJacobs_2007P-0001605-0003029 1 0.09 0.15 i 1.0 i cor +AJJacobs_2007P-0001605-0003029 1 0.24 0.25 thought 1.0 thought cor +AJJacobs_2007P-0001605-0003029 1 0.49 0.14 i'd 1.0 i'd cor +AJJacobs_2007P-0001605-0003029 1 0.63 0.22 tell 1.0 tell cor +AJJacobs_2007P-0001605-0003029 1 0.85 0.11 you 1.0 you cor +AJJacobs_2007P-0001605-0003029 1 0.96 0.05 a 1.0 a cor +AJJacobs_2007P-0001605-0003029 1 1.01 0.24 little 1.0 little cor +AJJacobs_2007P-0001605-0003029 1 1.25 0.5 about 1.0 about cor +AJJacobs_2007P-0001605-0003029 1 1.75 0.48 [UH] 1.0 [UH] cor +""" + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) +handler = logging.StreamHandler() +handler.setLevel(logging.INFO) +formatter = logging.Formatter('%(asctime)s [%(filename)s:%(lineno)s - ' + '%(funcName)s - %(levelname)s ] %(message)s') +handler.setFormatter(formatter) +logger.addHandler(handler) parser = argparse.ArgumentParser( @@ -142,6 +155,7 @@ def ProcessLineForNonScoredWords(a): ref_change_stats[ref_word + ' -> ' + hyp_word] += 1 return [] elif edit_type == 'sub': + assert hyp_word != '' if hyp_word in non_scored_words and ref_word in non_scored_words: # we also allow replacing one non-scored word with another. ref_change_stats[ref_word + ' -> ' + hyp_word] += 1 @@ -156,12 +170,10 @@ def ProcessLineForNonScoredWords(a): a[7] = edit_type return a - except Exception as e: - print("modify_ctm_edits.py: bad line in ctm-edits input: " + ' '.join(a), - file = sys.stderr) - print("modify_ctm_edits.py: exception was: " + str(e), - file = sys.stderr) - sys.exit(1) + except Exception: + logger.error("bad line in ctm-edits input: " + "{0}".format(a)) + raise RuntimeError # This function processes the split lines of one utterance (as a # list of lists of fields), to allow repetitions of words, so if the diff --git a/egs/wsj/s5/steps/cleanup/internal/resolve_ctm_edits_overlaps.py b/egs/wsj/s5/steps/cleanup/internal/resolve_ctm_edits_overlaps.py new file mode 100755 index 00000000000..5b74f0ef592 --- /dev/null +++ b/egs/wsj/s5/steps/cleanup/internal/resolve_ctm_edits_overlaps.py @@ -0,0 +1,387 @@ +#! /usr/bin/env python + +# Copyright 2014 Johns Hopkins University (Authors: Daniel Povey) +# 2014 Vijayaditya Peddinti +# 2016 Vimal Manohar +# Apache 2.0. + +""" +Script to combine ctms edits with overlapping segments obtained from +smith-waterman alignment. This script is similar to resolve_ctm_edits.py, +where the overlapping region is just split in two. The approach here is a +little more advanced since we have access to the WER +(w.r.t. the reference text). It finds the WER of the overlapped region +in the two overlapping segments, and chooses the better one. +""" + +from __future__ import print_function +import argparse +import logging + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) +handler = logging.StreamHandler() +handler.setLevel(logging.INFO) +formatter = logging.Formatter( + '%(asctime)s [%(pathname)s:%(lineno)s - ' + '%(funcName)s - %(levelname)s ] %(message)s') +handler.setFormatter(formatter) +logger.addHandler(handler) + + +def get_args(): + """gets command line arguments""" + + usage = """ Python script to resolve overlaps in ctms """ + parser = argparse.ArgumentParser(usage) + parser.add_argument('segments', type=argparse.FileType('r'), + help='use segments to resolve overlaps') + parser.add_argument('ctm_edits_in', type=argparse.FileType('r'), + help='input_ctm_file') + parser.add_argument('ctm_edits_out', type=argparse.FileType('w'), + help='output_ctm_file') + parser.add_argument('--verbose', type=int, default=0, + help="Higher value for more verbose logging.") + args = parser.parse_args() + + if args.verbose > 2: + logger.setLevel(logging.DEBUG) + handler.setLevel(logging.DEBUG) + + return args + + +def read_segments(segments_file): + """Read from segments and returns two dictionaries, + {utterance-id: (recording_id, start_time, end_time)} + {recording_id: list-of-utterances} + """ + segments = {} + reco2utt = defaultdict(list) + + num_lines = 0 + for line in segments_file: + num_lines += 1 + parts = line.strip().split() + assert len(parts) in [4, 5] + segments[parts[0]] = (parts[1], float(parts[2]), float(parts[3])) + reco2utt[parts[1]].append(parts[0]) + + logger.info("Read %d lines from segments file %s", + num_lines, segments_file.name) + segments_file.close() + + return segments, reco2utt + + +def read_ctm_edits(ctm_edits_file, segments): + """Read CTM from ctm_edits_file into a dictionary of values indexed by the + recording. + It is assumed to be sorted by the recording-id and utterance-id. + + Returns a dictionary {recording : ctm_edit_lines} + where ctm_lines is a list of lines of CTM corresponding to the + utterances in the recording. + The format is as follows: + [[(utteranceA, channelA, start_time1, duration1, hyp_word1, conf1, ref_word1, edit_type1), + (utteranceA, channelA, start_time2, duration2, hyp_word2, conf2, ref_word2, edit_type2), + ... + (utteranceA, channelA, start_timeN, durationN, hyp_wordN, confN, ref_wordN, edit_typeN)], + [(utteranceB, channelB, start_time1, duration1, hyp_word1, conf1, ref_word1, edit_type1), + (utteranceB, channelB, start_time2, duration2, hyp_word2, conf2, ref_word2, edit_type2), + ...], + ... + [... + (utteranceZ, channelZ, start_timeN, durationN, hyp_wordN, confN, ref_wordN, edit_typeN)] + ] + + Arguments: + segments - Dictionary containing the output of read_segments() + { utterance_id: (recording_id, start_time, end_time) } + """ + ctm_edits = {} + + num_lines = 0 + for line in ctm_edits_file: + + for key in [x[0] for x in segments.values()]: + ctm_edits[key] = [] + num_lines += 1 + parts = line.split() + + utt = parts[0] + reco = segments[utt][0] + + if (reco, utt) not in ctm_edits: + ctm_edits[(reco, utt)] = [] + + ctms[(reco, utt)].append([parts[0], parts[1], float(parts[2]), + float(parts[3]), parts[4], float(parts[5])] + + parts[6:]) + + logger.info("Read %d lines from CTM %s", num_lines, ctm_edits_file.name) + + ctm_edits_file.close() + return ctm_edits + + +def wer(ctm_edit_lines): + num_words = 0 + num_incorrect_words = 0 + for line in ctm_edit_lines: + if line[7] != 'sil': + num_words += 1 + if line[7] in ['ins', 'del', 'sub']: + num_incorrect_words += 1 + if num_words == 0 and num_incorrect_words > 0: + return float('inf') + if num_words == 0 and num_incorrect_words == 0: + return 0 + return (float(num_incorrect_words) / num_words, -num_words) + + +def choose_best_ctm_lines(first_lines, second_lines, + window_length, overlap_length): + """Returns ctm lines that have lower WER. If the WER is the lines with + the higher number of words is returned. + """ + i, best_lines = min((0, first_lines), (1, second_lines), + key=lambda x: wer(x[1])) + + return i + + #ctm_edit = [] + #prev_utt = "" + #num_lines = 0 + #num_utts = 0 + #for line in ctm_edits_file: + # num_lines += 1 + # try: + # parts = line.split() + # if prev_utt == parts[0]: + # ctm_edit.append([parts[0], parts[1], float(parts[2]), + # float(parts[3]), parts[4], float(parts[5])] + # + parts[6:]) + # else: + # if prev_utt != "": + # assert parts[0] > prev_utt # sorted by utterance-id + + # # New utterance. Append the previous utterance's CTM + # # into the list for the utterance's recording. + # reco = segments[prev_utt][0] + # ctm_edits[reco].append(ctm_edit) + # assert ctm_edit[0][0] == prev_utt + # num_utts += 1 + + # # Start a new CTM for the new utterance-id parts[0]. + # ctm_edit = [[parts[0], parts[1], float(parts[2]), + # float(parts[3]), parts[4], float(parts[5])] + # + parts[6:]] + # prev_utt = parts[0] + # except: + # logger.error("Error while reading line %s in CTM file %s", + # line, ctm_edits_file.name) + # raise + + ## Append the last ctm. + #reco = segments[prev_utt][0] + #ctm_edits[reco].append(ctm_edit) + + #logger.info("Read %d lines from CTM %s; got %d recordings, " + # "%d utterances.", + # num_lines, ctm_edits_file.name, len(ctm_edits), num_utts) + #ctm_edits_file.close() + #return ctm_edits + + +def resolve_overlaps(ctm_edits, segments): + """Resolve overlaps within segments of the same recording. + + Returns new lines of CTM for the recording. + + Arguments: + ctms - The CTM lines for a single recording. This is one value stored + in the dictionary read by read_ctm(). Assumes that the lines + are sorted by the utterance-ids. + The format is the following: + [[(utteranceA, channelA, start_time1, duration1, hyp_word1, conf1), + (utteranceA, channelA, start_time2, duration2, hyp_word2, conf2), + ... + (utteranceA, channelA, start_timeN, durationN, hyp_wordN, confN) + ], + [(utteranceB, channelB, start_time1, duration1, hyp_word1, conf1), + (utteranceB, channelB, start_time2, duration2, hyp_word2, conf2), + ...], + ... + [... + (utteranceZ, channelZ, start_timeN, durationN, hyp_wordN, confN)] + ] + segments - Dictionary containing the output of read_segments() + { utterance_id: (recording_id, start_time, end_time) } + """ + total_ctm_edits = [] + if len(ctm_edits) == 0: + raise RuntimeError('CTMs for recording is empty. ' + 'Something wrong with the input ctms') + + # First column of first line in CTM for first utterance + next_utt = ctm_edits[0][0][0] + for utt_index, ctm_edits_for_cur_utt in enumerate(ctm_edits): + if utt_index == len(ctm_edits) - 1: + break + + if len(ctm_edits_for_cur_utt) == 0: + next_utt = ctm_edits[utt_index + 1][0][0] + continue + + cur_utt = ctm_edits_for_cur_utt[0][0] + if cur_utt != next_utt: + logger.error( + "Current utterance %s is not the same as the next " + "utterance %s in previous iteration.\n" + "CTM is not sorted by utterance-id?", + cur_utt, next_utt) + raise ValueError + + # Assumption here is that the segments are written in + # consecutive order in time. + ctm_edits_for_next_utt = ctm_edits[utt_index + 1] + next_utt = ctm_edits_for_next_utt[0][0] + if segments[next_utt][1] < segments[cur_utt][1]: + logger.error( + "Next utterance %s <= Current utterance %s. " + "CTM edits is not sorted by utterance-id.", + next_utt, cur_utt) + raise ValueError + + try: + # length of this utterance + window_length = segments[cur_utt][2] - segments[cur_utt][1] + + # overlap of this segment with the next segment + # i.e. current_utterance_end_time - next_utterance_start_time + # Note: It is possible for this to be negative when there is + # actually no overlap between consecutive segments. + try: + overlap = segments[cur_utt][2] - segments[next_utt][1] + except KeyError: + logger("Could not find utterance %s in segments", + next_utt) + raise + + # find the first word that is in the overlap + # at the end of the cur utt + try: + cur_utt_end_index = next( + (i for i, line in enumerate(ctm_edits_for_cur_utt) + if line[2] + line[3] / 2.0 > window_length - overlap)) + except StopIteration: + cur_utt_end_index = len(ctm_edits_for_cur_utt) + + cur_utt_end_lines = ctm_edits_for_cur_utt[cur_utt_end_index:] + + # find the last word that is not in the overlap + # at the beginning of the next utt + try: + next_utt_start_index = next( + (i for i, line in enumerate(ctm_edits_for_next_utt) + if line[2] + line[3] / 2.0 > overlap)) + except StopIteration: + next_utt_start_index = 0 + + next_utt_start_lines = ctm_edits_for_next_utt[: + next_utt_start_index] + + choose_index = choose_best_ctm_lines( + cur_utt_end_lines, next_utt_start_lines, + window_length, overlap) + + # Ignore the hypotheses beyond this midpoint. They will be + # considered as part of the next segment. + if choose_index == 1: + total_ctm_edits.extend( + ctm_edits_for_cur_utt[:cur_utt_end_index]) + else: + total_ctm_edits.extend(ctm_edits_for_cur_utt) + + if choose_index == 0 and next_utt_start_index > 0: + # Update the ctm_edits_for_next_utt to include only the lines + # starting from index. + ctm_edits[utt_index + 1] = ( + ctm_edits_for_next_utt[next_utt_start_index:]) + # else leave the ctm_edits as is. + except: + logger.error("Could not resolve overlaps between CTM edits for " + "%s and %s", cur_utt, next_utt) + logger.error("Current CTM:") + for line in ctm_edits_for_cur_utt: + logger.error(ctm_edit_line_to_string(line)) + logger.error("Next CTM:") + for line in ctm_edits_for_next_utt: + logger.error(ctm_edit_line_to_string(line)) + raise + + # merge the last ctm entirely + total_ctm_edits.extend(ctm_edits[-1]) + + return total_ctm_edits + + +def ctm_edit_line_to_string(line): + """Converts a line of CTM edit to string.""" + return "{0} {1} {2} {3} {4} {5} {6}".format(line[0], line[1], line[2], + line[3], line[4], line[5], + " ".join(line[6:])) + + +def write_ctm_edits(ctm_edit_lines, out_file): + """Writes CTM lines stored in a list to file.""" + for line in ctm_edit_lines: + print(ctm_edit_line_to_string(line), file=out_file) + + +def run(args): + """this method does everything in this script""" + segments, reco2utt = read_segments(args.segments) + ctm_edits = read_ctm_edits(args.ctm_edits_in, segments) + + for reco, utts in reco2utt.iteritems(): + ctm_edits_for_reco = [] + for utt in sorted(utts, key=lambda x: segments[x][1]): + if (reco, utt) in ctms: + ctm_edits_for_reco.append(ctm_edits[(reco, utt)]) + try: + # Process CTMs in the recordings + ctm_edits_for_reco = resolve_overlaps(ctm_edits_for_reco, segments) + write_ctm_edits(ctm_edits_for_reco, args.ctm_edits_out) + except Exception: + logger.error("Failed to process CTM edits for recording %s", + reco) + raise + args.ctm_edits_out.close() + logger.info("Wrote CTM for %d recordings.", len(ctm_edits)) + + +def main(): + """The main function which parses arguments and call run().""" + args = get_args() + try: + run(args) + except: + logger.error("Failed to resolve overlaps", exc_info=True) + raise RuntimeError + finally: + try: + for f in [args.segments, args.ctm_edits_in, args.ctm_edits_out]: + if f is not None: + f.close() + except IOError: + logger.error("Could not close some files. " + "Disk error or broken pipes?") + raise + except UnboundLocalError: + raise SystemExit(1) + + +if __name__ == "__main__": + main() diff --git a/egs/wsj/s5/steps/cleanup/internal/retrieve_similar_docs.py b/egs/wsj/s5/steps/cleanup/internal/retrieve_similar_docs.py new file mode 100755 index 00000000000..d5dc6a643ac --- /dev/null +++ b/egs/wsj/s5/steps/cleanup/internal/retrieve_similar_docs.py @@ -0,0 +1,353 @@ +#! /usr/bin/env python + +# Copyright 2017 Vimal Manohar +# Apache 2.0. + +"""This script retrieves documents similar to the query documents +using a similarity score based on the total TFIDF for all the terms in the +query document. + +Some terminology: + original utterance-id = The utterance-id of the original long audio segments + and the corresponding reference transcript + source-text = reference transcript + source-text-id = original utterance-id + sub-segment = Approximately 30s long chunk of the original utterance + query-id = utterance-id of the sub-segment + document = Approximately 1000 words of a source-text + doc-id = Id of the document + +e.g. +foo1 A B C D E F is in the original text file +and foo1 foo 100 200 is in the original segments file. + +Here foo1 is the source-text-id and "A B C D" is the reference transcript. It +is a 100s long segment from the recording foo. + +foo1 is split into 30s long sub-segments as follows: +foo1-1 foo1 100 130 +foo1-2 foo1 125 155 +foo1-3 foo1 150 180 +foo1-4 foo1 175 200 + +foo1-{1,2,3,4} are query-ids. + +The source-text for foo1 is split into two-word documents. +doc1 A B +doc2 C D +doc3 E F + +doc{1,2,3} are doc-ids. + +--source-text2doc-ids option is given a mapping that contains +foo1 doc1 doc2 doc3 + +--query-id2source-text-id option is given a mapping that contains +foo1-1 foo1 +foo1-2 foo1 +foo1-3 foo1 +foo1-4 foo1 + +The query TF-IDFs are all indexed by the utterance-id of the sub-segments +of the original utterances. +The source TF-IDFs use the document-ids created by splitting the source-text +(corresponding to original utterances) into documents. + +For each query (sub-segment), we need to retrieve the documents that were +created from the same original utterance that the sub-segment was from. For +this, we have to load the source TF-IDF that has those documents. This +information is provided using the option --source-text2tf-idf-file, which +is like an SCP file with the first column being the source-text-id and the +second column begin the location of TF-IDF for the documents corresponding +to that source-text-id. + +The output of this script is a file where the first column is the +query-id (i.e. sub-segment-id) and the remaining columns, which is at least +one in number and a maxmium of (1 + 2 * num-neighbors-to-search) columns +are tuples separated by commas +(, , ), where is the document-id + is the proportion of the document from the beginning +that needs to be in the retrieved set. + is the proportion of the document from the end +that needs to be in the retrieved set. +If both and are 1, then the full document is +added to the retrieved set. +Some examples of the lines in the output file are: +foo1-1 doc1,1,1 +foo1-2 doc1,0,0.2 doc2,1,1 doc3,0.2,0 +""" + +from __future__ import print_function +import argparse +import logging + +import tf_idf + + +logger = logging.getLogger('__name__') +handler = logging.StreamHandler() +handler.setLevel(logging.INFO) +formatter = logging.Formatter("%(asctime)s [%(filename)s:%(lineno)s - " + "%(funcName)s - %(levelname)s ] %(message)s") +handler.setFormatter(formatter) + +for l in [logger, logging.getLogger('tf_idf'), logging.getLogger('libs')]: + l.setLevel(logging.DEBUG) + l.addHandler(handler) + + +def get_args(): + parser = argparse.ArgumentParser( + description="""This script retrieves documents similar to the + query documents using a similarity score based on the total TFIDF for + all the terms in the query document. + See the beginning of the script for more details about the + arguments to the script.""") + + parser.add_argument("--verbose", type=int, default=0, choices=[0, 1, 2, 3], + help="Higher for more logging statements") + + parser.add_argument("--num-neighbors-to-search", type=int, default=0, + help="""Number of neighboring documents to search + around the one retrieved based on maximum tf-idf + similarity. A value of 0 means only the document + with the maximum tf-idf similarity is retrieved, + and none of the documents adjacent to it.""") + parser.add_argument("--neighbor-tfidf-threshold", type=float, default=0.9, + help="""Ignore neighbors that have tf-idf similarity + with the query document less than this threshold + factor lower than the best score.""") + parser.add_argument("--partial-doc-fraction", default=0.2, + help="""The fraction of neighboring document that will + be part of the retrieved document set. + If this is greater than 0, then a fraction of words + from the neighboring documents is added to the + retrieved document.""") + + parser.add_argument("--source-text-id2doc-ids", + type=argparse.FileType('r'), required=True, + help="""A mapping from the source text to a list of + documents that it is broken into + ... + """) + parser.add_argument("--query-id2source-text-id", + type=argparse.FileType('r'), required=True, + help="""A mapping from the query document-id to a + source text from which a document needs to be + retrieved.""") + parser.add_argument("--source-text-id2tfidf", type=argparse.FileType('r'), + required=True, + help="""An SCP file for the TF-IDF for source + documents indexed by the source-text-id.""") + parser.add_argument("--query-tfidf", type=argparse.FileType('r'), + required=True, + help="""Archive of TF-IDF objects for query documents + indexed by the query-id. + The format is + query-id ... + """) + parser.add_argument("--relevant-docs", type=argparse.FileType('w'), + required=True, + help="""Output archive of a list of source documents + similar to a query document, indexed by the + query document id.""") + + args = parser.parse_args() + + if args.partial_doc_fraction < 0 or args.partial_doc_fraction > 1: + logger.error("--partial-doc-fraction must be in [0,1]") + raise ValueError + + return args + + +def read_map(file_handle, num_values_per_key=None, + min_num_values_per_key=None, must_contain_unique_key=True): + """Reads a map from a file into a dictionary and returns it. + Expects the map is stored in the file in the following format: + ... + The values are returned as a tuple stored in a dictionary indexed by the + "key". + + Arguments: + file_handle - A handle to an opened input file containing the map + num_values_per_key - If provided, the function raises an error if + the number of values read for a key in the input + file does not match the "num_values_per_key" + min_num_values_per_key - If provided, the function raises an error + if the number of values read for a key in the + input file is less than + "min_num_values_per_key" + must_contain_unique_key - If set to True, then it is required that the + file has a unique key; otherwise this + function will exit with error. + + Returns: + { key: tuple(values) } + """ + dict_map = {} + for line in file_handle: + try: + parts = line.strip().split() + key = parts[0] + + if (num_values_per_key is not None + and len(parts) - 1 != num_values_per_key): + logger.error( + "Expecting {0} columns; Got {1}.".format( + num_values_per_key + 1, len(parts))) + raise TypeError + + if (min_num_values_per_key is not None + and len(parts) - 1 < min_num_values_per_key): + logger.error( + "Expecting at least {0} columns; Got {1}.".format( + min_num_values_per_key + 1, len(parts))) + raise TypeError + + if must_contain_unique_key and key in dict_map: + logger.error("Found duplicate key %s", key) + raise TypeError + + if num_values_per_key is not None and num_values_per_key == 1: + dict_map[key] = parts[1] + else: + dict_map[key] = parts[1:] + except Exception: + logger.error("Failed reading line %s in file %s", + line, file_handle.name) + raise + file_handle.close() + return dict_map + + +def get_document_ids(source_docs, indexes): + indexes = sorted( + [(key, value[0], value[1]) for key, value in indexes.iteritems()], + key=lambda x: x[0]) + + doc_ids = [] + for i, partial_start, partial_end in indexes: + try: + doc_ids.append((source_docs[i], partial_start, partial_end)) + except IndexError: + pass + return doc_ids + + +def run(args): + """The main function that does all the processing. + Takes as argument the Namespace object obtained from _get_args(). + """ + query_id2source_text_id = read_map(args.query_id2source_text_id, + num_values_per_key=1) + source_text_id2doc_ids = read_map(args.source_text_id2doc_ids, + min_num_values_per_key=1) + + source_text_id2tfidf = read_map(args.source_text_id2tfidf, + num_values_per_key=1) + + num_queries = 0 + prev_source_text_id = "" + for query_id, query_tfidf in tf_idf.read_tfidf_ark(args.query_tfidf): + num_queries += 1 + + # The source text from which a document is to be retrieved for the + # input query + source_text_id = query_id2source_text_id[query_id] + + if prev_source_text_id != source_text_id: + source_tfidf = tf_idf.TFIDF() + source_tfidf.read(source_text_id2tfidf[source_text_id]) + prev_source_text_id = source_text_id + + # The source documents corresponding to the source text. + # This is set of documents which will be searched over for the query. + source_doc_ids = source_text_id2doc_ids[source_text_id] + + scores = query_tfidf.compute_similarity_scores( + source_tfidf, source_docs=source_doc_ids, query_id=query_id) + + assert len(scores) > 0, ( + "Did not get scores for query {0}".format(query_id)) + + if args.verbose > 2: + for tup, score in scores.iteritems(): + logger.debug("Score, {num}: {0} {1} {2}".format( + tup[0], tup[1], score, num=num_queries)) + + best_index, best_doc_id = max( + enumerate(source_doc_ids), key=lambda x: scores[(query_id, x[1])]) + best_score = scores[(query_id, best_doc_id)] + + assert source_doc_ids[best_index] == best_doc_id + assert best_score == max([scores[(query_id, x)] + for x in source_doc_ids]) + + best_indexes = {} + + if args.num_neighbors_to_search == 0: + best_indexes[best_index] = (1, 1) + if best_index > 0: + best_indexes[best_index - 1] = (0, args.partial_doc_fraction) + if best_index < len(source_doc_ids) - 1: + best_indexes[best_index + 1] = (args.partial_doc_fraction, 0) + else: + excluded_indexes = set() + for index in range( + max(best_index - args.num_neighbors_to_search, 0), + min(best_index + args.num_neighbors_to_search + 1, + len(source_doc_ids))): + if (scores[(query_id, source_doc_ids[index])] + >= args.neighbor_tfidf_threshold * best_score): + best_indexes[index] = (1, 1) # Type 2 + if index > 0 and index - 1 in excluded_indexes: + try: + # Type 1 and 3 + start_frac, end_frac = best_indexes[index - 1] + assert end_frac == 0 + best_indexes[index - 1] = ( + start_frac, args.partial_doc_fraction) + except KeyError: + # Type 1 + best_indexes[index - 1] = ( + 0, args.partial_doc_fraction) + else: + excluded_indexes.add(index) + if index > 0 and index - 1 not in excluded_indexes: + # Type 3 + best_indexes[index] = (args.partial_doc_fraction, 0) + + best_docs = get_document_ids(source_doc_ids, best_indexes) + + assert len(best_docs) > 0, ( + "Did not get best docs for query {0}\n" + "Scores: {1}\n" + "Source docs: {2}\n" + "Best index: {best_index}, score: {best_score}\n".format( + query_id, scores, source_doc_ids, + best_index=best_index, best_score=best_score)) + assert (best_doc_id, 1.0, 1.0) in best_docs + + print ("{0} {1}".format(query_id, " ".join( + ["%s,%.2f,%.2f" % x for x in best_docs])), + file=args.relevant_docs) + logger.info("Retrieved similar documents for " + "%d queries", num_queries) + + +def main(): + args = get_args() + + if args.verbose > 1: + handler.setLevel(logging.DEBUG) + try: + run(args) + finally: + for f in [args.query_id2source_text_id, args.source_text_id2doc_ids, + args.relevant_docs, args.query_tfidf, args.source_tfidf]: + f.close() + + +if __name__ == '__main__': + main() diff --git a/egs/wsj/s5/steps/cleanup/internal/segment_ctm_edits.py b/egs/wsj/s5/steps/cleanup/internal/segment_ctm_edits.py index 57e9d6ab959..b6c7e8a7cec 100755 --- a/egs/wsj/s5/steps/cleanup/internal/segment_ctm_edits.py +++ b/egs/wsj/s5/steps/cleanup/internal/segment_ctm_edits.py @@ -64,6 +64,10 @@ help = "Maximum proportion of the time of the segment that may " "consist of potentially bad data, in which we include 'tainted' lines of " "the ctm-edits input and unk-padding.") +parser.add_argument("--min-split-point-duration", type=float, default=0.1, + help="""Minimum duration of silence or non-scored word + to be considered a viable split point when + truncating based on junk proportion.""") parser.add_argument("--max-deleted-words-kept-when-merging", type = str, default = 1, help = "When merging segments that are found to be overlapping or " "adjacent after all other processing, keep in the transcript the " @@ -536,12 +540,15 @@ def PossiblyTruncateStartForJunkProportion(self): # We'll consider splitting on silence and on non-scored words. # (i.e. making the silence or non-scored word the left boundary of # the new utterance and discarding the piece to the left of that). - if this_edit_type == 'sil' or \ - (this_edit_type == 'cor' and this_ref_word in non_scored_words): + if ((this_edit_type == 'sil' + or (this_edit_type == 'cor' + and this_ref_word in non_scored_words)) + and (float(this_split_line[3]) + > args.min_split_point_duration)): candidate_start_index = i candidate_start_time = float(this_split_line[2]) break # Consider only the first potential truncation. - if candidate_start_index == None: + if candidate_start_index is None: return # Nothing to do as there is no place to split. candidate_removed_piece_duration = candidate_start_time - self.StartTime() if begin_junk_duration / candidate_removed_piece_duration < args.max_junk_proportion: @@ -575,12 +582,15 @@ def PossiblyTruncateEndForJunkProportion(self): # We'll consider splitting on silence and on non-scored words. # (i.e. making the silence or non-scored word the right boundary of # the new utterance and discarding the piece to the right of that). - if this_edit_type == 'sil' or \ - (this_edit_type == 'cor' and this_ref_word in non_scored_words): + if ((this_edit_type == 'sil' + or (this_edit_type == 'cor' + and this_ref_word in non_scored_words)) + and (float(this_split_line[3]) + > args.min_split_point_duration)): candidate_end_index = i + 1 # note: end-indexes are one past the last. candidate_end_time = float(this_split_line[2]) + float(this_split_line[3]) break # Consider only the latest potential truncation. - if candidate_end_index == None: + if candidate_end_index is None: return # Nothing to do as there is no place to split. candidate_removed_piece_duration = self.EndTime() - candidate_end_time if end_junk_duration / candidate_removed_piece_duration < args.max_junk_proportion: diff --git a/egs/wsj/s5/steps/cleanup/internal/segment_ctm_edits_mild.py b/egs/wsj/s5/steps/cleanup/internal/segment_ctm_edits_mild.py new file mode 100755 index 00000000000..35b9ed605ee --- /dev/null +++ b/egs/wsj/s5/steps/cleanup/internal/segment_ctm_edits_mild.py @@ -0,0 +1,2072 @@ +#! /usr/bin/env python + +# Copyright 2016 Vimal Manohar +# 2016 Johns Hopkins University (author: Daniel Povey) +# Apache 2.0 + +from __future__ import print_function +import argparse +import copy +import logging +import heapq +import sys +from collections import defaultdict + +""" +This script reads 'ctm-edits' file format that is produced by align_ctm_ref.py +and modified by modify_ctm_edits.py and taint_ctm_edits.py. Its function is to +produce a segmentation and text from the ctm-edits input. + +It is a milder version of the script segment_ctm_edits.py i.e. it allows +to keep more of the reference. This is useful for segmenting long-audio +based on imperfect transcripts. + +The ctm-edits file format that this script expects is as follows + +['tainted'] +[note: file-id is really utterance-id at this point]. +""" + +_global_logger = logging.getLogger(__name__) +_global_logger.setLevel(logging.INFO) +_global_handler = logging.StreamHandler() +_global_handler.setLevel(logging.INFO) +_global_formatter = logging.Formatter( + '%(asctime)s [%(pathname)s:%(lineno)s - ' + '%(funcName)s - %(levelname)s ] %(message)s') +_global_handler.setFormatter(_global_formatter) +_global_logger.addHandler(_global_handler) + +_global_non_scored_words = {} + + +def non_scored_words(): + return _global_non_scored_words + + +def get_args(): + parser = argparse.ArgumentParser( + description="""This program produces segmentation and text information + based on reading ctm-edits input format which is produced by + steps/cleanup/internal/get_ctm_edits.py, + steps/cleanup/internal/modify_ctm_edits.py and + steps/cleanup/internal/taint_ctm_edits.py.""", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + parser.add_argument("--min-segment-length", type=float, default=0.5, + help="""Minimum allowed segment length (in seconds) for + any segment; shorter segments than this will be + discarded.""") + parser.add_argument("--min-new-segment-length", type=float, default=1.0, + help="""Minimum allowed segment length (in seconds) for + newly created segments (i.e. not identical to the input + utterances). + Expected to be >= --min-segment-length.""") + parser.add_argument("--frame-length", type=float, default=0.01, + help="""This only affects rounding of the output times; + they will be constrained to multiples of this + value.""") + parser.add_argument("--max-tainted-length", type=float, default=0.05, + help="""Maximum allowed length of any 'tainted' line. + Note: 'tainted' lines may only appear at the boundary + of a segment""") + parser.add_argument("--max-edge-silence-length", type=float, default=0.5, + help="""Maximum allowed length of silence if it appears + at the edge of a segment (will be truncated). This + rule is relaxed if such truncation would take a segment + below the --min-segment-length or + --min-new-segment-length.""") + parser.add_argument("--max-edge-non-scored-length", type=float, + default=0.5, + help="""Maximum allowed length of a non-scored word + (noise, cough, etc.) if it appears at the edge of a + segment (will be truncated). This rule is relaxed if + such truncation would take a segment below the + --min-segment-length.""") + parser.add_argument("--max-internal-silence-length", type=float, + default=2.0, + help="""Maximum allowed length of silence if it appears + inside a segment (will cause the segment to be + split).""") + parser.add_argument("--max-internal-non-scored-length", type=float, + default=2.0, + help="""Maximum allowed length of a non-scored word + (noise, etc.) if it appears inside a segment (will + cause the segment to be split). + Note: reference words which are real words but OOV are + not included in this category.""") + parser.add_argument("--unk-padding", type=float, default=0.05, + help="""Amount of padding with that we do if a + segment boundary is next to errors (ins, del, sub). + That is, we add this amount of time to the segment and + add the word to cover the acoustics. If nonzero, + the --oov-symbol-file option must be supplied.""") + parser.add_argument("--max-junk-proportion", type=float, default=0.1, + help="""Maximum proportion of the time of the segment + that may consist of potentially bad data, in which we + include 'tainted' lines of the ctm-edits input and + unk-padding.""") + parser.add_argument("--min-split-point-duration", type=float, default=0.0, + help="""Minimum duration of silence or non-scored word + to be considered a viable split point when + truncating based on junk proportion.""") + parser.add_argument("--max-deleted-words-kept-when-merging", + dest='max_deleted_words', type=int, default=1, + help="""When merging segments that are found to be + overlapping or adjacent after all other processing, + keep in the transcript the reference words that were + deleted between the segments [if any] as long as there + were no more than this many reference words. Setting + this to zero will mean that any reference words that + were deleted between the segments we're about to + reattach will not appear in the generated transcript + (so we'll match the hyp).""") + + parser.add_argument("--splitting.min-silence-length", + dest="min_silence_length_to_split", + type=float, default=0.3, + help="""Only considers silences that are at least this + long as potential split points""") + parser.add_argument("--splitting.min-non-scored-length", + dest="min_non_scored_length_to_split", + type=float, default=0.1, + help="""Only considers non-scored words that are at + least this long as potential split points""") + parser.add_argument("--splitting.max-segment-length", + dest="max_segment_length_for_splitting", + type=float, default=10, + help="""Try to split long segments into segments that + are smaller that this size. See + possibly_split_long_segments() in Segment class.""") + parser.add_argument("--splitting.hard-max-segment-length", + dest="hard_max_segment_length", + type=float, default=15, + help="""Split all segments that are longer than this + uniformly into segments of size + --splitting.max-segment-length""") + + parser.add_argument("--merging-score.silence-factor", + dest="silence_factor", + type=float, default=1, + help="""Weightage on the silence length when merging + segments""") + parser.add_argument("--merging-score.incorrect-words-factor", + dest="incorrect_words_factor", + type=float, default=1, + help="""Weightage on the incorrect_words_length when + merging segments""") + parser.add_argument("--merging-score.tainted-words-factor", + dest="tainted_words_factor", + type=float, default=1, + help="""Weightage on the WER including the + tainted words as incorrect words.""") + + parser.add_argument("--merging.max-wer", + dest="max_wer", + type=float, default=10.0, + help="Max WER%% of merged segments when merging") + parser.add_argument("--merging.max-bad-proportion", + dest="max_bad_proportion", + type=float, default=0.2, + help="""Maximum length of silence, junk and incorrect + words in a merged segment allowed as a fraction of the + total length of merged segment.""") + parser.add_argument("--merging.max-segment-length", + dest='max_segment_length_for_merging', + type=float, default=10, + help="""Maximum segment length allowed for merged + segment""") + parser.add_argument("--merging.max-intersegment-incorrect-words-length", + dest='max_intersegment_incorrect_words_length', + type=float, default=0.2, + help="""Maximum length of intersegment region that + can be of incorrect word. This is to + allow cases where there may be a lot of silence in the + segment but the incorrect words are few, while + preventing regions that have a lot of incorrect + words.""") + + parser.add_argument("--oov-symbol-file", type=argparse.FileType('r'), + help="""Filename of file such as data/lang/oov.txt + which contains the text form of the OOV word, normally + ''. Supplied as a file to avoid complications + with escaping. Necessary if the --unk-padding option + has a nonzero value (which it does by default.""") + parser.add_argument("--ctm-edits-out", type=argparse.FileType('w'), + help="""Filename to output an extended version of the + ctm-edits format with segment start and end points + noted. This file is intended to be read by humans; + there are currently no scripts that will read it.""") + parser.add_argument("--word-stats-out", type=argparse.FileType('w'), + help="""Filename for output of word-level stats, of the + form ' ', + e.g. 'hello 0.12 12408', where the is + the proportion of the time that this reference word + does not make it into a segment. It can help reveal + words that have problematic pronunciations or are + associated with transcription errors.""") + + parser.add_argument("non_scored_words_in", + metavar="", + type=argparse.FileType('r'), + help="""Filename of file containing a list of + non-scored words, one per line. See + steps/cleanup/internal/get_nonscored_words.py.""") + parser.add_argument("ctm_edits_in", metavar="", + type=argparse.FileType('r'), + help="""Filename of input ctm-edits file. Use + /dev/stdin for standard input.""") + parser.add_argument("text_out", metavar="", + type=argparse.FileType('w'), + help="""Filename of output text file (same format as + data/train/text, i.e. + ... """) + parser.add_argument("segments_out", metavar="", + type=argparse.FileType('w'), + help="""Filename of output segments. This has the same + format as data/train/segments, but instead of + , the second field is the old + utterance-id, i.e + """) + + parser.add_argument("--verbose", type=int, default=0, + help="Use higher verbosity for more debugging output") + + args = parser.parse_args() + + if args.verbose > 2: + _global_handler.setLevel(logging.DEBUG) + _global_logger.setLevel(logging.DEBUG) + + return args + + +def is_tainted(split_line_of_utt): + """Returns True if this line in ctm-edit is "tainted.""" + return len(split_line_of_utt) > 8 and split_line_of_utt[8] == 'tainted' + + +def compute_segment_cores(split_lines_of_utt): + """ + This function returns a list of pairs (start-index, end-index) representing + the cores of segments (so if a pair is (s, e), then the core of a segment + would span (s, s+1, ... e-1). + + The argument 'split_lines_of_utt' is list of lines from a ctm-edits file + corresponding to a single utterance. + + By the 'core of a segment', we mean a sequence of ctm-edits lines including + at least one 'cor' line and a contiguous sequence of other lines of the + type 'cor', 'fix' and 'sil' that must be not tainted. The segment core + excludes any tainted lines at the edge of a segment, which will be added + later. + + We only initiate segments when it contains something correct and not + realized as unk (i.e. ref==hyp); and we extend it with anything that is + 'sil' or 'fix' or 'cor' that is not tainted. Contiguous regions of 'true' + in the resulting boolean array will then become the cores of prototype + segments, and we'll add any adjacent tainted words (or parts of them). + """ + num_lines = len(split_lines_of_utt) + line_is_in_segment_core = [False] * num_lines + # include only the correct lines + for i in range(num_lines): + if (split_lines_of_utt[i][7] == 'cor' + and split_lines_of_utt[i][4] == split_lines_of_utt[i][6]): + line_is_in_segment_core[i] = True + + # extend each proto-segment forwards as far as we can: + for i in range(1, num_lines): + if line_is_in_segment_core[i - 1] and not line_is_in_segment_core[i]: + edit_type = split_lines_of_utt[i][7] + if (not is_tainted(split_lines_of_utt[i]) + and (edit_type == 'cor' or edit_type == 'sil' + or edit_type == 'fix')): + line_is_in_segment_core[i] = True + + # extend each proto-segment backwards as far as we can: + for i in reversed(range(0, num_lines - 1)): + if line_is_in_segment_core[i + 1] and not line_is_in_segment_core[i]: + edit_type = split_lines_of_utt[i][7] + if (not is_tainted(split_lines_of_utt[i]) + and (edit_type == 'cor' or edit_type == 'sil' + or edit_type == 'fix')): + line_is_in_segment_core[i] = True + + # Get contiguous regions of line in the form of a list + # of (start_index, end_index) + segment_ranges = [] + cur_segment_start = None + for i in range(0, num_lines): + if line_is_in_segment_core[i]: + if cur_segment_start is None: + cur_segment_start = i + else: + if cur_segment_start is not None: + segment_ranges.append((cur_segment_start, i)) + cur_segment_start = None + if cur_segment_start is not None: + segment_ranges.append((cur_segment_start, num_lines)) + + return segment_ranges + + +class SegmentStats(object): + """Class to store various statistics of segments.""" + + def __init__(self): + self.num_incorrect_words = 0 + self.num_tainted_words = 0 + self.incorrect_words_length = 0 + self.tainted_nonsilence_length = 0 + self.silence_length = 0 + self.num_words = 0 + self.total_length = 0 + + def wer(self): + """Returns WER%""" + try: + return float(self.num_incorrect_words) * 100.0 / self.num_words + except ZeroDivisionError: + return float("inf") + + def bad_proportion(self): + assert self.total_length > 0 + proportion = float(self.silence_length + self.tainted_nonsilence_length + + self.incorrect_words_length) / self.total_length + if proportion > 1.00005: + raise RuntimeError("Error in segment stats {0}".format(self)) + return proportion + + def incorrect_proportion(self): + assert self.total_length > 0 + proportion = float(self.incorrect_words_length) / self.total_length + if proportion > 1.00005: + raise RuntimeError("Error in segment stats {0}".format(self)) + return proportion + + def combine(self, other, scale=1): + """Merges this stats with another stats object.""" + self.num_incorrect_words += scale * other.num_incorrect_words + self.num_tainted_words += scale * other.num_tainted_words + self.num_words += scale * other.num_words + self.incorrect_words_length += scale * other.incorrect_words_length + self.tainted_nonsilence_length += (scale + * other.tainted_nonsilence_length) + self.silence_length += scale * other.silence_length + self.total_length += scale * other.total_length + + def assert_equal(self, other): + try: + assert self.num_incorrect_words == other.num_incorrect_words + assert self.num_tainted_words == other.num_tainted_words + assert (abs(self.incorrect_words_length + - other.incorrect_words_length) < 0.01) + assert (abs(self.tainted_nonsilence_length + - other.tainted_nonsilence_length) < 0.01) + assert abs(self.silence_length - other.silence_length) < 0.01 + assert self.num_words == other.num_words + assert abs(self.total_length - other.total_length) < 0.01 + except AssertionError: + _global_logger.error("self %s != other %s", self, other) + raise + + def compare(self, other): + """Returns true if this stats is same as another stats object.""" + if self.num_incorrect_words != other.num_incorrect_words: + return False + if self.num_tainted_words != other.num_tainted_words: + return False + if self.incorrect_words_length != other.incorrect_words_length: + return False + if self.tainted_nonsilence_length != other.tainted_nonsilence_length: + return False + if self.silence_length != other.silence_length: + return False + if self.num_words != other.num_words: + return False + if self.total_length != other.total_length: + return False + return True + + def __str__(self): + return ("num-incorrect-words={num_incorrect:d}," + "num-tainted-words={num_tainted:d}," + "num-words={num_words:d}," + "incorrect-length={incorrect_length:.2f}," + "silence-length={sil_length:.2f}," + "tainted-nonsilence-length={tainted_nonsilence_length:.2f}," + "total-length={total_length:.2f}".format( + num_incorrect=self.num_incorrect_words, + num_tainted=self.num_tainted_words, + num_words=self.num_words, + incorrect_length=self.incorrect_words_length, + sil_length=self.silence_length, + tainted_nonsilence_length=self.tainted_nonsilence_length, + total_length=self.total_length)) + + +class Segment(object): + """Class to store segments.""" + + def __init__(self, split_lines_of_utt, start_index, end_index, + debug_str=None, compute_segment_stats=False, + segment_stats=None): + self.split_lines_of_utt = split_lines_of_utt + + # start_index is the index of the first line that appears in this + # segment, and end_index is one past the last line. This does not + # include unk-padding. + self.start_index = start_index + self.end_index = end_index + assert end_index > start_index + + # If the following values are nonzero, then when we create the segment + # we will add at the start and end of the segment [representing + # partial words], with this amount of additional audio. + self.start_unk_padding = 0.0 + self.end_unk_padding = 0.0 + + # debug_str keeps track of the 'core' of the segment. + if debug_str is None: + debug_str = 'core-start={0},core-end={1}'.format(start_index, + end_index) + else: + assert type(debug_str) == str + self.debug_str = debug_str + + # This gives the proportion of the time of the first line in the + # segment that we keep. Usually 1.0 but may be less if we've trimmed + # away some proportion of the time. + self.start_keep_proportion = 1.0 + # This gives the proportion of the time of the last line in the segment + # that we keep. Usually 1.0 but may be less if we've trimmed away some + # proportion of the time. + self.end_keep_proportion = 1.0 + + self.stats = None + + if compute_segment_stats: + self.compute_stats() + + if segment_stats is not None: + self.compute_stats() + self.stats.assert_equal(segment_stats) + self.stats = segment_stats + + def copy(self, copy_stats=True): + segment = Segment(self.split_lines_of_utt, self.start_index, + self.end_index, debug_str=self.debug_str, + segment_stats=(None if not copy_stats + else copy.deepcopy(self.stats))) + segment.start_keep_proportion = self.start_keep_proportion + segment.end_keep_proportion = self.end_keep_proportion + segment.start_unk_padding = self.start_unk_padding + segment.end_unk_padding = self.end_unk_padding + return segment + + def __str__(self): + return self.debug_info() + + def compute_stats(self): + """Compute stats for this segment and store them in SegmentStats + structure. + This is typically called just before merging segments. + """ + self.stats = SegmentStats() + for i in range(self.start_index, self.end_index): + this_duration = float(self.split_lines_of_utt[i][3]) + assert self.start_keep_proportion == 1.0 + assert self.end_keep_proportion == 1.0 + # TODO(vimal): Decide if keep proportion must be applied + # if i == self.start_index: + # this_duration *= self.start_keep_proportion + # if i == self.end_index - 1: + # this_duration *= self.end_keep_proportion + if self.end_index - 1 == self.start_index: + # TODO(vimal): Is this true? + assert self.start_keep_proportion == self.end_keep_proportion + + try: + if self.split_lines_of_utt[i][7] not in ['cor', 'fix', 'sil']: + # TODO(vimal): The commented part below is is apparently + # not true in modify_ctm_edits.py. + # Need to check this or change comments there. + # assert (self.split_lines_of_utt[i][6] + # not in non_scored_words) + assert not is_tainted(self.split_lines_of_utt[i]) + self.stats.num_incorrect_words += 1 + self.stats.incorrect_words_length += this_duration + if self.split_lines_of_utt[i][7] == 'sil': + self.stats.silence_length += this_duration + else: + if (self.split_lines_of_utt[i][6] + not in non_scored_words()): + self.stats.num_words += 1 + if (is_tainted(self.split_lines_of_utt[i]) + and self.split_lines_of_utt[i][7] not in 'sil' + and (self.split_lines_of_utt[i][6] + not in non_scored_words())): + # If ref_word is not a non-scored word, this would be + # counted as an incorrect word. + self.stats.num_tainted_words += 1 + self.stats.tainted_nonsilence_length += this_duration + except Exception: + _global_logger.error( + "Something went wrong when computing stats at " + "ctm line %s", self.split_lines_of_utt[i]) + raise + self.stats.total_length = self.length() + + try: + assert (self.stats.tainted_nonsilence_length + + self.stats.silence_length + + self.stats.incorrect_words_length - 0.001 + <= self.stats.total_length) + except AssertionError: + _global_logger.error( + "Something wrong with the stats for segment %s", self) + raise + + def possibly_add_tainted_lines(self): + """ + This is stage 1 of segment processing (after creating the boundaries of + the core of the segment, which is done outside of this class). + + This function may reduce start_index and/or increase end_index by + including a single adjacent 'tainted' line from the ctm-edits file. + This is only done if the lines at the boundaries of the segment are + currently real non-silence words and not non-scored words. The idea is + that we probably don't want to start or end the segment right at the + boundary of a real word, we want to add some kind of padding. + """ + split_lines_of_utt = self.split_lines_of_utt + # we're iterating over the segment (start, end) + for b in [False, True]: + if b: + boundary_index = self.end_index - 1 + adjacent_index = self.end_index + else: + boundary_index = self.start_index + adjacent_index = self.start_index - 1 + if (adjacent_index >= 0 + and adjacent_index < len(split_lines_of_utt)): + # only consider merging the adjacent word into the segment if + # we're not at the boundary of the utterance. + adjacent_line_is_tainted = is_tainted( + split_lines_of_utt[adjacent_index]) + # if the adjacent line wasn't tainted, then there must have + # been another stronger reason why we didn't include it in the + # core of the segment (probably that it was an ins, del or + # sub), so there is no point considering it. + if adjacent_line_is_tainted: + boundary_edit_type = split_lines_of_utt[boundary_index][7] + boundary_ref_word = split_lines_of_utt[boundary_index][6] + # Even if the edit_type is 'cor', it is possible that + # column 4 (hyp_word) is not the same as column 6 + # (ref_word) because the ref_word is an OOV and the + # hyp_word is OOV symbol. + + # we only add the tainted line to the segment if the word + # at the boundary was a non-silence word that was correctly + # decoded and not fixed [see modify_ctm_edits.py.] + if (boundary_edit_type == 'cor' + and (boundary_ref_word + not in non_scored_words())): + # Add the adjacent tainted line to the segment. + if b: + self.end_index += 1 + else: + self.start_index -= 1 + + def possibly_split_segment(self, max_internal_silence_length, + max_internal_non_scored_length): + """ + This is stage 3 of segment processing. + This function will split a segment into multiple pieces if any of the + internal [non-boundary] silences or non-scored words are longer + than the allowed values --max-internal-silence-length and + --max-internal-non-scored-length. + This function returns a list of segments. + In the normal case (where there is no splitting) it just returns an + array with a single element 'self'. + + Note: --max-internal-silence-length and + --max-internal-non-scored-length can be set to very large values + to avoid any splitting. + """ + # make sure the segment hasn't been processed more than we expect. + assert (self.start_unk_padding == 0.0 and self.end_unk_padding == 0.0 + and self.start_keep_proportion == 1.0 + and self.end_keep_proportion == 1.0) + segments = [] # the answer + cur_start_index = self.start_index + cur_start_is_split = False + # only consider splitting at non-boundary lines. [we'd just truncate + # the boundary lines.] + for index_to_split_at in range(cur_start_index + 1, + self.end_index - 1): + this_split_line = self.split_lines_of_utt[index_to_split_at] + this_duration = float(this_split_line[3]) + this_edit_type = this_split_line[7] + this_ref_word = this_split_line[6] + if ((this_edit_type == 'sil' and + this_duration > max_internal_silence_length) + or (this_ref_word in non_scored_words() + and (this_duration + > max_internal_non_scored_length))): + # We split this segment at this index, dividing the word in two + # [later on, in possibly_truncate_boundaries, it may be further + # truncated.] + # Note: we use 'index_to_split_at + 1' because the Segment + # constructor takes an 'end-index' which is interpreted as one + # past the end. + new_segment = Segment(self.split_lines_of_utt, cur_start_index, + index_to_split_at + 1, + debug_str=self.debug_str) + if cur_start_is_split: + new_segment.start_keep_proportion = 0.5 + new_segment.end_keep_proportion = 0.5 + cur_start_is_split = True + cur_start_index = index_to_split_at + segments.append(new_segment) + if len(segments) == 0: # We did not split. + segments.append(self) + else: + # We did split. Add the very last segment. + new_segment = Segment(self.split_lines_of_utt, cur_start_index, + self.end_index, + debug_str=self.debug_str) + assert cur_start_is_split + new_segment.start_keep_proportion = 0.5 + segments.append(new_segment) + return segments + + def possibly_split_long_segment(self, max_segment_length, + hard_max_segment_length, + min_silence_length_to_split, + min_non_scored_length_to_split): + """ + This is stage 4 of segment processing. + This function will split a segment into multiple pieces if it is + longer than the value --max-segment-length. + It tries to split at silences and non-scored words that are + at least --min-silence-length-to-split or + --min-non-scored-length-to-split long. + If this is not possible and the segments are still longer than + --hard-max-segment-length, then this is split into equal length + pieces of approximately --max-segment-length long. + This function returns a list of segments. + In the normal case (where there is no splitting) it just returns an + array with a single element 'self'. + """ + # make sure the segment hasn't been processed more than we expect. + assert self.start_unk_padding == 0.0 and self.end_unk_padding == 0.0 + if self.length() < max_segment_length: + return [self] + + segments = [self] # the answer + cur_start_index = self.start_index + + split_indexes = [] + # only consider splitting at non-boundary lines. [we'd just truncate + # the boundary lines.] + for index_to_split_at in range(cur_start_index + 1, + self.end_index - 1): + this_split_line = self.split_lines_of_utt[index_to_split_at] + this_duration = float(this_split_line[3]) + this_edit_type = this_split_line[7] + this_ref_word = this_split_line[6] + this_is_tainted = is_tainted(this_split_line) + if (this_edit_type == 'sil' + and this_duration > min_silence_length_to_split): + split_indexes.append((index_to_split_at, this_duration, + this_is_tainted)) + + if (this_ref_word in non_scored_words() + and (this_duration > min_non_scored_length_to_split)): + split_indexes.append((index_to_split_at, this_duration, + this_is_tainted)) + split_indexes.sort(key=lambda x: x[1], reverse=True) + split_indexes.sort(key=lambda x: x[2]) + + while True: + if len(split_indexes) == 0: + break + + new_segments = [] + + for segment in segments: + if segment.length() < max_segment_length: + new_segments.append(segment) + continue + + try: + index_to_split_at = next( + (x[0] for x in split_indexes + if (x[0] > segment.start_index + and x[0] < segment.end_index - 1))) + except StopIteration: + _global_logger.debug( + "Could not find an index in the range (%d, %d) in " + "split-indexes %s", segment.start_index, + segment.end_index - 1, split_indexes) + new_segments.append(segment) + continue + + # We split this segment at this index, dividing the word in two + # [later on, in possibly_truncate_boundaries, it may be further + # truncated.] + # Note: we use 'index_to_split_at + 1' because the Segment + # constructor takes an 'end-index' which is interpreted as one + # past the end. + new_segment = Segment( + self.split_lines_of_utt, segment.start_index, + index_to_split_at + 1, debug_str=self.debug_str) + new_segment.end_keep_proportion = 0.5 + new_segments.append(new_segment) + + new_segment = Segment( + self.split_lines_of_utt, index_to_split_at, + segment.end_index, debug_str=self.debug_str) + new_segment.start_keep_proportion = 0.5 + new_segments.append(new_segment) + + if len(segments) == len(new_segments): + # No splitting done + break + segments = new_segments + + for i, x in enumerate(segments): + _global_logger.debug("Segment %d = %s", i, x) + + new_segments = [] + # Split segments that are still very long + for segment in segments: + if segment.length() < hard_max_segment_length: + new_segments.append(segment) + continue + + cur_start_index = segment.start_index + cur_start = segment.start_time() + + index_to_split_at = None + try: + while True: + index_to_split_at = next( + (i for i in range(cur_start_index, segment.end_index) + if (float(self.split_lines_of_utt[i][2]) + >= cur_start + max_segment_length))) + + new_segment = Segment( + self.split_lines_of_utt, cur_start_index, + index_to_split_at) + new_segments.append(new_segment) + + cur_start_index = index_to_split_at + cur_start = float( + self.split_lines_of_utt[cur_start_index][2]) + index_to_split_at = None + + if (segment.end_time() - cur_start + < hard_max_segment_length): + raise StopIteration + except StopIteration: + if index_to_split_at is None: + _global_logger.debug( + "Could not find an index in the range (%d, %d) with " + "start time > %.2f", cur_start_index, + segment.end_index, cur_start + max_segment_length) + new_segment = Segment( + self.split_lines_of_utt, cur_start_index, + segment.end_index) + new_segments.append(new_segment) + break + segments = new_segments + return segments + + def possibly_truncate_boundaries(self, max_edge_silence_length, + max_edge_non_scored_length): + """ + This is stage 5 of segment processing. + It will truncate the silences and non-scored words at the segment + boundaries if they are longer than the --max-edge-silence-length and + --max-edge-non-scored-length respectively + (and to the extent that this wouldn't take us below the + --min-segment-length or --min-new-segment-length. See + relax_boundary_truncation()). + + Note: --max-edge-silence-length and --max-edge-non-scored-length + can be set to very large values to avoid any truncation. + """ + for b in [True, False]: + if b: + this_index = self.start_index + else: + this_index = self.end_index - 1 + this_split_line = self.split_lines_of_utt[this_index] + truncated_duration = None + this_duration = float(this_split_line[3]) + this_edit = this_split_line[7] + this_ref_word = this_split_line[6] + if (this_edit == 'sil' + and this_duration > max_edge_silence_length): + truncated_duration = max_edge_silence_length + elif (this_ref_word in non_scored_words() + and this_duration > max_edge_non_scored_length): + truncated_duration = max_edge_non_scored_length + if truncated_duration is not None: + keep_proportion = truncated_duration / this_duration + if b: + self.start_keep_proportion = keep_proportion + else: + self.end_keep_proportion = keep_proportion + + def relax_boundary_truncation(self, min_segment_length, + min_new_segment_length): + """ + This relaxes the segment-boundary truncation of + possibly_truncate_boundaries(), if it would take us below + min-new-segment-length or min-segment-length. + + Note: this does not relax the boundary truncation for a particular + boundary (start or end) if that boundary corresponds to a 'tainted' + line of the ctm (because it's dangerous to include too much 'tainted' + audio). + """ + # this should be called before adding unk padding. + assert self.start_unk_padding == self.end_unk_padding == 0.0 + if self.start_keep_proportion == self.end_keep_proportion == 1.0: + return # nothing to do there was no truncation. + length_cutoff = max(min_new_segment_length, min_segment_length) + length_with_truncation = self.length() + if length_with_truncation >= length_cutoff: + return # Nothing to do. + orig_start_keep_proportion = self.start_keep_proportion + orig_end_keep_proportion = self.end_keep_proportion + if not is_tainted(self.split_lines_of_utt[self.start_index]): + self.start_keep_proportion = 1.0 + if not is_tainted(self.split_lines_of_utt[self.end_index - 1]): + self.end_keep_proportion = 1.0 + length_with_relaxed_boundaries = self.length() + if length_with_relaxed_boundaries <= length_cutoff: + # Completely undo the truncation [to the extent allowed by the + # presence of tainted lines at the start/end] if, even without + # truncation, we'd be below the length cutoff. This segment may be + # removed later on (but it may not, if removing truncation makes us + # identical to the input utterance, and the length is between + # min_segment_length min_new_segment_length). + return + # Next, compute an interpolation constant a such that the + # {start,end}_keep_proportion values will equal + # a + # * [values-computed-by-possibly_truncate_boundaries()] + # + (1-a) * [completely-relaxed-values]. + # we're solving the equation: + # length_cutoff = a * length_with_truncation + # + (1-a) * length_with_relaxed_boundaries + # -> length_cutoff - length_with_relaxed_boundaries = + # a * (length_with_truncation - length_with_relaxed_boundaries) + # -> a = (length_cutoff - length_with_relaxed_boundaries) + # / (length_with_truncation - length_with_relaxed_boundaries) + a = ((length_cutoff - length_with_relaxed_boundaries) + / (length_with_truncation - length_with_relaxed_boundaries)) + if a < 0.0 or a > 1.0: + # TODO(vimal): Should this be an error? + _global_logger.warn("bad 'a' value = %.4f", a) + return + self.start_keep_proportion = ( + a * orig_start_keep_proportion + + (1 - a) * self.start_keep_proportion) + self.end_keep_proportion = ( + a * orig_end_keep_proportion + (1 - a) * self.end_keep_proportion) + if abs(self.length() - length_cutoff) >= 0.01: + # TODO(vimal): Should this be an error? + _global_logger.warn( + "possible problem relaxing boundary " + "truncation, length is %.2f vs %.2f", self.length(), + length_cutoff) + + def possibly_add_unk_padding(self, max_unk_padding): + """ + This is stage 7 of segment processing. + This function may set start_unk_padding and end_unk_padding to nonzero + values. This is done if the current boundary words are real, scored + words and we're not next to the beginning or end of the utterance. + """ + for b in [True, False]: + if b: + this_index = self.start_index + else: + this_index = self.end_index - 1 + this_split_line = self.split_lines_of_utt[this_index] + this_start_time = float(this_split_line[2]) + this_ref_word = this_split_line[6] + this_edit = this_split_line[7] + if this_edit == 'cor' and this_ref_word not in non_scored_words(): + # we can consider adding unk-padding. + if b: # start of utterance. + unk_padding = max_unk_padding + # close to beginning of file + if unk_padding > this_start_time: + unk_padding = this_start_time + # If we could add less than half of the specified + # unk-padding, don't add any (because when we add + # unk-padding we add the unknown-word symbol '', and + # if there isn't enough space to traverse the HMM we don't + # want to do it at all. + if unk_padding < 0.5 * max_unk_padding: + unk_padding = 0.0 + self.start_unk_padding = unk_padding + else: # end of utterance. + this_end_time = this_start_time + float(this_split_line[3]) + last_line = self.split_lines_of_utt[-1] + utterance_end_time = (float(last_line[2]) + + float(last_line[3])) + max_allowable_padding = utterance_end_time - this_end_time + assert max_allowable_padding > -0.01 + unk_padding = max_unk_padding + if unk_padding > max_allowable_padding: + unk_padding = max_allowable_padding + # If we could add less than half of the specified + # unk-padding, don't add any (because when we add + # unk-padding we add the unknown-word symbol '', + # and if there isn't enough space to traverse the HMM we + # don't want to do it at all. + if unk_padding < 0.5 * max_unk_padding: + unk_padding = 0.0 + self.end_unk_padding = unk_padding + + def start_time(self): + """Returns the start time of the utterance (within the enclosing + utterance). + This is before any rounding. + """ + if self.start_index == len(self.split_lines_of_utt): + assert self.end_index == len(self.split_lines_of_utt) + return self.end_time() + first_line = self.split_lines_of_utt[self.start_index] + first_line_start = float(first_line[2]) + first_line_duration = float(first_line[3]) + first_line_end = first_line_start + first_line_duration + return (first_line_end - self.start_unk_padding + - (first_line_duration * self.start_keep_proportion)) + + def debug_info(self, include_stats=True): + """Returns some string-valued information about 'this' that is useful + for debugging.""" + if include_stats and self.stats is not None: + stats = 'wer={wer:.2f},{stats},'.format( + wer=self.stats.wer(), stats=self.stats) + else: + stats = '' + + return ('start={start:d},end={end:d},' + 'unk-padding={start_unk_padding:.2f},{end_unk_padding:.2f},' + 'keep-proportion={start_prop:.2f},{end_prop:.2f},' + 'start-time={start_time:.2f},end-time={end_time:.2f},' + '{stats}' + 'debug-str={debug_str}'.format( + start=self.start_index, end=self.end_index, + start_unk_padding=self.start_unk_padding, + end_unk_padding=self.end_unk_padding, + start_prop=self.start_keep_proportion, + end_prop=self.end_keep_proportion, + start_time=self.start_time(), end_time=self.end_time(), + stats=stats, debug_str=self.debug_str)) + + def end_time(self): + """Returns the start time of the utterance (within the enclosing + utterance).""" + if self.end_index == 0: + assert self.start_index == 0 + return self.start_time() + last_line = self.split_lines_of_utt[self.end_index - 1] + last_line_start = float(last_line[2]) + last_line_duration = float(last_line[3]) + return (last_line_start + + (last_line_duration * self.end_keep_proportion) + + self.end_unk_padding) + + def length(self): + """Returns the segment length in seconds.""" + return self.end_time() - self.start_time() + + def is_whole_utterance(self): + """returns true if this segment corresponds to the whole utterance that + it's a part of (i.e. its start/end time are zero and the end-time of + the last segment.""" + last_line_of_utt = self.split_lines_of_utt[-1] + last_line_end_time = (float(last_line_of_utt[2]) + + float(last_line_of_utt[3])) + return (abs(self.start_time() - 0.0) < 0.001 + and abs(self.end_time() - last_line_end_time) < 0.001) + + def get_junk_proportion(self): + """Returns the proportion of the duration of this segment that consists + of unk-padding and tainted lines of input (will be between 0.0 and + 1.0).""" + # Note: only the first and last lines could possibly be tainted as + # that's how we create the segments; and if either or both are tainted + # the utterance must contain other lines, so double-counting is not a + # problem. + junk_duration = self.start_unk_padding + self.end_unk_padding + first_split_line = self.split_lines_of_utt[self.start_index] + if is_tainted(first_split_line): + first_duration = float(first_split_line[3]) + junk_duration += first_duration * self.start_keep_proportion + last_split_line = self.split_lines_of_utt[self.end_index - 1] + if is_tainted(last_split_line): + last_duration = float(last_split_line[3]) + junk_duration += last_duration * self.end_keep_proportion + return junk_duration / self.length() + + def get_junk_duration(self): + """Returns duration of junk""" + return self.get_junk_proportion() * self.length() + + def merge_adjacent_segment(self, other): + """ + This function will merge the segment in 'other' with the segment + in 'self'. It is only to be called when 'self' and 'other' are from + the same utterance, 'other' is after 'self' in time order (based on + the original segment cores), and self.end_index <= self.start_index + i.e. the two segments might have at most one index in common, + which is usually a tainted word or silence. + """ + try: + assert self.end_index <= other.start_index + 1 + assert self.start_time() < other.end_time() + assert self.split_lines_of_utt is other.split_lines_of_utt + except AssertionError: + _global_logger.error("self: %s", self) + _global_logger.error("other: %s", other) + raise + + assert self.start_index == 0 or self.start_index != other.start_index + + _global_logger.debug("Before merging: %s", self) + + assert not self.stats.compare(other.stats), "%s %s" % (self, other) + self.stats.combine(other.stats) + + if self.end_index == other.start_index + 1: + overlapping_segment = Segment( + self.split_lines_of_utt, other.start_index, + self.end_index, compute_segment_stats=True) + self.stats.combine(overlapping_segment.stats, scale=-1) + + _global_logger.debug("Other segment: %s", other) + + self.debug_str = "({0}/merged-with-adjacent/{1})".format( + self.debug_str, other.debug_str) + + # everything that relates to the end of this segment gets copied + # from 'other'. + self.end_index = other.end_index + self.end_unk_padding = other.end_unk_padding + self.end_keep_proportion = other.end_keep_proportion + + _global_logger.debug("After merging %s", self) + return + + def merge_with_segment(self, other, max_deleted_words): + """ + This function will merge the segment in 'other' with the segment + in 'self'. It is only to be called when 'self' and 'other' are from + the same utterance, 'other' is after 'self' in time order (based on + the original segment cores), and self.end_time() >= other.start_time(). + Note: in this situation there will normally be deleted words + between the two segments. What this program does with the deleted + words depends on '--max-deleted-words-kept-when-merging'. If there + were any inserted words in the transcript (less likely), this + program will keep the reference. + + Note: --max-deleted-words-kept-when-merging can be set to a very + large value to keep all the words. + """ + try: + assert self.end_time() >= other.start_time() + assert self.start_time() < other.end_time() + assert self.split_lines_of_utt is other.split_lines_of_utt + except AssertionError: + _global_logger.error("self: %s", self) + _global_logger.error("other: %s", other) + raise + + assert self.start_index == 0 or self.start_index != other.start_index + + _global_logger.debug("Before merging: %s", self) + + assert (not self.stats.compare(other.stats) + or self.start_time() != other.start_time() + or self.end_time() != other.end_time() + ), "%s %s" % (self, other) + self.stats.combine(other.stats) + + _global_logger.debug("Other segment: %s", other) + + orig_self_end_index = self.end_index + self.debug_str = "({0}/merged-with/{1})".format( + self.debug_str, other.debug_str) + + # everything that relates to the end of this segment gets copied + # from 'other'. + self.end_index = other.end_index + self.end_unk_padding = other.end_unk_padding + self.end_keep_proportion = other.end_keep_proportion + + _global_logger.debug("After merging %s", self) + + # The next thing we have to do is to go over any lines of the ctm that + # appear between 'self' and 'other', or are shared between both (this + # would only happen for tainted silence or non-scored-word segments), + # and decide what to do with them. We'll keep the reference for any + # substitutions or insertions (which anyway are unlikely to appear + # in these merged segments). Note: most of this happens in + # self.Text(), but at this point we need to decide whether to mark any + # deletions as 'discard-this-word'. + try: + if orig_self_end_index <= other.start_index: + # No overlap in indexes + first_index_of_overlap = orig_self_end_index + last_index_of_overlap = other.start_index - 1 + segment = Segment( + self.split_lines_of_utt, orig_self_end_index, + other.start_index, compute_segment_stats=True) + self.stats.combine(segment.stats) + else: + first_index_of_overlap = other.start_index + last_index_of_overlap = orig_self_end_index - 1 + + num_deleted_words = 0 + for i in range(first_index_of_overlap, last_index_of_overlap + 1): + edit_type = self.split_lines_of_utt[i][7] + if edit_type == 'del': + num_deleted_words += 1 + if num_deleted_words > max_deleted_words: + for i in range(first_index_of_overlap, + last_index_of_overlap + 1): + if self.split_lines_of_utt[i][7] == 'del': + self.split_lines_of_utt[i].append( + 'do-not-include-in-text') + except: + _global_logger.error( + "first-index-of-overlap = %d", first_index_of_overlap) + _global_logger.error( + "last-index-of-overlap = %d", last_index_of_overlap) + _global_logger.error("line = %d = %s", i, + self.split_lines_of_utt[i]) + raise + _global_logger.debug("After merging %s", self) + + def contains_atleast_one_scored_non_oov_word(self): + """ + this will return true if there is at least one word in the utterance + that's a scored word (not a non-scored word) and not an OOV word that's + realized as unk. This becomes a filter on keeping segments. + """ + for i in range(self.start_index, self.end_index): + this_split_line = self.split_lines_of_utt[i] + this_hyp_word = this_split_line[4] + this_ref_word = this_split_line[6] + this_edit = this_split_line[7] + if (this_edit == 'cor' and this_ref_word not in non_scored_words() + and this_ref_word == this_hyp_word): + return True + return False + + def text(self, oov_symbol, eps_symbol=""): + """Returns the text corresponding to this utterance, as a string.""" + text_array = [] + if self.start_unk_padding != 0.0: + text_array.append(oov_symbol) + for i in range(self.start_index, self.end_index): + this_split_line = self.split_lines_of_utt[i] + this_ref_word = this_split_line[6] + if (this_ref_word != eps_symbol + and this_split_line[-1] != 'do-not-include-in-text'): + text_array.append(this_ref_word) + if self.end_unk_padding != 0.0: + text_array.append(oov_symbol) + return ' '.join(text_array) + + +class SegmentsMerger(object): + """This class contains methods for merging segments. It stores the + appropriate statistics required for this process in objects of + SegmentStats class. + + Paramters: + segments - a reference to the list of inital segments + merged_segments - stores all the initial segments as well + as the newly created segments + between_segments - stores the inter-segment "segments" + for the initial segments + split_lines_of_utt - a reference to the CTM lines + """ + + def __init__(self, segments): + self.segments = segments + + try: + self.split_lines_of_utt = segments[0].split_lines_of_utt + except IndexError as e: + _global_logger.error("No input segments found!") + raise e + + self.merged_segments = {} + self.between_segments = [None for i in range(len(segments) + 1)] + + if segments[0].start_index > 0: + self.between_segments[0] = Segment( + self.split_lines_of_utt, 0, segments[0].start_index, + compute_segment_stats=True) + + for i, x in enumerate(segments): + x.compute_stats() + self.merged_segments[(i, )] = x + + if i > 0 and segments[i].start_index > segments[i - 1].end_index: + self.between_segments[i] = Segment( + self.split_lines_of_utt, segments[i - 1].end_index, + segments[i].start_index, compute_segment_stats=True) + + if segments[-1].end_index < len(self.split_lines_of_utt): + self.between_segments[-1] = Segment( + self.split_lines_of_utt, segments[-1].end_index, + len(self.split_lines_of_utt), compute_segment_stats=True) + + def _get_merged_cluster(self, cluster1, cluster2, rejected_clusters=None, + max_intersegment_incorrect_words_length=1): + try: + assert cluster2[0] > cluster1[-1] + new_cluster = cluster1 + cluster2 + new_cluster_tup = tuple(new_cluster) + + if (rejected_clusters is not None + and new_cluster_tup in rejected_clusters): + return (None, new_cluster, True) + + if new_cluster_tup in self.merged_segments: + return (self.merged_segments[new_cluster_tup], + new_cluster, False) + + if cluster1[-1] == -1: + assert len(cluster1) == 1 + # Consider merging cluster2 with the region before the 0^th + # segment + if (self.between_segments[0] is None + or self.between_segments[0].stats.total_length == 0 + or (self.between_segments[0] + .stats.incorrect_words_length + > max_intersegment_incorrect_words_length)): + # Reject zero length or bad start region + return (None, new_cluster, True) + merged_segment = self.between_segments[0].copy() + else: + merged_segment = self.merged_segments[tuple(cluster1)].copy() + + if cluster2[0] == len(self.segments): + assert len(cluster2) == 1 + if (self.between_segments[-1] is None + or (self.between_segments[-1] + .stats.total_length == 0) + or (self.between_segments[-1] + .stats.incorrect_words_length + > max_intersegment_incorrect_words_length)): + # Reject zero length or bad end region + return (None, new_cluster, True) + if self.between_segments[cluster2[0]] is not None: + if (self.between_segments[cluster2[0]] + .stats.incorrect_words_length + > max_intersegment_incorrect_words_length): + return (None, new_cluster, True) + merged_segment.merge_adjacent_segment( + self.between_segments[cluster2[0]]) + + if cluster2[0] < len(self.segments): + merged_segment.merge_adjacent_segment( + self.merged_segments[tuple(cluster2)]) + # else: + # Already done + # merged_segment.merge_adjacent_segment(self.between_segments[-1]) + + self.merged_segments[new_cluster_tup] = merged_segment + return (merged_segment, new_cluster, False) + except: + _global_logger.error("Failed merging cluster1 %s and cluster2 %s", + cluster1, cluster2) + for i in (cluster1 + cluster2): + if i >= 0 and i < len(self.segments): + _global_logger.error("Segment %d = %s", i, + self.segments[i]) + raise + + def merge_clusters(self, scoring_function, + max_wer=10, max_bad_proportion=0.3, + max_segment_length=10, + max_intersegment_incorrect_words_length=1): + for i, x in enumerate(self.segments): + _global_logger.debug("before agglomerative clustering, segment %d" + " = %s", i, x) + + # Initial clusters are the individual segments themselves. + clusters = [[x] for x in range(-1, len(self.segments) + 1)] + + rejected_clusters = set() + + while len(clusters) > 1: + try: + _global_logger.debug("Current clusters: %s", clusters) + + heap = [] + + for i in range(len(clusters) - 1): + merged_segment, new_cluster, reject = ( + self._get_merged_cluster( + clusters[i], clusters[i + 1], rejected_clusters, + max_intersegment_incorrect_words_length=( + max_intersegment_incorrect_words_length))) + if reject: + rejected_clusters.add(tuple(new_cluster)) + continue + + heapq.heappush(heap, (-scoring_function(merged_segment), + (merged_segment, i, new_cluster))) + + candidate_index = -1 + candidate_cluster = None + + while True: + try: + score, tup = heapq.heappop(heap) + except IndexError: + break + + segment, index, cluster = tup + + _global_logger.debug( + "Considering new cluster: (%d, %s)", index, cluster) + + if segment.stats.wer() > max_wer: + _global_logger.debug( + "Rejecting cluster with " + "WER%% %.2f > %.2f", segment.stats.wer(), max_wer) + rejected_clusters.add(tuple(cluster)) + continue + + if segment.stats.bad_proportion() > max_bad_proportion: + _global_logger.debug( + "Rejecting cluster with bad-proportion " + "%.2f > %.2f", segment.stats.bad_proportion(), + max_bad_proportion) + rejected_clusters.add(tuple(cluster)) + continue + + if segment.stats.total_length > max_segment_length: + _global_logger.debug( + "Rejecting cluster with length " + "%.2f > %.2f", segment.stats.total_length, + max_segment_length) + rejected_clusters.add(tuple(cluster)) + continue + + candidate_index, candidate_cluster = tup[1:] + _global_logger.debug("Accepted cluster (%d, %s)", + candidate_index, candidate_cluster) + break + + if candidate_index == -1: + return clusters + + new_clusters = [] + + for i in range(candidate_index): + new_clusters.append(clusters[i]) + new_clusters.append(candidate_cluster) + for i in range(candidate_index + 2, len(clusters)): + new_clusters.append(clusters[i]) + + if len(new_clusters) >= len(clusters): + raise RuntimeError("Old: {0}; New: {1}".format( + clusters, new_clusters)) + clusters = new_clusters + except Exception: + _global_logger.error( + "Failed merging clusters %s", clusters) + raise + + return clusters + + +def merge_segments(segments, args): + if len(segments) == 0: + _global_logger.debug("Got no segments at merging segments stage") + return [] + + def scoring_function(segment): + stats = segment.stats + try: + return (-stats.wer() - args.silence_factor * stats.silence_length + - args.incorrect_words_factor + * stats.incorrect_words_length + - args.tainted_words_factor + * stats.num_tainted_words * 100.0 / stats.num_words) + except ZeroDivisionError: + return float("-inf") + + # Do agglomerative clustering on the initial segments with the score + # for combining neighboring segments being the scoring_function on the + # stats of the combined segment. + merger = SegmentsMerger(segments) + clusters = merger.merge_clusters( + scoring_function, max_wer=args.max_wer, + max_bad_proportion=args.max_bad_proportion, + max_segment_length=args.max_segment_length_for_merging, + max_intersegment_incorrect_words_length=( + args.max_intersegment_incorrect_words_length)) + + _global_logger.debug("Clusters to be merged: %s", clusters) + + # Do the actual merging based on the clusters. + new_segments = [] + for cluster_index, cluster in enumerate(clusters): + _global_logger.debug( + "Merging cluster (%d, %s)", cluster_index, cluster) + + try: + if cluster_index == 0 and len(cluster) == 1: + assert cluster[0] == -1 + _global_logger.debug( + "Not adding region before the first segment") + # skip adding the lines before the initial segment if its + # not merged with the initial segment + continue + elif cluster_index == len(clusters) - 1 and len(cluster) == 1: + _global_logger.debug( + "Not adding remaining end region %s", + cluster[0]) + assert cluster[0] == len(segments) + # skip adding the lines after the last segment if its + # not merged with the last segment + break + + new_segments.append(merger.merged_segments[tuple(cluster)]) + except Exception: + _global_logger.error("Error with cluster (%d, %s)", + cluster_index, cluster) + raise + + segments = new_segments + + for i, x in enumerate(segments): + _global_logger.debug( + "after agglomerative clustering: segment %d = %s", i, x) + + assert len(segments) > 0 + segment_index = 0 + # Ignore all the initial segments that have WER > max_wer + while segment_index < len(segments): + segment = segments[segment_index] + if segment.stats.wer() < args.max_wer: + break + segment_index += 1 + + if segment_index == len(segments): + _global_logger.debug("No merged segments were below " + "WER%% %.2f", args.max_wer) + return [] + + _global_logger.debug("Merging overlapping segments starting from the " + "first segment with WER%% < max_wer i.e. %d = %s", + segment_index, segments[segment_index]) + + new_segments = [segments[segment_index]] + segment_index += 1 + while segment_index < len(segments): + if segments[segment_index].stats.wer() > args.max_wer: + # ignore this segment + segment_index += 1 + continue + if new_segments[-1].end_time() >= segments[segment_index].start_time(): + new_segments[-1].merge_with_segment( + segments[segment_index], args.max_deleted_words) + else: + new_segments.append(segments[segment_index]) + segment_index += 1 + segments = new_segments + + return segments + + +def get_segments_for_utterance(split_lines_of_utt, args, utterance_stats): + """ + This function creates the segments for an utterance as a list + of class Segment. + It returns a 2-tuple (list-of-segments, list-of-deleted-segments) + where the deleted segments are only useful for diagnostic printing. + Note: split_lines_of_utt is a list of lists, one per line, each containing + the sequence of fields. + """ + utterance_stats.num_utterances += 1 + + segment_ranges = compute_segment_cores(split_lines_of_utt) + + utterance_end_time = (float(split_lines_of_utt[-1][2]) + + float(split_lines_of_utt[-1][3])) + utterance_stats.total_length_of_utterances += utterance_end_time + + segments = [Segment(split_lines_of_utt, x[0], x[1]) + for x in segment_ranges] + + utterance_stats.accumulate_segment_stats( + segments, 'stage 0 [segment cores]') + + for i, x in enumerate(segments): + _global_logger.debug("stage 0: segment %d = %s", i, x) + + if args.verbose > 4: + print ("Stage 0 [segment cores]:", file=sys.stderr) + segments_copy = [x.copy() for x in segments] + print_debug_info_for_utterance(sys.stderr, + copy.deepcopy(split_lines_of_utt), + segments_copy, []) + + for segment in segments: + segment.possibly_add_tainted_lines() + utterance_stats.accumulate_segment_stats( + segments, 'stage 1 [add tainted lines]') + + for i, x in enumerate(segments): + _global_logger.debug("stage 1: segment %d = %s", i, x) + + if args.verbose > 4: + print ("Stage 1 [add tainted lines]:", file=sys.stderr) + segments_copy = [x.copy() for x in segments] + print_debug_info_for_utterance(sys.stderr, + copy.deepcopy(split_lines_of_utt), + segments_copy, []) + + segments = merge_segments(segments, args) + utterance_stats.accumulate_segment_stats( + segments, 'stage 2 [merge segments]') + + for i, x in enumerate(segments): + _global_logger.debug("stage 2: segment %d = %s", i, x) + + if args.verbose > 4: + print ("Stage 2 [merge segments]:", file=sys.stderr) + segments_copy = [x.copy() for x in segments] + print_debug_info_for_utterance(sys.stderr, + copy.deepcopy(split_lines_of_utt), + segments_copy, []) + + new_segments = [] + for s in segments: + new_segments += s.possibly_split_segment( + args.max_internal_silence_length, + args.max_internal_non_scored_length) + segments = new_segments + utterance_stats.accumulate_segment_stats( + segments, 'stage 3 [split segments]') + + for i, x in enumerate(segments): + _global_logger.debug( + "stage 3: segment %d, %s", i, x.debug_info(False)) + + if args.verbose > 4: + print ("Stage 3 [split segments]:", file=sys.stderr) + segments_copy = [x.copy() for x in segments] + print_debug_info_for_utterance(sys.stderr, + copy.deepcopy(split_lines_of_utt), + segments_copy, []) + + new_segments = [] + for s in segments: + new_segments += s.possibly_split_long_segment( + args.max_segment_length_for_splitting, + args.hard_max_segment_length, + args.min_silence_length_to_split, + args.min_non_scored_length_to_split) + segments = new_segments + utterance_stats.accumulate_segment_stats( + segments, 'stage 4 [split long segments]') + + for i, x in enumerate(segments): + _global_logger.debug( + "stage 4: segment %d, %s", i, x.debug_info(False)) + + if args.verbose > 4: + print ("Stage 4 [split long segments]:", file=sys.stderr) + segments_copy = [x.copy() for x in segments] + print_debug_info_for_utterance(sys.stderr, + copy.deepcopy(split_lines_of_utt), + segments_copy, []) + + for s in segments: + s.possibly_truncate_boundaries(args.max_edge_silence_length, + args.max_edge_non_scored_length) + utterance_stats.accumulate_segment_stats( + segments, 'stage 5 [truncate boundaries]') + + for i, x in enumerate(segments): + _global_logger.debug( + "stage 5: segment %d = %s", i, x.debug_info(False)) + + if args.verbose > 4: + print ("Stage 5 [truncate boundaries]:", file=sys.stderr) + segments_copy = [x.copy() for x in segments] + print_debug_info_for_utterance(sys.stderr, + copy.deepcopy(split_lines_of_utt), + segments_copy, []) + + for s in segments: + s.relax_boundary_truncation(args.min_segment_length, + args.min_new_segment_length) + utterance_stats.accumulate_segment_stats( + segments, 'stage 6 [relax boundary truncation]') + + for i, x in enumerate(segments): + _global_logger.debug( + "stage 6: segment %d = %s", i, x.debug_info(False)) + + if args.verbose > 4: + print ("Stage 6 [relax boundary truncation]:", file=sys.stderr) + segments_copy = [x.copy() for x in segments] + print_debug_info_for_utterance(sys.stderr, + copy.deepcopy(split_lines_of_utt), + segments_copy, []) + + for s in segments: + s.possibly_add_unk_padding(args.unk_padding) + utterance_stats.accumulate_segment_stats( + segments, 'stage 7 [unk-padding]') + + for i, x in enumerate(segments): + _global_logger.debug( + "stage 7: segment %d = %s", i, x.debug_info(False)) + + if args.verbose > 4: + print ("Stage 7 [unk-padding]:", file=sys.stderr) + segments_copy = [x.copy() for x in segments] + print_debug_info_for_utterance(sys.stderr, + copy.deepcopy(split_lines_of_utt), + segments_copy, []) + + deleted_segments = [] + new_segments = [] + for s in segments: + # the 0.999 allows for roundoff error. + if (not s.is_whole_utterance() + and s.length() < 0.999 * args.min_new_segment_length): + s.debug_str += '[deleted-because-of--min-new-segment-length]' + deleted_segments.append(s) + else: + new_segments.append(s) + segments = new_segments + utterance_stats.accumulate_segment_stats( + segments, + 'stage 8 [remove new segments under --min-new-segment-length') + + for i, x in enumerate(segments): + _global_logger.debug( + "stage 8: segment %d = %s", i, x.debug_info(False)) + + if args.verbose > 4: + print ("Stage 8 [remove new segments under " + "--min-new-segment-length]:", file=sys.stderr) + segments_copy = [x.copy() for x in segments] + print_debug_info_for_utterance(sys.stderr, + copy.deepcopy(split_lines_of_utt), + segments_copy, []) + + new_segments = [] + for s in segments: + # the 0.999 allows for roundoff error. + if s.length() < 0.999 * args.min_segment_length: + s.debug_str += '[deleted-because-of--min-segment-length]' + deleted_segments.append(s) + else: + new_segments.append(s) + segments = new_segments + utterance_stats.accumulate_segment_stats( + segments, 'stage 9 [remove segments under --min-segment-length]') + + for i, x in enumerate(segments): + _global_logger.debug( + "stage 9: segment %d = %s", i, x.debug_info(False)) + + if args.verbose > 4: + print ("Stage 9 [remove segments under " + "--min-segment-length]:", file=sys.stderr) + segments_copy = [x.copy() for x in segments] + print_debug_info_for_utterance(sys.stderr, + copy.deepcopy(split_lines_of_utt), + segments_copy, []) + + new_segments = [] + for s in segments: + if s.contains_atleast_one_scored_non_oov_word(): + new_segments.append(s) + else: + s.debug_str += '[deleted-because-no-scored-non-oov-words]' + deleted_segments.append(s) + segments = new_segments + utterance_stats.accumulate_segment_stats( + segments, 'stage 10 [remove segments without scored,non-OOV words]') + + for i, x in enumerate(segments): + _global_logger.debug( + "stage 10: segment %d = %s", i, x.debug_info(False)) + + if args.verbose > 4: + print ("Stage 10 [remove segments without scored, non-OOV words " + "", file=sys.stderr) + segments_copy = [x.copy() for x in segments] + print_debug_info_for_utterance(sys.stderr, + copy.deepcopy(split_lines_of_utt), + segments_copy, []) + + for i in range(len(segments) - 1): + if segments[i].end_time() > segments[i + 1].start_time(): + # this just adds something to --ctm-edits-out output + segments[i + 1].debug_str += ",overlaps-previous-segment" + + if len(segments) == 0: + utterance_stats.num_utterances_without_segments += 1 + + return (segments, deleted_segments) + + +def float_to_string(f): + """ this prints a number with a certain number of digits after the point, + while removing trailing zeros. + """ + num_digits = 6 # we want to print 6 digits after the zero + g = f + while abs(g) > 1.0: + g *= 0.1 + num_digits += 1 + format_str = '%.{0}g'.format(num_digits) + return format_str % f + + +def time_to_string(time, frame_length): + """ Gives time in string form as an exact multiple of the frame-length, + e.g. 0.01 (after rounding). + """ + n = round(time / frame_length) + assert n >= 0 + # The next function call will remove trailing zeros while printing it, so + # that e.g. 0.01 will be printed as 0.01 and not 0.0099999999999999. It + # seems that doing this in a simple way is not really possible (at least, + # not without assuming that frame_length is of the form 10^-n, which we + # don't really want to do). + return float_to_string(n * frame_length) + + +def write_segments_for_utterance(text_output_handle, segments_output_handle, + old_utterance_name, segments, oov_symbol, + eps_symbol="", frame_length=0.01): + for n, segment in enumerate(segments): + # split utterances will be named foo-bar-1 foo-bar-2, etc. + new_utterance_name = old_utterance_name + "-" + str(n + 1) + # print a line to the text output of the form like + # + # like: + # foo-bar-1 hello this is dan + print(new_utterance_name, segment.text(oov_symbol, eps_symbol), + file=text_output_handle) + # print a line to the segments output of the form + # + # like: + # foo-bar-1 foo-bar 5.1 7.2 + print(new_utterance_name, old_utterance_name, + time_to_string(segment.start_time(), frame_length), + time_to_string(segment.end_time(), frame_length), + file=segments_output_handle) + + +# Note, this is destrutive of 'segments_for_utterance', but it won't matter. +def print_debug_info_for_utterance(ctm_edits_out_handle, + split_lines_of_cur_utterance, + segments_for_utterance, + deleted_segments_for_utterance, + frame_length=0.01): + # info_to_print will be list of 2-tuples + # (time, 'start-segment-n'|'end-segment-n') + # representing the start or end times of segments. + info_to_print = [] + for n, segment in enumerate(segments_for_utterance): + start_string = 'start-segment-{0}[{1}]'.format(n + 1, + segment.debug_info()) + info_to_print.append((segment.start_time(), start_string)) + end_string = 'end-segment-{0}'.format(n + 1) + info_to_print.append((segment.end_time(), end_string)) + # for segments that were deleted we print info like + # start-deleted-segment-1, and otherwise similar info to segments that were + # retained. + for n, segment in enumerate(deleted_segments_for_utterance): + start_string = 'start-deleted-segment-{0}[{1}]'.format( + n + 1, segment.debug_info(False)) + info_to_print.append((segment.start_time(), start_string)) + end_string = 'end-deleted-segment-{0}'.format(n + 1) + info_to_print.append((segment.end_time(), end_string)) + + info_to_print = sorted(info_to_print) + + for i, split_line in enumerate(split_lines_of_cur_utterance): + # add an index like [0], [1], to the utterance-id so we can easily look + # up segment indexes. + split_line[0] += '[{0}]'.format(i) + start_time = float(split_line[2]) + end_time = start_time + float(split_line[3]) + split_line_copy = list(split_line) + while len(info_to_print) > 0 and info_to_print[0][0] <= end_time: + (segment_start, string) = info_to_print[0] + # shift the first element off of info_to_print. + info_to_print = info_to_print[1:] + # add a field like 'start-segment1[...]=3.21' to what we're about + # to print. + split_line_copy.append( + '{0}={1}'.format(string, + time_to_string(segment_start, frame_length))) + print(' '.join(split_line_copy), file=ctm_edits_out_handle) + + +class WordStats(object): + """ + This accumulates word-level stats about, for each reference word, with + what probability it will end up in the core of a segment. Words with + low probabilities of being in segments will generally be associated + with some kind of error (there is a higher probability of having a + wrong lexicon entry). + """ + def __init__(self): + self.word_count_pair = defaultdict(lambda: [0, 0]) + + def accumulate_for_utterance(self, split_lines_of_utt, + segments_for_utterance, + eps_symbol=""): + # word_count_pair is a map from a string (the word) to + # a list [total-count, count-not-within-segments] + line_is_in_segment = [False] * len(split_lines_of_utt) + for segment in segments_for_utterance: + for i in range(segment.start_index, segment.end_index): + line_is_in_segment[i] = True + for i, split_line in enumerate(split_lines_of_utt): + this_ref_word = split_line[6] + if this_ref_word != eps_symbol: + self.word_count_pair[this_ref_word][0] += 1 + if not line_is_in_segment[i]: + self.word_count_pair[this_ref_word][1] += 1 + + def print(self, word_stats_out): + # Sort from most to least problematic. We want to give more prominence + # to words that are most frequently not in segments, but also to + # high-count words. Define badness = pair[1] / pair[0], and + # total_count = pair[0], where 'pair' is a value of word_count_pair. + # We'll reverse sort on badness^3 * total_count = pair[1]^3 / + # pair[0]^2. + for key, pair in sorted( + self.word_count_pair.items(), + key=lambda item: (item[1][1] ** 3) * 1.0 / (item[1][0] ** 2), + reverse=True): + badness = pair[1] * 1.0 / pair[0] + total_count = pair[0] + print(key, badness, total_count, file=word_stats_out) + try: + word_stats_out.close() + except: + _global_logger.error("error closing file --word-stats-out=%s " + "(full disk?)", word_stats_out.name) + raise + + _global_logger.info( + """please see the file %s for word-level + statistics saying how frequently each word was excluded for a + segment; format is + . Particularly problematic words appear near the top + of the file.""", word_stats_out.name) + + +def process_data(args, oov_symbol, utterance_stats, word_stats): + """ + Most of what we're doing in the lines below is splitting the input lines + and grouping them per utterance, before giving them to + get_segments_for_utterance() and then printing the modified lines. + """ + first_line = args.ctm_edits_in.readline() + if first_line == '': + sys.exit("segment_ctm_edits.py: empty input") + split_pending_line = first_line.split() + if len(split_pending_line) == 0: + sys.exit("segment_ctm_edits.py: bad input line " + first_line) + cur_utterance = split_pending_line[0] + split_lines_of_cur_utterance = [] + + while True: + try: + if (len(split_pending_line) == 0 + or split_pending_line[0] != cur_utterance): + # Read one whole utterance. Now process it. + (segments_for_utterance, + deleted_segments_for_utterance) = get_segments_for_utterance( + split_lines_of_cur_utterance, args=args, + utterance_stats=utterance_stats) + word_stats.accumulate_for_utterance( + split_lines_of_cur_utterance, segments_for_utterance) + write_segments_for_utterance( + args.text_out, args.segments_out, cur_utterance, + segments_for_utterance, oov_symbol=oov_symbol, + frame_length=args.frame_length) + if args.ctm_edits_out is not None: + print_debug_info_for_utterance( + args.ctm_edits_out, split_lines_of_cur_utterance, + segments_for_utterance, deleted_segments_for_utterance, + frame_length=args.frame_length) + split_lines_of_cur_utterance = [] + if len(split_pending_line) == 0: + break + else: + cur_utterance = split_pending_line[0] + + split_lines_of_cur_utterance.append(split_pending_line) + next_line = args.ctm_edits_in.readline() + split_pending_line = next_line.split() + if len(split_pending_line) == 0: + if next_line != '': + sys.exit("segment_ctm_edits.py: got an " + "empty or whitespace input line") + except Exception: + _global_logger.error( + "Error with utterance %s", cur_utterance) + raise + + +def read_non_scored_words(non_scored_words_file): + for line in non_scored_words_file.readlines(): + parts = line.split() + if not len(parts) == 1: + raise RuntimeError( + "segment_ctm_edits.py: bad line in non-scored-words " + "file {0}: {1}".format(non_scored_words_file, line)) + _global_non_scored_words.add(parts[0]) + non_scored_words_file.close() + + +class UtteranceStats(object): + + def __init__(self): + # segment_total_length and num_segments are maps from + # 'stage' strings; see accumulate_segment_stats for details. + self.segment_total_length = defaultdict(int) + self.num_segments = defaultdict(int) + # the lambda expression below is an anonymous function that takes no + # arguments and returns the new list [0, 0]. + self.num_utterances = 0 + self.num_utterances_without_segments = 0 + self.total_length_of_utterances = 0 + + def accumulate_segment_stats(self, segment_list, text): + """ + Here, 'text' will be something that indicates the stage of processing, + e.g. 'Stage 0: segment cores', 'Stage 1: add tainted lines', etc. + """ + for segment in segment_list: + self.num_segments[text] += 1 + self.segment_total_length[text] += segment.length() + + def print_segment_stats(self): + _global_logger.info( + """Number of utterances is %d, of which %.2f%% had no segments + after all processing; total length of data in original utterances + (in seconds) was %d""", + self.num_utterances, + (self.num_utterances_without_segments * 100.0 + / self.num_utterances), + self.total_length_of_utterances) + + keys = sorted(self.segment_total_length.keys()) + for i, key in enumerate(keys): + if i > 0: + delta_percentage = '[%+.2f%%]' % ( + (self.segment_total_length[key] + - self.segment_total_length[keys[i - 1]]) + * 100.0 / self.total_length_of_utterances) + _global_logger.info( + 'At %s, num-segments is %d, total length %.2f%% of ' + 'original total %s', + key, self.num_segments[key], + (self.segment_total_length[key] + * 100.0 / self.total_length_of_utterances), + delta_percentage if i > 0 else '') + + +def main(): + args = get_args() + + try: + global _global_non_scored_words + _global_non_scored_words = set() + read_non_scored_words(args.non_scored_words_in) + + oov_symbol = None + if args.oov_symbol_file is not None: + try: + line = args.oov_symbol_file.readline() + assert len(line.split()) == 1 + oov_symbol = line.split()[0] + assert args.oov_symbol_file.readline() == '' + args.oov_symbol_file.close() + except Exception: + _global_logger.error("error reading file " + "--oov-symbol-file=%s", + args.oov_symbol_file.name) + raise + elif args.unk_padding != 0.0: + raise ValueError( + "if the --unk-padding option is nonzero (which " + "it is by default, " + "the --oov-symbol-file option must be supplied.") + + utterance_stats = UtteranceStats() + word_stats = WordStats() + process_data(args, + oov_symbol=oov_symbol, utterance_stats=utterance_stats, + word_stats=word_stats) + + try: + args.text_out.close() + args.segments_out.close() + if args.ctm_edits_out is not None: + args.ctm_edits_out.close() + except: + _global_logger.error("error closing one or more outputs " + "(broken pipe or full disk?)") + raise + + utterance_stats.print_segment_stats() + if args.word_stats_out is not None: + word_stats.print(args.word_stats_out) + if args.ctm_edits_out is not None: + _global_logger.info("detailed utterance-level debug information " + "is in %s", args.ctm_edits_out.name) + except: + _global_logger.error("Failed segmenting CTM edits") + raise + finally: + try: + args.text_out.close() + args.segments_out.close() + if args.ctm_edits_out is not None: + args.ctm_edits_out.close() + except: + _global_logger.error("error closing one or more outputs " + "(broken pipe or full disk?)") + raise + + +if __name__ == '__main__': + main() diff --git a/egs/wsj/s5/steps/cleanup/internal/split_text_into_docs.pl b/egs/wsj/s5/steps/cleanup/internal/split_text_into_docs.pl new file mode 100755 index 00000000000..fa60987aef3 --- /dev/null +++ b/egs/wsj/s5/steps/cleanup/internal/split_text_into_docs.pl @@ -0,0 +1,72 @@ +#! /usr/bin/perl + +# Copyright 2017 Vimal Manohar +# Apache 2.0. + +# If 'text' contains: +# utterance1 A B C D +# utterance2 C B +# and you ran: +# split_text_into_docs.pl --max-words 2 text doc2text docs +# then 'doc2text' would contain: +# utterance1-1 utterance1 +# utterance1-2 utterance1 +# utterance2-1 utterance2 +# and 'docs' would contain: +# utterance1-1 A B +# utterance1-2 C D +# utterance2-1 C B + +use warnings; +use strict; + +my $max_words = 1000; + +my $usage = "Usage: steps/cleanup/internal/split_text_into_docs.pl [--max-words ] text doc2text docs\n"; + +while (@ARGV > 3) { + if ($ARGV[0] eq "--max-words") { + shift @ARGV; + $max_words = shift @ARGV; + } else { + print STDERR "$usage"; + exit (1); + } +} + +if (scalar @ARGV != 3) { + print STDERR "$usage"; + exit (1); +} + +sub min ($$) { $_[$_[0] > $_[1]] } + +open TEXT, $ARGV[0] or die "$0: Could not open file $ARGV[0] for reading\n"; +open DOC2TEXT, ">", $ARGV[1] or die "$0: Could not open file $ARGV[1] for writing\n"; +open DOCS, ">", $ARGV[2] or die "$0: Could not open file $ARGV[2] for writing\n"; + +while () { + chomp; + my @F = split; + my $utt = shift @F; + my $num_words = scalar @F; + + if ($num_words <= $max_words) { + print DOCS "$_\n"; + print DOC2TEXT "$utt $utt\n"; + next; + } + + my $num_docs = int($num_words / $max_words) + 1; + my $num_words_shift = int($num_words / $num_docs) + 1; + my $words_per_doc = $num_words_shift; + + #print STDERR ("$utt num-words=$num_words num-docs=$num_docs words-per-doc=$words_per_doc\n"); + + for (my $i = 0; $i < $num_docs; $i++) { + my $st = $i*$num_words_shift; + my $end = min($st + $words_per_doc, $num_words) - 1; + print DOCS ("$utt-$i " . join(" ", @F[$st..$end]) . "\n"); + print DOC2TEXT "$utt-$i $utt\n"; + } +} diff --git a/egs/wsj/s5/steps/cleanup/internal/stitch_documents.py b/egs/wsj/s5/steps/cleanup/internal/stitch_documents.py new file mode 100755 index 00000000000..8614db75f10 --- /dev/null +++ b/egs/wsj/s5/steps/cleanup/internal/stitch_documents.py @@ -0,0 +1,157 @@ +#! /usr/bin/env python + +# Copyright 2016 Vimal Manohar +# Apache 2.0. + +"""This script reads an archive of mapping from query to +documents and stitches the documents for each query into a +new document. +Here "document" is just a list of words. + +query2docs is a mapping from query-id to a list of tuples +(document-id, start-fraction, end-fraction) +The tuple can be just the document-id, which is equivaluent to +specifying a start-fraction and end-fraction of 1.0 +The start and end fractions are used to stitch only a part of the +document to the retrieved set for the query. + +e.g. +query1 doc1 doc2 +query2 doc1,0,0.3 doc2,1,1 + +input-documents +doc1 A B C +doc2 D E +output-documents +query1 A B C D E +query2 C D E +""" + +from __future__ import print_function +import argparse +import logging + +logger = logging.getLogger(__name__) +handler = logging.StreamHandler() +handler.setLevel(logging.INFO) +formatter = logging.Formatter("%(asctime)s [%(pathname)s:%(lineno)s - " + "%(funcName)s - %(levelname)s ] %(message)s") +handler.setFormatter(formatter) + +for l in [logger, logging.getLogger('libs')]: + l.setLevel(logging.DEBUG) + l.addHandler(handler) + + +def get_args(): + """Returns arguments parsed from command-line.""" + + parser = argparse.ArgumentParser( + description="""This script reads an archive of mapping from query to + documents and stitches the documents for each query into a new + document.""") + + parser.add_argument("--query2docs", type=argparse.FileType('r'), + required=True, + help="""Input file containing an archive + of list of documents indexed by a query document + id.""") + parser.add_argument("--input-documents", type=argparse.FileType('r'), + required=True, + help="""Input file containing the documents + indexed by the document id.""") + parser.add_argument("--output-documents", type=argparse.FileType('w'), + required=True, + help="""Output documents indexed by the query + document-id, obtained by stitching input documents + corresponding to the query.""") + parser.add_argument("--check-sorted-docs-per-query", type=str, + choices=["true", "false"], default="false", + help="If specified, the script will expect " + "the document ids in --query2docs to be " + "sorted.") + + args = parser.parse_args() + + args.check_sorted_docs_per_query = bool( + args.check_sorted_docs_per_query == "true") + + return args + + +def run(args): + documents = {} + for line in args.input_documents: + parts = line.strip().split() + key = parts[0] + documents[key] = parts[1:] + args.input_documents.close() + + for line in args.query2docs: + try: + parts = line.strip().split() + query = parts[0] + document_infos = parts[1:] + + output_document = [] + prev_doc_id = '' + for doc_info in document_infos: + try: + doc_id, start_fraction, end_fraction = doc_info.split(',') + start_fraction = float(start_fraction) + end_fraction = float(end_fraction) + except ValueError: + doc_id = doc_info + start_fraction = 1.0 + end_fraction = 1.0 + + if args.check_sorted_docs_per_query: + if prev_doc_id != '': + if doc_id <= prev_doc_id: + raise RuntimeError( + "Documents not sorted and " + "--check-sorted-docs-per-query was True; " + "{0} <= {1}".format(doc_id, prev_doc_id)) + prev_doc_id = doc_id + + doc = documents[doc_id] + num_words = len(doc) + + if start_fraction == 1.0 or end_fraction == 1.0: + assert end_fraction == end_fraction + output_document.extend(doc) + else: + assert (start_fraction + end_fraction < 1.0) + if start_fraction > 0: + output_document.extend( + doc[0:int(start_fraction * num_words)]) + if end_fraction > 0: + output_document.extend( + doc[int(end_fraction * num_words):]) + + print ("{0} {1}".format(query, " ".join(output_document)), + file=args.output_documents) + except Exception: + logger.error("Error processing line %s in file %s", line, + args.query2docs.name) + raise + + +def main(): + args = get_args() + + try: + run(args) + except: + logger.error("Failed to stictch document; got error ", + exc_info=True) + raise SystemExit(1) + finally: + for f in [args.query2docs, args.input_documents, + args.output_documents]: + if f is not None: + f.close() + + +if __name__ == '__main__': + main() diff --git a/egs/wsj/s5/steps/cleanup/internal/taint_ctm_edits.py b/egs/wsj/s5/steps/cleanup/internal/taint_ctm_edits.py index 2230a10aee2..85e1df997a7 100755 --- a/egs/wsj/s5/steps/cleanup/internal/taint_ctm_edits.py +++ b/egs/wsj/s5/steps/cleanup/internal/taint_ctm_edits.py @@ -53,6 +53,9 @@ parser.add_argument("--verbose", type = int, default = 1, choices=[0,1,2,3], help = "Verbose level, higher = more verbose output") +parser.add_argument("--remove-deletions", type=str, default="true", + choices=["true", "false"], + help = "Remove deletions next to taintable lines") parser.add_argument("ctm_edits_in", metavar = "", help = "Filename of input ctm-edits file. " "Use /dev/stdin for standard input.") @@ -61,6 +64,7 @@ "Use /dev/stdout for standard output.") args = parser.parse_args() +args.remove_deletions = bool(args.remove_deletions == "true") @@ -70,7 +74,7 @@ # sequence of fields. Returns the same format of data after processing to add # the 'tainted' field. Note: this function is destructive of its input; the # input will not have the same value afterwards. -def ProcessUtterance(split_lines_of_utt): +def ProcessUtterance(split_lines_of_utt, remove_deletions=True): global num_lines_of_type, num_tainted_lines, \ num_del_lines_giving_taint, num_sub_lines_giving_taint, \ num_ins_lines_giving_taint @@ -114,7 +118,8 @@ def ProcessUtterance(split_lines_of_utt): j += 1 if tainted_an_adjacent_line: if edit_type == 'del': - split_lines_of_utt[i][7] = 'remove-this-line' + if remove_deletions: + split_lines_of_utt[i][7] = 'remove-this-line' num_del_lines_giving_taint += 1 elif edit_type == 'sub': num_sub_lines_giving_taint += 1 @@ -123,7 +128,8 @@ def ProcessUtterance(split_lines_of_utt): new_split_lines_of_utt = [] for i in range(len(split_lines_of_utt)): - if split_lines_of_utt[i][7] != 'remove-this-line': + if (not remove_deletions + or split_lines_of_utt[i][7] != 'remove-this-line'): new_split_lines_of_utt.append(split_lines_of_utt[i]) return new_split_lines_of_utt @@ -156,7 +162,8 @@ def ProcessData(): while True: if len(split_pending_line) == 0 or split_pending_line[0] != cur_utterance: - split_lines_of_cur_utterance = ProcessUtterance(split_lines_of_cur_utterance) + split_lines_of_cur_utterance = ProcessUtterance( + split_lines_of_cur_utterance, args.remove_deletions) for split_line in split_lines_of_cur_utterance: print(' '.join(split_line), file = f_out) split_lines_of_cur_utterance = [] diff --git a/egs/wsj/s5/steps/cleanup/internal/tf_idf.py b/egs/wsj/s5/steps/cleanup/internal/tf_idf.py new file mode 100644 index 00000000000..1eaff2d380f --- /dev/null +++ b/egs/wsj/s5/steps/cleanup/internal/tf_idf.py @@ -0,0 +1,426 @@ +# Copyright 2016 Vimal Manohar +# Apache 2.0. + +"""This module contains structures to accumulate, store and use stats +for Term-frequency and Inverse-document-frequency values. +""" + +from __future__ import print_function +import logging +import math +import re +import sys + +sys.path.insert(0, 'steps') + +logger = logging.getLogger('__name__') +logger.addHandler(logging.NullHandler()) + + +class IDFStats(object): + """Stores stats for computing inverse-document-frequencies. + """ + def __init__(self): + self.num_docs_for_term = {} + self.num_docs = 0 + + def get_inverse_document_frequency(self, term, weighting_scheme="log"): + """Get IDF for a term. + + Weighting scheme is the function applied on the raw + inverse-document frequencies n(t) = |d in D: t in d| + when computing idf(t,d). + Let N = Total number of documents. + + IDF weighting schemes:- + unary : idf(t,D) = 1 + log : idf(t,D) = log (N / (1 + n(t))) + log-smoothed : idf(t,D) = log(1 + N / n(t)) + probabilistic: idf(t,D) = log((N - n(t)) / n(t)) + """ + n_t = float(self.num_docs_for_term.get(term, 0)) + num_terms = len(self.num_docs_for_term) + + if num_terms == 0: + raise RuntimeError("No IDF stats have been accumulated.") + + if weighting_scheme == "unary": + return 1 + if weighting_scheme == "log": + return math.log(float(self.num_docs) / (1.0 + n_t)) + if weighting_scheme == "log-smoothed": + return math.log(1.0 + float(self.num_docs) / (1.0 + n_t)) + if weighting_scheme == "probabilitic": + return math.log((self.num_docs - n_t - 1) + / (1.0 + n_t)) + + def accumulate(self, term): + """Adds one count to the number of docs containing the term "term". + """ + self.num_docs_for_term[term] = self.num_docs_for_term.get(term, 0) + 1 + if len(term) == 1: + self.num_docs += 1 + + def write(self, file_handle): + """Writes the IDF stats to file using the format: + ... + for n-gram (, ... ) + """ + for term, num in self.num_docs_for_term.iteritems(): + if num == 0: + continue + assert isinstance(term, tuple) + print ("{term} {n}".format(term=" ".join(term), n=num), + file=file_handle) + + def read(self, file_handle): + """Loads IDF stats from file. """ + for line in file_handle: + parts = line.strip().split() + term = tuple(parts[0:-1]) + self.num_docs_for_term[term] = float(parts[-1]) + if len(term) == 1: + self.num_docs += 1 + + if len(self.num_docs_for_term) == 0: + raise RuntimeError("Read no IDF stats.") + + +class TFStats(object): + """Store stats for TF-IDF computation. + A separate object of IDFStats is stored within this object. + """ + def __init__(self): + self.raw_counts = {} + self.max_counts_for_term = {} + + def get_term_frequency(self, term, doc, weighting_scheme="raw", + normalization_factor=0.5): + """Returns the term-frequency for (term, document) pair. + + The function applied on the raw term-frequencies f(t,d) when computing + tf(t,d) is specified by the weighting_scheme. + binary : tf(t,d) = 1 if t in d else 0 + raw : tf(t,d) = f(t,d) + log : tf(t,d) = 1 + log(f(t,d)) + normalized : tf(t,d) = K + (1-K) * f(t,d) / max{f(t',d): t' in d} + """ + if weighting_scheme == "binary": + return 1 if (term, doc) in self.raw_counts else 0 + if weighting_scheme == "raw": + return self.raw_counts.get((term, doc), 0) + if weighting_scheme == "log": + if (term, doc) in self.raw_counts: + return 1 + math.log(self.raw_counts[(term, doc)]) + return 0 + if weighting_scheme == "normalized": + return (normalization_factor + + (1 - normalization_factor) + * self.raw_counts.get((term, doc), 0) + / (1.0 + self.max_counts_for_term.get(term, 0))) + raise KeyError("Unknown tf-weighting-scheme {0}".format( + weighting_scheme)) + + def accumulate(self, doc, text, ngram_order): + """Accumulate raw stats from a document for upto the specified + ngram-order.""" + for n in range(1, ngram_order + 1): + for i in range(len(text)): + term = tuple(text[i:(i+n)]) + self.raw_counts.setdefault((term, doc), 0) + self.raw_counts[(term, doc)] += 1 + + def compute_term_stats(self, idf_stats=None): + """Compute the maximum counts for each term over all the documents + based on the stored raw counts.""" + if len(self.raw_counts) == 0: + raise RuntimeError("No (term, doc) found in tf-stats.") + for tup, counts in self.raw_counts.iteritems(): + term = tup[0] + + if counts > self.max_counts_for_term.get(term, 0): + self.max_counts_for_term[term] = counts + + if idf_stats is not None: + idf_stats.accumulate(term) + + def __str__(self): + """Returns a string with all the stats in the following format: + ... + """ + lines = [] + for tup, counts in self.raw_counts.iteritems(): + term, doc = tup + lines.append("{order} {term} {doc} {counts}".format( + order=len(term), term=" ".join(term), + doc=doc, counts=counts)) + return "\n".join(lines) + + def read(self, file_handle, ngram_order=None, idf_stats=None): + """Reads the TF stats stored in a file in the following format: + ... + + If idf_stats is provided then idf_stats is accumulated simultaneously. + """ + for line in file_handle: + parts = line.strip().split() + order = parts[0] + assert len(parts) - 3 == order + if ngram_order is not None and order > ngram_order: + continue + term = tuple(parts[1:(order+1)]) + doc = parts[-2] + counts = float(parts[-1]) + + self.raw_counts[(term, doc)] = counts + + if counts > self.max_counts_for_term.get(term, 0): + self.max_counts_for_term[term] = counts + + if idf_stats is not None: + idf_stats.accumulate(term) + + if len(self.raw_counts) == 0: + raise RuntimeError("Read no TF stats.") + + +class TFIDF(object): + """Class to store TF-IDF values for term-document pairs. + + Parameters: + tf_idf - A dictionary of TF-IDF values indexed by (term, document) + tuple as key + """ + + def __init__(self): + self.tf_idf = {} + + def get_value(self, term, doc): + """Returns TF-IDF value for (term, doc) tuple if it exists. + Otherwise returns 0. + """ + return self.tf_idf[(term, doc)] + + def compute_similarity_scores(self, source_tfidf, source_docs=None, + do_length_normalization=False, + query_id=None): + """Computes TF-IDF similarity score between each pair of query + document contained in this object and the source documents + in the source_tfidf object. + + Arguments: + source_docs - If provided, the similarity scores are computed + for only the source documents contained in + source_docs. + use_average - If True, then the similarity scores is + normalized by the length of query. This is usually + not required when the scores are only utilized + for ranking the source documents. + query_id - If provided, check that this tf_idf object + contains values only for document with id 'query_id' + + Returns a dictionary + { (query_document_id, source_document_id): similarity_score } + """ + num_terms_per_doc = {} + similarity_scores = {} + + for tup, value in self.tf_idf.iteritems(): + term, doc = tup + num_terms_per_doc[doc] = num_terms_per_doc.get(doc, 0) + 1 + + if query_id is not None and doc != query_id: + raise RuntimeError("TF-IDF contains document {0}, which is " + "not the required query {1}. \n" + "Something wrong in how this TF-IDF object " + "was created or a bug in the " + "calling script.".format( + doc, query_id)) + + if source_docs is not None: + for src_doc in source_docs: + try: + src_value = source_tfidf.get_value(term, src_doc) + except KeyError: + logger.debug( + "Could not find ({term}, {src}) in " + "source_tfidf. " + "Choosing a tf-idf value of 0.".format( + term=term, src=src_doc)) + src_value = 0 + + similarity_scores[(doc, src_doc)] = ( + similarity_scores.get((doc, src_doc), 0) + + src_value * value) + else: + for src_tup, src_value in source_tfidf.tf_idf.iteritems(): + similarity_scores[(doc, src_doc)] = ( + similarity_scores.get((doc, src_doc), 0) + + src_value * value) + + if do_length_normalization: + for doc_pair, value in similarity_scores.iteritems(): + doc, src_doc = doc_pair + similarity_scores[(doc, src_doc)] = (value + / num_terms_per_doc[doc]) + + if logger.isEnabledFor(logging.DEBUG): + for doc, count in num_terms_per_doc.iteritems(): + logger.debug( + 'Seen {0} terms in query document {1}'.format(count, doc)) + + return similarity_scores + + def read(self, tf_idf_file): + """Loads TFIDF object from file.""" + + if len(self.tf_idf) != 0: + raise RuntimeError("TD-IDF object is not empty.") + seen_footer = False + line = tf_idf_file.readline() + parts = line.strip().split() + if re.search('^', line) is None: + raise TypeError( + "Invalid format of TD-IDF object. " + "Missing header ; got {0}".format(line)) + assert parts[0] == "" + if len(parts) > 1: + # Read header; go to the rest of line + line = " ".join(parts[1:]) + else: + # Nothing in this line. Read the next lines. + line = tf_idf_file.readline() + while line: + parts = line.strip().split() + if re.search('', line): + if len(parts) > 1: + raise TypeError( + "Expecting footer " + "to be on a separate line; got {0}".format(line)) + assert parts[0] == "" + seen_footer = True + break + if re.search('', line): + raise TypeError("Got unexpected header in line " + "{0}".format(line)) + + order = int(parts[0]) + term = tuple(parts[1:(order + 1)]) + doc = parts[-2] + tfidf = float(parts[-1]) + + entry = (term, doc) + if entry in self.tf_idf: + raise RuntimeError("Duplicate entry {0} found while reading " + "TFIDF object.".format(entry)) + self.tf_idf[entry] = tfidf + + line = tf_idf_file.readline() + if not seen_footer: + raise TypeError( + "Did not see footer " + "in TFIDF object; got {0}".format(line)) + + if len(self.tf_idf) == 0: + raise RuntimeError( + "Read no TF-IDF values from file {0}".format(tf_idf_file.name)) + + def write(self, tf_idf_file): + """Writes TFIDF object to file.""" + + print ("", file=tf_idf_file) + for tup, value in self.tf_idf.iteritems(): + term, doc = tup + print("{order} {term} {doc} {tfidf}".format( + order=len(term), term=" ".join(term), + doc=doc, tfidf=value), + file=tf_idf_file) + print ("", file=tf_idf_file) + + +def write_tfidf_from_stats( + tf_stats, idf_stats, tf_idf_file, tf_weighting_scheme="raw", + idf_weighting_scheme="log", tf_normalization_factor=0.5, + expected_document_id=None): + """Writes TF-IDF values to file args.tf_idf_file. + The format used is + . + Markers "" and "" are added for parsing this file + easily. + + Arguments: + tf_stats - A TFStats object + idf_stats - An IDFStats object + tf_idf_file - Output file to which the TF-IDF values will be written + tf_weighting_scheme - See doc_string in TFStats class + idf_weighting_scheme - See doc_string in IDFStats class + tf_normalization_factor - See doc_string in TFStats class + document_id - If provided, checks that the TFStats object contains + stats only for this document_id. + """ + if len(tf_stats.raw_counts) == 0: + raise RuntimeError("Supplied tf-stats object is empty.") + + if idf_stats.num_docs == 0: + raise RuntimeError("Supplied idf-stats object is empty.") + + print ("", file=tf_idf_file) + for tup in tf_stats.raw_counts: + term, doc = tup + + if expected_document_id is not None and doc != expected_document_id: + raise RuntimeError("TFStats object contains stats with " + "document {0}, " + "which is not the specified " + "document {1}.".format(doc, + expected_document_id)) + + tf_value = tf_stats.get_term_frequency( + term, doc, + weighting_scheme=tf_weighting_scheme, + normalization_factor=tf_normalization_factor) + + idf_value = idf_stats.get_inverse_document_frequency( + term, weighting_scheme=idf_weighting_scheme) + + print("{order} {term} {doc} {tfidf}".format( + order=len(term), term=" ".join(term), + doc=doc, tfidf=tf_value * idf_value), + file=tf_idf_file) + print ("", file=tf_idf_file) + + +def read_key(fd): + """ [str] = read_key(fd) + Read the utterance-key from the opened ark/stream descriptor 'fd'. + """ + str = '' + while 1: + char = fd.read(1) + if char == '' : break + if char == ' ' : break + str += char + str = str.strip() + if str == '': return None # end of file, + assert(re.match('^[\.a-zA-Z0-9_-]+$',str) != None) # check format, + return str + + +def read_tfidf_ark(file_handle): + """Read a kaldi archive of TFIDF objects indexed by a key (document-id). + + + ... + """ + try: + key = read_key(file_handle) + while key: + tf_idf = TFIDF() + try: + tf_idf.read(file_handle) + except RuntimeError: + raise + yield key, tf_idf + key = read_key(file_handle) + finally: + file_handle.close() diff --git a/egs/wsj/s5/steps/cleanup/make_biased_lm_graphs.sh b/egs/wsj/s5/steps/cleanup/make_biased_lm_graphs.sh index fdfdbda766a..799a379f7a1 100755 --- a/egs/wsj/s5/steps/cleanup/make_biased_lm_graphs.sh +++ b/egs/wsj/s5/steps/cleanup/make_biased_lm_graphs.sh @@ -39,9 +39,9 @@ echo "$0 $@" # Print the command line for logging [ -f path.sh ] && . ./path.sh # source the path. . parse_options.sh || exit 1; -if [ $# != 3 ]; then - echo "usage: $0 " - echo "e.g.: $0 data/train data/lang exp/tri3_cleanup" +if [ $# != 4 ]; then + echo "usage: $0 " + echo "e.g.: $0 data/train data/lang exp/tri3_cleanup exp/tri3_cleanup/graphs" echo " This script creates biased decoding graphs per utterance (or possibly" echo " groups of utterances, depending on --min-words-per-graph). Its output" echo " goes to /HCLG.fsts.scp, indexed by utterance. Directory is" @@ -72,33 +72,44 @@ if [ $# != 3 ]; then exit 1; fi -data=$1 +data_or_text=$1 lang=$2 dir=$3 +graph_dir=$4 +if [ -d $data_or_text ]; then + text=$data_or_text/text +else + text=$data_or_text +fi + +mkdir -p $graph_dir -for f in $lang/oov.int $dir/tree $dir/final.mdl \ +for f in $text $lang/oov.int $dir/tree $dir/final.mdl \ $lang/L_disambig.fst $lang/phones/disambig.int; do [ ! -f $f ] && echo "$0: expected file $f to exist" && exit 1; done +utils/lang/check_phones_compatible.sh $lang/phones.txt $dir/phones.txt +cp $lang/phones.txt $graph_dir + oov=`cat $lang/oov.int` || exit 1; -mkdir -p $dir/log +mkdir -p $graph_dir/log # create top_words.{int,txt} if [ $stage -le 0 ]; then export LC_ALL=C # the following pipe will be broken due to the 'head'; don't fail. set +o pipefail - utils/sym2int.pl --map-oov $oov -f 2- $lang/words.txt <$data/text | \ + utils/sym2int.pl --map-oov $oov -f 2- $lang/words.txt < $text | \ awk '{for(x=2;x<=NF;x++) print $x;}' | sort | uniq -c | \ - sort -nr | head -n $top_n_words > $dir/word_counts.int + sort -nr | head -n $top_n_words > $graph_dir/word_counts.int set -o pipefail - total_count=$(awk '{x+=$1} END{print x}' < $dir/word_counts.int) + total_count=$(awk '{x+=$1} END{print x}' < $graph_dir/word_counts.int) # print top-n words with their unigram probabilities. awk -v tot=$total_count -v weight=$top_n_words_weight '{print ($1*weight)/tot, $2;}' \ - <$dir/word_counts.int >$dir/top_words.int - utils/int2sym.pl -f 2 $lang/words.txt <$dir/top_words.int >$dir/top_words.txt + <$graph_dir/word_counts.int >$graph_dir/top_words.int + utils/int2sym.pl -f 2 $lang/words.txt <$graph_dir/top_words.int >$graph_dir/top_words.txt fi word_disambig_symbol=$(cat $lang/words.txt | grep -w "#0" | awk '{print $2}') @@ -107,12 +118,18 @@ if [ -z "$word_disambig_symbol" ]; then exit 1 fi -utils/split_data.sh --per-utt $data $nj +mkdir -p $graph_dir/texts +split_text= +for n in `seq $nj`; do + split_text="$split_text $graph_dir/texts/text.$n" +done -sdata=$data/split${nj}utt +utils/split_scp.pl $text $split_text +mkdir -p $graph_dir/log $graph_dir/fsts -mkdir -p $dir/log $dir/fsts +# Make $dir an absolute pathname +dir=`perl -e '($dir,$pwd)= @ARGV; if($dir!~m:^/:) { $dir = "$pwd/$dir"; } print $dir; ' $dir ${PWD}` if [ $stage -le 1 ]; then echo "$0: creating utterance-group-specific decoding graphs with biased LMs" @@ -120,27 +137,27 @@ if [ $stage -le 1 ]; then # These options are passed through directly to make_one_biased_lm.py. lm_opts="--word-disambig-symbol=$word_disambig_symbol --ngram-order=$ngram_order --min-lm-state-count=$min_lm_state_count --discounting-constant=$discounting_constant" - $cmd JOB=1:$nj $dir/log/compile_decoding_graphs.JOB.log \ - utils/sym2int.pl --map-oov $oov -f 2- $lang/words.txt $sdata/JOB/text \| \ + $cmd JOB=1:$nj $graph_dir/log/compile_decoding_graphs.JOB.log \ + utils/sym2int.pl --map-oov $oov -f 2- $lang/words.txt $graph_dir/texts/text.JOB \| \ steps/cleanup/make_biased_lms.py --min-words-per-graph=$min_words_per_graph \ - --lm-opts="$lm_opts" $dir/fsts/utt2group.JOB \| \ + --lm-opts="$lm_opts" $graph_dir/fsts/utt2group.JOB \| \ compile-train-graphs-fsts $scale_opts --read-disambig-syms=$lang/phones/disambig.int \ $dir/tree $dir/final.mdl $lang/L_disambig.fst ark:- \ - ark,scp:$dir/fsts/HCLG.fsts.JOB.ark,$dir/fsts/HCLG.fsts.JOB.scp || exit 1 + ark,scp:$graph_dir/fsts/HCLG.fsts.JOB.ark,$graph_dir/fsts/HCLG.fsts.JOB.scp || exit 1 fi -for j in $(seq $nj); do cat $dir/fsts/HCLG.fsts.$j.scp; done > $dir/fsts/HCLG.fsts.per_utt.scp -for j in $(seq $nj); do cat $dir/fsts/utt2group.$j; done > $dir/fsts/utt2group +for j in $(seq $nj); do cat $graph_dir/fsts/HCLG.fsts.$j.scp; done > $graph_dir/fsts/HCLG.fsts.per_utt.scp +for j in $(seq $nj); do cat $graph_dir/fsts/utt2group.$j; done > $graph_dir/fsts/utt2group -cp $lang/words.txt $dir/ +cp $lang/words.txt $graph_dir/ +cp -r $lang/phones $graph_dir/ # The following command gives us an scp file relative to utterance-id. -utils/apply_map.pl -f 2 $dir/fsts/HCLG.fsts.per_utt.scp <$dir/fsts/utt2group > $dir/HCLG.fsts.scp - -n1=$(cat $data/utt2spk | wc -l) -n2=$(cat $dir/HCLG.fsts.scp | wc -l) +utils/apply_map.pl -f 2 $graph_dir/fsts/HCLG.fsts.per_utt.scp <$graph_dir/fsts/utt2group > $graph_dir/HCLG.fsts.scp +n1=$(cat $text | wc -l) +n2=$(cat $graph_dir/HCLG.fsts.scp | wc -l) if [ $[$n1*9] -gt $[$n2*10] ]; then echo "$0: too many utterances have no scp, something seems to have gone wrong." diff --git a/egs/wsj/s5/steps/cleanup/segment_long_utterances.sh b/egs/wsj/s5/steps/cleanup/segment_long_utterances.sh new file mode 100755 index 00000000000..de1a04c3c23 --- /dev/null +++ b/egs/wsj/s5/steps/cleanup/segment_long_utterances.sh @@ -0,0 +1,386 @@ +#!/bin/bash + +# Copyright 2014 Guoguo Chen +# 2016 Vimal Manohar +# Apache 2.0 + +. path.sh + +set -e +set -o pipefail +set -u + +# Uniform segmentation options +max_segment_duration=30 +overlap_duration=5 +seconds_per_spk_max=30 + +# Decode options +graph_opts= +beam=15.0 +lattice_beam=1.0 +nj=4 +lmwt=10 + +# TF-IDF similarity search options +max_words=1000 +num_neighbors_to_search=1 # Number of neighboring documents to search around the one retrieved based on maximum tf-idf similarity. +neighbor_tfidf_threshold=0.5 + +align_full_hyp=false # Align full hypothesis i.e. trackback from the end to get the alignment. + +# First-pass segmentation opts +# These options are passed to the script +# steps/cleanup/internal/segment_ctm_edits_mild.py +segmentation_extra_opts= +min_split_point_duration=0.1 +max_deleted_words_kept_when_merging=1 +max_wer=50 +max_segment_length_for_merging=60 +max_bad_proportion=0.75 +max_intersegment_incorrect_words_length=1 +max_segment_length_for_splitting=10 +hard_max_segment_length=15 +min_silence_length_to_split_at=0.3 +min_non_scored_length_to_split_at=0.3 + +stage=-1 + +cmd=run.pl + +. utils/parse_options.sh + +if [ $# -ne 6 ]; then + cat < + e.g.: $0 exp/wsj_tri2b data/lang_nosp data/train_long data/train_long/text data/train_reseg exp/segment_wsj_long_utts_train +This script performs segmentation of the data in and +transcript , writing the segmented data (with a segments file) to + along with the corresponding aligned transcription. +Note: must be indexed by the utterance-ids of the utterances in + +The purpose of this script is to divide up the input data (which may consist of +long recordings such as television shows or audiobooks) into segments which are +of manageable length for further processing, along with the portion of the +transcript that seems to match each segment. +The output data is not necessarily particularly clean; you are advised to run +steps/cleanup/clean_and_segment_data.sh on the output in order to further clean +it and eliminate data where the transcript doesn't seem to match. +EOF + exit 1 +fi + +srcdir=$1 +lang=$2 +data=$3 +text=$4 +out_data=$5 +dir=$6 + +for f in $data/feats.scp $text $srcdir/tree \ + $srcdir/final.mdl $srcdir/cmvn_opts; do + if [ ! -f $f ]; then + echo "$0: Could not find file $f" + exit 1 + fi +done + +data_id=`basename $data` +mkdir -p $dir + +data_uniform_seg=$dir/${data_id}_uniform_seg + +frame_shift=`utils/data/get_frame_shift.sh $data` + +# First we split the data into segments of around 30s long, on which +# it would be possible to do a decoding. +# A diarization step will be added in the future. +if [ $stage -le 1 ]; then + echo "$0: Stage 1 (Splitting data directory $data into uniform segments)" + + utils/data/get_utt2dur.sh $data + if [ ! -f $data/segments ]; then + utils/data/get_segments_for_data.sh $data > $data/segments + fi + + utils/data/get_uniform_subsegments.py \ + --max-segment-duration=$max_segment_duration \ + --overlap-duration=$overlap_duration \ + --max-remaining-duration=$(perl -e "print $max_segment_duration / 2.0") \ + $data/segments > $dir/uniform_sub_segments + + # Get a mapping from the new to old utterance-ids. + # Typically, the old-utterance is the whole recording. + awk '{print $1" "$2}' $dir/uniform_sub_segments > $dir/new2orig_utt +fi + +if [ $stage -le 2 ]; then + echo "$0: Stage 2 (Prepare uniform sub-segmented data directory)" + rm -r $data_uniform_seg || true + + if [ ! -z "$seconds_per_spk_max" ]; then + utils/data/subsegment_data_dir.sh \ + $data $dir/uniform_sub_segments $dir/${data_id}_uniform_seg.temp + + utils/data/modify_speaker_info.sh --seconds-per-spk-max $seconds_per_spk_max \ + $dir/${data_id}_uniform_seg.temp $data_uniform_seg + else + utils/data/subsegment_data_dir.sh \ + $data $dir/uniform_sub_segments $data_uniform_seg + fi + + utils/fix_data_dir.sh $data_uniform_seg + + # Compute new cmvn stats for the segmented data directory + steps/compute_cmvn_stats.sh $data_uniform_seg/ +fi + +graph_dir=$dir/graphs_uniform_seg + +if [ $stage -le 3 ]; then + echo "$0: Stage 3 (Building biased-language-model decoding graphs)" + + cp $srcdir/final.mdl $dir + cp $srcdir/tree $dir + cp $srcdir/cmvn_opts $dir + cp $srcdir/{splice_opts,delta_opts,final.mat,final.alimdl} $dir 2>/dev/null || true + cp $srcdir/phones.txt $dir + + # Make graphs w.r.t. to the original text (usually recording-level) + steps/cleanup/make_biased_lm_graphs.sh $graph_opts \ + --nj $nj --cmd "$cmd" \ + $text $lang $dir $dir/graphs + + # and then copy it to the sub-segments. + mkdir -p $graph_dir + cat $dir/uniform_sub_segments | awk '{print $1" "$2}' | \ + utils/apply_map.pl -f 2 $dir/graphs/HCLG.fsts.scp > \ + $graph_dir/HCLG.fsts.scp + + cp $lang/words.txt $graph_dir + cp -r $lang/phones $graph_dir + [ -f $dir/graphs/num_pdfs ] && cp $dir/graphs/num_pdfs $graph_dir/ +fi + +decode_dir=$dir/lats +mkdir -p $decode_dir + +if [ $stage -le 4 ]; then + echo "$0: Decoding with biased language models..." + + if [ -f $srcdir/trans.1 ]; then + steps/cleanup/decode_fmllr_segmentation.sh \ + --beam $beam --lattice-beam $lattice_beam --nj $nj --cmd "$cmd --mem 4G" \ + --skip-scoring true --allow-partial false \ + $graph_dir $data_uniform_seg $decode_dir + else + steps/cleanup/decode_segmentation.sh \ + --beam $beam --lattice-beam $lattice_beam --nj $nj --cmd "$cmd --mem 4G" \ + --skip-scoring true --allow-partial false \ + $graph_dir $data_uniform_seg $decode_dir + fi +fi + +if [ $stage -le 5 ]; then + steps/cleanup/internal/get_ctm.sh \ + --lmwt $lmwt --cmd "$cmd --mem 4G" \ + --print-silence true \ + $data_uniform_seg $lang $decode_dir +fi + +# Split the original text into documents, over which we can do +# searching reasonably efficiently. Also get a mapping from the original +# text to the created documents (i.e. text2doc) +# Since the Smith-Waterman alignment is linear in the length of the +# text, we want to keep it reasonably small (a few thousand words). + +if [ $stage -le 6 ]; then + # Split the reference text into documents. + mkdir -p $dir/docs + steps/cleanup/internal/split_text_into_docs.pl --max-words $max_words \ + $text $dir/docs/doc2text $dir/docs/docs.txt + utils/utt2spk_to_spk2utt.pl $dir/docs/doc2text > $dir/docs/text2doc +fi + +if [ $stage -le 7 ]; then + # Get TF-IDF for the reference documents. + echo $nj > $dir/docs/num_jobs + + utils/split_data.sh $data_uniform_seg $nj + + mkdir -p $dir/docs/split$nj/ + + # First compute IDF stats + $cmd $dir/log/compute_source_idf_stats.log \ + steps/cleanup/internal/compute_tf_idf.py \ + --tf-weighting-scheme="raw" \ + --idf-weighting-scheme="log" \ + --output-idf-stats=$dir/docs/idf_stats.txt \ + $dir/docs/docs.txt $dir/docs/src_tf_idf.txt + + # Split documents so that they can be accessed easily by parallel jobs. + mkdir -p $dir/docs/split$nj/ + text2doc_splits= + for n in `seq $nj`; do + text2doc_splits="$text2doc_splits $dir/docs/split$nj/text2doc.$n" + done + sdir=$dir/docs/split$nj + + utils/split_scp.pl $dir/docs/text2doc $text2doc_splits + $cmd JOB=1:$nj $dir/docs/log/split_docs.JOB.log \ + utils/spk2utt_to_utt2spk.pl $sdir/text2doc.JOB \| \ + utils/filter_scp.pl /dev/stdin $dir/docs/docs.txt '>' \ + $sdir/docs.JOB.txt + + # Compute TF-IDF for the source documents. + $cmd JOB=1:$nj $dir/docs/log/get_tfidf_for_source_texts.JOB.log \ + steps/cleanup/internal/compute_tf_idf.py \ + --tf-weighting-scheme="raw" \ + --idf-weighting-scheme="log" \ + --input-idf-stats=$dir/docs/idf_stats.txt \ + $sdir/docs.JOB.txt $sdir/src_tf_idf.JOB.txt + + # Make $sdir an absolute pathname. + sdir=`perl -e '($dir,$pwd)= @ARGV; if($dir!~m:^/:) { $dir = "$pwd/$dir"; } print $dir; ' $sdir ${PWD}` + + for n in `seq $nj`; do + awk -v f="$sdir/src_tf_idf.$n.txt" '{print $1" "f}' \ + $sdir/text2doc.$n + done > $dir/docs/source2tf_idf.scp + + +fi + +if [ $stage -le 8 ]; then + echo "$0: using default values of non-scored words..." + + # At the level of this script we just hard-code it that non-scored words are + # those that map to silence phones (which is what get_non_scored_words.py + # gives us), although this could easily be made user-configurable. This list + # of non-scored words affects the behavior of several of the data-cleanup + # scripts; essentially, we view the non-scored words as negotiable when it + # comes to the reference transcript, so we'll consider changing the reference + # to match the hyp when it comes to these words. + steps/cleanup/internal/get_non_scored_words.py $lang > $dir/non_scored_words.txt +fi + +if [ $stage -le 9 ]; then + sdir=$dir/query_docs/split$nj + mkdir -p $sdir + + # Compute TF-IDF for the query documents (decode hypotheses). + # The output is an archive of TF-IDF indexed by the query. + $cmd JOB=1:$nj $dir/lats/log/compute_query_tf_idf.JOB.log \ + steps/cleanup/internal/ctm_to_text.pl --non-scored-words $dir/non_scored_words.txt \ + $dir/lats/score_$lmwt/${data_id}_uniform_seg.ctm.JOB \| \ + steps/cleanup/internal/compute_tf_idf.py \ + --tf-weighting-scheme="normalized" \ + --idf-weighting-scheme="log" \ + --input-idf-stats=$dir/docs/idf_stats.txt \ + --accumulate-over-docs=false \ + - $sdir/query_tf_idf.JOB.ark.txt + + # The relevant documents can be found using TF-IDF similarity and nearby + # documents can also be picked for the Smith-Waterman alignment stage. + + # The query TF-IDFs are all indexed by the utterance-id of the sub-segments. + # The source TF-IDFs use the document-ids created by splitting the reference + # text into documents. + # For each query, we need to retrieve the documents that were created from + # the same original utterance that the sub-segment was from. For this, + # we have to load the source TF-IDF that has those documents. This + # information is provided using the option --source-text2tf-idf-file. + # The output of this script is a file where the first column is the + # query-id (i.e. sub-segment-id) and the remaining columns, which is at least + # one in number and a maxmium of (1 + 2 * num-neighbors-to-search) columns + # is the document-ids for the retrieved documents. + $cmd JOB=1:$nj $dir/lats/log/retrieve_similar_docs.JOB.log \ + steps/cleanup/internal/retrieve_similar_docs.py \ + --query-tfidf=$dir/query_docs/split$nj/query_tf_idf.JOB.ark.txt \ + --source-text2tfidf-file=$dir/docs/source2tf_idf.scp \ + --source-text-id2doc-ids=$dir/docs/text2doc \ + --query-id2source-text-id=$dir/new2orig_utt \ + --num-neighbors-to-search=$num_neighbors_to_search \ + --neighbor-tfidf-threshold=$neighbor_tfidf_threshold \ + --relevant-docs=$dir/query_docs/split$nj/relevant_docs.JOB.txt + + $cmd JOB=1:$nj $dir/lats/log/get_ctm_edits.JOB.log \ + steps/cleanup/internal/stitch_documents.py \ + --query2docs=$dir/query_docs/split$nj/relevant_docs.JOB.txt \ + --input-documents=$dir/docs/split$nj/JOB/docs.txt \ + --output-documents=- \| \ + steps/cleanup/internal/align_ctm_ref.py --eps-symbol='""' \ + --oov-word="'`cat $lang/oov.txt`'" --symbol-table=$lang/words.txt \ + --hyp-format=CTM --align-full-hyp=$align_full_hyp \ + --hyp=$dir/lats/score_$lmwt/${data_id}_uniform_seg.ctm.JOB --ref=- \ + --output=$dir/lats/score_$lmwt/${data_id}_uniform_seg.ctm_edits.JOB + + for n in `seq $nj`; do + cat $dir/lats/score_$lmwt/${data_id}_uniform_seg.ctm_edits.$n + done > $dir/lats/score_$lmwt/ctm_edits + +fi + +if [ $stage -le 10 ]; then + steps/cleanup/internal/resolve_ctm_edits_overlaps.py \ + ${data_uniform_seg}/segments $dir/lats/score_$lmwt/ctm_edits $dir/ctm_edits +fi + +if [ $stage -le 11 ]; then + echo "$0: modifying ctm-edits file to allow repetitions [for dysfluencies] and " + echo " ... to fix reference mismatches involving non-scored words. " + + $cmd $dir/log/modify_ctm_edits.log \ + steps/cleanup/internal/modify_ctm_edits.py --verbose=3 $dir/non_scored_words.txt \ + $dir/ctm_edits $dir/ctm_edits.modified + + echo " ... See $dir/log/modify_ctm_edits.log for details and stats, including" + echo " a list of commonly-repeated words." +fi + +if [ $stage -le 12 ]; then + echo "$0: applying 'taint' markers to ctm-edits file to mark silences and" + echo " ... non-scored words that are next to errors." + $cmd $dir/log/taint_ctm_edits.log \ + steps/cleanup/internal/taint_ctm_edits.py --remove-deletions=false \ + $dir/ctm_edits.modified $dir/ctm_edits.tainted + echo "... Stats, including global cor/ins/del/sub stats, are in $dir/log/taint_ctm_edits.log." +fi + +if [ $stage -le 13 ]; then + echo "$0: creating segmentation from ctm-edits file." + + segmentation_opts=( + --min-split-point-duration=$min_split_point_duration + --max-deleted-words-kept-when-merging=$max_deleted_words_kept_when_merging + --merging.max-wer=$max_wer + --merging.max-segment-length=$max_segment_length_for_merging + --merging.max-bad-proportion=$max_bad_proportion + --merging.max-intersegment-incorrect-words-length=$max_intersegment_incorrect_words_length + --splitting.max-segment-length=$max_segment_length_for_splitting + --splitting.hard-max-segment-length=$hard_max_segment_length + --splitting.min-silence-length=$min_silence_length_to_split_at + --splitting.min-non-scored-length=$min_non_scored_length_to_split_at + ) + + $cmd $dir/log/segment_ctm_edits.log \ + steps/cleanup/internal/segment_ctm_edits_mild.py \ + ${segmentation_opts[@]} $segmentation_extra_opts \ + --oov-symbol-file=$lang/oov.txt \ + --ctm-edits-out=$dir/ctm_edits.segmented \ + --word-stats-out=$dir/word_stats.txt \ + $dir/non_scored_words.txt \ + $dir/ctm_edits.tainted $dir/text $dir/segments + + echo "$0: contents of $dir/log/segment_ctm_edits.log are:" + cat $dir/log/segment_ctm_edits.log + echo "For word-level statistics on p(not-being-in-a-segment), with 'worst' words at the top," + echo "see $dir/word_stats.txt" + echo "For detailed utterance-level debugging information, see $dir/ctm_edits.segmented" +fi + +mkdir -p $out_data +if [ $stage -le 14 ]; then + utils/data/subsegment_data_dir.sh $data_uniform_seg \ + $dir/segments $dir/text $out_data +fi diff --git a/egs/wsj/s5/steps/libs/common.py b/egs/wsj/s5/steps/libs/common.py index 0a13fc3504a..34d1c3818b8 100644 --- a/egs/wsj/s5/steps/libs/common.py +++ b/egs/wsj/s5/steps/libs/common.py @@ -246,9 +246,21 @@ def get_feat_dim_from_scp(feat_scp): return feat_dim +#<<<<<<< HEAD +#def split_data(data, num_jobs, per_utt=False): +# if per_utt: +# run_kaldi_command("utils/split_data.sh --per-utt {data} {num_jobs}" +# "".format(data=data, num_jobs=num_jobs)) +# return "{data}/split{num_jobs}utt".format(data=data, num_jobs=num_jobs) +# +# run_kaldi_command("utils/split_data.sh {data} {num_jobs}" +# "".format(data=data, num_jobs=num_jobs)) +# return "{data}/split{num_jobs}".format(data=data, num_jobs=num_jobs) +#======= def split_data(data, num_jobs): execute_command("utils/split_data.sh {data} {num_jobs}".format( data=data, num_jobs=num_jobs)) +#>>>>>>> 720133715566a823be93e634de094f448b67b7f1 def read_kaldi_matrix(matrix_file): diff --git a/egs/wsj/s5/steps/lmrescore_rnnlm_lat.sh b/egs/wsj/s5/steps/lmrescore_rnnlm_lat.sh index 75b08bc4779..d3e6ca73dd4 100755 --- a/egs/wsj/s5/steps/lmrescore_rnnlm_lat.sh +++ b/egs/wsj/s5/steps/lmrescore_rnnlm_lat.sh @@ -1,6 +1,7 @@ #!/bin/bash # Copyright 2015 Guoguo Chen +# 2017 Hainan Xu # Apache 2.0 # This script rescores lattices with RNNLM. See also rnnlmrescore.sh which is @@ -14,6 +15,8 @@ N=10 inv_acwt=12 weight=1.0 # Interpolation weight for RNNLM. # End configuration section. +rnnlm_ver= +#layer_string= echo "$0 $@" # Print the command line for logging @@ -39,6 +42,25 @@ data=$3 indir=$4 outdir=$5 +rescoring_binary=lattice-lmrescore-rnnlm + +first_arg=ark:$rnnlm_dir/unk.probs # this is for mikolov's rnnlm +extra_arg= + +if [ "$rnnlm_ver" == "cuedrnnlm" ]; then + layer_string=`cat $rnnlm_dir/layer_string | sed "s=:= =g"` + total_size=`wc -l $rnnlm_dir/unigram.counts | awk '{print $1}'` + rescoring_binary="lattice-lmrescore-cuedrnnlm" + cat $rnnlm_dir/rnnlm.input.wlist.index | tail -n +2 | awk '{print $1-1,$2}' > $rnnlm_dir/rnn.wlist + extra_arg="--full-voc-size=$total_size --layer-sizes=\"$layer_string\"" + first_arg=$rnnlm_dir/rnn.wlist +fi + +if [ "$rnnlm_ver" == "tensorflow" ]; then + rescoring_binary="lattice-lmrescore-tf-rnnlm" + first_arg="$first_arg $rnnlm_dir/wordlist.rnn.final" +fi + oldlm=$oldlang/G.fst if [ -f $oldlang/G.carpa ]; then oldlm=$oldlang/G.carpa @@ -48,7 +70,7 @@ elif [ ! -f $oldlm ]; then fi [ ! -f $oldlm ] && echo "$0: Missing file $oldlm" && exit 1; -[ ! -f $rnnlm_dir/rnnlm ] && echo "$0: Missing file $rnnlm_dir/rnnlm" && exit 1; +[ ! -f $rnnlm_dir/rnnlm ] && [ ! -d $rnnlm_dir/rnnlm ] && echo "$0: Missing file $rnnlm_dir/rnnlm" && exit 1; [ ! -f $rnnlm_dir/unk.probs ] &&\ echo "$0: Missing file $rnnlm_dir/unk.probs" && exit 1; [ ! -f $oldlang/words.txt ] &&\ @@ -72,20 +94,19 @@ if [ "$oldlm" == "$oldlang/G.fst" ]; then $cmd JOB=1:$nj $outdir/log/rescorelm.JOB.log \ lattice-lmrescore --lm-scale=$oldlm_weight \ "ark:gunzip -c $indir/lat.JOB.gz|" "$oldlm_command" ark:- \| \ - lattice-lmrescore-rnnlm --lm-scale=$weight \ - --max-ngram-order=$max_ngram_order ark:$rnnlm_dir/unk.probs \ - $oldlang/words.txt ark:- "$rnnlm_dir/rnnlm" \ + $rescoring_binary $extra_arg --lm-scale=$weight \ + --max-ngram-order=$max_ngram_order \ + $first_arg $oldlang/words.txt ark:- "$rnnlm_dir/rnnlm" \ "ark,t:|gzip -c>$outdir/lat.JOB.gz" || exit 1; else $cmd JOB=1:$nj $outdir/log/rescorelm.JOB.log \ lattice-lmrescore-const-arpa --lm-scale=$oldlm_weight \ - "ark:gunzip -c $indir/lat.JOB.gz|" "$oldlm" ark:- \| \ - lattice-lmrescore-rnnlm --lm-scale=$weight \ - --max-ngram-order=$max_ngram_order ark:$rnnlm_dir/unk.probs \ - $oldlang/words.txt ark:- "$rnnlm_dir/rnnlm" \ + "ark:gunzip -c $indir/lat.JOB.gz|" "$oldlm_command" ark:- \| \ + $rescoring_binary $extra_arg --lm-scale=$weight \ + --max-ngram-order=$max_ngram_order \ + $first_arg $oldlang/words.txt ark:- "$rnnlm_dir/rnnlm" \ "ark,t:|gzip -c>$outdir/lat.JOB.gz" || exit 1; fi - if ! $skip_scoring ; then err_msg="Not scoring because local/score.sh does not exist or not executable." [ ! -x local/score.sh ] && echo $err_msg && exit 1; diff --git a/egs/wsj/s5/steps/make_fbank.sh b/egs/wsj/s5/steps/make_fbank.sh index 1baecb3939a..39490e992dc 100755 --- a/egs/wsj/s5/steps/make_fbank.sh +++ b/egs/wsj/s5/steps/make_fbank.sh @@ -10,6 +10,7 @@ nj=4 cmd=run.pl fbank_config=conf/fbank.conf compress=true +write_utt2num_frames=false # if true writes utt2num_frames # End configuration section. echo "$0 $@" # Print the command line for logging @@ -25,6 +26,7 @@ if [ $# -lt 1 ] || [ $# -gt 3 ]; then echo " --fbank-config # config passed to compute-fbank-feats " echo " --nj # number of parallel jobs" echo " --cmd (utils/run.pl|utils/queue.pl ) # how to run jobs." + echo " --write-utt2num-frames # If true, write utt2num_frames file." exit 1; fi @@ -83,6 +85,12 @@ for n in $(seq $nj); do utils/create_data_link.pl $fbankdir/raw_fbank_$name.$n.ark done +if $write_utt2num_frames; then + write_num_frames_opt="--write-num-frames=ark,t:$logdir/utt2num_frames.JOB" +else + write_num_frames_opt= +fi + if [ -f $data/segments ]; then echo "$0 [info]: segments file exists: using that." split_segments="" @@ -96,7 +104,7 @@ if [ -f $data/segments ]; then $cmd JOB=1:$nj $logdir/make_fbank_${name}.JOB.log \ extract-segments scp,p:$scp $logdir/segments.JOB ark:- \| \ compute-fbank-feats $vtln_opts --verbose=2 --config=$fbank_config ark:- ark:- \| \ - copy-feats --compress=$compress ark:- \ + copy-feats --compress=$compress $write_num_frames_opt ark:- \ ark,scp:$fbankdir/raw_fbank_$name.JOB.ark,$fbankdir/raw_fbank_$name.JOB.scp \ || exit 1; @@ -111,7 +119,7 @@ else $cmd JOB=1:$nj $logdir/make_fbank_${name}.JOB.log \ compute-fbank-feats $vtln_opts --verbose=2 --config=$fbank_config scp,p:$logdir/wav.JOB.scp ark:- \| \ - copy-feats --compress=$compress ark:- \ + copy-feats --compress=$compress $write_num_frames_opt ark:- \ ark,scp:$fbankdir/raw_fbank_$name.JOB.ark,$fbankdir/raw_fbank_$name.JOB.scp \ || exit 1; @@ -129,6 +137,13 @@ for n in $(seq $nj); do cat $fbankdir/raw_fbank_$name.$n.scp || exit 1; done > $data/feats.scp +if $write_utt2num_frames; then + for n in $(seq $nj); do + cat $logdir/utt2num_frames.$n || exit 1; + done > $data/utt2num_frames || exit 1 + rm $logdir/uttnum_frames.* +fi + rm $logdir/wav.*.scp $logdir/segments.* 2>/dev/null nf=`cat $data/feats.scp | wc -l` diff --git a/egs/wsj/s5/steps/make_fbank_pitch.sh b/egs/wsj/s5/steps/make_fbank_pitch.sh index 4dbd00e09bd..dcabded4770 100755 --- a/egs/wsj/s5/steps/make_fbank_pitch.sh +++ b/egs/wsj/s5/steps/make_fbank_pitch.sh @@ -15,6 +15,7 @@ pitch_config=conf/pitch.conf pitch_postprocess_config= paste_length_tolerance=2 compress=true +write_utt2num_frames=false # if true writes utt2num_frames # End configuration section. echo "$0 $@" # Print the command line for logging @@ -33,6 +34,7 @@ if [ $# -lt 1 ] || [ $# -gt 3 ]; then echo " --paste-length-tolerance # length tolerance passed to paste-feats" echo " --nj # number of parallel jobs" echo " --cmd (utils/run.pl|utils/queue.pl ) # how to run jobs." + echo " --write-utt2num-frames # If true, write utt2num_frames file." exit 1; fi @@ -97,6 +99,12 @@ for n in $(seq $nj); do utils/create_data_link.pl $fbank_pitch_dir/raw_fbank_pitch_$name.$n.ark done +if $write_utt2num_frames; then + write_num_frames_opt="--write-num-frames=ark,t:$logdir/utt2num_frames.JOB" +else + write_num_frames_opt= +fi + if [ -f $data/segments ]; then echo "$0 [info]: segments file exists: using that." split_segments="" @@ -112,7 +120,7 @@ if [ -f $data/segments ]; then $cmd JOB=1:$nj $logdir/make_fbank_pitch_${name}.JOB.log \ paste-feats --length-tolerance=$paste_length_tolerance "$fbank_feats" "$pitch_feats" ark:- \| \ - copy-feats --compress=$compress ark:- \ + copy-feats --compress=$compress $write_num_frames_opt ark:- \ ark,scp:$fbank_pitch_dir/raw_fbank_pitch_$name.JOB.ark,$fbank_pitch_dir/raw_fbank_pitch_$name.JOB.scp \ || exit 1; @@ -130,7 +138,7 @@ else $cmd JOB=1:$nj $logdir/make_fbank_pitch_${name}.JOB.log \ paste-feats --length-tolerance=$paste_length_tolerance "$fbank_feats" "$pitch_feats" ark:- \| \ - copy-feats --compress=$compress ark:- \ + copy-feats --compress=$compress $write_num_frames_opt ark:- \ ark,scp:$fbank_pitch_dir/raw_fbank_pitch_$name.JOB.ark,$fbank_pitch_dir/raw_fbank_pitch_$name.JOB.scp \ || exit 1; @@ -148,6 +156,13 @@ for n in $(seq $nj); do cat $fbank_pitch_dir/raw_fbank_pitch_$name.$n.scp || exit 1; done > $data/feats.scp +if $write_utt2num_frames; then + for n in $(seq $nj); do + cat $logdir/utt2num_frames.$n || exit 1; + done > $data/utt2num_frames || exit 1 + rm $logdir/uttnum_frames.* +fi + rm $logdir/wav.*.scp $logdir/segments.* 2>/dev/null nf=`cat $data/feats.scp | wc -l` diff --git a/egs/wsj/s5/steps/make_mfcc_pitch.sh b/egs/wsj/s5/steps/make_mfcc_pitch.sh index ff9a7d2f5f3..996dd0367bf 100755 --- a/egs/wsj/s5/steps/make_mfcc_pitch.sh +++ b/egs/wsj/s5/steps/make_mfcc_pitch.sh @@ -15,6 +15,7 @@ pitch_config=conf/pitch.conf pitch_postprocess_config= paste_length_tolerance=2 compress=true +write_utt2num_frames=false # if true writes utt2num_frames # End configuration section. echo "$0 $@" # Print the command line for logging @@ -33,6 +34,7 @@ if [ $# -lt 1 ] || [ $# -gt 3 ]; then echo " --paste-length-tolerance # length tolerance passed to paste-feats" echo " --nj # number of parallel jobs" echo " --cmd (utils/run.pl|utils/queue.pl ) # how to run jobs." + echo " --write-utt2num-frames # If true, write utt2num_frames file." exit 1; fi @@ -96,6 +98,12 @@ for n in $(seq $nj); do utils/create_data_link.pl $mfcc_pitch_dir/raw_mfcc_pitch_$name.$n.ark done +if $write_utt2num_frames; then + write_num_frames_opt="--write-num-frames=ark,t:$logdir/utt2num_frames.JOB" +else + write_num_frames_opt= +fi + if [ -f $data/segments ]; then echo "$0 [info]: segments file exists: using that." split_segments="" @@ -111,7 +119,7 @@ if [ -f $data/segments ]; then $cmd JOB=1:$nj $logdir/make_mfcc_pitch_${name}.JOB.log \ paste-feats --length-tolerance=$paste_length_tolerance "$mfcc_feats" "$pitch_feats" ark:- \| \ - copy-feats --compress=$compress ark:- \ + copy-feats --compress=$compress $write_num_frames_opt ark:- \ ark,scp:$mfcc_pitch_dir/raw_mfcc_pitch_$name.JOB.ark,$mfcc_pitch_dir/raw_mfcc_pitch_$name.JOB.scp \ || exit 1; @@ -129,7 +137,7 @@ else $cmd JOB=1:$nj $logdir/make_mfcc_pitch_${name}.JOB.log \ paste-feats --length-tolerance=$paste_length_tolerance "$mfcc_feats" "$pitch_feats" ark:- \| \ - copy-feats --compress=$compress ark:- \ + copy-feats --compress=$compress $write_num_frames_opt ark:- \ ark,scp:$mfcc_pitch_dir/raw_mfcc_pitch_$name.JOB.ark,$mfcc_pitch_dir/raw_mfcc_pitch_$name.JOB.scp \ || exit 1; @@ -147,6 +155,13 @@ for n in $(seq $nj); do cat $mfcc_pitch_dir/raw_mfcc_pitch_$name.$n.scp || exit 1; done > $data/feats.scp +if $write_utt2num_frames; then + for n in $(seq $nj); do + cat $logdir/utt2num_frames.$n || exit 1; + done > $data/utt2num_frames || exit 1 + rm $logdir/uttnum_frames.* +fi + rm $logdir/wav_${name}.*.scp $logdir/segments.* 2>/dev/null nf=`cat $data/feats.scp | wc -l` diff --git a/egs/wsj/s5/steps/nnet/ivector/extract_ivectors.sh b/egs/wsj/s5/steps/nnet/ivector/extract_ivectors.sh index 36af3ab49d8..0e920e4a9b4 100755 --- a/egs/wsj/s5/steps/nnet/ivector/extract_ivectors.sh +++ b/egs/wsj/s5/steps/nnet/ivector/extract_ivectors.sh @@ -164,13 +164,13 @@ if [ $stage -le 2 ]; then weight-post ark:- "ark,s,cs:gunzip -c $dir/weights.gz|" ark:- \| \ ivector-extract --acoustic-weight=$posterior_scale --compute-objf-change=true \ --max-count=$max_count --spk2utt=ark:$this_sdata/JOB/spk2utt \ - $srcdir/final.ie "$feats" ark,s,cs:- ark,t:$dir/ivectors_spk.JOB.ark + $srcdir/final.ie "$feats" ark,s,cs:- ark:$dir/ivectors_spk.JOB.ark else $cmd JOB=1:$nj $dir/log/extract_ivectors.JOB.log \ gmm-global-get-post --n=$num_gselect --min-post=$min_post $srcdir/final.dubm "$gmm_feats" ark:- \| \ ivector-extract --acoustic-weight=$posterior_scale --compute-objf-change=true \ --max-count=$max_count --spk2utt=ark:$this_sdata/JOB/spk2utt \ - $srcdir/final.ie "$feats" ark,s,cs:- ark,t:$dir/ivectors_spk.JOB.ark + $srcdir/final.ie "$feats" ark,s,cs:- ark:$dir/ivectors_spk.JOB.ark fi fi @@ -181,39 +181,28 @@ if [ $stage -le 3 ]; then gmm-global-get-post --n=$num_gselect --min-post=$min_post $srcdir/final.dubm "$gmm_feats" ark:- \| \ weight-post ark:- "ark,s,cs:gunzip -c $dir/weights.gz|" ark:- \| \ ivector-extract --acoustic-weight=$posterior_scale --compute-objf-change=true --max-count=$max_count \ - $srcdir/final.ie "$feats" ark,s,cs:- ark,t:$dir/ivectors_utt.JOB.ark + $srcdir/final.ie "$feats" ark,s,cs:- ark:$dir/ivectors_utt.JOB.ark else $cmd JOB=1:$nj $dir/log/extract_ivectors_utt.JOB.log \ gmm-global-get-post --n=$num_gselect --min-post=$min_post $srcdir/final.dubm "$gmm_feats" ark:- \| \ ivector-extract --acoustic-weight=$posterior_scale --compute-objf-change=true --max-count=$max_count \ - $srcdir/final.ie "$feats" ark,s,cs:- ark,t:$dir/ivectors_utt.JOB.ark + $srcdir/final.ie "$feats" ark,s,cs:- ark:$dir/ivectors_utt.JOB.ark fi fi - -# get an utterance-level set of iVectors (just duplicate the speaker-level ones). -# note: if $this_sdata is set $dir/split$nj, then these won't be real speakers, they'll -# be "sub-speakers" (speakers split up into multiple utterances). -if [ $stage -le 4 ]; then - for j in $(seq $nj); do - utils/apply_map.pl -f 2 $dir/ivectors_spk.${j}.ark <$this_sdata/$j/utt2spk >$dir/ivectors_spk-as-utt.${j}.ark - done -fi - -ivector_dim=$[$(head -n 1 $dir/ivectors_spk.1.ark | wc -w) - 3] -echo "$0: iVector dim is $ivector_dim" - absdir=$(readlink -f $dir) - -if [ $stage -le 5 ]; then +if [ $stage -le 4 ]; then echo "$0: merging iVectors across jobs" copy-vector "ark:cat $dir/ivectors_spk.*.ark |" ark,scp:$absdir/ivectors_spk.ark,$dir/ivectors_spk.scp rm $dir/ivectors_spk.*.ark - copy-vector "ark:cat $dir/ivectors_spk-as-utt.*.ark |" ark,scp:$absdir/ivectors_spk-as-utt.ark,$dir/ivectors_spk-as-utt.scp - rm $dir/ivectors_spk-as-utt.*.ark copy-vector "ark:cat $dir/ivectors_utt.*.ark |" ark,scp:$absdir/ivectors_utt.ark,$dir/ivectors_utt.scp rm $dir/ivectors_utt.*.ark fi +# duplicate the `speaker' i-vector to all `utterances' of that speaker, +if [ $stage -le 5 ]; then + utils/apply_map.pl -f 2 $dir/ivectors_spk.scp <$data/utt2spk >$dir/ivectors_spk-as-utt.scp +fi + echo "$0: done extracting iVectors (per-speaker, per-sentence) into '$dir'" diff --git a/egs/wsj/s5/steps/nnet/train_mpe.sh b/egs/wsj/s5/steps/nnet/train_mpe.sh index 1d2a6256ea8..7b17c88e8ec 100755 --- a/egs/wsj/s5/steps/nnet/train_mpe.sh +++ b/egs/wsj/s5/steps/nnet/train_mpe.sh @@ -1,5 +1,5 @@ #!/bin/bash -# Copyright 2013-2015 Brno University of Technology (author: Karel Vesely) +# Copyright 2013-2017 Brno University of Technology (author: Karel Vesely) # Apache 2.0. # Sequence-discriminative MPE/sMBR training of DNN. @@ -20,10 +20,14 @@ learn_rate=0.00001 momentum=0.0 halving_factor=1.0 #ie. disable halving do_smbr=true -exclude_silphones=true # exclude silphones from approximate accuracy computation -unkphonelist= # exclude unkphones from approximate accuracy computation (overrides exclude_silphones) -one_silence_class=true # true : reduce insertions in sMBR/MPE FW/BW, more stable training, - # (all silphones are seen as a single class in the sMBR/MPE FW/BW) +one_silence_class=true # if true : all the `silphones' are mapped to a single class in the Forward-backward of sMBR/MPE, + # (this prevents the sMBR from WER explosion, which was happenning with some data). + # if false : the silphone-frames are always counted as 'wrong' in the calculation of the approximate accuracies, +silphonelist= # this overrides default silphone-list (for selecting a subset of sil-phones) + +unkphonelist= # dummy deprecated option, for backward compatibility, +exclude_silphones= # dummy deprecated option, for backward compatibility, + verbose=1 ivector= nnet= # For non-default location of nnet, @@ -78,7 +82,7 @@ cp $lang/phones.txt $dir cp $alidir/{final.mdl,tree} $dir -silphonelist=`cat $lang/phones/silence.csl` +[ -z $silphonelist ] && silphonelist=`cat $lang/phones/silence.csl` # Default 'silphonelist', #Get the files we will need [ -z "$nnet" ] && nnet=$srcdir/$(readlink $srcdir/final.nnet || echo final.nnet); @@ -99,18 +103,13 @@ cp $feature_transform $dir/final.feature_transform model=$dir/final.mdl [ -z "$model" ] && echo "Error transition model '$model' does not exist!" && exit 1; -# The argument '--silence-phones=csl' together with '--one-silence-class=true' -# will cause regrouping of the silenece phones into a single class in the FW/BW -# which calculates the Loss derivative (the 'new' behavior). -mpe_silphones_arg= #empty -$exclude_silphones && mpe_silphones_arg="--silence-phones=$silphonelist" # all silphones -[ ! -z $unkphonelist ] && mpe_silphones_arg="--silence-phones=$unkphonelist" # unk only - - # Shuffle the feature list to make the GD stochastic! # By shuffling features, we have to use lattices with random access (indexed by .scp file). cat $data/feats.scp | utils/shuffle_list.pl --srand $seed > $dir/train.scp +[ -n "$unkphonelist" ] && echo "WARNING: The option '--unkphonelist' is now deprecated. Please remove it from your recipe..." +[ -n "$exclude_silphones" ] && echo "WARNING: The option '--exclude-silphones' is now deprecated. Please remove it from your recipe..." + ### ### PREPARE FEATURE EXTRACTION PIPELINE ### @@ -192,7 +191,7 @@ while [ $x -le $num_iters ]; do --do-smbr=$do_smbr \ --verbose=$verbose \ --one-silence-class=$one_silence_class \ - $mpe_silphones_arg \ + ${silphonelist:+ --silence-phones=$silphonelist} \ $cur_mdl $alidir/final.mdl "$feats" "$lats" "$ali" $dir/$x.nnet fi cur_mdl=$dir/$x.nnet diff --git a/egs/wsj/s5/utils/data/convert_data_dir_to_whole.sh b/egs/wsj/s5/utils/data/convert_data_dir_to_whole.sh new file mode 100755 index 00000000000..3a01627cb16 --- /dev/null +++ b/egs/wsj/s5/utils/data/convert_data_dir_to_whole.sh @@ -0,0 +1,111 @@ +#! /bin/bash + +# This scripts converts a data directory into a "whole" data directory +# by removing the segments and using the recordings themselves as +# utterances + +set -o pipefail + +. path.sh + +cmd=run.pl +stage=-1 + +. parse_options.sh + +if [ $# -ne 2 ]; then + echo "Usage: convert_data_dir_to_whole.sh " + echo " e.g.: convert_data_dir_to_whole.sh data/dev data/dev_whole" + exit 1 +fi + +data=$1 +dir=$2 + +if [ ! -f $data/segments ]; then + # Data directory already does not contain segments. So just copy it. + utils/copy_data_dir.sh $data $dir + exit 0 +fi + +mkdir -p $dir +cp $data/wav.scp $dir +cp $data/reco2file_and_channel $dir +rm -f $dir/{utt2spk,text} || true + +[ -f $data/stm ] && cp $data/stm $dir +[ -f $data/glm ] && cp $data/glm $dir + +text_files= +[ -f $data/text ] && text_files="$data/text $dir/text" + +# Combine utt2spk and text from the segments into utt2spk and text for the whole +# recording. +cat $data/segments | sort -k2,2 -k3,4n | perl -e ' +if (scalar @ARGV == 4) { + ($utt2spk_in, $utt2spk_out, $text_in, $text_out) = @ARGV; +} elsif (scalar @ARGV == 2) { + ($utt2spk_in, $utt2spk_out) = @ARGV; +} else { + die "Unexpected number of arguments"; +} + +if (defined $text_in) { + open(TI, "<$text_in") || die "Error: fail to open $text_in\n"; + open(TO, ">$text_out") || die "Error: fail to open $text_out\n"; +} +open(UI, "<$utt2spk_in") || die "Error: fail to open $utt2spk_in\n"; +open(UO, ">$utt2spk_out") || die "Error: fail to open $utt2spk_out\n"; + +my %file2utt = (); +while () { + chomp; + my @col = split; + @col >= 4 or die "bad line $_\n"; + + if (! defined $file2utt{$col[1]}) { + $file2utt{$col[1]} = []; + } + push @{$file2utt{$col[1]}}, $col[0]; +} + +my %text = (); +my %utt2spk = (); + +while () { + chomp; + my @col = split; + $utt2spk{$col[0]} = $col[1]; +} + +if (defined $text_in) { + while () { + chomp; + my @col = split; + @col >= 1 or die "bad line $_\n"; + + my $utt = shift @col; + $text{$utt} = join(" ", @col); + } +} + +foreach $file (keys %file2utt) { + my @utts = @{$file2utt{$file}}; + #print STDERR $file . " " . join(" ", @utts) . "\n"; + print UO "$file $file\n"; + + if (defined $text_in) { + $text_line = ""; + foreach $utt (@utts) { + $text_line = "$text_line " . $text{$utt} + } + print TO "$file $text_line\n"; + } +} +' $data/utt2spk $dir/utt2spk $text_files + +sort -u $dir/utt2spk > $dir/utt2spk.tmp +mv $dir/utt2spk.tmp $dir/utt2spk +utils/utt2spk_to_spk2utt.pl $dir/utt2spk > $dir/spk2utt + +utils/fix_data_dir.sh $dir diff --git a/egs/wsj/s5/utils/data/get_frame_shift.sh b/egs/wsj/s5/utils/data/get_frame_shift.sh index d032c9c17fa..7413dbd0917 100755 --- a/egs/wsj/s5/utils/data/get_frame_shift.sh +++ b/egs/wsj/s5/utils/data/get_frame_shift.sh @@ -53,7 +53,7 @@ if [ -z $temp ]; then fi head -n 10 $dir/utt2dur | paste - $temp | \ - awk '{ dur += $2; frames += $4; } END { shift = dur / frames; if (shift > 0.01 && shift < 0.0102) shift = 0.01; print shift; }' || exit 1; + awk '{ dur += $2; frames += $4; } END { shift = dur / frames; if (shift > 0.0098 && shift < 0.0102) shift = 0.01; print shift; }' || exit 1; rm $temp diff --git a/egs/wsj/s5/utils/data/get_segments_for_data.sh b/egs/wsj/s5/utils/data/get_segments_for_data.sh index 694acc6a256..7adc4c465d3 100755 --- a/egs/wsj/s5/utils/data/get_segments_for_data.sh +++ b/egs/wsj/s5/utils/data/get_segments_for_data.sh @@ -19,7 +19,7 @@ fi data=$1 -if [ ! -f $data/utt2dur ]; then +if [ ! -s $data/utt2dur ]; then utils/data/get_utt2dur.sh $data 1>&2 || exit 1; fi diff --git a/egs/wsj/s5/utils/data/get_uniform_subsegments.py b/egs/wsj/s5/utils/data/get_uniform_subsegments.py new file mode 100755 index 00000000000..8479251bed7 --- /dev/null +++ b/egs/wsj/s5/utils/data/get_uniform_subsegments.py @@ -0,0 +1,86 @@ +#! /usr/bin/env python + +# Copyright 2017 Vimal Manohar +# Apache 2.0. + +import argparse +import logging +import sys +import textwrap + +def get_args(): + parser = argparse.ArgumentParser( + description=textwrap.dedent(""" + Creates a subsegments file from an input segments file + that has the format + , + where the timing are relative to the start-time of the + in the input segments file. + + e.g.: get_uniform_subsegments.py data/dev/segments > \\ + data/dev_uniform_segments/sub_segments + + utils/data/subsegment_data_dir.sh data/dev \\ + data/dev_uniform_segments/sub_segments data/dev_uniform_segments + + The output is written to stdout. The resulting file can be + passed to utils/data/subsegment_data_dir.sh to sub-segment + the data directory."""), + formatter_class=argparse.RawDescriptionHelpFormatter) + parser.add_argument("--max-segment-duration", type=float, + default=30, help="""Maximum duration of the + subsegments (in seconds)""") + parser.add_argument("--overlap-duration", type=float, + default=5, help="""Overlap between + adjacent segments (in seconds)""") + parser.add_argument("--max-remaining-duration", type=float, + default=10, help="""Segment is not split + if the left-over duration is more than this + many seconds""") + parser.add_argument("segments_file", type=argparse.FileType('r'), + help="""Input kaldi segments file""") + + args = parser.parse_args() + return args + + +def run(args): + for line in args.segments_file: + parts = line.strip().split() + utt_id = parts[0] + start_time = float(parts[2]) + end_time = float(parts[3]) + + dur = end_time - start_time + + start = start_time + while (dur > args.max_segment_duration + + args.max_remaining_duration): + end = start + args.max_segment_duration + new_utt = "{utt_id}-{s:06d}-{e:06d}".format( + utt_id=utt_id, s=int(100 * start), e=int(100 * end)) + print ("{new_utt} {utt_id} {s} {e}".format( + new_utt=new_utt, utt_id=utt_id, s=start, + e=start + args.max_segment_duration)) + start += args.max_segment_duration - args.overlap_duration + dur -= args.max_segment_duration - args.overlap_duration + + new_utt = "{utt_id}-{s:06d}-{e:06d}".format( + utt_id=utt_id, s=int(100 * start), e=int(100 * end_time)) + print ("{new_utt} {utt_id} {s} {e}".format( + new_utt=new_utt, utt_id=utt_id, s=start, e=end_time)) + + +def main(): + args = get_args() + try: + run(args) + except Exception: + logging.error("Failed creating subsegments", exc_info=True) + raise SystemExit(1) + finally: + args.segments_file.close() + + +if __name__ == '__main__': + main() diff --git a/egs/wsj/s5/utils/data/get_utt2dur.sh b/egs/wsj/s5/utils/data/get_utt2dur.sh index f14fc2c5e81..c415e8dfb81 100755 --- a/egs/wsj/s5/utils/data/get_utt2dur.sh +++ b/egs/wsj/s5/utils/data/get_utt2dur.sh @@ -35,7 +35,7 @@ if [ -s $data/utt2dur ] && \ exit 0; fi -if [ -f $data/segments ]; then +if [ -s $data/segments ]; then echo "$0: working out $data/utt2dur from $data/segments" cat $data/segments | awk '{len=$4-$3; print $1, len;}' > $data/utt2dur elif [ -f $data/wav.scp ]; then diff --git a/egs/wsj/s5/utils/data/get_utt2num_frames.sh b/egs/wsj/s5/utils/data/get_utt2num_frames.sh new file mode 100755 index 00000000000..e2921601ec9 --- /dev/null +++ b/egs/wsj/s5/utils/data/get_utt2num_frames.sh @@ -0,0 +1,42 @@ +#! /bin/bash + +cmd=run.pl +nj=4 + +frame_shift=0.01 +frame_overlap=0.015 + +. utils/parse_options.sh + +if [ $# -ne 1 ]; then + echo "This script writes a file utt2num_frames with the " + echo "number of frames in each utterance as measured based on the " + echo "duration of the utterances (in utt2dur) and the specified " + echo "frame_shift and frame_overlap." + echo "Usage: $0 " + exit 1 +fi + +data=$1 + +if [ -f $data/utt2num_frames ]; then + echo "$0: $data/utt2num_frames already present!" + exit 0; +fi + +if [ ! -f $data/feats.scp ]; then + utils/data/get_utt2dur.sh $data + awk -v fs=$frame_shift -v fovlp=$frame_overlap \ + '{print $1" "int( ($2 - fovlp) / fs)}' $data/utt2dur > $data/utt2num_frames + exit 0 +fi + +utils/split_data.sh $data $nj || exit 1 +$cmd JOB=1:$nj $data/log/get_utt2num_frames.JOB.log \ + feat-to-len scp:$data/split${nj}/JOB/feats.scp ark,t:$data/split$nj/JOB/utt2num_frames || exit 1 + +for n in `seq $nj`; do + cat $data/split$nj/$n/utt2num_frames +done > $data/utt2num_frames + +echo "$0: Computed and wrote $data/utt2num_frames" diff --git a/egs/wsj/s5/utils/data/normalize_data_range.pl b/egs/wsj/s5/utils/data/normalize_data_range.pl index f7936d98a31..d58421aa9be 100755 --- a/egs/wsj/s5/utils/data/normalize_data_range.pl +++ b/egs/wsj/s5/utils/data/normalize_data_range.pl @@ -45,14 +45,15 @@ sub combine_ranges { # though they are supported at the C++ level. if ($start1 eq "" || $start2 eq "" || $end1 eq "" || $end2 == "") { chop $line; - print("normalize_data_range.pl: could not make sense of line $line\n"); + print STDERR ("normalize_data_range.pl: could not make sense of line $line\n"); exit(1) } if ($start1 + $end2 > $end1) { chop $line; - print("normalize_data_range.pl: could not make sense of line $line " . + print STDERR ("normalize_data_range.pl: could not make sense of line $line " . "[second $row_or_column range too large vs first range, $start1 + $end2 > $end1]\n"); - exit(1); + # exit(1); + return ($start2+$start1, $end1); } return ($start2+$start1, $end2+$start1); } @@ -72,11 +73,11 @@ sub combine_ranges { # sometimes in scp files, we use the command concat-feats to splice together # two feature matrices. Handling this correctly is complicated and we don't # anticipate needing it, so we just refuse to process this type of data. - print "normalize_data_range.pl: this script cannot [yet] normalize the data ranges " . - "if concat-feats was in the input data\n"; + print STDERR ("normalize_data_range.pl: this script cannot [yet] normalize the data ranges " . + "if concat-feats was in the input data\n"); exit(1); } - print STDERR "matched: $before_range $first_range $second_range\n"; + # print STDERR "matched: $before_range $first_range $second_range\n"; if ($first_range !~ m/^((\d*):(\d*)|)(,(\d*):(\d*)|)$/) { print STDERR "normalize_data_range.pl: could not make sense of input line $_"; exit(1); diff --git a/egs/wsj/s5/utils/data/perturb_data_dir_speed_3way.sh b/egs/wsj/s5/utils/data/perturb_data_dir_speed_3way.sh index c575166534e..5b007cadb3f 100755 --- a/egs/wsj/s5/utils/data/perturb_data_dir_speed_3way.sh +++ b/egs/wsj/s5/utils/data/perturb_data_dir_speed_3way.sh @@ -43,5 +43,4 @@ utils/data/combine_data.sh $destdir ${srcdir} ${destdir}_speed0.9 ${destdir}_spe rm -r ${destdir}_speed0.9 ${destdir}_speed1.1 echo "$0: generated 3-way speed-perturbed version of data in $srcdir, in $destdir" -utils/validate_data_dir.sh --no-feats $destdir - +utils/validate_data_dir.sh --no-feats --no-text $destdir diff --git a/egs/wsj/s5/utils/data/subsegment_data_dir.sh b/egs/wsj/s5/utils/data/subsegment_data_dir.sh index 18a00c3df7d..ce006b6d2b3 100755 --- a/egs/wsj/s5/utils/data/subsegment_data_dir.sh +++ b/egs/wsj/s5/utils/data/subsegment_data_dir.sh @@ -24,9 +24,9 @@ segment_end_padding=0.0 . utils/parse_options.sh -if [ $# != 4 ]; then +if [ $# != 4 ] && [ $# != 3 ]; then echo "Usage: " - echo " $0 [options] " + echo " $0 [options] [] " echo "This script sub-segments a data directory. is to" echo "have lines of the form " echo "and is of the form ... ." @@ -50,11 +50,23 @@ export LC_ALL=C srcdir=$1 subsegments=$2 -new_text=$3 -dir=$4 +add_subsegment_text=false +if [ $# -eq 4 ]; then + new_text=$3 + dir=$4 + add_subsegment_text=true -for f in "$subsegments" "$new_text" "$srcdir/utt2spk"; do + if [ ! -f "$new_text" ]; then + echo "$0: no such file $new_text" + exit 1 + fi + +else + dir=$3 +fi + +for f in "$subsegments" "$srcdir/utt2spk"; do if [ ! -f "$f" ]; then echo "$0: no such file $f" exit 1; @@ -65,9 +77,11 @@ if ! mkdir -p $dir; then echo "$0: failed to create directory $dir" fi -if ! cmp <(awk '{print $1}' <$subsegments) <(awk '{print $1}' <$new_text); then - echo "$0: expected the first fields of the files $subsegments and $new_text to be identical" - exit 1 +if $add_subsegment_text; then + if ! cmp <(awk '{print $1}' <$subsegments) <(awk '{print $1}' <$new_text); then + echo "$0: expected the first fields of the files $subsegments and $new_text to be identical" + exit 1 + fi fi # create the utt2spk in $dir @@ -86,8 +100,11 @@ awk '{print $1, $2}' < $subsegments > $dir/new2old_utt utils/apply_map.pl -f 2 $srcdir/utt2spk < $dir/new2old_utt >$dir/utt2spk # .. and the new spk2utt file. utils/utt2spk_to_spk2utt.pl <$dir/utt2spk >$dir/spk2utt -# the new text file is just what the user provides. -cp $new_text $dir/text + +if $add_subsegment_text; then + # the new text file is just what the user provides. + cp $new_text $dir/text +fi # copy the source wav.scp cp $srcdir/wav.scp $dir @@ -184,6 +201,7 @@ utils/data/fix_data_dir.sh $dir validate_opts= [ ! -f $srcdir/feats.scp ] && validate_opts="$validate_opts --no-feats" [ ! -f $srcdir/wav.scp ] && validate_opts="$validate_opts --no-wav" +! $add_subsegment_text && validate_opts="$validate_opts --no-text" utils/data/validate_data_dir.sh $validate_opts $dir diff --git a/egs/wsj/s5/utils/perturb_data_dir_speed.sh b/egs/wsj/s5/utils/perturb_data_dir_speed.sh index 20ff86755eb..667bd934f04 100755 --- a/egs/wsj/s5/utils/perturb_data_dir_speed.sh +++ b/egs/wsj/s5/utils/perturb_data_dir_speed.sh @@ -112,4 +112,5 @@ cat $srcdir/utt2dur | utils/apply_map.pl -f 1 $destdir/utt_map | \ rm $destdir/spk_map $destdir/utt_map 2>/dev/null echo "$0: generated speed-perturbed version of data in $srcdir, in $destdir" -utils/validate_data_dir.sh --no-feats $destdir + +utils/validate_data_dir.sh --no-feats --no-text $destdir diff --git a/egs/wsj/s5/utils/scoring/wer_ops_details.pl b/egs/wsj/s5/utils/scoring/wer_ops_details.pl index f894754bdb1..269b31d45b4 100755 --- a/egs/wsj/s5/utils/scoring/wer_ops_details.pl +++ b/egs/wsj/s5/utils/scoring/wer_ops_details.pl @@ -70,7 +70,7 @@ sub max { next if @entries < 2; next if ($entries[1] ne "hyp") and ($entries[1] ne "ref") ; if (scalar @entries <= 2 ) { - print STDERR "Warning: skipping entry \"$_\", either an empty phrase or incompatible format\n" ; + print STDERR "$0: Warning: skipping entry \"$_\", either an empty phrase or incompatible format\n" ; next; } diff --git a/src/base/kaldi-error.h b/src/base/kaldi-error.h index a7bb3fb67a0..172ea675312 100644 --- a/src/base/kaldi-error.h +++ b/src/base/kaldi-error.h @@ -195,19 +195,6 @@ typedef void (*LogHandler)(const LogMessageEnvelope &envelope, /// stderr. SetLogHandler is obviously not thread safe. LogHandler SetLogHandler(LogHandler); - -/***** WRITING 'std::vector' TO LOGPRINT *****/ -template -std::ostream& operator<< (std::ostream& os, const std::vector& v) { - os << "[ "; - typename std::vector::const_iterator it; - for (it = v.begin(); it != v.end(); ++it) { - os << *it << " "; - } - os << "]"; - return os; -} - /// @} end "addtogroup error_group" } // namespace kaldi diff --git a/src/lm/const-arpa-lm.cc b/src/lm/const-arpa-lm.cc index 8c848d245a9..72636ccbd9e 100644 --- a/src/lm/const-arpa-lm.cc +++ b/src/lm/const-arpa-lm.cc @@ -278,7 +278,16 @@ void ConstArpaLmBuilder::ConsumeNGram(const NGram &ngram) { cur_order == ngram_order_ - 1, ngram.logprob, ngram.backoff); - KALDI_ASSERT(seq_to_state_.find(ngram.words) == seq_to_state_.end()); + if (seq_to_state_.find(ngram.words) != seq_to_state_.end()) { + std::ostringstream os; + os << "[ "; + for (size_t i = 0; i < ngram.words.size(); i++) { + os << ngram.words[i] << " "; + } + os <<"]"; + + KALDI_ERR << "N-gram " << os.str() << " appears twice in the arpa file"; + } seq_to_state_[ngram.words] = lm_state; } diff --git a/src/nnetbin/nnet-train-multistream-perutt.cc b/src/nnetbin/nnet-train-multistream-perutt.cc index 154c7fd9c9d..e7bcb0d45b6 100644 --- a/src/nnetbin/nnet-train-multistream-perutt.cc +++ b/src/nnetbin/nnet-train-multistream-perutt.cc @@ -274,7 +274,14 @@ int main(int argc, char *argv[]) { nnet.SetSeqLengths(frame_num_utt); // Show the 'utt' lengths in the VLOG[2], if (GetVerboseLevel() >= 2) { - KALDI_LOG << "frame_num_utt[" << frame_num_utt.size() << "]" << frame_num_utt; + std::ostringstream os; + os << "[ "; + for (size_t i = 0; i < frame_num_utt.size(); i++) { + os << frame_num_utt[i] << " "; + } + os << "]"; + + KALDI_LOG << "frame_num_utt[" << frame_num_utt.size() << "]" << os.str(); } // Reset all the streams (we have new sentences), nnet.ResetStreams(std::vector(frame_num_utt.size(), 1)); diff --git a/src/nnetbin/nnet-train-multistream.cc b/src/nnetbin/nnet-train-multistream.cc index 7424759f45b..1ecd4757d96 100644 --- a/src/nnetbin/nnet-train-multistream.cc +++ b/src/nnetbin/nnet-train-multistream.cc @@ -352,7 +352,14 @@ int main(int argc, char *argv[]) { nnet.SetSeqLengths(frame_num_utt); // Show the 'utt' lengths in the VLOG[2], if (GetVerboseLevel() >= 2) { - KALDI_LOG << "frame_num_utt[" << frame_num_utt.size() << "]" << frame_num_utt; + std::ostringstream os; + os << "[ "; + for (size_t i = 0; i < frame_num_utt.size(); i++) { + os << frame_num_utt[i] << " "; + } + os << "]"; + + KALDI_LOG << "frame_num_utt[" << frame_num_utt.size() << "]" << os.str(); } // with new utterance we reset the history, diff --git a/src/tensorflow/Makefile b/src/tensorflow/Makefile new file mode 100644 index 00000000000..083d22becb0 --- /dev/null +++ b/src/tensorflow/Makefile @@ -0,0 +1,24 @@ +include ../kaldi.mk + +CURDIR = $(shell pwd) +TENSORFLOW = $(CURDIR)/../../tools/tensorflow + +all: + +EXTRA_CXXFLAGS = -Wno-sign-compare -I$(TENSORFLOW)/bazel-tensorflow/external/protobuf/src -I$(TENSORFLOW)/bazel-genfiles -I$(TENSORFLOW) -I$(TENSORFLOW)/tensorflow/contrib/makefile/downloads/eigen/ +#EXTRA_CXXFLAGS = -Wno-sign-compare -fPIC -I$(TENSORFLOW)/bazel-tensorflow/external/protobuf/src -I$(TENSORFLOW)/bazel-genfiles -I$(TENSORFLOW) -I$(TENSORFLOW)/tensorflow/contrib/makefile/downloads/eigen/ + +OBJFILES = tensorflow-rnnlm-lib.o + +TESTFILES = + +LIBNAME = kaldi-tensorflow-rnnlm + +ADDLIBS = ../lm/kaldi-lm.a ../util/kaldi-util.a ../thread/kaldi-thread.a \ + ../matrix/kaldi-matrix.a ../base/kaldi-base.a \ + $(CURDIR)/../../tools/tensorflow/bazel-bin/tensorflow/tensorflow_cc.so + +LDLIBS += -lz -ldl -fPIC -lrt +LDLIBS += $(OTHERLIBS) -L$(TENSORFLOW)/bazel-bin/tensorflow -ltensorflow_cc + +include ../makefiles/default_rules.mk diff --git a/src/tensorflow/tensorflow-rnnlm-lib.cc b/src/tensorflow/tensorflow-rnnlm-lib.cc new file mode 100644 index 00000000000..7513b9207a7 --- /dev/null +++ b/src/tensorflow/tensorflow-rnnlm-lib.cc @@ -0,0 +1,336 @@ +// Copyright 2017 Hainan Xu +// wrapper for tensorflow rnnlm + +#include +#include + +#include "tensorflow/core/public/session.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" + +#include "tensorflow/tensorflow-rnnlm-lib.h" +#include "util/stl-utils.h" +#include "util/text-utils.h" + +namespace kaldi { +using std::ifstream; +using tf_rnnlm::KaldiTfRnnlmWrapper; +using tf_rnnlm::TfRnnlmDeterministicFst; +using tensorflow::Status; + +void SetUnkPenalties(const string &filename, const fst::SymbolTable& fst_word_symbols, + std::vector *out) { + if (filename == "") + return; + out->resize(fst_word_symbols.NumSymbols(), 0); // default is 0 + ifstream ifile(filename.c_str()); + string word; + float count, total_count = 0; + while (ifile >> word >> count) { + int id = fst_word_symbols.Find(word); + KALDI_ASSERT(id != fst::SymbolTable::kNoSymbol); + (*out)[id] = count; + total_count += count; + } + + for (int i = 0; i < out->size(); i++) { + if ((*out)[i] != 0) { + (*out)[i] = log ((*out)[i] / total_count); + } + } +} + +void KaldiTfRnnlmWrapper::ReadTfModel(const std::string &tf_model_path) { + string graph_path = tf_model_path + ".meta"; + + Status status = tensorflow::NewSession(tensorflow::SessionOptions(), &session_); + if (!status.ok()) { + KALDI_ERR << status.ToString(); + } + + tensorflow::MetaGraphDef graph_def; + status = tensorflow::ReadBinaryProto(tensorflow::Env::Default(), graph_path, &graph_def); + if (!status.ok()) { + KALDI_ERR << status.ToString(); + } + + // Add the graph to the session + status = session_->Create(graph_def.graph_def()); + if (!status.ok()) { + KALDI_ERR << status.ToString(); + } + + Tensor checkpointPathTensor(tensorflow::DT_STRING, tensorflow::TensorShape()); + checkpointPathTensor.scalar()() = tf_model_path; + + status = session_->Run( + {{ graph_def.saver_def().filename_tensor_name(), checkpointPathTensor },}, + {}, + {graph_def.saver_def().restore_op_name()}, + nullptr); + if (!status.ok()) { + KALDI_ERR << status.ToString(); + } +} + +KaldiTfRnnlmWrapper::KaldiTfRnnlmWrapper( + const KaldiTfRnnlmWrapperOpts &opts, + const std::string &rnn_wordlist, + const std::string &word_symbol_table_rxfilename, + const std::string &unk_prob_file, + const std::string &tf_model_path): opts_(opts) { + ReadTfModel(tf_model_path); + + fst::SymbolTable *fst_word_symbols = NULL; + if (!(fst_word_symbols = + fst::SymbolTable::ReadText(word_symbol_table_rxfilename))) { + KALDI_ERR << "Could not read symbol table from file " + << word_symbol_table_rxfilename; + } + + fst_label_to_word_.resize(fst_word_symbols->NumSymbols()); + + for (int32 i = 0; i < fst_label_to_word_.size(); ++i) { + fst_label_to_word_[i] = fst_word_symbols->Find(i); + if (fst_label_to_word_[i] == "") { + KALDI_ERR << "Could not find word for integer " << i << "in the word " + << "symbol table, mismatched symbol table or you have discoutinuous " + << "integers in your symbol table?"; + } + } + + fst_label_to_rnn_label_.resize(fst_word_symbols->NumSymbols(), -1); + num_total_words = fst_word_symbols->NumSymbols(); + + // read rnn wordlist and then generate ngram-label-to-rnn-label map + oos_ = -1; + { // input. + ifstream ifile(rnn_wordlist.c_str()); + string word; + int id = -1; + eos_ = 0; + while (ifile >> word) { + id++; + rnn_label_to_word_.push_back(word); // vector[i] = word + + int fst_label = fst_word_symbols->Find(word); + if (fst::SymbolTable::kNoSymbol == fst_label) { + if (id == eos_) { + KALDI_ASSERT(word == opts_.eos_symbol); + continue; + } +// KALDI_LOG << word << " " << opts_.unk_symbol << " " << oos_; + KALDI_ASSERT(word == opts_.unk_symbol && oos_ == -1); + oos_ = id; + continue; + } + KALDI_ASSERT(fst_label >= 0); + fst_label_to_rnn_label_[fst_label] = id; + } + } + if (fst_label_to_word_.size() > rnn_label_to_word_.size()) { + KALDI_ASSERT(oos_ != -1); + } + num_rnn_words = rnn_label_to_word_.size(); + + // we must have a oos symbol in the wordlist + if (oos_ == -1) { + return; + } + for (int i = 0; i < fst_label_to_rnn_label_.size(); i++) { + if (fst_label_to_rnn_label_[i] == -1) { + fst_label_to_rnn_label_[i] = oos_; + } + } + + AcquireInitialTensors(); + SetUnkPenalties(unk_prob_file, *fst_word_symbols, &unk_probs_); +} + +void KaldiTfRnnlmWrapper::AcquireInitialTensors() { + Status status; + // get the initial context + { + std::vector state; + status = session_->Run(std::vector>(), {"Train/Model/test_initial_state"}, {}, &state); + if (!status.ok()) { + KALDI_ERR << status.ToString(); + } + initial_context_ = state[0]; + } + + { + std::vector state; + Tensor bosword(tensorflow::DT_INT32, {1, 1}); + bosword.scalar()() = eos_; // eos_ is more like a sentence boundary + + std::vector> inputs = { + {"Train/Model/test_word_in", bosword}, + {"Train/Model/test_state_in", initial_context_}, + }; + + status = session_->Run(inputs, {"Train/Model/test_cell_out"}, {}, &state); + if (!status.ok()) { + KALDI_ERR << status.ToString(); + } + initial_cell_ = state[0]; + } +} + +BaseFloat KaldiTfRnnlmWrapper::GetLogProb( + int32 word, + int32 fst_word, +// const std::vector &wseq, + const Tensor &context_in, + const Tensor &cell_in, + Tensor *context_out, + Tensor *new_cell) { + + std::vector> inputs; + + Tensor thisword(tensorflow::DT_INT32, {1, 1}); + + thisword.scalar()() = word; + std::vector outputs; + + if (context_out != NULL) { + inputs = { + {"Train/Model/test_word_in", thisword}, + {"Train/Model/test_word_out", thisword}, + {"Train/Model/test_state_in", context_in}, + {"Train/Model/test_cell_in", cell_in}, + }; + + // The session will initialize the outputs + // Run the session, evaluating our "c" operation from the graph + Status status = session_->Run(inputs, + {"Train/Model/test_out", + "Train/Model/test_state_out", + "Train/Model/test_cell_out"}, {}, &outputs); + if (!status.ok()) { + KALDI_ERR << status.ToString(); + } + + *context_out = outputs[1]; + *new_cell = outputs[2]; + } else { + inputs = { + {"Train/Model/test_word_out", thisword}, + {"Train/Model/test_cell_in", cell_in}, + }; + + // Run the session, evaluating our "c" operation from the graph + Status status = session_->Run(inputs, + {"Train/Model/test_out"}, {}, &outputs); + if (!status.ok()) { + KALDI_ERR << status.ToString(); + } + } + + float ans; + if (word != oos_) { + ans = outputs[0].scalar()(); + } else { + if (unk_probs_.size() == 0) { + ans = outputs[0].scalar()() - log (num_total_words - num_rnn_words); + } else { + ans = outputs[0].scalar()() + unk_probs_[fst_word]; + } + } + +// KALDI_LOG << "Computing logprob of word " << rnn_label_to_word_[word] << "(" << word << ")" +// << " given history " << his_str.str() << " is " << exp(ans); +// KALDI_LOG << "prob is " << outputs[0].scalar()(); + return ans; +} + +const Tensor& KaldiTfRnnlmWrapper::GetInitialContext() const { + return initial_context_; +} + +const Tensor& KaldiTfRnnlmWrapper::GetInitialCell() const { + return initial_cell_; +} + +TfRnnlmDeterministicFst::TfRnnlmDeterministicFst(int32 max_ngram_order, + KaldiTfRnnlmWrapper *rnnlm) { + KALDI_ASSERT(rnnlm != NULL); + max_ngram_order_ = max_ngram_order; + rnnlm_ = rnnlm; + + // Uses empty history for . + std::vector") {} + + void Register(OptionsItf *opts) { + opts->Register("unk-symbol", &unk_symbol, "Symbol for out-of-vocabulary " + "words in rnnlm."); + opts->Register("eos-symbol", &eos_symbol, "End of setence symbol in " + "rnnlm."); + } +}; + +class KaldiTfRnnlmWrapper { + public: + KaldiTfRnnlmWrapper(const KaldiTfRnnlmWrapperOpts &opts, + const std::string &rnn_wordlist, + const std::string &word_symbol_table_rxfilename, + const std::string &unk_prob_file, + const std::string &tf_model_path); + + ~KaldiTfRnnlmWrapper() { + session_->Close(); + } + + int32 GetEos() const { return eos_; } + + // get an all-zero Tensor of the size that matches the hidden state of the TF model + const Tensor& GetInitialContext() const; + + // get the 2nd-to-last layer of RNN when feeding input of + // (initial-context, sentence-boundary) + const Tensor& GetInitialCell() const; + + // compute p(word | wseq) and return the log of that + // the computation used the input cell, + // which is the 2nd-to-last layer of the RNNLM associated with history wseq; + // + // and we generate (context_out, new_cell) by passing (context_in, word) into the model + BaseFloat GetLogProb(int32 word, + int32 fst_word, + const Tensor &context_in, // context to pass into RNN + const Tensor &cell_in, // 2nd-to-last layer + Tensor *context_out, + Tensor *new_cell); + + std::vector fst_label_to_rnn_label_; + std::vector rnn_label_to_word_; + std::vector fst_label_to_word_; + private: + void ReadTfModel(const std::string &tf_model_path); + + // do queries on the session to get the initial tensors (cell + context) + void AcquireInitialTensors(); + + KaldiTfRnnlmWrapperOpts opts_; + Tensor initial_context_; + Tensor initial_cell_; + int32 num_total_words; + int32 num_rnn_words; + + Session* session_; // owned here + int32 eos_; + int32 oos_; + + std::vector unk_probs_; + + KALDI_DISALLOW_COPY_AND_ASSIGN(KaldiTfRnnlmWrapper); +}; + +class TfRnnlmDeterministicFst + : public fst::DeterministicOnDemandFst { + public: + typedef fst::StdArc::Weight Weight; + typedef fst::StdArc::StateId StateId; + typedef fst::StdArc::Label Label; + + // Does not take ownership. + TfRnnlmDeterministicFst(int32 max_ngram_order, KaldiTfRnnlmWrapper *rnnlm); + + // We cannot use "const" because the pure virtual function in the interface is + // not const. + virtual StateId Start() { return start_state_; } + + // We cannot use "const" because the pure virtual function in the interface is + // not const. + virtual Weight Final(StateId s); + + virtual bool GetArc(StateId s, Label ilabel, fst::StdArc* oarc); + + private: + typedef unordered_map, + StateId, VectorHasher