Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ASR] support u2pp based cli and server, optimiz code of u2pp decode (reversed_weight parameter) #2489

Merged
merged 2 commits into from
Sep 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion demos/streaming_asr_server/conf/application.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ engine_list: ['asr_online']
################################### ASR #########################################
################### speech task: asr; engine_type: online #######################
asr_online:
model_type: 'conformer_online_wenetspeech'
model_type: 'conformer_u2pp_online_wenetspeech'
am_model: # the pdmodel file of am static model [optional]
am_params: # the pdiparams file of am static model [optional]
lang: 'zh'
Expand Down
1 change: 1 addition & 0 deletions docs/source/released_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Acoustic Model | Training Data | Token-based | Size | Descriptions | CER | WER |
[Ds2 Online Aishell ASR0 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_fbank161_ckpt_0.2.1.model.tar.gz) | Aishell Dataset | Char-based | 491 MB | 2 Conv + 5 LSTM layers | 0.0666 |-| 151 h | [D2 Online Aishell ASR0](../../examples/aishell/asr0) | onnx/inference/python |
[Ds2 Offline Aishell ASR0 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_offline_aishell_ckpt_1.0.1.model.tar.gz)| Aishell Dataset | Char-based | 1.4 GB | 2 Conv + 5 bidirectional LSTM layers| 0.0554 |-| 151 h | [Ds2 Offline Aishell ASR0](../../examples/aishell/asr0) | inference/python |
[Conformer Online Wenetspeech ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/asr1_chunk_conformer_wenetspeech_ckpt_1.0.0a.model.tar.gz) | WenetSpeech Dataset | Char-based | 457 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring| 0.11 (test\_net) 0.1879 (test\_meeting) |-| 10000 h |- | python |
[Conformer U2PP Online Wenetspeech ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/asr1_chunk_conformer_u2pp_wenetspeech_ckpt_1.1.1.model.tar.gz) | WenetSpeech Dataset | Char-based | 476 MB | Encoder:Conformer, Decoder:BiTransformer, Decoding method: Attention rescoring| 0.047198 (aishell test\_-1) 0.059212 (aishell test\_16) |-| 10000 h |- | python |
[Conformer Online Aishell ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_chunk_conformer_aishell_ckpt_0.2.0.model.tar.gz) | Aishell Dataset | Char-based | 189 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring| 0.0544 |-| 151 h | [Conformer Online Aishell ASR1](../../examples/aishell/asr1) | python |
[Conformer Offline Aishell ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_conformer_aishell_ckpt_1.0.1.model.tar.gz) | Aishell Dataset | Char-based | 189 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring | 0.0460 |-| 151 h | [Conformer Offline Aishell ASR1](../../examples/aishell/asr1) | python |
[Transformer Aishell ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_transformer_aishell_ckpt_0.1.1.model.tar.gz) | Aishell Dataset | Char-based | 128 MB | Encoder:Transformer, Decoder:Transformer, Decoding method: Attention rescoring | 0.0523 || 151 h | [Transformer Aishell ASR1](../../examples/aishell/asr1) | python |
Expand Down
4 changes: 2 additions & 2 deletions paddlespeech/cli/asr/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(self):
self.parser.add_argument(
'--model',
type=str,
default='conformer_wenetspeech',
default='conformer_u2pp_wenetspeech',
choices=[
tag[:tag.index('-')]
for tag in self.task_resource.pretrained_models.keys()
Expand Down Expand Up @@ -465,7 +465,7 @@ def execute(self, argv: List[str]) -> bool:
@stats_wrapper
def __call__(self,
audio_file: os.PathLike,
model: str='conformer_wenetspeech',
model: str='conformer_u2pp_wenetspeech',
lang: str='zh',
sample_rate: int=16000,
config: os.PathLike=None,
Expand Down
2 changes: 2 additions & 0 deletions paddlespeech/resource/model_alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
"deepspeech2online": ["paddlespeech.s2t.models.ds2:DeepSpeech2Model"],
"conformer": ["paddlespeech.s2t.models.u2:U2Model"],
"conformer_online": ["paddlespeech.s2t.models.u2:U2Model"],
"conformer_u2pp": ["paddlespeech.s2t.models.u2:U2Model"],
"conformer_u2pp_online": ["paddlespeech.s2t.models.u2:U2Model"],
"transformer": ["paddlespeech.s2t.models.u2:U2Model"],
"wenetspeech": ["paddlespeech.s2t.models.u2:U2Model"],

Expand Down
40 changes: 40 additions & 0 deletions paddlespeech/resource/pretrained_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,46 @@
'',
},
},
"conformer_u2pp_wenetspeech-zh-16k": {
'1.1': {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/asr1_chunk_conformer_u2pp_wenetspeech_ckpt_1.1.1.model.tar.gz',
'md5':
'eae678c04ed3b3f89672052fdc0c5e10',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/chunk_conformer_u2pp/checkpoints/avg_10',
'model':
'exp/chunk_conformer_u2pp/checkpoints/avg_10.pdparams',
'params':
'exp/chunk_conformer_u2pp/checkpoints/avg_10.pdparams',
'lm_url':
'',
'lm_md5':
'',
},
},
"conformer_u2pp_online_wenetspeech-zh-16k": {
'1.1': {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/asr1_chunk_conformer_u2pp_wenetspeech_ckpt_1.1.2.model.tar.gz',
'md5':
'925d047e9188dea7f421a718230c9ae3',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/chunk_conformer_u2pp/checkpoints/avg_10',
'model':
'exp/chunk_conformer_u2pp/checkpoints/avg_10.pdparams',
'params':
'exp/chunk_conformer_u2pp/checkpoints/avg_10.pdparams',
'lm_url':
'',
'lm_md5':
'',
},
},
"conformer_online_multicn-zh-16k": {
'1.0': {
'url':
Expand Down
4 changes: 1 addition & 3 deletions paddlespeech/s2t/exps/u2/bin/test_wav.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ def __init__(self, config, args):
self.preprocess_conf = config.preprocess_config
self.preprocess_args = {"train": False}
self.preprocessing = Transformation(self.preprocess_conf)
self.reverse_weight = getattr(config.model_conf, 'reverse_weight', 0.0)
self.text_feature = TextFeaturizer(
unit_type=config.unit_type,
vocab=config.vocab_filepath,
Expand Down Expand Up @@ -89,8 +88,7 @@ def run(self):
ctc_weight=decode_config.ctc_weight,
decoding_chunk_size=decode_config.decoding_chunk_size,
num_decoding_left_chunks=decode_config.num_decoding_left_chunks,
simulate_streaming=decode_config.simulate_streaming,
reverse_weight=self.reverse_weight)
simulate_streaming=decode_config.simulate_streaming)
rsl = result_transcripts[0][0]
utt = Path(self.audio_file).name
logger.info(f"hyp: {utt} {result_transcripts[0][0]}")
Expand Down
4 changes: 1 addition & 3 deletions paddlespeech/s2t/exps/u2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,6 @@ def __init__(self, config, args):
vocab=self.config.vocab_filepath,
spm_model_prefix=self.config.spm_model_prefix)
self.vocab_list = self.text_feature.vocab_list
self.reverse_weight = getattr(config.model_conf, 'reverse_weight', 0.0)

def id2token(self, texts, texts_len, text_feature):
""" ord() id to chr() chr """
Expand Down Expand Up @@ -351,8 +350,7 @@ def compute_metrics(self,
ctc_weight=decode_config.ctc_weight,
decoding_chunk_size=decode_config.decoding_chunk_size,
num_decoding_left_chunks=decode_config.num_decoding_left_chunks,
simulate_streaming=decode_config.simulate_streaming,
reverse_weight=self.reverse_weight)
simulate_streaming=decode_config.simulate_streaming)
decode_time = time.time() - start_time

