Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion egs/wsj/s5/steps/nnet3/chain/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@ def get_args():
should halve --trainer.samples-per-iter. May be
a comma-separated list of alternatives: first width
is the 'principal' chunk-width, used preferentially""")
parser.add_argument("--egs.nj", type=int, required=False,
default=0, dest="egs_nj",
help="""Number of jobs to use when generating egs.
Default: the same number as used for tree generation.
You probably do not need to tweak this, unless you
want to adapt a neural network on some different,
smaller-size data.""")

# chain options
parser.add_argument("--chain.lm-opts", type=str, dest='lm_opts',
Expand Down Expand Up @@ -283,7 +290,10 @@ def train(args, run_opts):
shutil.copy('{0}/phones.txt'.format(args.tree_dir), args.dir)

# Set some variables.
num_jobs = common_lib.get_number_of_jobs(args.tree_dir)
if args.egs_nj <= 0:
num_jobs = common_lib.get_number_of_jobs(args.tree_dir)
else:
num_jobs = args.egs_nj
feat_dim = common_lib.get_feat_dim(args.feat_dir)
ivector_dim = common_lib.get_ivector_dim(args.online_ivector_dir)
ivector_id = common_lib.get_ivector_extractor_id(args.online_ivector_dir)
Expand Down