Skip to content

Commit

Permalink
reuse init_model.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct committed Oct 23, 2023
1 parent c8cccdc commit 852165c
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 29 deletions.
11 changes: 8 additions & 3 deletions wenet/paraformer/ali_paraformer/config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# network architecture
# encoder related
encoder: SanEncoder
encoder: SanmEncoder
encoder_conf:
output_size: 512 # dimension of attention
attention_heads: 4
Expand All @@ -14,8 +14,9 @@ encoder_conf:
kernel_size: 11
sanm_shfit: 0

paraformer: true
# decoder related
decoder: transformer
decoder: SanmDecoder
decoder_conf:
attention_heads: 4
linear_units: 2048
Expand All @@ -28,7 +29,11 @@ decoder_conf:
kernel_size: 11
sanm_shfit: 0

predictor_conf:
lfr_conf:
lfr_m: 7
lfr_n: 6

cif_predictor_conf:
idim: 512
threshold: 1.0
l_order: 1
Expand Down
32 changes: 19 additions & 13 deletions wenet/paraformer/ali_paraformer/export_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from wenet.utils.checkpoint import load_checkpoint
from wenet.utils.cmvn import load_cmvn
from wenet.utils.file_utils import read_symbol_table
from wenet.utils.init_model import init_model


def get_args():
Expand Down Expand Up @@ -43,19 +44,24 @@ def main():
with open(args.config, 'r') as fin:
configs = yaml.load(fin, Loader=yaml.FullLoader)

mean, istd = load_cmvn(args.cmvn, is_json=True)
global_cmvn = GlobalCMVN(
torch.from_numpy(mean).float(),
torch.from_numpy(istd).float())
configs['encoder_conf']['input_size'] = 80 * 7
encoder = SanmEncoder(global_cmvn=global_cmvn, **configs['encoder_conf'])
configs['decoder_conf']['vocab_size'] = len(char_dict)
configs['decoder_conf']['encoder_output_size'] = encoder.output_size()
decoder = SanmDecoer(**configs['decoder_conf'])

# predictor = PredictorV3(**configs['predictor_conf'])
predictor = Predictor(**configs['predictor_conf'])
model = AliParaformer(encoder, decoder, predictor)
# mean, istd = load_cmvn(args.cmvn, is_json=True)
# global_cmvn = GlobalCMVN(
# torch.from_numpy(mean).float(),
# torch.from_numpy(istd).float())
# configs['encoder_conf']['input_size'] = 80 * 7
# encoder = SanmEncoder(global_cmvn=global_cmvn, **configs['encoder_conf'])
# configs['decoder_conf']['vocab_size'] = len(char_dict)
# configs['decoder_conf']['encoder_output_size'] = encoder.output_size()
# decoder = SanmDecoer(**configs['decoder_conf'])

# # predictor = PredictorV3(**configs['predictor_conf'])
# predictor = Predictor(**configs['predictor_conf'])
# model = AliParaformer(encoder, decoder, predictor)
configs['cmvn_file'] = args.cmvn
configs['is_json_cmvn'] = True
configs['input_dim'] = 80
configs['output_dim'] = len(char_dict)
model = init_model(configs)
load_checkpoint(model, args.ali_paraformer)
model.eval()

Expand Down
44 changes: 31 additions & 13 deletions wenet/utils/init_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import torch
from wenet.paraformer.ali_paraformer.model import SanmDecoer, SanmEncoder
from wenet.transducer.joint import TransducerJoint
from wenet.transducer.predictor import (ConvPredictor, EmbeddingPredictor,
RNNPredictor)
Expand All @@ -27,6 +28,7 @@
from wenet.squeezeformer.encoder import SqueezeformerEncoder
from wenet.efficient_conformer.encoder import EfficientConformerEncoder
from wenet.paraformer.paraformer import Paraformer
from wenet.paraformer.ali_paraformer.model import AliParaformer
from wenet.cif.predictor import Predictor
from wenet.utils.cmvn import load_cmvn

Expand Down Expand Up @@ -55,13 +57,12 @@ def init_model(configs):
global_cmvn=global_cmvn,
**configs['encoder_conf'])
elif encoder_type == 'efficientConformer':
encoder = EfficientConformerEncoder(input_dim,
global_cmvn=global_cmvn,
**configs['encoder_conf'],
**configs['encoder_conf']
['efficient_conf']
if 'efficient_conf' in
configs['encoder_conf'] else {})
encoder = EfficientConformerEncoder(
input_dim,
global_cmvn=global_cmvn,
**configs['encoder_conf'],
**configs['encoder_conf']['efficient_conf']
if 'efficient_conf' in configs['encoder_conf'] else {})
elif encoder_type == 'branchformer':
encoder = BranchformerEncoder(input_dim,
global_cmvn=global_cmvn,
Expand All @@ -70,13 +71,24 @@ def init_model(configs):
encoder = EBranchformerEncoder(input_dim,
global_cmvn=global_cmvn,
**configs['encoder_conf'])
elif encoder_type == 'SanmEncoder':
assert 'lfr_conf' in configs
encoder = SanmEncoder(global_cmvn=global_cmvn,
input_size=configs['lfr_conf']['lfr_m'] *
input_dim,
**configs['encoder_conf'])
else:
encoder = TransformerEncoder(input_dim,
global_cmvn=global_cmvn,
**configs['encoder_conf'])
if decoder_type == 'transformer':
decoder = TransformerDecoder(vocab_size, encoder.output_size(),
**configs['decoder_conf'])
elif decoder_type == 'SanmDecoder':
assert isinstance(encoder, SanmEncoder)
decoder = SanmDecoer(vocab_size=vocab_size,
encoder_output_size=encoder.output_size(),
**configs['decoder_conf'])
else:
assert 0.0 < configs['model_conf']['reverse_weight'] < 1.0
assert configs['decoder_conf']['r_num_blocks'] > 0
Expand Down Expand Up @@ -116,12 +128,18 @@ def init_model(configs):
**configs['model_conf'])
elif 'paraformer' in configs:
predictor = Predictor(**configs['cif_predictor_conf'])
model = Paraformer(vocab_size=vocab_size,
encoder=encoder,
decoder=decoder,
ctc=ctc,
predictor=predictor,
**configs['model_conf'])
if isinstance(encoder, SanmEncoder):
assert isinstance(decoder, SanmDecoer)
# NOTE(Mddct): only support inference for now
print('hello world')
model = AliParaformer(encoder, decoder, predictor)
else:
model = Paraformer(vocab_size=vocab_size,
encoder=encoder,
decoder=decoder,
ctc=ctc,
predictor=predictor,
**configs['model_conf'])
else:
model = ASRModel(vocab_size=vocab_size,
encoder=encoder,
Expand Down

0 comments on commit 852165c

Please sign in to comment.