diff --git a/egs/wsj/s5/steps/libs/nnet3/train/common.py b/egs/wsj/s5/steps/libs/nnet3/train/common.py index 7a5646e7a4d..49565e6bc7e 100644 --- a/egs/wsj/s5/steps/libs/nnet3/train/common.py +++ b/egs/wsj/s5/steps/libs/nnet3/train/common.py @@ -45,13 +45,13 @@ def get_outputs_list(model_file, get_raw_nnet_from_am=True): It will normally return 'output'. """ if get_raw_nnet_from_am: - outputs_list = common_lib.get_command_stdout( - "nnet3-am-info --print-args=false {0} | " - "grep -e 'output-node' | cut -f2 -d' ' | cut -f2 -d'=' ".format(model_file)) + outputs_list = common_lib.get_command_stdout( + "nnet3-am-info --print-args=false {0} | " + "grep -e 'output-node' | cut -f2 -d' ' | cut -f2 -d'=' ".format(model_file)) else: - outputs_list = common_lib.get_command_stdout( - "nnet3-info --print-args=false {0} | " - "grep -e 'output-node' | cut -f2 -d' ' | cut -f2 -d'=' ".format(model_file)) + outputs_list = common_lib.get_command_stdout( + "nnet3-info --print-args=false {0} | " + "grep -e 'output-node' | cut -f2 -d' ' | cut -f2 -d'=' ".format(model_file)) return outputs_list.split() @@ -71,13 +71,13 @@ def get_multitask_egs_opts(egs_dir, egs_prefix="", "valid_diagnostic." for validation. """ multitask_egs_opts = "" - egs_suffix = ".{0}".format(archive_index) if archive_index > -1 else "" + egs_suffix = ".{0}".format(archive_index) if archive_index > -1 else "" if use_multitask_egs: output_file_name = ("{egs_dir}/{egs_prefix}output{egs_suffix}.ark" "".format(egs_dir=egs_dir, - egs_prefix=egs_prefix, - egs_suffix=egs_suffix)) + egs_prefix=egs_prefix, + egs_suffix=egs_suffix)) output_rename_opt = "" if os.path.isfile(output_file_name): output_rename_opt = ("--outputs=ark:{output_file_name}".format( @@ -102,7 +102,7 @@ def get_multitask_egs_opts(egs_dir, egs_prefix="", def get_successful_models(num_models, log_file_pattern, difference_threshold=1.0): - assert(num_models > 0) + assert num_models > 0 parse_regex = re.compile( "LOG .* Overall average objective function for " @@ -163,9 +163,9 @@ def get_best_nnet_model(dir, iter, best_model_index, run_opts, get_raw_nnet_from_am=True): best_model = "{dir}/{next_iter}.{best_model_index}.raw".format( - dir=dir, - next_iter=iter + 1, - best_model_index=best_model_index) + dir=dir, + next_iter=iter + 1, + best_model_index=best_model_index) if get_raw_nnet_from_am: out_model = ("""- \| nnet3-am-copy --set-raw-nnet=- \ @@ -406,21 +406,21 @@ def verify_egs_dir(egs_dir, feat_dim, ivector_dim, ivector_extractor_id, logger.warning("The ivector ids are inconsistently used. It's your " "responsibility to make sure the ivector extractor " "has been used consistently") - elif (((egs_ivector_id is None) and (ivector_extractor_id is None))): + elif ((egs_ivector_id is None) and (ivector_extractor_id is None)): logger.warning("The ivector ids are not used. It's your " "responsibility to make sure the ivector extractor " "has been used consistently") - elif (ivector_extractor_id != egs_ivector_id): + elif ivector_extractor_id != egs_ivector_id: raise Exception("The egs were generated using a different ivector " "extractor. id1 = {0}, id2={1}".format( ivector_extractor_id, egs_ivector_id)); if (egs_left_context < left_context or - egs_right_context < right_context): + egs_right_context < right_context): raise Exception('The egs have insufficient (l,r) context ({0},{1}) ' 'versus expected ({2},{3})'.format( - egs_left_context, egs_right_context, - left_context, right_context)) + egs_left_context, egs_right_context, + left_context, right_context)) # the condition on the initial/final context is an equality condition, # not an inequality condition, as there is no mechanism to 'correct' the @@ -435,8 +435,8 @@ def verify_egs_dir(egs_dir, feat_dim, ivector_dim, ivector_extractor_id, raise Exception('The egs have incorrect initial/final (l,r) context ' '({0},{1}) versus expected ({2},{3}). See code from ' 'where this exception was raised for more info'.format( - egs_left_context_initial, egs_right_context_final, - left_context_initial, right_context_final)) + egs_left_context_initial, egs_right_context_final, + left_context_initial, right_context_final)) frames_per_eg_str = open('{0}/info/frames_per_eg'.format( egs_dir)).readline().rstrip() @@ -476,9 +476,9 @@ def compute_presoftmax_prior_scale(dir, alidir, num_jobs, run_opts, os.remove(file) pdf_counts = common_lib.read_kaldi_matrix('{0}/pdf_counts'.format(dir))[0] scaled_counts = smooth_presoftmax_prior_scale_vector( - pdf_counts, - presoftmax_prior_scale_power=presoftmax_prior_scale_power, - smooth=0.01) + pdf_counts, + presoftmax_prior_scale_power=presoftmax_prior_scale_power, + smooth=0.01) output_file = "{0}/presoftmax_prior_scale.vec".format(dir) common_lib.write_kaldi_matrix(output_file, [scaled_counts]) @@ -601,7 +601,7 @@ def should_do_shrinkage(iter, model_file, shrink_saturation_threshold, "saturation from the output '{0}' of " "get_saturation.pl on the info of " "model {1}".format(output, model_file)) - return (saturation > shrink_saturation_threshold) + return saturation > shrink_saturation_threshold def remove_nnet_egs(egs_dir): @@ -651,7 +651,7 @@ def self_test(): assert validate_chunk_width('64,25,128') -class CommonParser: +class CommonParser(object): """Parser for parsing common options related to nnet3 training. This argument parser adds common options related to nnet3 training @@ -663,7 +663,7 @@ class CommonParser: parser = argparse.ArgumentParser(add_help=False) def __init__(self, - include_chunk_context = True, + include_chunk_context=True, default_chunk_left_context=0): # feat options self.parser.add_argument("--feat.online-ivector-dir", type=str, @@ -692,11 +692,11 @@ def __init__(self, the case of FF-DNN this extra context will be used to allow for frame-shifts""") self.parser.add_argument("--egs.chunk-right-context", type=int, - dest='chunk_right_context', default=0, - help="""Number of additional frames of input - to the right of the input chunk. This extra - context will be used in the estimation of - bidirectional RNN state before prediction of + dest='chunk_right_context', default=0, + help="""Number of additional frames of input + to the right of the input chunk. This extra + context will be used in the estimation of + bidirectional RNN state before prediction of the first label.""") self.parser.add_argument("--egs.chunk-left-context-initial", type=int, dest='chunk_left_context_initial', default=-1, @@ -780,6 +780,18 @@ def __init__(self, dest='presoftmax_prior_scale_power', default=-0.25, help="Scale on presofmax prior") + self.parser.add_argument("--trainer.optimization.proportional-shrink", type=float, + dest='proportional_shrink', default=0.0, + help="""If nonzero, this will set a shrinkage (scaling) + factor for the parameters, whose value is set as: + shrink-value=(1.0 - proportional-shrink * learning-rate), where + 'learning-rate' is the learning rate being applied + on the current iteration, which will vary from + initial-effective-lrate*num-jobs-initial to + final-effective-lrate*num-jobs-final. + Unlike for train_rnn.py, this is applied unconditionally, + it does not depend on saturation of nonlinearities. + Can be used to roughly approximate l2 regularization.""") # Parameters for the optimization self.parser.add_argument( diff --git a/egs/wsj/s5/steps/nnet3/chain/train.py b/egs/wsj/s5/steps/nnet3/chain/train.py index a6e63c6da0a..dacfae99a2a 100755 --- a/egs/wsj/s5/steps/nnet3/chain/train.py +++ b/egs/wsj/s5/steps/nnet3/chain/train.py @@ -145,19 +145,6 @@ def get_args(): steps/nnet3/get_saturation.pl) exceeds this threshold we scale the parameter matrices with the shrink-value.""") - parser.add_argument("--trainer.optimization.proportional-shrink", type=float, - dest='proportional_shrink', default=0.0, - help="""If nonzero, this will set a shrinkage (scaling) - factor for the parameters, whose value is set as: - shrink-value=(1.0 - proportional-shrink * learning-rate), where - 'learning-rate' is the learning rate being applied - on the current iteration, which will vary from - initial-effective-lrate*num-jobs-initial to - final-effective-lrate*num-jobs-final. - Unlike for train_rnn.py, this is applied unconditionally, - it does not depend on saturation of nonlinearities. - Can be used to roughly approximate l2 regularization.""") - # RNN-specific training options parser.add_argument("--trainer.deriv-truncate-margin", type=int, dest='deriv_truncate_margin', default=None, @@ -419,14 +406,6 @@ def train(args, run_opts): num_archives_expanded, args.max_models_combine, args.num_jobs_final) - def learning_rate(iter, current_num_jobs, num_archives_processed): - return common_train_lib.get_learning_rate(iter, current_num_jobs, - num_iters, - num_archives_processed, - num_archives_to_process, - args.initial_effective_lrate, - args.final_effective_lrate) - min_deriv_time = None max_deriv_time_relative = None if args.deriv_truncate_margin is not None: @@ -448,22 +427,23 @@ def learning_rate(iter, current_num_jobs, num_archives_processed): if args.stage <= iter: model_file = "{dir}/{iter}.mdl".format(dir=args.dir, iter=iter) - lrate = learning_rate(iter, current_num_jobs, - num_archives_processed) - shrink_value = 1.0 - if args.proportional_shrink != 0.0: - shrink_value = 1.0 - (args.proportional_shrink * lrate) - if shrink_value <= 0.5: - raise Exception("proportional-shrink={0} is too large, it gives " - "shrink-value={1}".format(args.proportional_shrink, - shrink_value)) - - if args.shrink_value < shrink_value: - shrink_value = (args.shrink_value + lrate = common_train_lib.get_learning_rate(iter, current_num_jobs, + num_iters, + num_archives_processed, + num_archives_to_process, + args.initial_effective_lrate, + args.final_effective_lrate) + shrinkage_value = 1.0 - (args.proportional_shrink * lrate) + if shrinkage_value <= 0.5: + raise Exception("proportional-shrink={0} is too large, it gives " + "shrink-value={1}".format(args.proportional_shrink, + shrinkage_value)) + if args.shrink_value < shrinkage_value: + shrinkage_value = (args.shrink_value if common_train_lib.should_do_shrinkage( iter, model_file, args.shrink_saturation_threshold) - else shrink_value) + else shrinkage_value) chain_lib.train_one_iteration( dir=args.dir, @@ -478,7 +458,7 @@ def learning_rate(iter, current_num_jobs, num_archives_processed): args.dropout_schedule, float(num_archives_processed) / num_archives_to_process, iter), - shrinkage_value=shrink_value, + shrinkage_value=shrinkage_value, num_chunk_per_minibatch_str=args.num_chunk_per_minibatch, apply_deriv_weights=args.apply_deriv_weights, min_deriv_time=min_deriv_time, diff --git a/egs/wsj/s5/steps/nnet3/components.py b/egs/wsj/s5/steps/nnet3/components.py index 3fb92117d78..0782edd75a8 100644 --- a/egs/wsj/s5/steps/nnet3/components.py +++ b/egs/wsj/s5/steps/nnet3/components.py @@ -1,4 +1,5 @@ #!/usr/bin/env python +# Note: this file is part of some nnet3 config-creation tools that are now deprecated. from __future__ import print_function import os diff --git a/egs/wsj/s5/steps/nnet3/train_dnn.py b/egs/wsj/s5/steps/nnet3/train_dnn.py index fc79b3e3e8d..f28ddfeacd2 100755 --- a/egs/wsj/s5/steps/nnet3/train_dnn.py +++ b/egs/wsj/s5/steps/nnet3/train_dnn.py @@ -49,7 +49,7 @@ def get_args(): cross-entropy objective. DNNs include simple DNNs, TDNNs and CNNs.""", formatter_class=argparse.ArgumentDefaultsHelpFormatter, conflict_handler='resolve', - parents=[common_train_lib.CommonParser(include_chunk_context = False).parser]) + parents=[common_train_lib.CommonParser(include_chunk_context=False).parser]) # egs extraction options parser.add_argument("--egs.frames-per-eg", type=int, dest='frames_per_eg', @@ -105,7 +105,7 @@ def process_args(args): raise Exception("--egs.frames-per-eg should have a minimum value of 1") if not common_train_lib.validate_minibatch_size_str(args.minibatch_size): - raise Exception("--trainer.rnn.num-chunk-per-minibatch has an invalid value"); + raise Exception("--trainer.rnn.num-chunk-per-minibatch has an invalid value") if (not os.path.exists(args.dir) or not os.path.exists(args.dir+"/configs")): @@ -172,7 +172,7 @@ def train(args, run_opts): # split the training data into parts for individual jobs # we will use the same number of jobs as that used for alignment common_lib.execute_command("utils/split_data.sh {0} {1}".format( - args.feat_dir, num_jobs)) + args.feat_dir, num_jobs)) shutil.copy('{0}/tree'.format(args.ali_dir), args.dir) with open('{0}/num_jobs'.format(args.dir), 'w') as f: @@ -235,12 +235,12 @@ def train(args, run_opts): [egs_left_context, egs_right_context, frames_per_eg_str, num_archives] = ( - common_train_lib.verify_egs_dir(egs_dir, feat_dim, - ivector_dim, ivector_id, - left_context, right_context)) - assert(str(args.frames_per_eg) == frames_per_eg_str) + common_train_lib.verify_egs_dir(egs_dir, feat_dim, + ivector_dim, ivector_id, + left_context, right_context)) + assert str(args.frames_per_eg) == frames_per_eg_str - if (args.num_jobs_final > num_archives): + if args.num_jobs_final > num_archives: raise Exception('num_jobs_final cannot exceed the number of archives ' 'in the egs directory') @@ -248,7 +248,7 @@ def train(args, run_opts): # use during decoding common_train_lib.copy_egs_properties_to_exp_dir(egs_dir, args.dir) - if (args.stage <= -3) and os.path.exists(args.dir+"/configs/init.config"): + if args.stage <= -3 and os.path.exists(args.dir+"/configs/init.config"): logger.info('Computing the preconditioning matrix for input features') train_lib.common.compute_preconditioning_matrix( @@ -256,17 +256,17 @@ def train(args, run_opts): max_lda_jobs=args.max_lda_jobs, rand_prune=args.rand_prune) - if (args.stage <= -2): + if args.stage <= -2: logger.info("Computing initial vector for FixedScaleComponent before" " softmax, using priors^{prior_scale} and rescaling to" " average 1".format( prior_scale=args.presoftmax_prior_scale_power)) common_train_lib.compute_presoftmax_prior_scale( - args.dir, args.ali_dir, num_jobs, run_opts, - presoftmax_prior_scale_power=args.presoftmax_prior_scale_power) + args.dir, args.ali_dir, num_jobs, run_opts, + presoftmax_prior_scale_power=args.presoftmax_prior_scale_power) - if (args.stage <= -1): + if args.stage <= -1: logger.info("Preparing the initial acoustic model.") train_lib.acoustic_model.prepare_initial_acoustic_model( args.dir, args.ali_dir, run_opts) @@ -286,14 +286,6 @@ def train(args, run_opts): num_archives_expanded, args.max_models_combine, args.num_jobs_final) - def learning_rate(iter, current_num_jobs, num_archives_processed): - return common_train_lib.get_learning_rate(iter, current_num_jobs, - num_iters, - num_archives_processed, - num_archives_to_process, - args.initial_effective_lrate, - args.final_effective_lrate) - logger.info("Training will run for {0} epochs = " "{1} iterations".format(args.num_epochs, num_iters)) @@ -306,6 +298,18 @@ def learning_rate(iter, current_num_jobs, num_archives_processed): * float(iter) / num_iters) if args.stage <= iter: + lrate = common_train_lib.get_learning_rate(iter, current_num_jobs, + num_iters, + num_archives_processed, + num_archives_to_process, + args.initial_effective_lrate, + args.final_effective_lrate) + shrinkage_value = 1.0 - (args.proportional_shrink * lrate) + if shrinkage_value <= 0.5: + raise Exception("proportional-shrink={0} is too large, it gives " + "shrink-value={1}".format(args.proportional_shrink, + shrinkage_value)) + train_lib.common.train_one_iteration( dir=args.dir, iter=iter, @@ -314,8 +318,7 @@ def learning_rate(iter, current_num_jobs, num_archives_processed): num_jobs=current_num_jobs, num_archives_processed=num_archives_processed, num_archives=num_archives, - learning_rate=learning_rate(iter, current_num_jobs, - num_archives_processed), + learning_rate=lrate, dropout_edit_string=common_train_lib.get_dropout_edit_string( args.dropout_schedule, float(num_archives_processed) / num_archives_to_process, @@ -324,6 +327,7 @@ def learning_rate(iter, current_num_jobs, num_archives_processed): frames_per_eg=args.frames_per_eg, momentum=args.momentum, max_param_change=args.max_param_change, + shrinkage_value=shrinkage_value, shuffle_buffer_size=args.shuffle_buffer_size, run_opts=run_opts) diff --git a/egs/wsj/s5/steps/nnet3/train_raw_dnn.py b/egs/wsj/s5/steps/nnet3/train_raw_dnn.py index 089046a510e..3311e51d88f 100755 --- a/egs/wsj/s5/steps/nnet3/train_raw_dnn.py +++ b/egs/wsj/s5/steps/nnet3/train_raw_dnn.py @@ -8,6 +8,7 @@ raw neural network instead of an acoustic model. """ +from __future__ import print_function import argparse import logging import pprint @@ -47,7 +48,7 @@ def get_args(): DNNs include simple DNNs, TDNNs and CNNs.""", formatter_class=argparse.ArgumentDefaultsHelpFormatter, conflict_handler='resolve', - parents=[common_train_lib.CommonParser(include_chunk_context = False).parser]) + parents=[common_train_lib.CommonParser(include_chunk_context=False).parser]) # egs extraction options parser.add_argument("--egs.frames-per-eg", type=int, dest='frames_per_eg', @@ -80,18 +81,11 @@ def get_args(): rule as accepted by the --minibatch-size option of nnet3-merge-egs; run that program without args to see the format.""") - parser.add_argument("--trainer.optimization.proportional-shrink", type=float, - dest='proportional_shrink', default=0.0, - help="""If nonzero, this will set a shrinkage (scaling) - factor for the parameters, whose value is set as: - shrink-value=(1.0 - proportional-shrink * learning-rate), where - 'learning-rate' is the learning rate being applied - on the current iteration, which will vary from - initial-effective-lrate*num-jobs-initial to - final-effective-lrate*num-jobs-final. - Unlike for train_rnn.py, this is applied unconditionally, - it does not depend on saturation of nonlinearities. - Can be used to roughly approximate l2 regularization.""") + parser.add_argument("--compute-average-posteriors", + type=str, action=common_lib.StrToBoolAction, + choices=["true", "false"], default=False, + help="""If true, then the average output of the + network is computed and dumped as post.final.vec""") # General options parser.add_argument("--nj", type=int, default=4, @@ -127,7 +121,7 @@ def process_args(args): raise Exception("--egs.frames-per-eg should have a minimum value of 1") if not common_train_lib.validate_minibatch_size_str(args.minibatch_size): - raise Exception("--trainer.optimization.minibatch-size has an invalid value"); + raise Exception("--trainer.optimization.minibatch-size has an invalid value") if (not os.path.exists(args.dir) or not os.path.exists(args.dir+"/configs")): @@ -198,11 +192,7 @@ def train(args, run_opts): try: model_left_context = variables['model_left_context'] model_right_context = variables['model_right_context'] - if 'include_log_softmax' in variables: - include_log_softmax = common_lib.str_to_bool( - variables['include_log_softmax']) - else: - include_log_softmax = False + except KeyError as e: raise Exception("KeyError {0}: Variables need to be defined in " "{1}".format(str(e), '{0}/configs'.format(args.dir))) @@ -273,12 +263,12 @@ def train(args, run_opts): [egs_left_context, egs_right_context, frames_per_eg_str, num_archives] = ( - common_train_lib.verify_egs_dir(egs_dir, feat_dim, - ivector_dim, ivector_id, - left_context, right_context)) - assert(str(args.frames_per_eg) == frames_per_eg_str) + common_train_lib.verify_egs_dir(egs_dir, feat_dim, + ivector_dim, ivector_id, + left_context, right_context)) + assert str(args.frames_per_eg) == frames_per_eg_str - if (args.num_jobs_final > num_archives): + if args.num_jobs_final > num_archives: raise Exception('num_jobs_final cannot exceed the number of archives ' 'in the egs directory') @@ -286,7 +276,7 @@ def train(args, run_opts): # use during decoding common_train_lib.copy_egs_properties_to_exp_dir(egs_dir, args.dir) - if (args.stage <= -3) and os.path.exists(args.dir+"/configs/init.config"): + if args.stage <= -3 and os.path.exists(args.dir+"/configs/init.config"): logger.info('Computing the preconditioning matrix for input features') train_lib.common.compute_preconditioning_matrix( @@ -294,7 +284,7 @@ def train(args, run_opts): max_lda_jobs=args.max_lda_jobs, rand_prune=args.rand_prune) - if (args.stage <= -1): + if args.stage <= -1: logger.info("Preparing the initial network.") common_train_lib.prepare_initial_network(args.dir, run_opts) @@ -313,28 +303,20 @@ def train(args, run_opts): num_archives_expanded, args.max_models_combine, args.num_jobs_final) - if (os.path.exists('{0}/valid_diagnostic.scp'.format(args.egs_dir))): - if (os.path.exists('{0}/valid_diagnostic.egs'.format(args.egs_dir))): + if os.path.exists('{0}/valid_diagnostic.scp'.format(args.egs_dir)): + if os.path.exists('{0}/valid_diagnostic.egs'.format(args.egs_dir)): raise Exception('both {0}/valid_diagnostic.egs and ' '{0}/valid_diagnostic.scp exist.' 'This script expects only one of them to exist.' ''.format(args.egs_dir)) use_multitask_egs = True else: - if (not os.path.exists('{0}/valid_diagnostic.egs'.format(args.egs_dir))): + if not os.path.exists('{0}/valid_diagnostic.egs'.format(args.egs_dir)): raise Exception('neither {0}/valid_diagnostic.egs nor ' '{0}/valid_diagnostic.scp exist.' 'This script expects one of them.'.format(args.egs_dir)) use_multitask_egs = False - def learning_rate(iter, current_num_jobs, num_archives_processed): - return common_train_lib.get_learning_rate(iter, current_num_jobs, - num_iters, - num_archives_processed, - num_archives_to_process, - args.initial_effective_lrate, - args.final_effective_lrate) - logger.info("Training will run for {0} epochs = " "{1} iterations".format(args.num_epochs, num_iters)) @@ -346,18 +328,20 @@ def learning_rate(iter, current_num_jobs, num_archives_processed): + (args.num_jobs_final - args.num_jobs_initial) * float(iter) / num_iters) - lrate = learning_rate(iter, current_num_jobs, - num_archives_processed) - shrink_value = 1.0 - if args.proportional_shrink != 0.0: - shrink_value = 1.0 - (args.proportional_shrink * lrate) - if shrink_value <= 0.5: + if args.stage <= iter: + lrate = common_train_lib.get_learning_rate(iter, current_num_jobs, + num_iters, + num_archives_processed, + num_archives_to_process, + args.initial_effective_lrate, + args.final_effective_lrate) + + shrinkage_value = 1.0 - (args.proportional_shrink * lrate) + if shrinkage_value <= 0.5: raise Exception("proportional-shrink={0} is too large, it gives " "shrink-value={1}".format(args.proportional_shrink, - shrink_value)) - + shrinkage_value)) - if args.stage <= iter: train_lib.common.train_one_iteration( dir=args.dir, iter=iter, @@ -375,7 +359,7 @@ def learning_rate(iter, current_num_jobs, num_archives_processed): frames_per_eg=args.frames_per_eg, momentum=args.momentum, max_param_change=args.max_param_change, - shrinkage_value=shrink_value, + shrinkage_value=shrinkage_value, shuffle_buffer_size=args.shuffle_buffer_size, run_opts=run_opts, get_raw_nnet_from_am=False, @@ -383,7 +367,7 @@ def learning_rate(iter, current_num_jobs, num_archives_processed): use_multitask_egs=use_multitask_egs) if args.cleanup: - # do a clean up everythin but the last 2 models, under certain + # do a clean up everything but the last 2 models, under certain # conditions common_train_lib.remove_model( args.dir, iter-2, num_iters, models_to_combine, @@ -417,7 +401,7 @@ def learning_rate(iter, current_num_jobs, num_archives_processed): common_lib.force_symlink("{0}.raw".format(num_iters), "{0}/final.raw".format(args.dir)) - if include_log_softmax and args.stage <= num_iters + 1: + if args.compute_average_posteriors and args.stage <= num_iters + 1: logger.info("Getting average posterior for output-node 'output'.") train_lib.common.compute_average_posterior( dir=args.dir, iter='final', egs_dir=egs_dir, diff --git a/egs/wsj/s5/steps/nnet3/train_raw_rnn.py b/egs/wsj/s5/steps/nnet3/train_raw_rnn.py index 812be8b95f3..489185ae72c 100755 --- a/egs/wsj/s5/steps/nnet3/train_raw_rnn.py +++ b/egs/wsj/s5/steps/nnet3/train_raw_rnn.py @@ -9,7 +9,7 @@ """ This script is similar to steps/nnet3/train_rnn.py but trains a raw neural network instead of an acoustic model. """ - +from __future__ import print_function import argparse import logging import pprint @@ -123,6 +123,11 @@ def get_args(): backpropagated to. E.g., 8 is a reasonable setting. Note: the 'required' part of the chunk is defined by the model's {left,right}-context.""") + parser.add_argument("--compute-average-posteriors", + type=str, action=common_lib.StrToBoolAction, + choices=["true", "false"], default=False, + help="""If true, then the average output of the + network is computed and dumped as post.final.vec""") # General options parser.add_argument("--nj", type=int, default=4, @@ -155,10 +160,10 @@ def process_args(args): """ if not common_train_lib.validate_chunk_width(args.chunk_width): - raise Exception("--egs.chunk-width has an invalid value"); + raise Exception("--egs.chunk-width has an invalid value") if not common_train_lib.validate_minibatch_size_str(args.num_chunk_per_minibatch): - raise Exception("--trainer.rnn.num-chunk-per-minibatch has an invalid value"); + raise Exception("--trainer.rnn.num-chunk-per-minibatch has an invalid value") if args.chunk_left_context < 0: raise Exception("--egs.chunk-left-context should be non-negative") @@ -234,11 +239,6 @@ def train(args, run_opts): try: model_left_context = variables['model_left_context'] model_right_context = variables['model_right_context'] - if 'include_log_softmax' in variables: - include_log_softmax = common_lib.str_to_bool( - variables['include_log_softmax']) - else: - include_log_softmax = False except KeyError as e: raise Exception("KeyError {0}: Variables need to be defined in " "{1}".format(str(e), '{0}/configs'.format(args.dir))) @@ -312,15 +312,15 @@ def train(args, run_opts): [egs_left_context, egs_right_context, frames_per_eg_str, num_archives] = ( - common_train_lib.verify_egs_dir(egs_dir, feat_dim, - ivector_dim, ivector_id, - left_context, right_context)) + common_train_lib.verify_egs_dir(egs_dir, feat_dim, + ivector_dim, ivector_id, + left_context, right_context)) if args.chunk_width != frames_per_eg_str: raise Exception("mismatch between --egs.chunk-width and the frames_per_eg " "in the egs dir {0} vs {1}".format(args.chunk_width, - frames_per_eg_str)) + frames_per_eg_str)) - if (args.num_jobs_final > num_archives): + if args.num_jobs_final > num_archives: raise Exception('num_jobs_final cannot exceed the number of archives ' 'in the egs directory') @@ -328,7 +328,7 @@ def train(args, run_opts): # use during decoding common_train_lib.copy_egs_properties_to_exp_dir(egs_dir, args.dir) - if (args.stage <= -2) and os.path.exists(args.dir+"/configs/init.config"): + if args.stage <= -2 and os.path.exists(args.dir+"/configs/init.config"): logger.info('Computing the preconditioning matrix for input features') train_lib.common.compute_preconditioning_matrix( @@ -336,7 +336,7 @@ def train(args, run_opts): max_lda_jobs=args.max_lda_jobs, rand_prune=args.rand_prune) - if (args.stage <= -1): + if args.stage <= -1: logger.info("Preparing the initial network.") common_train_lib.prepare_initial_network(args.dir, run_opts) @@ -354,13 +354,6 @@ def train(args, run_opts): num_archives, args.max_models_combine, args.num_jobs_final) - def learning_rate(iter, current_num_jobs, num_archives_processed): - return common_train_lib.get_learning_rate(iter, current_num_jobs, - num_iters, - num_archives_processed, - num_archives_to_process, - args.initial_effective_lrate, - args.final_effective_lrate) min_deriv_time = None max_deriv_time_relative = None @@ -383,15 +376,26 @@ def learning_rate(iter, current_num_jobs, num_archives_processed): if args.stage <= iter: model_file = "{dir}/{iter}.raw".format(dir=args.dir, iter=iter) - shrinkage_value = 1.0 - if args.shrink_value != 1.0: + lrate = common_train_lib.get_learning_rate(iter, current_num_jobs, + num_iters, + num_archives_processed, + num_archives_to_process, + args.initial_effective_lrate, + args.final_effective_lrate) + + # shrinkage_value is a scale on the parameters. + shrinkage_value = 1.0 - (args.proportional_shrink * lrate) + if shrinkage_value <= 0.5: + raise Exception("proportional-shrink={0} is too large, it gives " + "shrink-value={1}".format(args.proportional_shrink, + shrinkage_value)) + if args.shrink_value < shrinkage_value: shrinkage_value = (args.shrink_value if common_train_lib.should_do_shrinkage( - iter, model_file, - args.shrink_saturation_threshold, - get_raw_nnet_from_am=False) - else 1 - ) + iter, model_file, + args.shrink_saturation_threshold, + get_raw_nnet_from_am=False) + else shrinkage_value) train_lib.common.train_one_iteration( dir=args.dir, @@ -401,8 +405,7 @@ def learning_rate(iter, current_num_jobs, num_archives_processed): num_jobs=current_num_jobs, num_archives_processed=num_archives_processed, num_archives=num_archives, - learning_rate=learning_rate(iter, current_num_jobs, - num_archives_processed), + learning_rate=lrate, dropout_edit_string=common_train_lib.get_dropout_edit_string( args.dropout_schedule, float(num_archives_processed) / num_archives_to_process, @@ -448,7 +451,7 @@ def learning_rate(iter, current_num_jobs, num_archives_processed): get_raw_nnet_from_am=False, sum_to_one_penalty=args.combine_sum_to_one_penalty) - if include_log_softmax and args.stage <= num_iters + 1: + if args.compute_average_posteriors and args.stage <= num_iters + 1: logger.info("Getting average posterior for purposes of " "adjusting the priors.") train_lib.common.compute_average_posterior( diff --git a/egs/wsj/s5/steps/nnet3/train_rnn.py b/egs/wsj/s5/steps/nnet3/train_rnn.py index c374cd8e8f9..8ba198ec5bd 100755 --- a/egs/wsj/s5/steps/nnet3/train_rnn.py +++ b/egs/wsj/s5/steps/nnet3/train_rnn.py @@ -7,6 +7,7 @@ """ This script is based on steps/nnet3/lstm/train.sh """ +from __future__ import print_function import argparse import logging import os @@ -56,7 +57,7 @@ def get_args(): 3. RNNs can also be trained with state preservation training""", formatter_class=argparse.ArgumentDefaultsHelpFormatter, conflict_handler='resolve', - parents=[common_train_lib.CommonParser(default_chunk_left_context = 40).parser]) + parents=[common_train_lib.CommonParser(default_chunk_left_context=40).parser]) # egs extraction options parser.add_argument("--egs.chunk-width", type=str, dest='chunk_width', @@ -150,10 +151,10 @@ def process_args(args): """ if not common_train_lib.validate_chunk_width(args.chunk_width): - raise Exception("--egs.chunk-width has an invalid value"); + raise Exception("--egs.chunk-width has an invalid value") if not common_train_lib.validate_minibatch_size_str(args.num_chunk_per_minibatch): - raise Exception("--trainer.rnn.num-chunk-per-minibatch has an invalid value"); + raise Exception("--trainer.rnn.num-chunk-per-minibatch has an invalid value") if args.chunk_left_context < 0: raise Exception("--egs.chunk-left-context should be non-negative") @@ -226,7 +227,7 @@ def train(args, run_opts): # split the training data into parts for individual jobs # we will use the same number of jobs as that used for alignment common_lib.execute_command("utils/split_data.sh {0} {1}".format( - args.feat_dir, num_jobs)) + args.feat_dir, num_jobs)) shutil.copy('{0}/tree'.format(args.ali_dir), args.dir) with open('{0}/num_jobs'.format(args.dir), 'w') as f: @@ -257,7 +258,7 @@ def train(args, run_opts): # we do this as it's a convenient way to get the stats for the 'lda-like' # transform. - if (args.stage <= -5): + if args.stage <= -5: logger.info("Initializing a basic network for estimating " "preconditioning matrix") common_lib.execute_command( @@ -267,7 +268,7 @@ def train(args, run_opts): dir=args.dir)) default_egs_dir = '{0}/egs'.format(args.dir) - if (args.stage <= -4) and args.egs_dir is None: + if args.stage <= -4 and args.egs_dir is None: logger.info("Generating egs") if args.feat_dir is None: @@ -297,16 +298,16 @@ def train(args, run_opts): [egs_left_context, egs_right_context, frames_per_eg_str, num_archives] = ( - common_train_lib.verify_egs_dir(egs_dir, feat_dim, - ivector_dim, ivector_id, - left_context, right_context, - left_context_initial, right_context_final)) + common_train_lib.verify_egs_dir(egs_dir, feat_dim, + ivector_dim, ivector_id, + left_context, right_context, + left_context_initial, right_context_final)) if args.chunk_width != frames_per_eg_str: raise Exception("mismatch between --egs.chunk-width and the frames_per_eg " "in the egs dir {0} vs {1}".format(args.chunk_width, frames_per_eg_str)) - if (args.num_jobs_final > num_archives): + if args.num_jobs_final > num_archives: raise Exception('num_jobs_final cannot exceed the number of archives ' 'in the egs directory') @@ -314,7 +315,7 @@ def train(args, run_opts): # use during decoding common_train_lib.copy_egs_properties_to_exp_dir(egs_dir, args.dir) - if (args.stage <= -3): + if args.stage <= -3: logger.info('Computing the preconditioning matrix for input features') train_lib.common.compute_preconditioning_matrix( @@ -322,17 +323,17 @@ def train(args, run_opts): max_lda_jobs=args.max_lda_jobs, rand_prune=args.rand_prune) - if (args.stage <= -2): + if args.stage <= -2: logger.info("Computing initial vector for FixedScaleComponent before" " softmax, using priors^{prior_scale} and rescaling to" " average 1".format( prior_scale=args.presoftmax_prior_scale_power)) common_train_lib.compute_presoftmax_prior_scale( - args.dir, args.ali_dir, num_jobs, run_opts, - presoftmax_prior_scale_power=args.presoftmax_prior_scale_power) + args.dir, args.ali_dir, num_jobs, run_opts, + presoftmax_prior_scale_power=args.presoftmax_prior_scale_power) - if (args.stage <= -1): + if args.stage <= -1: logger.info("Preparing the initial acoustic model.") train_lib.acoustic_model.prepare_initial_acoustic_model( args.dir, args.ali_dir, run_opts) @@ -348,17 +349,9 @@ def train(args, run_opts): models_to_combine = common_train_lib.get_model_combine_iters( num_iters, args.num_epochs, - num_archives, args.max_models_combine, + num_archives, args.max_models_combine, args.num_jobs_final) - def learning_rate(iter, current_num_jobs, num_archives_processed): - return common_train_lib.get_learning_rate(iter, current_num_jobs, - num_iters, - num_archives_processed, - num_archives_to_process, - args.initial_effective_lrate, - args.final_effective_lrate) - min_deriv_time = None max_deriv_time_relative = None if args.deriv_truncate_margin is not None: @@ -380,14 +373,24 @@ def learning_rate(iter, current_num_jobs, num_archives_processed): if args.stage <= iter: model_file = "{dir}/{iter}.mdl".format(dir=args.dir, iter=iter) - shrinkage_value = 1.0 - if args.shrink_value != 1.0: + + lrate = common_train_lib.get_learning_rate(iter, current_num_jobs, + num_iters, + num_archives_processed, + num_archives_to_process, + args.initial_effective_lrate, + args.final_effective_lrate) + + shrinkage_value = 1.0 - (args.proportional_shrink * lrate) + if shrinkage_value <= 0.5: + raise Exception("proportional-shrink={0} is too large, it gives " + "shrink-value={1}".format(args.proportional_shrink, + shrinkage_value)) + if args.shrink_value < shrinkage_value: shrinkage_value = (args.shrink_value if common_train_lib.should_do_shrinkage( - iter, model_file, - args.shrink_saturation_threshold) - else 1 - ) + iter, model_file, + args.shrink_saturation_threshold) else 1.0) train_lib.common.train_one_iteration( dir=args.dir, @@ -397,8 +400,7 @@ def learning_rate(iter, current_num_jobs, num_archives_processed): num_jobs=current_num_jobs, num_archives_processed=num_archives_processed, num_archives=num_archives, - learning_rate=learning_rate(iter, current_num_jobs, - num_archives_processed), + learning_rate=lrate, dropout_edit_string=common_train_lib.get_dropout_edit_string( args.dropout_schedule, float(num_archives_processed) / num_archives_to_process,