diff --git a/wenet/paraformer/ali_paraformer/config.yaml b/wenet/paraformer/ali_paraformer/config.yaml index e23593edd..94dcea281 100644 --- a/wenet/paraformer/ali_paraformer/config.yaml +++ b/wenet/paraformer/ali_paraformer/config.yaml @@ -1,6 +1,6 @@ # network architecture # encoder related -encoder: SanEncoder +encoder: SanmEncoder encoder_conf: output_size: 512 # dimension of attention attention_heads: 4 @@ -14,8 +14,9 @@ encoder_conf: kernel_size: 11 sanm_shfit: 0 +paraformer: true # decoder related -decoder: transformer +decoder: SanmDecoder decoder_conf: attention_heads: 4 linear_units: 2048 @@ -28,7 +29,11 @@ decoder_conf: kernel_size: 11 sanm_shfit: 0 -predictor_conf: +lfr_conf: + lfr_m: 7 + lfr_n: 6 + +cif_predictor_conf: idim: 512 threshold: 1.0 l_order: 1 diff --git a/wenet/paraformer/ali_paraformer/export_jit.py b/wenet/paraformer/ali_paraformer/export_jit.py index d50e5bcd5..b77735571 100644 --- a/wenet/paraformer/ali_paraformer/export_jit.py +++ b/wenet/paraformer/ali_paraformer/export_jit.py @@ -16,6 +16,7 @@ from wenet.utils.checkpoint import load_checkpoint from wenet.utils.cmvn import load_cmvn from wenet.utils.file_utils import read_symbol_table +from wenet.utils.init_model import init_model def get_args(): @@ -43,19 +44,24 @@ def main(): with open(args.config, 'r') as fin: configs = yaml.load(fin, Loader=yaml.FullLoader) - mean, istd = load_cmvn(args.cmvn, is_json=True) - global_cmvn = GlobalCMVN( - torch.from_numpy(mean).float(), - torch.from_numpy(istd).float()) - configs['encoder_conf']['input_size'] = 80 * 7 - encoder = SanmEncoder(global_cmvn=global_cmvn, **configs['encoder_conf']) - configs['decoder_conf']['vocab_size'] = len(char_dict) - configs['decoder_conf']['encoder_output_size'] = encoder.output_size() - decoder = SanmDecoer(**configs['decoder_conf']) - - # predictor = PredictorV3(**configs['predictor_conf']) - predictor = Predictor(**configs['predictor_conf']) - model = AliParaformer(encoder, decoder, predictor) + # mean, istd = load_cmvn(args.cmvn, is_json=True) + # global_cmvn = GlobalCMVN( + # torch.from_numpy(mean).float(), + # torch.from_numpy(istd).float()) + # configs['encoder_conf']['input_size'] = 80 * 7 + # encoder = SanmEncoder(global_cmvn=global_cmvn, **configs['encoder_conf']) + # configs['decoder_conf']['vocab_size'] = len(char_dict) + # configs['decoder_conf']['encoder_output_size'] = encoder.output_size() + # decoder = SanmDecoer(**configs['decoder_conf']) + + # # predictor = PredictorV3(**configs['predictor_conf']) + # predictor = Predictor(**configs['predictor_conf']) + # model = AliParaformer(encoder, decoder, predictor) + configs['cmvn_file'] = args.cmvn + configs['is_json_cmvn'] = True + configs['input_dim'] = 80 + configs['output_dim'] = len(char_dict) + model = init_model(configs) load_checkpoint(model, args.ali_paraformer) model.eval() diff --git a/wenet/utils/init_model.py b/wenet/utils/init_model.py index 886bca77c..f4e62239d 100644 --- a/wenet/utils/init_model.py +++ b/wenet/utils/init_model.py @@ -13,6 +13,7 @@ # limitations under the License. import torch +from wenet.paraformer.ali_paraformer.model import SanmDecoer, SanmEncoder from wenet.transducer.joint import TransducerJoint from wenet.transducer.predictor import (ConvPredictor, EmbeddingPredictor, RNNPredictor) @@ -27,6 +28,7 @@ from wenet.squeezeformer.encoder import SqueezeformerEncoder from wenet.efficient_conformer.encoder import EfficientConformerEncoder from wenet.paraformer.paraformer import Paraformer +from wenet.paraformer.ali_paraformer.model import AliParaformer from wenet.cif.predictor import Predictor from wenet.utils.cmvn import load_cmvn @@ -55,13 +57,12 @@ def init_model(configs): global_cmvn=global_cmvn, **configs['encoder_conf']) elif encoder_type == 'efficientConformer': - encoder = EfficientConformerEncoder(input_dim, - global_cmvn=global_cmvn, - **configs['encoder_conf'], - **configs['encoder_conf'] - ['efficient_conf'] - if 'efficient_conf' in - configs['encoder_conf'] else {}) + encoder = EfficientConformerEncoder( + input_dim, + global_cmvn=global_cmvn, + **configs['encoder_conf'], + **configs['encoder_conf']['efficient_conf'] + if 'efficient_conf' in configs['encoder_conf'] else {}) elif encoder_type == 'branchformer': encoder = BranchformerEncoder(input_dim, global_cmvn=global_cmvn, @@ -70,6 +71,12 @@ def init_model(configs): encoder = EBranchformerEncoder(input_dim, global_cmvn=global_cmvn, **configs['encoder_conf']) + elif encoder_type == 'SanmEncoder': + assert 'lfr_conf' in configs + encoder = SanmEncoder(global_cmvn=global_cmvn, + input_size=configs['lfr_conf']['lfr_m'] * + input_dim, + **configs['encoder_conf']) else: encoder = TransformerEncoder(input_dim, global_cmvn=global_cmvn, @@ -77,6 +84,11 @@ def init_model(configs): if decoder_type == 'transformer': decoder = TransformerDecoder(vocab_size, encoder.output_size(), **configs['decoder_conf']) + elif decoder_type == 'SanmDecoder': + assert isinstance(encoder, SanmEncoder) + decoder = SanmDecoer(vocab_size=vocab_size, + encoder_output_size=encoder.output_size(), + **configs['decoder_conf']) else: assert 0.0 < configs['model_conf']['reverse_weight'] < 1.0 assert configs['decoder_conf']['r_num_blocks'] > 0 @@ -116,12 +128,18 @@ def init_model(configs): **configs['model_conf']) elif 'paraformer' in configs: predictor = Predictor(**configs['cif_predictor_conf']) - model = Paraformer(vocab_size=vocab_size, - encoder=encoder, - decoder=decoder, - ctc=ctc, - predictor=predictor, - **configs['model_conf']) + if isinstance(encoder, SanmEncoder): + assert isinstance(decoder, SanmDecoer) + # NOTE(Mddct): only support inference for now + print('hello world') + model = AliParaformer(encoder, decoder, predictor) + else: + model = Paraformer(vocab_size=vocab_size, + encoder=encoder, + decoder=decoder, + ctc=ctc, + predictor=predictor, + **configs['model_conf']) else: model = ASRModel(vocab_size=vocab_size, encoder=encoder,