Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions egs/wsj/s5/steps/libs/nnet3/train/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -901,5 +901,16 @@ def __init__(self,
then only failure notifications are sent""")


def is_lda_added(config_dir):
"""Returns true if there is an lda.mat in init.config
which suggests the LDA matrix training stage needs to be executed in the
nnet training"""
for line in open("{config_dir}/init.config".format(config_dir)):
line = line.strip().split("#")[0]
if re.search(r"lda\.mat", line):
return True
return False


if __name__ == '__main__':
self_test()
18 changes: 11 additions & 7 deletions egs/wsj/s5/steps/nnet3/train_raw_dnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def get_args():
rule as accepted by the --minibatch-size option of
nnet3-merge-egs; run that program without args to see
the format.""")

parser.add_argument("--trainer.optimization.proportional-shrink", type=float,
dest='proportional_shrink', default=0.0,
help="""If nonzero, this will set a shrinkage (scaling)
Expand All @@ -92,6 +93,11 @@ def get_args():
Unlike for train_rnn.py, this is applied unconditionally,
it does not depend on saturation of nonlinearities.
Can be used to roughly approximate l2 regularization.""")
parser.add_argument("--compute-average-posteriors",
type=str, action=common_lib.StrToBoolAction,
choices=["true", "false"], default=False,
help="""If true, then the average output of the
network is computed and dumped as post.final.vec""")

# General options
parser.add_argument("--nj", type=int, default=4,
Expand Down Expand Up @@ -198,11 +204,7 @@ def train(args, run_opts):
try:
model_left_context = variables['model_left_context']
model_right_context = variables['model_right_context']
if 'include_log_softmax' in variables:
include_log_softmax = common_lib.str_to_bool(
variables['include_log_softmax'])
else:
include_log_softmax = False

except KeyError as e:
raise Exception("KeyError {0}: Variables need to be defined in "
"{1}".format(str(e), '{0}/configs'.format(args.dir)))
Expand Down Expand Up @@ -286,7 +288,9 @@ def train(args, run_opts):
# use during decoding
common_train_lib.copy_egs_properties_to_exp_dir(egs_dir, args.dir)

if (args.stage <= -3) and os.path.exists(args.dir+"/configs/init.config"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there other reasons we use init.config other than to make the LDA-like transform?

add_lda = common_train_lib.is_lda_added(config_dir)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there other reasons we use init.config other than to make the LDA-like transform?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably not. It seems like to be removed in the transfer learning PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

init.config is still used to create the initial model. The check is needed to know if the LDA needs to be trained.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did you see my comment:
"I think you are mistaken, if we are talking about the current kaldi_52
code; init.config is only used if we are doing the LDA thing."
can you please update the PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not understand what needs to be done. The current solution is needed to know if LDA needs to be trained. The function is_lda_added can be changed to read init.raw if needed.

if (add_lda and args.stage <= -3):
logger.info('Computing the preconditioning matrix for input features')

train_lib.common.compute_preconditioning_matrix(
Expand Down Expand Up @@ -417,7 +421,7 @@ def learning_rate(iter, current_num_jobs, num_archives_processed):
common_lib.force_symlink("{0}.raw".format(num_iters),
"{0}/final.raw".format(args.dir))

if include_log_softmax and args.stage <= num_iters + 1:
if compute_average_posteriors and args.stage <= num_iters + 1:
logger.info("Getting average posterior for output-node 'output'.")
train_lib.common.compute_average_posterior(
dir=args.dir, iter='final', egs_dir=egs_dir,
Expand Down