diff --git a/egs/wsj/s5/steps/libs/nnet3/xconfig/basic_layers.py b/egs/wsj/s5/steps/libs/nnet3/xconfig/basic_layers.py index aa301eb4398..7c01689e86c 100644 --- a/egs/wsj/s5/steps/libs/nnet3/xconfig/basic_layers.py +++ b/egs/wsj/s5/steps/libs/nnet3/xconfig/basic_layers.py @@ -998,8 +998,7 @@ def set_default_configs(self): self.config = {'input': '[-1]', 'dim': -1, 'cepstral-lifter': 22.0, - 'affine-transform-file': '', - 'write-init-config': True} + 'affine-transform-file': ''} def check_configs(self): if self.config['affine-transform-file'] is None: @@ -1031,13 +1030,6 @@ def get_full_config(self): output_dim = self.output_dim() transform_file = self.config['affine-transform-file'] - if self.config['write-init-config']: - # to init.config we write an output-node with the name 'output' and - # with a Descriptor equal to the descriptor that's the input to this - # layer. This will be used to accumulate stats to learn the LDA transform. - line = 'output-node name=output input={0}'.format(descriptor_final_string) - ans.append(('init', line)) - idct_mat = common_lib.compute_idct_matrix( input_dim, output_dim, self.config['cepstral-lifter']) # append a zero column to the matrix, this is the bias of the fixed diff --git a/egs/wsj/s5/steps/nnet3/chain/train.py b/egs/wsj/s5/steps/nnet3/chain/train.py index 65b09f9f1eb..6bc51dcbd3f 100755 --- a/egs/wsj/s5/steps/nnet3/chain/train.py +++ b/egs/wsj/s5/steps/nnet3/chain/train.py @@ -307,7 +307,7 @@ def train(args, run_opts): logger.info("Creating denominator FST") chain_lib.create_denominator_fst(args.dir, args.tree_dir, run_opts) - if (args.stage <= -4): + if (args.stage <= -4) and os.path.exists(args.dir+"/configs/init.config"): logger.info("Initializing a basic network for estimating " "preconditioning matrix") common_lib.execute_command( @@ -375,7 +375,7 @@ def train(args, run_opts): logger.info("Copying the properties from {0} to {1}".format(egs_dir, args.dir)) common_train_lib.copy_egs_properties_to_exp_dir(egs_dir, args.dir) - if (args.stage <= -2): + if (args.stage <= -2) and os.path.exists(args.dir+"/configs/init.config"): logger.info('Computing the preconditioning matrix for input features') chain_lib.compute_preconditioning_matrix(