From e8fa0132fc4957d3448c0a57fe94f2520c300652 Mon Sep 17 00:00:00 2001 From: Mddct Date: Mon, 23 Oct 2023 19:54:02 +0800 Subject: [PATCH] mv the intermediate files to the assets directory --- .../ali_paraformer/{ => assets}/config.yaml | 2 ++ .../ali_paraformer/{ => assets}/global_cmvn | 0 .../ali_paraformer/{ => assets}/units.txt | 0 .../{export_jit.py => test_infer_jit.py} | 16 ---------------- wenet/utils/init_model.py | 1 - 5 files changed, 2 insertions(+), 17 deletions(-) rename wenet/paraformer/ali_paraformer/{ => assets}/config.yaml (97%) rename wenet/paraformer/ali_paraformer/{ => assets}/global_cmvn (100%) rename wenet/paraformer/ali_paraformer/{ => assets}/units.txt (100%) rename wenet/paraformer/ali_paraformer/{export_jit.py => test_infer_jit.py} (78%) diff --git a/wenet/paraformer/ali_paraformer/config.yaml b/wenet/paraformer/ali_paraformer/assets/config.yaml similarity index 97% rename from wenet/paraformer/ali_paraformer/config.yaml rename to wenet/paraformer/ali_paraformer/assets/config.yaml index 94dcea281..fa2b2a50e 100644 --- a/wenet/paraformer/ali_paraformer/config.yaml +++ b/wenet/paraformer/ali_paraformer/assets/config.yaml @@ -14,6 +14,8 @@ encoder_conf: kernel_size: 11 sanm_shfit: 0 +input_dim: 80 +output_dim: 8404 paraformer: true # decoder related decoder: SanmDecoder diff --git a/wenet/paraformer/ali_paraformer/global_cmvn b/wenet/paraformer/ali_paraformer/assets/global_cmvn similarity index 100% rename from wenet/paraformer/ali_paraformer/global_cmvn rename to wenet/paraformer/ali_paraformer/assets/global_cmvn diff --git a/wenet/paraformer/ali_paraformer/units.txt b/wenet/paraformer/ali_paraformer/assets/units.txt similarity index 100% rename from wenet/paraformer/ali_paraformer/units.txt rename to wenet/paraformer/ali_paraformer/assets/units.txt diff --git a/wenet/paraformer/ali_paraformer/export_jit.py b/wenet/paraformer/ali_paraformer/test_infer_jit.py similarity index 78% rename from wenet/paraformer/ali_paraformer/export_jit.py rename to wenet/paraformer/ali_paraformer/test_infer_jit.py index b77735571..b3a168add 100644 --- a/wenet/paraformer/ali_paraformer/export_jit.py +++ b/wenet/paraformer/ali_paraformer/test_infer_jit.py @@ -43,24 +43,8 @@ def main(): char_dict = {v: k for k, v in symbol_table.items()} with open(args.config, 'r') as fin: configs = yaml.load(fin, Loader=yaml.FullLoader) - - # mean, istd = load_cmvn(args.cmvn, is_json=True) - # global_cmvn = GlobalCMVN( - # torch.from_numpy(mean).float(), - # torch.from_numpy(istd).float()) - # configs['encoder_conf']['input_size'] = 80 * 7 - # encoder = SanmEncoder(global_cmvn=global_cmvn, **configs['encoder_conf']) - # configs['decoder_conf']['vocab_size'] = len(char_dict) - # configs['decoder_conf']['encoder_output_size'] = encoder.output_size() - # decoder = SanmDecoer(**configs['decoder_conf']) - - # # predictor = PredictorV3(**configs['predictor_conf']) - # predictor = Predictor(**configs['predictor_conf']) - # model = AliParaformer(encoder, decoder, predictor) configs['cmvn_file'] = args.cmvn configs['is_json_cmvn'] = True - configs['input_dim'] = 80 - configs['output_dim'] = len(char_dict) model = init_model(configs) load_checkpoint(model, args.ali_paraformer) model.eval() diff --git a/wenet/utils/init_model.py b/wenet/utils/init_model.py index f4e62239d..22ad0fa07 100644 --- a/wenet/utils/init_model.py +++ b/wenet/utils/init_model.py @@ -131,7 +131,6 @@ def init_model(configs): if isinstance(encoder, SanmEncoder): assert isinstance(decoder, SanmDecoer) # NOTE(Mddct): only support inference for now - print('hello world') model = AliParaformer(encoder, decoder, predictor) else: model = Paraformer(vocab_size=vocab_size,