Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 43 additions & 31 deletions egs/wsj/s5/steps/libs/nnet3/train/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,13 @@ def get_outputs_list(model_file, get_raw_nnet_from_am=True):
It will normally return 'output'.
"""
if get_raw_nnet_from_am:
outputs_list = common_lib.get_command_stdout(
"nnet3-am-info --print-args=false {0} | "
"grep -e 'output-node' | cut -f2 -d' ' | cut -f2 -d'=' ".format(model_file))
outputs_list = common_lib.get_command_stdout(
"nnet3-am-info --print-args=false {0} | "
"grep -e 'output-node' | cut -f2 -d' ' | cut -f2 -d'=' ".format(model_file))
else:
outputs_list = common_lib.get_command_stdout(
"nnet3-info --print-args=false {0} | "
"grep -e 'output-node' | cut -f2 -d' ' | cut -f2 -d'=' ".format(model_file))
outputs_list = common_lib.get_command_stdout(
"nnet3-info --print-args=false {0} | "
"grep -e 'output-node' | cut -f2 -d' ' | cut -f2 -d'=' ".format(model_file))

return outputs_list.split()

Expand All @@ -71,13 +71,13 @@ def get_multitask_egs_opts(egs_dir, egs_prefix="",
"valid_diagnostic." for validation.
"""
multitask_egs_opts = ""
egs_suffix = ".{0}".format(archive_index) if archive_index > -1 else ""
egs_suffix = ".{0}".format(archive_index) if archive_index > -1 else ""

