Skip to content

Commit 3342541

Browse files
authored
fix(typo): class name SanmDecoer => SanmDecoder (#2110)
* fix(type): fix class name * fix(typo): fix class name
1 parent 091df4c commit 3342541

File tree

3 files changed

+7
-7
lines changed

3 files changed

+7
-7
lines changed

wenet/paraformer/export_jit.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch
66
import yaml
77
from wenet.paraformer.cif import Cif
8-
from wenet.paraformer.layers import (SanmDecoer, SanmEncoder)
8+
from wenet.paraformer.layers import (SanmDecoder, SanmEncoder)
99
from wenet.paraformer.paraformer import Paraformer
1010
from wenet.transformer.cmvn import GlobalCMVN
1111
from wenet.utils.checkpoint import load_checkpoint
@@ -33,9 +33,9 @@ def init_model(configs):
3333
encoder = SanmEncoder(global_cmvn=global_cmvn,
3434
input_size=configs['lfr_conf']['lfr_m'] * input_dim,
3535
**configs['encoder_conf'])
36-
decoder = decoder = SanmDecoer(vocab_size=vocab_size,
37-
encoder_output_size=encoder.output_size(),
38-
**configs['decoder_conf'])
36+
decoder = decoder = SanmDecoder(vocab_size=vocab_size,
37+
encoder_output_size=encoder.output_size(),
38+
**configs['decoder_conf'])
3939
predictor = Cif(**configs['cif_predictor_conf'])
4040
model = Paraformer(
4141
encoder=encoder,

wenet/paraformer/layers.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,7 @@ def forward(
372372
return x, tgt_mask, memory, memory_mask
373373

374374

375-
class SanmDecoer(TransformerDecoder):
375+
class SanmDecoder(TransformerDecoder):
376376

377377
def __init__(
378378
self,

wenet/paraformer/paraformer.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import torch
2323
from wenet.paraformer.cif import Cif
2424

25-
from wenet.paraformer.layers import SanmDecoer, SanmEncoder
25+
from wenet.paraformer.layers import SanmDecoder, SanmEncoder
2626
from wenet.paraformer.layers import LFR
2727
from wenet.paraformer.search import (paraformer_beam_search,
2828
paraformer_greedy_search)
@@ -37,7 +37,7 @@ class Paraformer(torch.nn.Module):
3737
3838
"""
3939

40-
def __init__(self, encoder: SanmEncoder, decoder: SanmDecoer,
40+
def __init__(self, encoder: SanmEncoder, decoder: SanmDecoder,
4141
predictor: Cif):
4242

4343
super().__init__()

0 commit comments

Comments
 (0)