diff --git a/egs/wsj/s5/steps/nnet3/xconfig_to_configs.py b/egs/wsj/s5/steps/nnet3/xconfig_to_configs.py index 5184b6eed41..942e45da300 100755 --- a/egs/wsj/s5/steps/nnet3/xconfig_to_configs.py +++ b/egs/wsj/s5/steps/nnet3/xconfig_to_configs.py @@ -207,21 +207,30 @@ def write_config_files(config_dir, all_layers): def add_back_compatibility_info(config_dir): """This will be removed when python script refactoring is done.""" - common_lib.run_kaldi_command("nnet3-init {0}/ref.config " - "{0}/ref.raw".format(config_dir)) - out, err = common_lib.run_kaldi_command("nnet3-info {0}/ref.raw | " - "head -4".format(config_dir)) - # out looks like this - # left-context: 7 - # right-context: 0 - # num-parameters: 90543902 - # modulus: 1 - info = {} - for line in out.split("\n"): - parts = line.split(":") - if len(parts) != 2: - continue - info[parts[0].strip()] = int(parts[1].strip()) + info = {'left-context' : 0, 'right-context' : 0} + for file_name in ['init', 'ref']: + if os.path.exists('{0}/{1}.config'.format(config_dir, file_name)): + common_lib.run_kaldi_command("nnet3-init {0}/{1}.config " + "{0}/{1}.raw".format(config_dir, file_name)) + out, err = common_lib.run_kaldi_command("nnet3-info {0}/{1}.raw | " + "head -4".format(config_dir, file_name)) + # out looks like this + # left-context: 7 + # right-context: 0 + # num-parameters: 90543902 + # modulus: 1 + for line in out.split("\n"): + parts = line.split(":") + if len(parts) != 2: + continue + key = parts[0].strip() + value = int(parts[1].strip()) + if key in ['left-context', 'right-context']: + info[key] = max(info[key], value) + else: + info[key] = value + + # Writing the back-compatible vars file # model_left_context=0