if use_multitask_egs:
output_file_name = ("{egs_dir}/{egs_prefix}output{egs_suffix}.ark"
"".format(egs_dir=egs_dir,
egs_prefix=egs_prefix,
egs_suffix=egs_suffix))
egs_prefix=egs_prefix,
egs_suffix=egs_suffix))
output_rename_opt = ""
if os.path.isfile(output_file_name):
output_rename_opt = ("--outputs=ark:{output_file_name}".format(
Expand All @@ -102,7 +102,7 @@ def get_multitask_egs_opts(egs_dir, egs_prefix="",

def get_successful_models(num_models, log_file_pattern,
difference_threshold=1.0):
assert(num_models > 0)
assert num_models > 0

parse_regex = re.compile(
"LOG .* Overall average objective function for "
Expand Down Expand Up @@ -163,9 +163,9 @@ def get_best_nnet_model(dir, iter, best_model_index, run_opts,
get_raw_nnet_from_am=True):

best_model = "{dir}/{next_iter}.{best_model_index}.raw".format(
dir=dir,
next_iter=iter + 1,
best_model_index=best_model_index)
dir=dir,
next_iter=iter + 1,
best_model_index=best_model_index)

if get_raw_nnet_from_am:
out_model = ("""- \| nnet3-am-copy --set-raw-nnet=- \
Expand Down Expand Up @@ -406,21 +406,21 @@ def verify_egs_dir(egs_dir, feat_dim, ivector_dim, ivector_extractor_id,
logger.warning("The ivector ids are inconsistently used. It's your "
"responsibility to make sure the ivector extractor "
"has been used consistently")
elif (((egs_ivector_id is None) and (ivector_extractor_id is None))):
elif ((egs_ivector_id is None) and (ivector_extractor_id is None)):
logger.warning("The ivector ids are not used. It's your "
"responsibility to make sure the ivector extractor "
"has been used consistently")
elif (ivector_extractor_id != egs_ivector_id):
elif ivector_extractor_id != egs_ivector_id:
raise Exception("The egs were generated using a different ivector "
"extractor. id1 = {0}, id2={1}".format(
ivector_extractor_id, egs_ivector_id));

if (egs_left_context < left_context or
egs_right_context < right_context):
egs_right_context < right_context):
raise Exception('The egs have insufficient (l,r) context ({0},{1}) '
'versus expected ({2},{3})'.format(
egs_left_context, egs_right_context,
left_context, right_context))
egs_left_context, egs_right_context,
left_context, right_context))

# the condition on the initial/final context is an equality condition,
# not an inequality condition, as there is no mechanism to 'correct' the
Expand All @@ -435,8 +435,8 @@ def verify_egs_dir(egs_dir, feat_dim, ivector_dim, ivector_extractor_id,
raise Exception('The egs have incorrect initial/final (l,r) context '
'({0},{1}) versus expected ({2},{3}). See code from '
'where this exception was raised for more info'.format(
egs_left_context_initial, egs_right_context_final,
left_context_initial, right_context_final))
egs_left_context_initial, egs_right_context_final,
left_context_initial, right_context_final))

frames_per_eg_str = open('{0}/info/frames_per_eg'.format(
egs_dir)).readline().rstrip()
Expand Down Expand Up @@ -476,9 +476,9 @@ def compute_presoftmax_prior_scale(dir, alidir, num_jobs, run_opts,
os.remove(file)
pdf_counts = common_lib.read_kaldi_matrix('{0}/pdf_counts'.format(dir))[0]
scaled_counts = smooth_presoftmax_prior_scale_vector(
pdf_counts,
presoftmax_prior_scale_power=presoftmax_prior_scale_power,
smooth=0.01)
pdf_counts,
presoftmax_prior_scale_power=presoftmax_prior_scale_power,
smooth=0.01)

output_file = "{0}/presoftmax_prior_scale.vec".format(dir)
common_lib.write_kaldi_matrix(output_file, [scaled_counts])
Expand Down Expand Up @@ -601,7 +601,7 @@ def should_do_shrinkage(iter, model_file, shrink_saturation_threshold,
"saturation from the output '{0}' of "
"get_saturation.pl on the info of "
"model {1}".format(output, model_file))
return (saturation > shrink_saturation_threshold)
return saturation > shrink_saturation_threshold


def remove_nnet_egs(egs_dir):
Expand Down Expand Up @@ -651,7 +651,7 @@ def self_test():
assert validate_chunk_width('64,25,128')


class CommonParser:
class CommonParser(object):
"""Parser for parsing common options related to nnet3 training.

This argument parser adds common options related to nnet3 training
Expand All @@ -663,7 +663,7 @@ class CommonParser:
parser = argparse.ArgumentParser(add_help=False)

def __init__(self,
include_chunk_context = True,
include_chunk_context=True,
default_chunk_left_context=0):
# feat options
self.parser.add_argument("--feat.online-ivector-dir", type=str,
Expand Down Expand Up @@ -692,11 +692,11 @@ def __init__(self,
the case of FF-DNN this extra context will be
used to allow for frame-shifts""")
self.parser.add_argument("--egs.chunk-right-context", type=int,
dest='chunk_right_context', default=0,
help="""Number of additional frames of input
to the right of the input chunk. This extra
context will be used in the estimation of
bidirectional RNN state before prediction of
dest='chunk_right_context', default=0,
help="""Number of additional frames of input
to the right of the input chunk. This extra
context will be used in the estimation of
bidirectional RNN state before prediction of
the first label.""")
self.parser.add_argument("--egs.chunk-left-context-initial", type=int,
dest='chunk_left_context_initial', default=-1,
Expand Down Expand Up @@ -780,6 +780,18 @@ def __init__(self,
dest='presoftmax_prior_scale_power',
default=-0.25,
help="Scale on presofmax prior")
self.parser.add_argument("--trainer.optimization.proportional-shrink", type=float,
dest='proportional_shrink', default=0.0,
help="""If nonzero, this will set a shrinkage (scaling)
factor for the parameters, whose value is set as:
shrink-value=(1.0 - proportional-shrink * learning-rate), where
'learning-rate' is the learning rate being applied
on the current iteration, which will vary from
initial-effective-lrate*num-jobs-initial to
final-effective-lrate*num-jobs-final.
Unlike for train_rnn.py, this is applied unconditionally,
it does not depend on saturation of nonlinearities.
Can be used to roughly approximate l2 regularization.""")

# Parameters for the optimization
self.parser.add_argument(
Expand Down
50 changes: 15 additions & 35 deletions egs/wsj/s5/steps/nnet3/chain/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,19 +145,6 @@ def get_args():
steps/nnet3/get_saturation.pl) exceeds this threshold
we scale the parameter matrices with the
shrink-value.""")
parser.add_argument("--trainer.optimization.proportional-shrink", type=float,
dest='proportional_shrink', default=0.0,
help="""If nonzero, this will set a shrinkage (scaling)
factor for the parameters, whose value is set as:
shrink-value=(1.0 - proportional-shrink * learning-rate), where
'learning-rate' is the learning rate being applied
on the current iteration, which will vary from
initial-effective-lrate*num-jobs-initial to
final-effective-lrate*num-jobs-final.
Unlike for train_rnn.py, this is applied unconditionally,
it does not depend on saturation of nonlinearities.
Can be used to roughly approximate l2 regularization.""")

# RNN-specific training options
parser.add_argument("--trainer.deriv-truncate-margin", type=int,
dest='deriv_truncate_margin', default=None,
Expand Down Expand Up @@ -419,14 +406,6 @@ def train(args, run_opts):
num_archives_expanded, args.max_models_combine,
args.num_jobs_final)

def learning_rate(iter, current_num_jobs, num_archives_processed):
return common_train_lib.get_learning_rate(iter, current_num_jobs,
num_iters,
num_archives_processed,
num_archives_to_process,
args.initial_effective_lrate,
args.final_effective_lrate)

min_deriv_time = None
max_deriv_time_relative = None
if args.deriv_truncate_margin is not None:
Expand All @@ -448,22 +427,23 @@ def learning_rate(iter, current_num_jobs, num_archives_processed):
if args.stage <= iter:
model_file = "{dir}/{iter}.mdl".format(dir=args.dir, iter=iter)

lrate = learning_rate(iter, current_num_jobs,
num_archives_processed)
shrink_value = 1.0
if args.proportional_shrink != 0.0:
shrink_value = 1.0 - (args.proportional_shrink * lrate)
if shrink_value <= 0.5:
raise Exception("proportional-shrink={0} is too large, it gives "
"shrink-value={1}".format(args.proportional_shrink,
shrink_value))

if args.shrink_value < shrink_value:
shrink_value = (args.shrink_value
lrate = common_train_lib.get_learning_rate(iter, current_num_jobs,
num_iters,
num_archives_processed,
num_archives_to_process,
args.initial_effective_lrate,
args.final_effective_lrate)
shrinkage_value = 1.0 - (args.proportional_shrink * lrate)
if shrinkage_value <= 0.5:
raise Exception("proportional-shrink={0} is too large, it gives "
"shrink-value={1}".format(args.proportional_shrink,
shrinkage_value))
if args.shrink_value < shrinkage_value:
shrinkage_value = (args.shrink_value
if common_train_lib.should_do_shrinkage(
iter, model_file,
args.shrink_saturation_threshold)
else shrink_value)
else shrinkage_value)

chain_lib.train_one_iteration(
dir=args.dir,
Expand All @@ -478,7 +458,7 @@ def learning_rate(iter, current_num_jobs, num_archives_processed):
args.dropout_schedule,
float(num_archives_processed) / num_archives_to_process,
iter),
shrinkage_value=shrink_value,
shrinkage_value=shrinkage_value,
num_chunk_per_minibatch_str=args.num_chunk_per_minibatch,
apply_deriv_weights=args.apply_deriv_weights,
min_deriv_time=min_deriv_time,
Expand Down
1 change: 1 addition & 0 deletions egs/wsj/s5/steps/nnet3/components.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#!/usr/bin/env python
# Note: this file is part of some nnet3 config-creation tools that are now deprecated.

from __future__ import print_function
import os
Expand Down
Loading