diff --git a/egs/wsj/s5/steps/libs/nnet3/train/common.py b/egs/wsj/s5/steps/libs/nnet3/train/common.py index e8f711e992e..62d4e302351 100644 --- a/egs/wsj/s5/steps/libs/nnet3/train/common.py +++ b/egs/wsj/s5/steps/libs/nnet3/train/common.py @@ -901,5 +901,16 @@ def __init__(self, then only failure notifications are sent""") +def is_lda_added(config_dir): + """Returns true if there is an lda.mat in init.config + which suggests the LDA matrix training stage needs to be executed in the + nnet training""" + for line in open("{config_dir}/init.config".format(config_dir)): + line = line.strip().split("#")[0] + if re.search(r"lda\.mat", line): + return True + return False + + if __name__ == '__main__': self_test() diff --git a/egs/wsj/s5/steps/nnet3/train_raw_dnn.py b/egs/wsj/s5/steps/nnet3/train_raw_dnn.py index 57886afefd8..8cf17267bb9 100755 --- a/egs/wsj/s5/steps/nnet3/train_raw_dnn.py +++ b/egs/wsj/s5/steps/nnet3/train_raw_dnn.py @@ -80,6 +80,7 @@ def get_args(): rule as accepted by the --minibatch-size option of nnet3-merge-egs; run that program without args to see the format.""") + parser.add_argument("--trainer.optimization.proportional-shrink", type=float, dest='proportional_shrink', default=0.0, help="""If nonzero, this will set a shrinkage (scaling) @@ -92,6 +93,11 @@ def get_args(): Unlike for train_rnn.py, this is applied unconditionally, it does not depend on saturation of nonlinearities. Can be used to roughly approximate l2 regularization.""") + parser.add_argument("--compute-average-posteriors", + type=str, action=common_lib.StrToBoolAction, + choices=["true", "false"], default=False, + help="""If true, then the average output of the + network is computed and dumped as post.final.vec""") # General options parser.add_argument("--nj", type=int, default=4, @@ -198,11 +204,7 @@ def train(args, run_opts): try: model_left_context = variables['model_left_context'] model_right_context = variables['model_right_context'] - if 'include_log_softmax' in variables: - include_log_softmax = common_lib.str_to_bool( - variables['include_log_softmax']) - else: - include_log_softmax = False + except KeyError as e: raise Exception("KeyError {0}: Variables need to be defined in " "{1}".format(str(e), '{0}/configs'.format(args.dir))) @@ -286,7 +288,9 @@ def train(args, run_opts): # use during decoding common_train_lib.copy_egs_properties_to_exp_dir(egs_dir, args.dir) - if (args.stage <= -3) and os.path.exists(args.dir+"/configs/init.config"): + add_lda = common_train_lib.is_lda_added(config_dir) + + if (add_lda and args.stage <= -3): logger.info('Computing the preconditioning matrix for input features') train_lib.common.compute_preconditioning_matrix( @@ -417,7 +421,7 @@ def learning_rate(iter, current_num_jobs, num_archives_processed): common_lib.force_symlink("{0}.raw".format(num_iters), "{0}/final.raw".format(args.dir)) - if include_log_softmax and args.stage <= num_iters + 1: + if compute_average_posteriors and args.stage <= num_iters + 1: logger.info("Getting average posterior for output-node 'output'.") train_lib.common.compute_average_posterior( dir=args.dir, iter='final', egs_dir=egs_dir,