for utt, target, result, rec_tids in zip(
Expand Down
33 changes: 15 additions & 18 deletions paddlespeech/s2t/models/u2/u2.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,16 +507,14 @@ def ctc_prefix_beam_search(
num_decoding_left_chunks, simulate_streaming)
return hyps[0][0]

def attention_rescoring(
self,
speech: paddle.Tensor,
speech_lengths: paddle.Tensor,
beam_size: int,
decoding_chunk_size: int=-1,
num_decoding_left_chunks: int=-1,
ctc_weight: float=0.0,
simulate_streaming: bool=False,
reverse_weight: float=0.0, ) -> List[int]:
def attention_rescoring(self,
speech: paddle.Tensor,
speech_lengths: paddle.Tensor,
beam_size: int,
decoding_chunk_size: int=-1,
num_decoding_left_chunks: int=-1,
ctc_weight: float=0.0,
simulate_streaming: bool=False) -> List[int]:
""" Apply attention rescoring decoding, CTC prefix beam search
is applied first to get nbest, then we resoring the nbest on
attention decoder with corresponding encoder out
Expand All @@ -536,7 +534,7 @@ def attention_rescoring(
"""
assert speech.shape[0] == speech_lengths.shape[0]
assert decoding_chunk_size != 0
if reverse_weight > 0.0:
if self.reverse_weight > 0.0:
zh794390558 marked this conversation as resolved.
Show resolved Hide resolved
# decoder should be a bitransformer decoder if reverse_weight > 0.0
assert hasattr(self.decoder, 'right_decoder')
device = speech.place
Expand Down Expand Up @@ -574,7 +572,7 @@ def attention_rescoring(
self.eos)
decoder_out, r_decoder_out, _ = self.decoder(
encoder_out, encoder_mask, hyps_pad, hyps_lens, r_hyps_pad,
reverse_weight) # (beam_size, max_hyps_len, vocab_size)
self.reverse_weight) # (beam_size, max_hyps_len, vocab_size)
# ctc score in ln domain
decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1)
decoder_out = decoder_out.numpy()
Expand All @@ -594,12 +592,13 @@ def attention_rescoring(
score += decoder_out[i][j][w]
# last decoder output token is `eos`, for laste decoder input token.
score += decoder_out[i][len(hyp[0])][self.eos]
if reverse_weight > 0:
if self.reverse_weight > 0:
r_score = 0.0
for j, w in enumerate(hyp[0]):
r_score += r_decoder_out[i][len(hyp[0]) - j - 1][w]
r_score += r_decoder_out[i][len(hyp[0])][self.eos]
score = score * (1 - reverse_weight) + r_score * reverse_weight
score = score * (1 - self.reverse_weight
) + r_score * self.reverse_weight
# add ctc score (which in ln domain)
score += hyp[1] * ctc_weight
if score > best_score:
Expand Down Expand Up @@ -748,8 +747,7 @@ def decode(self,
ctc_weight: float=0.0,
decoding_chunk_size: int=-1,
num_decoding_left_chunks: int=-1,
simulate_streaming: bool=False,
reverse_weight: float=0.0):
simulate_streaming: bool=False):
"""u2 decoding.

Args:
Expand Down Expand Up @@ -821,8 +819,7 @@ def decode(self,
decoding_chunk_size=decoding_chunk_size,
num_decoding_left_chunks=num_decoding_left_chunks,
ctc_weight=ctc_weight,
simulate_streaming=simulate_streaming,
reverse_weight=reverse_weight)
simulate_streaming=simulate_streaming)
hyps = [hyp]
else:
raise ValueError(f"Not support decoding method: {decoding_method}")
Expand Down
2 changes: 1 addition & 1 deletion paddlespeech/server/conf/ws_conformer_application.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ asr_online:
decode_method:
num_decoding_left_chunks: -1
force_yes: True
device: # cpu or gpu:id
device: cpu # cpu or gpu:id
continuous_decoding: True # enable continue decoding when endpoint detected

am_predictor_conf:
Expand Down
23 changes: 19 additions & 4 deletions paddlespeech/server/engine/asr/online/python/asr_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from yacs.config import CfgNode

from paddlespeech.audio.transform.transformation import Transformation
from paddlespeech.audio.utils.tensor_utils import st_reverse_pad_list
from paddlespeech.cli.asr.infer import ASRExecutor
from paddlespeech.cli.log import logger
from paddlespeech.resource import CommonTaskResource
Expand Down Expand Up @@ -603,24 +604,31 @@ def rescoring(self):

hyps_pad = pad_sequence(
hyp_list, batch_first=True, padding_value=self.model.ignore_id)
ori_hyps_pad = hyps_pad
hyps_lens = paddle.to_tensor(
[len(hyp[0]) for hyp in hyps], place=self.device,
dtype=paddle.long) # (beam_size,)
hyps_pad, _ = add_sos_eos(hyps_pad, self.model.sos, self.model.eos,
self.model.ignore_id)
hyps_lens = hyps_lens + 1 # Add <sos> at begining

encoder_out = self.encoder_out.repeat(beam_size, 1, 1)
encoder_mask = paddle.ones(
(beam_size, 1, encoder_out.shape[1]), dtype=paddle.bool)

decoder_out, _, _ = self.model.decoder(
encoder_out, encoder_mask, hyps_pad,
hyps_lens) # (beam_size, max_hyps_len, vocab_size)
r_hyps_pad = st_reverse_pad_list(ori_hyps_pad, hyps_lens - 1,
self.model.sos, self.model.eos)
decoder_out, r_decoder_out, _ = self.model.decoder(
encoder_out, encoder_mask, hyps_pad, hyps_lens, r_hyps_pad,
self.model.reverse_weight) # (beam_size, max_hyps_len, vocab_size)
zh794390558 marked this conversation as resolved.
Show resolved Hide resolved
# ctc score in ln domain
decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1)
decoder_out = decoder_out.numpy()

# r_decoder_out will be 0.0, if reverse_weight is 0.0 or decoder is a
# conventional transformer decoder.
r_decoder_out = paddle.nn.functional.log_softmax(r_decoder_out, axis=-1)
r_decoder_out = r_decoder_out.numpy()

# Only use decoder score for rescoring
best_score = -float('inf')
best_index = 0
Expand All @@ -632,6 +640,13 @@ def rescoring(self):

# last decoder output token is `eos`, for laste decoder input token.
score += decoder_out[i][len(hyp[0])][self.model.eos]
if self.model.reverse_weight > 0:
r_score = 0.0
for j, w in enumerate(hyp[0]):
r_score += r_decoder_out[i][len(hyp[0]) - j - 1][w]
r_score += r_decoder_out[i][len(hyp[0])][self.model.eos]
score = score * (1 - self.model.reverse_weight
) + r_score * self.model.reverse_weight
# add ctc score (which in ln domain)
score += hyp[1] * self.ctc_decode_config.ctc_weight

Expand Down